diff --git a/lib/python3.10/site-packages/sympy/categories/__pycache__/baseclasses.cpython-310.pyc b/lib/python3.10/site-packages/sympy/categories/__pycache__/baseclasses.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66f789b590f768f55d219ff9b6eaf8b63ffd7f20 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/categories/__pycache__/baseclasses.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/categories/__pycache__/diagram_drawing.cpython-310.pyc b/lib/python3.10/site-packages/sympy/categories/__pycache__/diagram_drawing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d73d119ccea83ff0cc604f69c0bdafb7c13f462 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/categories/__pycache__/diagram_drawing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/categories/tests/__init__.py b/lib/python3.10/site-packages/sympy/categories/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/categories/tests/test_baseclasses.py b/lib/python3.10/site-packages/sympy/categories/tests/test_baseclasses.py new file mode 100644 index 0000000000000000000000000000000000000000..cfac32229768fb5903b23b11ffb236912c0b931e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/categories/tests/test_baseclasses.py @@ -0,0 +1,209 @@ +from sympy.categories import (Object, Morphism, IdentityMorphism, + NamedMorphism, CompositeMorphism, + Diagram, Category) +from sympy.categories.baseclasses import Class +from sympy.testing.pytest import raises +from sympy.core.containers import (Dict, Tuple) +from sympy.sets import EmptySet +from sympy.sets.sets import FiniteSet + + +def test_morphisms(): + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + + # Test the base morphism. + f = NamedMorphism(A, B, "f") + assert f.domain == A + assert f.codomain == B + assert f == NamedMorphism(A, B, "f") + + # Test identities. + id_A = IdentityMorphism(A) + id_B = IdentityMorphism(B) + assert id_A.domain == A + assert id_A.codomain == A + assert id_A == IdentityMorphism(A) + assert id_A != id_B + + # Test named morphisms. + g = NamedMorphism(B, C, "g") + assert g.name == "g" + assert g != f + assert g == NamedMorphism(B, C, "g") + assert g != NamedMorphism(B, C, "f") + + # Test composite morphisms. + assert f == CompositeMorphism(f) + + k = g.compose(f) + assert k.domain == A + assert k.codomain == C + assert k.components == Tuple(f, g) + assert g * f == k + assert CompositeMorphism(f, g) == k + + assert CompositeMorphism(g * f) == g * f + + # Test the associativity of composition. + h = NamedMorphism(C, D, "h") + + p = h * g + u = h * g * f + + assert h * k == u + assert p * f == u + assert CompositeMorphism(f, g, h) == u + + # Test flattening. + u2 = u.flatten("u") + assert isinstance(u2, NamedMorphism) + assert u2.name == "u" + assert u2.domain == A + assert u2.codomain == D + + # Test identities. + assert f * id_A == f + assert id_B * f == f + assert id_A * id_A == id_A + assert CompositeMorphism(id_A) == id_A + + # Test bad compositions. + raises(ValueError, lambda: f * g) + + raises(TypeError, lambda: f.compose(None)) + raises(TypeError, lambda: id_A.compose(None)) + raises(TypeError, lambda: f * None) + raises(TypeError, lambda: id_A * None) + + raises(TypeError, lambda: CompositeMorphism(f, None, 1)) + + raises(ValueError, lambda: NamedMorphism(A, B, "")) + raises(NotImplementedError, lambda: Morphism(A, B)) + + +def test_diagram(): + A = Object("A") + B = Object("B") + C = Object("C") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + id_A = IdentityMorphism(A) + id_B = IdentityMorphism(B) + + empty = EmptySet + + # Test the addition of identities. + d1 = Diagram([f]) + + assert d1.objects == FiniteSet(A, B) + assert d1.hom(A, B) == (FiniteSet(f), empty) + assert d1.hom(A, A) == (FiniteSet(id_A), empty) + assert d1.hom(B, B) == (FiniteSet(id_B), empty) + + assert d1 == Diagram([id_A, f]) + assert d1 == Diagram([f, f]) + + # Test the addition of composites. + d2 = Diagram([f, g]) + homAC = d2.hom(A, C)[0] + + assert d2.objects == FiniteSet(A, B, C) + assert g * f in d2.premises.keys() + assert homAC == FiniteSet(g * f) + + # Test equality, inequality and hash. + d11 = Diagram([f]) + + assert d1 == d11 + assert d1 != d2 + assert hash(d1) == hash(d11) + + d11 = Diagram({f: "unique"}) + assert d1 != d11 + + # Make sure that (re-)adding composites (with new properties) + # works as expected. + d = Diagram([f, g], {g * f: "unique"}) + assert d.conclusions == Dict({g * f: FiniteSet("unique")}) + + # Check the hom-sets when there are premises and conclusions. + assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f)) + d = Diagram([f, g], [g * f]) + assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f)) + + # Check how the properties of composite morphisms are computed. + d = Diagram({f: ["unique", "isomorphism"], g: "unique"}) + assert d.premises[g * f] == FiniteSet("unique") + + # Check that conclusion morphisms with new objects are not allowed. + d = Diagram([f], [g]) + assert d.conclusions == Dict({}) + + # Test an empty diagram. + d = Diagram() + assert d.premises == Dict({}) + assert d.conclusions == Dict({}) + assert d.objects == empty + + # Check a SymPy Dict object. + d = Diagram(Dict({f: FiniteSet("unique", "isomorphism"), g: "unique"})) + assert d.premises[g * f] == FiniteSet("unique") + + # Check the addition of components of composite morphisms. + d = Diagram([g * f]) + assert f in d.premises + assert g in d.premises + + # Check subdiagrams. + d = Diagram([f, g], {g * f: "unique"}) + + d1 = Diagram([f]) + assert d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram([NamedMorphism(B, A, "f'")]) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d1 = Diagram([f, g], {g * f: ["unique", "something"]}) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram({f: "blooh"}) + d1 = Diagram({f: "bleeh"}) + assert not d.is_subdiagram(d1) + assert not d1.is_subdiagram(d) + + d = Diagram([f, g], {f: "unique", g * f: "veryunique"}) + d1 = d.subdiagram_from_objects(FiniteSet(A, B)) + assert d1 == Diagram([f], {f: "unique"}) + raises(ValueError, lambda: d.subdiagram_from_objects(FiniteSet(A, + Object("D")))) + + raises(ValueError, lambda: Diagram({IdentityMorphism(A): "unique"})) + + +def test_category(): + A = Object("A") + B = Object("B") + C = Object("C") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + + d1 = Diagram([f, g]) + d2 = Diagram([f]) + + objects = d1.objects | d2.objects + + K = Category("K", objects, commutative_diagrams=[d1, d2]) + + assert K.name == "K" + assert K.objects == Class(objects) + assert K.commutative_diagrams == FiniteSet(d1, d2) + + raises(ValueError, lambda: Category("")) diff --git a/lib/python3.10/site-packages/sympy/categories/tests/test_drawing.py b/lib/python3.10/site-packages/sympy/categories/tests/test_drawing.py new file mode 100644 index 0000000000000000000000000000000000000000..63a13266cd6b58f6a85aad4af0813b395acbb5e1 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/categories/tests/test_drawing.py @@ -0,0 +1,919 @@ +from sympy.categories.diagram_drawing import _GrowableGrid, ArrowStringDescription +from sympy.categories import (DiagramGrid, Object, NamedMorphism, + Diagram, XypicDiagramDrawer, xypic_draw_diagram) +from sympy.sets.sets import FiniteSet + + +def test_GrowableGrid(): + grid = _GrowableGrid(1, 2) + + # Check dimensions. + assert grid.width == 1 + assert grid.height == 2 + + # Check initialization of elements. + assert grid[0, 0] is None + assert grid[1, 0] is None + + # Check assignment to elements. + grid[0, 0] = 1 + grid[1, 0] = "two" + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + + # Check appending a row. + grid.append_row() + + assert grid.width == 1 + assert grid.height == 3 + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + assert grid[2, 0] is None + + # Check appending a column. + grid.append_column() + assert grid.width == 2 + assert grid.height == 3 + + assert grid[0, 0] == 1 + assert grid[1, 0] == "two" + assert grid[2, 0] is None + + assert grid[0, 1] is None + assert grid[1, 1] is None + assert grid[2, 1] is None + + grid = _GrowableGrid(1, 2) + grid[0, 0] = 1 + grid[1, 0] = "two" + + # Check prepending a row. + grid.prepend_row() + assert grid.width == 1 + assert grid.height == 3 + + assert grid[0, 0] is None + assert grid[1, 0] == 1 + assert grid[2, 0] == "two" + + # Check prepending a column. + grid.prepend_column() + assert grid.width == 2 + assert grid.height == 3 + + assert grid[0, 0] is None + assert grid[1, 0] is None + assert grid[2, 0] is None + + assert grid[0, 1] is None + assert grid[1, 1] == 1 + assert grid[2, 1] == "two" + + +def test_DiagramGrid(): + # Set up some objects and morphisms. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(D, A, "h") + k = NamedMorphism(D, B, "k") + + # A one-morphism diagram. + d = Diagram([f]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid.morphisms == {f: FiniteSet()} + + # A triangle. + d = Diagram([f, g], {g * f: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] == C + assert grid[1, 1] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), + g * f: FiniteSet("unique")} + + # A triangle with a "loop" morphism. + l_A = NamedMorphism(A, A, "l_A") + d = Diagram([f, g, l_A]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), l_A: FiniteSet()} + + # A simple diagram. + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == D + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + assert str(grid) == '[[Object("A"), Object("B"), Object("D")], ' \ + '[None, Object("C"), None]]' + + # A chain of morphisms. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + k = NamedMorphism(D, E, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + # A square. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, D, "g") + h = NamedMorphism(A, C, "h") + k = NamedMorphism(C, D, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[1, 0] == C + assert grid[1, 1] == D + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + k: FiniteSet()} + + # A strange diagram which resulted from a typo when creating a + # test for five lemma, but which allowed to stop one extra problem + # in the algorithm. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + + # These 4 morphisms should be between primed objects. + j = NamedMorphism(A, B, "j") + k = NamedMorphism(B, C, "k") + l = NamedMorphism(C, D, "l") + m = NamedMorphism(D, E, "m") + + o = NamedMorphism(A, A_, "o") + p = NamedMorphism(B, B_, "p") + q = NamedMorphism(C, C_, "q") + r = NamedMorphism(D, D_, "r") + s = NamedMorphism(E, E_, "s") + + d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s]) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 4 + assert grid[0, 0] is None + assert grid[0, 1] == A + assert grid[0, 2] == A_ + assert grid[1, 0] == C + assert grid[1, 1] == B + assert grid[1, 2] == B_ + assert grid[2, 0] == C_ + assert grid[2, 1] == D + assert grid[2, 2] == D_ + assert grid[3, 0] is None + assert grid[3, 1] == E + assert grid[3, 2] == E_ + + morphisms = {} + for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # A cube. + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + A4 = Object("A4") + A5 = Object("A5") + A6 = Object("A6") + A7 = Object("A7") + A8 = Object("A8") + + # The top face of the cube. + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A1, A3, "f2") + f3 = NamedMorphism(A2, A4, "f3") + f4 = NamedMorphism(A3, A4, "f3") + + # The bottom face of the cube. + f5 = NamedMorphism(A5, A6, "f5") + f6 = NamedMorphism(A5, A7, "f6") + f7 = NamedMorphism(A6, A8, "f7") + f8 = NamedMorphism(A7, A8, "f8") + + # The remaining morphisms. + f9 = NamedMorphism(A1, A5, "f9") + f10 = NamedMorphism(A2, A6, "f10") + f11 = NamedMorphism(A3, A7, "f11") + f12 = NamedMorphism(A4, A8, "f11") + + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 3 + assert grid[0, 0] is None + assert grid[0, 1] == A5 + assert grid[0, 2] == A6 + assert grid[0, 3] is None + assert grid[1, 0] is None + assert grid[1, 1] == A1 + assert grid[1, 2] == A2 + assert grid[1, 3] is None + assert grid[2, 0] == A7 + assert grid[2, 1] == A3 + assert grid[2, 2] == A4 + assert grid[2, 3] == A8 + + morphisms = {} + for m in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # A line diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + grid = DiagramGrid(d, layout="sequential") + + assert grid.width == 5 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid[0, 4] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + i: FiniteSet()} + + # Test the transposed version. + grid = DiagramGrid(d, layout="sequential", transpose=True) + + assert grid.width == 1 + assert grid.height == 5 + assert grid[0, 0] == A + assert grid[1, 0] == B + assert grid[2, 0] == C + assert grid[3, 0] == D + assert grid[4, 0] == E + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(), + i: FiniteSet()} + + # A pullback. + m1 = NamedMorphism(A, B, "m1") + m2 = NamedMorphism(A, C, "m2") + s1 = NamedMorphism(B, D, "s1") + s2 = NamedMorphism(C, D, "s2") + f1 = NamedMorphism(E, B, "f1") + f2 = NamedMorphism(E, C, "f2") + g = NamedMorphism(E, A, "g") + + d = Diagram([m1, m2, s1, s2, f1, f2], {g: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == E + assert grid[1, 0] == C + assert grid[1, 1] == D + assert grid[1, 2] is None + + morphisms = {g: FiniteSet("unique")} + for m in [m1, m2, s1, s2, f1, f2]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # Test the pullback with sequential layout, just for stress + # testing. + grid = DiagramGrid(d, layout="sequential") + + assert grid.width == 5 + assert grid.height == 1 + assert grid[0, 0] == D + assert grid[0, 1] == B + assert grid[0, 2] == A + assert grid[0, 3] == C + assert grid[0, 4] == E + assert grid.morphisms == morphisms + + # Test a pullback with object grouping. + grid = DiagramGrid(d, groups=FiniteSet(E, FiniteSet(A, B, C, D))) + + assert grid.width == 3 + assert grid.height == 2 + assert grid[0, 0] == E + assert grid[0, 1] == A + assert grid[0, 2] == B + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid.morphisms == morphisms + + # Five lemma, actually. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + + j = NamedMorphism(A_, B_, "j") + k = NamedMorphism(B_, C_, "k") + l = NamedMorphism(C_, D_, "l") + m = NamedMorphism(D_, E_, "m") + + o = NamedMorphism(A, A_, "o") + p = NamedMorphism(B, B_, "p") + q = NamedMorphism(C, C_, "q") + r = NamedMorphism(D, D_, "r") + s = NamedMorphism(E, E_, "s") + + d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s]) + grid = DiagramGrid(d) + + assert grid.width == 5 + assert grid.height == 3 + assert grid[0, 0] is None + assert grid[0, 1] == A + assert grid[0, 2] == A_ + assert grid[0, 3] is None + assert grid[0, 4] is None + assert grid[1, 0] == C + assert grid[1, 1] == B + assert grid[1, 2] == B_ + assert grid[1, 3] == C_ + assert grid[1, 4] is None + assert grid[2, 0] == D + assert grid[2, 1] == E + assert grid[2, 2] is None + assert grid[2, 3] == D_ + assert grid[2, 4] == E_ + + morphisms = {} + for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]: + morphisms[m] = FiniteSet() + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping. + grid = DiagramGrid(d, FiniteSet( + FiniteSet(A, B, C, D, E), FiniteSet(A_, B_, C_, D_, E_))) + + assert grid.width == 6 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[0, 3] == A_ + assert grid[0, 4] == B_ + assert grid[0, 5] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[1, 3] is None + assert grid[1, 4] == C_ + assert grid[1, 5] == D_ + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid[2, 3] is None + assert grid[2, 4] is None + assert grid[2, 5] == E_ + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping, but mixing containers + # to represent groups. + grid = DiagramGrid(d, [(A, B, C, D, E), {A_, B_, C_, D_, E_}]) + + assert grid.width == 6 + assert grid.height == 3 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] is None + assert grid[0, 3] == A_ + assert grid[0, 4] == B_ + assert grid[0, 5] is None + assert grid[1, 0] is None + assert grid[1, 1] == C + assert grid[1, 2] == D + assert grid[1, 3] is None + assert grid[1, 4] == C_ + assert grid[1, 5] == D_ + assert grid[2, 0] is None + assert grid[2, 1] is None + assert grid[2, 2] == E + assert grid[2, 3] is None + assert grid[2, 4] is None + assert grid[2, 5] == E_ + assert grid.morphisms == morphisms + + # Test the five lemma with object grouping and hints. + grid = DiagramGrid(d, { + FiniteSet(A, B, C, D, E): {"layout": "sequential", + "transpose": True}, + FiniteSet(A_, B_, C_, D_, E_): {"layout": "sequential", + "transpose": True}}, + transpose=True) + + assert grid.width == 5 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid[0, 4] == E + assert grid[1, 0] == A_ + assert grid[1, 1] == B_ + assert grid[1, 2] == C_ + assert grid[1, 3] == D_ + assert grid[1, 4] == E_ + assert grid.morphisms == morphisms + + # A two-triangle disconnected diagram. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + f_ = NamedMorphism(A_, B_, "f") + g_ = NamedMorphism(B_, C_, "g") + d = Diagram([f, g, f_, g_], {g * f: "unique", g_ * f_: "unique"}) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 2 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == A_ + assert grid[0, 3] == B_ + assert grid[1, 0] == C + assert grid[1, 1] is None + assert grid[1, 2] == C_ + assert grid[1, 3] is None + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), f_: FiniteSet(), + g_: FiniteSet(), g * f: FiniteSet("unique"), + g_ * f_: FiniteSet("unique")} + + # A two-morphism disconnected diagram. + f = NamedMorphism(A, B, "f") + g = NamedMorphism(C, D, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert grid.width == 4 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + assert grid[0, 2] == C + assert grid[0, 3] == D + assert grid.morphisms == {f: FiniteSet(), g: FiniteSet()} + + # Test a one-object diagram. + f = NamedMorphism(A, A, "f") + d = Diagram([f]) + grid = DiagramGrid(d) + + assert grid.width == 1 + assert grid.height == 1 + assert grid[0, 0] == A + + # Test a two-object disconnected diagram. + g = NamedMorphism(B, B, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert grid.width == 2 + assert grid.height == 1 + assert grid[0, 0] == A + assert grid[0, 1] == B + + +def test_DiagramGrid_pseudopod(): + # Test a diagram in which even growing a pseudopod does not + # eventually help. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + F = Object("F") + A_ = Object("A'") + B_ = Object("B'") + C_ = Object("C'") + D_ = Object("D'") + E_ = Object("E'") + + f1 = NamedMorphism(A, B, "f1") + f2 = NamedMorphism(A, C, "f2") + f3 = NamedMorphism(A, D, "f3") + f4 = NamedMorphism(A, E, "f4") + f5 = NamedMorphism(A, A_, "f5") + f6 = NamedMorphism(A, B_, "f6") + f7 = NamedMorphism(A, C_, "f7") + f8 = NamedMorphism(A, D_, "f8") + f9 = NamedMorphism(A, E_, "f9") + f10 = NamedMorphism(A, F, "f10") + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]) + grid = DiagramGrid(d) + + assert grid.width == 5 + assert grid.height == 3 + assert grid[0, 0] == E + assert grid[0, 1] == C + assert grid[0, 2] == C_ + assert grid[0, 3] == E_ + assert grid[0, 4] == F + assert grid[1, 0] == D + assert grid[1, 1] == A + assert grid[1, 2] == A_ + assert grid[1, 3] is None + assert grid[1, 4] is None + assert grid[2, 0] == D_ + assert grid[2, 1] == B + assert grid[2, 2] == B_ + assert grid[2, 3] is None + assert grid[2, 4] is None + + morphisms = {} + for f in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]: + morphisms[f] = FiniteSet() + assert grid.morphisms == morphisms + + +def test_ArrowStringDescription(): + astr = ArrowStringDescription("cm", "", None, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "^", 12, "", "", "d", "r", "_", "f") + assert str(astr) == "\\ar@/^12cm/[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "", "d", "r", "_", "f") + assert str(astr) == "\\ar[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + assert str(astr) == "\\ar@(r,u)[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + assert str(astr) == "\\ar@(r,u)[dr]_{f}" + + astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f") + astr.arrow_style = "{-->}" + assert str(astr) == "\\ar@(r,u)@{-->}[dr]_{f}" + + astr = ArrowStringDescription("cm", "_", 12, "", "", "d", "r", "_", "f") + astr.arrow_style = "{-->}" + assert str(astr) == "\\ar@/_12cm/@{-->}[dr]_{f}" + + +def test_XypicDiagramDrawer_line(): + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + grid = DiagramGrid(d, layout="sequential") + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]^{f} & B \\ar[r]^{g} & C \\ar[r]^{h} & D \\ar[r]^{i} & E \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, layout="sequential", transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\\\\n" \ + "B \\ar[d]^{g} \\\\\n" \ + "C \\ar[d]^{h} \\\\\n" \ + "D \\ar[d]^{i} \\\\\n" \ + "E \n" \ + "}\n" + + +def test_XypicDiagramDrawer_triangle(): + # A triangle diagram. + A = Object("A") + B = Object("B") + C = Object("C") + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + + d = Diagram([f, g], {g * f: "unique"}) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]_{g\\circ f} \\ar[r]^{f} & B \\ar[ld]^{g} \\\\\n" \ + "C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B \\ar[ru]_{g} & \n" \ + "}\n" + + # The same diagram, with a masked morphism. + assert drawer.draw(d, grid, masked=[g]) == "\\xymatrix{\n" \ + "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B & \n" \ + "}\n" + + # The same diagram with a formatter for "unique". + def formatter(astr): + astr.label = "\\exists !" + astr.label + astr.arrow_style = "{-->}" + + drawer.arrow_formatters["unique"] = formatter + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar@{-->}[r]^{\\exists !g\\circ f} \\ar[d]_{f} & C \\\\\n" \ + "B \\ar[ru]_{g} & \n" \ + "}\n" + + # The same diagram with a default formatter. + def default_formatter(astr): + astr.label_displacement = "(0.45)" + + drawer.default_arrow_formatter = default_formatter + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar@{-->}[r]^(0.45){\\exists !g\\circ f} \\ar[d]_(0.45){f} & C \\\\\n" \ + "B \\ar[ru]_(0.45){g} & \n" \ + "}\n" + + # A triangle diagram with a lot of morphisms between the same + # objects. + f1 = NamedMorphism(B, A, "f1") + f2 = NamedMorphism(A, B, "f2") + g1 = NamedMorphism(C, B, "g1") + g2 = NamedMorphism(B, C, "g2") + d = Diagram([f, f1, f2, g, g1, g2], {f1 * g1: "unique", g2 * f2: "unique"}) + + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid, masked=[f1*g1*g2*f2, g2*f2*f1*g1]) == \ + "\\xymatrix{\n" \ + "A \\ar[r]^{g_{2}\\circ f_{2}} \\ar[d]_{f} \\ar@/^3mm/[d]^{f_{2}} " \ + "& C \\ar@/^3mm/[l]^{f_{1}\\circ g_{1}} \\ar@/^3mm/[ld]^{g_{1}} \\\\\n" \ + "B \\ar@/^3mm/[u]^{f_{1}} \\ar[ru]_{g} \\ar@/^3mm/[ru]^{g_{2}} & \n" \ + "}\n" + + +def test_XypicDiagramDrawer_cube(): + # A cube diagram. + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + A4 = Object("A4") + A5 = Object("A5") + A6 = Object("A6") + A7 = Object("A7") + A8 = Object("A8") + + # The top face of the cube. + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A1, A3, "f2") + f3 = NamedMorphism(A2, A4, "f3") + f4 = NamedMorphism(A3, A4, "f3") + + # The bottom face of the cube. + f5 = NamedMorphism(A5, A6, "f5") + f6 = NamedMorphism(A5, A7, "f6") + f7 = NamedMorphism(A6, A8, "f7") + f8 = NamedMorphism(A7, A8, "f8") + + # The remaining morphisms. + f9 = NamedMorphism(A1, A5, "f9") + f10 = NamedMorphism(A2, A6, "f10") + f11 = NamedMorphism(A3, A7, "f11") + f12 = NamedMorphism(A4, A8, "f11") + + d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "& A_{5} \\ar[r]^{f_{5}} \\ar[ldd]_{f_{6}} & A_{6} \\ar[rdd]^{f_{7}} " \ + "& \\\\\n" \ + "& A_{1} \\ar[r]^{f_{1}} \\ar[d]^{f_{2}} \\ar[u]^{f_{9}} & A_{2} " \ + "\\ar[d]^{f_{3}} \\ar[u]_{f_{10}} & \\\\\n" \ + "A_{7} \\ar@/_3mm/[rrr]_{f_{8}} & A_{3} \\ar[r]^{f_{3}} \\ar[l]_{f_{11}} " \ + "& A_{4} \\ar[r]^{f_{11}} & A_{8} \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "& & A_{7} \\ar@/^3mm/[ddd]^{f_{8}} \\\\\n" \ + "A_{5} \\ar[d]_{f_{5}} \\ar[rru]^{f_{6}} & A_{1} \\ar[d]^{f_{1}} " \ + "\\ar[r]^{f_{2}} \\ar[l]^{f_{9}} & A_{3} \\ar[d]_{f_{3}} " \ + "\\ar[u]^{f_{11}} \\\\\n" \ + "A_{6} \\ar[rrd]_{f_{7}} & A_{2} \\ar[r]^{f_{3}} \\ar[l]^{f_{10}} " \ + "& A_{4} \\ar[d]_{f_{11}} \\\\\n" \ + "& & A_{8} \n" \ + "}\n" + + +def test_XypicDiagramDrawer_curved_and_loops(): + # A simple diagram, with a curved arrow. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(D, A, "h") + k = NamedMorphism(D, B, "k") + d = Diagram([f, g, h, k]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} & B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_3mm/[ll]_{h} \\\\\n" \ + "& C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \ + "}\n" + + # The same diagram, larger and rotated. + assert drawer.draw(d, grid, diagram_format="@+1cm@dr") == \ + "\\xymatrix@+1cm@dr{\n" \ + "A \\ar[d]^{f} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \ + "}\n" + + # A simple diagram with three curved arrows. + h1 = NamedMorphism(D, A, "h1") + h2 = NamedMorphism(A, D, "h2") + k = NamedMorphism(D, B, "k") + d = Diagram([f, g, h, k, h1, h2]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \ + "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\\\\n" \ + "& C & \n" \ + "}\n" + + # The same diagram, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} & \n" \ + "}\n" + + # The same diagram, with "loop" morphisms. + l_A = NamedMorphism(A, A, "l_A") + l_D = NamedMorphism(D, D, "l_D") + l_C = NamedMorphism(C, C, "l_C") + d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \ + "& B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_7mm/[ll]_{h} " \ + "\\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} \\\\\n" \ + "& C \\ar@(l,d)[]^{l_{C}} & \n" \ + "}\n" + + # The same diagram with "loop" morphisms, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \ + "\\ar@(l,d)[]^{l_{D}} & \n" \ + "}\n" + + # The same diagram with two "loop" morphisms per object. + l_A_ = NamedMorphism(A, A, "n_A") + l_D_ = NamedMorphism(D, D, "n_D") + l_C_ = NamedMorphism(C, C, "n_C") + d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C, l_A_, l_D_, l_C_]) + grid = DiagramGrid(d) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \ + "\\ar@/^3mm/@(l,d)[]^{n_{A}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \ + "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} " \ + "\\ar@/^3mm/@(d,r)[]^{n_{D}} \\\\\n" \ + "& C \\ar@(l,d)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} & \n" \ + "}\n" + + # The same diagram with two "loop" morphisms per object, transposed. + grid = DiagramGrid(d, transpose=True) + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == "\\xymatrix{\n" \ + "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} " \ + "\\ar@/^3mm/@(u,l)[]^{n_{A}} & \\\\\n" \ + "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} \\\\\n" \ + "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \ + "\\ar@(l,d)[]^{l_{D}} \\ar@/^3mm/@(d,r)[]^{n_{D}} & \n" \ + "}\n" + + +def test_xypic_draw_diagram(): + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + D = Object("D") + E = Object("E") + + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + h = NamedMorphism(C, D, "h") + i = NamedMorphism(D, E, "i") + d = Diagram([f, g, h, i]) + + grid = DiagramGrid(d, layout="sequential") + drawer = XypicDiagramDrawer() + assert drawer.draw(d, grid) == xypic_draw_diagram(d, layout="sequential") diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..32855854b63ea8be3998d8a0ecd09ec879328264 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/abstract_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/abstract_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3252cd45fadb21ab0d0ea2c23f2552e488a4b547 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/abstract_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/algorithms.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/algorithms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb7737fbb0e487feedc9b7d203b890e959f57ebc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/algorithms.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/approximations.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/approximations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..347be7115c03a86b63961453fbe340b294d37850 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/approximations.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/ast.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/ast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..935996e06fbd44017eb1862da896333e27e8d8c6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/ast.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/cfunctions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cfunctions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9d702c9583edd5d98a01709882845e368578877 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cfunctions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/cnodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cnodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3995d9742f4a727ad213ec4967df38870c14b829 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cnodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/cutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c73d761d15c956e790e23104f9a76e527a21173 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/cxxnodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cxxnodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09220efbb1cede9d65e256891554956aee4ea720 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/cxxnodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/fnodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/fnodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..641c938740df73cb37a4c6bb7bbb2a3f21890e53 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/fnodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/futils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/futils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33826d74c6ac8f070b9ffffdb6187c0aeaab8a35 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/futils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/matrix_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/matrix_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3427f8ab950c999f95e49ff4ee2956186ec4381 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/matrix_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/numpy_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/numpy_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3c74c9fa66c04e856fe7ae72d164434bd3feebc4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/numpy_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/pynodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/pynodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3ac1d184e778f705ed09e856ce426f5b14f5df1e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/pynodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/pyutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/pyutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bc1f7d5dad0a7d9140483630abac830a0ca6b6e3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/pyutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/rewriting.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/rewriting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe9e957844cf4f3a0d3e137451f597c16ef45425 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/rewriting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/__pycache__/scipy_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/__pycache__/scipy_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33d4aa40daab8ddd48d9ed80154a45f7847ea12c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/__pycache__/scipy_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__init__.py b/lib/python3.10/site-packages/sympy/codegen/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4710d6329646e1a52d7584a7171d332553a269f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_abstract_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_abstract_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e1c752252fd8ecfedd10d55726f67182a643ba0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_abstract_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_algorithms.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_algorithms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48bfb879390da257e02321ada5d61a76b48f77ca Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_algorithms.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_applications.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_applications.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc6ece98a81d38c37524858f25ad6a00b0768741 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_applications.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_approximations.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_approximations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ba50eccd6c7e74dcb746f68d3bffc1b9034a7c3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_approximations.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_ast.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_ast.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4d57d321c4cc2e24488046561b696258d627dfc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_ast.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cfunctions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cfunctions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bfd59932866e0e24a0e11c179fa2aed27f1a7c66 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cfunctions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cnodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cnodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e2b81471d5a3443d2253f62a081b15f5079dea8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cnodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cxxnodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cxxnodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5eddc4941f0e4bb977a096aeded2a9b0f8dbb7d4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_cxxnodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_fnodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_fnodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c59cde620838ff99d84a6816e62a84a35f69fc5d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_fnodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_matrix_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_matrix_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..516c91c3fddd41eab4b80656cf68e1aca651c9cc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_matrix_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_numpy_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_numpy_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eab8f7bce5f4dfdda4422423a5bed6fe06fc6b8a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_numpy_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_pynodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_pynodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85fb9c8d868ef3c72ca61733c506a4a15db27bf2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_pynodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_pyutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_pyutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e315b53a59bea1d9bd09bdd3ea7d399043dca8b2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_pyutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_rewriting.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_rewriting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fda1f0d6b36d953796b703733c88e024a08bb9c4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_rewriting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_scipy_nodes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_scipy_nodes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e17e90280351418a6734755b6a14d3500e0cac Binary files /dev/null and b/lib/python3.10/site-packages/sympy/codegen/tests/__pycache__/test_scipy_nodes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/test_abstract_nodes.py b/lib/python3.10/site-packages/sympy/codegen/tests/test_abstract_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..89e1f73ff8cb24a4a865aa51304ec66e9901e3cb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/codegen/tests/test_abstract_nodes.py @@ -0,0 +1,14 @@ +from sympy.core.symbol import symbols +from sympy.codegen.abstract_nodes import List + + +def test_List(): + l = List(2, 3, 4) + assert l == List(2, 3, 4) + assert str(l) == "[2, 3, 4]" + x, y, z = symbols('x y z') + l = List(x**2,y**3,z**4) + # contrary to python's built-in list, we can call e.g. "replace" on List. + m = l.replace(lambda arg: arg.is_Pow and arg.exp>2, lambda p: p.base-p.exp) + assert m == [x**2, y-3, z-4] + hash(m) diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/test_algorithms.py b/lib/python3.10/site-packages/sympy/codegen/tests/test_algorithms.py new file mode 100644 index 0000000000000000000000000000000000000000..09446258d461d71e299408555399c7f09fbd8419 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/codegen/tests/test_algorithms.py @@ -0,0 +1,179 @@ +import tempfile +from sympy import log, Min, Max, sqrt +from sympy.core.numbers import Float +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.trigonometric import cos +from sympy.codegen.ast import Assignment, Raise, RuntimeError_, QuotedString +from sympy.codegen.algorithms import newtons_method, newtons_method_function +from sympy.codegen.cfunctions import expm1 +from sympy.codegen.fnodes import bind_C +from sympy.codegen.futils import render_as_module as f_module +from sympy.codegen.pyutils import render_as_module as py_module +from sympy.external import import_module +from sympy.printing.codeprinter import ccode +from sympy.utilities._compilation import compile_link_import_strings, has_c, has_fortran +from sympy.utilities._compilation.util import may_xfail +from sympy.testing.pytest import skip, raises + +cython = import_module('cython') +wurlitzer = import_module('wurlitzer') + +def test_newtons_method(): + x, dx, atol = symbols('x dx atol') + expr = cos(x) - x**3 + algo = newtons_method(expr, x, atol, dx) + assert algo.has(Assignment(dx, -expr/expr.diff(x))) + + +@may_xfail +def test_newtons_method_function__ccode(): + x = Symbol('x', real=True) + expr = cos(x) - x**3 + func = newtons_method_function(expr, x) + + if not cython: + skip("cython not installed.") + if not has_c(): + skip("No C compiler found.") + + compile_kw = {"std": 'c99'} + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings([ + ('newton.c', ('#include \n' + '#include \n') + ccode(func)), + ('_newton.pyx', ("#cython: language_level={}\n".format("3") + + "cdef extern double newton(double)\n" + "def py_newton(x):\n" + " return newton(x)\n")) + ], build_dir=folder, compile_kwargs=compile_kw) + assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12 + + +@may_xfail +def test_newtons_method_function__fcode(): + x = Symbol('x', real=True) + expr = cos(x) - x**3 + func = newtons_method_function(expr, x, attrs=[bind_C(name='newton')]) + + if not cython: + skip("cython not installed.") + if not has_fortran(): + skip("No Fortran compiler found.") + + f_mod = f_module([func], 'mod_newton') + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings([ + ('newton.f90', f_mod), + ('_newton.pyx', ("#cython: language_level={}\n".format("3") + + "cdef extern double newton(double*)\n" + "def py_newton(double x):\n" + " return newton(&x)\n")) + ], build_dir=folder) + assert abs(mod.py_newton(0.5) - 0.865474033102) < 1e-12 + + +def test_newtons_method_function__pycode(): + x = Symbol('x', real=True) + expr = cos(x) - x**3 + func = newtons_method_function(expr, x) + py_mod = py_module(func) + namespace = {} + exec(py_mod, namespace, namespace) + res = eval('newton(0.5)', namespace) + assert abs(res - 0.865474033102) < 1e-12 + + +@may_xfail +def test_newtons_method_function__ccode_parameters(): + args = x, A, k, p = symbols('x A k p') + expr = A*cos(k*x) - p*x**3 + raises(ValueError, lambda: newtons_method_function(expr, x)) + use_wurlitzer = wurlitzer + + func = newtons_method_function(expr, x, args, debug=use_wurlitzer) + + if not has_c(): + skip("No C compiler found.") + if not cython: + skip("cython not installed.") + + compile_kw = {"std": 'c99'} + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings([ + ('newton_par.c', ('#include \n' + '#include \n') + ccode(func)), + ('_newton_par.pyx', ("#cython: language_level={}\n".format("3") + + "cdef extern double newton(double, double, double, double)\n" + "def py_newton(x, A=1, k=1, p=1):\n" + " return newton(x, A, k, p)\n")) + ], compile_kwargs=compile_kw, build_dir=folder) + + if use_wurlitzer: + with wurlitzer.pipes() as (out, err): + result = mod.py_newton(0.5) + else: + result = mod.py_newton(0.5) + + assert abs(result - 0.865474033102) < 1e-12 + + if not use_wurlitzer: + skip("C-level output only tested when package 'wurlitzer' is available.") + + out, err = out.read(), err.read() + assert err == '' + assert out == """\ +x= 0.5 +x= 1.1121 d_x= 0.61214 +x= 0.90967 d_x= -0.20247 +x= 0.86726 d_x= -0.042409 +x= 0.86548 d_x= -0.0017867 +x= 0.86547 d_x= -3.1022e-06 +x= 0.86547 d_x= -9.3421e-12 +x= 0.86547 d_x= 3.6902e-17 +""" # try to run tests with LC_ALL=C if this assertion fails + + +def test_newtons_method_function__rtol_cse_nan(): + a, b, c, N_geo, N_tot = symbols('a b c N_geo N_tot', real=True, nonnegative=True) + i = Symbol('i', integer=True, nonnegative=True) + N_ari = N_tot - N_geo - 1 + delta_ari = (c-b)/N_ari + ln_delta_geo = log(b) + log(-expm1((log(a)-log(b))/N_geo)) + eqb_log = ln_delta_geo - log(delta_ari) + + def _clamp(low, expr, high): + return Min(Max(low, expr), high) + + meth_kw = { + 'clamped_newton': {'delta_fn': lambda e, x: _clamp( + (sqrt(a*x)-x)*0.99, + -e/e.diff(x), + (sqrt(c*x)-x)*0.99 + )}, + 'halley': {'delta_fn': lambda e, x: (-2*(e*e.diff(x))/(2*e.diff(x)**2 - e*e.diff(x, 2)))}, + 'halley_alt': {'delta_fn': lambda e, x: (-e/e.diff(x)/(1-e/e.diff(x)*e.diff(x,2)/2/e.diff(x)))}, + } + args = eqb_log, b + for use_cse in [False, True]: + kwargs = { + 'params': (b, a, c, N_geo, N_tot), 'itermax': 60, 'debug': True, 'cse': use_cse, + 'counter': i, 'atol': 1e-100, 'rtol': 2e-16, 'bounds': (a,c), + 'handle_nan': Raise(RuntimeError_(QuotedString("encountered NaN."))) + } + func = {k: newtons_method_function(*args, func_name=f"{k}_b", **dict(kwargs, **kw)) for k, kw in meth_kw.items()} + py_mod = {k: py_module(v) for k, v in func.items()} + namespace = {} + root_find_b = {} + for k, v in py_mod.items(): + ns = namespace[k] = {} + exec(v, ns, ns) + root_find_b[k] = ns[f'{k}_b'] + ref = Float('13.2261515064168768938151923226496') + reftol = {'clamped_newton': 2e-16, 'halley': 2e-16, 'halley_alt': 3e-16} + guess = 4.0 + for meth, func in root_find_b.items(): + result = func(guess, 1e-2, 1e2, 50, 100) + req = ref*reftol[meth] + if use_cse: + req *= 2 + assert abs(result - ref) < req diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/test_applications.py b/lib/python3.10/site-packages/sympy/codegen/tests/test_applications.py new file mode 100644 index 0000000000000000000000000000000000000000..26d5d0f699b947db13b658d793f808d632f67a1a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/codegen/tests/test_applications.py @@ -0,0 +1,57 @@ +# This file contains tests that exercise multiple AST nodes + +import tempfile + +from sympy.external import import_module +from sympy.printing.codeprinter import ccode +from sympy.utilities._compilation import compile_link_import_strings, has_c +from sympy.utilities._compilation.util import may_xfail +from sympy.testing.pytest import skip +from sympy.codegen.ast import ( + FunctionDefinition, FunctionPrototype, Variable, Pointer, real, Assignment, + integer, CodeBlock, While +) +from sympy.codegen.cnodes import void, PreIncrement +from sympy.codegen.cutils import render_as_source_file + +cython = import_module('cython') +np = import_module('numpy') + +def _mk_func1(): + declars = n, inp, out = Variable('n', integer), Pointer('inp', real), Pointer('out', real) + i = Variable('i', integer) + whl = While(i2, lambda p: p.base-p.exp) + assert m == [x**2, y-3, z-4] diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/test_pyutils.py b/lib/python3.10/site-packages/sympy/codegen/tests/test_pyutils.py new file mode 100644 index 0000000000000000000000000000000000000000..0a2f0ff358f333635c8d44195a5c39d63ac8f16f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/codegen/tests/test_pyutils.py @@ -0,0 +1,7 @@ +from sympy.codegen.ast import Print +from sympy.codegen.pyutils import render_as_module + +def test_standard(): + ast = Print('x y'.split(), r"coordinate: %12.5g %12.5g\n") + assert render_as_module(ast, standard='python3') == \ + '\n\nprint("coordinate: %12.5g %12.5g\\n" % (x, y), end="")' diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/test_rewriting.py b/lib/python3.10/site-packages/sympy/codegen/tests/test_rewriting.py new file mode 100644 index 0000000000000000000000000000000000000000..51e0c9ecc940f60186cc04d4bf15650281d31cd8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/codegen/tests/test_rewriting.py @@ -0,0 +1,479 @@ +import tempfile +from sympy.core.numbers import pi, Rational +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin, sinc) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.assumptions import assuming, Q +from sympy.external import import_module +from sympy.printing.codeprinter import ccode +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.cfunctions import log2, exp2, expm1, log1p +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.scipy_nodes import cosm1, powm1 +from sympy.codegen.rewriting import ( + optimize, cosm1_opt, log2_opt, exp2_opt, expm1_opt, log1p_opt, powm1_opt, optims_c99, + create_expand_pow_optimization, matinv_opt, logaddexp_opt, logaddexp2_opt, + optims_numpy, optims_scipy, sinc_opts, FuncMinusOneOptim +) +from sympy.testing.pytest import XFAIL, skip +from sympy.utilities import lambdify +from sympy.utilities._compilation import compile_link_import_strings, has_c +from sympy.utilities._compilation.util import may_xfail + +cython = import_module('cython') +numpy = import_module('numpy') +scipy = import_module('scipy') + + +def test_log2_opt(): + x = Symbol('x') + expr1 = 7*log(3*x + 5)/(log(2)) + opt1 = optimize(expr1, [log2_opt]) + assert opt1 == 7*log2(3*x + 5) + assert opt1.rewrite(log) == expr1 + + expr2 = 3*log(5*x + 7)/(13*log(2)) + opt2 = optimize(expr2, [log2_opt]) + assert opt2 == 3*log2(5*x + 7)/13 + assert opt2.rewrite(log) == expr2 + + expr3 = log(x)/log(2) + opt3 = optimize(expr3, [log2_opt]) + assert opt3 == log2(x) + assert opt3.rewrite(log) == expr3 + + expr4 = log(x)/log(2) + log(x+1) + opt4 = optimize(expr4, [log2_opt]) + assert opt4 == log2(x) + log(2)*log2(x+1) + assert opt4.rewrite(log) == expr4 + + expr5 = log(17) + opt5 = optimize(expr5, [log2_opt]) + assert opt5 == expr5 + + expr6 = log(x + 3)/log(2) + opt6 = optimize(expr6, [log2_opt]) + assert str(opt6) == 'log2(x + 3)' + assert opt6.rewrite(log) == expr6 + + +def test_exp2_opt(): + x = Symbol('x') + expr1 = 1 + 2**x + opt1 = optimize(expr1, [exp2_opt]) + assert opt1 == 1 + exp2(x) + assert opt1.rewrite(Pow) == expr1 + + expr2 = 1 + 3**x + assert expr2 == optimize(expr2, [exp2_opt]) + + +def test_expm1_opt(): + x = Symbol('x') + + expr1 = exp(x) - 1 + opt1 = optimize(expr1, [expm1_opt]) + assert expm1(x) - opt1 == 0 + assert opt1.rewrite(exp) == expr1 + + expr2 = 3*exp(x) - 3 + opt2 = optimize(expr2, [expm1_opt]) + assert 3*expm1(x) == opt2 + assert opt2.rewrite(exp) == expr2 + + expr3 = 3*exp(x) - 5 + opt3 = optimize(expr3, [expm1_opt]) + assert 3*expm1(x) - 2 == opt3 + assert opt3.rewrite(exp) == expr3 + expm1_opt_non_opportunistic = FuncMinusOneOptim(exp, expm1, opportunistic=False) + assert expr3 == optimize(expr3, [expm1_opt_non_opportunistic]) + assert opt1 == optimize(expr1, [expm1_opt_non_opportunistic]) + assert opt2 == optimize(expr2, [expm1_opt_non_opportunistic]) + + expr4 = 3*exp(x) + log(x) - 3 + opt4 = optimize(expr4, [expm1_opt]) + assert 3*expm1(x) + log(x) == opt4 + assert opt4.rewrite(exp) == expr4 + + expr5 = 3*exp(2*x) - 3 + opt5 = optimize(expr5, [expm1_opt]) + assert 3*expm1(2*x) == opt5 + assert opt5.rewrite(exp) == expr5 + + expr6 = (2*exp(x) + 1)/(exp(x) + 1) + 1 + opt6 = optimize(expr6, [expm1_opt]) + assert opt6.count_ops() <= expr6.count_ops() + + def ev(e): + return e.subs(x, 3).evalf() + assert abs(ev(expr6) - ev(opt6)) < 1e-15 + + y = Symbol('y') + expr7 = (2*exp(x) - 1)/(1 - exp(y)) - 1/(1-exp(y)) + opt7 = optimize(expr7, [expm1_opt]) + assert -2*expm1(x)/expm1(y) == opt7 + assert (opt7.rewrite(exp) - expr7).factor() == 0 + + expr8 = (1+exp(x))**2 - 4 + opt8 = optimize(expr8, [expm1_opt]) + tgt8a = (exp(x) + 3)*expm1(x) + tgt8b = 2*expm1(x) + expm1(2*x) + # Both tgt8a & tgt8b seem to give full precision (~16 digits for double) + # for x=1e-7 (compare with expr8 which only achieves ~8 significant digits). + # If we can show that either tgt8a or tgt8b is preferable, we can + # change this test to ensure the preferable version is returned. + assert (tgt8a - tgt8b).rewrite(exp).factor() == 0 + assert opt8 in (tgt8a, tgt8b) + assert (opt8.rewrite(exp) - expr8).factor() == 0 + + expr9 = sin(expr8) + opt9 = optimize(expr9, [expm1_opt]) + tgt9a = sin(tgt8a) + tgt9b = sin(tgt8b) + assert opt9 in (tgt9a, tgt9b) + assert (opt9.rewrite(exp) - expr9.rewrite(exp)).factor().is_zero + + +def test_expm1_two_exp_terms(): + x, y = map(Symbol, 'x y'.split()) + expr1 = exp(x) + exp(y) - 2 + opt1 = optimize(expr1, [expm1_opt]) + assert opt1 == expm1(x) + expm1(y) + + +def test_cosm1_opt(): + x = Symbol('x') + + expr1 = cos(x) - 1 + opt1 = optimize(expr1, [cosm1_opt]) + assert cosm1(x) - opt1 == 0 + assert opt1.rewrite(cos) == expr1 + + expr2 = 3*cos(x) - 3 + opt2 = optimize(expr2, [cosm1_opt]) + assert 3*cosm1(x) == opt2 + assert opt2.rewrite(cos) == expr2 + + expr3 = 3*cos(x) - 5 + opt3 = optimize(expr3, [cosm1_opt]) + assert 3*cosm1(x) - 2 == opt3 + assert opt3.rewrite(cos) == expr3 + cosm1_opt_non_opportunistic = FuncMinusOneOptim(cos, cosm1, opportunistic=False) + assert expr3 == optimize(expr3, [cosm1_opt_non_opportunistic]) + assert opt1 == optimize(expr1, [cosm1_opt_non_opportunistic]) + assert opt2 == optimize(expr2, [cosm1_opt_non_opportunistic]) + + expr4 = 3*cos(x) + log(x) - 3 + opt4 = optimize(expr4, [cosm1_opt]) + assert 3*cosm1(x) + log(x) == opt4 + assert opt4.rewrite(cos) == expr4 + + expr5 = 3*cos(2*x) - 3 + opt5 = optimize(expr5, [cosm1_opt]) + assert 3*cosm1(2*x) == opt5 + assert opt5.rewrite(cos) == expr5 + + expr6 = 2 - 2*cos(x) + opt6 = optimize(expr6, [cosm1_opt]) + assert -2*cosm1(x) == opt6 + assert opt6.rewrite(cos) == expr6 + + +def test_cosm1_two_cos_terms(): + x, y = map(Symbol, 'x y'.split()) + expr1 = cos(x) + cos(y) - 2 + opt1 = optimize(expr1, [cosm1_opt]) + assert opt1 == cosm1(x) + cosm1(y) + + +def test_expm1_cosm1_mixed(): + x = Symbol('x') + expr1 = exp(x) + cos(x) - 2 + opt1 = optimize(expr1, [expm1_opt, cosm1_opt]) + assert opt1 == cosm1(x) + expm1(x) + + +def _check_num_lambdify(expr, opt, val_subs, approx_ref, lambdify_kw=None, poorness=1e10): + """ poorness=1e10 signifies that `expr` loses precision of at least ten decimal digits. """ + num_ref = expr.subs(val_subs).evalf() + eps = numpy.finfo(numpy.float64).eps + assert abs(num_ref - approx_ref) < approx_ref*eps + f1 = lambdify(list(val_subs.keys()), opt, **(lambdify_kw or {})) + args_float = tuple(map(float, val_subs.values())) + num_err1 = abs(f1(*args_float) - approx_ref) + assert num_err1 < abs(num_ref*eps) + f2 = lambdify(list(val_subs.keys()), expr, **(lambdify_kw or {})) + num_err2 = abs(f2(*args_float) - approx_ref) + assert num_err2 > abs(num_ref*eps*poorness) # this only ensures that the *test* works as intended + + +def test_cosm1_apart(): + x = Symbol('x') + + expr1 = 1/cos(x) - 1 + opt1 = optimize(expr1, [cosm1_opt]) + assert opt1 == -cosm1(x)/cos(x) + if scipy: + _check_num_lambdify(expr1, opt1, {x: S(10)**-30}, 5e-61, lambdify_kw={"modules": 'scipy'}) + + expr2 = 2/cos(x) - 2 + opt2 = optimize(expr2, optims_scipy) + assert opt2 == -2*cosm1(x)/cos(x) + if scipy: + _check_num_lambdify(expr2, opt2, {x: S(10)**-30}, 1e-60, lambdify_kw={"modules": 'scipy'}) + + expr3 = pi/cos(3*x) - pi + opt3 = optimize(expr3, [cosm1_opt]) + assert opt3 == -pi*cosm1(3*x)/cos(3*x) + if scipy: + _check_num_lambdify(expr3, opt3, {x: S(10)**-30/3}, float(5e-61*pi), lambdify_kw={"modules": 'scipy'}) + + +def test_powm1(): + args = x, y = map(Symbol, "xy") + + expr1 = x**y - 1 + opt1 = optimize(expr1, [powm1_opt]) + assert opt1 == powm1(x, y) + for arg in args: + assert expr1.diff(arg) == opt1.diff(arg) + if scipy and tuple(map(int, scipy.version.version.split('.')[:3])) >= (1, 10, 0): + subs1_a = {x: Rational(*(1.0+1e-13).as_integer_ratio()), y: pi} + ref1_f64_a = 3.139081648208105e-13 + _check_num_lambdify(expr1, opt1, subs1_a, ref1_f64_a, lambdify_kw={"modules": 'scipy'}, poorness=10**11) + + subs1_b = {x: pi, y: Rational(*(1e-10).as_integer_ratio())} + ref1_f64_b = 1.1447298859149205e-10 + _check_num_lambdify(expr1, opt1, subs1_b, ref1_f64_b, lambdify_kw={"modules": 'scipy'}, poorness=10**9) + + +def test_log1p_opt(): + x = Symbol('x') + expr1 = log(x + 1) + opt1 = optimize(expr1, [log1p_opt]) + assert log1p(x) - opt1 == 0 + assert opt1.rewrite(log) == expr1 + + expr2 = log(3*x + 3) + opt2 = optimize(expr2, [log1p_opt]) + assert log1p(x) + log(3) == opt2 + assert (opt2.rewrite(log) - expr2).simplify() == 0 + + expr3 = log(2*x + 1) + opt3 = optimize(expr3, [log1p_opt]) + assert log1p(2*x) - opt3 == 0 + assert opt3.rewrite(log) == expr3 + + expr4 = log(x+3) + opt4 = optimize(expr4, [log1p_opt]) + assert str(opt4) == 'log(x + 3)' + + +def test_optims_c99(): + x = Symbol('x') + + expr1 = 2**x + log(x)/log(2) + log(x + 1) + exp(x) - 1 + opt1 = optimize(expr1, optims_c99).simplify() + assert opt1 == exp2(x) + log2(x) + log1p(x) + expm1(x) + assert opt1.rewrite(exp).rewrite(log).rewrite(Pow) == expr1 + + expr2 = log(x)/log(2) + log(x + 1) + opt2 = optimize(expr2, optims_c99) + assert opt2 == log2(x) + log1p(x) + assert opt2.rewrite(log) == expr2 + + expr3 = log(x)/log(2) + log(17*x + 17) + opt3 = optimize(expr3, optims_c99) + delta3 = opt3 - (log2(x) + log(17) + log1p(x)) + assert delta3 == 0 + assert (opt3.rewrite(log) - expr3).simplify() == 0 + + expr4 = 2**x + 3*log(5*x + 7)/(13*log(2)) + 11*exp(x) - 11 + log(17*x + 17) + opt4 = optimize(expr4, optims_c99).simplify() + delta4 = opt4 - (exp2(x) + 3*log2(5*x + 7)/13 + 11*expm1(x) + log(17) + log1p(x)) + assert delta4 == 0 + assert (opt4.rewrite(exp).rewrite(log).rewrite(Pow) - expr4).simplify() == 0 + + expr5 = 3*exp(2*x) - 3 + opt5 = optimize(expr5, optims_c99) + delta5 = opt5 - 3*expm1(2*x) + assert delta5 == 0 + assert opt5.rewrite(exp) == expr5 + + expr6 = exp(2*x) - 3 + opt6 = optimize(expr6, optims_c99) + assert opt6 in (expm1(2*x) - 2, expr6) # expm1(2*x) - 2 is not better or worse + + expr7 = log(3*x + 3) + opt7 = optimize(expr7, optims_c99) + delta7 = opt7 - (log(3) + log1p(x)) + assert delta7 == 0 + assert (opt7.rewrite(log) - expr7).simplify() == 0 + + expr8 = log(2*x + 3) + opt8 = optimize(expr8, optims_c99) + assert opt8 == expr8 + + +def test_create_expand_pow_optimization(): + cc = lambda x: ccode( + optimize(x, [create_expand_pow_optimization(4)])) + x = Symbol('x') + assert cc(x**4) == 'x*x*x*x' + assert cc(x**4 + x**2) == 'x*x + x*x*x*x' + assert cc(x**5 + x**4) == 'pow(x, 5) + x*x*x*x' + assert cc(sin(x)**4) == 'pow(sin(x), 4)' + # gh issue 15335 + assert cc(x**(-4)) == '1.0/(x*x*x*x)' + assert cc(x**(-5)) == 'pow(x, -5)' + assert cc(-x**4) == '-(x*x*x*x)' + assert cc(x**4 - x**2) == '-(x*x) + x*x*x*x' + i = Symbol('i', integer=True) + assert cc(x**i - x**2) == 'pow(x, i) - (x*x)' + y = Symbol('y', real=True) + assert cc(Abs(exp(y**4))) == "exp(y*y*y*y)" + + # gh issue 20753 + cc2 = lambda x: ccode(optimize(x, [create_expand_pow_optimization( + 4, base_req=lambda b: b.is_Function)])) + assert cc2(x**3 + sin(x)**3) == "pow(x, 3) + sin(x)*sin(x)*sin(x)" + + +def test_matsolve(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + x = MatrixSymbol('x', n, 1) + + with assuming(Q.fullrank(A)): + assert optimize(A**(-1) * x, [matinv_opt]) == MatrixSolve(A, x) + assert optimize(A**(-1) * x + x, [matinv_opt]) == MatrixSolve(A, x) + x + + +def test_logaddexp_opt(): + x, y = map(Symbol, 'x y'.split()) + expr1 = log(exp(x) + exp(y)) + opt1 = optimize(expr1, [logaddexp_opt]) + assert logaddexp(x, y) - opt1 == 0 + assert logaddexp(y, x) - opt1 == 0 + assert opt1.rewrite(log) == expr1 + + +def test_logaddexp2_opt(): + x, y = map(Symbol, 'x y'.split()) + expr1 = log(2**x + 2**y)/log(2) + opt1 = optimize(expr1, [logaddexp2_opt]) + assert logaddexp2(x, y) - opt1 == 0 + assert logaddexp2(y, x) - opt1 == 0 + assert opt1.rewrite(log) == expr1 + + +def test_sinc_opts(): + def check(d): + for k, v in d.items(): + assert optimize(k, sinc_opts) == v + + x = Symbol('x') + check({ + sin(x)/x : sinc(x), + sin(2*x)/(2*x) : sinc(2*x), + sin(3*x)/x : 3*sinc(3*x), + x*sin(x) : x*sin(x) + }) + + y = Symbol('y') + check({ + sin(x*y)/(x*y) : sinc(x*y), + y*sin(x/y)/x : sinc(x/y), + sin(sin(x))/sin(x) : sinc(sin(x)), + sin(3*sin(x))/sin(x) : 3*sinc(3*sin(x)), + sin(x)/y : sin(x)/y + }) + + +def test_optims_numpy(): + def check(d): + for k, v in d.items(): + assert optimize(k, optims_numpy) == v + + x = Symbol('x') + check({ + sin(2*x)/(2*x) + exp(2*x) - 1: sinc(2*x) + expm1(2*x), + log(x+3)/log(2) + log(x**2 + 1): log1p(x**2) + log2(x+3) + }) + + +@XFAIL # room for improvement, ideally this test case should pass. +def test_optims_numpy_TODO(): + def check(d): + for k, v in d.items(): + assert optimize(k, optims_numpy) == v + + x, y = map(Symbol, 'x y'.split()) + check({ + log(x*y)*sin(x*y)*log(x*y+1)/(log(2)*x*y): log2(x*y)*sinc(x*y)*log1p(x*y), + exp(x*sin(y)/y) - 1: expm1(x*sinc(y)) + }) + + +@may_xfail +def test_compiled_ccode_with_rewriting(): + if not cython: + skip("cython not installed.") + if not has_c(): + skip("No C compiler found.") + + x = Symbol('x') + about_two = 2**(58/S(117))*3**(97/S(117))*5**(4/S(39))*7**(92/S(117))/S(30)*pi + # about_two: 1.999999999999581826 + unchanged = 2*exp(x) - about_two + xval = S(10)**-11 + ref = unchanged.subs(x, xval).n(19) # 2.0418173913673213e-11 + + rewritten = optimize(2*exp(x) - about_two, [expm1_opt]) + + # Unfortunately, we need to call ``.n()`` on our expressions before we hand them + # to ``ccode``, and we need to request a large number of significant digits. + # In this test, results converged for double precision when the following number + # of significant digits were chosen: + NUMBER_OF_DIGITS = 25 # TODO: this should ideally be automatically handled. + + func_c = ''' +#include + +double func_unchanged(double x) { + return %(unchanged)s; +} +double func_rewritten(double x) { + return %(rewritten)s; +} +''' % {"unchanged": ccode(unchanged.n(NUMBER_OF_DIGITS)), + "rewritten": ccode(rewritten.n(NUMBER_OF_DIGITS))} + + func_pyx = ''' +#cython: language_level=3 +cdef extern double func_unchanged(double) +cdef extern double func_rewritten(double) +def py_unchanged(x): + return func_unchanged(x) +def py_rewritten(x): + return func_rewritten(x) +''' + with tempfile.TemporaryDirectory() as folder: + mod, info = compile_link_import_strings( + [('func.c', func_c), ('_func.pyx', func_pyx)], + build_dir=folder, compile_kwargs={"std": 'c99'} + ) + err_rewritten = abs(mod.py_rewritten(1e-11) - ref) + err_unchanged = abs(mod.py_unchanged(1e-11) - ref) + assert 1e-27 < err_rewritten < 1e-25 # highly accurate. + assert 1e-19 < err_unchanged < 1e-16 # quite poor. + + # Tolerances used above were determined as follows: + # >>> no_opt = unchanged.subs(x, xval.evalf()).evalf() + # >>> with_opt = rewritten.n(25).subs(x, 1e-11).evalf() + # >>> with_opt - ref, no_opt - ref + # (1.1536301877952077e-26, 1.6547074214222335e-18) diff --git a/lib/python3.10/site-packages/sympy/codegen/tests/test_scipy_nodes.py b/lib/python3.10/site-packages/sympy/codegen/tests/test_scipy_nodes.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d1461037eec81ade0c99b18fbbf5a4517ce0b7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/codegen/tests/test_scipy_nodes.py @@ -0,0 +1,44 @@ +from itertools import product +from sympy.core.power import Pow +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.trigonometric import cos +from sympy.core.numbers import pi +from sympy.codegen.scipy_nodes import cosm1, powm1 + +x, y, z = symbols('x y z') + + +def test_cosm1(): + cm1_xy = cosm1(x*y) + ref_xy = cos(x*y) - 1 + for wrt, deriv_order in product([x, y, z], range(3)): + assert ( + cm1_xy.diff(wrt, deriv_order) - + ref_xy.diff(wrt, deriv_order) + ).rewrite(cos).simplify() == 0 + + expr_minus2 = cosm1(pi) + assert expr_minus2.rewrite(cos) == -2 + assert cosm1(3.14).simplify() == cosm1(3.14) # cannot simplify with 3.14 + assert cosm1(pi/2).simplify() == -1 + assert (1/cos(x) - 1 + cosm1(x)/cos(x)).simplify() == 0 + + +def test_powm1(): + cases = { + powm1(x, y): x**y - 1, + powm1(x*y, z): (x*y)**z - 1, + powm1(x, y*z): x**(y*z)-1, + powm1(x*y*z, x*y*z): (x*y*z)**(x*y*z)-1 + } + for pm1_e, ref_e in cases.items(): + for wrt, deriv_order in product([x, y, z], range(3)): + der = pm1_e.diff(wrt, deriv_order) + ref = ref_e.diff(wrt, deriv_order) + delta = (der - ref).rewrite(Pow) + assert delta.simplify() == 0 + + eulers_constant_m1 = powm1(x, 1/log(x)) + assert eulers_constant_m1.rewrite(Pow) == exp(1) - 1 + assert eulers_constant_m1.simplify() == exp(1) - 1 diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..421cc99291f289233ae4abfc1d09613c0c2d40aa Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/fp_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/fp_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6a394046c78197e6217430533fd43d4ed37972b8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/fp_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/free_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/free_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c7f0582f540d132b9c04b03977f677eb200c733 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/free_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/galois.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/galois.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffb417d0c20c1c244642fb81c4e7ab9752cb9da0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/galois.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/generators.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/generators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61fb889e071e50e4af7c34fc0f91129eb725a38d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/generators.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/graycode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/graycode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ead365f8517a218b74f2e6badafa7aa5f1391eee Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/graycode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/group_constructs.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/group_constructs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c2f1db6806cf9ab5426706dae976baca99ab10a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/group_constructs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/group_numbers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/group_numbers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7980454d135bd16c7cff131c7b03e4e3078dcc1f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/group_numbers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/homomorphisms.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/homomorphisms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dbd31b5dcdfad5d20533c544e98de7195710e3f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/homomorphisms.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/named_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/named_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b43dcbcc21d5baf431601604cf62c0d632d8f9ec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/named_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/partitions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/partitions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de356aa2f2016acff4d44f873db53791f0fcbefe Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/partitions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/pc_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/pc_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6589351ae173d50327205d1a1a3ad5a99cce3346 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/pc_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/permutations.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/permutations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fba800551bed0ffd6446716281cbc2b00bdaa58 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/permutations.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/polyhedron.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/polyhedron.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3a8de0d27fce709c987dd2b1842f00d31cd908c4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/polyhedron.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/prufer.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/prufer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..552d1bb56780fe03c14c8a91ee1da4221e35338c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/prufer.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/rewritingsystem.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/rewritingsystem.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..859b3dd02b697e556cf872715d58deaa9b315b9f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/rewritingsystem.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/rewritingsystem_fsm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/rewritingsystem_fsm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bd6423c3c5a6c33eb5b060f0da05971e00fba6a9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/rewritingsystem_fsm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/schur_number.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/schur_number.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..21a74e89c2dd959f38f95b4ccf3cbfb8a952e256 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/schur_number.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/subsets.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/subsets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..50e8e48043cce89397ec23da88701e3f444dabcf Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/subsets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/tensor_can.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/tensor_can.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..02a6d001ea6ba96857afd123b41e074d2604aab1 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/tensor_can.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/testutil.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/testutil.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..919e3bece6d42bdf2edcedda67bdf9a5b6cbed09 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/testutil.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/util.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5321a02b349ba3c13a99bdbffbbb1ef6f41dead Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/__pycache__/util.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bc5d1a5eac123be45c6f97a3bebaa5a297a8150 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_coset_table.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_coset_table.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..780af67ae86a7339692ed58d64714f5981f863a6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_coset_table.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_fp_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_fp_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..365f89a541a13333610d4a24e20ec7e51d041c5f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_fp_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_free_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_free_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e0d48c493ce4838ce4e62c3e650bd12a7b7101e3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_free_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_galois.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_galois.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf0645292d24019e7293febdc4f663cdf827d91b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_galois.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_generators.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_generators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e01740e1cf172e40ace9694968b86d61731f42e7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_generators.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_graycode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_graycode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..198fd0c7c2e3f1aecfcacae1b296a4cd564c1e81 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_graycode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_group_constructs.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_group_constructs.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fb4514ead705b7f93e0cbc1c82e02ceb4e8c0d6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_group_constructs.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_group_numbers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_group_numbers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94f54ce6916b0b9628033421a2540a342d97c9c4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_group_numbers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_homomorphisms.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_homomorphisms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a2e000080e09bc9cc3603586c39d82fbd169c4a4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_homomorphisms.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_named_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_named_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eeba681a52cb05d5b26bfd22afb429f0bc0b62e5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_named_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_partitions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_partitions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa8ba8f6e53f84b53f3e55afa8acdc3835e55d7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_partitions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_pc_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_pc_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4c530c8e150961d21998a1b76f477a3be1fa4a5d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_pc_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_perm_groups.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_perm_groups.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de5eb2429c3a60c861a246b7e2cccd9dd20d8e7a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_perm_groups.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_permutations.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_permutations.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..006cbb77cabb48fb7d08a4e8f5b4d6f6941af667 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_permutations.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_polyhedron.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_polyhedron.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68d49233cbc9be7ac1ac390fd262485da3d93883 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_polyhedron.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_prufer.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_prufer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e5ef9a5aa87230c52a1e698c72f5a727420101c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_prufer.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_rewriting.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_rewriting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f4e6974eb8225a1edbc3b96a8c113632822918ec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_rewriting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_schur_number.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_schur_number.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab3386d644744051f4073539e58bec9f2f180765 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_schur_number.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_subsets.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_subsets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e282597e8d30adf7f2aa880cee0acdb3c0fe2280 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_subsets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_tensor_can.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_tensor_can.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56e4d6cf97b17ae995dee1d81a7968531cbb299f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_tensor_can.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_testutil.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_testutil.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..151f799143d80b12c88c44f1674b65b9e2ec1225 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_testutil.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_util.cpython-310.pyc b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a375d5ad4131ac14174b3141d655fe66a7be596 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/combinatorics/tests/__pycache__/test_util.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/hydrogen.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/hydrogen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dbc42c5ad21c41f7da2a5b9276c2b9762a64ddd4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/hydrogen.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/matrices.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/matrices.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edefe62b8b0d8a9df1da3a7ea0222981f403cfdc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/matrices.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/paulialgebra.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/paulialgebra.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4bb71ae4ddee3013392efc77177d3e92f814bc2e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/paulialgebra.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/pring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/pring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6826af40db603ce5e53fa7d8f6e0fe13f6f9b34e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/pring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/qho_1d.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/qho_1d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cd36b3db5362d35e2d2749043dff8dfcedfbc6d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/qho_1d.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/secondquant.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/secondquant.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38e771458dde9a965a629770e7275975e3cee268 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/secondquant.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/sho.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/sho.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..548371ae4a9ed0ec9a3c683436e01ed94ce4caaf Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/sho.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/__pycache__/wigner.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/__pycache__/wigner.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6c66356113fdb628b280250e92ff2603d690ad1f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/__pycache__/wigner.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/biomechanics/__init__.py b/lib/python3.10/site-packages/sympy/physics/biomechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0f687cc23c1862b65e55117841cfd7d2b8e3f0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/biomechanics/__init__.py @@ -0,0 +1,53 @@ +"""Biomechanics extension for SymPy. + +Includes biomechanics-related constructs which allows users to extend multibody +models created using `sympy.physics.mechanics` into biomechanical or +musculoskeletal models involding musculotendons and activation dynamics. + +""" + +from .activation import ( + ActivationBase, + FirstOrderActivationDeGroote2016, + ZerothOrderActivation, +) +from .curve import ( + CharacteristicCurveCollection, + CharacteristicCurveFunction, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from .musculotendon import ( + MusculotendonBase, + MusculotendonDeGroote2016, + MusculotendonFormulation, +) + + +__all__ = [ + # Musculotendon characteristic curve functions + 'CharacteristicCurveCollection', + 'CharacteristicCurveFunction', + 'FiberForceLengthActiveDeGroote2016', + 'FiberForceLengthPassiveDeGroote2016', + 'FiberForceLengthPassiveInverseDeGroote2016', + 'FiberForceVelocityDeGroote2016', + 'FiberForceVelocityInverseDeGroote2016', + 'TendonForceLengthDeGroote2016', + 'TendonForceLengthInverseDeGroote2016', + + # Activation dynamics classes + 'ActivationBase', + 'FirstOrderActivationDeGroote2016', + 'ZerothOrderActivation', + + # Musculotendon classes + 'MusculotendonBase', + 'MusculotendonDeGroote2016', + 'MusculotendonFormulation', +] diff --git a/lib/python3.10/site-packages/sympy/physics/biomechanics/_mixin.py b/lib/python3.10/site-packages/sympy/physics/biomechanics/_mixin.py new file mode 100644 index 0000000000000000000000000000000000000000..f6ff905100fb4d6f346aaf717cfe9a66b4c2cc9a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/biomechanics/_mixin.py @@ -0,0 +1,53 @@ +"""Mixin classes for sharing functionality between unrelated classes. + +This module is named with a leading underscore to signify to users that it's +"private" and only intended for internal use by the biomechanics module. + +""" + + +__all__ = ['_NamedMixin'] + + +class _NamedMixin: + """Mixin class for adding `name` properties. + + Valid names, as will typically be used by subclasses as a suffix when + naming automatically-instantiated symbol attributes, must be nonzero length + strings. + + Attributes + ========== + + name : str + The name identifier associated with the instance. Must be a string of + length at least 1. + + """ + + @property + def name(self) -> str: + """The name associated with the class instance.""" + return self._name + + @name.setter + def name(self, name: str) -> None: + if hasattr(self, '_name'): + msg = ( + f'Can\'t set attribute `name` to {repr(name)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(name, str): + msg = ( + f'Name {repr(name)} passed to `name` was of type ' + f'{type(name)}, must be {str}.' + ) + raise TypeError(msg) + if name in {''}: + msg = ( + f'Name {repr(name)} is invalid, must be a nonzero length ' + f'{type(str)}.' + ) + raise ValueError(msg) + self._name = name diff --git a/lib/python3.10/site-packages/sympy/physics/biomechanics/activation.py b/lib/python3.10/site-packages/sympy/physics/biomechanics/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..36005cc532144a48b0c2732eba5679a23e83b3c4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/biomechanics/activation.py @@ -0,0 +1,869 @@ +r"""Activation dynamics for musclotendon models. + +Musculotendon models are able to produce active force when they are activated, +which is when a chemical process has taken place within the muscle fibers +causing them to voluntarily contract. Biologically this chemical process (the +diffusion of :math:`\textrm{Ca}^{2+}` ions) is not the input in the system, +electrical signals from the nervous system are. These are termed excitations. +Activation dynamics, which relates the normalized excitation level to the +normalized activation level, can be modeled by the models present in this +module. + +""" + +from abc import ABC, abstractmethod +from functools import cached_property + +from sympy.core.symbol import Symbol +from sympy.core.numbers import Float, Integer, Rational +from sympy.functions.elementary.hyperbolic import tanh +from sympy.matrices.dense import MutableDenseMatrix as Matrix, zeros +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics import dynamicsymbols + + +__all__ = [ + 'ActivationBase', + 'FirstOrderActivationDeGroote2016', + 'ZerothOrderActivation', +] + + +class ActivationBase(ABC, _NamedMixin): + """Abstract base class for all activation dynamics classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom activation dynamics types through + subclassing. + + """ + + def __init__(self, name): + """Initializer for ``ActivationBase``.""" + self.name = str(name) + + # Symbols + self._e = dynamicsymbols(f"e_{name}") + self._a = dynamicsymbols(f"a_{name}") + + @classmethod + @abstractmethod + def with_defaults(cls, name): + """Alternate constructor that provides recommended defaults for + constants.""" + pass + + @property + def excitation(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``e`` can also be used to access the same attribute. + + """ + return self._e + + @property + def e(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``excitation`` can also be used to access the same attribute. + + """ + return self._e + + @property + def activation(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``a`` can also be used to access the same attribute. + + """ + return self._a + + @property + def a(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``activation`` can also be used to access the same attribute. + + """ + return self._a + + @property + @abstractmethod + def order(self): + """Order of the (differential) equation governing activation.""" + pass + + @property + @abstractmethod + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``p`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``constants`` can also be used to access the same attribute. + + """ + pass + + @property + @abstractmethod + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + @property + @abstractmethod + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + @abstractmethod + def rhs(self): + """ + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + pass + + def __eq__(self, other): + """Equality check for activation dynamics.""" + if type(self) != type(other): + return False + if self.name != other.name: + return False + return True + + def __repr__(self): + """Default representation of activation dynamics.""" + return f'{self.__class__.__name__}({self.name!r})' + + +class ZerothOrderActivation(ActivationBase): + """Simple zeroth-order activation dynamics mapping excitation to + activation. + + Explanation + =========== + + Zeroth-order activation dynamics are useful in instances where you want to + reduce the complexity of your musculotendon dynamics as they simple map + exictation to activation. As a result, no additional state equations are + introduced to your system. They also remove a potential source of delay + between the input and dynamics of your system as no (ordinary) differential + equations are involed. + + """ + + def __init__(self, name): + """Initializer for ``ZerothOrderActivation``. + + Parameters + ========== + + name : str + The name identifier associated with the instance. Must be a string + of length at least 1. + + """ + super().__init__(name) + + # Zeroth-order activation dynamics has activation equal excitation so + # overwrite the symbol for activation with the excitation symbol. + self._a = self._e + + @classmethod + def with_defaults(cls, name): + """Alternate constructor that provides recommended defaults for + constants. + + Explanation + =========== + + As this concrete class doesn't implement any constants associated with + its dynamics, this ``classmethod`` simply creates a standard instance + of ``ZerothOrderActivation``. An implementation is provided to ensure + a consistent interface between all ``ActivationBase`` concrete classes. + + """ + return cls(name) + + @property + def order(self): + """Order of the (differential) equation governing activation.""" + return 0 + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated state variables and so this + property return an empty column ``Matrix`` with shape (0, 1). + + The alias ``x`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated state variables and so this + property return an empty column ``Matrix`` with shape (0, 1). + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + Excitation is the only input in zeroth-order activation dynamics and so + this property returns a column ``Matrix`` with one entry, ``e``, and + shape (1, 1). + + The alias ``r`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + Excitation is the only input in zeroth-order activation dynamics and so + this property returns a column ``Matrix`` with one entry, ``e``, and + shape (1, 1). + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated constants and so this property + return an empty column ``Matrix`` with shape (0, 1). + + The alias ``p`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + As zeroth-order activation dynamics simply maps excitation to + activation, this class has no associated constants and so this property + return an empty column ``Matrix`` with shape (0, 1). + + The alias ``constants`` can also be used to access the same attribute. + + """ + return zeros(0, 1) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``M`` is an empty square + ``Matrix`` with shape (0, 0). + + """ + return Matrix([]) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``F`` is an empty column + ``Matrix`` with shape (0, 1). + + """ + return zeros(0, 1) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear has dimension 0 and therefore this method returns an empty + column ``Matrix`` with shape (0, 1). + + """ + return zeros(0, 1) + + +class FirstOrderActivationDeGroote2016(ActivationBase): + r"""First-order activation dynamics based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the first-order activation dynamics equation for the rate of change + of activation with respect to time as a function of excitation and + activation. + + The function is defined by the equation: + + .. math:: + + \frac{da}{dt} = \left(\frac{\frac{1}{2} + a0}{\tau_a \left(\frac{1}{2} + + \frac{3a}{2}\right)} + \frac{\left(\frac{1}{2} + + \frac{3a}{2}\right) \left(\frac{1}{2} - a0\right)}{\tau_d}\right) + \left(e - a\right) + + where + + .. math:: + + a0 = \frac{\tanh{\left(b \left(e - a\right) \right)}}{2} + + with constant values of :math:`tau_a = 0.015`, :math:`tau_d = 0.060`, and + :math:`b = 10`. + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + def __init__(self, + name, + activation_time_constant=None, + deactivation_time_constant=None, + smoothing_rate=None, + ): + """Initializer for ``FirstOrderActivationDeGroote2016``. + + Parameters + ========== + activation time constant : Symbol | Number | None + The value of the activation time constant governing the delay + between excitation and activation when excitation exceeds + activation. + deactivation time constant : Symbol | Number | None + The value of the deactivation time constant governing the delay + between excitation and activation when activation exceeds + excitation. + smoothing_rate : Symbol | Number | None + The slope of the hyperbolic tangent function used to smooth between + the switching of the equations where excitation exceed activation + and where activation exceeds excitation. The recommended value to + use is ``10``, but values between ``0.1`` and ``100`` can be used. + + """ + super().__init__(name) + + # Symbols + self.activation_time_constant = activation_time_constant + self.deactivation_time_constant = deactivation_time_constant + self.smoothing_rate = smoothing_rate + + @classmethod + def with_defaults(cls, name): + r"""Alternate constructor that will use the published constants. + + Explanation + =========== + + Returns an instance of ``FirstOrderActivationDeGroote2016`` using the + three constant values specified in the original publication. + + These have the values: + + :math:`tau_a = 0.015` + :math:`tau_d = 0.060` + :math:`b = 10` + + """ + tau_a = Float('0.015') + tau_d = Float('0.060') + b = Float('10.0') + return cls(name, tau_a, tau_d, b) + + @property + def activation_time_constant(self): + """Delay constant for activation. + + Explanation + =========== + + The alias ```tau_a`` can also be used to access the same attribute. + + """ + return self._tau_a + + @activation_time_constant.setter + def activation_time_constant(self, tau_a): + if hasattr(self, '_tau_a'): + msg = ( + f'Can\'t set attribute `activation_time_constant` to ' + f'{repr(tau_a)} as it is immutable and already has value ' + f'{self._tau_a}.' + ) + raise AttributeError(msg) + self._tau_a = Symbol(f'tau_a_{self.name}') if tau_a is None else tau_a + + @property + def tau_a(self): + """Delay constant for activation. + + Explanation + =========== + + The alias ``activation_time_constant`` can also be used to access the + same attribute. + + """ + return self._tau_a + + @property + def deactivation_time_constant(self): + """Delay constant for deactivation. + + Explanation + =========== + + The alias ``tau_d`` can also be used to access the same attribute. + + """ + return self._tau_d + + @deactivation_time_constant.setter + def deactivation_time_constant(self, tau_d): + if hasattr(self, '_tau_d'): + msg = ( + f'Can\'t set attribute `deactivation_time_constant` to ' + f'{repr(tau_d)} as it is immutable and already has value ' + f'{self._tau_d}.' + ) + raise AttributeError(msg) + self._tau_d = Symbol(f'tau_d_{self.name}') if tau_d is None else tau_d + + @property + def tau_d(self): + """Delay constant for deactivation. + + Explanation + =========== + + The alias ``deactivation_time_constant`` can also be used to access the + same attribute. + + """ + return self._tau_d + + @property + def smoothing_rate(self): + """Smoothing constant for the hyperbolic tangent term. + + Explanation + =========== + + The alias ``b`` can also be used to access the same attribute. + + """ + return self._b + + @smoothing_rate.setter + def smoothing_rate(self, b): + if hasattr(self, '_b'): + msg = ( + f'Can\'t set attribute `smoothing_rate` to {b!r} as it is ' + f'immutable and already has value {self._b!r}.' + ) + raise AttributeError(msg) + self._b = Symbol(f'b_{self.name}') if b is None else b + + @property + def b(self): + """Smoothing constant for the hyperbolic tangent term. + + Explanation + =========== + + The alias ``smoothing_rate`` can also be used to access the same + attribute. + + """ + return self._b + + @property + def order(self): + """Order of the (differential) equation governing activation.""" + return 1 + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + return Matrix([self._a]) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._a]) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + return Matrix([self._e]) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + Explanation + =========== + + The alias ``p`` can also be used to access the same attribute. + + """ + constants = [self._tau_a, self._tau_d, self._b] + symbolic_constants = [c for c in constants if not c.is_number] + return Matrix(symbolic_constants) if symbolic_constants else zeros(0, 1) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``constants`` can also be used to access the same attribute. + + """ + constants = [self._tau_a, self._tau_d, self._b] + symbolic_constants = [c for c in constants if not c.is_number] + return Matrix(symbolic_constants) if symbolic_constants else zeros(0, 1) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([Integer(1)]) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([self._da_eqn]) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + """ + return Matrix([self._da_eqn]) + + @cached_property + def _da_eqn(self): + HALF = Rational(1, 2) + a0 = HALF * tanh(self._b * (self._e - self._a)) + a1 = (HALF + Rational(3, 2) * self._a) + a2 = (HALF + a0) / (self._tau_a * a1) + a3 = a1 * (HALF - a0) / self._tau_d + activation_dynamics_equation = (a2 + a3) * (self._e - self._a) + return activation_dynamics_equation + + def __eq__(self, other): + """Equality check for ``FirstOrderActivationDeGroote2016``.""" + if type(self) != type(other): + return False + self_attrs = (self.name, self.tau_a, self.tau_d, self.b) + other_attrs = (other.name, other.tau_a, other.tau_d, other.b) + if self_attrs == other_attrs: + return True + return False + + def __repr__(self): + """Representation of ``FirstOrderActivationDeGroote2016``.""" + return ( + f'{self.__class__.__name__}({self.name!r}, ' + f'activation_time_constant={self.tau_a!r}, ' + f'deactivation_time_constant={self.tau_d!r}, ' + f'smoothing_rate={self.b!r})' + ) diff --git a/lib/python3.10/site-packages/sympy/physics/biomechanics/curve.py b/lib/python3.10/site-packages/sympy/physics/biomechanics/curve.py new file mode 100644 index 0000000000000000000000000000000000000000..6474dc1517cc34876da833cac524e8b148ab90cc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/biomechanics/curve.py @@ -0,0 +1,1763 @@ +"""Implementations of characteristic curves for musculotendon models.""" + +from dataclasses import dataclass + +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import ArgumentIndexError, Function +from sympy.core.numbers import Float, Integer +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import cosh, sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.precedence import PRECEDENCE + + +__all__ = [ + 'CharacteristicCurveCollection', + 'CharacteristicCurveFunction', + 'FiberForceLengthActiveDeGroote2016', + 'FiberForceLengthPassiveDeGroote2016', + 'FiberForceLengthPassiveInverseDeGroote2016', + 'FiberForceVelocityDeGroote2016', + 'FiberForceVelocityInverseDeGroote2016', + 'TendonForceLengthDeGroote2016', + 'TendonForceLengthInverseDeGroote2016', +] + + +class CharacteristicCurveFunction(Function): + """Base class for all musculotendon characteristic curve functions.""" + + @classmethod + def eval(cls): + msg = ( + f'Cannot directly instantiate {cls.__name__!r}, instances of ' + f'characteristic curves must be of a concrete subclass.' + + ) + raise TypeError(msg) + + def _print_code(self, printer): + """Print code for the function defining the curve using a printer. + + Explanation + =========== + + The order of operations may need to be controlled as constant folding + the numeric terms within the equations of a musculotendon + characteristic curve can sometimes results in a numerically-unstable + expression. + + Parameters + ========== + + printer : Printer + The printer to be used to print a string representation of the + characteristic curve as valid code in the target language. + + """ + return printer._print(printer.parenthesize( + self.doit(deep=False, evaluate=False), PRECEDENCE['Atom'], + )) + + _ccode = _print_code + _cupycode = _print_code + _cxxcode = _print_code + _fcode = _print_code + _jaxcode = _print_code + _lambdacode = _print_code + _mpmathcode = _print_code + _octave = _print_code + _pythoncode = _print_code + _numpycode = _print_code + _scipycode = _print_code + + +class TendonForceLengthDeGroote2016(CharacteristicCurveFunction): + r"""Tendon force-length curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized tendon force produced as a function of normalized + tendon length. + + The function is defined by the equation: + + $fl^T = c_0 \exp{c_3 \left( \tilde{l}^T - c_1 \right)} - c_2$ + + with constant values of $c_0 = 0.2$, $c_1 = 0.995$, $c_2 = 0.25$, and + $c_3 = 33.93669377311689$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces no + force when the tendon is in an unstrained state. It also produces a force + of 1 normalized unit when the tendon is under a 5% strain. + + Examples + ======== + + The preferred way to instantiate :class:`TendonForceLengthDeGroote2016` is using + the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized tendon length. We'll create a + :class:`~.Symbol` called ``l_T_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import TendonForceLengthDeGroote2016 + >>> l_T_tilde = Symbol('l_T_tilde') + >>> fl_T = TendonForceLengthDeGroote2016.with_defaults(l_T_tilde) + >>> fl_T + TendonForceLengthDeGroote2016(l_T_tilde, 0.2, 0.995, 0.25, + 33.93669377311689) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> fl_T = TendonForceLengthDeGroote2016(l_T_tilde, c0, c1, c2, c3) + >>> fl_T + TendonForceLengthDeGroote2016(l_T_tilde, c0, c1, c2, c3) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_T`` and + ``l_T_slack``, representing tendon length and tendon slack length + respectively. We can then represent ``l_T_tilde`` as an expression, the + ratio of these. + + >>> l_T, l_T_slack = symbols('l_T l_T_slack') + >>> l_T_tilde = l_T/l_T_slack + >>> fl_T = TendonForceLengthDeGroote2016.with_defaults(l_T_tilde) + >>> fl_T + TendonForceLengthDeGroote2016(l_T/l_T_slack, 0.2, 0.995, 0.25, + 33.93669377311689) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_T.doit(evaluate=False) + -0.25 + 0.2*exp(33.93669377311689*(l_T/l_T_slack - 0.995)) + + The function can also be differentiated. We'll differentiate with respect + to l_T using the ``diff`` method on an instance with the single positional + argument ``l_T``. + + >>> fl_T.diff(l_T) + 6.787338754623378*exp(33.93669377311689*(l_T/l_T_slack - 0.995))/l_T_slack + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_T_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the tendon force-length function using the + four constant values specified in the original publication. + + These have the values: + + $c_0 = 0.2$ + $c_1 = 0.995$ + $c_2 = 0.25$ + $c_3 = 33.93669377311689$ + + Parameters + ========== + + l_T_tilde : Any (sympifiable) + Normalized tendon length. + + """ + c0 = Float('0.2') + c1 = Float('0.995') + c2 = Float('0.25') + c3 = Float('33.93669377311689') + return cls(l_T_tilde, c0, c1, c2, c3) + + @classmethod + def eval(cls, l_T_tilde, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + l_T_tilde : Any (sympifiable) + Normalized tendon length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.2``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``0.995``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.25``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``33.93669377311689``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_T_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_T_tilde = l_T_tilde.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return c0*exp(c3*(l_T_tilde - c1)) - c2 + + return c0*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) - c2 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_T_tilde, c0, c1, c2, c3 = self.args + if argindex == 1: + return c0*c3*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 2: + return exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 3: + return -c0*c3*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + elif argindex == 4: + return Integer(-1) + elif argindex == 5: + return c0*(l_T_tilde - c1)*exp(c3*UnevaluatedExpr(l_T_tilde - c1)) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return TendonForceLengthInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_T_tilde = self.args[0] + _l_T_tilde = printer._print(l_T_tilde) + return r'\operatorname{fl}^T \left( %s \right)' % _l_T_tilde + + +class TendonForceLengthInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse tendon force-length curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized tendon length that produces a specific normalized + tendon force. + + The function is defined by the equation: + + ${fl^T}^{-1} = frac{\log{\frac{fl^T + c_2}{c_0}}}{c_3} + c_1$ + + with constant values of $c_0 = 0.2$, $c_1 = 0.995$, $c_2 = 0.25$, and + $c_3 = 33.93669377311689$. This function is the exact analytical inverse + of the related tendon force-length curve ``TendonForceLengthDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces no + force when the tendon is in an unstrained state. It also produces a force + of 1 normalized unit when the tendon is under a 5% strain. + + Examples + ======== + + The preferred way to instantiate :class:`TendonForceLengthInverseDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized tendon force-length, which is + equal to the tendon force. We'll create a :class:`~.Symbol` called ``fl_T`` to + represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import TendonForceLengthInverseDeGroote2016 + >>> fl_T = Symbol('fl_T') + >>> l_T_tilde = TendonForceLengthInverseDeGroote2016.with_defaults(fl_T) + >>> l_T_tilde + TendonForceLengthInverseDeGroote2016(fl_T, 0.2, 0.995, 0.25, + 33.93669377311689) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> l_T_tilde = TendonForceLengthInverseDeGroote2016(fl_T, c0, c1, c2, c3) + >>> l_T_tilde + TendonForceLengthInverseDeGroote2016(fl_T, c0, c1, c2, c3) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> l_T_tilde.doit(evaluate=False) + c1 + log((c2 + fl_T)/c0)/c3 + + The function can also be differentiated. We'll differentiate with respect + to l_T using the ``diff`` method on an instance with the single positional + argument ``l_T``. + + >>> l_T_tilde.diff(fl_T) + 1/(c3*(c2 + fl_T)) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fl_T): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse tendon force-length function + using the four constant values specified in the original publication. + + These have the values: + + $c_0 = 0.2$ + $c_1 = 0.995$ + $c_2 = 0.25$ + $c_3 = 33.93669377311689$ + + Parameters + ========== + + fl_T : Any (sympifiable) + Normalized tendon force as a function of tendon length. + + """ + c0 = Float('0.2') + c1 = Float('0.995') + c2 = Float('0.25') + c3 = Float('33.93669377311689') + return cls(fl_T, c0, c1, c2, c3) + + @classmethod + def eval(cls, fl_T, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + fl_T : Any (sympifiable) + Normalized tendon force as a function of tendon length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.2``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``0.995``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.25``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``33.93669377311689``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fl_T, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fl_T = fl_T.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return log((fl_T + c2)/c0)/c3 + c1 + + return log(UnevaluatedExpr((fl_T + c2)/c0))/c3 + c1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fl_T, c0, c1, c2, c3 = self.args + if argindex == 1: + return 1/(c3*(fl_T + c2)) + elif argindex == 2: + return -1/(c0*c3) + elif argindex == 3: + return Integer(1) + elif argindex == 4: + return 1/(c3*(fl_T + c2)) + elif argindex == 5: + return -log(UnevaluatedExpr((fl_T + c2)/c0))/c3**2 + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return TendonForceLengthDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fl_T = self.args[0] + _fl_T = printer._print(fl_T) + return r'\left( \operatorname{fl}^T \right)^{-1} \left( %s \right)' % _fl_T + + +class FiberForceLengthPassiveDeGroote2016(CharacteristicCurveFunction): + r"""Passive muscle fiber force-length curve based on De Groote et al., 2016 + [1]_. + + Explanation + =========== + + The function is defined by the equation: + + $fl^M_{pas} = \frac{\frac{\exp{c_1 \left(\tilde{l^M} - 1\right)}}{c_0} - 1}{\exp{c_1} - 1}$ + + with constant values of $c_0 = 0.6$ and $c_1 = 4.0$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + passive fiber force very close to 0 for all normalized fiber lengths + between 0 and 1. + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceLengthPassiveDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber length. We'll + create a :class:`~.Symbol` called ``l_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthPassiveDeGroote2016 + >>> l_M_tilde = Symbol('l_M_tilde') + >>> fl_M = FiberForceLengthPassiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M_tilde, 0.6, 4.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1 = symbols('c0 c1') + >>> fl_M = FiberForceLengthPassiveDeGroote2016(l_M_tilde, c0, c1) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M_tilde, c0, c1) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_M`` and + ``l_M_opt``, representing muscle fiber length and optimal muscle fiber + length respectively. We can then represent ``l_M_tilde`` as an expression, + the ratio of these. + + >>> l_M, l_M_opt = symbols('l_M l_M_opt') + >>> l_M_tilde = l_M/l_M_opt + >>> fl_M = FiberForceLengthPassiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthPassiveDeGroote2016(l_M/l_M_opt, 0.6, 4.0) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_M.doit(evaluate=False) + 0.0186573603637741*(-1 + exp(6.66666666666667*(l_M/l_M_opt - 1))) + + The function can also be differentiated. We'll differentiate with respect + to l_M using the ``diff`` method on an instance with the single positional + argument ``l_M``. + + >>> fl_M.diff(l_M) + 0.12438240242516*exp(6.66666666666667*(l_M/l_M_opt - 1))/l_M_opt + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the muscle fiber passive force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = 0.6$ + $c_1 = 4.0$ + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + + """ + c0 = Float('0.6') + c1 = Float('4.0') + return cls(l_M_tilde, c0, c1) + + @classmethod + def eval(cls, l_M_tilde, c0, c1): + """Evaluation of basic inputs. + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.6``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``4.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_M_tilde = l_M_tilde.doit(deep=deep, **hints) + c0, c1 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1 = constants + + if evaluate: + return (exp((c1*(l_M_tilde - 1))/c0) - 1)/(exp(c1) - 1) + + return (exp((c1*UnevaluatedExpr(l_M_tilde - 1))/c0) - 1)/(exp(c1) - 1) + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_M_tilde, c0, c1 = self.args + if argindex == 1: + return c1*exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0)/(c0*(exp(c1) - 1)) + elif argindex == 2: + return ( + -c1*exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0) + *UnevaluatedExpr(l_M_tilde - 1)/(c0**2*(exp(c1) - 1)) + ) + elif argindex == 3: + return ( + -exp(c1)*(-1 + exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0))/(exp(c1) - 1)**2 + + exp(c1*UnevaluatedExpr(l_M_tilde - 1)/c0)*(l_M_tilde - 1)/(c0*(exp(c1) - 1)) + ) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceLengthPassiveInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_M_tilde = self.args[0] + _l_M_tilde = printer._print(l_M_tilde) + return r'\operatorname{fl}^M_{pas} \left( %s \right)' % _l_M_tilde + + +class FiberForceLengthPassiveInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse passive muscle fiber force-length curve based on De Groote et + al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber length that produces a specific normalized + passive muscle fiber force. + + The function is defined by the equation: + + ${fl^M_{pas}}^{-1} = \frac{c_0 \log{\left(\exp{c_1} - 1\right)fl^M_pas + 1}}{c_1} + 1$ + + with constant values of $c_0 = 0.6$ and $c_1 = 4.0$. This function is the + exact analytical inverse of the related tendon force-length curve + ``FiberForceLengthPassiveDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + passive fiber force very close to 0 for all normalized fiber lengths + between 0 and 1. + + Examples + ======== + + The preferred way to instantiate + :class:`FiberForceLengthPassiveInverseDeGroote2016` is using the + :meth:`~.with_defaults` constructor because this will automatically populate the + constants within the characteristic curve equation with the floating point + values from the original publication. This constructor takes a single + argument corresponding to the normalized passive muscle fiber length-force + component of the muscle fiber force. We'll create a :class:`~.Symbol` called + ``fl_M_pas`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthPassiveInverseDeGroote2016 + >>> fl_M_pas = Symbol('fl_M_pas') + >>> l_M_tilde = FiberForceLengthPassiveInverseDeGroote2016.with_defaults(fl_M_pas) + >>> l_M_tilde + FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, 0.6, 4.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1 = symbols('c0 c1') + >>> l_M_tilde = FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c0, c1) + >>> l_M_tilde + FiberForceLengthPassiveInverseDeGroote2016(fl_M_pas, c0, c1) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> l_M_tilde.doit(evaluate=False) + c0*log(1 + fl_M_pas*(exp(c1) - 1))/c1 + 1 + + The function can also be differentiated. We'll differentiate with respect + to fl_M_pas using the ``diff`` method on an instance with the single positional + argument ``fl_M_pas``. + + >>> l_M_tilde.diff(fl_M_pas) + c0*(exp(c1) - 1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fl_M_pas): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber passive force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = 0.6$ + $c_1 = 4.0$ + + Parameters + ========== + + fl_M_pas : Any (sympifiable) + Normalized passive muscle fiber force as a function of muscle fiber + length. + + """ + c0 = Float('0.6') + c1 = Float('4.0') + return cls(fl_M_pas, c0, c1) + + @classmethod + def eval(cls, fl_M_pas, c0, c1): + """Evaluation of basic inputs. + + Parameters + ========== + + fl_M_pas : Any (sympifiable) + Normalized passive muscle fiber force. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.6``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``4.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_T_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fl_M_pas, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fl_M_pas = fl_M_pas.doit(deep=deep, **hints) + c0, c1 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1 = constants + + if evaluate: + return c0*log(fl_M_pas*(exp(c1) - 1) + 1)/c1 + 1 + + return c0*log(UnevaluatedExpr(fl_M_pas*(exp(c1) - 1)) + 1)/c1 + 1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fl_M_pas, c0, c1 = self.args + if argindex == 1: + return c0*(exp(c1) - 1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + elif argindex == 2: + return log(fl_M_pas*(exp(c1) - 1) + 1)/c1 + elif argindex == 3: + return ( + c0*fl_M_pas*exp(c1)/(c1*(fl_M_pas*(exp(c1) - 1) + 1)) + - c0*log(fl_M_pas*(exp(c1) - 1) + 1)/c1**2 + ) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceLengthPassiveDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fl_M_pas = self.args[0] + _fl_M_pas = printer._print(fl_M_pas) + return r'\left( \operatorname{fl}^M_{pas} \right)^{-1} \left( %s \right)' % _fl_M_pas + + +class FiberForceLengthActiveDeGroote2016(CharacteristicCurveFunction): + r"""Active muscle fiber force-length curve based on De Groote et al., 2016 + [1]_. + + Explanation + =========== + + The function is defined by the equation: + + $fl_{\text{act}}^M = c_0 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_1}{c_2 + c_3 \tilde{l}^M}\right)^2\right) + + c_4 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_5}{c_6 + c_7 \tilde{l}^M}\right)^2\right) + + c_8 \exp\left(-\frac{1}{2}\left(\frac{\tilde{l}^M - c_9}{c_{10} + c_{11} \tilde{l}^M}\right)^2\right)$ + + with constant values of $c0 = 0.814$, $c1 = 1.06$, $c2 = 0.162$, + $c3 = 0.0633$, $c4 = 0.433$, $c5 = 0.717$, $c6 = -0.0299$, $c7 = 0.2$, + $c8 = 0.1$, $c9 = 1.0$, $c10 = 0.354$, and $c11 = 0.0$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + active fiber force of 1 at a normalized fiber length of 1, and an active + fiber force of 0 at normalized fiber lengths of 0 and 2. + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceLengthActiveDeGroote2016` is + using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber length. We'll + create a :class:`~.Symbol` called ``l_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceLengthActiveDeGroote2016 + >>> l_M_tilde = Symbol('l_M_tilde') + >>> fl_M = FiberForceLengthActiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M_tilde, 0.814, 1.06, 0.162, 0.0633, + 0.433, 0.717, -0.0299, 0.2, 0.1, 1.0, 0.354, 0.0) + + It's also possible to populate the two constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = symbols('c0:12') + >>> fl_M = FiberForceLengthActiveDeGroote2016(l_M_tilde, c0, c1, c2, c3, + ... c4, c5, c6, c7, c8, c9, c10, c11) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M_tilde, c0, c1, c2, c3, c4, c5, c6, + c7, c8, c9, c10, c11) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``l_M`` and + ``l_M_opt``, representing muscle fiber length and optimal muscle fiber + length respectively. We can then represent ``l_M_tilde`` as an expression, + the ratio of these. + + >>> l_M, l_M_opt = symbols('l_M l_M_opt') + >>> l_M_tilde = l_M/l_M_opt + >>> fl_M = FiberForceLengthActiveDeGroote2016.with_defaults(l_M_tilde) + >>> fl_M + FiberForceLengthActiveDeGroote2016(l_M/l_M_opt, 0.814, 1.06, 0.162, 0.0633, + 0.433, 0.717, -0.0299, 0.2, 0.1, 1.0, 0.354, 0.0) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fl_M.doit(evaluate=False) + 0.814*exp(-19.0519737844841*(l_M/l_M_opt + - 1.06)**2/(0.390740740740741*l_M/l_M_opt + 1)**2) + + 0.433*exp(-12.5*(l_M/l_M_opt - 0.717)**2/(l_M/l_M_opt - 0.1495)**2) + + 0.1*exp(-3.98991349867535*(l_M/l_M_opt - 1.0)**2) + + The function can also be differentiated. We'll differentiate with respect + to l_M using the ``diff`` method on an instance with the single positional + argument ``l_M``. + + >>> fl_M.diff(l_M) + ((-0.79798269973507*l_M/l_M_opt + + 0.79798269973507)*exp(-3.98991349867535*(l_M/l_M_opt - 1.0)**2) + + (10.825*(-l_M/l_M_opt + 0.717)/(l_M/l_M_opt - 0.1495)**2 + + 10.825*(l_M/l_M_opt - 0.717)**2/(l_M/l_M_opt + - 0.1495)**3)*exp(-12.5*(l_M/l_M_opt - 0.717)**2/(l_M/l_M_opt - 0.1495)**2) + + (31.0166133211401*(-l_M/l_M_opt + 1.06)/(0.390740740740741*l_M/l_M_opt + + 1)**2 + 13.6174190361677*(0.943396226415094*l_M/l_M_opt + - 1)**2/(0.390740740740741*l_M/l_M_opt + + 1)**3)*exp(-21.4067977442463*(0.943396226415094*l_M/l_M_opt + - 1)**2/(0.390740740740741*l_M/l_M_opt + 1)**2))/l_M_opt + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, l_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber act force-length + function using the four constant values specified in the original + publication. + + These have the values: + + $c0 = 0.814$ + $c1 = 1.06$ + $c2 = 0.162$ + $c3 = 0.0633$ + $c4 = 0.433$ + $c5 = 0.717$ + $c6 = -0.0299$ + $c7 = 0.2$ + $c8 = 0.1$ + $c9 = 1.0$ + $c10 = 0.354$ + $c11 = 0.0$ + + Parameters + ========== + + fl_M_act : Any (sympifiable) + Normalized passive muscle fiber force as a function of muscle fiber + length. + + """ + c0 = Float('0.814') + c1 = Float('1.06') + c2 = Float('0.162') + c3 = Float('0.0633') + c4 = Float('0.433') + c5 = Float('0.717') + c6 = Float('-0.0299') + c7 = Float('0.2') + c8 = Float('0.1') + c9 = Float('1.0') + c10 = Float('0.354') + c11 = Float('0.0') + return cls(l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11) + + @classmethod + def eval(cls, l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11): + """Evaluation of basic inputs. + + Parameters + ========== + + l_M_tilde : Any (sympifiable) + Normalized muscle fiber length. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``0.814``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``1.06``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``0.162``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.0633``. + c4 : Any (sympifiable) + The fifth constant in the characteristic equation. The published + value is ``0.433``. + c5 : Any (sympifiable) + The sixth constant in the characteristic equation. The published + value is ``0.717``. + c6 : Any (sympifiable) + The seventh constant in the characteristic equation. The published + value is ``-0.0299``. + c7 : Any (sympifiable) + The eighth constant in the characteristic equation. The published + value is ``0.2``. + c8 : Any (sympifiable) + The ninth constant in the characteristic equation. The published + value is ``0.1``. + c9 : Any (sympifiable) + The tenth constant in the characteristic equation. The published + value is ``1.0``. + c10 : Any (sympifiable) + The eleventh constant in the characteristic equation. The published + value is ``0.354``. + c11 : Any (sympifiable) + The tweflth constant in the characteristic equation. The published + value is ``0.0``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``l_M_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + l_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + l_M_tilde = l_M_tilde.doit(deep=deep, **hints) + constants = [c.doit(deep=deep, **hints) for c in constants] + c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = constants + + if evaluate: + return ( + c0*exp(-(((l_M_tilde - c1)/(c2 + c3*l_M_tilde))**2)/2) + + c4*exp(-(((l_M_tilde - c5)/(c6 + c7*l_M_tilde))**2)/2) + + c8*exp(-(((l_M_tilde - c9)/(c10 + c11*l_M_tilde))**2)/2) + ) + + return ( + c0*exp(-((UnevaluatedExpr(l_M_tilde - c1)/(c2 + c3*l_M_tilde))**2)/2) + + c4*exp(-((UnevaluatedExpr(l_M_tilde - c5)/(c6 + c7*l_M_tilde))**2)/2) + + c8*exp(-((UnevaluatedExpr(l_M_tilde - c9)/(c10 + c11*l_M_tilde))**2)/2) + ) + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + l_M_tilde, c0, c1, c2, c3, c4, c5, c6, c7, c8, c9, c10, c11 = self.args + if argindex == 1: + return ( + c0*( + c3*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + + (c1 - l_M_tilde)/((c2 + c3*l_M_tilde)**2) + )*exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + + c4*( + c7*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + + (c5 - l_M_tilde)/((c6 + c7*l_M_tilde)**2) + )*exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + + c8*( + c11*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + + (c9 - l_M_tilde)/((c10 + c11*l_M_tilde)**2) + )*exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 2: + return exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + elif argindex == 3: + return ( + c0*(l_M_tilde - c1)/(c2 + c3*l_M_tilde)**2 + *exp(-(l_M_tilde - c1)**2 /(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 4: + return ( + c0*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + *exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 5: + return ( + c0*l_M_tilde*(l_M_tilde - c1)**2/(c2 + c3*l_M_tilde)**3 + *exp(-(l_M_tilde - c1)**2/(2*(c2 + c3*l_M_tilde)**2)) + ) + elif argindex == 6: + return exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + elif argindex == 7: + return ( + c4*(l_M_tilde - c5)/(c6 + c7*l_M_tilde)**2 + *exp(-(l_M_tilde - c5)**2 /(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 8: + return ( + c4*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + *exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 9: + return ( + c4*l_M_tilde*(l_M_tilde - c5)**2/(c6 + c7*l_M_tilde)**3 + *exp(-(l_M_tilde - c5)**2/(2*(c6 + c7*l_M_tilde)**2)) + ) + elif argindex == 10: + return exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + elif argindex == 11: + return ( + c8*(l_M_tilde - c9)/(c10 + c11*l_M_tilde)**2 + *exp(-(l_M_tilde - c9)**2 /(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 12: + return ( + c8*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + *exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + elif argindex == 13: + return ( + c8*l_M_tilde*(l_M_tilde - c9)**2/(c10 + c11*l_M_tilde)**3 + *exp(-(l_M_tilde - c9)**2/(2*(c10 + c11*l_M_tilde)**2)) + ) + + raise ArgumentIndexError(self, argindex) + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + l_M_tilde = self.args[0] + _l_M_tilde = printer._print(l_M_tilde) + return r'\operatorname{fl}^M_{act} \left( %s \right)' % _l_M_tilde + + +class FiberForceVelocityDeGroote2016(CharacteristicCurveFunction): + r"""Muscle fiber force-velocity curve based on De Groote et al., 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber force produced as a function of + normalized tendon velocity. + + The function is defined by the equation: + + $fv^M = c_0 \log{\left(c_1 \tilde{v}_m + c_2\right) + \sqrt{\left(c_1 \tilde{v}_m + c_2\right)^2 + 1}} + c_3$ + + with constant values of $c_0 = -0.318$, $c_1 = -8.149$, $c_2 = -0.374$, and + $c_3 = 0.886$. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + normalized muscle fiber force of 1 when the muscle fibers are contracting + isometrically (they have an extension rate of 0). + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceVelocityDeGroote2016` is using + the :meth:`~.with_defaults` constructor because this will automatically populate + the constants within the characteristic curve equation with the floating + point values from the original publication. This constructor takes a single + argument corresponding to normalized muscle fiber extension velocity. We'll + create a :class:`~.Symbol` called ``v_M_tilde`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceVelocityDeGroote2016 + >>> v_M_tilde = Symbol('v_M_tilde') + >>> fv_M = FiberForceVelocityDeGroote2016.with_defaults(v_M_tilde) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M_tilde, -0.318, -8.149, -0.374, 0.886) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> fv_M = FiberForceVelocityDeGroote2016(v_M_tilde, c0, c1, c2, c3) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M_tilde, c0, c1, c2, c3) + + You don't just have to use symbols as the arguments, it's also possible to + use expressions. Let's create a new pair of symbols, ``v_M`` and + ``v_M_max``, representing muscle fiber extension velocity and maximum + muscle fiber extension velocity respectively. We can then represent + ``v_M_tilde`` as an expression, the ratio of these. + + >>> v_M, v_M_max = symbols('v_M v_M_max') + >>> v_M_tilde = v_M/v_M_max + >>> fv_M = FiberForceVelocityDeGroote2016.with_defaults(v_M_tilde) + >>> fv_M + FiberForceVelocityDeGroote2016(v_M/v_M_max, -0.318, -8.149, -0.374, 0.886) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> fv_M.doit(evaluate=False) + 0.886 - 0.318*log(-8.149*v_M/v_M_max - 0.374 + sqrt(1 + (-8.149*v_M/v_M_max + - 0.374)**2)) + + The function can also be differentiated. We'll differentiate with respect + to v_M using the ``diff`` method on an instance with the single positional + argument ``v_M``. + + >>> fv_M.diff(v_M) + 2.591382*(1 + (-8.149*v_M/v_M_max - 0.374)**2)**(-1/2)/v_M_max + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, v_M_tilde): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the muscle fiber force-velocity function + using the four constant values specified in the original publication. + + These have the values: + + $c_0 = -0.318$ + $c_1 = -8.149$ + $c_2 = -0.374$ + $c_3 = 0.886$ + + Parameters + ========== + + v_M_tilde : Any (sympifiable) + Normalized muscle fiber extension velocity. + + """ + c0 = Float('-0.318') + c1 = Float('-8.149') + c2 = Float('-0.374') + c3 = Float('0.886') + return cls(v_M_tilde, c0, c1, c2, c3) + + @classmethod + def eval(cls, v_M_tilde, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + v_M_tilde : Any (sympifiable) + Normalized muscle fiber extension velocity. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``-0.318``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``-8.149``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``-0.374``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.886``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``v_M_tilde`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + v_M_tilde, *constants = self.args + if deep: + hints['evaluate'] = evaluate + v_M_tilde = v_M_tilde.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return c0*log(c1*v_M_tilde + c2 + sqrt((c1*v_M_tilde + c2)**2 + 1)) + c3 + + return c0*log(c1*v_M_tilde + c2 + sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1)) + c3 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + v_M_tilde, c0, c1, c2, c3 = self.args + if argindex == 1: + return c0*c1/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 2: + return log( + c1*v_M_tilde + c2 + + sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + ) + elif argindex == 3: + return c0*v_M_tilde/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 4: + return c0/sqrt(UnevaluatedExpr(c1*v_M_tilde + c2)**2 + 1) + elif argindex == 5: + return Integer(1) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceVelocityInverseDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + v_M_tilde = self.args[0] + _v_M_tilde = printer._print(v_M_tilde) + return r'\operatorname{fv}^M \left( %s \right)' % _v_M_tilde + + +class FiberForceVelocityInverseDeGroote2016(CharacteristicCurveFunction): + r"""Inverse muscle fiber force-velocity curve based on De Groote et al., + 2016 [1]_. + + Explanation + =========== + + Gives the normalized muscle fiber velocity that produces a specific + normalized muscle fiber force. + + The function is defined by the equation: + + ${fv^M}^{-1} = \frac{\sinh{\frac{fv^M - c_3}{c_0}} - c_2}{c_1}$ + + with constant values of $c_0 = -0.318$, $c_1 = -8.149$, $c_2 = -0.374$, and + $c_3 = 0.886$. This function is the exact analytical inverse of the related + muscle fiber force-velocity curve ``FiberForceVelocityDeGroote2016``. + + While it is possible to change the constant values, these were carefully + selected in the original publication to give the characteristic curve + specific and required properties. For example, the function produces a + normalized muscle fiber force of 1 when the muscle fibers are contracting + isometrically (they have an extension rate of 0). + + Examples + ======== + + The preferred way to instantiate :class:`FiberForceVelocityInverseDeGroote2016` + is using the :meth:`~.with_defaults` constructor because this will automatically + populate the constants within the characteristic curve equation with the + floating point values from the original publication. This constructor takes + a single argument corresponding to normalized muscle fiber force-velocity + component of the muscle fiber force. We'll create a :class:`~.Symbol` called + ``fv_M`` to represent this. + + >>> from sympy import Symbol + >>> from sympy.physics.biomechanics import FiberForceVelocityInverseDeGroote2016 + >>> fv_M = Symbol('fv_M') + >>> v_M_tilde = FiberForceVelocityInverseDeGroote2016.with_defaults(fv_M) + >>> v_M_tilde + FiberForceVelocityInverseDeGroote2016(fv_M, -0.318, -8.149, -0.374, 0.886) + + It's also possible to populate the four constants with your own values too. + + >>> from sympy import symbols + >>> c0, c1, c2, c3 = symbols('c0 c1 c2 c3') + >>> v_M_tilde = FiberForceVelocityInverseDeGroote2016(fv_M, c0, c1, c2, c3) + >>> v_M_tilde + FiberForceVelocityInverseDeGroote2016(fv_M, c0, c1, c2, c3) + + To inspect the actual symbolic expression that this function represents, + we can call the :meth:`~.doit` method on an instance. We'll use the keyword + argument ``evaluate=False`` as this will keep the expression in its + canonical form and won't simplify any constants. + + >>> v_M_tilde.doit(evaluate=False) + (-c2 + sinh((-c3 + fv_M)/c0))/c1 + + The function can also be differentiated. We'll differentiate with respect + to fv_M using the ``diff`` method on an instance with the single positional + argument ``fv_M``. + + >>> v_M_tilde.diff(fv_M) + cosh((-c3 + fv_M)/c0)/(c0*c1) + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + @classmethod + def with_defaults(cls, fv_M): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the inverse muscle fiber force-velocity + function using the four constant values specified in the original + publication. + + These have the values: + + $c_0 = -0.318$ + $c_1 = -8.149$ + $c_2 = -0.374$ + $c_3 = 0.886$ + + Parameters + ========== + + fv_M : Any (sympifiable) + Normalized muscle fiber extension velocity. + + """ + c0 = Float('-0.318') + c1 = Float('-8.149') + c2 = Float('-0.374') + c3 = Float('0.886') + return cls(fv_M, c0, c1, c2, c3) + + @classmethod + def eval(cls, fv_M, c0, c1, c2, c3): + """Evaluation of basic inputs. + + Parameters + ========== + + fv_M : Any (sympifiable) + Normalized muscle fiber force as a function of muscle fiber + extension velocity. + c0 : Any (sympifiable) + The first constant in the characteristic equation. The published + value is ``-0.318``. + c1 : Any (sympifiable) + The second constant in the characteristic equation. The published + value is ``-8.149``. + c2 : Any (sympifiable) + The third constant in the characteristic equation. The published + value is ``-0.374``. + c3 : Any (sympifiable) + The fourth constant in the characteristic equation. The published + value is ``0.886``. + + """ + pass + + def _eval_evalf(self, prec): + """Evaluate the expression numerically using ``evalf``.""" + return self.doit(deep=False, evaluate=False)._eval_evalf(prec) + + def doit(self, deep=True, evaluate=True, **hints): + """Evaluate the expression defining the function. + + Parameters + ========== + + deep : bool + Whether ``doit`` should be recursively called. Default is ``True``. + evaluate : bool. + Whether the SymPy expression should be evaluated as it is + constructed. If ``False``, then no constant folding will be + conducted which will leave the expression in a more numerically- + stable for values of ``fv_M`` that correspond to a sensible + operating range for a musculotendon. Default is ``True``. + **kwargs : dict[str, Any] + Additional keyword argument pairs to be recursively passed to + ``doit``. + + """ + fv_M, *constants = self.args + if deep: + hints['evaluate'] = evaluate + fv_M = fv_M.doit(deep=deep, **hints) + c0, c1, c2, c3 = [c.doit(deep=deep, **hints) for c in constants] + else: + c0, c1, c2, c3 = constants + + if evaluate: + return (sinh((fv_M - c3)/c0) - c2)/c1 + + return (sinh(UnevaluatedExpr(fv_M - c3)/c0) - c2)/c1 + + def fdiff(self, argindex=1): + """Derivative of the function with respect to a single argument. + + Parameters + ========== + + argindex : int + The index of the function's arguments with respect to which the + derivative should be taken. Argument indexes start at ``1``. + Default is ``1``. + + """ + fv_M, c0, c1, c2, c3 = self.args + if argindex == 1: + return cosh((fv_M - c3)/c0)/(c0*c1) + elif argindex == 2: + return (c3 - fv_M)*cosh((fv_M - c3)/c0)/(c0**2*c1) + elif argindex == 3: + return (c2 - sinh((fv_M - c3)/c0))/c1**2 + elif argindex == 4: + return -1/c1 + elif argindex == 5: + return -cosh((fv_M - c3)/c0)/(c0*c1) + + raise ArgumentIndexError(self, argindex) + + def inverse(self, argindex=1): + """Inverse function. + + Parameters + ========== + + argindex : int + Value to start indexing the arguments at. Default is ``1``. + + """ + return FiberForceVelocityDeGroote2016 + + def _latex(self, printer): + """Print a LaTeX representation of the function defining the curve. + + Parameters + ========== + + printer : Printer + The printer to be used to print the LaTeX string representation. + + """ + fv_M = self.args[0] + _fv_M = printer._print(fv_M) + return r'\left( \operatorname{fv}^M \right)^{-1} \left( %s \right)' % _fv_M + + +@dataclass(frozen=True) +class CharacteristicCurveCollection: + """Simple data container to group together related characteristic curves.""" + tendon_force_length: CharacteristicCurveFunction + tendon_force_length_inverse: CharacteristicCurveFunction + fiber_force_length_passive: CharacteristicCurveFunction + fiber_force_length_passive_inverse: CharacteristicCurveFunction + fiber_force_length_active: CharacteristicCurveFunction + fiber_force_velocity: CharacteristicCurveFunction + fiber_force_velocity_inverse: CharacteristicCurveFunction + + def __iter__(self): + """Iterator support for ``CharacteristicCurveCollection``.""" + yield self.tendon_force_length + yield self.tendon_force_length_inverse + yield self.fiber_force_length_passive + yield self.fiber_force_length_passive_inverse + yield self.fiber_force_length_active + yield self.fiber_force_velocity + yield self.fiber_force_velocity_inverse diff --git a/lib/python3.10/site-packages/sympy/physics/biomechanics/musculotendon.py b/lib/python3.10/site-packages/sympy/physics/biomechanics/musculotendon.py new file mode 100644 index 0000000000000000000000000000000000000000..8bb1f64fa8f61743ad72b200c4318bbf28916fb1 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/biomechanics/musculotendon.py @@ -0,0 +1,1424 @@ +"""Implementations of musculotendon models. + +Musculotendon models are a critical component of biomechanical models, one that +differentiates them from pure multibody systems. Musculotendon models produce a +force dependent on their level of activation, their length, and their +extension velocity. Length- and extension velocity-dependent force production +are governed by force-length and force-velocity characteristics. +These are normalized functions that are dependent on the musculotendon's state +and are specific to a given musculotendon model. + +""" + +from abc import abstractmethod +from enum import IntEnum, unique + +from sympy.core.numbers import Float, Integer +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix, diag, eye, zeros +from sympy.physics.biomechanics.activation import ActivationBase +from sympy.physics.biomechanics.curve import ( + CharacteristicCurveCollection, + FiberForceLengthActiveDeGroote2016, + FiberForceLengthPassiveDeGroote2016, + FiberForceLengthPassiveInverseDeGroote2016, + FiberForceVelocityDeGroote2016, + FiberForceVelocityInverseDeGroote2016, + TendonForceLengthDeGroote2016, + TendonForceLengthInverseDeGroote2016, +) +from sympy.physics.biomechanics._mixin import _NamedMixin +from sympy.physics.mechanics.actuator import ForceActuator +from sympy.physics.vector.functions import dynamicsymbols + + +__all__ = [ + 'MusculotendonBase', + 'MusculotendonDeGroote2016', + 'MusculotendonFormulation', +] + + +@unique +class MusculotendonFormulation(IntEnum): + """Enumeration of types of musculotendon dynamics formulations. + + Explanation + =========== + + An (integer) enumeration is used as it allows for clearer selection of the + different formulations of musculotendon dynamics. + + Members + ======= + + RIGID_TENDON : 0 + A rigid tendon model. + FIBER_LENGTH_EXPLICIT : 1 + An explicit elastic tendon model with the muscle fiber length (l_M) as + the state variable. + TENDON_FORCE_EXPLICIT : 2 + An explicit elastic tendon model with the tendon force (F_T) as the + state variable. + FIBER_LENGTH_IMPLICIT : 3 + An implicit elastic tendon model with the muscle fiber length (l_M) as + the state variable and the muscle fiber velocity as an additional input + variable. + TENDON_FORCE_IMPLICIT : 4 + An implicit elastic tendon model with the tendon force (F_T) as the + state variable as the muscle fiber velocity as an additional input + variable. + + """ + + RIGID_TENDON = 0 + FIBER_LENGTH_EXPLICIT = 1 + TENDON_FORCE_EXPLICIT = 2 + FIBER_LENGTH_IMPLICIT = 3 + TENDON_FORCE_IMPLICIT = 4 + + def __str__(self): + """Returns a string representation of the enumeration value. + + Notes + ===== + + This hard coding is required due to an incompatibility between the + ``IntEnum`` implementations in Python 3.10 and Python 3.11 + (https://github.com/python/cpython/issues/84247). From Python 3.11 + onwards, the ``__str__`` method uses ``int.__str__``, whereas prior it + used ``Enum.__str__``. Once Python 3.11 becomes the minimum version + supported by SymPy, this method override can be removed. + + """ + return str(self.value) + + +_DEFAULT_MUSCULOTENDON_FORMULATION = MusculotendonFormulation.RIGID_TENDON + + +class MusculotendonBase(ForceActuator, _NamedMixin): + r"""Abstract base class for all musculotendon classes to inherit from. + + Explanation + =========== + + A musculotendon generates a contractile force based on its activation, + length, and shortening velocity. This abstract base class is to be inherited + by all musculotendon subclasses that implement different characteristic + musculotendon curves. Characteristic musculotendon curves are required for + the tendon force-length, passive fiber force-length, active fiber force- + length, and fiber force-velocity relationships. + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is used + as a suffix when automatically generated symbols are instantiated. It + must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the musculotendon. + This must be an instance of a concrete subclass of ``ActivationBase``, + e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal fiber + length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + with_defaults : bool + Whether ``with_defaults`` alternate constructors should be used when + automatically constructing child classes. Default is ``False``. + + """ + + def __init__( + self, + name, + pathway, + activation_dynamics, + *, + musculotendon_dynamics=_DEFAULT_MUSCULOTENDON_FORMULATION, + tendon_slack_length=None, + peak_isometric_force=None, + optimal_fiber_length=None, + maximal_fiber_velocity=None, + optimal_pennation_angle=None, + fiber_damping_coefficient=None, + with_defaults=False, + ): + self.name = name + + # Supply a placeholder force to the super initializer, this will be + # replaced later + super().__init__(Symbol('F'), pathway) + + # Activation dynamics + if not isinstance(activation_dynamics, ActivationBase): + msg = ( + f'Can\'t set attribute `activation_dynamics` to ' + f'{activation_dynamics} as it must be of type ' + f'`ActivationBase`, not {type(activation_dynamics)}.' + ) + raise TypeError(msg) + self._activation_dynamics = activation_dynamics + self._child_objects = (self._activation_dynamics, ) + + # Constants + if tendon_slack_length is not None: + self._l_T_slack = tendon_slack_length + else: + self._l_T_slack = Symbol(f'l_T_slack_{self.name}') + if peak_isometric_force is not None: + self._F_M_max = peak_isometric_force + else: + self._F_M_max = Symbol(f'F_M_max_{self.name}') + if optimal_fiber_length is not None: + self._l_M_opt = optimal_fiber_length + else: + self._l_M_opt = Symbol(f'l_M_opt_{self.name}') + if maximal_fiber_velocity is not None: + self._v_M_max = maximal_fiber_velocity + else: + self._v_M_max = Symbol(f'v_M_max_{self.name}') + if optimal_pennation_angle is not None: + self._alpha_opt = optimal_pennation_angle + else: + self._alpha_opt = Symbol(f'alpha_opt_{self.name}') + if fiber_damping_coefficient is not None: + self._beta = fiber_damping_coefficient + else: + self._beta = Symbol(f'beta_{self.name}') + + # Musculotendon dynamics + self._with_defaults = with_defaults + if musculotendon_dynamics == MusculotendonFormulation.RIGID_TENDON: + self._rigid_tendon_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.FIBER_LENGTH_EXPLICIT: + self._fiber_length_explicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.TENDON_FORCE_EXPLICIT: + self._tendon_force_explicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.FIBER_LENGTH_IMPLICIT: + self._fiber_length_implicit_musculotendon_dynamics() + elif musculotendon_dynamics == MusculotendonFormulation.TENDON_FORCE_IMPLICIT: + self._tendon_force_implicit_musculotendon_dynamics() + else: + msg = ( + f'Musculotendon dynamics {repr(musculotendon_dynamics)} ' + f'passed to `musculotendon_dynamics` was of type ' + f'{type(musculotendon_dynamics)}, must be ' + f'{MusculotendonFormulation}.' + ) + raise TypeError(msg) + self._musculotendon_dynamics = musculotendon_dynamics + + # Must override the placeholder value in `self._force` now that the + # actual force has been calculated by + # `self.__musculotendon_dynamics`. + # Note that `self._force` assumes forces are expansile, musculotendon + # forces are contractile hence the minus sign preceeding `self._F_T` + # (the tendon force). + self._force = -self._F_T + + @classmethod + def with_defaults( + cls, + name, + pathway, + activation_dynamics, + *, + musculotendon_dynamics=_DEFAULT_MUSCULOTENDON_FORMULATION, + tendon_slack_length=None, + peak_isometric_force=None, + optimal_fiber_length=None, + maximal_fiber_velocity=Float('10.0'), + optimal_pennation_angle=Float('0.0'), + fiber_damping_coefficient=Float('0.1'), + ): + r"""Recommended constructor that will use the published constants. + + Explanation + =========== + + Returns a new instance of the musculotendon class using recommended + values for ``v_M_max``, ``alpha_opt``, and ``beta``. The values are: + + :math:`v^M_{max} = 10` + :math:`\alpha_{opt} = 0` + :math:`\beta = \frac{1}{10}` + + The musculotendon curves are also instantiated using the constants from + the original publication. + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is + used as a suffix when automatically generated symbols are + instantiated. It must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the + musculotendon. This must be an instance of a concrete subclass of + ``ActivationBase``, e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be + cast to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value + ``0``, which will be cast to the enumeration member). There are four + possible formulations for an elastic tendon model. To use an + explicit formulation with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer + value ``1``). To use an explicit formulation with the tendon force + as the state, set this to + ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` (or the integer + value ``2``). To use an implicit formulation with the fiber length + as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer + value ``3``). To use an implicit formulation with the tendon force + as the state, set this to + ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` (or the integer + value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a + rigid tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In + all musculotendon models, peak isometric force is used to normalized + tendon and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no + passive force and their maximum active force. In all musculotendon + models, optimal fiber length is used to normalize muscle fiber + length to give :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the + muscle fibers are unable to produce any active force. In all + musculotendon models, maximal fiber velocity is used to normalize + muscle fiber extension velocity to give + :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal + fiber length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + """ + return cls( + name, + pathway, + activation_dynamics=activation_dynamics, + musculotendon_dynamics=musculotendon_dynamics, + tendon_slack_length=tendon_slack_length, + peak_isometric_force=peak_isometric_force, + optimal_fiber_length=optimal_fiber_length, + maximal_fiber_velocity=maximal_fiber_velocity, + optimal_pennation_angle=optimal_pennation_angle, + fiber_damping_coefficient=fiber_damping_coefficient, + with_defaults=True, + ) + + @abstractmethod + def curves(cls): + """Return a ``CharacteristicCurveCollection`` of the curves related to + the specific model.""" + pass + + @property + def tendon_slack_length(self): + r"""Symbol or value corresponding to the tendon slack length constant. + + Explanation + =========== + + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + + The alias ``l_T_slack`` can also be used to access the same attribute. + + """ + return self._l_T_slack + + @property + def l_T_slack(self): + r"""Symbol or value corresponding to the tendon slack length constant. + + Explanation + =========== + + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + + The alias ``tendon_slack_length`` can also be used to access the same + attribute. + + """ + return self._l_T_slack + + @property + def peak_isometric_force(self): + r"""Symbol or value corresponding to the peak isometric force constant. + + Explanation + =========== + + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + + The alias ``F_M_max`` can also be used to access the same attribute. + + """ + return self._F_M_max + + @property + def F_M_max(self): + r"""Symbol or value corresponding to the peak isometric force constant. + + Explanation + =========== + + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + + The alias ``peak_isometric_force`` can also be used to access the same + attribute. + + """ + return self._F_M_max + + @property + def optimal_fiber_length(self): + r"""Symbol or value corresponding to the optimal fiber length constant. + + Explanation + =========== + + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + + The alias ``l_M_opt`` can also be used to access the same attribute. + + """ + return self._l_M_opt + + @property + def l_M_opt(self): + r"""Symbol or value corresponding to the optimal fiber length constant. + + Explanation + =========== + + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + + The alias ``optimal_fiber_length`` can also be used to access the same + attribute. + + """ + return self._l_M_opt + + @property + def maximal_fiber_velocity(self): + r"""Symbol or value corresponding to the maximal fiber velocity constant. + + Explanation + =========== + + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + + The alias ``v_M_max`` can also be used to access the same attribute. + + """ + return self._v_M_max + + @property + def v_M_max(self): + r"""Symbol or value corresponding to the maximal fiber velocity constant. + + Explanation + =========== + + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + + The alias ``maximal_fiber_velocity`` can also be used to access the same + attribute. + + """ + return self._v_M_max + + @property + def optimal_pennation_angle(self): + """Symbol or value corresponding to the optimal pennation angle + constant. + + Explanation + =========== + + The pennation angle when muscle fiber length equals the optimal fiber + length. + + The alias ``alpha_opt`` can also be used to access the same attribute. + + """ + return self._alpha_opt + + @property + def alpha_opt(self): + """Symbol or value corresponding to the optimal pennation angle + constant. + + Explanation + =========== + + The pennation angle when muscle fiber length equals the optimal fiber + length. + + The alias ``optimal_pennation_angle`` can also be used to access the + same attribute. + + """ + return self._alpha_opt + + @property + def fiber_damping_coefficient(self): + """Symbol or value corresponding to the fiber damping coefficient + constant. + + Explanation + =========== + + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + The alias ``beta`` can also be used to access the same attribute. + + """ + return self._beta + + @property + def beta(self): + """Symbol or value corresponding to the fiber damping coefficient + constant. + + Explanation + =========== + + The coefficient of damping to be used in the damping element in the + muscle fiber model. + + The alias ``fiber_damping_coefficient`` can also be used to access the + same attribute. + + """ + return self._beta + + @property + def activation_dynamics(self): + """Activation dynamics model governing this musculotendon's activation. + + Explanation + =========== + + Returns the instance of a subclass of ``ActivationBase`` that governs + the relationship between excitation and activation that is used to + represent the activation dynamics of this musculotendon. + + """ + return self._activation_dynamics + + @property + def excitation(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``e`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._e + + @property + def e(self): + """Dynamic symbol representing excitation. + + Explanation + =========== + + The alias ``excitation`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._e + + @property + def activation(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``a`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._a + + @property + def a(self): + """Dynamic symbol representing activation. + + Explanation + =========== + + The alias ``activation`` can also be used to access the same attribute. + + """ + return self._activation_dynamics._a + + @property + def musculotendon_dynamics(self): + """The choice of rigid or type of elastic tendon musculotendon dynamics. + + Explanation + =========== + + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + + """ + return self._musculotendon_dynamics + + def _rigid_tendon_musculotendon_dynamics(self): + """Rigid tendon musculotendon.""" + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._l_T = self._l_T_slack + self._l_T_tilde = Integer(1) + self._l_M = sqrt((self._l_MT - self._l_T)**2 + (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_M_tilde = self._l_M/self._l_M_opt + self._v_M = self._v_MT*(self._l_MT - self._l_T_slack)/self._l_M + self._v_M_tilde = self._v_M/self._v_M_max + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + self._fv_M = self.curves.fiber_force_velocity.with_defaults(self._v_M_tilde) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._fv_M = self.curves.fiber_force_velocity(self._v_M_tilde, *fv_M_constants) + self._F_M_tilde = self.a*self._fl_M_act*self._fv_M + self._fl_M_pas + self._beta*self._v_M_tilde + self._F_T_tilde = self._F_M_tilde + self._F_M = self._F_M_tilde*self._F_M_max + self._cos_alpha = cos(self._alpha_opt) + self._F_T = self._F_M*self._cos_alpha + + # Containers + self._state_vars = zeros(0, 1) + self._input_vars = zeros(0, 1) + self._state_eqns = zeros(0, 1) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _fiber_length_explicit_musculotendon_dynamics(self): + """Elastic tendon musculotendon using `l_M_tilde` as a state.""" + self._l_M_tilde = dynamicsymbols(f'l_M_tilde_{self.name}') + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._l_M = self._l_M_tilde*self._l_M_opt + self._l_T = self._l_MT - sqrt(self._l_M**2 - (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_T_tilde = self._l_T/self._l_T_slack + self._cos_alpha = (self._l_MT - self._l_T)/self._l_M + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + self._F_T_tilde = self._fl_T + self._F_T = self._F_T_tilde*self._F_M_max + self._F_M = self._F_T/self._cos_alpha + self._F_M_tilde = self._F_M/self._F_M_max + self._fv_M = (self._F_M_tilde - self._fl_M_pas)/(self.a*self._fl_M_act) + if self._with_defaults: + self._v_M_tilde = self.curves.fiber_force_velocity_inverse.with_defaults(self._fv_M) + else: + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._v_M_tilde = self.curves.fiber_force_velocity_inverse(self._fv_M, *fv_M_constants) + self._dl_M_tilde_dt = (self._v_M_max/self._l_M_opt)*self._v_M_tilde + + self._state_vars = Matrix([self._l_M_tilde]) + self._input_vars = zeros(0, 1) + self._state_eqns = Matrix([self._dl_M_tilde_dt]) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _tendon_force_explicit_musculotendon_dynamics(self): + """Elastic tendon musculotendon using `F_T_tilde` as a state.""" + self._F_T_tilde = dynamicsymbols(f'F_T_tilde_{self.name}') + self._l_MT = self.pathway.length + self._v_MT = self.pathway.extension_velocity + self._fl_T = self._F_T_tilde + if self._with_defaults: + self._fl_T_inv = self.curves.tendon_force_length_inverse.with_defaults(self._fl_T) + else: + fl_T_constants = symbols(f'c_0:4_fl_T_{self.name}') + self._fl_T_inv = self.curves.tendon_force_length_inverse(self._fl_T, *fl_T_constants) + self._l_T_tilde = self._fl_T_inv + self._l_T = self._l_T_tilde*self._l_T_slack + self._l_M = sqrt((self._l_MT - self._l_T)**2 + (self._l_M_opt*sin(self._alpha_opt))**2) + self._l_M_tilde = self._l_M/self._l_M_opt + if self._with_defaults: + self._fl_M_pas = self.curves.fiber_force_length_passive.with_defaults(self._l_M_tilde) + self._fl_M_act = self.curves.fiber_force_length_active.with_defaults(self._l_M_tilde) + else: + fl_M_pas_constants = symbols(f'c_0:2_fl_M_pas_{self.name}') + self._fl_M_pas = self.curves.fiber_force_length_passive(self._l_M_tilde, *fl_M_pas_constants) + fl_M_act_constants = symbols(f'c_0:12_fl_M_act_{self.name}') + self._fl_M_act = self.curves.fiber_force_length_active(self._l_M_tilde, *fl_M_act_constants) + self._cos_alpha = (self._l_MT - self._l_T)/self._l_M + self._F_T = self._F_T_tilde*self._F_M_max + self._F_M = self._F_T/self._cos_alpha + self._F_M_tilde = self._F_M/self._F_M_max + self._fv_M = (self._F_M_tilde - self._fl_M_pas)/(self.a*self._fl_M_act) + if self._with_defaults: + self._fv_M_inv = self.curves.fiber_force_velocity_inverse.with_defaults(self._fv_M) + else: + fv_M_constants = symbols(f'c_0:4_fv_M_{self.name}') + self._fv_M_inv = self.curves.fiber_force_velocity_inverse(self._fv_M, *fv_M_constants) + self._v_M_tilde = self._fv_M_inv + self._v_M = self._v_M_tilde*self._v_M_max + self._v_T = self._v_MT - (self._v_M/self._cos_alpha) + self._v_T_tilde = self._v_T/self._l_T_slack + if self._with_defaults: + self._fl_T = self.curves.tendon_force_length.with_defaults(self._l_T_tilde) + else: + self._fl_T = self.curves.tendon_force_length(self._l_T_tilde, *fl_T_constants) + self._dF_T_tilde_dt = self._fl_T.diff(dynamicsymbols._t).subs({self._l_T_tilde.diff(dynamicsymbols._t): self._v_T_tilde}) + + self._state_vars = Matrix([self._F_T_tilde]) + self._input_vars = zeros(0, 1) + self._state_eqns = Matrix([self._dF_T_tilde_dt]) + self._curve_constants = Matrix( + fl_T_constants + + fl_M_pas_constants + + fl_M_act_constants + + fv_M_constants + ) if not self._with_defaults else zeros(0, 1) + + def _fiber_length_implicit_musculotendon_dynamics(self): + raise NotImplementedError + + def _tendon_force_implicit_musculotendon_dynamics(self): + raise NotImplementedError + + @property + def state_vars(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``x`` can also be used to access the same attribute. + + """ + state_vars = [self._state_vars] + for child in self._child_objects: + state_vars.append(child.state_vars) + return Matrix.vstack(*state_vars) + + @property + def x(self): + """Ordered column matrix of functions of time that represent the state + variables. + + Explanation + =========== + + The alias ``state_vars`` can also be used to access the same attribute. + + """ + state_vars = [self._state_vars] + for child in self._child_objects: + state_vars.append(child.state_vars) + return Matrix.vstack(*state_vars) + + @property + def input_vars(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``r`` can also be used to access the same attribute. + + """ + input_vars = [self._input_vars] + for child in self._child_objects: + input_vars.append(child.input_vars) + return Matrix.vstack(*input_vars) + + @property + def r(self): + """Ordered column matrix of functions of time that represent the input + variables. + + Explanation + =========== + + The alias ``input_vars`` can also be used to access the same attribute. + + """ + input_vars = [self._input_vars] + for child in self._child_objects: + input_vars.append(child.input_vars) + return Matrix.vstack(*input_vars) + + @property + def constants(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``p`` can also be used to access the same attribute. + + """ + musculotendon_constants = [ + self._l_T_slack, + self._F_M_max, + self._l_M_opt, + self._v_M_max, + self._alpha_opt, + self._beta, + ] + musculotendon_constants = [ + c for c in musculotendon_constants if not c.is_number + ] + constants = [ + Matrix(musculotendon_constants) + if musculotendon_constants + else zeros(0, 1) + ] + for child in self._child_objects: + constants.append(child.constants) + constants.append(self._curve_constants) + return Matrix.vstack(*constants) + + @property + def p(self): + """Ordered column matrix of non-time varying symbols present in ``M`` + and ``F``. + + Explanation + =========== + + Only symbolic constants are returned. If a numeric type (e.g. ``Float``) + has been used instead of ``Symbol`` for a constant then that attribute + will not be included in the matrix returned by this property. This is + because the primary use of this property attribute is to provide an + ordered sequence of the still-free symbols that require numeric values + during code generation. + + The alias ``constants`` can also be used to access the same attribute. + + """ + musculotendon_constants = [ + self._l_T_slack, + self._F_M_max, + self._l_M_opt, + self._v_M_max, + self._alpha_opt, + self._beta, + ] + musculotendon_constants = [ + c for c in musculotendon_constants if not c.is_number + ] + constants = [ + Matrix(musculotendon_constants) + if musculotendon_constants + else zeros(0, 1) + ] + for child in self._child_objects: + constants.append(child.constants) + constants.append(self._curve_constants) + return Matrix.vstack(*constants) + + @property + def M(self): + """Ordered square matrix of coefficients on the LHS of ``M x' = F``. + + Explanation + =========== + + The square matrix that forms part of the LHS of the linear system of + ordinary differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``M`` is an empty square + ``Matrix`` with shape (0, 0). + + """ + M = [eye(len(self._state_vars))] + for child in self._child_objects: + M.append(child.M) + return diag(*M) + + @property + def F(self): + """Ordered column matrix of equations on the RHS of ``M x' = F``. + + Explanation + =========== + + The column matrix that forms the RHS of the linear system of ordinary + differential equations governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear system has dimension 0 and therefore ``F`` is an empty column + ``Matrix`` with shape (0, 1). + + """ + F = [self._state_eqns] + for child in self._child_objects: + F.append(child.F) + return Matrix.vstack(*F) + + def rhs(self): + """Ordered column matrix of equations for the solution of ``M x' = F``. + + Explanation + =========== + + The solution to the linear system of ordinary differential equations + governing the activation dynamics: + + ``M(x, r, t, p) x' = F(x, r, t, p)``. + + As zeroth-order activation dynamics have no state variables, this + linear has dimension 0 and therefore this method returns an empty + column ``Matrix`` with shape (0, 1). + + """ + is_explicit = ( + MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + MusculotendonFormulation.TENDON_FORCE_EXPLICIT, + ) + if self.musculotendon_dynamics is MusculotendonFormulation.RIGID_TENDON: + child_rhs = [child.rhs() for child in self._child_objects] + return Matrix.vstack(*child_rhs) + elif self.musculotendon_dynamics in is_explicit: + rhs = self._state_eqns + child_rhs = [child.rhs() for child in self._child_objects] + return Matrix.vstack(rhs, *child_rhs) + return self.M.solve(self.F) + + def __repr__(self): + """Returns a string representation to reinstantiate the model.""" + return ( + f'{self.__class__.__name__}({self.name!r}, ' + f'pathway={self.pathway!r}, ' + f'activation_dynamics={self.activation_dynamics!r}, ' + f'musculotendon_dynamics={self.musculotendon_dynamics}, ' + f'tendon_slack_length={self._l_T_slack!r}, ' + f'peak_isometric_force={self._F_M_max!r}, ' + f'optimal_fiber_length={self._l_M_opt!r}, ' + f'maximal_fiber_velocity={self._v_M_max!r}, ' + f'optimal_pennation_angle={self._alpha_opt!r}, ' + f'fiber_damping_coefficient={self._beta!r})' + ) + + def __str__(self): + """Returns a string representation of the expression for musculotendon + force.""" + return str(self.force) + + +class MusculotendonDeGroote2016(MusculotendonBase): + r"""Musculotendon model using the curves of De Groote et al., 2016 [1]_. + + Examples + ======== + + This class models the musculotendon actuator parametrized by the + characteristic curves described in De Groote et al., 2016 [1]_. Like all + musculotendon models in SymPy's biomechanics module, it requires a pathway + to define its line of action. We'll begin by creating a simple + ``LinearPathway`` between two points that our musculotendon will follow. + We'll create a point ``O`` to represent the musculotendon's origin and + another ``I`` to represent its insertion. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame, dynamicsymbols) + + >>> N = ReferenceFrame('N') + >>> O, I = O, P = symbols('O, I', cls=Point) + >>> q, u = dynamicsymbols('q, u', real=True) + >>> I.set_pos(O, q*N.x) + >>> O.set_vel(N, 0) + >>> I.set_vel(N, u*N.x) + >>> pathway = LinearPathway(O, I) + >>> pathway.attachments + (O, I) + >>> pathway.length + Abs(q(t)) + >>> pathway.extension_velocity + sign(q(t))*Derivative(q(t), t) + + A musculotendon also takes an instance of an activation dynamics model as + this will be used to provide symbols for the activation in the formulation + of the musculotendon dynamics. We'll use an instance of + ``FirstOrderActivationDeGroote2016`` to represent first-order activation + dynamics. Note that a single name argument needs to be provided as SymPy + will use this as a suffix. + + >>> from sympy.physics.biomechanics import FirstOrderActivationDeGroote2016 + + >>> activation = FirstOrderActivationDeGroote2016('muscle') + >>> activation.x + Matrix([[a_muscle(t)]]) + >>> activation.r + Matrix([[e_muscle(t)]]) + >>> activation.p + Matrix([ + [tau_a_muscle], + [tau_d_muscle], + [ b_muscle]]) + >>> activation.rhs() + Matrix([[((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + The musculotendon class requires symbols or values to be passed to represent + the constants in the musculotendon dynamics. We'll use SymPy's ``symbols`` + function to create symbols for the maximum isometric force ``F_M_max``, + optimal fiber length ``l_M_opt``, tendon slack length ``l_T_slack``, maximum + fiber velocity ``v_M_max``, optimal pennation angle ``alpha_opt, and fiber + damping coefficient ``beta``. + + >>> F_M_max = symbols('F_M_max', real=True) + >>> l_M_opt = symbols('l_M_opt', real=True) + >>> l_T_slack = symbols('l_T_slack', real=True) + >>> v_M_max = symbols('v_M_max', real=True) + >>> alpha_opt = symbols('alpha_opt', real=True) + >>> beta = symbols('beta', real=True) + + We can then import the class ``MusculotendonDeGroote2016`` from the + biomechanics module and create an instance by passing in the various objects + we have previously instantiated. By default, a musculotendon model with + rigid tendon musculotendon dynamics will be created. + + >>> from sympy.physics.biomechanics import MusculotendonDeGroote2016 + + >>> rigid_tendon_muscle = MusculotendonDeGroote2016( + ... 'muscle', + ... pathway, + ... activation, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... maximal_fiber_velocity=v_M_max, + ... optimal_pennation_angle=alpha_opt, + ... fiber_damping_coefficient=beta, + ... ) + + We can inspect the various properties of the musculotendon, including + getting the symbolic expression describing the force it produces using its + ``force`` attribute. + + >>> rigid_tendon_muscle.force + -F_M_max*(beta*(-l_T_slack + Abs(q(t)))*sign(q(t))*Derivative(q(t), t)... + + When we created the musculotendon object, we passed in an instance of an + activation dynamics object that governs the activation within the + musculotendon. SymPy makes a design choice here that the activation dynamics + instance will be treated as a child object of the musculotendon dynamics. + Therefore, if we want to inspect the state and input variables associated + with the musculotendon model, we will also be returned the state and input + variables associated with the child object, or the activation dynamics in + this case. As the musculotendon model that we created here uses rigid tendon + dynamics, no additional states or inputs relating to the musculotendon are + introduces. Consequently, the model has a single state associated with it, + the activation, and a single input associated with it, the excitation. The + states and inputs can be inspected using the ``x`` and ``r`` attributes + respectively. Note that both ``x`` and ``r`` have the alias attributes of + ``state_vars`` and ``input_vars``. + + >>> rigid_tendon_muscle.x + Matrix([[a_muscle(t)]]) + >>> rigid_tendon_muscle.r + Matrix([[e_muscle(t)]]) + + To see which constants are symbolic in the musculotendon model, we can use + the ``p`` or ``constants`` attribute. This returns a ``Matrix`` populated + by the constants that are represented by a ``Symbol`` rather than a numeric + value. + + >>> rigid_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [ v_M_max], + [ alpha_opt], + [ beta], + [ tau_a_muscle], + [ tau_d_muscle], + [ b_muscle], + [ c_0_fl_T_muscle], + [ c_1_fl_T_muscle], + [ c_2_fl_T_muscle], + [ c_3_fl_T_muscle], + [ c_0_fl_M_pas_muscle], + [ c_1_fl_M_pas_muscle], + [ c_0_fl_M_act_muscle], + [ c_1_fl_M_act_muscle], + [ c_2_fl_M_act_muscle], + [ c_3_fl_M_act_muscle], + [ c_4_fl_M_act_muscle], + [ c_5_fl_M_act_muscle], + [ c_6_fl_M_act_muscle], + [ c_7_fl_M_act_muscle], + [ c_8_fl_M_act_muscle], + [ c_9_fl_M_act_muscle], + [c_10_fl_M_act_muscle], + [c_11_fl_M_act_muscle], + [ c_0_fv_M_muscle], + [ c_1_fv_M_muscle], + [ c_2_fv_M_muscle], + [ c_3_fv_M_muscle]]) + + Finally, we can call the ``rhs`` method to return a ``Matrix`` that + contains as its elements the righthand side of the ordinary differential + equations corresponding to each of the musculotendon's states. Like the + method with the same name on the ``Method`` classes in SymPy's mechanics + module, this returns a column vector where the number of rows corresponds to + the number of states. For our example here, we have a single state, the + dynamic symbol ``a_muscle(t)``, so the returned value is a 1-by-1 + ``Matrix``. + + >>> rigid_tendon_muscle.rhs() + Matrix([[((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + The musculotendon class supports elastic tendon musculotendon models in + addition to rigid tendon ones. You can choose to either use the fiber length + or tendon force as an additional state. You can also specify whether an + explicit or implicit formulation should be used. To select a formulation, + pass a member of the ``MusculotendonFormulation`` enumeration to the + ``musculotendon_dynamics`` parameter when calling the constructor. This + enumeration is an ``IntEnum``, so you can also pass an integer, however it + is recommended to use the enumeration as it is clearer which formulation you + are actually selecting. Below, we'll use the ``FIBER_LENGTH_EXPLICIT`` + member to create a musculotendon with an elastic tendon that will use the + (normalized) muscle fiber length as an additional state and will produce + the governing ordinary differential equation in explicit form. + + >>> from sympy.physics.biomechanics import MusculotendonFormulation + + >>> elastic_tendon_muscle = MusculotendonDeGroote2016( + ... 'muscle', + ... pathway, + ... activation, + ... musculotendon_dynamics=MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... maximal_fiber_velocity=v_M_max, + ... optimal_pennation_angle=alpha_opt, + ... fiber_damping_coefficient=beta, + ... ) + + >>> elastic_tendon_muscle.force + -F_M_max*TendonForceLengthDeGroote2016((-sqrt(l_M_opt**2*... + >>> elastic_tendon_muscle.x + Matrix([ + [l_M_tilde_muscle(t)], + [ a_muscle(t)]]) + >>> elastic_tendon_muscle.r + Matrix([[e_muscle(t)]]) + >>> elastic_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [ v_M_max], + [ alpha_opt], + [ beta], + [ tau_a_muscle], + [ tau_d_muscle], + [ b_muscle], + [ c_0_fl_T_muscle], + [ c_1_fl_T_muscle], + [ c_2_fl_T_muscle], + [ c_3_fl_T_muscle], + [ c_0_fl_M_pas_muscle], + [ c_1_fl_M_pas_muscle], + [ c_0_fl_M_act_muscle], + [ c_1_fl_M_act_muscle], + [ c_2_fl_M_act_muscle], + [ c_3_fl_M_act_muscle], + [ c_4_fl_M_act_muscle], + [ c_5_fl_M_act_muscle], + [ c_6_fl_M_act_muscle], + [ c_7_fl_M_act_muscle], + [ c_8_fl_M_act_muscle], + [ c_9_fl_M_act_muscle], + [c_10_fl_M_act_muscle], + [c_11_fl_M_act_muscle], + [ c_0_fv_M_muscle], + [ c_1_fv_M_muscle], + [ c_2_fv_M_muscle], + [ c_3_fv_M_muscle]]) + >>> elastic_tendon_muscle.rhs() + Matrix([ + [v_M_max*FiberForceVelocityInverseDeGroote2016((l_M_opt*...], + [ ((1/2 - tanh(b_muscle*(-a_muscle(t) + e_muscle(t)))/2)*(3*...]]) + + It is strongly recommended to use the alternate ``with_defaults`` + constructor when creating an instance because this will ensure that the + published constants are used in the musculotendon characteristic curves. + + >>> elastic_tendon_muscle = MusculotendonDeGroote2016.with_defaults( + ... 'muscle', + ... pathway, + ... activation, + ... musculotendon_dynamics=MusculotendonFormulation.FIBER_LENGTH_EXPLICIT, + ... tendon_slack_length=l_T_slack, + ... peak_isometric_force=F_M_max, + ... optimal_fiber_length=l_M_opt, + ... ) + + >>> elastic_tendon_muscle.x + Matrix([ + [l_M_tilde_muscle(t)], + [ a_muscle(t)]]) + >>> elastic_tendon_muscle.r + Matrix([[e_muscle(t)]]) + >>> elastic_tendon_muscle.p + Matrix([ + [ l_T_slack], + [ F_M_max], + [ l_M_opt], + [tau_a_muscle], + [tau_d_muscle], + [ b_muscle]]) + + Parameters + ========== + + name : str + The name identifier associated with the musculotendon. This name is used + as a suffix when automatically generated symbols are instantiated. It + must be a string of nonzero length. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + activation_dynamics : ActivationBase + The activation dynamics that will be modeled within the musculotendon. + This must be an instance of a concrete subclass of ``ActivationBase``, + e.g. ``FirstOrderActivationDeGroote2016``. + musculotendon_dynamics : MusculotendonFormulation | int + The formulation of musculotendon dynamics that should be used + internally, i.e. rigid or elastic tendon model, the choice of + musculotendon state etc. This must be a member of the integer + enumeration ``MusculotendonFormulation`` or an integer that can be cast + to a member. To use a rigid tendon formulation, set this to + ``MusculotendonFormulation.RIGID_TENDON`` (or the integer value ``0``, + which will be cast to the enumeration member). There are four possible + formulations for an elastic tendon model. To use an explicit formulation + with the fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_EXPLICIT`` (or the integer value + ``1``). To use an explicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_EXPLICIT`` + (or the integer value ``2``). To use an implicit formulation with the + fiber length as the state, set this to + ``MusculotendonFormulation.FIBER_LENGTH_IMPLICIT`` (or the integer value + ``3``). To use an implicit formulation with the tendon force as the + state, set this to ``MusculotendonFormulation.TENDON_FORCE_IMPLICIT`` + (or the integer value ``4``). The default is + ``MusculotendonFormulation.RIGID_TENDON``, which corresponds to a rigid + tendon formulation. + tendon_slack_length : Expr | None + The length of the tendon when the musculotendon is in its unloaded + state. In a rigid tendon model the tendon length is the tendon slack + length. In all musculotendon models, tendon slack length is used to + normalize tendon length to give + :math:`\tilde{l}^T = \frac{l^T}{l^T_{slack}}`. + peak_isometric_force : Expr | None + The maximum force that the muscle fiber can produce when it is + undergoing an isometric contraction (no lengthening velocity). In all + musculotendon models, peak isometric force is used to normalized tendon + and muscle fiber force to give + :math:`\tilde{F}^T = \frac{F^T}{F^M_{max}}`. + optimal_fiber_length : Expr | None + The muscle fiber length at which the muscle fibers produce no passive + force and their maximum active force. In all musculotendon models, + optimal fiber length is used to normalize muscle fiber length to give + :math:`\tilde{l}^M = \frac{l^M}{l^M_{opt}}`. + maximal_fiber_velocity : Expr | None + The fiber velocity at which, during muscle fiber shortening, the muscle + fibers are unable to produce any active force. In all musculotendon + models, maximal fiber velocity is used to normalize muscle fiber + extension velocity to give :math:`\tilde{v}^M = \frac{v^M}{v^M_{max}}`. + optimal_pennation_angle : Expr | None + The pennation angle when muscle fiber length equals the optimal fiber + length. + fiber_damping_coefficient : Expr | None + The coefficient of damping to be used in the damping element in the + muscle fiber model. + with_defaults : bool + Whether ``with_defaults`` alternate constructors should be used when + automatically constructing child classes. Default is ``False``. + + References + ========== + + .. [1] De Groote, F., Kinney, A. L., Rao, A. V., & Fregly, B. J., Evaluation + of direct collocation optimal control problem formulations for + solving the muscle redundancy problem, Annals of biomedical + engineering, 44(10), (2016) pp. 2922-2936 + + """ + + curves = CharacteristicCurveCollection( + tendon_force_length=TendonForceLengthDeGroote2016, + tendon_force_length_inverse=TendonForceLengthInverseDeGroote2016, + fiber_force_length_passive=FiberForceLengthPassiveDeGroote2016, + fiber_force_length_passive_inverse=FiberForceLengthPassiveInverseDeGroote2016, + fiber_force_length_active=FiberForceLengthActiveDeGroote2016, + fiber_force_velocity=FiberForceVelocityDeGroote2016, + fiber_force_velocity_inverse=FiberForceVelocityInverseDeGroote2016, + ) diff --git a/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/__init__.py b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b1c040fe7d1f66dc4ef2dc18061d0744f08d5258 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/__init__.py @@ -0,0 +1,6 @@ +__all__ = ['Beam', + 'Truss', 'Cable'] + +from .beam import Beam +from .truss import Truss +from .cable import Cable diff --git a/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/beam.py b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/beam.py new file mode 100644 index 0000000000000000000000000000000000000000..b89474a6b411e359789eacb8047f9845d19e5393 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/beam.py @@ -0,0 +1,3732 @@ +""" +This module can be used to solve 2D beam bending problems with +singularity functions in mechanics. +""" + +from sympy.core import S, Symbol, diff, symbols +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.relational import Eq +from sympy.core.sympify import sympify +from sympy.solvers import linsolve +from sympy.solvers.ode.ode import dsolve +from sympy.solvers.solvers import solve +from sympy.printing import sstr +from sympy.functions import SingularityFunction, Piecewise, factorial +from sympy.integrals import integrate +from sympy.series import limit +from sympy.plotting import plot, PlotGrid +from sympy.geometry.entity import GeometryEntity +from sympy.external import import_module +from sympy.sets.sets import Interval +from sympy.utilities.lambdify import lambdify +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import iterable +import warnings + + +__doctest_requires__ = { + ('Beam.draw', + 'Beam.plot_bending_moment', + 'Beam.plot_deflection', + 'Beam.plot_ild_moment', + 'Beam.plot_ild_shear', + 'Beam.plot_shear_force', + 'Beam.plot_shear_stress', + 'Beam.plot_slope'): ['matplotlib'], +} + + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + + +class Beam: + """ + A Beam is a structural element that is capable of withstanding load + primarily by resisting against bending. Beams are characterized by + their cross sectional profile(Second moment of area), their length + and their material. + + .. note:: + A consistent sign convention must be used while solving a beam + bending problem; the results will + automatically follow the chosen sign convention. However, the + chosen sign convention must respect the rule that, on the positive + side of beam's axis (in respect to current section), a loading force + giving positive shear yields a negative moment, as below (the + curved arrow shows the positive moment and rotation): + + .. image:: allowed-sign-conventions.png + + Examples + ======== + There is a beam of length 4 meters. A constant distributed load of 6 N/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. The deflection of the beam at the end is restricted. + + Using the sign convention of downwards forces being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols, Piecewise + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(4, E, I) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(6, 2, 0) + >>> b.apply_load(R2, 4, -1) + >>> b.bc_deflection = [(0, 0), (4, 0)] + >>> b.boundary_conditions + {'deflection': [(0, 0), (4, 0)], 'slope': []} + >>> b.load + R1*SingularityFunction(x, 0, -1) + R2*SingularityFunction(x, 4, -1) + 6*SingularityFunction(x, 2, 0) + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.load + -3*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 2, 0) - 9*SingularityFunction(x, 4, -1) + >>> b.shear_force() + 3*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 2, 1) + 9*SingularityFunction(x, 4, 0) + >>> b.bending_moment() + 3*SingularityFunction(x, 0, 1) - 3*SingularityFunction(x, 2, 2) + 9*SingularityFunction(x, 4, 1) + >>> b.slope() + (-3*SingularityFunction(x, 0, 2)/2 + SingularityFunction(x, 2, 3) - 9*SingularityFunction(x, 4, 2)/2 + 7)/(E*I) + >>> b.deflection() + (7*x - SingularityFunction(x, 0, 3)/2 + SingularityFunction(x, 2, 4)/4 - 3*SingularityFunction(x, 4, 3)/2)/(E*I) + >>> b.deflection().rewrite(Piecewise) + (7*x - Piecewise((x**3, x >= 0), (0, True))/2 + - 3*Piecewise(((x - 4)**3, x >= 4), (0, True))/2 + + Piecewise(((x - 2)**4, x >= 2), (0, True))/4)/(E*I) + + Calculate the support reactions for a fully symbolic beam of length L. + There are two simple supports below the beam, one at the starting point + and another at the ending point of the beam. The deflection of the beam + at the end is restricted. The beam is loaded with: + + * a downward point load P1 applied at L/4 + * an upward point load P2 applied at L/8 + * a counterclockwise moment M1 applied at L/2 + * a clockwise moment M2 applied at 3*L/4 + * a distributed constant load q1, applied downward, starting from L/2 + up to 3*L/4 + * a distributed constant load q2, applied upward, starting from 3*L/4 + up to L + + No assumptions are needed for symbolic loads. However, defining a positive + length will help the algorithm to compute the solution. + + >>> E, I = symbols('E, I') + >>> L = symbols("L", positive=True) + >>> P1, P2, M1, M2, q1, q2 = symbols("P1, P2, M1, M2, q1, q2") + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(L, E, I) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, L, -1) + >>> b.apply_load(P1, L/4, -1) + >>> b.apply_load(-P2, L/8, -1) + >>> b.apply_load(M1, L/2, -2) + >>> b.apply_load(-M2, 3*L/4, -2) + >>> b.apply_load(q1, L/2, 0, 3*L/4) + >>> b.apply_load(-q2, 3*L/4, 0, L) + >>> b.bc_deflection = [(0, 0), (L, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> print(b.reaction_loads[R1]) + (-3*L**2*q1 + L**2*q2 - 24*L*P1 + 28*L*P2 - 32*M1 + 32*M2)/(32*L) + >>> print(b.reaction_loads[R2]) + (-5*L**2*q1 + 7*L**2*q2 - 8*L*P1 + 4*L*P2 + 32*M1 - 32*M2)/(32*L) + """ + + def __init__(self, length, elastic_modulus, second_moment, area=Symbol('A'), variable=Symbol('x'), base_char='C'): + """Initializes the class. + + Parameters + ========== + + length : Sympifyable + A Symbol or value representing the Beam's length. + + elastic_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of Elasticity. + It is a measure of the stiffness of the Beam material. It can + also be a continuous function of position along the beam. + + second_moment : Sympifyable or Geometry object + Describes the cross-section of the beam via a SymPy expression + representing the Beam's second moment of area. It is a geometrical + property of an area which reflects how its points are distributed + with respect to its neutral axis. It can also be a continuous + function of position along the beam. Alternatively ``second_moment`` + can be a shape object such as a ``Polygon`` from the geometry module + representing the shape of the cross-section of the beam. In such cases, + it is assumed that the x-axis of the shape object is aligned with the + bending axis of the beam. The second moment of area will be computed + from the shape object internally. + + area : Symbol/float + Represents the cross-section area of beam + + variable : Symbol, optional + A Symbol object that will be used as the variable along the beam + while representing the load, shear, moment, slope and deflection + curve. By default, it is set to ``Symbol('x')``. + + base_char : String, optional + A String that will be used as base character to generate sequential + symbols for integration constants in cases where boundary conditions + are not sufficient to solve them. + """ + self.length = length + self.elastic_modulus = elastic_modulus + if isinstance(second_moment, GeometryEntity): + self.cross_section = second_moment + else: + self.cross_section = None + self.second_moment = second_moment + self.variable = variable + self._base_char = base_char + self._boundary_conditions = {'deflection': [], 'slope': []} + self._load = 0 + self.area = area + self._applied_supports = [] + self._support_as_loads = [] + self._applied_loads = [] + self._reaction_loads = {} + self._ild_reactions = {} + self._ild_shear = 0 + self._ild_moment = 0 + # _original_load is a copy of _load equations with unsubstituted reaction + # forces. It is used for calculating reaction forces in case of I.L.D. + self._original_load = 0 + self._composite_type = None + self._hinge_position = None + + def __str__(self): + shape_description = self._cross_section if self._cross_section else self._second_moment + str_sol = 'Beam({}, {}, {})'.format(sstr(self._length), sstr(self._elastic_modulus), sstr(shape_description)) + return str_sol + + @property + def reaction_loads(self): + """ Returns the reaction forces in a dictionary.""" + return self._reaction_loads + + @property + def ild_shear(self): + """ Returns the I.L.D. shear equation.""" + return self._ild_shear + + @property + def ild_reactions(self): + """ Returns the I.L.D. reaction forces in a dictionary.""" + return self._ild_reactions + + @property + def ild_moment(self): + """ Returns the I.L.D. moment equation.""" + return self._ild_moment + + @property + def length(self): + """Length of the Beam.""" + return self._length + + @length.setter + def length(self, l): + self._length = sympify(l) + + @property + def area(self): + """Cross-sectional area of the Beam. """ + return self._area + + @area.setter + def area(self, a): + self._area = sympify(a) + + @property + def variable(self): + """ + A symbol that can be used as a variable along the length of the beam + while representing load distribution, shear force curve, bending + moment, slope curve and the deflection curve. By default, it is set + to ``Symbol('x')``, but this property is mutable. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I, A = symbols('E, I, A') + >>> x, y, z = symbols('x, y, z') + >>> b = Beam(4, E, I) + >>> b.variable + x + >>> b.variable = y + >>> b.variable + y + >>> b = Beam(4, E, I, A, z) + >>> b.variable + z + """ + return self._variable + + @variable.setter + def variable(self, v): + if isinstance(v, Symbol): + self._variable = v + else: + raise TypeError("""The variable should be a Symbol object.""") + + @property + def elastic_modulus(self): + """Young's Modulus of the Beam. """ + return self._elastic_modulus + + @elastic_modulus.setter + def elastic_modulus(self, e): + self._elastic_modulus = sympify(e) + + @property + def second_moment(self): + """Second moment of area of the Beam. """ + return self._second_moment + + @second_moment.setter + def second_moment(self, i): + self._cross_section = None + if isinstance(i, GeometryEntity): + raise ValueError("To update cross-section geometry use `cross_section` attribute") + else: + self._second_moment = sympify(i) + + @property + def cross_section(self): + """Cross-section of the beam""" + return self._cross_section + + @cross_section.setter + def cross_section(self, s): + if s: + self._second_moment = s.second_moment_of_area()[0] + self._cross_section = s + + @property + def boundary_conditions(self): + """ + Returns a dictionary of boundary conditions applied on the beam. + The dictionary has three keywords namely moment, slope and deflection. + The value of each keyword is a list of tuple, where each tuple + contains location and value of a boundary condition in the format + (location, value). + + Examples + ======== + There is a beam of length 4 meters. The bending moment at 0 should be 4 + and at 4 it should be 0. The slope of the beam should be 1 at 0. The + deflection should be 2 at 0. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.bc_deflection = [(0, 2)] + >>> b.bc_slope = [(0, 1)] + >>> b.boundary_conditions + {'deflection': [(0, 2)], 'slope': [(0, 1)]} + + Here the deflection of the beam should be ``2`` at ``0``. + Similarly, the slope of the beam should be ``1`` at ``0``. + """ + return self._boundary_conditions + + @property + def bc_slope(self): + return self._boundary_conditions['slope'] + + @bc_slope.setter + def bc_slope(self, s_bcs): + self._boundary_conditions['slope'] = s_bcs + + @property + def bc_deflection(self): + return self._boundary_conditions['deflection'] + + @bc_deflection.setter + def bc_deflection(self, d_bcs): + self._boundary_conditions['deflection'] = d_bcs + + def join(self, beam, via="fixed"): + """ + This method joins two beams to make a new composite beam system. + Passed Beam class instance is attached to the right end of calling + object. This method can be used to form beams having Discontinuous + values of Elastic modulus or Second moment. + + Parameters + ========== + beam : Beam class object + The Beam object which would be connected to the right of calling + object. + via : String + States the way two Beam object would get connected + - For axially fixed Beams, via="fixed" + - For Beams connected via hinge, via="hinge" + + Examples + ======== + There is a cantilever beam of length 4 meters. For first 2 meters + its moment of inertia is `1.5*I` and `I` for the other end. + A pointload of magnitude 4 N is applied from the top at its free end. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b1 = Beam(2, E, 1.5*I) + >>> b2 = Beam(2, E, I) + >>> b = b1.join(b2, "fixed") + >>> b.apply_load(20, 4, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 0, -2) + >>> b.bc_slope = [(0, 0)] + >>> b.bc_deflection = [(0, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.load + 80*SingularityFunction(x, 0, -2) - 20*SingularityFunction(x, 0, -1) + 20*SingularityFunction(x, 4, -1) + >>> b.slope() + (-((-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))/I + 120/I)/E + 80.0/(E*I))*SingularityFunction(x, 2, 0) + - 0.666666666666667*(-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))*SingularityFunction(x, 0, 0)/(E*I) + + 0.666666666666667*(-80*SingularityFunction(x, 0, 1) + 10*SingularityFunction(x, 0, 2) - 10*SingularityFunction(x, 4, 2))*SingularityFunction(x, 2, 0)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + new_length = self.length + beam.length + if self.second_moment != beam.second_moment: + new_second_moment = Piecewise((self.second_moment, x<=self.length), + (beam.second_moment, x<=new_length)) + else: + new_second_moment = self.second_moment + + if via == "fixed": + new_beam = Beam(new_length, E, new_second_moment, x) + new_beam._composite_type = "fixed" + return new_beam + + if via == "hinge": + new_beam = Beam(new_length, E, new_second_moment, x) + new_beam._composite_type = "hinge" + new_beam._hinge_position = self.length + return new_beam + + def apply_support(self, loc, type="fixed"): + """ + This method applies support to a particular beam object and returns + the symbol of the unknown reaction load(s). + + Parameters + ========== + loc : Sympifyable + Location of point at which support is applied. + type : String + Determines type of Beam support applied. To apply support structure + with + - zero degree of freedom, type = "fixed" + - one degree of freedom, type = "pin" + - two degrees of freedom, type = "roller" + + Returns + ======= + Symbol or tuple of Symbol + The unknown reaction load as a symbol. + - Symbol(reaction_force) if type = "pin" or "roller" + - Symbol(reaction_force), Symbol(reaction_moment) if type = "fixed" + + Examples + ======== + There is a beam of length 20 meters. A moment of magnitude 100 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at a distance of 10 meters. + There is one fixed support at the start of the beam and a roller at the end. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(20, E, I) + >>> p0, m0 = b.apply_support(0, 'fixed') + >>> p1 = b.apply_support(20, 'roller') + >>> b.apply_load(-8, 10, -1) + >>> b.apply_load(100, 20, -2) + >>> b.solve_for_reaction_loads(p0, m0, p1) + >>> b.reaction_loads + {M_0: 20, R_0: -2, R_20: 10} + >>> b.reaction_loads[p0] + -2 + >>> b.load + 20*SingularityFunction(x, 0, -2) - 2*SingularityFunction(x, 0, -1) + - 8*SingularityFunction(x, 10, -1) + 100*SingularityFunction(x, 20, -2) + + 10*SingularityFunction(x, 20, -1) + """ + loc = sympify(loc) + self._applied_supports.append((loc, type)) + if type in ("pin", "roller"): + reaction_load = Symbol('R_'+str(loc)) + self.apply_load(reaction_load, loc, -1) + self.bc_deflection.append((loc, 0)) + else: + reaction_load = Symbol('R_'+str(loc)) + reaction_moment = Symbol('M_'+str(loc)) + self.apply_load(reaction_load, loc, -1) + self.apply_load(reaction_moment, loc, -2) + self.bc_deflection.append((loc, 0)) + self.bc_slope.append((loc, 0)) + self._support_as_loads.append((reaction_moment, loc, -2, None)) + + self._support_as_loads.append((reaction_load, loc, -1, None)) + + if type in ("pin", "roller"): + return reaction_load + else: + return reaction_load, reaction_moment + + def apply_load(self, value, start, order, end=None): + """ + This method adds up the loads given to a particular beam object. + + Parameters + ========== + value : Sympifyable + The value inserted should have the units [Force/(Distance**(n+1)] + where n is the order of applied load. + Units for applied loads: + + - For moments, unit = kN*m + - For point loads, unit = kN + - For constant distributed load, unit = kN/m + - For ramp loads, unit = kN/m/m + - For parabolic ramp loads, unit = kN/m/m/m + - ... so on. + + start : Sympifyable + The starting point of the applied load. For point moments and + point forces this is the location of application. + order : Integer + The order of the applied load. + + - For moments, order = -2 + - For point loads, order =-1 + - For constant distributed load, order = 0 + - For ramp loads, order = 1 + - For parabolic ramp loads, order = 2 + - ... so on. + + end : Sympifyable, optional + An optional argument that can be used if the load has an end point + within the length of the beam. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A point load of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 2 meters to 3 meters + away from the starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 2, 2, end=3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + self._applied_loads.append((value, start, order, end)) + self._load += value*SingularityFunction(x, start, order) + self._original_load += value*SingularityFunction(x, start, order) + + if end: + # load has an end point within the length of the beam. + self._handle_end(x, value, start, order, end, type="apply") + + def remove_load(self, value, start, order, end=None): + """ + This method removes a particular load present on the beam object. + Returns a ValueError if the load passed as an argument is not + present on the beam. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied load. + start : Sympifyable + The starting point of the applied load. For point moments and + point forces this is the location of application. + order : Integer + The order of the applied load. + - For moments, order= -2 + - For point loads, order=-1 + - For constant distributed load, order=0 + - For ramp loads, order=1 + - For parabolic ramp loads, order=2 + - ... so on. + end : Sympifyable, optional + An optional argument that can be used if the load has an end point + within the length of the beam. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A pointload of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 2 meters to 3 meters + away from the starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 2, 2, end=3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 2, 2) + 2*SingularityFunction(x, 3, 0) + 4*SingularityFunction(x, 3, 1) + 2*SingularityFunction(x, 3, 2) + >>> b.remove_load(-2, 2, 2, end = 3) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if (value, start, order, end) in self._applied_loads: + self._load -= value*SingularityFunction(x, start, order) + self._original_load -= value*SingularityFunction(x, start, order) + self._applied_loads.remove((value, start, order, end)) + else: + msg = "No such load distribution exists on the beam object." + raise ValueError(msg) + + if end: + # load has an end point within the length of the beam. + self._handle_end(x, value, start, order, end, type="remove") + + def _handle_end(self, x, value, start, order, end, type): + """ + This functions handles the optional `end` value in the + `apply_load` and `remove_load` functions. When the value + of end is not NULL, this function will be executed. + """ + if order.is_negative: + msg = ("If 'end' is provided the 'order' of the load cannot " + "be negative, i.e. 'end' is only valid for distributed " + "loads.") + raise ValueError(msg) + # NOTE : A Taylor series can be used to define the summation of + # singularity functions that subtract from the load past the end + # point such that it evaluates to zero past 'end'. + f = value*x**order + + if type == "apply": + # iterating for "apply_load" method + for i in range(0, order + 1): + self._load -= (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + self._original_load -= (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + elif type == "remove": + # iterating for "remove_load" method + for i in range(0, order + 1): + self._load += (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + self._original_load += (f.diff(x, i).subs(x, end - start) * + SingularityFunction(x, end, i)/factorial(i)) + + + @property + def load(self): + """ + Returns a Singularity Function expression which represents + the load distribution curve of the Beam object. + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A point load of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point and a parabolic ramp load of magnitude + 2 N/m is applied below the beam starting from 3 meters away from the + starting point of the beam. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(-2, 3, 2) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 4*SingularityFunction(x, 2, -1) - 2*SingularityFunction(x, 3, 2) + """ + return self._load + + @property + def applied_loads(self): + """ + Returns a list of all loads applied on the beam object. + Each load in the list is a tuple of form (value, start, order, end). + + Examples + ======== + There is a beam of length 4 meters. A moment of magnitude 3 Nm is + applied in the clockwise direction at the starting point of the beam. + A pointload of magnitude 4 N is applied from the top of the beam at + 2 meters from the starting point. Another pointload of magnitude 5 N + is applied at same position. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(4, E, I) + >>> b.apply_load(-3, 0, -2) + >>> b.apply_load(4, 2, -1) + >>> b.apply_load(5, 2, -1) + >>> b.load + -3*SingularityFunction(x, 0, -2) + 9*SingularityFunction(x, 2, -1) + >>> b.applied_loads + [(-3, 0, -2, None), (4, 2, -1, None), (5, 2, -1, None)] + """ + return self._applied_loads + + def _solve_hinge_beams(self, *reactions): + """Method to find integration constants and reactional variables in a + composite beam connected via hinge. + This method resolves the composite Beam into its sub-beams and then + equations of shear force, bending moment, slope and deflection are + evaluated for both of them separately. These equations are then solved + for unknown reactions and integration constants using the boundary + conditions applied on the Beam. Equal deflection of both sub-beams + at the hinge joint gives us another equation to solve the system. + + Examples + ======== + A combined beam, with constant fkexural rigidity E*I, is formed by joining + a Beam of length 2*l to the right of another Beam of length l. The whole beam + is fixed at both of its both end. A point load of magnitude P is also applied + from the top at a distance of 2*l from starting point. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> l=symbols('l', positive=True) + >>> b1=Beam(l, E, I) + >>> b2=Beam(2*l, E, I) + >>> b=b1.join(b2,"hinge") + >>> M1, A1, M2, A2, P = symbols('M1 A1 M2 A2 P') + >>> b.apply_load(A1,0,-1) + >>> b.apply_load(M1,0,-2) + >>> b.apply_load(P,2*l,-1) + >>> b.apply_load(A2,3*l,-1) + >>> b.apply_load(M2,3*l,-2) + >>> b.bc_slope=[(0,0), (3*l, 0)] + >>> b.bc_deflection=[(0,0), (3*l, 0)] + >>> b.solve_for_reaction_loads(M1, A1, M2, A2) + >>> b.reaction_loads + {A1: -5*P/18, A2: -13*P/18, M1: 5*P*l/18, M2: -4*P*l/9} + >>> b.slope() + (5*P*l*SingularityFunction(x, 0, 1)/18 - 5*P*SingularityFunction(x, 0, 2)/36 + 5*P*SingularityFunction(x, l, 2)/36)*SingularityFunction(x, 0, 0)/(E*I) + - (5*P*l*SingularityFunction(x, 0, 1)/18 - 5*P*SingularityFunction(x, 0, 2)/36 + 5*P*SingularityFunction(x, l, 2)/36)*SingularityFunction(x, l, 0)/(E*I) + + (P*l**2/18 - 4*P*l*SingularityFunction(-l + x, 2*l, 1)/9 - 5*P*SingularityFunction(-l + x, 0, 2)/36 + P*SingularityFunction(-l + x, l, 2)/2 + - 13*P*SingularityFunction(-l + x, 2*l, 2)/36)*SingularityFunction(x, l, 0)/(E*I) + >>> b.deflection() + (5*P*l*SingularityFunction(x, 0, 2)/36 - 5*P*SingularityFunction(x, 0, 3)/108 + 5*P*SingularityFunction(x, l, 3)/108)*SingularityFunction(x, 0, 0)/(E*I) + - (5*P*l*SingularityFunction(x, 0, 2)/36 - 5*P*SingularityFunction(x, 0, 3)/108 + 5*P*SingularityFunction(x, l, 3)/108)*SingularityFunction(x, l, 0)/(E*I) + + (5*P*l**3/54 + P*l**2*(-l + x)/18 - 2*P*l*SingularityFunction(-l + x, 2*l, 2)/9 - 5*P*SingularityFunction(-l + x, 0, 3)/108 + P*SingularityFunction(-l + x, l, 3)/6 + - 13*P*SingularityFunction(-l + x, 2*l, 3)/108)*SingularityFunction(x, l, 0)/(E*I) + """ + x = self.variable + l = self._hinge_position + E = self._elastic_modulus + I = self._second_moment + + if isinstance(I, Piecewise): + I1 = I.args[0][0] + I2 = I.args[1][0] + else: + I1 = I2 = I + + load_1 = 0 # Load equation on first segment of composite beam + load_2 = 0 # Load equation on second segment of composite beam + + # Distributing load on both segments + for load in self.applied_loads: + if load[1] < l: + load_1 += load[0]*SingularityFunction(x, load[1], load[2]) + if load[2] == 0: + load_1 -= load[0]*SingularityFunction(x, load[3], load[2]) + elif load[2] > 0: + load_1 -= load[0]*SingularityFunction(x, load[3], load[2]) + load[0]*SingularityFunction(x, load[3], 0) + elif load[1] == l: + load_1 += load[0]*SingularityFunction(x, load[1], load[2]) + load_2 += load[0]*SingularityFunction(x, load[1] - l, load[2]) + elif load[1] > l: + load_2 += load[0]*SingularityFunction(x, load[1] - l, load[2]) + if load[2] == 0: + load_2 -= load[0]*SingularityFunction(x, load[3] - l, load[2]) + elif load[2] > 0: + load_2 -= load[0]*SingularityFunction(x, load[3] - l, load[2]) + load[0]*SingularityFunction(x, load[3] - l, 0) + + h = Symbol('h') # Force due to hinge + load_1 += h*SingularityFunction(x, l, -1) + load_2 -= h*SingularityFunction(x, 0, -1) + + eq = [] + shear_1 = integrate(load_1, x) + shear_curve_1 = limit(shear_1, x, l) + eq.append(shear_curve_1) + bending_1 = integrate(shear_1, x) + moment_curve_1 = limit(bending_1, x, l) + eq.append(moment_curve_1) + + shear_2 = integrate(load_2, x) + shear_curve_2 = limit(shear_2, x, self.length - l) + eq.append(shear_curve_2) + bending_2 = integrate(shear_2, x) + moment_curve_2 = limit(bending_2, x, self.length - l) + eq.append(moment_curve_2) + + C1 = Symbol('C1') + C2 = Symbol('C2') + C3 = Symbol('C3') + C4 = Symbol('C4') + slope_1 = S.One/(E*I1)*(integrate(bending_1, x) + C1) + def_1 = S.One/(E*I1)*(integrate((E*I)*slope_1, x) + C1*x + C2) + slope_2 = S.One/(E*I2)*(integrate(integrate(integrate(load_2, x), x), x) + C3) + def_2 = S.One/(E*I2)*(integrate((E*I)*slope_2, x) + C4) + + for position, value in self.bc_slope: + if position>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) # Reaction force at x = 10 + >>> b.apply_load(R2, 30, -1) # Reaction force at x = 30 + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.load + R1*SingularityFunction(x, 10, -1) + R2*SingularityFunction(x, 30, -1) + - 8*SingularityFunction(x, 0, -1) + 120*SingularityFunction(x, 30, -2) + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.reaction_loads + {R1: 6, R2: 2} + >>> b.load + -8*SingularityFunction(x, 0, -1) + 6*SingularityFunction(x, 10, -1) + + 120*SingularityFunction(x, 30, -2) + 2*SingularityFunction(x, 30, -1) + """ + if self._composite_type == "hinge": + return self._solve_hinge_beams(*reactions) + + x = self.variable + l = self.length + C3 = Symbol('C3') + C4 = Symbol('C4') + + shear_curve = limit(self.shear_force(), x, l) + moment_curve = limit(self.bending_moment(), x, l) + + slope_eqs = [] + deflection_eqs = [] + + slope_curve = integrate(self.bending_moment(), x) + C3 + for position, value in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - value + slope_eqs.append(eqs) + + deflection_curve = integrate(slope_curve, x) + C4 + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + deflection_eqs.append(eqs) + + solution = list((linsolve([shear_curve, moment_curve] + slope_eqs + + deflection_eqs, (C3, C4) + reactions).args)[0]) + solution = solution[2:] + + self._reaction_loads = dict(zip(reactions, solution)) + self._load = self._load.subs(self._reaction_loads) + + def shear_force(self): + """ + Returns a Singularity Function expression which represents + the shear force curve of the Beam object. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.shear_force() + 8*SingularityFunction(x, 0, 0) - 6*SingularityFunction(x, 10, 0) - 120*SingularityFunction(x, 30, -1) - 2*SingularityFunction(x, 30, 0) + """ + x = self.variable + return -integrate(self.load, x) + + def max_shear_force(self): + """Returns maximum Shear force and its coordinate + in the Beam object.""" + shear_curve = self.shear_force() + x = self.variable + + terms = shear_curve.args + singularity = [] # Points at which shear function changes + for term in terms: + if isinstance(term, Mul): + term = term.args[-1] # SingularityFunction in the term + singularity.append(term.args[1]) + singularity = list(set(singularity)) + singularity.sort() + + intervals = [] # List of Intervals with discrete value of shear force + shear_values = [] # List of values of shear force in each interval + for i, s in enumerate(singularity): + if s == 0: + continue + try: + shear_slope = Piecewise((float("nan"), x<=singularity[i-1]),(self._load.rewrite(Piecewise), x>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.bending_moment() + 8*SingularityFunction(x, 0, 1) - 6*SingularityFunction(x, 10, 1) - 120*SingularityFunction(x, 30, 0) - 2*SingularityFunction(x, 30, 1) + """ + x = self.variable + return integrate(self.shear_force(), x) + + def max_bmoment(self): + """Returns maximum Shear force and its coordinate + in the Beam object.""" + bending_curve = self.bending_moment() + x = self.variable + + terms = bending_curve.args + singularity = [] # Points at which bending moment changes + for term in terms: + if isinstance(term, Mul): + term = term.args[-1] # SingularityFunction in the term + singularity.append(term.args[1]) + singularity = list(set(singularity)) + singularity.sort() + + intervals = [] # List of Intervals with discrete value of bending moment + moment_values = [] # List of values of bending moment in each interval + for i, s in enumerate(singularity): + if s == 0: + continue + try: + moment_slope = Piecewise( + (float("nan"), x <= singularity[i - 1]), + (self.shear_force().rewrite(Piecewise), x < s), + (float("nan"), True)) + points = solve(moment_slope, x) + val = [] + for point in points: + val.append(abs(bending_curve.subs(x, point))) + points.extend([singularity[i-1], s]) + val += [abs(limit(bending_curve, x, singularity[i-1], '+')), abs(limit(bending_curve, x, s, '-'))] + max_moment = max(val) + moment_values.append(max_moment) + intervals.append(points[val.index(max_moment)]) + + # If bending moment in a particular Interval has zero or constant + # slope, then above block gives NotImplementedError as solve + # can't represent Interval solutions. + except NotImplementedError: + initial_moment = limit(bending_curve, x, singularity[i-1], '+') + final_moment = limit(bending_curve, x, s, '-') + # If bending_curve has a constant slope(it is a line). + if bending_curve.subs(x, (singularity[i-1] + s)/2) == (initial_moment + final_moment)/2 and initial_moment != final_moment: + moment_values.extend([initial_moment, final_moment]) + intervals.extend([singularity[i-1], s]) + else: # bending_curve has same value in whole Interval + moment_values.append(final_moment) + intervals.append(Interval(singularity[i-1], s)) + + moment_values = list(map(abs, moment_values)) + maximum_moment = max(moment_values) + point = intervals[moment_values.index(maximum_moment)] + return (point, maximum_moment) + + def point_cflexure(self): + """ + Returns a Set of point(s) with zero bending moment and + where bending moment curve of the beam object changes + its sign from negative to positive or vice versa. + + Examples + ======== + There is is 10 meter long overhanging beam. There are + two simple supports below the beam. One at the start + and another one at a distance of 6 meters from the start. + Point loads of magnitude 10KN and 20KN are applied at + 2 meters and 4 meters from start respectively. A Uniformly + distribute load of magnitude of magnitude 3KN/m is also + applied on top starting from 6 meters away from starting + point till end. + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> b = Beam(10, E, I) + >>> b.apply_load(-4, 0, -1) + >>> b.apply_load(-46, 6, -1) + >>> b.apply_load(10, 2, -1) + >>> b.apply_load(20, 4, -1) + >>> b.apply_load(3, 6, 0) + >>> b.point_cflexure() + [10/3] + """ + + # To restrict the range within length of the Beam + moment_curve = Piecewise((float("nan"), self.variable<=0), + (self.bending_moment(), self.variable>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.slope() + (-4*SingularityFunction(x, 0, 2) + 3*SingularityFunction(x, 10, 2) + + 120*SingularityFunction(x, 30, 1) + SingularityFunction(x, 30, 2) + 4000/3)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + I = self.second_moment + + if self._composite_type == "hinge": + return self._hinge_beam_slope + if not self._boundary_conditions['slope']: + return diff(self.deflection(), x) + if isinstance(I, Piecewise) and self._composite_type == "fixed": + args = I.args + slope = 0 + prev_slope = 0 + prev_end = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + if i != len(args) - 1: + slope += (prev_slope + slope_value)*SingularityFunction(x, prev_end, 0) - \ + (prev_slope + slope_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + slope += (prev_slope + slope_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + return slope + + C3 = Symbol('C3') + slope_curve = -integrate(S.One/(E*I)*self.bending_moment(), x) + C3 + + bc_eqs = [] + for position, value in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - value + bc_eqs.append(eqs) + constants = list(linsolve(bc_eqs, C3)) + slope_curve = slope_curve.subs({C3: constants[0][0]}) + return slope_curve + + def deflection(self): + """ + Returns a Singularity Function expression which represents + the elastic curve or deflection of the Beam object. + + Examples + ======== + There is a beam of length 30 meters. A moment of magnitude 120 Nm is + applied in the clockwise direction at the end of the beam. A pointload + of magnitude 8 N is applied from the top of the beam at the starting + point. There are two simple supports below the beam. One at the end + and another one at a distance of 10 meters from the start. The + deflection is restricted at both the supports. + + Using the sign convention of upward forces and clockwise moment + being positive. + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> E, I = symbols('E, I') + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(30, E, I) + >>> b.apply_load(-8, 0, -1) + >>> b.apply_load(R1, 10, -1) + >>> b.apply_load(R2, 30, -1) + >>> b.apply_load(120, 30, -2) + >>> b.bc_deflection = [(10, 0), (30, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.deflection() + (4000*x/3 - 4*SingularityFunction(x, 0, 3)/3 + SingularityFunction(x, 10, 3) + + 60*SingularityFunction(x, 30, 2) + SingularityFunction(x, 30, 3)/3 - 12000)/(E*I) + """ + x = self.variable + E = self.elastic_modulus + I = self.second_moment + if self._composite_type == "hinge": + return self._hinge_beam_deflection + if not self._boundary_conditions['deflection'] and not self._boundary_conditions['slope']: + if isinstance(I, Piecewise) and self._composite_type == "fixed": + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + base_char = self._base_char + constants = symbols(base_char + '3:5') + return S.One/(E*I)*integrate(-integrate(self.bending_moment(), x), x) + constants[0]*x + constants[1] + elif not self._boundary_conditions['deflection']: + base_char = self._base_char + constant = symbols(base_char + '4') + return integrate(self.slope(), x) + constant + elif not self._boundary_conditions['slope'] and self._boundary_conditions['deflection']: + if isinstance(I, Piecewise) and self._composite_type == "fixed": + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = -S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + base_char = self._base_char + C3, C4 = symbols(base_char + '3:5') # Integration constants + slope_curve = -integrate(self.bending_moment(), x) + C3 + deflection_curve = integrate(slope_curve, x) + C4 + bc_eqs = [] + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + bc_eqs.append(eqs) + constants = list(linsolve(bc_eqs, (C3, C4))) + deflection_curve = deflection_curve.subs({C3: constants[0][0], C4: constants[0][1]}) + return S.One/(E*I)*deflection_curve + + if isinstance(I, Piecewise) and self._composite_type == "fixed": + args = I.args + prev_slope = 0 + prev_def = 0 + prev_end = 0 + deflection = 0 + for i in range(len(args)): + if i != 0: + prev_end = args[i-1][1].args[1] + slope_value = S.One/E*integrate(self.bending_moment()/args[i][0], (x, prev_end, x)) + recent_segment_slope = prev_slope + slope_value + deflection_value = integrate(recent_segment_slope, (x, prev_end, x)) + if i != len(args) - 1: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) \ + - (prev_def + deflection_value)*SingularityFunction(x, args[i][1].args[1], 0) + else: + deflection += (prev_def + deflection_value)*SingularityFunction(x, prev_end, 0) + prev_slope = slope_value.subs(x, args[i][1].args[1]) + prev_def = deflection_value.subs(x, args[i][1].args[1]) + return deflection + + C4 = Symbol('C4') + deflection_curve = integrate(self.slope(), x) + C4 + + bc_eqs = [] + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + bc_eqs.append(eqs) + + constants = list(linsolve(bc_eqs, C4)) + deflection_curve = deflection_curve.subs({C4: constants[0][0]}) + return deflection_curve + + def max_deflection(self): + """ + Returns point of max deflection and its corresponding deflection value + in a Beam object. + """ + + # To restrict the range within length of the Beam + slope_curve = Piecewise((float("nan"), self.variable<=0), + (self.slope(), self.variable>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6), 2) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_shear_stress() + Plot object containing: + [0]: cartesian line: 6875*SingularityFunction(x, 0, 0) - 2500*SingularityFunction(x, 2, 0) + - 5000*SingularityFunction(x, 4, 1) + 15625*SingularityFunction(x, 8, 0) + + 5000*SingularityFunction(x, 8, 1) for x over (0.0, 8.0) + """ + + shear_stress = self.shear_stress() + x = self.variable + length = self.length + + if subs is None: + subs = {} + for sym in shear_stress.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('value of %s was not passed.' %sym) + + if length in subs: + length = subs[length] + + # Returns Plot of Shear Stress + return plot (shear_stress.subs(subs), (x, 0, length), + title='Shear Stress', xlabel=r'$\mathrm{x}$', ylabel=r'$\tau$', + line_color='r') + + + def plot_shear_force(self, subs=None): + """ + + Returns a plot for Shear force present in the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_shear_force() + Plot object containing: + [0]: cartesian line: 13750*SingularityFunction(x, 0, 0) - 5000*SingularityFunction(x, 2, 0) + - 10000*SingularityFunction(x, 4, 1) + 31250*SingularityFunction(x, 8, 0) + + 10000*SingularityFunction(x, 8, 1) for x over (0.0, 8.0) + """ + shear_force = self.shear_force() + if subs is None: + subs = {} + for sym in shear_force.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(shear_force.subs(subs), (self.variable, 0, length), title='Shear Force', + xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{V}$', line_color='g') + + def plot_bending_moment(self, subs=None): + """ + + Returns a plot for Bending moment present in the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_bending_moment() + Plot object containing: + [0]: cartesian line: 13750*SingularityFunction(x, 0, 1) - 5000*SingularityFunction(x, 2, 1) + - 5000*SingularityFunction(x, 4, 2) + 31250*SingularityFunction(x, 8, 1) + + 5000*SingularityFunction(x, 8, 2) for x over (0.0, 8.0) + """ + bending_moment = self.bending_moment() + if subs is None: + subs = {} + for sym in bending_moment.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(bending_moment.subs(subs), (self.variable, 0, length), title='Bending Moment', + xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{M}$', line_color='b') + + def plot_slope(self, subs=None): + """ + + Returns a plot for slope of deflection curve of the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_slope() + Plot object containing: + [0]: cartesian line: -8.59375e-5*SingularityFunction(x, 0, 2) + 3.125e-5*SingularityFunction(x, 2, 2) + + 2.08333333333333e-5*SingularityFunction(x, 4, 3) - 0.0001953125*SingularityFunction(x, 8, 2) + - 2.08333333333333e-5*SingularityFunction(x, 8, 3) + 0.00138541666666667 for x over (0.0, 8.0) + """ + slope = self.slope() + if subs is None: + subs = {} + for sym in slope.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(slope.subs(subs), (self.variable, 0, length), title='Slope', + xlabel=r'$\mathrm{x}$', ylabel=r'$\theta$', line_color='m') + + def plot_deflection(self, subs=None): + """ + + Returns a plot for deflection curve of the Beam object. + + Parameters + ========== + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> b.plot_deflection() + Plot object containing: + [0]: cartesian line: 0.00138541666666667*x - 2.86458333333333e-5*SingularityFunction(x, 0, 3) + + 1.04166666666667e-5*SingularityFunction(x, 2, 3) + 5.20833333333333e-6*SingularityFunction(x, 4, 4) + - 6.51041666666667e-5*SingularityFunction(x, 8, 3) - 5.20833333333333e-6*SingularityFunction(x, 8, 4) + for x over (0.0, 8.0) + """ + deflection = self.deflection() + if subs is None: + subs = {} + for sym in deflection.atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + return plot(deflection.subs(subs), (self.variable, 0, length), + title='Deflection', xlabel=r'$\mathrm{x}$', ylabel=r'$\delta$', + line_color='r') + + + def plot_loading_results(self, subs=None): + """ + Returns a subplot of Shear Force, Bending Moment, + Slope and Deflection of the Beam object. + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + + There is a beam of length 8 meters. A constant distributed load of 10 KN/m + is applied from half of the beam till the end. There are two simple supports + below the beam, one at the starting point and another at the ending point + of the beam. A pointload of magnitude 5 KN is also applied from top of the + beam, at a distance of 4 meters from the starting point. + Take E = 200 GPa and I = 400*(10**-6) meter**4. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> R1, R2 = symbols('R1, R2') + >>> b = Beam(8, 200*(10**9), 400*(10**-6)) + >>> b.apply_load(5000, 2, -1) + >>> b.apply_load(R1, 0, -1) + >>> b.apply_load(R2, 8, -1) + >>> b.apply_load(10000, 4, 0, end=8) + >>> b.bc_deflection = [(0, 0), (8, 0)] + >>> b.solve_for_reaction_loads(R1, R2) + >>> axes = b.plot_loading_results() + """ + length = self.length + variable = self.variable + if subs is None: + subs = {} + for sym in self.deflection().atoms(Symbol): + if sym == self.variable: + continue + if sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if length in subs: + length = subs[length] + ax1 = plot(self.shear_force().subs(subs), (variable, 0, length), + title="Shear Force", xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{V}$', + line_color='g', show=False) + ax2 = plot(self.bending_moment().subs(subs), (variable, 0, length), + title="Bending Moment", xlabel=r'$\mathrm{x}$', ylabel=r'$\mathrm{M}$', + line_color='b', show=False) + ax3 = plot(self.slope().subs(subs), (variable, 0, length), + title="Slope", xlabel=r'$\mathrm{x}$', ylabel=r'$\theta$', + line_color='m', show=False) + ax4 = plot(self.deflection().subs(subs), (variable, 0, length), + title="Deflection", xlabel=r'$\mathrm{x}$', ylabel=r'$\delta$', + line_color='r', show=False) + + return PlotGrid(4, 1, ax1, ax2, ax3, ax4) + + def _solve_for_ild_equations(self): + """ + + Helper function for I.L.D. It takes the unsubstituted + copy of the load equation and uses it to calculate shear force and bending + moment equations. + """ + + x = self.variable + shear_force = -integrate(self._original_load, x) + bending_moment = integrate(shear_force, x) + + return shear_force, bending_moment + + def solve_for_ild_reactions(self, value, *reactions): + """ + + Determines the Influence Line Diagram equations for reaction + forces under the effect of a moving load. + + Parameters + ========== + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Examples + ======== + + There is a beam of length 10 meters. There are two simple supports + below the beam, one at the starting point and another at the ending + point of the beam. Calculate the I.L.D. equations for reaction forces + under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_10 = symbols('R_0, R_10') + >>> b = Beam(10, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p10 = b.apply_support(10, 'roller') + >>> b.solve_for_ild_reactions(1,R_0,R_10) + >>> b.ild_reactions + {R_0: x/10 - 1, R_10: -x/10} + + """ + shear_force, bending_moment = self._solve_for_ild_equations() + x = self.variable + l = self.length + C3 = Symbol('C3') + C4 = Symbol('C4') + + shear_curve = limit(shear_force, x, l) - value + moment_curve = limit(bending_moment, x, l) - value*(l-x) + + slope_eqs = [] + deflection_eqs = [] + + slope_curve = integrate(bending_moment, x) + C3 + for position, value in self._boundary_conditions['slope']: + eqs = slope_curve.subs(x, position) - value + slope_eqs.append(eqs) + + deflection_curve = integrate(slope_curve, x) + C4 + for position, value in self._boundary_conditions['deflection']: + eqs = deflection_curve.subs(x, position) - value + deflection_eqs.append(eqs) + + solution = list((linsolve([shear_curve, moment_curve] + slope_eqs + + deflection_eqs, (C3, C4) + reactions).args)[0]) + solution = solution[2:] + + # Determining the equations and solving them. + self._ild_reactions = dict(zip(reactions, solution)) + + def plot_ild_reactions(self, subs=None): + """ + + Plots the Influence Line Diagram of Reaction Forces + under the effect of a moving load. This function + should be called after calling solve_for_ild_reactions(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + + There is a beam of length 10 meters. A point load of magnitude 5KN + is also applied from top of the beam, at a distance of 4 meters + from the starting point. There are two simple supports below the + beam, located at the starting point and at a distance of 7 meters + from the starting point. Plot the I.L.D. equations for reactions + at both support points under the effect of a moving load + of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_7 = symbols('R_0, R_7') + >>> b = Beam(10, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p7 = b.apply_support(7, 'roller') + >>> b.apply_load(5,4,-1) + >>> b.solve_for_ild_reactions(1,R_0,R_7) + >>> b.ild_reactions + {R_0: x/7 - 22/7, R_7: -x/7 - 20/7} + >>> b.plot_ild_reactions() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: x/7 - 22/7 for x over (0.0, 10.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -x/7 - 20/7 for x over (0.0, 10.0) + + """ + if not self._ild_reactions: + raise ValueError("I.L.D. reaction equations not found. Please use solve_for_ild_reactions() to generate the I.L.D. reaction equations.") + + x = self.variable + ildplots = [] + + if subs is None: + subs = {} + + for reaction in self._ild_reactions: + for sym in self._ild_reactions[reaction].atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for reaction in self._ild_reactions: + ildplots.append(plot(self._ild_reactions[reaction].subs(subs), + (x, 0, self._length.subs(subs)), title='I.L.D. for Reactions', + xlabel=x, ylabel=reaction, line_color='blue', show=False)) + + return PlotGrid(len(ildplots), 1, *ildplots) + + def solve_for_ild_shear(self, distance, value, *reactions): + """ + + Determines the Influence Line Diagram equations for shear at a + specified point under the effect of a moving load. + + Parameters + ========== + distance : Integer + Distance of the point from the start of the beam + for which equations are to be determined + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Calculate the I.L.D. equations for Shear at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_shear(4, 1, R_0, R_8) + >>> b.ild_shear + Piecewise((x/8, x < 4), (x/8 - 1, x > 4)) + + """ + + x = self.variable + l = self.length + + shear_force, _ = self._solve_for_ild_equations() + + shear_curve1 = value - limit(shear_force, x, distance) + shear_curve2 = (limit(shear_force, x, l) - limit(shear_force, x, distance)) - value + + for reaction in reactions: + shear_curve1 = shear_curve1.subs(reaction,self._ild_reactions[reaction]) + shear_curve2 = shear_curve2.subs(reaction,self._ild_reactions[reaction]) + + shear_eq = Piecewise((shear_curve1, x < distance), (shear_curve2, x > distance)) + + self._ild_shear = shear_eq + + def plot_ild_shear(self,subs=None): + """ + + Plots the Influence Line Diagram for Shear under the effect + of a moving load. This function should be called after + calling solve_for_ild_shear(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Plot the I.L.D. for Shear at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_shear(4, 1, R_0, R_8) + >>> b.ild_shear + Piecewise((x/8, x < 4), (x/8 - 1, x > 4)) + >>> b.plot_ild_shear() + Plot object containing: + [0]: cartesian line: Piecewise((x/8, x < 4), (x/8 - 1, x > 4)) for x over (0.0, 12.0) + + """ + + if not self._ild_shear: + raise ValueError("I.L.D. shear equation not found. Please use solve_for_ild_shear() to generate the I.L.D. shear equations.") + + x = self.variable + l = self._length + + if subs is None: + subs = {} + + for sym in self._ild_shear.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + return plot(self._ild_shear.subs(subs), (x, 0, l), title='I.L.D. for Shear', + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{V}$', line_color='blue',show=True) + + def solve_for_ild_moment(self, distance, value, *reactions): + """ + + Determines the Influence Line Diagram equations for moment at a + specified point under the effect of a moving load. + + Parameters + ========== + distance : Integer + Distance of the point from the start of the beam + for which equations are to be determined + value : Integer + Magnitude of moving load + reactions : + The reaction forces applied on the beam. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Calculate the I.L.D. equations for Moment at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_moment(4, 1, R_0, R_8) + >>> b.ild_moment + Piecewise((-x/2, x < 4), (x/2 - 4, x > 4)) + + """ + + x = self.variable + l = self.length + + _, moment = self._solve_for_ild_equations() + + moment_curve1 = value*(distance-x) - limit(moment, x, distance) + moment_curve2= (limit(moment, x, l)-limit(moment, x, distance))-value*(l-x) + + for reaction in reactions: + moment_curve1 = moment_curve1.subs(reaction, self._ild_reactions[reaction]) + moment_curve2 = moment_curve2.subs(reaction, self._ild_reactions[reaction]) + + moment_eq = Piecewise((moment_curve1, x < distance), (moment_curve2, x > distance)) + self._ild_moment = moment_eq + + def plot_ild_moment(self,subs=None): + """ + + Plots the Influence Line Diagram for Moment under the effect + of a moving load. This function should be called after + calling solve_for_ild_moment(). + + Parameters + ========== + + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + + There is a beam of length 12 meters. There are two simple supports + below the beam, one at the starting point and another at a distance + of 8 meters. Plot the I.L.D. for Moment at a distance + of 4 meters under the effect of a moving load of magnitude 1kN. + + Using the sign convention of downwards forces being positive. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> E, I = symbols('E, I') + >>> R_0, R_8 = symbols('R_0, R_8') + >>> b = Beam(12, E, I) + >>> p0 = b.apply_support(0, 'roller') + >>> p8 = b.apply_support(8, 'roller') + >>> b.solve_for_ild_reactions(1, R_0, R_8) + >>> b.solve_for_ild_moment(4, 1, R_0, R_8) + >>> b.ild_moment + Piecewise((-x/2, x < 4), (x/2 - 4, x > 4)) + >>> b.plot_ild_moment() + Plot object containing: + [0]: cartesian line: Piecewise((-x/2, x < 4), (x/2 - 4, x > 4)) for x over (0.0, 12.0) + + """ + + if not self._ild_moment: + raise ValueError("I.L.D. moment equation not found. Please use solve_for_ild_moment() to generate the I.L.D. moment equations.") + + x = self.variable + + if subs is None: + subs = {} + + for sym in self._ild_moment.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + + for sym in self._length.atoms(Symbol): + if sym != x and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + return plot(self._ild_moment.subs(subs), (x, 0, self._length), title='I.L.D. for Moment', + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{M}$', line_color='blue', show=True) + + @doctest_depends_on(modules=('numpy',)) + def draw(self, pictorial=True): + """ + Returns a plot object representing the beam diagram of the beam. + In particular, the diagram might include: + + * the beam. + * vertical black arrows represent point loads and support reaction + forces (the latter if they have been added with the ``apply_load`` + method). + * circular arrows represent moments. + * shaded areas represent distributed loads. + * the support, if ``apply_support`` has been executed. + * if a composite beam has been created with the ``join`` method and + a hinge has been specified, it will be shown with a white disc. + + The diagram shows positive loads on the upper side of the beam, + and negative loads on the lower side. If two or more distributed + loads acts along the same direction over the same region, the + function will add them up together. + + .. note:: + The user must be careful while entering load values. + The draw function assumes a sign convention which is used + for plotting loads. + Given a right handed coordinate system with XYZ coordinates, + the beam's length is assumed to be along the positive X axis. + The draw function recognizes positive loads(with n>-2) as loads + acting along negative Y direction and positive moments acting + along positive Z direction. + + Parameters + ========== + + pictorial: Boolean (default=True) + Setting ``pictorial=True`` would simply create a pictorial (scaled) + view of the beam diagram. On the other hand, ``pictorial=False`` + would create a beam diagram with the exact dimensions on the plot. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam + >>> from sympy import symbols + >>> P1, P2, M = symbols('P1, P2, M') + >>> E, I = symbols('E, I') + >>> b = Beam(50, 20, 30) + >>> b.apply_load(-10, 2, -1) + >>> b.apply_load(15, 26, -1) + >>> b.apply_load(P1, 10, -1) + >>> b.apply_load(-P2, 40, -1) + >>> b.apply_load(90, 5, 0, 23) + >>> b.apply_load(10, 30, 1, 50) + >>> b.apply_load(M, 15, -2) + >>> b.apply_load(-M, 30, -2) + >>> p50 = b.apply_support(50, "pin") + >>> p0, m0 = b.apply_support(0, "fixed") + >>> p20 = b.apply_support(20, "roller") + >>> p = b.draw() # doctest: +SKIP + >>> p # doctest: +ELLIPSIS,+SKIP + Plot object containing: + [0]: cartesian line: 25*SingularityFunction(x, 5, 0) - 25*SingularityFunction(x, 23, 0) + + SingularityFunction(x, 30, 1) - 20*SingularityFunction(x, 50, 0) + - SingularityFunction(x, 50, 1) + 5 for x over (0.0, 50.0) + [1]: cartesian line: 5 for x over (0.0, 50.0) + ... + >>> p.show() # doctest: +SKIP + + """ + if not numpy: + raise ImportError("To use this function numpy module is required") + + loads = list(set(self.applied_loads) - set(self._support_as_loads)) + if (not pictorial) and any((len(l[0].free_symbols) > 0) and (l[2] >= 0) for l in loads): + raise ValueError("`pictorial=False` requires numerical " + "distributed loads. Instead, symbolic loads were found. " + "Cannot continue.") + + x = self.variable + + # checking whether length is an expression in terms of any Symbol. + if isinstance(self.length, Expr): + l = list(self.length.atoms(Symbol)) + # assigning every Symbol a default value of 10 + l = dict.fromkeys(l, 10) + length = self.length.subs(l) + else: + l = {} + length = self.length + height = length/10 + + rectangles = [] + rectangles.append({'xy':(0, 0), 'width':length, 'height': height, 'facecolor':"brown"}) + annotations, markers, load_eq,load_eq1, fill = self._draw_load(pictorial, length, l) + support_markers, support_rectangles = self._draw_supports(length, l) + + rectangles += support_rectangles + markers += support_markers + + if self._composite_type == "hinge": + # if self is a composite beam with an hinge, show it + ratio = self._hinge_position / self.length + x_pos = float(ratio) * length + markers += [{'args':[[x_pos], [height / 2]], 'marker':'o', 'markersize':6, 'color':"white"}] + + ylim = (-length, 1.25*length) + if fill: + # when distributed loads are presents, they might get clipped out + # in the figure by the ylim settings. + # It might be necessary to compute new limits. + _min = min(min(fill["y2"]), min(r["xy"][1] for r in rectangles)) + _max = max(max(fill["y1"]), max(r["xy"][1] for r in rectangles)) + if (_min < ylim[0]) or (_max > ylim[1]): + offset = abs(_max - _min) * 0.1 + ylim = (_min - offset, _max + offset) + + sing_plot = plot(height + load_eq, height + load_eq1, (x, 0, length), + xlim=(-height, length + height), ylim=ylim, + annotations=annotations, markers=markers, rectangles=rectangles, + line_color='brown', fill=fill, axis=False, show=False) + + return sing_plot + + + def _is_load_negative(self, load): + """Try to determine if a load is negative or positive, using + expansion and doit if necessary. + + Returns + ======= + True: if the load is negative + False: if the load is positive + None: if it is indeterminate + + """ + rv = load.is_negative + if load.is_Atom or rv is not None: + return rv + return load.doit().expand().is_negative + + def _draw_load(self, pictorial, length, l): + loads = list(set(self.applied_loads) - set(self._support_as_loads)) + height = length/10 + x = self.variable + + annotations = [] + markers = [] + load_args = [] + scaled_load = 0 + load_args1 = [] + scaled_load1 = 0 + load_eq = S.Zero # For positive valued higher order loads + load_eq1 = S.Zero # For negative valued higher order loads + fill = None + + # schematic view should use the class convention as much as possible. + # However, users can add expressions as symbolic loads, for example + # P1 - P2: is this load positive or negative? We can't say. + # On these occasions it is better to inform users about the + # indeterminate state of those loads. + warning_head = "Please, note that this schematic view might not be " \ + "in agreement with the sign convention used by the Beam class " \ + "for load-related computations, because it was not possible " \ + "to determine the sign (hence, the direction) of the " \ + "following loads:\n" + warning_body = "" + + for load in loads: + # check if the position of load is in terms of the beam length. + if l: + pos = load[1].subs(l) + else: + pos = load[1] + + # point loads + if load[2] == -1: + iln = self._is_load_negative(load[0]) + if iln is None: + warning_body += "* Point load %s located at %s\n" % (load[0], load[1]) + if iln: + annotations.append({'text':'', 'xy':(pos, 0), 'xytext':(pos, height - 4*height), 'arrowprops':{'width': 1.5, 'headlength': 5, 'headwidth': 5, 'facecolor': 'black'}}) + else: + annotations.append({'text':'', 'xy':(pos, height), 'xytext':(pos, height*4), 'arrowprops':{"width": 1.5, "headlength": 4, "headwidth": 4, "facecolor": 'black'}}) + # moment loads + elif load[2] == -2: + iln = self._is_load_negative(load[0]) + if iln is None: + warning_body += "* Moment %s located at %s\n" % (load[0], load[1]) + if self._is_load_negative(load[0]): + markers.append({'args':[[pos], [height/2]], 'marker': r'$\circlearrowright$', 'markersize':15}) + else: + markers.append({'args':[[pos], [height/2]], 'marker': r'$\circlearrowleft$', 'markersize':15}) + # higher order loads + elif load[2] >= 0: + # `fill` will be assigned only when higher order loads are present + value, start, order, end = load + + iln = self._is_load_negative(value) + if iln is None: + warning_body += "* Distributed load %s from %s to %s\n" % (value, start, end) + + # Positive loads have their separate equations + if not iln: + # if pictorial is True we remake the load equation again with + # some constant magnitude values. + if pictorial: + # remake the load equation again with some constant + # magnitude values. + value = 10**(1-order) if order > 0 else length/2 + scaled_load += value*SingularityFunction(x, start, order) + if end: + f2 = value*x**order if order >= 0 else length/2*x**order + for i in range(0, order + 1): + scaled_load -= (f2.diff(x, i).subs(x, end - start)* + SingularityFunction(x, end, i)/factorial(i)) + + if isinstance(scaled_load, Add): + load_args = scaled_load.args + else: + # when the load equation consists of only a single term + load_args = (scaled_load,) + load_eq = Add(*[i.subs(l) for i in load_args]) + + # For loads with negative value + else: + if pictorial: + # remake the load equation again with some constant + # magnitude values. + value = 10**(1-order) if order > 0 else length/2 + scaled_load1 += abs(value)*SingularityFunction(x, start, order) + if end: + f2 = abs(value)*x**order if order >= 0 else length/2*x**order + for i in range(0, order + 1): + scaled_load1 -= (f2.diff(x, i).subs(x, end - start)* + SingularityFunction(x, end, i)/factorial(i)) + + if isinstance(scaled_load1, Add): + load_args1 = scaled_load1.args + else: + # when the load equation consists of only a single term + load_args1 = (scaled_load1,) + load_eq1 = [i.subs(l) for i in load_args1] + load_eq1 = -Add(*load_eq1) - height + + if len(warning_body) > 0: + warnings.warn(warning_head + warning_body) + + xx = numpy.arange(0, float(length), 0.001) + yy1 = lambdify([x], height + load_eq.rewrite(Piecewise))(xx) + yy2 = lambdify([x], height + load_eq1.rewrite(Piecewise))(xx) + if not isinstance(yy1, numpy.ndarray): + yy1 *= numpy.ones_like(xx) + if not isinstance(yy2, numpy.ndarray): + yy2 *= numpy.ones_like(xx) + fill = {'x': xx, 'y1': yy1, 'y2': yy2, + 'color':'darkkhaki', "zorder": -1} + return annotations, markers, load_eq, load_eq1, fill + + + def _draw_supports(self, length, l): + height = float(length/10) + + support_markers = [] + support_rectangles = [] + for support in self._applied_supports: + if l: + pos = support[0].subs(l) + else: + pos = support[0] + + if support[1] == "pin": + support_markers.append({'args':[pos, [0]], 'marker':6, 'markersize':13, 'color':"black"}) + + elif support[1] == "roller": + support_markers.append({'args':[pos, [-height/2.5]], 'marker':'o', 'markersize':11, 'color':"black"}) + + elif support[1] == "fixed": + if pos == 0: + support_rectangles.append({'xy':(0, -3*height), 'width':-length/20, 'height':6*height + height, 'fill':False, 'hatch':'/////'}) + else: + support_rectangles.append({'xy':(length, -3*height), 'width':length/20, 'height': 6*height + height, 'fill':False, 'hatch':'/////'}) + + return support_markers, support_rectangles + + +class Beam3D(Beam): + """ + This class handles loads applied in any direction of a 3D space along + with unequal values of Second moment along different axes. + + .. note:: + A consistent sign convention must be used while solving a beam + bending problem; the results will + automatically follow the chosen sign convention. + This class assumes that any kind of distributed load/moment is + applied through out the span of a beam. + + Examples + ======== + There is a beam of l meters long. A constant distributed load of magnitude q + is applied along y-axis from start till the end of beam. A constant distributed + moment of magnitude m is also applied along z-axis from start till the end of beam. + Beam is fixed at both of its end. So, deflection of the beam at the both ends + is restricted. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols, simplify, collect, factor + >>> l, E, G, I, A = symbols('l, E, G, I, A') + >>> b = Beam3D(l, E, G, I, A) + >>> x, q, m = symbols('x, q, m') + >>> b.apply_load(q, 0, 0, dir="y") + >>> b.apply_moment_load(m, 0, -1, dir="z") + >>> b.shear_force() + [0, -q*x, 0] + >>> b.bending_moment() + [0, 0, -m*x + q*x**2/2] + >>> b.bc_slope = [(0, [0, 0, 0]), (l, [0, 0, 0])] + >>> b.bc_deflection = [(0, [0, 0, 0]), (l, [0, 0, 0])] + >>> b.solve_slope_deflection() + >>> factor(b.slope()) + [0, 0, x*(-l + x)*(-A*G*l**3*q + 2*A*G*l**2*q*x - 12*E*I*l*q + - 72*E*I*m + 24*E*I*q*x)/(12*E*I*(A*G*l**2 + 12*E*I))] + >>> dx, dy, dz = b.deflection() + >>> dy = collect(simplify(dy), x) + >>> dx == dz == 0 + True + >>> dy == (x*(12*E*I*l*(A*G*l**2*q - 2*A*G*l*m + 12*E*I*q) + ... + x*(A*G*l*(3*l*(A*G*l**2*q - 2*A*G*l*m + 12*E*I*q) + x*(-2*A*G*l**2*q + 4*A*G*l*m - 24*E*I*q)) + ... + A*G*(A*G*l**2 + 12*E*I)*(-2*l**2*q + 6*l*m - 4*m*x + q*x**2) + ... - 12*E*I*q*(A*G*l**2 + 12*E*I)))/(24*A*E*G*I*(A*G*l**2 + 12*E*I))) + True + + References + ========== + + .. [1] https://homes.civil.aau.dk/jc/FemteSemester/Beams3D.pdf + + """ + + def __init__(self, length, elastic_modulus, shear_modulus, second_moment, + area, variable=Symbol('x')): + """Initializes the class. + + Parameters + ========== + length : Sympifyable + A Symbol or value representing the Beam's length. + elastic_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of Elasticity. + It is a measure of the stiffness of the Beam material. + shear_modulus : Sympifyable + A SymPy expression representing the Beam's Modulus of rigidity. + It is a measure of rigidity of the Beam material. + second_moment : Sympifyable or list + A list of two elements having SymPy expression representing the + Beam's Second moment of area. First value represent Second moment + across y-axis and second across z-axis. + Single SymPy expression can be passed if both values are same + area : Sympifyable + A SymPy expression representing the Beam's cross-sectional area + in a plane perpendicular to length of the Beam. + variable : Symbol, optional + A Symbol object that will be used as the variable along the beam + while representing the load, shear, moment, slope and deflection + curve. By default, it is set to ``Symbol('x')``. + """ + super().__init__(length, elastic_modulus, second_moment, variable) + self.shear_modulus = shear_modulus + self.area = area + self._load_vector = [0, 0, 0] + self._moment_load_vector = [0, 0, 0] + self._torsion_moment = {} + self._load_Singularity = [0, 0, 0] + self._slope = [0, 0, 0] + self._deflection = [0, 0, 0] + self._angular_deflection = 0 + + @property + def shear_modulus(self): + """Young's Modulus of the Beam. """ + return self._shear_modulus + + @shear_modulus.setter + def shear_modulus(self, e): + self._shear_modulus = sympify(e) + + @property + def second_moment(self): + """Second moment of area of the Beam. """ + return self._second_moment + + @second_moment.setter + def second_moment(self, i): + if isinstance(i, list): + i = [sympify(x) for x in i] + self._second_moment = i + else: + self._second_moment = sympify(i) + + @property + def area(self): + """Cross-sectional area of the Beam. """ + return self._area + + @area.setter + def area(self, a): + self._area = sympify(a) + + @property + def load_vector(self): + """ + Returns a three element list representing the load vector. + """ + return self._load_vector + + @property + def moment_load_vector(self): + """ + Returns a three element list representing moment loads on Beam. + """ + return self._moment_load_vector + + @property + def boundary_conditions(self): + """ + Returns a dictionary of boundary conditions applied on the beam. + The dictionary has two keywords namely slope and deflection. + The value of each keyword is a list of tuple, where each tuple + contains location and value of a boundary condition in the format + (location, value). Further each value is a list corresponding to + slope or deflection(s) values along three axes at that location. + + Examples + ======== + There is a beam of length 4 meters. The slope at 0 should be 4 along + the x-axis and 0 along others. At the other end of beam, deflection + along all the three axes should be zero. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(30, E, G, I, A, x) + >>> b.bc_slope = [(0, (4, 0, 0))] + >>> b.bc_deflection = [(4, [0, 0, 0])] + >>> b.boundary_conditions + {'deflection': [(4, [0, 0, 0])], 'slope': [(0, (4, 0, 0))]} + + Here the deflection of the beam should be ``0`` along all the three axes at ``4``. + Similarly, the slope of the beam should be ``4`` along x-axis and ``0`` + along y and z axis at ``0``. + """ + return self._boundary_conditions + + def polar_moment(self): + """ + Returns the polar moment of area of the beam + about the X axis with respect to the centroid. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A = symbols('l, E, G, I, A') + >>> b = Beam3D(l, E, G, I, A) + >>> b.polar_moment() + 2*I + >>> I1 = [9, 15] + >>> b = Beam3D(l, E, G, I1, A) + >>> b.polar_moment() + 24 + """ + if not iterable(self.second_moment): + return 2*self.second_moment + return sum(self.second_moment) + + def apply_load(self, value, start, order, dir="y"): + """ + This method adds up the force load to a particular beam object. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied load. + dir : String + Axis along which load is applied. + order : Integer + The order of the applied load. + - For point loads, order=-1 + - For constant distributed load, order=0 + - For ramp loads, order=1 + - For parabolic ramp loads, order=2 + - ... so on. + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if dir == "x": + if not order == -1: + self._load_vector[0] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + + elif dir == "y": + if not order == -1: + self._load_vector[1] += value + self._load_Singularity[1] += value*SingularityFunction(x, start, order) + + else: + if not order == -1: + self._load_vector[2] += value + self._load_Singularity[2] += value*SingularityFunction(x, start, order) + + def apply_moment_load(self, value, start, order, dir="y"): + """ + This method adds up the moment loads to a particular beam object. + + Parameters + ========== + value : Sympifyable + The magnitude of an applied moment. + dir : String + Axis along which moment is applied. + order : Integer + The order of the applied load. + - For point moments, order=-2 + - For constant distributed moment, order=-1 + - For ramp moments, order=0 + - For parabolic ramp moments, order=1 + - ... so on. + """ + x = self.variable + value = sympify(value) + start = sympify(start) + order = sympify(order) + + if dir == "x": + if not order == -2: + self._moment_load_vector[0] += value + else: + if start in list(self._torsion_moment): + self._torsion_moment[start] += value + else: + self._torsion_moment[start] = value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + elif dir == "y": + if not order == -2: + self._moment_load_vector[1] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + else: + if not order == -2: + self._moment_load_vector[2] += value + self._load_Singularity[0] += value*SingularityFunction(x, start, order) + + def apply_support(self, loc, type="fixed"): + if type in ("pin", "roller"): + reaction_load = Symbol('R_'+str(loc)) + self._reaction_loads[reaction_load] = reaction_load + self.bc_deflection.append((loc, [0, 0, 0])) + else: + reaction_load = Symbol('R_'+str(loc)) + reaction_moment = Symbol('M_'+str(loc)) + self._reaction_loads[reaction_load] = [reaction_load, reaction_moment] + self.bc_deflection.append((loc, [0, 0, 0])) + self.bc_slope.append((loc, [0, 0, 0])) + + def solve_for_reaction_loads(self, *reaction): + """ + Solves for the reaction forces. + + Examples + ======== + There is a beam of length 30 meters. It it supported by rollers at + of its end. A constant distributed load of magnitude 8 N is applied + from start till its end along y-axis. Another linear load having + slope equal to 9 is applied along z-axis. + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(30, E, G, I, A, x) + >>> b.apply_load(8, start=0, order=0, dir="y") + >>> b.apply_load(9*x, start=0, order=0, dir="z") + >>> b.bc_deflection = [(0, [0, 0, 0]), (30, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="y") + >>> b.apply_load(R2, start=30, order=-1, dir="y") + >>> b.apply_load(R3, start=0, order=-1, dir="z") + >>> b.apply_load(R4, start=30, order=-1, dir="z") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.reaction_loads + {R1: -120, R2: -120, R3: -1350, R4: -2700} + """ + x = self.variable + l = self.length + q = self._load_Singularity + shear_curves = [integrate(load, x) for load in q] + moment_curves = [integrate(shear, x) for shear in shear_curves] + for i in range(3): + react = [r for r in reaction if (shear_curves[i].has(r) or moment_curves[i].has(r))] + if len(react) == 0: + continue + shear_curve = limit(shear_curves[i], x, l) + moment_curve = limit(moment_curves[i], x, l) + sol = list((linsolve([shear_curve, moment_curve], react).args)[0]) + sol_dict = dict(zip(react, sol)) + reaction_loads = self._reaction_loads + # Check if any of the evaluated reaction exists in another direction + # and if it exists then it should have same value. + for key in sol_dict: + if key in reaction_loads and sol_dict[key] != reaction_loads[key]: + raise ValueError("Ambiguous solution for %s in different directions." % key) + self._reaction_loads.update(sol_dict) + + def shear_force(self): + """ + Returns a list of three expressions which represents the shear force + curve of the Beam object along all three axes. + """ + x = self.variable + q = self._load_vector + return [integrate(-q[0], x), integrate(-q[1], x), integrate(-q[2], x)] + + def axial_force(self): + """ + Returns expression of Axial shear force present inside the Beam object. + """ + return self.shear_force()[0] + + def shear_stress(self): + """ + Returns a list of three expressions which represents the shear stress + curve of the Beam object along all three axes. + """ + return [self.shear_force()[0]/self._area, self.shear_force()[1]/self._area, self.shear_force()[2]/self._area] + + def axial_stress(self): + """ + Returns expression of Axial stress present inside the Beam object. + """ + return self.axial_force()/self._area + + def bending_moment(self): + """ + Returns a list of three expressions which represents the bending moment + curve of the Beam object along all three axes. + """ + x = self.variable + m = self._moment_load_vector + shear = self.shear_force() + + return [integrate(-m[0], x), integrate(-m[1] + shear[2], x), + integrate(-m[2] - shear[1], x) ] + + def torsional_moment(self): + """ + Returns expression of Torsional moment present inside the Beam object. + """ + return self.bending_moment()[0] + + def solve_for_torsion(self): + """ + Solves for the angular deflection due to the torsional effects of + moments being applied in the x-direction i.e. out of or into the beam. + + Here, a positive torque means the direction of the torque is positive + i.e. out of the beam along the beam-axis. Likewise, a negative torque + signifies a torque into the beam cross-section. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_moment_load(4, 4, -2, dir='x') + >>> b.apply_moment_load(4, 8, -2, dir='x') + >>> b.apply_moment_load(4, 8, -2, dir='x') + >>> b.solve_for_torsion() + >>> b.angular_deflection().subs(x, 3) + 18/(G*I) + """ + x = self.variable + sum_moments = 0 + for point in list(self._torsion_moment): + sum_moments += self._torsion_moment[point] + list(self._torsion_moment).sort() + pointsList = list(self._torsion_moment) + torque_diagram = Piecewise((sum_moments, x<=pointsList[0]), (0, x>=pointsList[0])) + for i in range(len(pointsList))[1:]: + sum_moments -= self._torsion_moment[pointsList[i-1]] + torque_diagram += Piecewise((0, x<=pointsList[i-1]), (sum_moments, x<=pointsList[i]), (0, x>=pointsList[i])) + integrated_torque_diagram = integrate(torque_diagram) + self._angular_deflection = integrated_torque_diagram/(self.shear_modulus*self.polar_moment()) + + def solve_slope_deflection(self): + x = self.variable + l = self.length + E = self.elastic_modulus + G = self.shear_modulus + I = self.second_moment + if isinstance(I, list): + I_y, I_z = I[0], I[1] + else: + I_y = I_z = I + A = self._area + load = self._load_vector + moment = self._moment_load_vector + defl = Function('defl') + theta = Function('theta') + + # Finding deflection along x-axis(and corresponding slope value by differentiating it) + # Equation used: Derivative(E*A*Derivative(def_x(x), x), x) + load_x = 0 + eq = Derivative(E*A*Derivative(defl(x), x), x) + load[0] + def_x = dsolve(Eq(eq, 0), defl(x)).args[1] + # Solving constants originated from dsolve + C1 = Symbol('C1') + C2 = Symbol('C2') + constants = list((linsolve([def_x.subs(x, 0), def_x.subs(x, l)], C1, C2).args)[0]) + def_x = def_x.subs({C1:constants[0], C2:constants[1]}) + slope_x = def_x.diff(x) + self._deflection[0] = def_x + self._slope[0] = slope_x + + # Finding deflection along y-axis and slope across z-axis. System of equation involved: + # 1: Derivative(E*I_z*Derivative(theta_z(x), x), x) + G*A*(Derivative(defl_y(x), x) - theta_z(x)) + moment_z = 0 + # 2: Derivative(G*A*(Derivative(defl_y(x), x) - theta_z(x)), x) + load_y = 0 + C_i = Symbol('C_i') + # Substitute value of `G*A*(Derivative(defl_y(x), x) - theta_z(x))` from (2) in (1) + eq1 = Derivative(E*I_z*Derivative(theta(x), x), x) + (integrate(-load[1], x) + C_i) + moment[2] + slope_z = dsolve(Eq(eq1, 0)).args[1] + + # Solve for constants originated from using dsolve on eq1 + constants = list((linsolve([slope_z.subs(x, 0), slope_z.subs(x, l)], C1, C2).args)[0]) + slope_z = slope_z.subs({C1:constants[0], C2:constants[1]}) + + # Put value of slope obtained back in (2) to solve for `C_i` and find deflection across y-axis + eq2 = G*A*(Derivative(defl(x), x)) + load[1]*x - C_i - G*A*slope_z + def_y = dsolve(Eq(eq2, 0), defl(x)).args[1] + # Solve for constants originated from using dsolve on eq2 + constants = list((linsolve([def_y.subs(x, 0), def_y.subs(x, l)], C1, C_i).args)[0]) + self._deflection[1] = def_y.subs({C1:constants[0], C_i:constants[1]}) + self._slope[2] = slope_z.subs(C_i, constants[1]) + + # Finding deflection along z-axis and slope across y-axis. System of equation involved: + # 1: Derivative(E*I_y*Derivative(theta_y(x), x), x) - G*A*(Derivative(defl_z(x), x) + theta_y(x)) + moment_y = 0 + # 2: Derivative(G*A*(Derivative(defl_z(x), x) + theta_y(x)), x) + load_z = 0 + + # Substitute value of `G*A*(Derivative(defl_y(x), x) + theta_z(x))` from (2) in (1) + eq1 = Derivative(E*I_y*Derivative(theta(x), x), x) + (integrate(load[2], x) - C_i) + moment[1] + slope_y = dsolve(Eq(eq1, 0)).args[1] + # Solve for constants originated from using dsolve on eq1 + constants = list((linsolve([slope_y.subs(x, 0), slope_y.subs(x, l)], C1, C2).args)[0]) + slope_y = slope_y.subs({C1:constants[0], C2:constants[1]}) + + # Put value of slope obtained back in (2) to solve for `C_i` and find deflection across z-axis + eq2 = G*A*(Derivative(defl(x), x)) + load[2]*x - C_i + G*A*slope_y + def_z = dsolve(Eq(eq2,0)).args[1] + # Solve for constants originated from using dsolve on eq2 + constants = list((linsolve([def_z.subs(x, 0), def_z.subs(x, l)], C1, C_i).args)[0]) + self._deflection[2] = def_z.subs({C1:constants[0], C_i:constants[1]}) + self._slope[1] = slope_y.subs(C_i, constants[1]) + + def slope(self): + """ + Returns a three element list representing slope of deflection curve + along all the three axes. + """ + return self._slope + + def deflection(self): + """ + Returns a three element list representing deflection curve along all + the three axes. + """ + return self._deflection + + def angular_deflection(self): + """ + Returns a function in x depicting how the angular deflection, due to moments + in the x-axis on the beam, varies with x. + """ + return self._angular_deflection + + def _plot_shear_force(self, dir, subs=None): + + shear_force = self.shear_force() + + if dir == 'x': + dir_num = 0 + color = 'r' + + elif dir == 'y': + dir_num = 1 + color = 'g' + + elif dir == 'z': + dir_num = 2 + color = 'b' + + if subs is None: + subs = {} + + for sym in shear_force[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(shear_force[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Shear Force along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{V(%c)}$'%dir, line_color=color) + + def plot_shear_force(self, dir="all", subs=None): + + """ + + Returns a plot for Shear force along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which shear force plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_shear_force() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -6*x**2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -15*x for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For shear force along x direction + if dir == "x": + Px = self._plot_shear_force('x', subs) + return Px.show() + # For shear force along y direction + elif dir == "y": + Py = self._plot_shear_force('y', subs) + return Py.show() + # For shear force along z direction + elif dir == "z": + Pz = self._plot_shear_force('z', subs) + return Pz.show() + # For shear force along all direction + else: + Px = self._plot_shear_force('x', subs) + Py = self._plot_shear_force('y', subs) + Pz = self._plot_shear_force('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_bending_moment(self, dir, subs=None): + + bending_moment = self.bending_moment() + + if dir == 'x': + dir_num = 0 + color = 'g' + + elif dir == 'y': + dir_num = 1 + color = 'c' + + elif dir == 'z': + dir_num = 2 + color = 'm' + + if subs is None: + subs = {} + + for sym in bending_moment[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(bending_moment[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Bending Moment along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{M(%c)}$'%dir, line_color=color) + + def plot_bending_moment(self, dir="all", subs=None): + + """ + + Returns a plot for bending moment along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which bending moment plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_bending_moment() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -15*x**2/2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: 2*x**3 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For bending moment along x direction + if dir == "x": + Px = self._plot_bending_moment('x', subs) + return Px.show() + # For bending moment along y direction + elif dir == "y": + Py = self._plot_bending_moment('y', subs) + return Py.show() + # For bending moment along z direction + elif dir == "z": + Pz = self._plot_bending_moment('z', subs) + return Pz.show() + # For bending moment along all direction + else: + Px = self._plot_bending_moment('x', subs) + Py = self._plot_bending_moment('y', subs) + Pz = self._plot_bending_moment('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_slope(self, dir, subs=None): + + slope = self.slope() + + if dir == 'x': + dir_num = 0 + color = 'b' + + elif dir == 'y': + dir_num = 1 + color = 'm' + + elif dir == 'z': + dir_num = 2 + color = 'g' + + if subs is None: + subs = {} + + for sym in slope[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + + return plot(slope[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Slope along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{\theta(%c)}$'%dir, line_color=color) + + def plot_slope(self, dir="all", subs=None): + + """ + + Returns a plot for Slope along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which Slope plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as keys and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_slope() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -x**3/1600 + 3*x**2/160 - x/8 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**4/8000 - 19*x**2/172 + 52*x/43 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For Slope along x direction + if dir == "x": + Px = self._plot_slope('x', subs) + return Px.show() + # For Slope along y direction + elif dir == "y": + Py = self._plot_slope('y', subs) + return Py.show() + # For Slope along z direction + elif dir == "z": + Pz = self._plot_slope('z', subs) + return Pz.show() + # For Slope along all direction + else: + Px = self._plot_slope('x', subs) + Py = self._plot_slope('y', subs) + Pz = self._plot_slope('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _plot_deflection(self, dir, subs=None): + + deflection = self.deflection() + + if dir == 'x': + dir_num = 0 + color = 'm' + + elif dir == 'y': + dir_num = 1 + color = 'r' + + elif dir == 'z': + dir_num = 2 + color = 'c' + + if subs is None: + subs = {} + + for sym in deflection[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(deflection[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Deflection along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\mathrm{\delta(%c)}$'%dir, line_color=color) + + def plot_deflection(self, dir="all", subs=None): + + """ + + Returns a plot for Deflection along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which deflection plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as keys and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_deflection() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: x**5/40000 - 4013*x**3/90300 + 26*x**2/43 + 1520*x/903 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: x**4/6400 - x**3/160 + 27*x**2/560 + 2*x/7 for x over (0.0, 20.0) + + + """ + + dir = dir.lower() + # For deflection along x direction + if dir == "x": + Px = self._plot_deflection('x', subs) + return Px.show() + # For deflection along y direction + elif dir == "y": + Py = self._plot_deflection('y', subs) + return Py.show() + # For deflection along z direction + elif dir == "z": + Pz = self._plot_deflection('z', subs) + return Pz.show() + # For deflection along all direction + else: + Px = self._plot_deflection('x', subs) + Py = self._plot_deflection('y', subs) + Pz = self._plot_deflection('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def plot_loading_results(self, dir='x', subs=None): + + """ + + Returns a subplot of Shear Force, Bending Moment, + Slope and Deflection of the Beam object along the direction specified. + + Parameters + ========== + + dir : string (default : "x") + Direction along which plots are required. + If no direction is specified, plots along x-axis are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters. It is supported by rollers + at both of its ends. A linear load having slope equal to 12 is applied + along y-axis. A constant distributed load of magnitude 15 N is + applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, A, x) + >>> subs = {E:40, G:21, I:100, A:25} + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.plot_loading_results('y',subs) + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: -6*x**2 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -15*x**2/2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -x**3/1600 + 3*x**2/160 - x/8 for x over (0.0, 20.0) + Plot[3]:Plot object containing: + [0]: cartesian line: x**5/40000 - 4013*x**3/90300 + 26*x**2/43 + 1520*x/903 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + if subs is None: + subs = {} + + ax1 = self._plot_shear_force(dir, subs) + ax2 = self._plot_bending_moment(dir, subs) + ax3 = self._plot_slope(dir, subs) + ax4 = self._plot_deflection(dir, subs) + + return PlotGrid(4, 1, ax1, ax2, ax3, ax4) + + def _plot_shear_stress(self, dir, subs=None): + + shear_stress = self.shear_stress() + + if dir == 'x': + dir_num = 0 + color = 'r' + + elif dir == 'y': + dir_num = 1 + color = 'g' + + elif dir == 'z': + dir_num = 2 + color = 'b' + + if subs is None: + subs = {} + + for sym in shear_stress[dir_num].atoms(Symbol): + if sym != self.variable and sym not in subs: + raise ValueError('Value of %s was not passed.' %sym) + if self.length in subs: + length = subs[self.length] + else: + length = self.length + + return plot(shear_stress[dir_num].subs(subs), (self.variable, 0, length), show = False, title='Shear stress along %c direction'%dir, + xlabel=r'$\mathrm{X}$', ylabel=r'$\tau(%c)$'%dir, line_color=color) + + def plot_shear_stress(self, dir="all", subs=None): + + """ + + Returns a plot for Shear Stress along all three directions + present in the Beam object. + + Parameters + ========== + dir : string (default : "all") + Direction along which shear stress plot is required. + If no direction is specified, all plots are displayed. + subs : dictionary + Python dictionary containing Symbols as key and their + corresponding values. + + Examples + ======== + There is a beam of length 20 meters and area of cross section 2 square + meters. It is supported by rollers at both of its ends. A linear load having + slope equal to 12 is applied along y-axis. A constant distributed load + of magnitude 15 N is applied from start till its end along z-axis. + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, E, G, I, 2, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.plot_shear_stress() + PlotGrid object containing: + Plot[0]:Plot object containing: + [0]: cartesian line: 0 for x over (0.0, 20.0) + Plot[1]:Plot object containing: + [0]: cartesian line: -3*x**2 for x over (0.0, 20.0) + Plot[2]:Plot object containing: + [0]: cartesian line: -15*x/2 for x over (0.0, 20.0) + + """ + + dir = dir.lower() + # For shear stress along x direction + if dir == "x": + Px = self._plot_shear_stress('x', subs) + return Px.show() + # For shear stress along y direction + elif dir == "y": + Py = self._plot_shear_stress('y', subs) + return Py.show() + # For shear stress along z direction + elif dir == "z": + Pz = self._plot_shear_stress('z', subs) + return Pz.show() + # For shear stress along all direction + else: + Px = self._plot_shear_stress('x', subs) + Py = self._plot_shear_stress('y', subs) + Pz = self._plot_shear_stress('z', subs) + return PlotGrid(3, 1, Px, Py, Pz) + + def _max_shear_force(self, dir): + """ + Helper function for max_shear_force(). + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.shear_force()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + load_curve = Piecewise((float("nan"), self.variable<=0), + (self._load_vector[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.max_shear_force() + [(0, 0), (20, 2400), (20, 300)] + """ + + max_shear = [] + max_shear.append(self._max_shear_force('x')) + max_shear.append(self._max_shear_force('y')) + max_shear.append(self._max_shear_force('z')) + return max_shear + + def _max_bending_moment(self, dir): + """ + Helper function for max_bending_moment(). + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.bending_moment()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + shear_curve = Piecewise((float("nan"), self.variable<=0), + (self.shear_force()[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.max_bending_moment() + [(0, 0), (20, 3000), (20, 16000)] + """ + + max_bmoment = [] + max_bmoment.append(self._max_bending_moment('x')) + max_bmoment.append(self._max_bending_moment('y')) + max_bmoment.append(self._max_bending_moment('z')) + return max_bmoment + + max_bmoment = max_bending_moment + + def _max_deflection(self, dir): + """ + Helper function for max_Deflection() + """ + + dir = dir.lower() + + if dir == 'x': + dir_num = 0 + + elif dir == 'y': + dir_num = 1 + + elif dir == 'z': + dir_num = 2 + + if not self.deflection()[dir_num]: + return (0,0) + # To restrict the range within length of the Beam + slope_curve = Piecewise((float("nan"), self.variable<=0), + (self.slope()[dir_num], self.variable>> from sympy.physics.continuum_mechanics.beam import Beam3D + >>> from sympy import symbols + >>> l, E, G, I, A, x = symbols('l, E, G, I, A, x') + >>> b = Beam3D(20, 40, 21, 100, 25, x) + >>> b.apply_load(15, start=0, order=0, dir="z") + >>> b.apply_load(12*x, start=0, order=0, dir="y") + >>> b.bc_deflection = [(0, [0, 0, 0]), (20, [0, 0, 0])] + >>> R1, R2, R3, R4 = symbols('R1, R2, R3, R4') + >>> b.apply_load(R1, start=0, order=-1, dir="z") + >>> b.apply_load(R2, start=20, order=-1, dir="z") + >>> b.apply_load(R3, start=0, order=-1, dir="y") + >>> b.apply_load(R4, start=20, order=-1, dir="y") + >>> b.solve_for_reaction_loads(R1, R2, R3, R4) + >>> b.solve_slope_deflection() + >>> b.max_deflection() + [(0, 0), (10, 495/14), (-10 + 10*sqrt(10793)/43, (10 - 10*sqrt(10793)/43)**3/160 - 20/7 + (10 - 10*sqrt(10793)/43)**4/6400 + 20*sqrt(10793)/301 + 27*(10 - 10*sqrt(10793)/43)**2/560)] + """ + + max_def = [] + max_def.append(self._max_deflection('x')) + max_def.append(self._max_deflection('y')) + max_def.append(self._max_deflection('z')) + return max_def diff --git a/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/cable.py b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/cable.py new file mode 100644 index 0000000000000000000000000000000000000000..40a32d2c636edc5eb7d729173a1ee5cd011ffc9a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/cable.py @@ -0,0 +1,587 @@ +""" +This module can be used to solve problems related +to 2D Cables. +""" + +from sympy.core.sympify import sympify +from sympy.core.symbol import Symbol +from sympy import sin, cos, pi, atan, diff +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.solvers.solveset import linsolve +from sympy.matrices import Matrix + + +class Cable: + """ + Cables are structures in engineering that support + the applied transverse loads through the tensile + resistance developed in its members. + + Cables are widely used in suspension bridges, tension + leg offshore platforms, transmission lines, and find + use in several other engineering applications. + + Examples + ======== + A cable is supported at (0, 10) and (10, 10). Two point loads + acting vertically downwards act on the cable, one with magnitude 3 kN + and acting 2 meters from the left support and 3 meters below it, while + the other with magnitude 2 kN is 6 meters from the left support and + 6 meters below it. + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('P', 2, 7, 3, 270)) + >>> c.apply_load(-1, ('Q', 6, 4, 2, 270)) + >>> c.loads + {'distributed': {}, 'point_load': {'P': [3, 270], 'Q': [2, 270]}} + >>> c.loads_position + {'P': [2, 7], 'Q': [6, 4]} + """ + def __init__(self, support_1, support_2): + """ + Initializes the class. + + Parameters + ========== + + support_1 and support_2 are tuples of the form + (label, x, y), where + + label : String or symbol + The label of the support + + x : Sympifyable + The x coordinate of the position of the support + + y : Sympifyable + The y coordinate of the position of the support + """ + self._left_support = [] + self._right_support = [] + self._supports = {} + self._support_labels = [] + self._loads = {"distributed": {}, "point_load": {}} + self._loads_position = {} + self._length = 0 + self._reaction_loads = {} + self._tension = {} + self._lowest_x_global = sympify(0) + + if support_1[0] == support_2[0]: + raise ValueError("Supports can not have the same label") + + elif support_1[1] == support_2[1]: + raise ValueError("Supports can not be at the same location") + + x1 = sympify(support_1[1]) + y1 = sympify(support_1[2]) + self._supports[support_1[0]] = [x1, y1] + + x2 = sympify(support_2[1]) + y2 = sympify(support_2[2]) + self._supports[support_2[0]] = [x2, y2] + + if support_1[1] < support_2[1]: + self._left_support.append(x1) + self._left_support.append(y1) + self._right_support.append(x2) + self._right_support.append(y2) + self._support_labels.append(support_1[0]) + self._support_labels.append(support_2[0]) + + else: + self._left_support.append(x2) + self._left_support.append(y2) + self._right_support.append(x1) + self._right_support.append(y1) + self._support_labels.append(support_2[0]) + self._support_labels.append(support_1[0]) + + for i in self._support_labels: + self._reaction_loads[Symbol("R_"+ i +"_x")] = 0 + self._reaction_loads[Symbol("R_"+ i +"_y")] = 0 + + @property + def supports(self): + """ + Returns the supports of the cable along with their + positions. + """ + return self._supports + + @property + def left_support(self): + """ + Returns the position of the left support. + """ + return self._left_support + + @property + def right_support(self): + """ + Returns the position of the right support. + """ + return self._right_support + + @property + def loads(self): + """ + Returns the magnitude and direction of the loads + acting on the cable. + """ + return self._loads + + @property + def loads_position(self): + """ + Returns the position of the point loads acting on the + cable. + """ + return self._loads_position + + @property + def length(self): + """ + Returns the length of the cable. + """ + return self._length + + @property + def reaction_loads(self): + """ + Returns the reaction forces at the supports, which are + initialized to 0. + """ + return self._reaction_loads + + @property + def tension(self): + """ + Returns the tension developed in the cable due to the loads + applied. + """ + return self._tension + + def tension_at(self, x): + """ + Returns the tension at a given value of x developed due to + distributed load. + """ + if 'distributed' not in self._tension.keys(): + raise ValueError("No distributed load added or solve method not called") + + if x > self._right_support[0] or x < self._left_support[0]: + raise ValueError("The value of x should be between the two supports") + + A = self._tension['distributed'] + X = Symbol('X') + + return A.subs({X:(x-self._lowest_x_global)}) + + def apply_length(self, length): + """ + This method specifies the length of the cable + + Parameters + ========== + + length : Sympifyable + The length of the cable + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_length(20) + >>> c.length + 20 + """ + dist = ((self._left_support[0] - self._right_support[0])**2 + - (self._left_support[1] - self._right_support[1])**2)**(1/2) + + if length < dist: + raise ValueError("length should not be less than the distance between the supports") + + self._length = length + + def change_support(self, label, new_support): + """ + This method changes the mentioned support with a new support. + + Parameters + ========== + label: String or symbol + The label of the support to be changed + + new_support: Tuple of the form (new_label, x, y) + new_label: String or symbol + The label of the new support + + x: Sympifyable + The x-coordinate of the position of the new support. + + y: Sympifyable + The y-coordinate of the position of the new support. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.supports + {'A': [0, 10], 'B': [10, 10]} + >>> c.change_support('B', ('C', 5, 6)) + >>> c.supports + {'A': [0, 10], 'C': [5, 6]} + """ + if label not in self._supports: + raise ValueError("No support exists with the given label") + + i = self._support_labels.index(label) + rem_label = self._support_labels[(i+1)%2] + x1 = self._supports[rem_label][0] + y1 = self._supports[rem_label][1] + + x = sympify(new_support[1]) + y = sympify(new_support[2]) + + for l in self._loads_position: + if l[0] >= max(x, x1) or l[0] <= min(x, x1): + raise ValueError("The change in support will throw an existing load out of range") + + self._supports.pop(label) + self._left_support.clear() + self._right_support.clear() + self._reaction_loads.clear() + self._support_labels.remove(label) + + self._supports[new_support[0]] = [x, y] + + if x1 < x: + self._left_support.append(x1) + self._left_support.append(y1) + self._right_support.append(x) + self._right_support.append(y) + self._support_labels.append(new_support[0]) + + else: + self._left_support.append(x) + self._left_support.append(y) + self._right_support.append(x1) + self._right_support.append(y1) + self._support_labels.insert(0, new_support[0]) + + for i in self._support_labels: + self._reaction_loads[Symbol("R_"+ i +"_x")] = 0 + self._reaction_loads[Symbol("R_"+ i +"_y")] = 0 + + def apply_load(self, order, load): + """ + This method adds load to the cable. + + Parameters + ========== + + order : Integer + The order of the applied load. + + - For point loads, order = -1 + - For distributed load, order = 0 + + load : tuple + + * For point loads, load is of the form (label, x, y, magnitude, direction), where: + + label : String or symbol + The label of the load + + x : Sympifyable + The x coordinate of the position of the load + + y : Sympifyable + The y coordinate of the position of the load + + magnitude : Sympifyable + The magnitude of the load. It must always be positive + + direction : Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + + * For uniformly distributed load, load is of the form (label, magnitude) + + label : String or symbol + The label of the load + + magnitude : Sympifyable + The magnitude of the load. It must always be positive + + Examples + ======== + + For a point load of magnitude 12 units inclined at 30 degrees with the horizontal: + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('Z', 5, 5, 12, 30)) + >>> c.loads + {'distributed': {}, 'point_load': {'Z': [12, 30]}} + >>> c.loads_position + {'Z': [5, 5]} + + + For a uniformly distributed load of magnitude 9 units: + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(0, ('X', 9)) + >>> c.loads + {'distributed': {'X': 9}, 'point_load': {}} + """ + if order == -1: + if len(self._loads["distributed"]) != 0: + raise ValueError("Distributed load already exists") + + label = load[0] + if label in self._loads["point_load"]: + raise ValueError("Label already exists") + + x = sympify(load[1]) + y = sympify(load[2]) + + if x > self._right_support[0] or x < self._left_support[0]: + raise ValueError("The load should be positioned between the supports") + + magnitude = sympify(load[3]) + direction = sympify(load[4]) + + self._loads["point_load"][label] = [magnitude, direction] + self._loads_position[label] = [x, y] + + elif order == 0: + if len(self._loads_position) != 0: + raise ValueError("Point load(s) already exist") + + label = load[0] + if label in self._loads["distributed"]: + raise ValueError("Label already exists") + + magnitude = sympify(load[1]) + + self._loads["distributed"][label] = magnitude + + else: + raise ValueError("Order should be either -1 or 0") + + def remove_loads(self, *args): + """ + This methods removes the specified loads. + + Parameters + ========== + This input takes multiple label(s) as input + label(s): String or symbol + The label(s) of the loads to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(('A', 0, 10), ('B', 10, 10)) + >>> c.apply_load(-1, ('Z', 5, 5, 12, 30)) + >>> c.loads + {'distributed': {}, 'point_load': {'Z': [12, 30]}} + >>> c.remove_loads('Z') + >>> c.loads + {'distributed': {}, 'point_load': {}} + """ + for i in args: + if len(self._loads_position) == 0: + if i not in self._loads['distributed']: + raise ValueError("Error removing load " + i + ": no such load exists") + + else: + self._loads['disrtibuted'].pop(i) + + else: + if i not in self._loads['point_load']: + raise ValueError("Error removing load " + i + ": no such load exists") + + else: + self._loads['point_load'].pop(i) + self._loads_position.pop(i) + + def solve(self, *args): + """ + This method solves for the reaction forces at the supports, the tension developed in + the cable, and updates the length of the cable. + + Parameters + ========== + This method requires no input when solving for point loads + For distributed load, the x and y coordinates of the lowest point of the cable are + required as + + x: Sympifyable + The x coordinate of the lowest point + + y: Sympifyable + The y coordinate of the lowest point + + Examples + ======== + For point loads, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c = Cable(("A", 0, 10), ("B", 10, 10)) + >>> c.apply_load(-1, ('Z', 2, 7.26, 3, 270)) + >>> c.apply_load(-1, ('X', 4, 6, 8, 270)) + >>> c.solve() + >>> c.tension + {A_Z: 8.91403453669861, X_B: 19*sqrt(13)/10, Z_X: 4.79150773600774} + >>> c.reaction_loads + {R_A_x: -5.25547445255474, R_A_y: 7.2, R_B_x: 5.25547445255474, R_B_y: 3.8} + >>> c.length + 5.7560958484519 + 2*sqrt(13) + + For distributed load, + + >>> from sympy.physics.continuum_mechanics.cable import Cable + >>> c=Cable(("A", 0, 40),("B", 100, 20)) + >>> c.apply_load(0, ("X", 850)) + >>> c.solve(58.58, 0) + >>> c.tension + {'distributed': 36456.8485*sqrt(0.000543529004799705*(X + 0.00135624381275735)**2 + 1)} + >>> c.tension_at(0) + 61709.0363315913 + >>> c.reaction_loads + {R_A_x: 36456.8485, R_A_y: -49788.5866682485, R_B_x: 44389.8401587246, R_B_y: 42866.621696333} + """ + + if len(self._loads_position) != 0: + sorted_position = sorted(self._loads_position.items(), key = lambda item : item[1][0]) + + sorted_position.append(self._support_labels[1]) + sorted_position.insert(0, self._support_labels[0]) + + self._tension.clear() + moment_sum_from_left_support = 0 + moment_sum_from_right_support = 0 + F_x = 0 + F_y = 0 + self._length = 0 + + for i in range(1, len(sorted_position)-1): + if i == 1: + self._length+=sqrt((self._left_support[0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._left_support[1] - self._loads_position[sorted_position[i][0]][1])**2) + + else: + self._length+=sqrt((self._loads_position[sorted_position[i-1][0]][0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._loads_position[sorted_position[i-1][0]][1] - self._loads_position[sorted_position[i][0]][1])**2) + + if i == len(sorted_position)-2: + self._length+=sqrt((self._right_support[0] - self._loads_position[sorted_position[i][0]][0])**2 + (self._right_support[1] - self._loads_position[sorted_position[i][0]][1])**2) + + moment_sum_from_left_support += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._left_support[1] - self._loads_position[sorted_position[i][0]][1]) + moment_sum_from_left_support += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._left_support[0] - self._loads_position[sorted_position[i][0]][0]) + + F_x += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) + F_y += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) + + label = Symbol(sorted_position[i][0]+"_"+sorted_position[i+1][0]) + y2 = self._loads_position[sorted_position[i][0]][1] + x2 = self._loads_position[sorted_position[i][0]][0] + y1 = 0 + x1 = 0 + + if i == len(sorted_position)-2: + x1 = self._right_support[0] + y1 = self._right_support[1] + + else: + x1 = self._loads_position[sorted_position[i+1][0]][0] + y1 = self._loads_position[sorted_position[i+1][0]][1] + + angle_with_horizontal = atan((y1 - y2)/(x1 - x2)) + + tension = -(moment_sum_from_left_support)/(abs(self._left_support[1] - self._loads_position[sorted_position[i][0]][1])*cos(angle_with_horizontal) + abs(self._left_support[0] - self._loads_position[sorted_position[i][0]][0])*sin(angle_with_horizontal)) + self._tension[label] = tension + moment_sum_from_right_support += self._loads['point_load'][sorted_position[i][0]][0] * cos(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._right_support[1] - self._loads_position[sorted_position[i][0]][1]) + moment_sum_from_right_support += self._loads['point_load'][sorted_position[i][0]][0] * sin(pi * self._loads['point_load'][sorted_position[i][0]][1] / 180) * abs(self._right_support[0] - self._loads_position[sorted_position[i][0]][0]) + + label = Symbol(sorted_position[0][0]+"_"+sorted_position[1][0]) + y2 = self._loads_position[sorted_position[1][0]][1] + x2 = self._loads_position[sorted_position[1][0]][0] + x1 = self._left_support[0] + y1 = self._left_support[1] + + angle_with_horizontal = -atan((y2 - y1)/(x2 - x1)) + tension = -(moment_sum_from_right_support)/(abs(self._right_support[1] - self._loads_position[sorted_position[1][0]][1])*cos(angle_with_horizontal) + abs(self._right_support[0] - self._loads_position[sorted_position[1][0]][0])*sin(angle_with_horizontal)) + self._tension[label] = tension + + angle_with_horizontal = pi/2 - angle_with_horizontal + label = self._support_labels[0] + self._reaction_loads[Symbol("R_"+label+"_x")] = -sin(angle_with_horizontal) * tension + F_x += -sin(angle_with_horizontal) * tension + self._reaction_loads[Symbol("R_"+label+"_y")] = cos(angle_with_horizontal) * tension + F_y += cos(angle_with_horizontal) * tension + + label = self._support_labels[1] + self._reaction_loads[Symbol("R_"+label+"_x")] = -F_x + self._reaction_loads[Symbol("R_"+label+"_y")] = -F_y + + elif len(self._loads['distributed']) != 0 : + + if len(args) == 0: + raise ValueError("Provide the lowest point of the cable") + + lowest_x = sympify(args[0]) + lowest_y = sympify(args[1]) + self._lowest_x_global = lowest_x + + a = Symbol('a') + b = Symbol('b') + c = Symbol('c') + # augmented matrix form of linsolve + + M = Matrix( + [[self._left_support[0]**2, self._left_support[0], 1, self._left_support[1]], + [self._right_support[0]**2, self._right_support[0], 1, self._right_support[1]], + [lowest_x**2, lowest_x, 1, lowest_y] ] + ) + + coefficient_solution = list(linsolve(M, (a, b, c))) + + if len(coefficient_solution) == 0: + raise ValueError("The lowest point is inconsistent with the supports") + + A = coefficient_solution[0][0] + B = coefficient_solution[0][1] + C = coefficient_solution[0][2] + + + # y = A*x**2 + B*x + C + # shifting origin to lowest point + X = Symbol('X') + Y = Symbol('Y') + Y = A*(X + lowest_x)**2 + B*(X + lowest_x) + C - lowest_y + + temp_list = list(self._loads['distributed'].values()) + applied_force = temp_list[0] + + horizontal_force_constant = (applied_force * (self._right_support[0] - lowest_x)**2) / (2 * (self._right_support[1] - lowest_y)) + + self._tension.clear() + tangent_slope_to_curve = diff(Y, X) + self._tension['distributed'] = horizontal_force_constant / (cos(atan(tangent_slope_to_curve))) + + label = self._support_labels[0] + self._reaction_loads[Symbol("R_"+label+"_x")] = self.tension_at(self._left_support[0]) * cos(atan(tangent_slope_to_curve.subs(X, self._left_support[0] - lowest_x))) + self._reaction_loads[Symbol("R_"+label+"_y")] = self.tension_at(self._left_support[0]) * sin(atan(tangent_slope_to_curve.subs(X, self._left_support[0] - lowest_x))) + + label = self._support_labels[1] + self._reaction_loads[Symbol("R_"+label+"_x")] = self.tension_at(self._left_support[0]) * cos(atan(tangent_slope_to_curve.subs(X, self._right_support[0] - lowest_x))) + self._reaction_loads[Symbol("R_"+label+"_y")] = self.tension_at(self._left_support[0]) * sin(atan(tangent_slope_to_curve.subs(X, self._right_support[0] - lowest_x))) diff --git a/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/truss.py b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/truss.py new file mode 100644 index 0000000000000000000000000000000000000000..f7fd0ea3f5e18574f21e2f656477c7af987d8eb6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/continuum_mechanics/truss.py @@ -0,0 +1,1108 @@ +""" +This module can be used to solve problems related +to 2D Trusses. +""" + + +from cmath import atan, inf +from sympy.core.add import Add +from sympy.core.evalf import INF +from sympy.core.mul import Mul +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy import Matrix, pi +from sympy.external.importtools import import_module +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import zeros +import math +from sympy.physics.units.quantities import Quantity +from sympy.plotting import plot +from sympy.utilities.decorator import doctest_depends_on +from sympy import sin, cos + + +__doctest_requires__ = {('Truss.draw'): ['matplotlib']} + + +numpy = import_module('numpy', import_kwargs={'fromlist':['arange']}) + + +class Truss: + """ + A Truss is an assembly of members such as beams, + connected by nodes, that create a rigid structure. + In engineering, a truss is a structure that + consists of two-force members only. + + Trusses are extremely important in engineering applications + and can be seen in numerous real-world applications like bridges. + + Examples + ======== + + There is a Truss consisting of four nodes and five + members connecting the nodes. A force P acts + downward on the node D and there also exist pinned + and roller joints on the nodes A and B respectively. + + .. image:: truss_example.png + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(("node_1", 0, 0), ("node_2", 6, 0), ("node_3", 2, 2), ("node_4", 2, 0)) + >>> t.add_member(("member_1", "node_1", "node_4"), ("member_2", "node_2", "node_4"), ("member_3", "node_1", "node_3")) + >>> t.add_member(("member_4", "node_2", "node_3"), ("member_5", "node_3", "node_4")) + >>> t.apply_load(("node_4", 10, 270)) + >>> t.apply_support(("node_1", "pinned"), ("node_2", "roller")) + """ + + def __init__(self): + """ + Initializes the class + """ + self._nodes = [] + self._members = {} + self._loads = {} + self._supports = {} + self._node_labels = [] + self._node_positions = [] + self._node_position_x = [] + self._node_position_y = [] + self._nodes_occupied = {} + self._member_lengths = {} + self._reaction_loads = {} + self._internal_forces = {} + self._node_coordinates = {} + + @property + def nodes(self): + """ + Returns the nodes of the truss along with their positions. + """ + return self._nodes + + @property + def node_labels(self): + """ + Returns the node labels of the truss. + """ + return self._node_labels + + @property + def node_positions(self): + """ + Returns the positions of the nodes of the truss. + """ + return self._node_positions + + @property + def members(self): + """ + Returns the members of the truss along with the start and end points. + """ + return self._members + + @property + def member_lengths(self): + """ + Returns the length of each member of the truss. + """ + return self._member_lengths + + @property + def supports(self): + """ + Returns the nodes with provided supports along with the kind of support provided i.e. + pinned or roller. + """ + return self._supports + + @property + def loads(self): + """ + Returns the loads acting on the truss. + """ + return self._loads + + @property + def reaction_loads(self): + """ + Returns the reaction forces for all supports which are all initialized to 0. + """ + return self._reaction_loads + + @property + def internal_forces(self): + """ + Returns the internal forces for all members which are all initialized to 0. + """ + return self._internal_forces + + def add_node(self, *args): + """ + This method adds a node to the truss along with its name/label and its location. + Multiple nodes can be added at the same time. + + Parameters + ========== + The input(s) for this method are tuples of the form (label, x, y). + + label: String or a Symbol + The label for a node. It is the only way to identify a particular node. + + x: Sympifyable + The x-coordinate of the position of the node. + + y: Sympifyable + The y-coordinate of the position of the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0)) + >>> t.nodes + [('A', 0, 0)] + >>> t.add_node(('B', 3, 0), ('C', 4, 1)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('C', 4, 1)] + """ + + for i in args: + label = i[0] + x = i[1] + x = sympify(x) + y=i[2] + y = sympify(y) + if label in self._node_coordinates: + raise ValueError("Node needs to have a unique label") + + elif [x, y] in self._node_coordinates.values(): + raise ValueError("A node already exists at the given position") + + else : + self._nodes.append((label, x, y)) + self._node_labels.append(label) + self._node_positions.append((x, y)) + self._node_position_x.append(x) + self._node_position_y.append(y) + self._node_coordinates[label] = [x, y] + + + + def remove_node(self, *args): + """ + This method removes a node from the truss. + Multiple nodes can be removed at the same time. + + Parameters + ========== + The input(s) for this method are the labels of the nodes to be removed. + + label: String or Symbol + The label of the node to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 5, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('C', 5, 0)] + >>> t.remove_node('A', 'C') + >>> t.nodes + [('B', 3, 0)] + """ + for label in args: + for i in range(len(self.nodes)): + if self._node_labels[i] == label: + x = self._node_position_x[i] + y = self._node_position_y[i] + + if label not in self._node_coordinates: + raise ValueError("No such node exists in the truss") + + else: + members_duplicate = self._members.copy() + for member in members_duplicate: + if label == self._members[member][0] or label == self._members[member][1]: + raise ValueError("The given node already has member attached to it") + self._nodes.remove((label, x, y)) + self._node_labels.remove(label) + self._node_positions.remove((x, y)) + self._node_position_x.remove(x) + self._node_position_y.remove(y) + if label in self._loads: + self._loads.pop(label) + if label in self._supports: + self._supports.pop(label) + self._node_coordinates.pop(label) + + + + def add_member(self, *args): + """ + This method adds a member between any two nodes in the given truss. + + Parameters + ========== + The input(s) of the method are tuple(s) of the form (label, start, end). + + label: String or Symbol + The label for a member. It is the only way to identify a particular member. + + start: String or Symbol + The label of the starting point/node of the member. + + end: String or Symbol + The label of the ending point/node of the member. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 2, 2)) + >>> t.add_member(('AB', 'A', 'B'), ('BC', 'B', 'C')) + >>> t.members + {'AB': ['A', 'B'], 'BC': ['B', 'C']} + """ + for i in args: + label = i[0] + start = i[1] + end = i[2] + + if start not in self._node_coordinates or end not in self._node_coordinates or start==end: + raise ValueError("The start and end points of the member must be unique nodes") + + elif label in self._members: + raise ValueError("A member with the same label already exists for the truss") + + elif self._nodes_occupied.get((start, end)): + raise ValueError("A member already exists between the two nodes") + + else: + self._members[label] = [start, end] + self._member_lengths[label] = sqrt((self._node_coordinates[end][0]-self._node_coordinates[start][0])**2 + (self._node_coordinates[end][1]-self._node_coordinates[start][1])**2) + self._nodes_occupied[start, end] = True + self._nodes_occupied[end, start] = True + self._internal_forces[label] = 0 + + def remove_member(self, *args): + """ + This method removes members from the given truss. + + Parameters + ========== + labels: String or Symbol + The label for the member to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('C', 2, 2)) + >>> t.add_member(('AB', 'A', 'B'), ('AC', 'A', 'C'), ('BC', 'B', 'C')) + >>> t.members + {'AB': ['A', 'B'], 'AC': ['A', 'C'], 'BC': ['B', 'C']} + >>> t.remove_member('AC', 'BC') + >>> t.members + {'AB': ['A', 'B']} + """ + for label in args: + if label not in self._members: + raise ValueError("No such member exists in the Truss") + + else: + self._nodes_occupied.pop((self._members[label][0], self._members[label][1])) + self._nodes_occupied.pop((self._members[label][1], self._members[label][0])) + self._members.pop(label) + self._member_lengths.pop(label) + self._internal_forces.pop(label) + + def change_node_label(self, *args): + """ + This method changes the label(s) of the specified node(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (label, new_label). + + label: String or Symbol + The label of the node for which the label has + to be changed. + + new_label: String or Symbol + The new label of the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0)] + >>> t.change_node_label(('A', 'C'), ('B', 'D')) + >>> t.nodes + [('C', 0, 0), ('D', 3, 0)] + """ + for i in args: + label = i[0] + new_label = i[1] + if label not in self._node_coordinates: + raise ValueError("No such node exists for the Truss") + elif new_label in self._node_coordinates: + raise ValueError("A node with the given label already exists") + else: + for node in self._nodes: + if node[0] == label: + self._nodes[self._nodes.index((label, node[1], node[2]))] = (new_label, node[1], node[2]) + self._node_labels[self._node_labels.index(node[0])] = new_label + self._node_coordinates[new_label] = self._node_coordinates[label] + self._node_coordinates.pop(label) + if node[0] in self._supports: + self._supports[new_label] = self._supports[node[0]] + self._supports.pop(node[0]) + if new_label in self._supports: + if self._supports[new_label] == 'pinned': + if 'R_'+str(label)+'_x' in self._reaction_loads and 'R_'+str(label)+'_y' in self._reaction_loads: + self._reaction_loads['R_'+str(new_label)+'_x'] = self._reaction_loads['R_'+str(label)+'_x'] + self._reaction_loads['R_'+str(new_label)+'_y'] = self._reaction_loads['R_'+str(label)+'_y'] + self._reaction_loads.pop('R_'+str(label)+'_x') + self._reaction_loads.pop('R_'+str(label)+'_y') + self._loads[new_label] = self._loads[label] + for load in self._loads[new_label]: + if load[1] == 90: + load[0] -= Symbol('R_'+str(label)+'_y') + if load[0] == 0: + self._loads[label].remove(load) + break + for load in self._loads[new_label]: + if load[1] == 0: + load[0] -= Symbol('R_'+str(label)+'_x') + if load[0] == 0: + self._loads[label].remove(load) + break + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_x'), 0) + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_y'), 90) + self._loads.pop(label) + elif self._supports[new_label] == 'roller': + self._loads[new_label] = self._loads[label] + for load in self._loads[label]: + if load[1] == 90: + load[0] -= Symbol('R_'+str(label)+'_y') + if load[0] == 0: + self._loads[label].remove(load) + break + self.apply_load(new_label, Symbol('R_'+str(new_label)+'_y'), 90) + self._loads.pop(label) + else: + if label in self._loads: + self._loads[new_label] = self._loads[label] + self._loads.pop(label) + for member in self._members: + if self._members[member][0] == node[0]: + self._members[member][0] = new_label + self._nodes_occupied[(new_label, self._members[member][1])] = True + self._nodes_occupied[(self._members[member][1], new_label)] = True + self._nodes_occupied.pop((label, self._members[member][1])) + self._nodes_occupied.pop((self._members[member][1], label)) + elif self._members[member][1] == node[0]: + self._members[member][1] = new_label + self._nodes_occupied[(self._members[member][0], new_label)] = True + self._nodes_occupied[(new_label, self._members[member][0])] = True + self._nodes_occupied.pop((self._members[member][0], label)) + self._nodes_occupied.pop((label, self._members[member][0])) + + def change_member_label(self, *args): + """ + This method changes the label(s) of the specified member(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (label, new_label) + + label: String or Symbol + The label of the member for which the label has + to be changed. + + new_label: String or Symbol + The new label of the member. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0), ('D', 5, 0)) + >>> t.nodes + [('A', 0, 0), ('B', 3, 0), ('D', 5, 0)] + >>> t.change_node_label(('A', 'C')) + >>> t.nodes + [('C', 0, 0), ('B', 3, 0), ('D', 5, 0)] + >>> t.add_member(('BC', 'B', 'C'), ('BD', 'B', 'D')) + >>> t.members + {'BC': ['B', 'C'], 'BD': ['B', 'D']} + >>> t.change_member_label(('BC', 'BC_new'), ('BD', 'BD_new')) + >>> t.members + {'BC_new': ['B', 'C'], 'BD_new': ['B', 'D']} + """ + for i in args: + label = i[0] + new_label = i[1] + if label not in self._members: + raise ValueError("No such member exists for the Truss") + else: + members_duplicate = list(self._members).copy() + for member in members_duplicate: + if member == label: + self._members[new_label] = [self._members[member][0], self._members[member][1]] + self._members.pop(label) + self._member_lengths[new_label] = self._member_lengths[label] + self._member_lengths.pop(label) + self._internal_forces[new_label] = self._internal_forces[label] + self._internal_forces.pop(label) + + def apply_load(self, *args): + """ + This method applies external load(s) at the specified node(s). + + Parameters + ========== + The input(s) of the method are tuple(s) of the form (location, magnitude, direction). + + location: String or Symbol + Label of the Node at which load is applied. + + magnitude: Sympifyable + Magnitude of the load applied. It must always be positive and any changes in + the direction of the load are not reflected here. + + direction: Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> from sympy import symbols + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> P = symbols('P') + >>> t.apply_load(('A', P, 90), ('A', P/2, 45), ('A', P/4, 90)) + >>> t.loads + {'A': [[P, 90], [P/2, 45], [P/4, 90]]} + """ + for i in args: + location = i[0] + magnitude = i[1] + direction = i[2] + magnitude = sympify(magnitude) + direction = sympify(direction) + + if location not in self._node_coordinates: + raise ValueError("Load must be applied at a known node") + + else: + if location in self._loads: + self._loads[location].append([magnitude, direction]) + else: + self._loads[location] = [[magnitude, direction]] + + def remove_load(self, *args): + """ + This method removes already + present external load(s) at specified node(s). + + Parameters + ========== + The input(s) of this method are tuple(s) of the form (location, magnitude, direction). + + location: String or Symbol + Label of the Node at which load is applied and is to be removed. + + magnitude: Sympifyable + Magnitude of the load applied. + + direction: Sympifyable + The angle, in degrees, that the load vector makes with the horizontal + in the counter-clockwise direction. It takes the values 0 to 360, + inclusive. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> from sympy import symbols + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> P = symbols('P') + >>> t.apply_load(('A', P, 90), ('A', P/2, 45), ('A', P/4, 90)) + >>> t.loads + {'A': [[P, 90], [P/2, 45], [P/4, 90]]} + >>> t.remove_load(('A', P/4, 90), ('A', P/2, 45)) + >>> t.loads + {'A': [[P, 90]]} + """ + for i in args: + location = i[0] + magnitude = i[1] + direction = i[2] + magnitude = sympify(magnitude) + direction = sympify(direction) + + if location not in self._node_coordinates: + raise ValueError("Load must be removed from a known node") + + else: + if [magnitude, direction] not in self._loads[location]: + raise ValueError("No load of this magnitude and direction has been applied at this node") + else: + self._loads[location].remove([magnitude, direction]) + if self._loads[location] == []: + self._loads.pop(location) + + def apply_support(self, *args): + """ + This method adds a pinned or roller support at specified node(s). + + Parameters + ========== + The input(s) of this method are of the form (location, type). + + location: String or Symbol + Label of the Node at which support is added. + + type: String + Type of the support being provided at the node. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.apply_support(('A', 'pinned'), ('B', 'roller')) + >>> t.supports + {'A': 'pinned', 'B': 'roller'} + """ + for i in args: + location = i[0] + type = i[1] + if location not in self._node_coordinates: + raise ValueError("Support must be added on a known node") + + else: + if location not in self._supports: + if type == 'pinned': + self.apply_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self.apply_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif type == 'roller': + self.apply_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif self._supports[location] == 'pinned': + if type == 'roller': + self.remove_load((location, Symbol('R_'+str(location)+'_x'), 0)) + elif self._supports[location] == 'roller': + if type == 'pinned': + self.apply_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self._supports[location] = type + + def remove_support(self, *args): + """ + This method removes support from specified node(s.) + + Parameters + ========== + + locations: String or Symbol + Label of the Node(s) at which support is to be removed. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(('A', 0, 0), ('B', 3, 0)) + >>> t.apply_support(('A', 'pinned'), ('B', 'roller')) + >>> t.supports + {'A': 'pinned', 'B': 'roller'} + >>> t.remove_support('A','B') + >>> t.supports + {} + """ + for location in args: + + if location not in self._node_coordinates: + raise ValueError("No such node exists in the Truss") + + elif location not in self._supports: + raise ValueError("No support has been added to the given node") + + else: + if self._supports[location] == 'pinned': + self.remove_load((location, Symbol('R_'+str(location)+'_x'), 0)) + self.remove_load((location, Symbol('R_'+str(location)+'_y'), 90)) + elif self._supports[location] == 'roller': + self.remove_load((location, Symbol('R_'+str(location)+'_y'), 90)) + self._supports.pop(location) + + def solve(self): + """ + This method solves for all reaction forces of all supports and all internal forces + of all the members in the truss, provided the Truss is solvable. + + A Truss is solvable if the following condition is met, + + 2n >= r + m + + Where n is the number of nodes, r is the number of reaction forces, where each pinned + support has 2 reaction forces and each roller has 1, and m is the number of members. + + The given condition is derived from the fact that a system of equations is solvable + only when the number of variables is lesser than or equal to the number of equations. + Equilibrium Equations in x and y directions give two equations per node giving 2n number + equations. However, the truss needs to be stable as well and may be unstable if 2n > r + m. + The number of variables is simply the sum of the number of reaction forces and member + forces. + + .. note:: + The sign convention for the internal forces present in a member revolves around whether each + force is compressive or tensile. While forming equations for each node, internal force due + to a member on the node is assumed to be away from the node i.e. each force is assumed to + be compressive by default. Hence, a positive value for an internal force implies the + presence of compressive force in the member and a negative value implies a tensile force. + + Examples + ======== + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> t = Truss() + >>> t.add_node(("node_1", 0, 0), ("node_2", 6, 0), ("node_3", 2, 2), ("node_4", 2, 0)) + >>> t.add_member(("member_1", "node_1", "node_4"), ("member_2", "node_2", "node_4"), ("member_3", "node_1", "node_3")) + >>> t.add_member(("member_4", "node_2", "node_3"), ("member_5", "node_3", "node_4")) + >>> t.apply_load(("node_4", 10, 270)) + >>> t.apply_support(("node_1", "pinned"), ("node_2", "roller")) + >>> t.solve() + >>> t.reaction_loads + {'R_node_1_x': 0, 'R_node_1_y': 20/3, 'R_node_2_y': 10/3} + >>> t.internal_forces + {'member_1': 20/3, 'member_2': 20/3, 'member_3': -20*sqrt(2)/3, 'member_4': -10*sqrt(5)/3, 'member_5': 10} + """ + count_reaction_loads = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + count_reaction_loads += 2 + elif self._supports[node[0]]=='roller': + count_reaction_loads += 1 + if 2*len(self._nodes) != len(self._members) + count_reaction_loads: + raise ValueError("The given truss cannot be solved") + coefficients_matrix = [[0 for i in range(2*len(self._nodes))] for j in range(2*len(self._nodes))] + load_matrix = zeros(2*len(self.nodes), 1) + load_matrix_row = 0 + for node in self._nodes: + if node[0] in self._loads: + for load in self._loads[node[0]]: + if load[0]!=Symbol('R_'+str(node[0])+'_x') and load[0]!=Symbol('R_'+str(node[0])+'_y'): + load_matrix[load_matrix_row] -= load[0]*cos(pi*load[1]/180) + load_matrix[load_matrix_row + 1] -= load[0]*sin(pi*load[1]/180) + load_matrix_row += 2 + cols = 0 + row = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + coefficients_matrix[row][cols] += 1 + coefficients_matrix[row+1][cols+1] += 1 + cols += 2 + elif self._supports[node[0]]=='roller': + coefficients_matrix[row+1][cols] += 1 + cols += 1 + row += 2 + for member in self._members: + start = self._members[member][0] + end = self._members[member][1] + length = sqrt((self._node_coordinates[start][0]-self._node_coordinates[end][0])**2 + (self._node_coordinates[start][1]-self._node_coordinates[end][1])**2) + start_index = self._node_labels.index(start) + end_index = self._node_labels.index(end) + horizontal_component_start = (self._node_coordinates[end][0]-self._node_coordinates[start][0])/length + vertical_component_start = (self._node_coordinates[end][1]-self._node_coordinates[start][1])/length + horizontal_component_end = (self._node_coordinates[start][0]-self._node_coordinates[end][0])/length + vertical_component_end = (self._node_coordinates[start][1]-self._node_coordinates[end][1])/length + coefficients_matrix[start_index*2][cols] += horizontal_component_start + coefficients_matrix[start_index*2+1][cols] += vertical_component_start + coefficients_matrix[end_index*2][cols] += horizontal_component_end + coefficients_matrix[end_index*2+1][cols] += vertical_component_end + cols += 1 + forces_matrix = (Matrix(coefficients_matrix)**-1)*load_matrix + self._reaction_loads = {} + i = 0 + min_load = inf + for node in self._nodes: + if node[0] in self._loads: + for load in self._loads[node[0]]: + if type(load[0]) not in [Symbol, Mul, Add]: + min_load = min(min_load, load[0]) + for j in range(len(forces_matrix)): + if type(forces_matrix[j]) not in [Symbol, Mul, Add]: + if abs(forces_matrix[j]/min_load) <1E-10: + forces_matrix[j] = 0 + for node in self._nodes: + if node[0] in self._supports: + if self._supports[node[0]]=='pinned': + self._reaction_loads['R_'+str(node[0])+'_x'] = forces_matrix[i] + self._reaction_loads['R_'+str(node[0])+'_y'] = forces_matrix[i+1] + i += 2 + elif self._supports[node[0]]=='roller': + self._reaction_loads['R_'+str(node[0])+'_y'] = forces_matrix[i] + i += 1 + for member in self._members: + self._internal_forces[member] = forces_matrix[i] + i += 1 + return + + @doctest_depends_on(modules=('numpy',)) + def draw(self, subs_dict=None): + """ + Returns a plot object of the Truss with all its nodes, members, + supports and loads. + + .. note:: + The user must be careful while entering load values in their + directions. The draw function assumes a sign convention that + is used for plotting loads. + + Given a right-handed coordinate system with XYZ coordinates, + the supports are assumed to be such that the reaction forces of a + pinned support is in the +X and +Y direction while those of a + roller support is in the +Y direction. For the load, the range + of angles, one can input goes all the way to 360 degrees which, in the + the plot is the angle that the load vector makes with the positive x-axis in the anticlockwise direction. + + For example, for a 90-degree angle, the load will be a vertically + directed along +Y while a 270-degree angle denotes a vertical + load as well but along -Y. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.physics.continuum_mechanics.truss import Truss + >>> import math + >>> t = Truss() + >>> t.add_node(("A", -4, 0), ("B", 0, 0), ("C", 4, 0), ("D", 8, 0)) + >>> t.add_node(("E", 6, 2/math.sqrt(3))) + >>> t.add_node(("F", 2, 2*math.sqrt(3))) + >>> t.add_node(("G", -2, 2/math.sqrt(3))) + >>> t.add_member(("AB","A","B"), ("BC","B","C"), ("CD","C","D")) + >>> t.add_member(("AG","A","G"), ("GB","G","B"), ("GF","G","F")) + >>> t.add_member(("BF","B","F"), ("FC","F","C"), ("CE","C","E")) + >>> t.add_member(("FE","F","E"), ("DE","D","E")) + >>> t.apply_support(("A","pinned"), ("D","roller")) + >>> t.apply_load(("G", 3, 90), ("E", 3, 90), ("F", 2, 90)) + >>> p = t.draw() + >>> p # doctest: +ELLIPSIS + Plot object containing: + [0]: cartesian line: 1 for x over (1.0, 1.0) + ... + >>> p.show() + """ + if not numpy: + raise ImportError("To use this function numpy module is required") + + x = Symbol('x') + + markers = [] + annotations = [] + rectangles = [] + + node_markers = self._draw_nodes(subs_dict) + markers += node_markers + + member_rectangles = self._draw_members() + rectangles += member_rectangles + + support_markers = self._draw_supports() + markers += support_markers + + load_annotations = self._draw_loads() + annotations += load_annotations + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + lim = max(xmax*1.1-xmin*0.8+1, ymax*1.1-ymin*0.8+1) + + if lim==xmax*1.1-xmin*0.8+1: + sing_plot = plot(1, (x, 1, 1), markers=markers, show=False, annotations=annotations, xlim=(xmin-0.05*lim, xmax*1.1), ylim=(xmin-0.05*lim, xmax*1.1), axis=False, rectangles=rectangles) + + else: + sing_plot = plot(1, (x, 1, 1), markers=markers, show=False, annotations=annotations, xlim=(ymin-0.05*lim, ymax*1.1), ylim=(ymin-0.05*lim, ymax*1.1), axis=False, rectangles=rectangles) + + return sing_plot + + + def _draw_nodes(self, subs_dict): + node_markers = [] + + for node in self._node_coordinates: + if (type(self._node_coordinates[node][0]) in (Symbol, Quantity)): + if self._node_coordinates[node][0] in subs_dict: + self._node_coordinates[node][0] = subs_dict[self._node_coordinates[node][0]] + else: + raise ValueError("provided substituted dictionary is not adequate") + elif (type(self._node_coordinates[node][0]) == Mul): + objects = self._node_coordinates[node][0].as_coeff_Mul() + for object in objects: + if type(object) in (Symbol, Quantity): + if subs_dict==None or object not in subs_dict: + raise ValueError("provided substituted dictionary is not adequate") + else: + self._node_coordinates[node][0] /= object + self._node_coordinates[node][0] *= subs_dict[object] + + if (type(self._node_coordinates[node][1]) in (Symbol, Quantity)): + if self._node_coordinates[node][1] in subs_dict: + self._node_coordinates[node][1] = subs_dict[self._node_coordinates[node][1]] + else: + raise ValueError("provided substituted dictionary is not adequate") + elif (type(self._node_coordinates[node][1]) == Mul): + objects = self._node_coordinates[node][1].as_coeff_Mul() + for object in objects: + if type(object) in (Symbol, Quantity): + if subs_dict==None or object not in subs_dict: + raise ValueError("provided substituted dictionary is not adequate") + else: + self._node_coordinates[node][1] /= object + self._node_coordinates[node][1] *= subs_dict[object] + + for node in self._node_coordinates: + node_markers.append( + { + 'args':[[self._node_coordinates[node][0]], [self._node_coordinates[node][1]]], + 'marker':'o', + 'markersize':5, + 'color':'black' + } + ) + return node_markers + + def _draw_members(self): + + member_rectangles = [] + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for member in self._members: + x1 = self._node_coordinates[self._members[member][0]][0] + y1 = self._node_coordinates[self._members[member][0]][1] + x2 = self._node_coordinates[self._members[member][1]][0] + y2 = self._node_coordinates[self._members[member][1]][1] + if x2!=x1 and y2!=y1: + if x2>x1: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff*cos(pi/4+atan((y2-y1)/(x2-x1)))/2, y1-0.005*max_diff*sin(pi/4+atan((y2-y1)/(x2-x1)))/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2)+0.005*max_diff/math.sqrt(2), + 'height':0.005*max_diff, + 'angle':180*atan((y2-y1)/(x2-x1))/pi, + 'color':'brown' + } + ) + else: + member_rectangles.append( + { + 'xy':(x2-0.005*max_diff*cos(pi/4+atan((y2-y1)/(x2-x1)))/2, y2-0.005*max_diff*sin(pi/4+atan((y2-y1)/(x2-x1)))/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2)+0.005*max_diff/math.sqrt(2), + 'height':0.005*max_diff, + 'angle':180*atan((y2-y1)/(x2-x1))/pi, + 'color':'brown' + } + ) + elif y2==y1: + if x2>x1: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff/2, y1-0.005*max_diff/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2), + 'height':0.005*max_diff, + 'angle':90*(1-math.copysign(1, x2-x1)), + 'color':'brown' + } + ) + else: + member_rectangles.append( + { + 'xy':(x1-0.005*max_diff/2, y1-0.005*max_diff/2), + 'width':sqrt((x1-x2)**2+(y1-y2)**2), + 'height':-0.005*max_diff, + 'angle':90*(1-math.copysign(1, x2-x1)), + 'color':'brown' + } + ) + else: + if y1abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin + else: + max_diff = 1.1*ymax-0.8*ymin + + for node in self._supports: + if self._supports[node]=='pinned': + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]] + ], + 'marker':6, + 'markersize':15, + 'color':'black', + 'markerfacecolor':'none' + } + ) + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.035*max_diff] + ], + 'marker':'_', + 'markersize':14, + 'color':'black' + } + ) + + elif self._supports[node]=='roller': + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.02*max_diff] + ], + 'marker':'o', + 'markersize':11, + 'color':'black', + 'markerfacecolor':'none' + } + ) + support_markers.append( + { + 'args':[ + [self._node_coordinates[node][0]], + [self._node_coordinates[node][1]-0.0375*max_diff] + ], + 'marker':'_', + 'markersize':14, + 'color':'black' + } + ) + return support_markers + + def _draw_loads(self): + load_annotations = [] + + xmax = -INF + xmin = INF + ymax = -INF + ymin = INF + + for node in self._node_coordinates: + xmax = max(xmax, self._node_coordinates[node][0]) + xmin = min(xmin, self._node_coordinates[node][0]) + ymax = max(ymax, self._node_coordinates[node][1]) + ymin = min(ymin, self._node_coordinates[node][1]) + + if abs(1.1*xmax-0.8*xmin)>abs(1.1*ymax-0.8*ymin): + max_diff = 1.1*xmax-0.8*xmin+5 + else: + max_diff = 1.1*ymax-0.8*ymin+5 + + for node in self._loads: + for load in self._loads[node]: + if load[0] in [Symbol('R_'+str(node)+'_x'), Symbol('R_'+str(node)+'_y')]: + continue + x = self._node_coordinates[node][0] + y = self._node_coordinates[node][1] + load_annotations.append( + { + 'text':'', + 'xy':( + x-math.cos(pi*load[1]/180)*(max_diff/100), + y-math.sin(pi*load[1]/180)*(max_diff/100) + ), + 'xytext':( + x-(max_diff/100+abs(xmax-xmin)+abs(ymax-ymin))*math.cos(pi*load[1]/180)/20, + y-(max_diff/100+abs(xmax-xmin)+abs(ymax-ymin))*math.sin(pi*load[1]/180)/20 + ), + 'arrowprops':{'width':1.5, 'headlength':5, 'headwidth':5, 'facecolor':'black'} + } + ) + return load_annotations diff --git a/lib/python3.10/site-packages/sympy/physics/control/__init__.py b/lib/python3.10/site-packages/sympy/physics/control/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb8c13ff147b3603466c8c4b2d9c8c0b25e3b360 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/control/__init__.py @@ -0,0 +1,16 @@ +from .lti import (TransferFunction, Series, MIMOSeries, Parallel, MIMOParallel, + Feedback, MIMOFeedback, TransferFunctionMatrix, StateSpace, gbt, bilinear, forward_diff, + backward_diff, phase_margin, gain_margin) +from .control_plots import (pole_zero_numerical_data, pole_zero_plot, step_response_numerical_data, + step_response_plot, impulse_response_numerical_data, impulse_response_plot, ramp_response_numerical_data, + ramp_response_plot, bode_magnitude_numerical_data, bode_phase_numerical_data, bode_magnitude_plot, + bode_phase_plot, bode_plot) + +__all__ = ['TransferFunction', 'Series', 'MIMOSeries', 'Parallel', + 'MIMOParallel', 'Feedback', 'MIMOFeedback', 'TransferFunctionMatrix', 'StateSpace', + 'gbt', 'bilinear', 'forward_diff', 'backward_diff', 'phase_margin', 'gain_margin', + 'pole_zero_numerical_data', 'pole_zero_plot', 'step_response_numerical_data', + 'step_response_plot', 'impulse_response_numerical_data', 'impulse_response_plot', + 'ramp_response_numerical_data', 'ramp_response_plot', + 'bode_magnitude_numerical_data', 'bode_phase_numerical_data', + 'bode_magnitude_plot', 'bode_phase_plot', 'bode_plot'] diff --git a/lib/python3.10/site-packages/sympy/physics/control/control_plots.py b/lib/python3.10/site-packages/sympy/physics/control/control_plots.py new file mode 100644 index 0000000000000000000000000000000000000000..3742de329e61a84ff604accaced369261bc4befe --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/control/control_plots.py @@ -0,0 +1,978 @@ +from sympy.core.numbers import I, pi +from sympy.functions.elementary.exponential import (exp, log) +from sympy.polys.partfrac import apart +from sympy.core.symbol import Dummy +from sympy.external import import_module +from sympy.functions import arg, Abs +from sympy.integrals.laplace import _fast_inverse_laplace +from sympy.physics.control.lti import SISOLinearTimeInvariant +from sympy.plotting.series import LineOver1DRangeSeries +from sympy.polys.polytools import Poly +from sympy.printing.latex import latex + +__all__ = ['pole_zero_numerical_data', 'pole_zero_plot', + 'step_response_numerical_data', 'step_response_plot', + 'impulse_response_numerical_data', 'impulse_response_plot', + 'ramp_response_numerical_data', 'ramp_response_plot', + 'bode_magnitude_numerical_data', 'bode_phase_numerical_data', + 'bode_magnitude_plot', 'bode_phase_plot', 'bode_plot'] + +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) + +numpy = import_module('numpy') + +if matplotlib: + plt = matplotlib.pyplot + +if numpy: + np = numpy # Matplotlib already has numpy as a compulsory dependency. No need to install it separately. + + +def _check_system(system): + """Function to check whether the dynamical system passed for plots is + compatible or not.""" + if not isinstance(system, SISOLinearTimeInvariant): + raise NotImplementedError("Only SISO LTI systems are currently supported.") + sys = system.to_expr() + len_free_symbols = len(sys.free_symbols) + if len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + if sys.has(exp): + # Should test that exp is not part of a constant, in which case + # no exception is required, compare exp(s) with s*exp(1) + raise NotImplementedError("Time delay terms are not supported.") + + +def pole_zero_numerical_data(system): + """ + Returns the numerical data of poles and zeros of the system. + It is internally used by ``pole_zero_plot`` to get the data + for plotting poles and zeros. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the pole-zero data is to be computed. + + Returns + ======= + + tuple : (zeros, poles) + zeros = Zeros of the system. NumPy array of complex numbers. + poles = Poles of the system. NumPy array of complex numbers. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import pole_zero_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> pole_zero_numerical_data(tf1) # doctest: +SKIP + ([-0.+1.j 0.-1.j], [-2. +0.j -0.5+0.8660254j -0.5-0.8660254j -1. +0.j ]) + + See Also + ======== + + pole_zero_plot + + """ + _check_system(system) + system = system.doit() # Get the equivalent TransferFunction object. + + num_poly = Poly(system.num, system.var).all_coeffs() + den_poly = Poly(system.den, system.var).all_coeffs() + + num_poly = np.array(num_poly, dtype=np.complex128) + den_poly = np.array(den_poly, dtype=np.complex128) + + zeros = np.roots(num_poly) + poles = np.roots(den_poly) + + return zeros, poles + + +def pole_zero_plot(system, pole_color='blue', pole_markersize=10, + zero_color='orange', zero_markersize=7, grid=True, show_axes=True, + show=True, **kwargs): + r""" + Returns the Pole-Zero plot (also known as PZ Plot or PZ Map) of a system. + + A Pole-Zero plot is a graphical representation of a system's poles and + zeros. It is plotted on a complex plane, with circular markers representing + the system's zeros and 'x' shaped markers representing the system's poles. + + Parameters + ========== + + system : SISOLinearTimeInvariant type systems + The system for which the pole-zero plot is to be computed. + pole_color : str, tuple, optional + The color of the pole points on the plot. Default color + is blue. The color can be provided as a matplotlib color string, + or a 3-tuple of floats each in the 0-1 range. + pole_markersize : Number, optional + The size of the markers used to mark the poles in the plot. + Default pole markersize is 10. + zero_color : str, tuple, optional + The color of the zero points on the plot. Default color + is orange. The color can be provided as a matplotlib color string, + or a 3-tuple of floats each in the 0-1 range. + zero_markersize : Number, optional + The size of the markers used to mark the zeros in the plot. + Default zero markersize is 7. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import pole_zero_plot + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> pole_zero_plot(tf1) # doctest: +SKIP + + See Also + ======== + + pole_zero_numerical_data + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Pole%E2%80%93zero_plot + + """ + zeros, poles = pole_zero_numerical_data(system) + + zero_real = np.real(zeros) + zero_imag = np.imag(zeros) + + pole_real = np.real(poles) + pole_imag = np.imag(poles) + + plt.plot(pole_real, pole_imag, 'x', mfc='none', + markersize=pole_markersize, color=pole_color) + plt.plot(zero_real, zero_imag, 'o', markersize=zero_markersize, + color=zero_color) + plt.xlabel('Real Axis') + plt.ylabel('Imaginary Axis') + plt.title(f'Poles and Zeros of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def step_response_numerical_data(system, prec=8, lower_limit=0, + upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the step response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the unit step response data is to be computed. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the step response. NumPy array. + y = Amplitude-axis values of the points in the step response. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import step_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> step_response_numerical_data(tf1) # doctest: +SKIP + ([0.0, 0.025413462339411542, 0.0484508722725343, ... , 9.670250533855183, 9.844291913708725, 10.0], + [0.0, 0.023844582399907256, 0.042894276802320226, ..., 6.828770759094287e-12, 6.456457160755703e-12]) + + See Also + ======== + + step_response_plot + + """ + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = system.to_expr()/(system.var) + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def step_response_plot(system, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the unit step response of a continuous-time system. It is + the response of the system when the input signal is a step function. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Step Response is to be computed. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import step_response_plot + >>> tf1 = TransferFunction(8*s**2 + 18*s + 32, s**3 + 6*s**2 + 14*s + 24, s) + >>> step_response_plot(tf1) # doctest: +SKIP + + See Also + ======== + + impulse_response_plot, ramp_response_plot + + References + ========== + + .. [1] https://www.mathworks.com/help/control/ref/lti.step.html + + """ + x, y = step_response_numerical_data(system, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Unit Step Response of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def impulse_response_numerical_data(system, prec=8, lower_limit=0, + upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the impulse response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the impulse response data is to be computed. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the impulse response. NumPy array. + y = Amplitude-axis values of the points in the impulse response. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import impulse_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> impulse_response_numerical_data(tf1) # doctest: +SKIP + ([0.0, 0.06616480200395854,... , 9.854500743565858, 10.0], + [0.9999999799999999, 0.7042848373025861,...,7.170748906965121e-13, -5.1901263495547205e-12]) + + See Also + ======== + + impulse_response_plot + + """ + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = system.to_expr() + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def impulse_response_plot(system, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the unit impulse response (Input is the Dirac-Delta Function) of a + continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Impulse Response is to be computed. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import impulse_response_plot + >>> tf1 = TransferFunction(8*s**2 + 18*s + 32, s**3 + 6*s**2 + 14*s + 24, s) + >>> impulse_response_plot(tf1) # doctest: +SKIP + + See Also + ======== + + step_response_plot, ramp_response_plot + + References + ========== + + .. [1] https://www.mathworks.com/help/control/ref/dynamicsystem.impulse.html + + """ + x, y = impulse_response_numerical_data(system, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Impulse Response of ${latex(system)}$', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def ramp_response_numerical_data(system, slope=1, prec=8, + lower_limit=0, upper_limit=10, **kwargs): + """ + Returns the numerical values of the points in the ramp response plot + of a SISO continuous-time system. By default, adaptive sampling + is used. If the user wants to instead get an uniformly + sampled response, then ``adaptive`` kwarg should be passed ``False`` + and ``n`` must be passed as additional kwargs. + Refer to the parameters of class :class:`sympy.plotting.series.LineOver1DRangeSeries` + for more details. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the ramp response data is to be computed. + slope : Number, optional + The slope of the input ramp function. Defaults to 1. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + kwargs : + Additional keyword arguments are passed to the underlying + :class:`sympy.plotting.series.LineOver1DRangeSeries` class. + + Returns + ======= + + tuple : (x, y) + x = Time-axis values of the points in the ramp response plot. NumPy array. + y = Amplitude-axis values of the points in the ramp response plot. NumPy array. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When ``lower_limit`` parameter is less than 0. + + When ``slope`` is negative. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import ramp_response_numerical_data + >>> tf1 = TransferFunction(s, s**2 + 5*s + 8, s) + >>> ramp_response_numerical_data(tf1) # doctest: +SKIP + (([0.0, 0.12166980856813935,..., 9.861246379582118, 10.0], + [1.4504508011325967e-09, 0.006046440489058766,..., 0.12499999999568202, 0.12499999999661349])) + + See Also + ======== + + ramp_response_plot + + """ + if slope < 0: + raise ValueError("Slope must be greater than or equal" + " to zero.") + if lower_limit < 0: + raise ValueError("Lower limit of time must be greater " + "than or equal to zero.") + _check_system(system) + _x = Dummy("x") + expr = (slope*system.to_expr())/((system.var)**2) + expr = apart(expr, system.var, full=True) + _y = _fast_inverse_laplace(expr, system.var, _x).evalf(prec) + return LineOver1DRangeSeries(_y, (_x, lower_limit, upper_limit), + **kwargs).get_points() + + +def ramp_response_plot(system, slope=1, color='b', prec=8, lower_limit=0, + upper_limit=10, show_axes=False, grid=True, show=True, **kwargs): + r""" + Returns the ramp response of a continuous-time system. + + Ramp function is defined as the straight line + passing through origin ($f(x) = mx$). The slope of + the ramp function can be varied by the user and + the default value is 1. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Ramp Response is to be computed. + slope : Number, optional + The slope of the input ramp function. Defaults to 1. + color : str, tuple, optional + The color of the line. Default is Blue. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + lower_limit : Number, optional + The lower limit of the plot range. Defaults to 0. + upper_limit : Number, optional + The upper limit of the plot range. Defaults to 10. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import ramp_response_plot + >>> tf1 = TransferFunction(s, (s+4)*(s+8), s) + >>> ramp_response_plot(tf1, upper_limit=2) # doctest: +SKIP + + See Also + ======== + + step_response_plot, impulse_response_plot + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ramp_function + + """ + x, y = ramp_response_numerical_data(system, slope=slope, prec=prec, + lower_limit=lower_limit, upper_limit=upper_limit, **kwargs) + plt.plot(x, y, color=color) + plt.xlabel('Time (s)') + plt.ylabel('Amplitude') + plt.title(f'Ramp Response of ${latex(system)}$ [Slope = {slope}]', pad=20) + + if grid: + plt.grid() + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_magnitude_numerical_data(system, initial_exp=-5, final_exp=5, freq_unit='rad/sec', **kwargs): + """ + Returns the numerical data of the Bode magnitude plot of the system. + It is internally used by ``bode_magnitude_plot`` to get the data + for plotting Bode magnitude plot. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the data is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and ``'Hz'`` (Hertz) as frequency units. + + Returns + ======= + + tuple : (x, y) + x = x-axis values of the Bode magnitude plot. + y = y-axis values of the Bode magnitude plot. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When incorrect frequency units are given as input. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_magnitude_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> bode_magnitude_numerical_data(tf1) # doctest: +SKIP + ([1e-05, 1.5148378120533502e-05,..., 68437.36188804005, 100000.0], + [-6.020599914256786, -6.0205999155219505,..., -193.4117304087953, -200.00000000260573]) + + See Also + ======== + + bode_magnitude_plot, bode_phase_numerical_data + + """ + _check_system(system) + expr = system.to_expr() + freq_units = ('rad/sec', 'Hz') + if freq_unit not in freq_units: + raise ValueError('Only "rad/sec" and "Hz" are accepted frequency units.') + + _w = Dummy("w", real=True) + if freq_unit == 'Hz': + repl = I*_w*2*pi + else: + repl = I*_w + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + + x, y = LineOver1DRangeSeries(mag, + (_w, 10**initial_exp, 10**final_exp), xscale='log', **kwargs).get_points() + + return x, y + + +def bode_magnitude_plot(system, initial_exp=-5, final_exp=5, + color='b', show_axes=False, grid=True, show=True, freq_unit='rad/sec', **kwargs): + r""" + Returns the Bode magnitude plot of a continuous-time system. + + See ``bode_plot`` for all the parameters. + """ + x, y = bode_magnitude_numerical_data(system, initial_exp=initial_exp, + final_exp=final_exp, freq_unit=freq_unit) + plt.plot(x, y, color=color, **kwargs) + plt.xscale('log') + + + plt.xlabel('Frequency (%s) [Log Scale]' % freq_unit) + plt.ylabel('Magnitude (dB)') + plt.title(f'Bode Plot (Magnitude) of ${latex(system)}$', pad=20) + + if grid: + plt.grid(True) + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_phase_numerical_data(system, initial_exp=-5, final_exp=5, freq_unit='rad/sec', phase_unit='rad', phase_unwrap = True, **kwargs): + """ + Returns the numerical data of the Bode phase plot of the system. + It is internally used by ``bode_phase_plot`` to get the data + for plotting Bode phase plot. Users can use this data to further + analyse the dynamics of the system or plot using a different + backend/plotting-module. + + Parameters + ========== + + system : SISOLinearTimeInvariant + The system for which the Bode phase plot data is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and '``'Hz'`` (Hertz) as frequency units. + phase_unit : string, optional + User can choose between ``'rad'`` (radians) and ``'deg'`` (degree) as phase units. + phase_unwrap : bool, optional + Set to ``True`` by default. + + Returns + ======= + + tuple : (x, y) + x = x-axis values of the Bode phase plot. + y = y-axis values of the Bode phase plot. + + Raises + ====== + + NotImplementedError + When a SISO LTI system is not passed. + + When time delay terms are present in the system. + + ValueError + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + When incorrect frequency or phase units are given as input. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_phase_numerical_data + >>> tf1 = TransferFunction(s**2 + 1, s**4 + 4*s**3 + 6*s**2 + 5*s + 2, s) + >>> bode_phase_numerical_data(tf1) # doctest: +SKIP + ([1e-05, 1.4472354033813751e-05, 2.035581932165858e-05,..., 47577.3248186011, 67884.09326036123, 100000.0], + [-2.5000000000291665e-05, -3.6180885085e-05, -5.08895483066e-05,...,-3.1415085799262523, -3.14155265358979]) + + See Also + ======== + + bode_magnitude_plot, bode_phase_numerical_data + + """ + _check_system(system) + expr = system.to_expr() + freq_units = ('rad/sec', 'Hz') + phase_units = ('rad', 'deg') + if freq_unit not in freq_units: + raise ValueError('Only "rad/sec" and "Hz" are accepted frequency units.') + if phase_unit not in phase_units: + raise ValueError('Only "rad" and "deg" are accepted phase units.') + + _w = Dummy("w", real=True) + if freq_unit == 'Hz': + repl = I*_w*2*pi + else: + repl = I*_w + w_expr = expr.subs({system.var: repl}) + + if phase_unit == 'deg': + phase = arg(w_expr)*180/pi + else: + phase = arg(w_expr) + + x, y = LineOver1DRangeSeries(phase, + (_w, 10**initial_exp, 10**final_exp), xscale='log', **kwargs).get_points() + + half = None + if phase_unwrap: + if(phase_unit == 'rad'): + half = pi + elif(phase_unit == 'deg'): + half = 180 + if half: + unit = 2*half + for i in range(1, len(y)): + diff = y[i] - y[i - 1] + if diff > half: # Jump from -half to half + y[i] = (y[i] - unit) + elif diff < -half: # Jump from half to -half + y[i] = (y[i] + unit) + + return x, y + + +def bode_phase_plot(system, initial_exp=-5, final_exp=5, + color='b', show_axes=False, grid=True, show=True, freq_unit='rad/sec', phase_unit='rad', phase_unwrap=True, **kwargs): + r""" + Returns the Bode phase plot of a continuous-time system. + + See ``bode_plot`` for all the parameters. + """ + x, y = bode_phase_numerical_data(system, initial_exp=initial_exp, + final_exp=final_exp, freq_unit=freq_unit, phase_unit=phase_unit, phase_unwrap=phase_unwrap) + plt.plot(x, y, color=color, **kwargs) + plt.xscale('log') + + plt.xlabel('Frequency (%s) [Log Scale]' % freq_unit) + plt.ylabel('Phase (%s)' % phase_unit) + plt.title(f'Bode Plot (Phase) of ${latex(system)}$', pad=20) + + if grid: + plt.grid(True) + if show_axes: + plt.axhline(0, color='black') + plt.axvline(0, color='black') + if show: + plt.show() + return + + return plt + + +def bode_plot(system, initial_exp=-5, final_exp=5, + grid=True, show_axes=False, show=True, freq_unit='rad/sec', phase_unit='rad', phase_unwrap=True, **kwargs): + r""" + Returns the Bode phase and magnitude plots of a continuous-time system. + + Parameters + ========== + + system : SISOLinearTimeInvariant type + The LTI SISO system for which the Bode Plot is to be computed. + initial_exp : Number, optional + The initial exponent of 10 of the semilog plot. Defaults to -5. + final_exp : Number, optional + The final exponent of 10 of the semilog plot. Defaults to 5. + show : boolean, optional + If ``True``, the plot will be displayed otherwise + the equivalent matplotlib ``plot`` object will be returned. + Defaults to True. + prec : int, optional + The decimal point precision for the point coordinate values. + Defaults to 8. + grid : boolean, optional + If ``True``, the plot will have a grid. Defaults to True. + show_axes : boolean, optional + If ``True``, the coordinate axes will be shown. Defaults to False. + freq_unit : string, optional + User can choose between ``'rad/sec'`` (radians/second) and ``'Hz'`` (Hertz) as frequency units. + phase_unit : string, optional + User can choose between ``'rad'`` (radians) and ``'deg'`` (degree) as phase units. + + Examples + ======== + + .. plot:: + :context: close-figs + :format: doctest + :include-source: True + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy.physics.control.control_plots import bode_plot + >>> tf1 = TransferFunction(1*s**2 + 0.1*s + 7.5, 1*s**4 + 0.12*s**3 + 9*s**2, s) + >>> bode_plot(tf1, initial_exp=0.2, final_exp=0.7) # doctest: +SKIP + + See Also + ======== + + bode_magnitude_plot, bode_phase_plot + + """ + plt.subplot(211) + mag = bode_magnitude_plot(system, initial_exp=initial_exp, final_exp=final_exp, + show=False, grid=grid, show_axes=show_axes, + freq_unit=freq_unit, **kwargs) + mag.title(f'Bode Plot of ${latex(system)}$', pad=20) + mag.xlabel(None) + plt.subplot(212) + bode_phase_plot(system, initial_exp=initial_exp, final_exp=final_exp, + show=False, grid=grid, show_axes=show_axes, freq_unit=freq_unit, phase_unit=phase_unit, phase_unwrap=phase_unwrap, **kwargs).title(None) + + if show: + plt.show() + return + + return plt diff --git a/lib/python3.10/site-packages/sympy/physics/control/lti.py b/lib/python3.10/site-packages/sympy/physics/control/lti.py new file mode 100644 index 0000000000000000000000000000000000000000..54349e50e087077435ed2fcdf01c2aed23f0edea --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/control/lti.py @@ -0,0 +1,4304 @@ +from typing import Type +from sympy import Interval, numer, Rational, solveset +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.evalf import EvalfMixin +from sympy.core.expr import Expr +from sympy.core.function import expand +from sympy.core.logic import fuzzy_and +from sympy.core.mul import Mul +from sympy.core.numbers import I, pi, oo +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, Symbol +from sympy.functions import Abs +from sympy.core.sympify import sympify, _sympify +from sympy.matrices import Matrix, ImmutableMatrix, ImmutableDenseMatrix, eye, ShapeError, zeros +from sympy.functions.elementary.exponential import (exp, log) +from sympy.matrices.expressions import MatMul, MatAdd +from sympy.polys import Poly, rootof +from sympy.polys.polyroots import roots +from sympy.polys.polytools import (cancel, degree) +from sympy.series import limit +from sympy.utilities.misc import filldedent + +from mpmath.libmp.libmpf import prec_to_dps + +__all__ = ['TransferFunction', 'Series', 'MIMOSeries', 'Parallel', 'MIMOParallel', + 'Feedback', 'MIMOFeedback', 'TransferFunctionMatrix', 'StateSpace', 'gbt', 'bilinear', 'forward_diff', 'backward_diff', + 'phase_margin', 'gain_margin'] + +def _roots(poly, var): + """ like roots, but works on higher-order polynomials. """ + r = roots(poly, var, multiple=True) + n = degree(poly) + if len(r) != n: + r = [rootof(poly, var, k) for k in range(n)] + return r + +def gbt(tf, sample_per, alpha): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the generalised bilinear transformation method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{T(\alpha z + (1-\alpha))}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, gbt + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = gbt(tf, T, 0.5) + >>> numZ + [T/(2*(L + R*T/2)), T/(2*(L + R*T/2))] + >>> denZ + [1, (-L + R*T/2)/(L + R*T/2)] + + >>> numZ, denZ = gbt(tf, T, 0) + >>> numZ + [T/L] + >>> denZ + [1, (-L + R*T)/L] + + >>> numZ, denZ = gbt(tf, T, 1) + >>> numZ + [T/(L + R*T), 0] + >>> denZ + [1, -L/(L + R*T)] + + >>> numZ, denZ = gbt(tf, T, 0.3) + >>> numZ + [3*T/(10*(L + 3*R*T/10)), 7*T/(10*(L + 3*R*T/10))] + >>> denZ + [1, (-L + 7*R*T/10)/(L + 3*R*T/10)] + + References + ========== + + .. [1] https://www.polyu.edu.hk/ama/profile/gfzhang/Research/ZCC09_IJC.pdf + """ + if not tf.is_SISO: + raise NotImplementedError("Not implemented for MIMO systems.") + + T = sample_per # and sample period T + s = tf.var + z = s # dummy discrete variable z + + np = tf.num.as_poly(s).all_coeffs() + dp = tf.den.as_poly(s).all_coeffs() + alpha = Rational(alpha).limit_denominator(1000) + + # The next line results from multiplying H(z) with z^N/z^N + N = max(len(np), len(dp)) - 1 + num = Add(*[ T**(N-i) * c * (z-1)**i * (alpha * z + 1 - alpha)**(N-i) for c, i in zip(np[::-1], range(len(np))) ]) + den = Add(*[ T**(N-i) * c * (z-1)**i * (alpha * z + 1 - alpha)**(N-i) for c, i in zip(dp[::-1], range(len(dp))) ]) + + num_coefs = num.as_poly(z).all_coeffs() + den_coefs = den.as_poly(z).all_coeffs() + + para = den_coefs[0] + num_coefs = [coef/para for coef in num_coefs] + den_coefs = [coef/para for coef in den_coefs] + + return num_coefs, den_coefs + +def bilinear(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the bilinear transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{2}{T}\frac{z-1}{z+1}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, bilinear + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = bilinear(tf, T) + >>> numZ + [T/(2*(L + R*T/2)), T/(2*(L + R*T/2))] + >>> denZ + [1, (-L + R*T/2)/(L + R*T/2)] + """ + return gbt(tf, sample_per, S.Half) + +def forward_diff(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the forward difference transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{T}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, forward_diff + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = forward_diff(tf, T) + >>> numZ + [T/L] + >>> denZ + [1, (-L + R*T)/L] + """ + return gbt(tf, sample_per, S.Zero) + +def backward_diff(tf, sample_per): + r""" + Returns falling coefficients of H(z) from numerator and denominator. + + Explanation + =========== + + Where H(z) is the corresponding discretized transfer function, + discretized with the backward difference transform method. + H(z) is obtained from the continuous transfer function H(s) + by substituting $s(z) = \frac{z-1}{Tz}$ into H(s), where T is the + sample period. + Coefficients are falling, i.e. $H(z) = \frac{az+b}{cz+d}$ is returned + as [a, b], [c, d]. + + Examples + ======== + + >>> from sympy.physics.control.lti import TransferFunction, backward_diff + >>> from sympy.abc import s, L, R, T + + >>> tf = TransferFunction(1, s*L + R, s) + >>> numZ, denZ = backward_diff(tf, T) + >>> numZ + [T/(L + R*T), 0] + >>> denZ + [1, -L/(L + R*T)] + """ + return gbt(tf, sample_per, S.One) + +def phase_margin(system): + r""" + Returns the phase margin of a continuous time system. + Only applicable to Transfer Functions which can generate valid bode plots. + + Raises + ====== + + NotImplementedError + When time delay terms are present in the system. + + ValueError + When a SISO LTI system is not passed. + + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.physics.control import TransferFunction, phase_margin + >>> from sympy.abc import s + + >>> tf = TransferFunction(1, s**3 + 2*s**2 + s, s) + >>> phase_margin(tf) + 180*(-pi + atan((-1 + (-2*18**(1/3)/(9 + sqrt(93))**(1/3) + 12**(1/3)*(9 + sqrt(93))**(1/3))**2/36)/(-12**(1/3)*(9 + sqrt(93))**(1/3)/3 + 2*18**(1/3)/(3*(9 + sqrt(93))**(1/3)))))/pi + 180 + >>> phase_margin(tf).n() + 21.3863897518751 + + >>> tf1 = TransferFunction(s**3, s**2 + 5*s, s) + >>> phase_margin(tf1) + -180 + 180*(atan(sqrt(2)*(-51/10 - sqrt(101)/10)*sqrt(1 + sqrt(101))/(2*(sqrt(101)/2 + 51/2))) + pi)/pi + >>> phase_margin(tf1).n() + -25.1783920627277 + + >>> tf2 = TransferFunction(1, s + 1, s) + >>> phase_margin(tf2) + -180 + + See Also + ======== + + gain_margin + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Phase_margin + + """ + from sympy.functions import arg + + if not isinstance(system, SISOLinearTimeInvariant): + raise ValueError("Margins are only applicable for SISO LTI systems.") + + _w = Dummy("w", real=True) + repl = I*_w + expr = system.to_expr() + len_free_symbols = len(expr.free_symbols) + if expr.has(exp): + raise NotImplementedError("Margins for systems with Time delay terms are not supported.") + elif len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + mag_sol = list(solveset(mag, _w, Interval(0, oo, left_open=True))) + + if (len(mag_sol) == 0): + pm = S(-180) + else: + wcp = mag_sol[0] + pm = ((arg(w_expr)*S(180)/pi).subs({_w:wcp}) + S(180)) % 360 + + if(pm >= 180): + pm = pm - 360 + + return pm + +def gain_margin(system): + r""" + Returns the gain margin of a continuous time system. + Only applicable to Transfer Functions which can generate valid bode plots. + + Raises + ====== + + NotImplementedError + When time delay terms are present in the system. + + ValueError + When a SISO LTI system is not passed. + + When more than one free symbol is present in the system. + The only variable in the transfer function should be + the variable of the Laplace transform. + + Examples + ======== + + >>> from sympy.physics.control import TransferFunction, gain_margin + >>> from sympy.abc import s + + >>> tf = TransferFunction(1, s**3 + 2*s**2 + s, s) + >>> gain_margin(tf) + 20*log(2)/log(10) + >>> gain_margin(tf).n() + 6.02059991327962 + + >>> tf1 = TransferFunction(s**3, s**2 + 5*s, s) + >>> gain_margin(tf1) + oo + + See Also + ======== + + phase_margin + + References + ========== + + https://en.wikipedia.org/wiki/Bode_plot + + """ + if not isinstance(system, SISOLinearTimeInvariant): + raise ValueError("Margins are only applicable for SISO LTI systems.") + + _w = Dummy("w", real=True) + repl = I*_w + expr = system.to_expr() + len_free_symbols = len(expr.free_symbols) + if expr.has(exp): + raise NotImplementedError("Margins for systems with Time delay terms are not supported.") + elif len_free_symbols > 1: + raise ValueError("Extra degree of freedom found. Make sure" + " that there are no free symbols in the dynamical system other" + " than the variable of Laplace transform.") + + w_expr = expr.subs({system.var: repl}) + + mag = 20*log(Abs(w_expr), 10) + phase = w_expr + phase_sol = list(solveset(numer(phase.as_real_imag()[1].cancel()),_w, Interval(0, oo, left_open = True))) + + if (len(phase_sol) == 0): + gm = oo + else: + wcg = phase_sol[0] + gm = -mag.subs({_w:wcg}) + + return gm + +class LinearTimeInvariant(Basic, EvalfMixin): + """A common class for all the Linear Time-Invariant Dynamical Systems.""" + + _clstype: Type + + # Users should not directly interact with this class. + def __new__(cls, *system, **kwargs): + if cls is LinearTimeInvariant: + raise NotImplementedError('The LTICommon class is not meant to be used directly.') + return super(LinearTimeInvariant, cls).__new__(cls, *system, **kwargs) + + @classmethod + def _check_args(cls, args): + if not args: + raise ValueError("At least 1 argument must be passed.") + if not all(isinstance(arg, cls._clstype) for arg in args): + raise TypeError(f"All arguments must be of type {cls._clstype}.") + var_set = {arg.var for arg in args} + if len(var_set) != 1: + raise ValueError(filldedent(f""" + All transfer functions should use the same complex variable + of the Laplace transform. {len(var_set)} different + values found.""")) + + @property + def is_SISO(self): + """Returns `True` if the passed LTI system is SISO else returns False.""" + return self._is_SISO + + +class SISOLinearTimeInvariant(LinearTimeInvariant): + """A common class for all the SISO Linear Time-Invariant Dynamical Systems.""" + # Users should not directly interact with this class. + _is_SISO = True + + +class MIMOLinearTimeInvariant(LinearTimeInvariant): + """A common class for all the MIMO Linear Time-Invariant Dynamical Systems.""" + # Users should not directly interact with this class. + _is_SISO = False + + +SISOLinearTimeInvariant._clstype = SISOLinearTimeInvariant +MIMOLinearTimeInvariant._clstype = MIMOLinearTimeInvariant + + +def _check_other_SISO(func): + def wrapper(*args, **kwargs): + if not isinstance(args[-1], SISOLinearTimeInvariant): + return NotImplemented + else: + return func(*args, **kwargs) + return wrapper + + +def _check_other_MIMO(func): + def wrapper(*args, **kwargs): + if not isinstance(args[-1], MIMOLinearTimeInvariant): + return NotImplemented + else: + return func(*args, **kwargs) + return wrapper + + +class TransferFunction(SISOLinearTimeInvariant): + r""" + A class for representing LTI (Linear, time-invariant) systems that can be strictly described + by ratio of polynomials in the Laplace transform complex variable. The arguments + are ``num``, ``den``, and ``var``, where ``num`` and ``den`` are numerator and + denominator polynomials of the ``TransferFunction`` respectively, and the third argument is + a complex variable of the Laplace transform used by these polynomials of the transfer function. + ``num`` and ``den`` can be either polynomials or numbers, whereas ``var`` + has to be a :py:class:`~.Symbol`. + + Explanation + =========== + + Generally, a dynamical system representing a physical model can be described in terms of Linear + Ordinary Differential Equations like - + + $\small{b_{m}y^{\left(m\right)}+b_{m-1}y^{\left(m-1\right)}+\dots+b_{1}y^{\left(1\right)}+b_{0}y= + a_{n}x^{\left(n\right)}+a_{n-1}x^{\left(n-1\right)}+\dots+a_{1}x^{\left(1\right)}+a_{0}x}$ + + Here, $x$ is the input signal and $y$ is the output signal and superscript on both is the order of derivative + (not exponent). Derivative is taken with respect to the independent variable, $t$. Also, generally $m$ is greater + than $n$. + + It is not feasible to analyse the properties of such systems in their native form therefore, we use + mathematical tools like Laplace transform to get a better perspective. Taking the Laplace transform + of both the sides in the equation (at zero initial conditions), we get - + + $\small{\mathcal{L}[b_{m}y^{\left(m\right)}+b_{m-1}y^{\left(m-1\right)}+\dots+b_{1}y^{\left(1\right)}+b_{0}y]= + \mathcal{L}[a_{n}x^{\left(n\right)}+a_{n-1}x^{\left(n-1\right)}+\dots+a_{1}x^{\left(1\right)}+a_{0}x]}$ + + Using the linearity property of Laplace transform and also considering zero initial conditions + (i.e. $\small{y(0^{-}) = 0}$, $\small{y'(0^{-}) = 0}$ and so on), the equation + above gets translated to - + + $\small{b_{m}\mathcal{L}[y^{\left(m\right)}]+\dots+b_{1}\mathcal{L}[y^{\left(1\right)}]+b_{0}\mathcal{L}[y]= + a_{n}\mathcal{L}[x^{\left(n\right)}]+\dots+a_{1}\mathcal{L}[x^{\left(1\right)}]+a_{0}\mathcal{L}[x]}$ + + Now, applying Derivative property of Laplace transform, + + $\small{b_{m}s^{m}\mathcal{L}[y]+\dots+b_{1}s\mathcal{L}[y]+b_{0}\mathcal{L}[y]= + a_{n}s^{n}\mathcal{L}[x]+\dots+a_{1}s\mathcal{L}[x]+a_{0}\mathcal{L}[x]}$ + + Here, the superscript on $s$ is **exponent**. Note that the zero initial conditions assumption, mentioned above, is very important + and cannot be ignored otherwise the dynamical system cannot be considered time-independent and the simplified equation above + cannot be reached. + + Collecting $\mathcal{L}[y]$ and $\mathcal{L}[x]$ terms from both the sides and taking the ratio + $\frac{ \mathcal{L}\left\{y\right\} }{ \mathcal{L}\left\{x\right\} }$, we get the typical rational form of transfer + function. + + The numerator of the transfer function is, therefore, the Laplace transform of the output signal + (The signals are represented as functions of time) and similarly, the denominator + of the transfer function is the Laplace transform of the input signal. It is also a convention + to denote the input and output signal's Laplace transform with capital alphabets like shown below. + + $H(s) = \frac{Y(s)}{X(s)} = \frac{ \mathcal{L}\left\{y(t)\right\} }{ \mathcal{L}\left\{x(t)\right\} }$ + + $s$, also known as complex frequency, is a complex variable in the Laplace domain. It corresponds to the + equivalent variable $t$, in the time domain. Transfer functions are sometimes also referred to as the Laplace + transform of the system's impulse response. Transfer function, $H$, is represented as a rational + function in $s$ like, + + $H(s) =\ \frac{a_{n}s^{n}+a_{n-1}s^{n-1}+\dots+a_{1}s+a_{0}}{b_{m}s^{m}+b_{m-1}s^{m-1}+\dots+b_{1}s+b_{0}}$ + + Parameters + ========== + + num : Expr, Number + The numerator polynomial of the transfer function. + den : Expr, Number + The denominator polynomial of the transfer function. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + TypeError + When ``var`` is not a Symbol or when ``num`` or ``den`` is not a + number or a polynomial. + ValueError + When ``den`` is zero. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(s + a, s**2 + s + 1, s) + >>> tf1 + TransferFunction(a + s, s**2 + s + 1, s) + >>> tf1.num + a + s + >>> tf1.den + s**2 + s + 1 + >>> tf1.var + s + >>> tf1.args + (a + s, s**2 + s + 1, s) + + Any complex variable can be used for ``var``. + + >>> tf2 = TransferFunction(a*p**3 - a*p**2 + s*p, p + a**2, p) + >>> tf2 + TransferFunction(a*p**3 - a*p**2 + p*s, a**2 + p, p) + >>> tf3 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf3 + TransferFunction((p - 1)*(p + 3), (p - 1)*(p + 5), p) + + To negate a transfer function the ``-`` operator can be prepended: + + >>> tf4 = TransferFunction(-a + s, p**2 + s, p) + >>> -tf4 + TransferFunction(a - s, p**2 + s, p) + >>> tf5 = TransferFunction(s**4 - 2*s**3 + 5*s + 4, s + 4, s) + >>> -tf5 + TransferFunction(-s**4 + 2*s**3 - 5*s - 4, s + 4, s) + + You can use a float or an integer (or other constants) as numerator and denominator: + + >>> tf6 = TransferFunction(1/2, 4, s) + >>> tf6.num + 0.500000000000000 + >>> tf6.den + 4 + >>> tf6.var + s + >>> tf6.args + (0.5, 4, s) + + You can take the integer power of a transfer function using the ``**`` operator: + + >>> tf7 = TransferFunction(s + a, s - a, s) + >>> tf7**3 + TransferFunction((a + s)**3, (-a + s)**3, s) + >>> tf7**0 + TransferFunction(1, 1, s) + >>> tf8 = TransferFunction(p + 4, p - 3, p) + >>> tf8**-1 + TransferFunction(p - 3, p + 4, p) + + Addition, subtraction, and multiplication of transfer functions can form + unevaluated ``Series`` or ``Parallel`` objects. + + >>> tf9 = TransferFunction(s + 1, s**2 + s + 1, s) + >>> tf10 = TransferFunction(s - p, s + 3, s) + >>> tf11 = TransferFunction(4*s**2 + 2*s - 4, s - 1, s) + >>> tf12 = TransferFunction(1 - s, s**2 + 4, s) + >>> tf9 + tf10 + Parallel(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-p + s, s + 3, s)) + >>> tf10 - tf11 + Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(-4*s**2 - 2*s + 4, s - 1, s)) + >>> tf9 * tf10 + Series(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-p + s, s + 3, s)) + >>> tf10 - (tf9 + tf12) + Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(-s - 1, s**2 + s + 1, s), TransferFunction(s - 1, s**2 + 4, s)) + >>> tf10 - (tf9 * tf12) + Parallel(TransferFunction(-p + s, s + 3, s), Series(TransferFunction(-1, 1, s), TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(1 - s, s**2 + 4, s))) + >>> tf11 * tf10 * tf9 + Series(TransferFunction(4*s**2 + 2*s - 4, s - 1, s), TransferFunction(-p + s, s + 3, s), TransferFunction(s + 1, s**2 + s + 1, s)) + >>> tf9 * tf11 + tf10 * tf12 + Parallel(Series(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(4*s**2 + 2*s - 4, s - 1, s)), Series(TransferFunction(-p + s, s + 3, s), TransferFunction(1 - s, s**2 + 4, s))) + >>> (tf9 + tf12) * (tf10 + tf11) + Series(Parallel(TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(1 - s, s**2 + 4, s)), Parallel(TransferFunction(-p + s, s + 3, s), TransferFunction(4*s**2 + 2*s - 4, s - 1, s))) + + These unevaluated ``Series`` or ``Parallel`` objects can convert into the + resultant transfer function using ``.doit()`` method or by ``.rewrite(TransferFunction)``. + + >>> ((tf9 + tf10) * tf12).doit() + TransferFunction((1 - s)*((-p + s)*(s**2 + s + 1) + (s + 1)*(s + 3)), (s + 3)*(s**2 + 4)*(s**2 + s + 1), s) + >>> (tf9 * tf10 - tf11 * tf12).rewrite(TransferFunction) + TransferFunction(-(1 - s)*(s + 3)*(s**2 + s + 1)*(4*s**2 + 2*s - 4) + (-p + s)*(s - 1)*(s + 1)*(s**2 + 4), (s - 1)*(s + 3)*(s**2 + 4)*(s**2 + s + 1), s) + + See Also + ======== + + Feedback, Series, Parallel + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Transfer_function + .. [2] https://en.wikipedia.org/wiki/Laplace_transform + + """ + def __new__(cls, num, den, var): + num, den = _sympify(num), _sympify(den) + + if not isinstance(var, Symbol): + raise TypeError("Variable input must be a Symbol.") + + if den == 0: + raise ValueError("TransferFunction cannot have a zero denominator.") + + if (((isinstance(num, (Expr, TransferFunction, Series, Parallel)) and num.has(Symbol)) or num.is_number) and + ((isinstance(den, (Expr, TransferFunction, Series, Parallel)) and den.has(Symbol)) or den.is_number)): + return super(TransferFunction, cls).__new__(cls, num, den, var) + + else: + raise TypeError("Unsupported type for numerator or denominator of TransferFunction.") + + @classmethod + def from_rational_expression(cls, expr, var=None): + r""" + Creates a new ``TransferFunction`` efficiently from a rational expression. + + Parameters + ========== + + expr : Expr, Number + The rational expression representing the ``TransferFunction``. + var : Symbol, optional + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + ValueError + When ``expr`` is of type ``Number`` and optional parameter ``var`` + is not passed. + + When ``expr`` has more than one variables and an optional parameter + ``var`` is not passed. + ZeroDivisionError + When denominator of ``expr`` is zero or it has ``ComplexInfinity`` + in its numerator. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> expr1 = (s + 5)/(3*s**2 + 2*s + 1) + >>> tf1 = TransferFunction.from_rational_expression(expr1) + >>> tf1 + TransferFunction(s + 5, 3*s**2 + 2*s + 1, s) + >>> expr2 = (a*p**3 - a*p**2 + s*p)/(p + a**2) # Expr with more than one variables + >>> tf2 = TransferFunction.from_rational_expression(expr2, p) + >>> tf2 + TransferFunction(a*p**3 - a*p**2 + p*s, a**2 + p, p) + + In case of conflict between two or more variables in a expression, SymPy will + raise a ``ValueError``, if ``var`` is not passed by the user. + + >>> tf = TransferFunction.from_rational_expression((a + a*s)/(s**2 + s + 1)) + Traceback (most recent call last): + ... + ValueError: Conflicting values found for positional argument `var` ({a, s}). Specify it manually. + + This can be corrected by specifying the ``var`` parameter manually. + + >>> tf = TransferFunction.from_rational_expression((a + a*s)/(s**2 + s + 1), s) + >>> tf + TransferFunction(a*s + a, s**2 + s + 1, s) + + ``var`` also need to be specified when ``expr`` is a ``Number`` + + >>> tf3 = TransferFunction.from_rational_expression(10, s) + >>> tf3 + TransferFunction(10, 1, s) + + """ + expr = _sympify(expr) + if var is None: + _free_symbols = expr.free_symbols + _len_free_symbols = len(_free_symbols) + if _len_free_symbols == 1: + var = list(_free_symbols)[0] + elif _len_free_symbols == 0: + raise ValueError(filldedent(""" + Positional argument `var` not found in the + TransferFunction defined. Specify it manually.""")) + else: + raise ValueError(filldedent(""" + Conflicting values found for positional argument `var` ({}). + Specify it manually.""".format(_free_symbols))) + + _num, _den = expr.as_numer_denom() + if _den == 0 or _num.has(S.ComplexInfinity): + raise ZeroDivisionError("TransferFunction cannot have a zero denominator.") + return cls(_num, _den, var) + + @classmethod + def from_coeff_lists(cls, num_list, den_list, var): + r""" + Creates a new ``TransferFunction`` efficiently from a list of coefficients. + + Parameters + ========== + + num_list : Sequence + Sequence comprising of numerator coefficients. + den_list : Sequence + Sequence comprising of denominator coefficients. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Raises + ====== + + ZeroDivisionError + When the constructed denominator is zero. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> num = [1, 0, 2] + >>> den = [3, 2, 2, 1] + >>> tf = TransferFunction.from_coeff_lists(num, den, s) + >>> tf + TransferFunction(s**2 + 2, 3*s**3 + 2*s**2 + 2*s + 1, s) + + # Create a Transfer Function with more than one variable + >>> tf1 = TransferFunction.from_coeff_lists([p, 1], [2*p, 0, 4], s) + >>> tf1 + TransferFunction(p*s + 1, 2*p*s**2 + 4, s) + + """ + num_list = num_list[::-1] + den_list = den_list[::-1] + num_var_powers = [var**i for i in range(len(num_list))] + den_var_powers = [var**i for i in range(len(den_list))] + + _num = sum(coeff * var_power for coeff, var_power in zip(num_list, num_var_powers)) + _den = sum(coeff * var_power for coeff, var_power in zip(den_list, den_var_powers)) + + if _den == 0: + raise ZeroDivisionError("TransferFunction cannot have a zero denominator.") + + return cls(_num, _den, var) + + @classmethod + def from_zpk(cls, zeros, poles, gain, var): + r""" + Creates a new ``TransferFunction`` from given zeros, poles and gain. + + Parameters + ========== + + zeros : Sequence + Sequence comprising of zeros of transfer function. + poles : Sequence + Sequence comprising of poles of transfer function. + gain : Number, Symbol, Expression + A scalar value specifying gain of the model. + var : Symbol + Complex variable of the Laplace transform used by the + polynomials of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, k + >>> from sympy.physics.control.lti import TransferFunction + >>> zeros = [1, 2, 3] + >>> poles = [6, 5, 4] + >>> gain = 7 + >>> tf = TransferFunction.from_zpk(zeros, poles, gain, s) + >>> tf + TransferFunction(7*(s - 3)*(s - 2)*(s - 1), (s - 6)*(s - 5)*(s - 4), s) + + # Create a Transfer Function with variable poles and zeros + >>> tf1 = TransferFunction.from_zpk([p, k], [p + k, p - k], 2, s) + >>> tf1 + TransferFunction(2*(-k + s)*(-p + s), (-k - p + s)*(k - p + s), s) + + # Complex poles or zeros are acceptable + >>> tf2 = TransferFunction.from_zpk([0], [1-1j, 1+1j, 2], -2, s) + >>> tf2 + TransferFunction(-2*s, (s - 2)*(s - 1.0 - 1.0*I)*(s - 1.0 + 1.0*I), s) + + """ + num_poly = 1 + den_poly = 1 + for zero in zeros: + num_poly *= var - zero + for pole in poles: + den_poly *= var - pole + + return cls(gain*num_poly, den_poly, var) + + @property + def num(self): + """ + Returns the numerator polynomial of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(s**2 + p*s + 3, s - 4, s) + >>> G1.num + p*s + s**2 + 3 + >>> G2 = TransferFunction((p + 5)*(p - 3), (p - 3)*(p + 1), p) + >>> G2.num + (p - 3)*(p + 5) + + """ + return self.args[0] + + @property + def den(self): + """ + Returns the denominator polynomial of the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(s + 4, p**3 - 2*p + 4, s) + >>> G1.den + p**3 - 2*p + 4 + >>> G2 = TransferFunction(3, 4, s) + >>> G2.den + 4 + + """ + return self.args[1] + + @property + def var(self): + """ + Returns the complex variable of the Laplace transform used by the polynomials of + the transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G1.var + p + >>> G2 = TransferFunction(0, s - 5, s) + >>> G2.var + s + + """ + return self.args[2] + + def _eval_subs(self, old, new): + arg_num = self.num.subs(old, new) + arg_den = self.den.subs(old, new) + argnew = TransferFunction(arg_num, arg_den, self.var) + return self if old == self.var else argnew + + def _eval_evalf(self, prec): + return TransferFunction( + self.num._eval_evalf(prec), + self.den._eval_evalf(prec), + self.var) + + def _eval_simplify(self, **kwargs): + tf = cancel(Mul(self.num, 1/self.den, evaluate=False), expand=False).as_numer_denom() + num_, den_ = tf[0], tf[1] + return TransferFunction(num_, den_, self.var) + + def _eval_rewrite_as_StateSpace(self, *args): + """ + Returns the equivalent space space model of the transfer function model. + The state space model will be returned in the controllable cannonical form. + + Unlike the space state to transfer function model conversion, the transfer function + to state space model conversion is not unique. There can be multiple state space + representations of a given transfer function model. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control import TransferFunction, StateSpace + >>> tf = TransferFunction(s**2 + 1, s**3 + 2*s + 10, s) + >>> tf.rewrite(StateSpace) + StateSpace(Matrix([ + [ 0, 1, 0], + [ 0, 0, 1], + [-10, -2, 0]]), Matrix([ + [0], + [0], + [1]]), Matrix([[1, 0, 1]]), Matrix([[0]])) + + """ + if not self.is_proper: + raise ValueError("Transfer Function must be proper.") + + num_poly = Poly(self.num, self.var) + den_poly = Poly(self.den, self.var) + n = den_poly.degree() + + num_coeffs = num_poly.all_coeffs() + den_coeffs = den_poly.all_coeffs() + diff = n - num_poly.degree() + num_coeffs = [0]*diff + num_coeffs + + a = den_coeffs[1:] + a_mat = Matrix([[(-1)*coefficient/den_coeffs[0] for coefficient in reversed(a)]]) + vert = zeros(n-1, 1) + mat = eye(n-1) + A = vert.row_join(mat) + A = A.col_join(a_mat) + + B = zeros(n, 1) + B[n-1] = 1 + + i = n + C = [] + while(i > 0): + C.append(num_coeffs[i] - den_coeffs[i]*num_coeffs[0]) + i -= 1 + C = Matrix([C]) + + D = Matrix([num_coeffs[0]]) + + return StateSpace(A, B, C, D) + + def expand(self): + """ + Returns the transfer function with numerator and denominator + in expanded form. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> G1 = TransferFunction((a - s)**2, (s**2 + a)**2, s) + >>> G1.expand() + TransferFunction(a**2 - 2*a*s + s**2, a**2 + 2*a*s**2 + s**4, s) + >>> G2 = TransferFunction((p + 3*b)*(p - b), (p - b)*(p + 2*b), p) + >>> G2.expand() + TransferFunction(-3*b**2 + 2*b*p + p**2, -2*b**2 + b*p + p**2, p) + + """ + return TransferFunction(expand(self.num), expand(self.den), self.var) + + def dc_gain(self): + """ + Computes the gain of the response as the frequency approaches zero. + + The DC gain is infinite for systems with pure integrators. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(s + 3, s**2 - 9, s) + >>> tf1.dc_gain() + -1/3 + >>> tf2 = TransferFunction(p**2, p - 3 + p**3, p) + >>> tf2.dc_gain() + 0 + >>> tf3 = TransferFunction(a*p**2 - b, s + b, s) + >>> tf3.dc_gain() + (a*p**2 - b)/b + >>> tf4 = TransferFunction(1, s, s) + >>> tf4.dc_gain() + oo + + """ + m = Mul(self.num, Pow(self.den, -1, evaluate=False), evaluate=False) + return limit(m, self.var, 0) + + def poles(self): + """ + Returns the poles of a transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf1.poles() + [-5, 1] + >>> tf2 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + >>> tf2.poles() + [I, I, -I, -I] + >>> tf3 = TransferFunction(s**2, a*s + p, s) + >>> tf3.poles() + [-p/a] + + """ + return _roots(Poly(self.den, self.var), self.var) + + def zeros(self): + """ + Returns the zeros of a transfer function. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction((p + 3)*(p - 1), (p - 1)*(p + 5), p) + >>> tf1.zeros() + [-3, 1] + >>> tf2 = TransferFunction((1 - s)**2, (s**2 + 1)**2, s) + >>> tf2.zeros() + [1, 1] + >>> tf3 = TransferFunction(s**2, a*s + p, s) + >>> tf3.zeros() + [0, 0] + + """ + return _roots(Poly(self.num, self.var), self.var) + + def eval_frequency(self, other): + """ + Returns the system response at any point in the real or complex plane. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy import I + >>> tf1 = TransferFunction(1, s**2 + 2*s + 1, s) + >>> omega = 0.1 + >>> tf1.eval_frequency(I*omega) + 1/(0.99 + 0.2*I) + >>> tf2 = TransferFunction(s**2, a*s + p, s) + >>> tf2.eval_frequency(2) + 4/(2*a + p) + >>> tf2.eval_frequency(I*2) + -4/(2*I*a + p) + """ + arg_num = self.num.subs(self.var, other) + arg_den = self.den.subs(self.var, other) + argnew = TransferFunction(arg_num, arg_den, self.var).to_expr() + return argnew.expand() + + def is_stable(self): + """ + Returns True if the transfer function is asymptotically stable; else False. + + This would not check the marginal or conditional stability of the system. + + Examples + ======== + + >>> from sympy.abc import s, p, a + >>> from sympy import symbols + >>> from sympy.physics.control.lti import TransferFunction + >>> q, r = symbols('q, r', negative=True) + >>> tf1 = TransferFunction((1 - s)**2, (s + 1)**2, s) + >>> tf1.is_stable() + True + >>> tf2 = TransferFunction((1 - p)**2, (s**2 + 1)**2, s) + >>> tf2.is_stable() + False + >>> tf3 = TransferFunction(4, q*s - r, s) + >>> tf3.is_stable() + False + >>> tf4 = TransferFunction(p + 1, a*p - s**2, p) + >>> tf4.is_stable() is None # Not enough info about the symbols to determine stability + True + + """ + return fuzzy_and(pole.as_real_imag()[0].is_negative for pole in self.poles()) + + def __add__(self, other): + if isinstance(other, (TransferFunction, Series)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Parallel(self, other) + elif isinstance(other, Parallel): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = list(other.args) + return Parallel(self, *arg_list) + else: + raise ValueError("TransferFunction cannot be added with {}.". + format(type(other))) + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + if isinstance(other, (TransferFunction, Series)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Parallel(self, -other) + elif isinstance(other, Parallel): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = [-i for i in list(other.args)] + return Parallel(self, *arg_list) + else: + raise ValueError("{} cannot be subtracted from a TransferFunction." + .format(type(other))) + + def __rsub__(self, other): + return -self + other + + def __mul__(self, other): + if isinstance(other, (TransferFunction, Parallel)): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Series(self, other) + elif isinstance(other, Series): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + arg_list = list(other.args) + return Series(self, *arg_list) + else: + raise ValueError("TransferFunction cannot be multiplied with {}." + .format(type(other))) + + __rmul__ = __mul__ + + def __truediv__(self, other): + if isinstance(other, TransferFunction): + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + return Series(self, TransferFunction(other.den, other.num, self.var)) + elif (isinstance(other, Parallel) and len(other.args + ) == 2 and isinstance(other.args[0], TransferFunction) + and isinstance(other.args[1], (Series, TransferFunction))): + + if not self.var == other.var: + raise ValueError(filldedent(""" + Both TransferFunction and Parallel should use the + same complex variable of the Laplace transform.""")) + if other.args[1] == self: + # plant and controller with unit feedback. + return Feedback(self, other.args[0]) + other_arg_list = list(other.args[1].args) if isinstance( + other.args[1], Series) else other.args[1] + if other_arg_list == other.args[1]: + return Feedback(self, other_arg_list) + elif self in other_arg_list: + other_arg_list.remove(self) + else: + return Feedback(self, Series(*other_arg_list)) + + if len(other_arg_list) == 1: + return Feedback(self, *other_arg_list) + else: + return Feedback(self, Series(*other_arg_list)) + else: + raise ValueError("TransferFunction cannot be divided by {}.". + format(type(other))) + + __rtruediv__ = __truediv__ + + def __pow__(self, p): + p = sympify(p) + if not p.is_Integer: + raise ValueError("Exponent must be an integer.") + if p is S.Zero: + return TransferFunction(1, 1, self.var) + elif p > 0: + num_, den_ = self.num**p, self.den**p + else: + p = abs(p) + num_, den_ = self.den**p, self.num**p + + return TransferFunction(num_, den_, self.var) + + def __neg__(self): + return TransferFunction(-self.num, self.den, self.var) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial is less than + or equal to degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf1.is_proper + False + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*p + 2, p) + >>> tf2.is_proper + True + + """ + return degree(self.num, self.var) <= degree(self.den, self.var) + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial is strictly less + than degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf1.is_strictly_proper + False + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf2.is_strictly_proper + True + + """ + return degree(self.num, self.var) < degree(self.den, self.var) + + @property + def is_biproper(self): + """ + Returns True if degree of the numerator polynomial is equal to + degree of the denominator polynomial, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf1.is_biproper + True + >>> tf2 = TransferFunction(p**2, p + a, p) + >>> tf2.is_biproper + False + + """ + return degree(self.num, self.var) == degree(self.den, self.var) + + def to_expr(self): + """ + Converts a ``TransferFunction`` object to SymPy Expr. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction + >>> from sympy import Expr + >>> tf1 = TransferFunction(s, a*s**2 + 1, s) + >>> tf1.to_expr() + s/(a*s**2 + 1) + >>> isinstance(_, Expr) + True + >>> tf2 = TransferFunction(1, (p + 3*b)*(b - p), p) + >>> tf2.to_expr() + 1/((b - p)*(3*b + p)) + >>> tf3 = TransferFunction((s - 2)*(s - 3), (s - 1)*(s - 2)*(s - 3), s) + >>> tf3.to_expr() + ((s - 3)*(s - 2))/(((s - 3)*(s - 2)*(s - 1))) + + """ + + if self.num != 1: + return Mul(self.num, Pow(self.den, -1, evaluate=False), evaluate=False) + else: + return Pow(self.den, -1, evaluate=False) + + +def _flatten_args(args, _cls): + temp_args = [] + for arg in args: + if isinstance(arg, _cls): + temp_args.extend(arg.args) + else: + temp_args.append(arg) + return tuple(temp_args) + + +def _dummify_args(_arg, var): + dummy_dict = {} + dummy_arg_list = [] + + for arg in _arg: + _s = Dummy() + dummy_dict[_s] = var + dummy_arg = arg.subs({var: _s}) + dummy_arg_list.append(dummy_arg) + + return dummy_arg_list, dummy_dict + + +class Series(SISOLinearTimeInvariant): + r""" + A class for representing a series configuration of SISO systems. + + Parameters + ========== + + args : SISOLinearTimeInvariant + SISO systems in a series configuration. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``Series(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, SISO in this case. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(p**2, p + s, s) + >>> S1 = Series(tf1, tf2) + >>> S1 + Series(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)) + >>> S1.var + s + >>> S2 = Series(tf2, Parallel(tf3, -tf1)) + >>> S2 + Series(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), Parallel(TransferFunction(p**2, p + s, s), TransferFunction(-a*p**2 - b*s, -p + s, s))) + >>> S2.var + s + >>> S3 = Series(Parallel(tf1, tf2), Parallel(tf2, tf3)) + >>> S3 + Series(Parallel(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)), Parallel(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), TransferFunction(p**2, p + s, s))) + >>> S3.var + s + + You can get the resultant transfer function by using ``.doit()`` method: + + >>> S3 = Series(tf1, tf2, -tf3) + >>> S3.doit() + TransferFunction(-p**2*(s**3 - 2)*(a*p**2 + b*s), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + >>> S4 = Series(tf2, Parallel(tf1, -tf3)) + >>> S4.doit() + TransferFunction((s**3 - 2)*(-p**2*(-p + s) + (p + s)*(a*p**2 + b*s)), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + + Notes + ===== + + All the transfer functions should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + MIMOSeries, Parallel, TransferFunction, Feedback + + """ + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, Series) + cls._check_args(args) + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Series, Parallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> Series(G1, G2).var + p + >>> Series(-G3, Parallel(G1, G2)).var + p + + """ + return self.args[0].var + + def doit(self, **hints): + """ + Returns the resultant transfer function obtained after evaluating + the transfer functions in series configuration. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> Series(tf2, tf1).doit() + TransferFunction((s**3 - 2)*(a*p**2 + b*s), (-p + s)*(s**4 + 5*s + 6), s) + >>> Series(-tf1, -tf2).doit() + TransferFunction((2 - s**3)*(-a*p**2 - b*s), (-p + s)*(s**4 + 5*s + 6), s) + + """ + + _num_arg = (arg.doit().num for arg in self.args) + _den_arg = (arg.doit().den for arg in self.args) + res_num = Mul(*_num_arg, evaluate=True) + res_den = Mul(*_den_arg, evaluate=True) + return TransferFunction(res_num, res_den, self.var) + + def _eval_rewrite_as_TransferFunction(self, *args, **kwargs): + return self.doit() + + @_check_other_SISO + def __add__(self, other): + + if isinstance(other, Parallel): + arg_list = list(other.args) + return Parallel(self, *arg_list) + + return Parallel(self, other) + + __radd__ = __add__ + + @_check_other_SISO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_SISO + def __mul__(self, other): + + arg_list = list(self.args) + return Series(*arg_list, other) + + def __truediv__(self, other): + if isinstance(other, TransferFunction): + return Series(*self.args, TransferFunction(other.den, other.num, other.var)) + elif isinstance(other, Series): + tf_self = self.rewrite(TransferFunction) + tf_other = other.rewrite(TransferFunction) + return tf_self / tf_other + elif (isinstance(other, Parallel) and len(other.args) == 2 + and isinstance(other.args[0], TransferFunction) and isinstance(other.args[1], Series)): + + if not self.var == other.var: + raise ValueError(filldedent(""" + All the transfer functions should use the same complex variable + of the Laplace transform.""")) + self_arg_list = set(self.args) + other_arg_list = set(other.args[1].args) + res = list(self_arg_list ^ other_arg_list) + if len(res) == 0: + return Feedback(self, other.args[0]) + elif len(res) == 1: + return Feedback(self, *res) + else: + return Feedback(self, Series(*res)) + else: + raise ValueError("This transfer function expression is invalid.") + + def __neg__(self): + return Series(TransferFunction(-1, 1, self.var), self) + + def to_expr(self): + """Returns the equivalent ``Expr`` object.""" + return Mul(*(arg.to_expr() for arg in self.args), evaluate=False) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is less than or equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*s + 2, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> S1 = Series(-tf2, tf1) + >>> S1.is_proper + False + >>> S2 = Series(tf1, tf2, tf3) + >>> S2.is_proper + True + + """ + return self.doit().is_proper + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is strictly less than degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**2 + 5*s + 6, s) + >>> tf3 = TransferFunction(1, s**2 + s + 1, s) + >>> S1 = Series(tf1, tf2) + >>> S1.is_strictly_proper + False + >>> S2 = Series(tf1, tf2, tf3) + >>> S2.is_strictly_proper + True + + """ + return self.doit().is_strictly_proper + + @property + def is_biproper(self): + r""" + Returns True if degree of the numerator polynomial of the resultant transfer + function is equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(p, s**2, s) + >>> tf3 = TransferFunction(s**2, 1, s) + >>> S1 = Series(tf1, -tf2) + >>> S1.is_biproper + False + >>> S2 = Series(tf2, tf3) + >>> S2.is_biproper + True + + """ + return self.doit().is_biproper + + +def _mat_mul_compatible(*args): + """To check whether shapes are compatible for matrix mul.""" + return all(args[i].num_outputs == args[i+1].num_inputs for i in range(len(args)-1)) + + +class MIMOSeries(MIMOLinearTimeInvariant): + r""" + A class for representing a series configuration of MIMO systems. + + Parameters + ========== + + args : MIMOLinearTimeInvariant + MIMO systems in a series configuration. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``MIMOSeries(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + + ``num_outputs`` of the MIMO system is not equal to the + ``num_inputs`` of its adjacent MIMO system. (Matrix + multiplication constraint, basically) + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, MIMO in this case. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import MIMOSeries, TransferFunctionMatrix + >>> from sympy import Matrix, pprint + >>> mat_a = Matrix([[5*s], [5]]) # 2 Outputs 1 Input + >>> mat_b = Matrix([[5, 1/(6*s**2)]]) # 1 Output 2 Inputs + >>> mat_c = Matrix([[1, s], [5/s, 1]]) # 2 Outputs 2 Inputs + >>> tfm_a = TransferFunctionMatrix.from_Matrix(mat_a, s) + >>> tfm_b = TransferFunctionMatrix.from_Matrix(mat_b, s) + >>> tfm_c = TransferFunctionMatrix.from_Matrix(mat_c, s) + >>> MIMOSeries(tfm_c, tfm_b, tfm_a) + MIMOSeries(TransferFunctionMatrix(((TransferFunction(1, 1, s), TransferFunction(s, 1, s)), (TransferFunction(5, s, s), TransferFunction(1, 1, s)))), TransferFunctionMatrix(((TransferFunction(5, 1, s), TransferFunction(1, 6*s**2, s)),)), TransferFunctionMatrix(((TransferFunction(5*s, 1, s),), (TransferFunction(5, 1, s),)))) + >>> pprint(_, use_unicode=False) # For Better Visualization + [5*s] [1 s] + [---] [5 1 ] [- -] + [ 1 ] [- ----] [1 1] + [ ] *[1 2] *[ ] + [ 5 ] [ 6*s ]{t} [5 1] + [ - ] [- -] + [ 1 ]{t} [s 1]{t} + >>> MIMOSeries(tfm_c, tfm_b, tfm_a).doit() + TransferFunctionMatrix(((TransferFunction(150*s**4 + 25*s, 6*s**3, s), TransferFunction(150*s**4 + 5*s, 6*s**2, s)), (TransferFunction(150*s**3 + 25, 6*s**3, s), TransferFunction(150*s**3 + 5, 6*s**2, s)))) + >>> pprint(_, use_unicode=False) # (2 Inputs -A-> 2 Outputs) -> (2 Inputs -B-> 1 Output) -> (1 Input -C-> 2 Outputs) is equivalent to (2 Inputs -Series Equivalent-> 2 Outputs). + [ 4 4 ] + [150*s + 25*s 150*s + 5*s] + [------------- ------------] + [ 3 2 ] + [ 6*s 6*s ] + [ ] + [ 3 3 ] + [ 150*s + 25 150*s + 5 ] + [ ----------- ---------- ] + [ 3 2 ] + [ 6*s 6*s ]{t} + + Notes + ===== + + All the transfer function matrices should use the same complex variable ``var`` of the Laplace transform. + + ``MIMOSeries(A, B)`` is not equivalent to ``A*B``. It is always in the reverse order, that is ``B*A``. + + See Also + ======== + + Series, MIMOParallel + + """ + def __new__(cls, *args, evaluate=False): + + cls._check_args(args) + + if _mat_mul_compatible(*args): + obj = super().__new__(cls, *args) + + else: + raise ValueError(filldedent(""" + Number of input signals do not match the number + of output signals of adjacent systems for some args.""")) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, MIMOSeries, TransferFunctionMatrix + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> tfm_1 = TransferFunctionMatrix([[G1, G2, G3]]) + >>> tfm_2 = TransferFunctionMatrix([[G1], [G2], [G3]]) + >>> MIMOSeries(tfm_2, tfm_1).var + p + + """ + return self.args[0].var + + @property + def num_inputs(self): + """Returns the number of input signals of the series system.""" + return self.args[0].num_inputs + + @property + def num_outputs(self): + """Returns the number of output signals of the series system.""" + return self.args[-1].num_outputs + + @property + def shape(self): + """Returns the shape of the equivalent MIMO system.""" + return self.num_outputs, self.num_inputs + + def doit(self, cancel=False, **kwargs): + """ + Returns the resultant transfer function matrix obtained after evaluating + the MIMO systems arranged in a series configuration. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, MIMOSeries, TransferFunctionMatrix + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf2]]) + >>> tfm2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf1]]) + >>> MIMOSeries(tfm2, tfm1).doit() + TransferFunctionMatrix(((TransferFunction(2*(-p + s)*(s**3 - 2)*(a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)**2*(s**4 + 5*s + 6)**2, s), TransferFunction((-p + s)**2*(s**3 - 2)*(a*p**2 + b*s) + (-p + s)*(a*p**2 + b*s)**2*(s**4 + 5*s + 6), (-p + s)**3*(s**4 + 5*s + 6), s)), (TransferFunction((-p + s)*(s**3 - 2)**2*(s**4 + 5*s + 6) + (s**3 - 2)*(a*p**2 + b*s)*(s**4 + 5*s + 6)**2, (-p + s)*(s**4 + 5*s + 6)**3, s), TransferFunction(2*(s**3 - 2)*(a*p**2 + b*s), (-p + s)*(s**4 + 5*s + 6), s)))) + + """ + _arg = (arg.doit()._expr_mat for arg in reversed(self.args)) + + if cancel: + res = MatMul(*_arg, evaluate=True) + return TransferFunctionMatrix.from_Matrix(res, self.var) + + _dummy_args, _dummy_dict = _dummify_args(_arg, self.var) + res = MatMul(*_dummy_args, evaluate=True) + temp_tfm = TransferFunctionMatrix.from_Matrix(res, self.var) + return temp_tfm.subs(_dummy_dict) + + def _eval_rewrite_as_TransferFunctionMatrix(self, *args, **kwargs): + return self.doit() + + @_check_other_MIMO + def __add__(self, other): + + if isinstance(other, MIMOParallel): + arg_list = list(other.args) + return MIMOParallel(self, *arg_list) + + return MIMOParallel(self, other) + + __radd__ = __add__ + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_MIMO + def __mul__(self, other): + + if isinstance(other, MIMOSeries): + self_arg_list = list(self.args) + other_arg_list = list(other.args) + return MIMOSeries(*other_arg_list, *self_arg_list) # A*B = MIMOSeries(B, A) + + arg_list = list(self.args) + return MIMOSeries(other, *arg_list) + + def __neg__(self): + arg_list = list(self.args) + arg_list[0] = -arg_list[0] + return MIMOSeries(*arg_list) + + +class Parallel(SISOLinearTimeInvariant): + r""" + A class for representing a parallel configuration of SISO systems. + + Parameters + ========== + + args : SISOLinearTimeInvariant + SISO systems in a parallel arrangement. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``Parallel(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel, Series + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(p**2, p + s, s) + >>> P1 = Parallel(tf1, tf2) + >>> P1 + Parallel(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)) + >>> P1.var + s + >>> P2 = Parallel(tf2, Series(tf3, -tf1)) + >>> P2 + Parallel(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), Series(TransferFunction(p**2, p + s, s), TransferFunction(-a*p**2 - b*s, -p + s, s))) + >>> P2.var + s + >>> P3 = Parallel(Series(tf1, tf2), Series(tf2, tf3)) + >>> P3 + Parallel(Series(TransferFunction(a*p**2 + b*s, -p + s, s), TransferFunction(s**3 - 2, s**4 + 5*s + 6, s)), Series(TransferFunction(s**3 - 2, s**4 + 5*s + 6, s), TransferFunction(p**2, p + s, s))) + >>> P3.var + s + + You can get the resultant transfer function by using ``.doit()`` method: + + >>> Parallel(tf1, tf2, -tf3).doit() + TransferFunction(-p**2*(-p + s)*(s**4 + 5*s + 6) + (-p + s)*(p + s)*(s**3 - 2) + (p + s)*(a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + >>> Parallel(tf2, Series(tf1, -tf3)).doit() + TransferFunction(-p**2*(a*p**2 + b*s)*(s**4 + 5*s + 6) + (-p + s)*(p + s)*(s**3 - 2), (-p + s)*(p + s)*(s**4 + 5*s + 6), s) + + Notes + ===== + + All the transfer functions should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + Series, TransferFunction, Feedback + + """ + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, Parallel) + cls._check_args(args) + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Parallel, Series + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> Parallel(G1, G2).var + p + >>> Parallel(-G3, Series(G1, G2)).var + p + + """ + return self.args[0].var + + def doit(self, **hints): + """ + Returns the resultant transfer function obtained after evaluating + the transfer functions in parallel configuration. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> Parallel(tf2, tf1).doit() + TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s) + >>> Parallel(-tf1, -tf2).doit() + TransferFunction((2 - s**3)*(-p + s) + (-a*p**2 - b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s) + + """ + + _arg = (arg.doit().to_expr() for arg in self.args) + res = Add(*_arg).as_numer_denom() + return TransferFunction(*res, self.var) + + def _eval_rewrite_as_TransferFunction(self, *args, **kwargs): + return self.doit() + + @_check_other_SISO + def __add__(self, other): + + self_arg_list = list(self.args) + return Parallel(*self_arg_list, other) + + __radd__ = __add__ + + @_check_other_SISO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_SISO + def __mul__(self, other): + + if isinstance(other, Series): + arg_list = list(other.args) + return Series(self, *arg_list) + + return Series(self, other) + + def __neg__(self): + return Series(TransferFunction(-1, 1, self.var), self) + + def to_expr(self): + """Returns the equivalent ``Expr`` object.""" + return Add(*(arg.to_expr() for arg in self.args), evaluate=False) + + @property + def is_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is less than or equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(b*s**2 + p**2 - a*p + s, b - p**2, s) + >>> tf2 = TransferFunction(p**2 - 4*p, p**3 + 3*s + 2, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(-tf2, tf1) + >>> P1.is_proper + False + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_proper + True + + """ + return self.doit().is_proper + + @property + def is_strictly_proper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is strictly less than degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(tf1, tf2) + >>> P1.is_strictly_proper + False + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_strictly_proper + True + + """ + return self.doit().is_strictly_proper + + @property + def is_biproper(self): + """ + Returns True if degree of the numerator polynomial of the resultant transfer + function is equal to degree of the denominator polynomial of + the same, else False. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, Parallel + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(p**2, p + s, s) + >>> tf3 = TransferFunction(s, s**2 + s + 1, s) + >>> P1 = Parallel(tf1, -tf2) + >>> P1.is_biproper + True + >>> P2 = Parallel(tf2, tf3) + >>> P2.is_biproper + False + + """ + return self.doit().is_biproper + + +class MIMOParallel(MIMOLinearTimeInvariant): + r""" + A class for representing a parallel configuration of MIMO systems. + + Parameters + ========== + + args : MIMOLinearTimeInvariant + MIMO Systems in a parallel arrangement. + evaluate : Boolean, Keyword + When passed ``True``, returns the equivalent + ``MIMOParallel(*args).doit()``. Set to ``False`` by default. + + Raises + ====== + + ValueError + When no argument is passed. + + ``var`` attribute is not same for every system. + + All MIMO systems passed do not have same shape. + TypeError + Any of the passed ``*args`` has unsupported type + + A combination of SISO and MIMO systems is + passed. There should be homogeneity in the + type of systems passed, MIMO in this case. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix, MIMOParallel + >>> from sympy import Matrix, pprint + >>> expr_1 = 1/s + >>> expr_2 = s/(s**2-1) + >>> expr_3 = (2 + s)/(s**2 - 1) + >>> expr_4 = 5 + >>> tfm_a = TransferFunctionMatrix.from_Matrix(Matrix([[expr_1, expr_2], [expr_3, expr_4]]), s) + >>> tfm_b = TransferFunctionMatrix.from_Matrix(Matrix([[expr_2, expr_1], [expr_4, expr_3]]), s) + >>> tfm_c = TransferFunctionMatrix.from_Matrix(Matrix([[expr_3, expr_4], [expr_1, expr_2]]), s) + >>> MIMOParallel(tfm_a, tfm_b, tfm_c) + MIMOParallel(TransferFunctionMatrix(((TransferFunction(1, s, s), TransferFunction(s, s**2 - 1, s)), (TransferFunction(s + 2, s**2 - 1, s), TransferFunction(5, 1, s)))), TransferFunctionMatrix(((TransferFunction(s, s**2 - 1, s), TransferFunction(1, s, s)), (TransferFunction(5, 1, s), TransferFunction(s + 2, s**2 - 1, s)))), TransferFunctionMatrix(((TransferFunction(s + 2, s**2 - 1, s), TransferFunction(5, 1, s)), (TransferFunction(1, s, s), TransferFunction(s, s**2 - 1, s))))) + >>> pprint(_, use_unicode=False) # For Better Visualization + [ 1 s ] [ s 1 ] [s + 2 5 ] + [ - ------] [------ - ] [------ - ] + [ s 2 ] [ 2 s ] [ 2 1 ] + [ s - 1] [s - 1 ] [s - 1 ] + [ ] + [ ] + [ ] + [s + 2 5 ] [ 5 s + 2 ] [ 1 s ] + [------ - ] [ - ------] [ - ------] + [ 2 1 ] [ 1 2 ] [ s 2 ] + [s - 1 ]{t} [ s - 1]{t} [ s - 1]{t} + >>> MIMOParallel(tfm_a, tfm_b, tfm_c).doit() + TransferFunctionMatrix(((TransferFunction(s**2 + s*(2*s + 2) - 1, s*(s**2 - 1), s), TransferFunction(2*s**2 + 5*s*(s**2 - 1) - 1, s*(s**2 - 1), s)), (TransferFunction(s**2 + s*(s + 2) + 5*s*(s**2 - 1) - 1, s*(s**2 - 1), s), TransferFunction(5*s**2 + 2*s - 3, s**2 - 1, s)))) + >>> pprint(_, use_unicode=False) + [ 2 2 / 2 \ ] + [ s + s*(2*s + 2) - 1 2*s + 5*s*\s - 1/ - 1] + [ -------------------- -----------------------] + [ / 2 \ / 2 \ ] + [ s*\s - 1/ s*\s - 1/ ] + [ ] + [ 2 / 2 \ 2 ] + [s + s*(s + 2) + 5*s*\s - 1/ - 1 5*s + 2*s - 3 ] + [--------------------------------- -------------- ] + [ / 2 \ 2 ] + [ s*\s - 1/ s - 1 ]{t} + + Notes + ===== + + All the transfer function matrices should use the same complex variable + ``var`` of the Laplace transform. + + See Also + ======== + + Parallel, MIMOSeries + + """ + def __new__(cls, *args, evaluate=False): + + args = _flatten_args(args, MIMOParallel) + + cls._check_args(args) + + if any(arg.shape != args[0].shape for arg in args): + raise TypeError("Shape of all the args is not equal.") + + obj = super().__new__(cls, *args) + + return obj.doit() if evaluate else obj + + @property + def var(self): + """ + Returns the complex variable used by all the systems. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOParallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> G4 = TransferFunction(p**2, p**2 - 1, p) + >>> tfm_a = TransferFunctionMatrix([[G1, G2], [G3, G4]]) + >>> tfm_b = TransferFunctionMatrix([[G2, G1], [G4, G3]]) + >>> MIMOParallel(tfm_a, tfm_b).var + p + + """ + return self.args[0].var + + @property + def num_inputs(self): + """Returns the number of input signals of the parallel system.""" + return self.args[0].num_inputs + + @property + def num_outputs(self): + """Returns the number of output signals of the parallel system.""" + return self.args[0].num_outputs + + @property + def shape(self): + """Returns the shape of the equivalent MIMO system.""" + return self.num_outputs, self.num_inputs + + def doit(self, **hints): + """ + Returns the resultant transfer function matrix obtained after evaluating + the MIMO systems arranged in a parallel configuration. + + Examples + ======== + + >>> from sympy.abc import s, p, a, b + >>> from sympy.physics.control.lti import TransferFunction, MIMOParallel, TransferFunctionMatrix + >>> tf1 = TransferFunction(a*p**2 + b*s, s - p, s) + >>> tf2 = TransferFunction(s**3 - 2, s**4 + 5*s + 6, s) + >>> tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + >>> MIMOParallel(tfm_1, tfm_2).doit() + TransferFunctionMatrix(((TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s), TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s)), (TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s), TransferFunction((-p + s)*(s**3 - 2) + (a*p**2 + b*s)*(s**4 + 5*s + 6), (-p + s)*(s**4 + 5*s + 6), s)))) + + """ + _arg = (arg.doit()._expr_mat for arg in self.args) + res = MatAdd(*_arg, evaluate=True) + return TransferFunctionMatrix.from_Matrix(res, self.var) + + def _eval_rewrite_as_TransferFunctionMatrix(self, *args, **kwargs): + return self.doit() + + @_check_other_MIMO + def __add__(self, other): + + self_arg_list = list(self.args) + return MIMOParallel(*self_arg_list, other) + + __radd__ = __add__ + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + @_check_other_MIMO + def __mul__(self, other): + + if isinstance(other, MIMOSeries): + arg_list = list(other.args) + return MIMOSeries(*arg_list, self) + + return MIMOSeries(other, self) + + def __neg__(self): + arg_list = [-arg for arg in list(self.args)] + return MIMOParallel(*arg_list) + + +class Feedback(TransferFunction): + r""" + A class for representing closed-loop feedback interconnection between two + SISO input/output systems. + + The first argument, ``sys1``, is the feedforward part of the closed-loop + system or in simple words, the dynamical model representing the process + to be controlled. The second argument, ``sys2``, is the feedback system + and controls the fed back signal to ``sys1``. Both ``sys1`` and ``sys2`` + can either be ``Series`` or ``TransferFunction`` objects. + + Parameters + ========== + + sys1 : Series, TransferFunction + The feedforward path system. + sys2 : Series, TransferFunction, optional + The feedback path system (often a feedback controller). + It is the model sitting on the feedback path. + + If not specified explicitly, the sys2 is + assumed to be unit (1.0) transfer function. + sign : int, optional + The sign of feedback. Can either be ``1`` + (for positive feedback) or ``-1`` (for negative feedback). + Default value is `-1`. + + Raises + ====== + + ValueError + When ``sys1`` and ``sys2`` are not using the + same complex variable of the Laplace transform. + + When a combination of ``sys1`` and ``sys2`` yields + zero denominator. + + TypeError + When either ``sys1`` or ``sys2`` is not a ``Series`` or a + ``TransferFunction`` object. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1 + Feedback(TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s), TransferFunction(5*s - 10, s + 7, s), -1) + >>> F1.var + s + >>> F1.args + (TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s), TransferFunction(5*s - 10, s + 7, s), -1) + + You can get the feedforward and feedback path systems by using ``.sys1`` and ``.sys2`` respectively. + + >>> F1.sys1 + TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> F1.sys2 + TransferFunction(5*s - 10, s + 7, s) + + You can get the resultant closed loop transfer function obtained by negative feedback + interconnection using ``.doit()`` method. + + >>> F1.doit() + TransferFunction((s + 7)*(s**2 - 4*s + 2)*(3*s**2 + 7*s - 3), ((s + 7)*(s**2 - 4*s + 2) + (5*s - 10)*(3*s**2 + 7*s - 3))*(s**2 - 4*s + 2), s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> C = TransferFunction(5*s + 10, s + 10, s) + >>> F2 = Feedback(G*C, TransferFunction(1, 1, s)) + >>> F2.doit() + TransferFunction((s + 10)*(5*s + 10)*(s**2 + 2*s + 3)*(2*s**2 + 5*s + 1), (s + 10)*((s + 10)*(s**2 + 2*s + 3) + (5*s + 10)*(2*s**2 + 5*s + 1))*(s**2 + 2*s + 3), s) + + To negate a ``Feedback`` object, the ``-`` operator can be prepended: + + >>> -F1 + Feedback(TransferFunction(-3*s**2 - 7*s + 3, s**2 - 4*s + 2, s), TransferFunction(10 - 5*s, s + 7, s), -1) + >>> -F2 + Feedback(Series(TransferFunction(-1, 1, s), TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s), TransferFunction(5*s + 10, s + 10, s)), TransferFunction(-1, 1, s), -1) + + See Also + ======== + + MIMOFeedback, Series, Parallel + + """ + def __new__(cls, sys1, sys2=None, sign=-1): + if not sys2: + sys2 = TransferFunction(1, 1, sys1.var) + + if not (isinstance(sys1, (TransferFunction, Series, Feedback)) + and isinstance(sys2, (TransferFunction, Series, Feedback))): + raise TypeError("Unsupported type for `sys1` or `sys2` of Feedback.") + + if sign not in [-1, 1]: + raise ValueError(filldedent(""" + Unsupported type for feedback. `sign` arg should + either be 1 (positive feedback loop) or -1 + (negative feedback loop).""")) + + if Mul(sys1.to_expr(), sys2.to_expr()).simplify() == sign: + raise ValueError("The equivalent system will have zero denominator.") + + if sys1.var != sys2.var: + raise ValueError(filldedent(""" + Both `sys1` and `sys2` should be using the + same complex variable.""")) + + return super(TransferFunction, cls).__new__(cls, sys1, sys2, _sympify(sign)) + + @property + def sys1(self): + """ + Returns the feedforward system of the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.sys1 + TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.sys1 + TransferFunction(1, 1, p) + + """ + return self.args[0] + + @property + def sys2(self): + """ + Returns the feedback controller of the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.sys2 + TransferFunction(5*s - 10, s + 7, s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.sys2 + Series(TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p), TransferFunction(5*p + 10, p + 10, p), TransferFunction(1 - s, p + 2, p)) + + """ + return self.args[1] + + @property + def var(self): + """ + Returns the complex variable of the Laplace transform used by all + the transfer functions involved in the feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.var + s + >>> G = TransferFunction(2*s**2 + 5*s + 1, p**2 + 2*p + 3, p) + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - s, p + 2, p) + >>> F2 = Feedback(TransferFunction(1, 1, p), G*C*P) + >>> F2.var + p + + """ + return self.sys1.var + + @property + def sign(self): + """ + Returns the type of MIMO Feedback model. ``1`` + for Positive and ``-1`` for Negative. + """ + return self.args[2] + + @property + def num(self): + """ + Returns the numerator of the closed loop feedback system. + """ + return self.sys1 + + @property + def den(self): + """ + Returns the denominator of the closed loop feedback model. + """ + unit = TransferFunction(1, 1, self.var) + arg_list = list(self.sys1.args) if isinstance(self.sys1, Series) else [self.sys1] + if self.sign == 1: + return Parallel(unit, -Series(self.sys2, *arg_list)) + return Parallel(unit, Series(self.sys2, *arg_list)) + + @property + def sensitivity(self): + """ + Returns the sensitivity function of the feedback loop. + + Sensitivity of a Feedback system is the ratio + of change in the open loop gain to the change in + the closed loop gain. + + .. note:: + This method would not return the complementary + sensitivity function. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> C = TransferFunction(5*p + 10, p + 10, p) + >>> P = TransferFunction(1 - p, p + 2, p) + >>> F_1 = Feedback(P, C) + >>> F_1.sensitivity + 1/((1 - p)*(5*p + 10)/((p + 2)*(p + 10)) + 1) + + """ + + return 1/(1 - self.sign*self.sys1.to_expr()*self.sys2.to_expr()) + + def doit(self, cancel=False, expand=False, **hints): + """ + Returns the resultant transfer function obtained by the + feedback interconnection. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> plant = TransferFunction(3*s**2 + 7*s - 3, s**2 - 4*s + 2, s) + >>> controller = TransferFunction(5*s - 10, s + 7, s) + >>> F1 = Feedback(plant, controller) + >>> F1.doit() + TransferFunction((s + 7)*(s**2 - 4*s + 2)*(3*s**2 + 7*s - 3), ((s + 7)*(s**2 - 4*s + 2) + (5*s - 10)*(3*s**2 + 7*s - 3))*(s**2 - 4*s + 2), s) + >>> G = TransferFunction(2*s**2 + 5*s + 1, s**2 + 2*s + 3, s) + >>> F2 = Feedback(G, TransferFunction(1, 1, s)) + >>> F2.doit() + TransferFunction((s**2 + 2*s + 3)*(2*s**2 + 5*s + 1), (s**2 + 2*s + 3)*(3*s**2 + 7*s + 4), s) + + Use kwarg ``expand=True`` to expand the resultant transfer function. + Use ``cancel=True`` to cancel out the common terms in numerator and + denominator. + + >>> F2.doit(cancel=True, expand=True) + TransferFunction(2*s**2 + 5*s + 1, 3*s**2 + 7*s + 4, s) + >>> F2.doit(expand=True) + TransferFunction(2*s**4 + 9*s**3 + 17*s**2 + 17*s + 3, 3*s**4 + 13*s**3 + 27*s**2 + 29*s + 12, s) + + """ + arg_list = list(self.sys1.args) if isinstance(self.sys1, Series) else [self.sys1] + # F_n and F_d are resultant TFs of num and den of Feedback. + F_n, unit = self.sys1.doit(), TransferFunction(1, 1, self.sys1.var) + if self.sign == -1: + F_d = Parallel(unit, Series(self.sys2, *arg_list)).doit() + else: + F_d = Parallel(unit, -Series(self.sys2, *arg_list)).doit() + + _resultant_tf = TransferFunction(F_n.num * F_d.den, F_n.den * F_d.num, F_n.var) + + if cancel: + _resultant_tf = _resultant_tf.simplify() + + if expand: + _resultant_tf = _resultant_tf.expand() + + return _resultant_tf + + def _eval_rewrite_as_TransferFunction(self, num, den, sign, **kwargs): + return self.doit() + + def to_expr(self): + """ + Converts a ``Feedback`` object to SymPy Expr. + + Examples + ======== + + >>> from sympy.abc import s, a, b + >>> from sympy.physics.control.lti import TransferFunction, Feedback + >>> from sympy import Expr + >>> tf1 = TransferFunction(a+s, 1, s) + >>> tf2 = TransferFunction(b+s, 1, s) + >>> fd1 = Feedback(tf1, tf2) + >>> fd1.to_expr() + (a + s)/((a + s)*(b + s) + 1) + >>> isinstance(_, Expr) + True + """ + + return self.doit().to_expr() + + def __neg__(self): + return Feedback(-self.sys1, -self.sys2, self.sign) + + +def _is_invertible(a, b, sign): + """ + Checks whether a given pair of MIMO + systems passed is invertible or not. + """ + _mat = eye(a.num_outputs) - sign*(a.doit()._expr_mat)*(b.doit()._expr_mat) + _det = _mat.det() + + return _det != 0 + + +class MIMOFeedback(MIMOLinearTimeInvariant): + r""" + A class for representing closed-loop feedback interconnection between two + MIMO input/output systems. + + Parameters + ========== + + sys1 : MIMOSeries, TransferFunctionMatrix + The MIMO system placed on the feedforward path. + sys2 : MIMOSeries, TransferFunctionMatrix + The system placed on the feedback path + (often a feedback controller). + sign : int, optional + The sign of feedback. Can either be ``1`` + (for positive feedback) or ``-1`` (for negative feedback). + Default value is `-1`. + + Raises + ====== + + ValueError + When ``sys1`` and ``sys2`` are not using the + same complex variable of the Laplace transform. + + Forward path model should have an equal number of inputs/outputs + to the feedback path outputs/inputs. + + When product of ``sys1`` and ``sys2`` is not a square matrix. + + When the equivalent MIMO system is not invertible. + + TypeError + When either ``sys1`` or ``sys2`` is not a ``MIMOSeries`` or a + ``TransferFunctionMatrix`` object. + + Examples + ======== + + >>> from sympy import Matrix, pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix, MIMOFeedback + >>> plant_mat = Matrix([[1, 1/s], [0, 1]]) + >>> controller_mat = Matrix([[10, 0], [0, 10]]) # Constant Gain + >>> plant = TransferFunctionMatrix.from_Matrix(plant_mat, s) + >>> controller = TransferFunctionMatrix.from_Matrix(controller_mat, s) + >>> feedback = MIMOFeedback(plant, controller) # Negative Feedback (default) + >>> pprint(feedback, use_unicode=False) + / [1 1] [10 0 ] \-1 [1 1] + | [- -] [-- - ] | [- -] + | [1 s] [1 1 ] | [1 s] + |I + [ ] *[ ] | * [ ] + | [0 1] [0 10] | [0 1] + | [- -] [- --] | [- -] + \ [1 1]{t} [1 1 ]{t}/ [1 1]{t} + + To get the equivalent system matrix, use either ``doit`` or ``rewrite`` method. + + >>> pprint(feedback.doit(), use_unicode=False) + [1 1 ] + [-- -----] + [11 121*s] + [ ] + [0 1 ] + [- -- ] + [1 11 ]{t} + + To negate the ``MIMOFeedback`` object, use ``-`` operator. + + >>> neg_feedback = -feedback + >>> pprint(neg_feedback.doit(), use_unicode=False) + [-1 -1 ] + [--- -----] + [11 121*s] + [ ] + [ 0 -1 ] + [ - --- ] + [ 1 11 ]{t} + + See Also + ======== + + Feedback, MIMOSeries, MIMOParallel + + """ + def __new__(cls, sys1, sys2, sign=-1): + if not (isinstance(sys1, (TransferFunctionMatrix, MIMOSeries)) + and isinstance(sys2, (TransferFunctionMatrix, MIMOSeries))): + raise TypeError("Unsupported type for `sys1` or `sys2` of MIMO Feedback.") + + if sys1.num_inputs != sys2.num_outputs or \ + sys1.num_outputs != sys2.num_inputs: + raise ValueError(filldedent(""" + Product of `sys1` and `sys2` must + yield a square matrix.""")) + + if sign not in (-1, 1): + raise ValueError(filldedent(""" + Unsupported type for feedback. `sign` arg should + either be 1 (positive feedback loop) or -1 + (negative feedback loop).""")) + + if not _is_invertible(sys1, sys2, sign): + raise ValueError("Non-Invertible system inputted.") + if sys1.var != sys2.var: + raise ValueError(filldedent(""" + Both `sys1` and `sys2` should be using the + same complex variable.""")) + + return super().__new__(cls, sys1, sys2, _sympify(sign)) + + @property + def sys1(self): + r""" + Returns the system placed on the feedforward path of the MIMO feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s**2 + s + 1, s**2 - s + 1, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(1, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf3, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) + >>> F_1.sys1 + TransferFunctionMatrix(((TransferFunction(s**2 + s + 1, s**2 - s + 1, s), TransferFunction(1, s, s)), (TransferFunction(1, s, s), TransferFunction(s**2 + s + 1, s**2 - s + 1, s)))) + >>> pprint(_, use_unicode=False) + [ 2 ] + [s + s + 1 1 ] + [---------- - ] + [ 2 s ] + [s - s + 1 ] + [ ] + [ 2 ] + [ 1 s + s + 1] + [ - ----------] + [ s 2 ] + [ s - s + 1]{t} + + """ + return self.args[0] + + @property + def sys2(self): + r""" + Returns the feedback controller of the MIMO feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s**2, s**3 - s + 1, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(1, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2) + >>> F_1.sys2 + TransferFunctionMatrix(((TransferFunction(s**2, s**3 - s + 1, s), TransferFunction(1, 1, s)), (TransferFunction(1, 1, s), TransferFunction(1, s, s)))) + >>> pprint(_, use_unicode=False) + [ 2 ] + [ s 1] + [---------- -] + [ 3 1] + [s - s + 1 ] + [ ] + [ 1 1] + [ - -] + [ 1 s]{t} + + """ + return self.args[1] + + @property + def var(self): + r""" + Returns the complex variable of the Laplace transform used by all + the transfer functions involved in the MIMO feedback loop. + + Examples + ======== + + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(p, 1 - p, p) + >>> tf2 = TransferFunction(1, p, p) + >>> tf3 = TransferFunction(1, 1, p) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) # Positive feedback + >>> F_1.var + p + + """ + return self.sys1.var + + @property + def sign(self): + r""" + Returns the type of feedback interconnection of two models. ``1`` + for Positive and ``-1`` for Negative. + """ + return self.args[2] + + @property + def sensitivity(self): + r""" + Returns the sensitivity function matrix of the feedback loop. + + Sensitivity of a closed-loop system is the ratio of change + in the open loop gain to the change in the closed loop gain. + + .. note:: + This method would not return the complementary + sensitivity function. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(p, 1 - p, p) + >>> tf2 = TransferFunction(1, p, p) + >>> tf3 = TransferFunction(1, 1, p) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + >>> sys2 = TransferFunctionMatrix([[tf1, tf3], [tf3, tf2]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) # Positive feedback + >>> F_2 = MIMOFeedback(sys1, sys2) # Negative feedback + >>> pprint(F_1.sensitivity, use_unicode=False) + [ 4 3 2 5 4 2 ] + [- p + 3*p - 4*p + 3*p - 1 p - 2*p + 3*p - 3*p + 1 ] + [---------------------------- -----------------------------] + [ 4 3 2 5 4 3 2 ] + [ p + 3*p - 8*p + 8*p - 3 p + 3*p - 8*p + 8*p - 3*p] + [ ] + [ 4 3 2 3 2 ] + [ p - p - p + p 3*p - 6*p + 4*p - 1 ] + [ -------------------------- -------------------------- ] + [ 4 3 2 4 3 2 ] + [ p + 3*p - 8*p + 8*p - 3 p + 3*p - 8*p + 8*p - 3 ] + >>> pprint(F_2.sensitivity, use_unicode=False) + [ 4 3 2 5 4 2 ] + [p - 3*p + 2*p + p - 1 p - 2*p + 3*p - 3*p + 1] + [------------------------ --------------------------] + [ 4 3 5 4 2 ] + [ p - 3*p + 2*p - 1 p - 3*p + 2*p - p ] + [ ] + [ 4 3 2 4 3 ] + [ p - p - p + p 2*p - 3*p + 2*p - 1 ] + [ ------------------- --------------------- ] + [ 4 3 4 3 ] + [ p - 3*p + 2*p - 1 p - 3*p + 2*p - 1 ] + + """ + _sys1_mat = self.sys1.doit()._expr_mat + _sys2_mat = self.sys2.doit()._expr_mat + + return (eye(self.sys1.num_inputs) - \ + self.sign*_sys1_mat*_sys2_mat).inv() + + def doit(self, cancel=True, expand=False, **hints): + r""" + Returns the resultant transfer function matrix obtained by the + feedback interconnection. + + Examples + ======== + + >>> from sympy import pprint + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, MIMOFeedback + >>> tf1 = TransferFunction(s, 1 - s, s) + >>> tf2 = TransferFunction(1, s, s) + >>> tf3 = TransferFunction(5, 1, s) + >>> tf4 = TransferFunction(s - 1, s, s) + >>> tf5 = TransferFunction(0, 1, s) + >>> sys1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + >>> sys2 = TransferFunctionMatrix([[tf3, tf5], [tf5, tf5]]) + >>> F_1 = MIMOFeedback(sys1, sys2, 1) + >>> pprint(F_1, use_unicode=False) + / [ s 1 ] [5 0] \-1 [ s 1 ] + | [----- - ] [- -] | [----- - ] + | [1 - s s ] [1 1] | [1 - s s ] + |I - [ ] *[ ] | * [ ] + | [ 5 s - 1] [0 0] | [ 5 s - 1] + | [ - -----] [- -] | [ - -----] + \ [ 1 s ]{t} [1 1]{t}/ [ 1 s ]{t} + >>> pprint(F_1.doit(), use_unicode=False) + [ -s s - 1 ] + [------- ----------- ] + [6*s - 1 s*(6*s - 1) ] + [ ] + [5*s - 5 (s - 1)*(6*s + 24)] + [------- ------------------] + [6*s - 1 s*(6*s - 1) ]{t} + + If the user wants the resultant ``TransferFunctionMatrix`` object without + canceling the common factors then the ``cancel`` kwarg should be passed ``False``. + + >>> pprint(F_1.doit(cancel=False), use_unicode=False) + [ s*(s - 1) s - 1 ] + [ ----------------- ----------- ] + [ (1 - s)*(6*s - 1) s*(6*s - 1) ] + [ ] + [s*(25*s - 25) + 5*(1 - s)*(6*s - 1) s*(s - 1)*(6*s - 1) + s*(25*s - 25)] + [----------------------------------- -----------------------------------] + [ (1 - s)*(6*s - 1) 2 ] + [ s *(6*s - 1) ]{t} + + If the user wants the expanded form of the resultant transfer function matrix, + the ``expand`` kwarg should be passed as ``True``. + + >>> pprint(F_1.doit(expand=True), use_unicode=False) + [ -s s - 1 ] + [------- -------- ] + [6*s - 1 2 ] + [ 6*s - s ] + [ ] + [ 2 ] + [5*s - 5 6*s + 18*s - 24] + [------- ----------------] + [6*s - 1 2 ] + [ 6*s - s ]{t} + + """ + _mat = self.sensitivity * self.sys1.doit()._expr_mat + + _resultant_tfm = _to_TFM(_mat, self.var) + + if cancel: + _resultant_tfm = _resultant_tfm.simplify() + + if expand: + _resultant_tfm = _resultant_tfm.expand() + + return _resultant_tfm + + def _eval_rewrite_as_TransferFunctionMatrix(self, sys1, sys2, sign, **kwargs): + return self.doit() + + def __neg__(self): + return MIMOFeedback(-self.sys1, -self.sys2, self.sign) + + +def _to_TFM(mat, var): + """Private method to convert ImmutableMatrix to TransferFunctionMatrix efficiently""" + to_tf = lambda expr: TransferFunction.from_rational_expression(expr, var) + arg = [[to_tf(expr) for expr in row] for row in mat.tolist()] + return TransferFunctionMatrix(arg) + + +class TransferFunctionMatrix(MIMOLinearTimeInvariant): + r""" + A class for representing the MIMO (multiple-input and multiple-output) + generalization of the SISO (single-input and single-output) transfer function. + + It is a matrix of transfer functions (``TransferFunction``, SISO-``Series`` or SISO-``Parallel``). + There is only one argument, ``arg`` which is also the compulsory argument. + ``arg`` is expected to be strictly of the type list of lists + which holds the transfer functions or reducible to transfer functions. + + Parameters + ========== + + arg : Nested ``List`` (strictly). + Users are expected to input a nested list of ``TransferFunction``, ``Series`` + and/or ``Parallel`` objects. + + Examples + ======== + + .. note:: + ``pprint()`` can be used for better visualization of ``TransferFunctionMatrix`` objects. + + >>> from sympy.abc import s, p, a + >>> from sympy import pprint + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, Series, Parallel + >>> tf_1 = TransferFunction(s + a, s**2 + s + 1, s) + >>> tf_2 = TransferFunction(p**4 - 3*p + 2, s + p, s) + >>> tf_3 = TransferFunction(3, s + 2, s) + >>> tf_4 = TransferFunction(-a + p, 9*s - 9, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1], [tf_2], [tf_3]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(3, s + 2, s),))) + >>> tfm_1.var + s + >>> tfm_1.num_inputs + 1 + >>> tfm_1.num_outputs + 3 + >>> tfm_1.shape + (3, 1) + >>> tfm_1.args + (((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(3, s + 2, s),)),) + >>> tfm_2 = TransferFunctionMatrix([[tf_1, -tf_3], [tf_2, -tf_1], [tf_3, -tf_2]]) + >>> tfm_2 + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(p**4 - 3*p + 2, p + s, s), TransferFunction(-a - s, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)))) + >>> pprint(tfm_2, use_unicode=False) # pretty-printing for better visualization + [ a + s -3 ] + [ ---------- ----- ] + [ 2 s + 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [p - 3*p + 2 -a - s ] + [------------ ---------- ] + [ p + s 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [ 3 - p + 3*p - 2] + [ ----- --------------] + [ s + 2 p + s ]{t} + + TransferFunctionMatrix can be transposed, if user wants to switch the input and output transfer functions + + >>> tfm_2.transpose() + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(p**4 - 3*p + 2, p + s, s), TransferFunction(3, s + 2, s)), (TransferFunction(-3, s + 2, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)))) + >>> pprint(_, use_unicode=False) + [ 4 ] + [ a + s p - 3*p + 2 3 ] + [---------- ------------ ----- ] + [ 2 p + s s + 2 ] + [s + s + 1 ] + [ ] + [ 4 ] + [ -3 -a - s - p + 3*p - 2] + [ ----- ---------- --------------] + [ s + 2 2 p + s ] + [ s + s + 1 ]{t} + + >>> tf_5 = TransferFunction(5, s, s) + >>> tf_6 = TransferFunction(5*s, (2 + s**2), s) + >>> tf_7 = TransferFunction(5, (s*(2 + s**2)), s) + >>> tf_8 = TransferFunction(5, 1, s) + >>> tfm_3 = TransferFunctionMatrix([[tf_5, tf_6], [tf_7, tf_8]]) + >>> tfm_3 + TransferFunctionMatrix(((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)), (TransferFunction(5, s*(s**2 + 2), s), TransferFunction(5, 1, s)))) + >>> pprint(tfm_3, use_unicode=False) + [ 5 5*s ] + [ - ------] + [ s 2 ] + [ s + 2] + [ ] + [ 5 5 ] + [---------- - ] + [ / 2 \ 1 ] + [s*\s + 2/ ]{t} + >>> tfm_3.var + s + >>> tfm_3.shape + (2, 2) + >>> tfm_3.num_outputs + 2 + >>> tfm_3.num_inputs + 2 + >>> tfm_3.args + (((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)), (TransferFunction(5, s*(s**2 + 2), s), TransferFunction(5, 1, s))),) + + To access the ``TransferFunction`` at any index in the ``TransferFunctionMatrix``, use the index notation. + + >>> tfm_3[1, 0] # gives the TransferFunction present at 2nd Row and 1st Col. Similar to that in Matrix classes + TransferFunction(5, s*(s**2 + 2), s) + >>> tfm_3[0, 0] # gives the TransferFunction present at 1st Row and 1st Col. + TransferFunction(5, s, s) + >>> tfm_3[:, 0] # gives the first column + TransferFunctionMatrix(((TransferFunction(5, s, s),), (TransferFunction(5, s*(s**2 + 2), s),))) + >>> pprint(_, use_unicode=False) + [ 5 ] + [ - ] + [ s ] + [ ] + [ 5 ] + [----------] + [ / 2 \] + [s*\s + 2/]{t} + >>> tfm_3[0, :] # gives the first row + TransferFunctionMatrix(((TransferFunction(5, s, s), TransferFunction(5*s, s**2 + 2, s)),)) + >>> pprint(_, use_unicode=False) + [5 5*s ] + [- ------] + [s 2 ] + [ s + 2]{t} + + To negate a transfer function matrix, ``-`` operator can be prepended: + + >>> tfm_4 = TransferFunctionMatrix([[tf_2], [-tf_1], [tf_3]]) + >>> -tfm_4 + TransferFunctionMatrix(((TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(-3, s + 2, s),))) + >>> tfm_5 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, -tf_1]]) + >>> -tfm_5 + TransferFunctionMatrix(((TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(-p**4 + 3*p - 2, p + s, s)), (TransferFunction(-3, s + 2, s), TransferFunction(a + s, s**2 + s + 1, s)))) + + ``subs()`` returns the ``TransferFunctionMatrix`` object with the value substituted in the expression. This will not + mutate your original ``TransferFunctionMatrix``. + + >>> tfm_2.subs(p, 2) # substituting p everywhere in tfm_2 with 2. + TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(12, s + 2, s), TransferFunction(-a - s, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-12, s + 2, s)))) + >>> pprint(_, use_unicode=False) + [ a + s -3 ] + [---------- ----- ] + [ 2 s + 2 ] + [s + s + 1 ] + [ ] + [ 12 -a - s ] + [ ----- ----------] + [ s + 2 2 ] + [ s + s + 1] + [ ] + [ 3 -12 ] + [ ----- ----- ] + [ s + 2 s + 2 ]{t} + >>> pprint(tfm_2, use_unicode=False) # State of tfm_2 is unchanged after substitution + [ a + s -3 ] + [ ---------- ----- ] + [ 2 s + 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [p - 3*p + 2 -a - s ] + [------------ ---------- ] + [ p + s 2 ] + [ s + s + 1 ] + [ ] + [ 4 ] + [ 3 - p + 3*p - 2] + [ ----- --------------] + [ s + 2 p + s ]{t} + + ``subs()`` also supports multiple substitutions. + + >>> tfm_2.subs({p: 2, a: 1}) # substituting p with 2 and a with 1 + TransferFunctionMatrix(((TransferFunction(s + 1, s**2 + s + 1, s), TransferFunction(-3, s + 2, s)), (TransferFunction(12, s + 2, s), TransferFunction(-s - 1, s**2 + s + 1, s)), (TransferFunction(3, s + 2, s), TransferFunction(-12, s + 2, s)))) + >>> pprint(_, use_unicode=False) + [ s + 1 -3 ] + [---------- ----- ] + [ 2 s + 2 ] + [s + s + 1 ] + [ ] + [ 12 -s - 1 ] + [ ----- ----------] + [ s + 2 2 ] + [ s + s + 1] + [ ] + [ 3 -12 ] + [ ----- ----- ] + [ s + 2 s + 2 ]{t} + + Users can reduce the ``Series`` and ``Parallel`` elements of the matrix to ``TransferFunction`` by using + ``doit()``. + + >>> tfm_6 = TransferFunctionMatrix([[Series(tf_3, tf_4), Parallel(tf_3, tf_4)]]) + >>> tfm_6 + TransferFunctionMatrix(((Series(TransferFunction(3, s + 2, s), TransferFunction(-a + p, 9*s - 9, s)), Parallel(TransferFunction(3, s + 2, s), TransferFunction(-a + p, 9*s - 9, s))),)) + >>> pprint(tfm_6, use_unicode=False) + [-a + p 3 -a + p 3 ] + [-------*----- ------- + -----] + [9*s - 9 s + 2 9*s - 9 s + 2]{t} + >>> tfm_6.doit() + TransferFunctionMatrix(((TransferFunction(-3*a + 3*p, (s + 2)*(9*s - 9), s), TransferFunction(27*s + (-a + p)*(s + 2) - 27, (s + 2)*(9*s - 9), s)),)) + >>> pprint(_, use_unicode=False) + [ -3*a + 3*p 27*s + (-a + p)*(s + 2) - 27] + [----------------- ----------------------------] + [(s + 2)*(9*s - 9) (s + 2)*(9*s - 9) ]{t} + >>> tf_9 = TransferFunction(1, s, s) + >>> tf_10 = TransferFunction(1, s**2, s) + >>> tfm_7 = TransferFunctionMatrix([[Series(tf_9, tf_10), tf_9], [tf_10, Parallel(tf_9, tf_10)]]) + >>> tfm_7 + TransferFunctionMatrix(((Series(TransferFunction(1, s, s), TransferFunction(1, s**2, s)), TransferFunction(1, s, s)), (TransferFunction(1, s**2, s), Parallel(TransferFunction(1, s, s), TransferFunction(1, s**2, s))))) + >>> pprint(tfm_7, use_unicode=False) + [ 1 1 ] + [---- - ] + [ 2 s ] + [s*s ] + [ ] + [ 1 1 1] + [ -- -- + -] + [ 2 2 s] + [ s s ]{t} + >>> tfm_7.doit() + TransferFunctionMatrix(((TransferFunction(1, s**3, s), TransferFunction(1, s, s)), (TransferFunction(1, s**2, s), TransferFunction(s**2 + s, s**3, s)))) + >>> pprint(_, use_unicode=False) + [1 1 ] + [-- - ] + [ 3 s ] + [s ] + [ ] + [ 2 ] + [1 s + s] + [-- ------] + [ 2 3 ] + [s s ]{t} + + Addition, subtraction, and multiplication of transfer function matrices can form + unevaluated ``Series`` or ``Parallel`` objects. + + - For addition and subtraction: + All the transfer function matrices must have the same shape. + + - For multiplication (C = A * B): + The number of inputs of the first transfer function matrix (A) must be equal to the + number of outputs of the second transfer function matrix (B). + + Also, use pretty-printing (``pprint``) to analyse better. + + >>> tfm_8 = TransferFunctionMatrix([[tf_3], [tf_2], [-tf_1]]) + >>> tfm_9 = TransferFunctionMatrix([[-tf_3]]) + >>> tfm_10 = TransferFunctionMatrix([[tf_1], [tf_2], [tf_4]]) + >>> tfm_11 = TransferFunctionMatrix([[tf_4], [-tf_1]]) + >>> tfm_12 = TransferFunctionMatrix([[tf_4, -tf_1, tf_3], [-tf_2, -tf_4, -tf_3]]) + >>> tfm_8 + tfm_10 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a + p, 9*s - 9, s),)))) + >>> pprint(_, use_unicode=False) + [ 3 ] [ a + s ] + [ ----- ] [ ---------- ] + [ s + 2 ] [ 2 ] + [ ] [ s + s + 1 ] + [ 4 ] [ ] + [p - 3*p + 2] [ 4 ] + [------------] + [p - 3*p + 2] + [ p + s ] [------------] + [ ] [ p + s ] + [ -a - s ] [ ] + [ ---------- ] [ -a + p ] + [ 2 ] [ ------- ] + [ s + s + 1 ]{t} [ 9*s - 9 ]{t} + >>> -tfm_10 - tfm_8 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(-a - s, s**2 + s + 1, s),), (TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a - p, 9*s - 9, s),))), TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),), (TransferFunction(-p**4 + 3*p - 2, p + s, s),), (TransferFunction(a + s, s**2 + s + 1, s),)))) + >>> pprint(_, use_unicode=False) + [ -a - s ] [ -3 ] + [ ---------- ] [ ----- ] + [ 2 ] [ s + 2 ] + [ s + s + 1 ] [ ] + [ ] [ 4 ] + [ 4 ] [- p + 3*p - 2] + [- p + 3*p - 2] + [--------------] + [--------------] [ p + s ] + [ p + s ] [ ] + [ ] [ a + s ] + [ a - p ] [ ---------- ] + [ ------- ] [ 2 ] + [ 9*s - 9 ]{t} [ s + s + 1 ]{t} + >>> tfm_12 * tfm_8 + MIMOSeries(TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(-a + p, 9*s - 9, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(3, s + 2, s)), (TransferFunction(-p**4 + 3*p - 2, p + s, s), TransferFunction(a - p, 9*s - 9, s), TransferFunction(-3, s + 2, s))))) + >>> pprint(_, use_unicode=False) + [ 3 ] + [ ----- ] + [ -a + p -a - s 3 ] [ s + 2 ] + [ ------- ---------- -----] [ ] + [ 9*s - 9 2 s + 2] [ 4 ] + [ s + s + 1 ] [p - 3*p + 2] + [ ] *[------------] + [ 4 ] [ p + s ] + [- p + 3*p - 2 a - p -3 ] [ ] + [-------------- ------- -----] [ -a - s ] + [ p + s 9*s - 9 s + 2]{t} [ ---------- ] + [ 2 ] + [ s + s + 1 ]{t} + >>> tfm_12 * tfm_8 * tfm_9 + MIMOSeries(TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),),)), TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))), TransferFunctionMatrix(((TransferFunction(-a + p, 9*s - 9, s), TransferFunction(-a - s, s**2 + s + 1, s), TransferFunction(3, s + 2, s)), (TransferFunction(-p**4 + 3*p - 2, p + s, s), TransferFunction(a - p, 9*s - 9, s), TransferFunction(-3, s + 2, s))))) + >>> pprint(_, use_unicode=False) + [ 3 ] + [ ----- ] + [ -a + p -a - s 3 ] [ s + 2 ] + [ ------- ---------- -----] [ ] + [ 9*s - 9 2 s + 2] [ 4 ] + [ s + s + 1 ] [p - 3*p + 2] [ -3 ] + [ ] *[------------] *[-----] + [ 4 ] [ p + s ] [s + 2]{t} + [- p + 3*p - 2 a - p -3 ] [ ] + [-------------- ------- -----] [ -a - s ] + [ p + s 9*s - 9 s + 2]{t} [ ---------- ] + [ 2 ] + [ s + s + 1 ]{t} + >>> tfm_10 + tfm_8*tfm_9 + MIMOParallel(TransferFunctionMatrix(((TransferFunction(a + s, s**2 + s + 1, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a + p, 9*s - 9, s),))), MIMOSeries(TransferFunctionMatrix(((TransferFunction(-3, s + 2, s),),)), TransferFunctionMatrix(((TransferFunction(3, s + 2, s),), (TransferFunction(p**4 - 3*p + 2, p + s, s),), (TransferFunction(-a - s, s**2 + s + 1, s),))))) + >>> pprint(_, use_unicode=False) + [ a + s ] [ 3 ] + [ ---------- ] [ ----- ] + [ 2 ] [ s + 2 ] + [ s + s + 1 ] [ ] + [ ] [ 4 ] + [ 4 ] [p - 3*p + 2] [ -3 ] + [p - 3*p + 2] + [------------] *[-----] + [------------] [ p + s ] [s + 2]{t} + [ p + s ] [ ] + [ ] [ -a - s ] + [ -a + p ] [ ---------- ] + [ ------- ] [ 2 ] + [ 9*s - 9 ]{t} [ s + s + 1 ]{t} + + These unevaluated ``Series`` or ``Parallel`` objects can convert into the + resultant transfer function matrix using ``.doit()`` method or by + ``.rewrite(TransferFunctionMatrix)``. + + >>> (-tfm_8 + tfm_10 + tfm_8*tfm_9).doit() + TransferFunctionMatrix(((TransferFunction((a + s)*(s + 2)**3 - 3*(s + 2)**2*(s**2 + s + 1) - 9*(s + 2)*(s**2 + s + 1), (s + 2)**3*(s**2 + s + 1), s),), (TransferFunction((p + s)*(-3*p**4 + 9*p - 6), (p + s)**2*(s + 2), s),), (TransferFunction((-a + p)*(s + 2)*(s**2 + s + 1)**2 + (a + s)*(s + 2)*(9*s - 9)*(s**2 + s + 1) + (3*a + 3*s)*(9*s - 9)*(s**2 + s + 1), (s + 2)*(9*s - 9)*(s**2 + s + 1)**2, s),))) + >>> (-tfm_12 * -tfm_8 * -tfm_9).rewrite(TransferFunctionMatrix) + TransferFunctionMatrix(((TransferFunction(3*(-3*a + 3*p)*(p + s)*(s + 2)*(s**2 + s + 1)**2 + 3*(-3*a - 3*s)*(p + s)*(s + 2)*(9*s - 9)*(s**2 + s + 1) + 3*(a + s)*(s + 2)**2*(9*s - 9)*(-p**4 + 3*p - 2)*(s**2 + s + 1), (p + s)*(s + 2)**3*(9*s - 9)*(s**2 + s + 1)**2, s),), (TransferFunction(3*(-a + p)*(p + s)*(s + 2)**2*(-p**4 + 3*p - 2)*(s**2 + s + 1) + 3*(3*a + 3*s)*(p + s)**2*(s + 2)*(9*s - 9) + 3*(p + s)*(s + 2)*(9*s - 9)*(-3*p**4 + 9*p - 6)*(s**2 + s + 1), (p + s)**2*(s + 2)**3*(9*s - 9)*(s**2 + s + 1), s),))) + + See Also + ======== + + TransferFunction, MIMOSeries, MIMOParallel, Feedback + + """ + def __new__(cls, arg): + + expr_mat_arg = [] + try: + var = arg[0][0].var + except TypeError: + raise ValueError(filldedent(""" + `arg` param in TransferFunctionMatrix should + strictly be a nested list containing TransferFunction + objects.""")) + for row in arg: + temp = [] + for element in row: + if not isinstance(element, SISOLinearTimeInvariant): + raise TypeError(filldedent(""" + Each element is expected to be of + type `SISOLinearTimeInvariant`.""")) + + if var != element.var: + raise ValueError(filldedent(""" + Conflicting value(s) found for `var`. All TransferFunction + instances in TransferFunctionMatrix should use the same + complex variable in Laplace domain.""")) + + temp.append(element.to_expr()) + expr_mat_arg.append(temp) + + if isinstance(arg, (tuple, list, Tuple)): + # Making nested Tuple (sympy.core.containers.Tuple) from nested list or nested Python tuple + arg = Tuple(*(Tuple(*r, sympify=False) for r in arg), sympify=False) + + obj = super(TransferFunctionMatrix, cls).__new__(cls, arg) + obj._expr_mat = ImmutableMatrix(expr_mat_arg) + return obj + + @classmethod + def from_Matrix(cls, matrix, var): + """ + Creates a new ``TransferFunctionMatrix`` efficiently from a SymPy Matrix of ``Expr`` objects. + + Parameters + ========== + + matrix : ``ImmutableMatrix`` having ``Expr``/``Number`` elements. + var : Symbol + Complex variable of the Laplace transform which will be used by the + all the ``TransferFunction`` objects in the ``TransferFunctionMatrix``. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix + >>> from sympy import Matrix, pprint + >>> M = Matrix([[s, 1/s], [1/(s+1), s]]) + >>> M_tf = TransferFunctionMatrix.from_Matrix(M, s) + >>> pprint(M_tf, use_unicode=False) + [ s 1] + [ - -] + [ 1 s] + [ ] + [ 1 s] + [----- -] + [s + 1 1]{t} + >>> M_tf.elem_poles() + [[[], [0]], [[-1], []]] + >>> M_tf.elem_zeros() + [[[0], []], [[], [0]]] + + """ + return _to_TFM(matrix, var) + + @property + def var(self): + """ + Returns the complex variable used by all the transfer functions or + ``Series``/``Parallel`` objects in a transfer function matrix. + + Examples + ======== + + >>> from sympy.abc import p, s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix, Series, Parallel + >>> G1 = TransferFunction(p**2 + 2*p + 4, p - 6, p) + >>> G2 = TransferFunction(p, 4 - p, p) + >>> G3 = TransferFunction(0, p**4 - 1, p) + >>> G4 = TransferFunction(s + 1, s**2 + s + 1, s) + >>> S1 = Series(G1, G2) + >>> S2 = Series(-G3, Parallel(G2, -G1)) + >>> tfm1 = TransferFunctionMatrix([[G1], [G2], [G3]]) + >>> tfm1.var + p + >>> tfm2 = TransferFunctionMatrix([[-S1, -S2], [S1, S2]]) + >>> tfm2.var + p + >>> tfm3 = TransferFunctionMatrix([[G4]]) + >>> tfm3.var + s + + """ + return self.args[0][0][0].var + + @property + def num_inputs(self): + """ + Returns the number of inputs of the system. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> G1 = TransferFunction(s + 3, s**2 - 3, s) + >>> G2 = TransferFunction(4, s**2, s) + >>> G3 = TransferFunction(p**2 + s**2, p - 3, s) + >>> tfm_1 = TransferFunctionMatrix([[G2, -G1, G3], [-G2, -G1, -G3]]) + >>> tfm_1.num_inputs + 3 + + See Also + ======== + + num_outputs + + """ + return self._expr_mat.shape[1] + + @property + def num_outputs(self): + """ + Returns the number of outputs of the system. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunctionMatrix + >>> from sympy import Matrix + >>> M_1 = Matrix([[s], [1/s]]) + >>> TFM = TransferFunctionMatrix.from_Matrix(M_1, s) + >>> print(TFM) + TransferFunctionMatrix(((TransferFunction(s, 1, s),), (TransferFunction(1, s, s),))) + >>> TFM.num_outputs + 2 + + See Also + ======== + + num_inputs + + """ + return self._expr_mat.shape[0] + + @property + def shape(self): + """ + Returns the shape of the transfer function matrix, that is, ``(# of outputs, # of inputs)``. + + Examples + ======== + + >>> from sympy.abc import s, p + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf1 = TransferFunction(p**2 - 1, s**4 + s**3 - p, p) + >>> tf2 = TransferFunction(1 - p, p**2 - 3*p + 7, p) + >>> tf3 = TransferFunction(3, 4, p) + >>> tfm1 = TransferFunctionMatrix([[tf1, -tf2]]) + >>> tfm1.shape + (1, 2) + >>> tfm2 = TransferFunctionMatrix([[-tf2, tf3], [tf1, -tf1]]) + >>> tfm2.shape + (2, 2) + + """ + return self._expr_mat.shape + + def __neg__(self): + neg = -self._expr_mat + return _to_TFM(neg, self.var) + + @_check_other_MIMO + def __add__(self, other): + + if not isinstance(other, MIMOParallel): + return MIMOParallel(self, other) + other_arg_list = list(other.args) + return MIMOParallel(self, *other_arg_list) + + @_check_other_MIMO + def __sub__(self, other): + return self + (-other) + + @_check_other_MIMO + def __mul__(self, other): + + if not isinstance(other, MIMOSeries): + return MIMOSeries(other, self) + other_arg_list = list(other.args) + return MIMOSeries(*other_arg_list, self) + + def __getitem__(self, key): + trunc = self._expr_mat.__getitem__(key) + if isinstance(trunc, ImmutableMatrix): + return _to_TFM(trunc, self.var) + return TransferFunction.from_rational_expression(trunc, self.var) + + def transpose(self): + """Returns the transpose of the ``TransferFunctionMatrix`` (switched input and output layers).""" + transposed_mat = self._expr_mat.transpose() + return _to_TFM(transposed_mat, self.var) + + def elem_poles(self): + """ + Returns the poles of each element of the ``TransferFunctionMatrix``. + + .. note:: + Actual poles of a MIMO system are NOT the poles of individual elements. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s + 2, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s + 2, s**2 + 5*s - 10, s)))) + >>> tfm_1.elem_poles() + [[[-1], [-2, -1]], [[-2, -1], [-5/2 + sqrt(65)/2, -sqrt(65)/2 - 5/2]]] + + See Also + ======== + + elem_zeros + + """ + return [[element.poles() for element in row] for row in self.doit().args[0]] + + def elem_zeros(self): + """ + Returns the zeros of each element of the ``TransferFunctionMatrix``. + + .. note:: + Actual zeros of a MIMO system are NOT the zeros of individual elements. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s)))) + >>> tfm_1.elem_zeros() + [[[], [-6]], [[-3], [4, 5]]] + + See Also + ======== + + elem_poles + + """ + return [[element.zeros() for element in row] for row in self.doit().args[0]] + + def eval_frequency(self, other): + """ + Evaluates system response of each transfer function in the ``TransferFunctionMatrix`` at any point in the real or complex plane. + + Examples + ======== + + >>> from sympy.abc import s + >>> from sympy.physics.control.lti import TransferFunction, TransferFunctionMatrix + >>> from sympy import I + >>> tf_1 = TransferFunction(3, (s + 1), s) + >>> tf_2 = TransferFunction(s + 6, (s + 1)*(s + 2), s) + >>> tf_3 = TransferFunction(s + 3, s**2 + 3*s + 2, s) + >>> tf_4 = TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s) + >>> tfm_1 = TransferFunctionMatrix([[tf_1, tf_2], [tf_3, tf_4]]) + >>> tfm_1 + TransferFunctionMatrix(((TransferFunction(3, s + 1, s), TransferFunction(s + 6, (s + 1)*(s + 2), s)), (TransferFunction(s + 3, s**2 + 3*s + 2, s), TransferFunction(s**2 - 9*s + 20, s**2 + 5*s - 10, s)))) + >>> tfm_1.eval_frequency(2) + Matrix([ + [ 1, 2/3], + [5/12, 3/2]]) + >>> tfm_1.eval_frequency(I*2) + Matrix([ + [ 3/5 - 6*I/5, -I], + [3/20 - 11*I/20, -101/74 + 23*I/74]]) + """ + mat = self._expr_mat.subs(self.var, other) + return mat.expand() + + def _flat(self): + """Returns flattened list of args in TransferFunctionMatrix""" + return [elem for tup in self.args[0] for elem in tup] + + def _eval_evalf(self, prec): + """Calls evalf() on each transfer function in the transfer function matrix""" + dps = prec_to_dps(prec) + mat = self._expr_mat.applyfunc(lambda a: a.evalf(n=dps)) + return _to_TFM(mat, self.var) + + def _eval_simplify(self, **kwargs): + """Simplifies the transfer function matrix""" + simp_mat = self._expr_mat.applyfunc(lambda a: cancel(a, expand=False)) + return _to_TFM(simp_mat, self.var) + + def expand(self, **hints): + """Expands the transfer function matrix""" + expand_mat = self._expr_mat.expand(**hints) + return _to_TFM(expand_mat, self.var) + +class StateSpace(LinearTimeInvariant): + r""" + State space model (ssm) of a linear, time invariant control system. + + Represents the standard state-space model with A, B, C, D as state-space matrices. + This makes the linear control system: + (1) x'(t) = A * x(t) + B * u(t); x in R^n , u in R^k + (2) y(t) = C * x(t) + D * u(t); y in R^m + where u(t) is any input signal, y(t) the corresponding output, and x(t) the system's state. + + Parameters + ========== + + A : Matrix + The State matrix of the state space model. + B : Matrix + The Input-to-State matrix of the state space model. + C : Matrix + The State-to-Output matrix of the state space model. + D : Matrix + The Feedthrough matrix of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + + The easiest way to create a StateSpaceModel is via four matrices: + + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> StateSpace(A, B, C, D) + StateSpace(Matrix([ + [1, 2], + [1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 1]]), Matrix([[0]])) + + + One can use less matrices. The rest will be filled with a minimum of zeros: + + >>> StateSpace(A, B) + StateSpace(Matrix([ + [1, 2], + [1, 0]]), Matrix([ + [1], + [1]]), Matrix([[0, 0]]), Matrix([[0]])) + + + See Also + ======== + + TransferFunction, TransferFunctionMatrix + + References + ========== + .. [1] https://en.wikipedia.org/wiki/State-space_representation + .. [2] https://in.mathworks.com/help/control/ref/ss.html + + """ + def __new__(cls, A=None, B=None, C=None, D=None): + if A is None: + A = zeros(1) + if B is None: + B = zeros(A.rows, 1) + if C is None: + C = zeros(1, A.cols) + if D is None: + D = zeros(C.rows, B.cols) + + A = _sympify(A) + B = _sympify(B) + C = _sympify(C) + D = _sympify(D) + + if (isinstance(A, ImmutableDenseMatrix) and isinstance(B, ImmutableDenseMatrix) and + isinstance(C, ImmutableDenseMatrix) and isinstance(D, ImmutableDenseMatrix)): + # Check State Matrix is square + if A.rows != A.cols: + raise ShapeError("Matrix A must be a square matrix.") + + # Check State and Input matrices have same rows + if A.rows != B.rows: + raise ShapeError("Matrices A and B must have the same number of rows.") + + # Check Ouput and Feedthrough matrices have same rows + if C.rows != D.rows: + raise ShapeError("Matrices C and D must have the same number of rows.") + + # Check State and Ouput matrices have same columns + if A.cols != C.cols: + raise ShapeError("Matrices A and C must have the same number of columns.") + + # Check Input and Feedthrough matrices have same columns + if B.cols != D.cols: + raise ShapeError("Matrices B and D must have the same number of columns.") + + obj = super(StateSpace, cls).__new__(cls, A, B, C, D) + obj._A = A + obj._B = B + obj._C = C + obj._D = D + + # Determine if the system is SISO or MIMO + num_outputs = D.rows + num_inputs = D.cols + if num_inputs == 1 and num_outputs == 1: + obj._is_SISO = True + obj._clstype = SISOLinearTimeInvariant + else: + obj._is_SISO = False + obj._clstype = MIMOLinearTimeInvariant + + return obj + + else: + raise TypeError("A, B, C and D inputs must all be sympy Matrices.") + + @property + def state_matrix(self): + """ + Returns the state matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.state_matrix + Matrix([ + [1, 2], + [1, 0]]) + + """ + return self._A + + @property + def input_matrix(self): + """ + Returns the input matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.input_matrix + Matrix([ + [1], + [1]]) + + """ + return self._B + + @property + def output_matrix(self): + """ + Returns the output matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.output_matrix + Matrix([[0, 1]]) + + """ + return self._C + + @property + def feedforward_matrix(self): + """ + Returns the feedforward matrix of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.feedforward_matrix + Matrix([[0]]) + + """ + return self._D + + @property + def num_states(self): + """ + Returns the number of states of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_states + 2 + + """ + return self._A.rows + + @property + def num_inputs(self): + """ + Returns the number of inputs of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_inputs + 1 + + """ + return self._D.cols + + @property + def num_outputs(self): + """ + Returns the number of outputs of the model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[1, 2], [1, 0]]) + >>> B = Matrix([1, 1]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.num_outputs + 1 + + """ + return self._D.rows + + def _eval_evalf(self, prec): + """ + Returns state space model where numerical expressions are evaluated into floating point numbers. + """ + dps = prec_to_dps(prec) + return StateSpace( + self._A.evalf(n = dps), + self._B.evalf(n = dps), + self._C.evalf(n = dps), + self._D.evalf(n = dps)) + + def _eval_rewrite_as_TransferFunction(self, *args): + """ + Returns the equivalent Transfer Function of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import TransferFunction, StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.rewrite(TransferFunction) + [[TransferFunction(12*s + 59, s**2 + 6*s + 8, s)]] + + """ + s = Symbol('s') + n = self._A.shape[0] + I = eye(n) + G = self._C*(s*I - self._A).solve(self._B) + self._D + G = G.simplify() + to_tf = lambda expr: TransferFunction.from_rational_expression(expr, s) + tf_mat = [[to_tf(expr) for expr in sublist] for sublist in G.tolist()] + return tf_mat + + def __add__(self, other): + """ + Add two State Space systems (parallel connection). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1 + ss2 + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [ 2], + [-2]]), Matrix([[-1, 1]]), Matrix([[0]])) + + """ + # Check for scalars + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + B = self._B + C = self._C + D = self._D.applyfunc(lambda element: element + other) + + else: + # Check nature of system + if not isinstance(other, StateSpace): + raise ValueError("Addition is only supported for 2 State Space models.") + # Check dimensions of system + elif ((self.num_inputs != other.num_inputs) or (self.num_outputs != other.num_outputs)): + raise ShapeError("Systems with incompatible inputs and outputs cannot be added.") + + m1 = (self._A).row_join(zeros(self._A.shape[0], other._A.shape[-1])) + m2 = zeros(other._A.shape[0], self._A.shape[-1]).row_join(other._A) + + A = m1.col_join(m2) + B = self._B.col_join(other._B) + C = self._C.row_join(other._C) + D = self._D + other._D + + return StateSpace(A, B, C, D) + + def __radd__(self, other): + """ + Right add two State Space systems. + + Examples + ======== + + >>> from sympy.physics.control import StateSpace + >>> s = StateSpace() + >>> 5 + s + StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[5]])) + + """ + return self + other + + def __sub__(self, other): + """ + Subtract two State Space systems. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1 - ss2 + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [ 2], + [-2]]), Matrix([[-1, -1]]), Matrix([[-4]])) + + """ + return self + (-other) + + def __rsub__(self, other): + """ + Right subtract two tate Space systems. + + Examples + ======== + + >>> from sympy.physics.control import StateSpace + >>> s = StateSpace() + >>> 5 - s + StateSpace(Matrix([[0]]), Matrix([[0]]), Matrix([[0]]), Matrix([[5]])) + + """ + return other + (-self) + + def __neg__(self): + """ + Returns the negation of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> -ss + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [2], + [5]]), Matrix([[-1, -2]]), Matrix([[0]])) + + """ + return StateSpace(self._A, self._B, -self._C, -self._D) + + def __mul__(self, other): + """ + Multiplication of two State Space systems (serial connection). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> ss*5 + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [2], + [5]]), Matrix([[5, 10]]), Matrix([[0]])) + + """ + # Check for scalars + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + B = self._B + C = self._C.applyfunc(lambda element: element*other) + D = self._D.applyfunc(lambda element: element*other) + + else: + # Check nature of system + if not isinstance(other, StateSpace): + raise ValueError("Multiplication is only supported for 2 State Space models.") + # Check dimensions of system + elif self.num_inputs != other.num_outputs: + raise ShapeError("Systems with incompatible inputs and outputs cannot be multiplied.") + + m1 = (other._A).row_join(zeros(other._A.shape[0], self._A.shape[1])) + m2 = (self._B * other._C).row_join(self._A) + + A = m1.col_join(m2) + B = (other._B).col_join(self._B * other._D) + C = (self._D * other._C).row_join(self._C) + D = self._D * other._D + + return StateSpace(A, B, C, D) + + def __rmul__(self, other): + """ + Right multiply two tate Space systems. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-5, -1], [3, -1]]) + >>> B = Matrix([2, 5]) + >>> C = Matrix([[1, 2]]) + >>> D = Matrix([0]) + >>> ss = StateSpace(A, B, C, D) + >>> 5*ss + StateSpace(Matrix([ + [-5, -1], + [ 3, -1]]), Matrix([ + [10], + [25]]), Matrix([[1, 2]]), Matrix([[0]])) + + """ + if isinstance(other, (int, float, complex, Symbol)): + A = self._A + C = self._C + B = self._B.applyfunc(lambda element: element*other) + D = self._D.applyfunc(lambda element: element*other) + return StateSpace(A, B, C, D) + else: + return self*other + + def __repr__(self): + A_str = self._A.__repr__() + B_str = self._B.__repr__() + C_str = self._C.__repr__() + D_str = self._D.__repr__() + + return f"StateSpace(\n{A_str},\n\n{B_str},\n\n{C_str},\n\n{D_str})" + + + def append(self, other): + """ + Returns the first model appended with the second model. The order is preserved. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A1 = Matrix([[1]]) + >>> B1 = Matrix([[2]]) + >>> C1 = Matrix([[-1]]) + >>> D1 = Matrix([[-2]]) + >>> A2 = Matrix([[-1]]) + >>> B2 = Matrix([[-2]]) + >>> C2 = Matrix([[1]]) + >>> D2 = Matrix([[2]]) + >>> ss1 = StateSpace(A1, B1, C1, D1) + >>> ss2 = StateSpace(A2, B2, C2, D2) + >>> ss1.append(ss2) + StateSpace(Matrix([ + [1, 0], + [0, -1]]), Matrix([ + [2, 0], + [0, -2]]), Matrix([ + [-1, 0], + [ 0, 1]]), Matrix([ + [-2, 0], + [ 0, 2]])) + + """ + n = self.num_states + other.num_states + m = self.num_inputs + other.num_inputs + p = self.num_outputs + other.num_outputs + + A = zeros(n, n) + B = zeros(n, m) + C = zeros(p, n) + D = zeros(p, m) + + A[:self.num_states, :self.num_states] = self._A + A[self.num_states:, self.num_states:] = other._A + B[:self.num_states, :self.num_inputs] = self._B + B[self.num_states:, self.num_inputs:] = other._B + C[:self.num_outputs, :self.num_states] = self._C + C[self.num_outputs:, self.num_states:] = other._C + D[:self.num_outputs, :self.num_inputs] = self._D + D[self.num_outputs:, self.num_inputs:] = other._D + return StateSpace(A, B, C, D) + + def observability_matrix(self): + """ + Returns the observability matrix of the state space model: + [C, C * A^1, C * A^2, .. , C * A^(n-1)]; A in R^(n x n), C in R^(m x k) + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ob = ss.observability_matrix() + >>> ob + Matrix([ + [0, 1], + [1, 0]]) + + References + ========== + .. [1] https://in.mathworks.com/help/control/ref/statespacemodel.obsv.html + + """ + n = self.num_states + ob = self._C + for i in range(1,n): + ob = ob.col_join(self._C * self._A**i) + + return ob + + def observable_subspace(self): + """ + Returns the observable subspace of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ob_subspace = ss.observable_subspace() + >>> ob_subspace + [Matrix([ + [0], + [1]]), Matrix([ + [1], + [0]])] + + """ + return self.observability_matrix().columnspace() + + def is_observable(self): + """ + Returns if the state space model is observable. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.is_observable() + True + + """ + return self.observability_matrix().rank() == self.num_states + + def controllability_matrix(self): + """ + Returns the controllability matrix of the system: + [B, A * B, A^2 * B, .. , A^(n-1) * B]; A in R^(n x n), B in R^(n x m) + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.controllability_matrix() + Matrix([ + [0.5, -0.75], + [ 0, 0.5]]) + + References + ========== + .. [1] https://in.mathworks.com/help/control/ref/statespacemodel.ctrb.html + + """ + co = self._B + n = self._A.shape[0] + for i in range(1, n): + co = co.row_join(((self._A)**i) * self._B) + + return co + + def controllable_subspace(self): + """ + Returns the controllable subspace of the state space model. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> co_subspace = ss.controllable_subspace() + >>> co_subspace + [Matrix([ + [0.5], + [ 0]]), Matrix([ + [-0.75], + [ 0.5]])] + + """ + return self.controllability_matrix().columnspace() + + def is_controllable(self): + """ + Returns if the state space model is controllable. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.physics.control import StateSpace + >>> A = Matrix([[-1.5, -2], [1, 0]]) + >>> B = Matrix([0.5, 0]) + >>> C = Matrix([[0, 1]]) + >>> D = Matrix([1]) + >>> ss = StateSpace(A, B, C, D) + >>> ss.is_controllable() + True + + """ + return self.controllability_matrix().rank() == self.num_states diff --git a/lib/python3.10/site-packages/sympy/physics/hep/__init__.py b/lib/python3.10/site-packages/sympy/physics/hep/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/physics/hep/gamma_matrices.py b/lib/python3.10/site-packages/sympy/physics/hep/gamma_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..40c3d0754438902f304d01c2df354dd09f9ea257 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/hep/gamma_matrices.py @@ -0,0 +1,716 @@ +""" + Module to handle gamma matrices expressed as tensor objects. + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex + >>> from sympy.tensor.tensor import tensor_indices + >>> i = tensor_indices('i', LorentzIndex) + >>> G(i) + GammaMatrix(i) + + Note that there is already an instance of GammaMatrixHead in four dimensions: + GammaMatrix, which is simply declare as + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix + >>> from sympy.tensor.tensor import tensor_indices + >>> i = tensor_indices('i', LorentzIndex) + >>> GammaMatrix(i) + GammaMatrix(i) + + To access the metric tensor + + >>> LorentzIndex.metric + metric(LorentzIndex,LorentzIndex) + +""" +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.matrices.dense import eye +from sympy.matrices.expressions.trace import trace +from sympy.tensor.tensor import TensorIndexType, TensorIndex,\ + TensMul, TensAdd, tensor_mul, Tensor, TensorHead, TensorSymmetry + + +# DiracSpinorIndex = TensorIndexType('DiracSpinorIndex', dim=4, dummy_name="S") + + +LorentzIndex = TensorIndexType('LorentzIndex', dim=4, dummy_name="L") + + +GammaMatrix = TensorHead("GammaMatrix", [LorentzIndex], + TensorSymmetry.no_symmetry(1), comm=None) + + +def extract_type_tens(expression, component): + """ + Extract from a ``TensExpr`` all tensors with `component`. + + Returns two tensor expressions: + + * the first contains all ``Tensor`` of having `component`. + * the second contains all remaining. + + + """ + if isinstance(expression, Tensor): + sp = [expression] + elif isinstance(expression, TensMul): + sp = expression.args + else: + raise ValueError('wrong type') + + # Collect all gamma matrices of the same dimension + new_expr = S.One + residual_expr = S.One + for i in sp: + if isinstance(i, Tensor) and i.component == component: + new_expr *= i + else: + residual_expr *= i + return new_expr, residual_expr + + +def simplify_gamma_expression(expression): + extracted_expr, residual_expr = extract_type_tens(expression, GammaMatrix) + res_expr = _simplify_single_line(extracted_expr) + return res_expr * residual_expr + + +def simplify_gpgp(ex, sort=True): + """ + simplify products ``G(i)*p(-i)*G(j)*p(-j) -> p(i)*p(-i)`` + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, simplify_gpgp + >>> from sympy.tensor.tensor import tensor_indices, tensor_heads + >>> p, q = tensor_heads('p, q', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> ps = p(i0)*G(-i0) + >>> qs = q(i0)*G(-i0) + >>> simplify_gpgp(ps*qs*qs) + GammaMatrix(-L_0)*p(L_0)*q(L_1)*q(-L_1) + """ + def _simplify_gpgp(ex): + components = ex.components + a = [] + comp_map = [] + for i, comp in enumerate(components): + comp_map.extend([i]*comp.rank) + dum = [(i[0], i[1], comp_map[i[0]], comp_map[i[1]]) for i in ex.dum] + for i in range(len(components)): + if components[i] != GammaMatrix: + continue + for dx in dum: + if dx[2] == i: + p_pos1 = dx[3] + elif dx[3] == i: + p_pos1 = dx[2] + else: + continue + comp1 = components[p_pos1] + if comp1.comm == 0 and comp1.rank == 1: + a.append((i, p_pos1)) + if not a: + return ex + elim = set() + tv = [] + hit = True + coeff = S.One + ta = None + while hit: + hit = False + for i, ai in enumerate(a[:-1]): + if ai[0] in elim: + continue + if ai[0] != a[i + 1][0] - 1: + continue + if components[ai[1]] != components[a[i + 1][1]]: + continue + elim.add(ai[0]) + elim.add(ai[1]) + elim.add(a[i + 1][0]) + elim.add(a[i + 1][1]) + if not ta: + ta = ex.split() + mu = TensorIndex('mu', LorentzIndex) + hit = True + if i == 0: + coeff = ex.coeff + tx = components[ai[1]](mu)*components[ai[1]](-mu) + if len(a) == 2: + tx *= 4 # eye(4) + tv.append(tx) + break + + if tv: + a = [x for j, x in enumerate(ta) if j not in elim] + a.extend(tv) + t = tensor_mul(*a)*coeff + # t = t.replace(lambda x: x.is_Matrix, lambda x: 1) + return t + else: + return ex + + if sort: + ex = ex.sorted_components() + # this would be better off with pattern matching + while 1: + t = _simplify_gpgp(ex) + if t != ex: + ex = t + else: + return t + + +def gamma_trace(t): + """ + trace of a single line of gamma matrices + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + gamma_trace, LorentzIndex + >>> from sympy.tensor.tensor import tensor_indices, tensor_heads + >>> p, q = tensor_heads('p, q', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> ps = p(i0)*G(-i0) + >>> qs = q(i0)*G(-i0) + >>> gamma_trace(G(i0)*G(i1)) + 4*metric(i0, i1) + >>> gamma_trace(ps*ps) - 4*p(i0)*p(-i0) + 0 + >>> gamma_trace(ps*qs + ps*ps) - 4*p(i0)*p(-i0) - 4*p(i0)*q(-i0) + 0 + + """ + if isinstance(t, TensAdd): + res = TensAdd(*[gamma_trace(x) for x in t.args]) + return res + t = _simplify_single_line(t) + res = _trace_single_line(t) + return res + + +def _simplify_single_line(expression): + """ + Simplify single-line product of gamma matrices. + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, _simplify_single_line + >>> from sympy.tensor.tensor import tensor_indices, TensorHead + >>> p = TensorHead('p', [LorentzIndex]) + >>> i0,i1 = tensor_indices('i0:2', LorentzIndex) + >>> _simplify_single_line(G(i0)*G(i1)*p(-i1)*G(-i0)) + 2*G(i0)*p(-i0) + 0 + + """ + t1, t2 = extract_type_tens(expression, GammaMatrix) + if t1 != 1: + t1 = kahane_simplify(t1) + res = t1*t2 + return res + + +def _trace_single_line(t): + """ + Evaluate the trace of a single gamma matrix line inside a ``TensExpr``. + + Notes + ===== + + If there are ``DiracSpinorIndex.auto_left`` and ``DiracSpinorIndex.auto_right`` + indices trace over them; otherwise traces are not implied (explain) + + + Examples + ======== + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, \ + LorentzIndex, _trace_single_line + >>> from sympy.tensor.tensor import tensor_indices, TensorHead + >>> p = TensorHead('p', [LorentzIndex]) + >>> i0,i1,i2,i3,i4,i5 = tensor_indices('i0:6', LorentzIndex) + >>> _trace_single_line(G(i0)*G(i1)) + 4*metric(i0, i1) + >>> _trace_single_line(G(i0)*p(-i0)*G(i1)*p(-i1)) - 4*p(i0)*p(-i0) + 0 + + """ + def _trace_single_line1(t): + t = t.sorted_components() + components = t.components + ncomps = len(components) + g = LorentzIndex.metric + # gamma matirices are in a[i:j] + hit = 0 + for i in range(ncomps): + if components[i] == GammaMatrix: + hit = 1 + break + + for j in range(i + hit, ncomps): + if components[j] != GammaMatrix: + break + else: + j = ncomps + numG = j - i + if numG == 0: + tcoeff = t.coeff + return t.nocoeff if tcoeff else t + if numG % 2 == 1: + return TensMul.from_data(S.Zero, [], [], []) + elif numG > 4: + # find the open matrix indices and connect them: + a = t.split() + ind1 = a[i].get_indices()[0] + ind2 = a[i + 1].get_indices()[0] + aa = a[:i] + a[i + 2:] + t1 = tensor_mul(*aa)*g(ind1, ind2) + t1 = t1.contract_metric(g) + args = [t1] + sign = 1 + for k in range(i + 2, j): + sign = -sign + ind2 = a[k].get_indices()[0] + aa = a[:i] + a[i + 1:k] + a[k + 1:] + t2 = sign*tensor_mul(*aa)*g(ind1, ind2) + t2 = t2.contract_metric(g) + t2 = simplify_gpgp(t2, False) + args.append(t2) + t3 = TensAdd(*args) + t3 = _trace_single_line(t3) + return t3 + else: + a = t.split() + t1 = _gamma_trace1(*a[i:j]) + a2 = a[:i] + a[j:] + t2 = tensor_mul(*a2) + t3 = t1*t2 + if not t3: + return t3 + t3 = t3.contract_metric(g) + return t3 + + t = t.expand() + if isinstance(t, TensAdd): + a = [_trace_single_line1(x)*x.coeff for x in t.args] + return TensAdd(*a) + elif isinstance(t, (Tensor, TensMul)): + r = t.coeff*_trace_single_line1(t) + return r + else: + return trace(t) + + +def _gamma_trace1(*a): + gctr = 4 # FIXME specific for d=4 + g = LorentzIndex.metric + if not a: + return gctr + n = len(a) + if n%2 == 1: + #return TensMul.from_data(S.Zero, [], [], []) + return S.Zero + if n == 2: + ind0 = a[0].get_indices()[0] + ind1 = a[1].get_indices()[0] + return gctr*g(ind0, ind1) + if n == 4: + ind0 = a[0].get_indices()[0] + ind1 = a[1].get_indices()[0] + ind2 = a[2].get_indices()[0] + ind3 = a[3].get_indices()[0] + + return gctr*(g(ind0, ind1)*g(ind2, ind3) - \ + g(ind0, ind2)*g(ind1, ind3) + g(ind0, ind3)*g(ind1, ind2)) + + +def kahane_simplify(expression): + r""" + This function cancels contracted elements in a product of four + dimensional gamma matrices, resulting in an expression equal to the given + one, without the contracted gamma matrices. + + Parameters + ========== + + `expression` the tensor expression containing the gamma matrices to simplify. + + Notes + ===== + + If spinor indices are given, the matrices must be given in + the order given in the product. + + Algorithm + ========= + + The idea behind the algorithm is to use some well-known identities, + i.e., for contractions enclosing an even number of `\gamma` matrices + + `\gamma^\mu \gamma_{a_1} \cdots \gamma_{a_{2N}} \gamma_\mu = 2 (\gamma_{a_{2N}} \gamma_{a_1} \cdots \gamma_{a_{2N-1}} + \gamma_{a_{2N-1}} \cdots \gamma_{a_1} \gamma_{a_{2N}} )` + + for an odd number of `\gamma` matrices + + `\gamma^\mu \gamma_{a_1} \cdots \gamma_{a_{2N+1}} \gamma_\mu = -2 \gamma_{a_{2N+1}} \gamma_{a_{2N}} \cdots \gamma_{a_{1}}` + + Instead of repeatedly applying these identities to cancel out all contracted indices, + it is possible to recognize the links that would result from such an operation, + the problem is thus reduced to a simple rearrangement of free gamma matrices. + + Examples + ======== + + When using, always remember that the original expression coefficient + has to be handled separately + + >>> from sympy.physics.hep.gamma_matrices import GammaMatrix as G, LorentzIndex + >>> from sympy.physics.hep.gamma_matrices import kahane_simplify + >>> from sympy.tensor.tensor import tensor_indices + >>> i0, i1, i2 = tensor_indices('i0:3', LorentzIndex) + >>> ta = G(i0)*G(-i0) + >>> kahane_simplify(ta) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + >>> tb = G(i0)*G(i1)*G(-i0) + >>> kahane_simplify(tb) + -2*GammaMatrix(i1) + >>> t = G(i0)*G(-i0) + >>> kahane_simplify(t) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + >>> t = G(i0)*G(-i0) + >>> kahane_simplify(t) + Matrix([ + [4, 0, 0, 0], + [0, 4, 0, 0], + [0, 0, 4, 0], + [0, 0, 0, 4]]) + + If there are no contractions, the same expression is returned + + >>> tc = G(i0)*G(i1) + >>> kahane_simplify(tc) + GammaMatrix(i0)*GammaMatrix(i1) + + References + ========== + + [1] Algorithm for Reducing Contracted Products of gamma Matrices, + Joseph Kahane, Journal of Mathematical Physics, Vol. 9, No. 10, October 1968. + """ + + if isinstance(expression, Mul): + return expression + if isinstance(expression, TensAdd): + return TensAdd(*[kahane_simplify(arg) for arg in expression.args]) + + if isinstance(expression, Tensor): + return expression + + assert isinstance(expression, TensMul) + + gammas = expression.args + + for gamma in gammas: + assert gamma.component == GammaMatrix + + free = expression.free + # spinor_free = [_ for _ in expression.free_in_args if _[1] != 0] + + # if len(spinor_free) == 2: + # spinor_free.sort(key=lambda x: x[2]) + # assert spinor_free[0][1] == 1 and spinor_free[-1][1] == 2 + # assert spinor_free[0][2] == 0 + # elif spinor_free: + # raise ValueError('spinor indices do not match') + + dum = [] + for dum_pair in expression.dum: + if expression.index_types[dum_pair[0]] == LorentzIndex: + dum.append((dum_pair[0], dum_pair[1])) + + dum = sorted(dum) + + if len(dum) == 0: # or GammaMatrixHead: + # no contractions in `expression`, just return it. + return expression + + # find the `first_dum_pos`, i.e. the position of the first contracted + # gamma matrix, Kahane's algorithm as described in his paper requires the + # gamma matrix expression to start with a contracted gamma matrix, this is + # a workaround which ignores possible initial free indices, and re-adds + # them later. + + first_dum_pos = min(map(min, dum)) + + # for p1, p2, a1, a2 in expression.dum_in_args: + # if p1 != 0 or p2 != 0: + # # only Lorentz indices, skip Dirac indices: + # continue + # first_dum_pos = min(p1, p2) + # break + + total_number = len(free) + len(dum)*2 + number_of_contractions = len(dum) + + free_pos = [None]*total_number + for i in free: + free_pos[i[1]] = i[0] + + # `index_is_free` is a list of booleans, to identify index position + # and whether that index is free or dummy. + index_is_free = [False]*total_number + + for i, indx in enumerate(free): + index_is_free[indx[1]] = True + + # `links` is a dictionary containing the graph described in Kahane's paper, + # to every key correspond one or two values, representing the linked indices. + # All values in `links` are integers, negative numbers are used in the case + # where it is necessary to insert gamma matrices between free indices, in + # order to make Kahane's algorithm work (see paper). + links = {i: [] for i in range(first_dum_pos, total_number)} + + # `cum_sign` is a step variable to mark the sign of every index, see paper. + cum_sign = -1 + # `cum_sign_list` keeps storage for all `cum_sign` (every index). + cum_sign_list = [None]*total_number + block_free_count = 0 + + # multiply `resulting_coeff` by the coefficient parameter, the rest + # of the algorithm ignores a scalar coefficient. + resulting_coeff = S.One + + # initialize a list of lists of indices. The outer list will contain all + # additive tensor expressions, while the inner list will contain the + # free indices (rearranged according to the algorithm). + resulting_indices = [[]] + + # start to count the `connected_components`, which together with the number + # of contractions, determines a -1 or +1 factor to be multiplied. + connected_components = 1 + + # First loop: here we fill `cum_sign_list`, and draw the links + # among consecutive indices (they are stored in `links`). Links among + # non-consecutive indices will be drawn later. + for i, is_free in enumerate(index_is_free): + # if `expression` starts with free indices, they are ignored here; + # they are later added as they are to the beginning of all + # `resulting_indices` list of lists of indices. + if i < first_dum_pos: + continue + + if is_free: + block_free_count += 1 + # if previous index was free as well, draw an arch in `links`. + if block_free_count > 1: + links[i - 1].append(i) + links[i].append(i - 1) + else: + # Change the sign of the index (`cum_sign`) if the number of free + # indices preceding it is even. + cum_sign *= 1 if (block_free_count % 2) else -1 + if block_free_count == 0 and i != first_dum_pos: + # check if there are two consecutive dummy indices: + # in this case create virtual indices with negative position, + # these "virtual" indices represent the insertion of two + # gamma^0 matrices to separate consecutive dummy indices, as + # Kahane's algorithm requires dummy indices to be separated by + # free indices. The product of two gamma^0 matrices is unity, + # so the new expression being examined is the same as the + # original one. + if cum_sign == -1: + links[-1-i] = [-1-i+1] + links[-1-i+1] = [-1-i] + if (i - cum_sign) in links: + if i != first_dum_pos: + links[i].append(i - cum_sign) + if block_free_count != 0: + if i - cum_sign < len(index_is_free): + if index_is_free[i - cum_sign]: + links[i - cum_sign].append(i) + block_free_count = 0 + + cum_sign_list[i] = cum_sign + + # The previous loop has only created links between consecutive free indices, + # it is necessary to properly create links among dummy (contracted) indices, + # according to the rules described in Kahane's paper. There is only one exception + # to Kahane's rules: the negative indices, which handle the case of some + # consecutive free indices (Kahane's paper just describes dummy indices + # separated by free indices, hinting that free indices can be added without + # altering the expression result). + for i in dum: + # get the positions of the two contracted indices: + pos1 = i[0] + pos2 = i[1] + + # create Kahane's upper links, i.e. the upper arcs between dummy + # (i.e. contracted) indices: + links[pos1].append(pos2) + links[pos2].append(pos1) + + # create Kahane's lower links, this corresponds to the arcs below + # the line described in the paper: + + # first we move `pos1` and `pos2` according to the sign of the indices: + linkpos1 = pos1 + cum_sign_list[pos1] + linkpos2 = pos2 + cum_sign_list[pos2] + + # otherwise, perform some checks before creating the lower arcs: + + # make sure we are not exceeding the total number of indices: + if linkpos1 >= total_number: + continue + if linkpos2 >= total_number: + continue + + # make sure we are not below the first dummy index in `expression`: + if linkpos1 < first_dum_pos: + continue + if linkpos2 < first_dum_pos: + continue + + # check if the previous loop created "virtual" indices between dummy + # indices, in such a case relink `linkpos1` and `linkpos2`: + if (-1-linkpos1) in links: + linkpos1 = -1-linkpos1 + if (-1-linkpos2) in links: + linkpos2 = -1-linkpos2 + + # move only if not next to free index: + if linkpos1 >= 0 and not index_is_free[linkpos1]: + linkpos1 = pos1 + + if linkpos2 >=0 and not index_is_free[linkpos2]: + linkpos2 = pos2 + + # create the lower arcs: + if linkpos2 not in links[linkpos1]: + links[linkpos1].append(linkpos2) + if linkpos1 not in links[linkpos2]: + links[linkpos2].append(linkpos1) + + # This loop starts from the `first_dum_pos` index (first dummy index) + # walks through the graph deleting the visited indices from `links`, + # it adds a gamma matrix for every free index in encounters, while it + # completely ignores dummy indices and virtual indices. + pointer = first_dum_pos + previous_pointer = 0 + while True: + if pointer in links: + next_ones = links.pop(pointer) + else: + break + + if previous_pointer in next_ones: + next_ones.remove(previous_pointer) + + previous_pointer = pointer + + if next_ones: + pointer = next_ones[0] + else: + break + + if pointer == previous_pointer: + break + if pointer >=0 and free_pos[pointer] is not None: + for ri in resulting_indices: + ri.append(free_pos[pointer]) + + # The following loop removes the remaining connected components in `links`. + # If there are free indices inside a connected component, it gives a + # contribution to the resulting expression given by the factor + # `gamma_a gamma_b ... gamma_z + gamma_z ... gamma_b gamma_a`, in Kahanes's + # paper represented as {gamma_a, gamma_b, ... , gamma_z}, + # virtual indices are ignored. The variable `connected_components` is + # increased by one for every connected component this loop encounters. + + # If the connected component has virtual and dummy indices only + # (no free indices), it contributes to `resulting_indices` by a factor of two. + # The multiplication by two is a result of the + # factor {gamma^0, gamma^0} = 2 I, as it appears in Kahane's paper. + # Note: curly brackets are meant as in the paper, as a generalized + # multi-element anticommutator! + + while links: + connected_components += 1 + pointer = min(links.keys()) + previous_pointer = pointer + # the inner loop erases the visited indices from `links`, and it adds + # all free indices to `prepend_indices` list, virtual indices are + # ignored. + prepend_indices = [] + while True: + if pointer in links: + next_ones = links.pop(pointer) + else: + break + + if previous_pointer in next_ones: + if len(next_ones) > 1: + next_ones.remove(previous_pointer) + + previous_pointer = pointer + + if next_ones: + pointer = next_ones[0] + + if pointer >= first_dum_pos and free_pos[pointer] is not None: + prepend_indices.insert(0, free_pos[pointer]) + # if `prepend_indices` is void, it means there are no free indices + # in the loop (and it can be shown that there must be a virtual index), + # loops of virtual indices only contribute by a factor of two: + if len(prepend_indices) == 0: + resulting_coeff *= 2 + # otherwise, add the free indices in `prepend_indices` to + # the `resulting_indices`: + else: + expr1 = prepend_indices + expr2 = list(reversed(prepend_indices)) + resulting_indices = [expri + ri for ri in resulting_indices for expri in (expr1, expr2)] + + # sign correction, as described in Kahane's paper: + resulting_coeff *= -1 if (number_of_contractions - connected_components + 1) % 2 else 1 + # power of two factor, as described in Kahane's paper: + resulting_coeff *= 2**(number_of_contractions) + + # If `first_dum_pos` is not zero, it means that there are trailing free gamma + # matrices in front of `expression`, so multiply by them: + resulting_indices = [ free_pos[0:first_dum_pos] + ri for ri in resulting_indices ] + + resulting_expr = S.Zero + for i in resulting_indices: + temp_expr = S.One + for j in i: + temp_expr *= GammaMatrix(j) + resulting_expr += temp_expr + + t = resulting_coeff * resulting_expr + t1 = None + if isinstance(t, TensAdd): + t1 = t.args[0] + elif isinstance(t, TensMul): + t1 = t + if t1: + pass + else: + t = eye(4)*t + return t diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/__init__.py b/lib/python3.10/site-packages/sympy/physics/mechanics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0f1ac5d49d514eab763e56007096dd44cdde87dc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/__init__.py @@ -0,0 +1,90 @@ +__all__ = [ + 'vector', + + 'CoordinateSym', 'ReferenceFrame', 'Dyadic', 'Vector', 'Point', 'cross', + 'dot', 'express', 'time_derivative', 'outer', 'kinematic_equations', + 'get_motion_params', 'partial_velocity', 'dynamicsymbols', 'vprint', + 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', 'init_vprinting', 'curl', + 'divergence', 'gradient', 'is_conservative', 'is_solenoidal', + 'scalar_potential', 'scalar_potential_difference', + + 'KanesMethod', + + 'RigidBody', + + 'linear_momentum', 'angular_momentum', 'kinetic_energy', 'potential_energy', + 'Lagrangian', 'mechanics_printing', 'mprint', 'msprint', 'mpprint', + 'mlatex', 'msubs', 'find_dynamicsymbols', + + 'inertia', 'inertia_of_point_mass', 'Inertia', + + 'Force', 'Torque', + + 'Particle', + + 'LagrangesMethod', + + 'Linearizer', + + 'Body', + + 'SymbolicSystem', 'System', + + 'PinJoint', 'PrismaticJoint', 'CylindricalJoint', 'PlanarJoint', + 'SphericalJoint', 'WeldJoint', + + 'JointsMethod', + + 'WrappingCylinder', 'WrappingGeometryBase', 'WrappingSphere', + + 'PathwayBase', 'LinearPathway', 'ObstacleSetPathway', 'WrappingPathway', + + 'ActuatorBase', 'ForceActuator', 'LinearDamper', 'LinearSpring', + 'TorqueActuator', 'DuffingSpring' +] + +from sympy.physics import vector + +from sympy.physics.vector import (CoordinateSym, ReferenceFrame, Dyadic, Vector, Point, + cross, dot, express, time_derivative, outer, kinematic_equations, + get_motion_params, partial_velocity, dynamicsymbols, vprint, + vsstrrepr, vsprint, vpprint, vlatex, init_vprinting, curl, divergence, + gradient, is_conservative, is_solenoidal, scalar_potential, + scalar_potential_difference) + +from .kane import KanesMethod + +from .rigidbody import RigidBody + +from .functions import (linear_momentum, angular_momentum, kinetic_energy, + potential_energy, Lagrangian, mechanics_printing, + mprint, msprint, mpprint, mlatex, msubs, + find_dynamicsymbols) + +from .inertia import inertia, inertia_of_point_mass, Inertia + +from .loads import Force, Torque + +from .particle import Particle + +from .lagrange import LagrangesMethod + +from .linearize import Linearizer + +from .body import Body + +from .system import SymbolicSystem, System + +from .jointsmethod import JointsMethod + +from .joint import (PinJoint, PrismaticJoint, CylindricalJoint, PlanarJoint, + SphericalJoint, WeldJoint) + +from .wrapping_geometry import (WrappingCylinder, WrappingGeometryBase, + WrappingSphere) + +from .pathway import (PathwayBase, LinearPathway, ObstacleSetPathway, + WrappingPathway) + +from .actuator import (ActuatorBase, ForceActuator, LinearDamper, LinearSpring, + TorqueActuator, DuffingSpring) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/actuator.py b/lib/python3.10/site-packages/sympy/physics/mechanics/actuator.py new file mode 100644 index 0000000000000000000000000000000000000000..537b21444b7073b6c60bcb81cb038f15cb864c55 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/actuator.py @@ -0,0 +1,992 @@ +"""Implementations of actuators for linked force and torque application.""" + +from abc import ABC, abstractmethod + +from sympy import S, sympify +from sympy.physics.mechanics.joint import PinJoint +from sympy.physics.mechanics.loads import Torque +from sympy.physics.mechanics.pathway import PathwayBase +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.physics.vector import ReferenceFrame, Vector + + +__all__ = [ + 'ActuatorBase', + 'ForceActuator', + 'LinearDamper', + 'LinearSpring', + 'TorqueActuator', + 'DuffingSpring' +] + + +class ActuatorBase(ABC): + """Abstract base class for all actuator classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom actuator types through subclassing. + + """ + + def __init__(self): + """Initializer for ``ActuatorBase``.""" + pass + + @abstractmethod + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + """ + pass + + def __repr__(self): + """Default representation of an actuator.""" + return f'{self.__class__.__name__}()' + + +class ForceActuator(ActuatorBase): + """Force-producing actuator. + + Explanation + =========== + + A ``ForceActuator`` is an actuator that produces a (expansile) force along + its length. + + A force actuator uses a pathway instance to determine the direction and + number of forces that it applies to a system. Consider the simplest case + where a ``LinearPathway`` instance is used. This pathway is made up of two + points that can move relative to each other, and results in a pair of equal + and opposite forces acting on the endpoints. If the positive time-varying + Euclidean distance between the two points is defined, then the "extension + velocity" is the time derivative of this distance. The extension velocity + is positive when the two points are moving away from each other and + negative when moving closer to each other. The direction for the force + acting on either point is determined by constructing a unit vector directed + from the other point to this point. This establishes a sign convention such + that a positive force magnitude tends to push the points apart, this is the + meaning of "expansile" in this context. The following diagram shows the + positive force sense and the distance between the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct an actuator, an expression (or symbol) must be supplied to + represent the force it can produce, alongside a pathway specifying its line + of action. Let's also create a global reference frame and spatially fix one + of the points in it while setting the other to be positioned such that it + can freely move in the frame's x direction specified by the coordinate + ``q``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (ForceActuator, LinearPathway, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> force = symbols('F') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> actuator = ForceActuator(force, linear_pathway) + >>> actuator + ForceActuator(F, LinearPathway(pA, pB)) + + Parameters + ========== + + force : Expr + The scalar expression defining the (expansile) force that the actuator + produces. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + + def __init__(self, force, pathway): + """Initializer for ``ForceActuator``. + + Parameters + ========== + + force : Expr + The scalar expression defining the (expansile) force that the + actuator produces. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + self.force = force + self.pathway = pathway + + @property + def force(self): + """The magnitude of the force produced by the actuator.""" + return self._force + + @force.setter + def force(self, force): + if hasattr(self, '_force'): + msg = ( + f'Can\'t set attribute `force` to {repr(force)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._force = sympify(force, strict=True) + + @property + def pathway(self): + """The ``Pathway`` defining the actuator's line of action.""" + return self._pathway + + @pathway.setter + def pathway(self, pathway): + if hasattr(self, '_pathway'): + msg = ( + f'Can\'t set attribute `pathway` to {repr(pathway)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(pathway, PathwayBase): + msg = ( + f'Value {repr(pathway)} passed to `pathway` was of type ' + f'{type(pathway)}, must be {PathwayBase}.' + ) + raise TypeError(msg) + self._pathway = pathway + + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced by a force + actuator that follows a linear pathway. In this example we'll assume + that the force actuator is being used to model a simple linear spring. + First, create a linear pathway between two points separated by the + coordinate ``q`` in the ``x`` direction of the global frame ``N``. + + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> pA, pB = Point('pA'), Point('pB') + >>> pB.set_pos(pA, q*N.x) + >>> pathway = LinearPathway(pA, pB) + + Now create a symbol ``k`` to describe the spring's stiffness and + instantiate a force actuator that produces a (contractile) force + proportional to both the spring's stiffness and the pathway's length. + Note that actuator classes use the sign convention that expansile + forces are positive, so for a spring to produce a contractile force the + spring force needs to be calculated as the negative for the stiffness + multiplied by the length. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ForceActuator + >>> stiffness = symbols('k') + >>> spring_force = -stiffness*pathway.length + >>> spring = ForceActuator(spring_force, pathway) + + The forces produced by the spring can be generated in the list of loads + form that ``KanesMethod`` (and other equations of motion methods) + requires by calling the ``to_loads`` method. + + >>> spring.to_loads() + [(pA, k*q(t)*N.x), (pB, - k*q(t)*N.x)] + + A simple linear damper can be modeled in a similar way. Create another + symbol ``c`` to describe the dampers damping coefficient. This time + instantiate a force actuator that produces a force proportional to both + the damper's damping coefficient and the pathway's extension velocity. + Note that the damping force is negative as it acts in the opposite + direction to which the damper is changing in length. + + >>> damping_coefficient = symbols('c') + >>> damping_force = -damping_coefficient*pathway.extension_velocity + >>> damper = ForceActuator(damping_force, pathway) + + Again, the forces produces by the damper can be generated by calling + the ``to_loads`` method. + + >>> damper.to_loads() + [(pA, c*Derivative(q(t), t)*N.x), (pB, - c*Derivative(q(t), t)*N.x)] + + """ + return self.pathway.to_loads(self.force) + + def __repr__(self): + """Representation of a ``ForceActuator``.""" + return f'{self.__class__.__name__}({self.force}, {self.pathway})' + + +class LinearSpring(ForceActuator): + """A spring with its spring force as a linear function of its length. + + Explanation + =========== + + Note that the "linear" in the name ``LinearSpring`` refers to the fact that + the spring force is a linear function of the springs length. I.e. for a + linear spring with stiffness ``k``, distance between its ends of ``x``, and + an equilibrium length of ``0``, the spring force will be ``-k*x``, which is + a linear function in ``x``. To create a spring that follows a linear, or + straight, pathway between its two ends, a ``LinearPathway`` instance needs + to be passed to the ``pathway`` parameter. + + A ``LinearSpring`` is a subclass of ``ForceActuator`` and so follows the + same sign conventions for length, extension velocity, and the direction of + the forces it applies to its points of attachment on bodies. The sign + convention for the direction of forces is such that, for the case where a + linear spring is instantiated with a ``LinearPathway`` instance as its + pathway, they act to push the two ends of the spring away from one another. + Because springs produces a contractile force and acts to pull the two ends + together towards the equilibrium length when stretched, the scalar portion + of the forces on the endpoint are negative in order to flip the sign of the + forces on the endpoints when converted into vector quantities. The + following diagram shows the positive force sense and the distance between + the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct a linear spring, an expression (or symbol) must be supplied to + represent the stiffness (spring constant) of the spring, alongside a + pathway specifying its line of action. Let's also create a global reference + frame and spatially fix one of the points in it while setting the other to + be positioned such that it can freely move in the frame's x direction + specified by the coordinate ``q``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearPathway, LinearSpring, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> stiffness = symbols('k') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> spring = LinearSpring(stiffness, linear_pathway) + >>> spring + LinearSpring(k, LinearPathway(pA, pB)) + + This spring will produce a force that is proportional to both its stiffness + and the pathway's length. Note that this force is negative as SymPy's sign + convention for actuators is that negative forces are contractile. + + >>> spring.force + -k*sqrt(q(t)**2) + + To create a linear spring with a non-zero equilibrium length, an expression + (or symbol) can be passed to the ``equilibrium_length`` parameter on + construction on a ``LinearSpring`` instance. Let's create a symbol ``l`` + to denote a non-zero equilibrium length and create another linear spring. + + >>> l = symbols('l') + >>> spring = LinearSpring(stiffness, linear_pathway, equilibrium_length=l) + >>> spring + LinearSpring(k, LinearPathway(pA, pB), equilibrium_length=l) + + The spring force of this new spring is again proportional to both its + stiffness and the pathway's length. However, the spring will not produce + any force when ``q(t)`` equals ``l``. Note that the force will become + expansile when ``q(t)`` is less than ``l``, as expected. + + >>> spring.force + -k*(-l + sqrt(q(t)**2)) + + Parameters + ========== + + stiffness : Expr + The spring constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium, i.e. it produces no + force. The default value is 0, i.e. the spring force is a linear + function of the pathway's length with no constant offset. + + See Also + ======== + + ForceActuator: force-producing actuator (superclass of ``LinearSpring``). + LinearPathway: straight-line pathway between a pair of points. + + """ + + def __init__(self, stiffness, pathway, equilibrium_length=S.Zero): + """Initializer for ``LinearSpring``. + + Parameters + ========== + + stiffness : Expr + The spring constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium, i.e. it produces + no force. The default value is 0, i.e. the spring force is a linear + function of the pathway's length with no constant offset. + + """ + self.stiffness = stiffness + self.pathway = pathway + self.equilibrium_length = equilibrium_length + + @property + def force(self): + """The spring force produced by the linear spring.""" + return -self.stiffness*(self.pathway.length - self.equilibrium_length) + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + @property + def stiffness(self): + """The spring constant for the linear spring.""" + return self._stiffness + + @stiffness.setter + def stiffness(self, stiffness): + if hasattr(self, '_stiffness'): + msg = ( + f'Can\'t set attribute `stiffness` to {repr(stiffness)} as it ' + f'is immutable.' + ) + raise AttributeError(msg) + self._stiffness = sympify(stiffness, strict=True) + + @property + def equilibrium_length(self): + """The length of the spring at which it produces no force.""" + return self._equilibrium_length + + @equilibrium_length.setter + def equilibrium_length(self, equilibrium_length): + if hasattr(self, '_equilibrium_length'): + msg = ( + f'Can\'t set attribute `equilibrium_length` to ' + f'{repr(equilibrium_length)} as it is immutable.' + ) + raise AttributeError(msg) + self._equilibrium_length = sympify(equilibrium_length, strict=True) + + def __repr__(self): + """Representation of a ``LinearSpring``.""" + string = f'{self.__class__.__name__}({self.stiffness}, {self.pathway}' + if self.equilibrium_length == S.Zero: + string += ')' + else: + string += f', equilibrium_length={self.equilibrium_length})' + return string + + +class LinearDamper(ForceActuator): + """A damper whose force is a linear function of its extension velocity. + + Explanation + =========== + + Note that the "linear" in the name ``LinearDamper`` refers to the fact that + the damping force is a linear function of the damper's rate of change in + its length. I.e. for a linear damper with damping ``c`` and extension + velocity ``v``, the damping force will be ``-c*v``, which is a linear + function in ``v``. To create a damper that follows a linear, or straight, + pathway between its two ends, a ``LinearPathway`` instance needs to be + passed to the ``pathway`` parameter. + + A ``LinearDamper`` is a subclass of ``ForceActuator`` and so follows the + same sign conventions for length, extension velocity, and the direction of + the forces it applies to its points of attachment on bodies. The sign + convention for the direction of forces is such that, for the case where a + linear damper is instantiated with a ``LinearPathway`` instance as its + pathway, they act to push the two ends of the damper away from one another. + Because dampers produce a force that opposes the direction of change in + length, when extension velocity is positive the scalar portions of the + forces applied at the two endpoints are negative in order to flip the sign + of the forces on the endpoints wen converted into vector quantities. When + extension velocity is negative (i.e. when the damper is shortening), the + scalar portions of the fofces applied are also negative so that the signs + cancel producing forces on the endpoints that are in the same direction as + the positive sign convention for the forces at the endpoints of the pathway + (i.e. they act to push the endpoints away from one another). The following + diagram shows the positive force sense and the distance between the + points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + To construct a linear damper, an expression (or symbol) must be supplied to + represent the damping coefficient of the damper (we'll use the symbol + ``c``), alongside a pathway specifying its line of action. Let's also + create a global reference frame and spatially fix one of the points in it + while setting the other to be positioned such that it can freely move in + the frame's x direction specified by the coordinate ``q``. The velocity + that the two points move away from one another can be specified by the + coordinate ``u`` where ``u`` is the first time derivative of ``q`` + (i.e., ``u = Derivative(q(t), t)``). + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (LinearDamper, LinearPathway, + ... Point, ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> damping = symbols('c') + >>> pA, pB = Point('pA'), Point('pB') + >>> pA.set_vel(N, 0) + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + >>> pB.vel(N) + Derivative(q(t), t)*N.x + >>> linear_pathway = LinearPathway(pA, pB) + >>> damper = LinearDamper(damping, linear_pathway) + >>> damper + LinearDamper(c, LinearPathway(pA, pB)) + + This damper will produce a force that is proportional to both its damping + coefficient and the pathway's extension length. Note that this force is + negative as SymPy's sign convention for actuators is that negative forces + are contractile and the damping force of the damper will oppose the + direction of length change. + + >>> damper.force + -c*sqrt(q(t)**2)*Derivative(q(t), t)/q(t) + + Parameters + ========== + + damping : Expr + The damping constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of a + concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + See Also + ======== + + ForceActuator: force-producing actuator (superclass of ``LinearDamper``). + LinearPathway: straight-line pathway between a pair of points. + + """ + + def __init__(self, damping, pathway): + """Initializer for ``LinearDamper``. + + Parameters + ========== + + damping : Expr + The damping constant. + pathway : PathwayBase + The pathway that the actuator follows. This must be an instance of + a concrete subclass of ``PathwayBase``, e.g. ``LinearPathway``. + + """ + self.damping = damping + self.pathway = pathway + + @property + def force(self): + """The damping force produced by the linear damper.""" + return -self.damping*self.pathway.extension_velocity + + @force.setter + def force(self, force): + raise AttributeError('Can\'t set computed attribute `force`.') + + @property + def damping(self): + """The damping constant for the linear damper.""" + return self._damping + + @damping.setter + def damping(self, damping): + if hasattr(self, '_damping'): + msg = ( + f'Can\'t set attribute `damping` to {repr(damping)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._damping = sympify(damping, strict=True) + + def __repr__(self): + """Representation of a ``LinearDamper``.""" + return f'{self.__class__.__name__}({self.damping}, {self.pathway})' + + +class TorqueActuator(ActuatorBase): + """Torque-producing actuator. + + Explanation + =========== + + A ``TorqueActuator`` is an actuator that produces a pair of equal and + opposite torques on a pair of bodies. + + Examples + ======== + + To construct a torque actuator, an expression (or symbol) must be supplied + to represent the torque it can produce, alongside a vector specifying the + axis about which the torque will act, and a pair of frames on which the + torque will act. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (ReferenceFrame, RigidBody, + ... TorqueActuator) + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> torque = symbols('T') + >>> axis = N.z + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> bodies = (child, parent) + >>> actuator = TorqueActuator(torque, axis, *bodies) + >>> actuator + TorqueActuator(T, axis=N.z, target_frame=A, reaction_frame=N) + + Note that because torques actually act on frames, not bodies, + ``TorqueActuator`` will extract the frame associated with a ``RigidBody`` + when one is passed instead of a ``ReferenceFrame``. + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator produces. + axis : Vector + The axis about which the actuator applies torques. + target_frame : ReferenceFrame | RigidBody + The primary frame on which the actuator will apply the torque. + reaction_frame : ReferenceFrame | RigidBody | None + The secondary frame on which the actuator will apply the torque. Note + that the (equal and opposite) reaction torque is applied to this frame. + + """ + + def __init__(self, torque, axis, target_frame, reaction_frame=None): + """Initializer for ``TorqueActuator``. + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator + produces. + axis : Vector + The axis about which the actuator applies torques. + target_frame : ReferenceFrame | RigidBody + The primary frame on which the actuator will apply the torque. + reaction_frame : ReferenceFrame | RigidBody | None + The secondary frame on which the actuator will apply the torque. + Note that the (equal and opposite) reaction torque is applied to + this frame. + + """ + self.torque = torque + self.axis = axis + self.target_frame = target_frame + self.reaction_frame = reaction_frame + + @classmethod + def at_pin_joint(cls, torque, pin_joint): + """Alternate construtor to instantiate from a ``PinJoint`` instance. + + Examples + ======== + + To create a pin joint the ``PinJoint`` class requires a name, parent + body, and child body to be passed to its constructor. It is also + possible to control the joint axis using the ``joint_axis`` keyword + argument. In this example let's use the parent body's reference frame's + z-axis as the joint axis. + + >>> from sympy.physics.mechanics import (PinJoint, ReferenceFrame, + ... RigidBody, TorqueActuator) + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> pin_joint = PinJoint( + ... 'pin', + ... parent, + ... child, + ... joint_axis=N.z, + ... ) + + Let's also create a symbol ``T`` that will represent the torque applied + by the torque actuator. + + >>> from sympy import symbols + >>> torque = symbols('T') + + To create the torque actuator from the ``torque`` and ``pin_joint`` + variables previously instantiated, these can be passed to the alternate + constructor class method ``at_pin_joint`` of the ``TorqueActuator`` + class. It should be noted that a positive torque will cause a positive + displacement of the joint coordinate or that the torque is applied on + the child body with a reaction torque on the parent. + + >>> actuator = TorqueActuator.at_pin_joint(torque, pin_joint) + >>> actuator + TorqueActuator(T, axis=N.z, target_frame=A, reaction_frame=N) + + Parameters + ========== + + torque : Expr + The scalar expression defining the torque that the actuator + produces. + pin_joint : PinJoint + The pin joint, and by association the parent and child bodies, on + which the torque actuator will act. The pair of bodies acted upon + by the torque actuator are the parent and child bodies of the pin + joint, with the child acting as the reaction body. The pin joint's + axis is used as the axis about which the torque actuator will apply + its torque. + + """ + if not isinstance(pin_joint, PinJoint): + msg = ( + f'Value {repr(pin_joint)} passed to `pin_joint` was of type ' + f'{type(pin_joint)}, must be {PinJoint}.' + ) + raise TypeError(msg) + return cls( + torque, + pin_joint.joint_axis, + pin_joint.child_interframe, + pin_joint.parent_interframe, + ) + + @property + def torque(self): + """The magnitude of the torque produced by the actuator.""" + return self._torque + + @torque.setter + def torque(self, torque): + if hasattr(self, '_torque'): + msg = ( + f'Can\'t set attribute `torque` to {repr(torque)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._torque = sympify(torque, strict=True) + + @property + def axis(self): + """The axis about which the torque acts.""" + return self._axis + + @axis.setter + def axis(self, axis): + if hasattr(self, '_axis'): + msg = ( + f'Can\'t set attribute `axis` to {repr(axis)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(axis, Vector): + msg = ( + f'Value {repr(axis)} passed to `axis` was of type ' + f'{type(axis)}, must be {Vector}.' + ) + raise TypeError(msg) + self._axis = axis + + @property + def target_frame(self): + """The primary reference frames on which the torque will act.""" + return self._target_frame + + @target_frame.setter + def target_frame(self, target_frame): + if hasattr(self, '_target_frame'): + msg = ( + f'Can\'t set attribute `target_frame` to {repr(target_frame)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if isinstance(target_frame, RigidBody): + target_frame = target_frame.frame + elif not isinstance(target_frame, ReferenceFrame): + msg = ( + f'Value {repr(target_frame)} passed to `target_frame` was of ' + f'type {type(target_frame)}, must be {ReferenceFrame}.' + ) + raise TypeError(msg) + self._target_frame = target_frame + + @property + def reaction_frame(self): + """The primary reference frames on which the torque will act.""" + return self._reaction_frame + + @reaction_frame.setter + def reaction_frame(self, reaction_frame): + if hasattr(self, '_reaction_frame'): + msg = ( + f'Can\'t set attribute `reaction_frame` to ' + f'{repr(reaction_frame)} as it is immutable.' + ) + raise AttributeError(msg) + if isinstance(reaction_frame, RigidBody): + reaction_frame = reaction_frame.frame + elif ( + not isinstance(reaction_frame, ReferenceFrame) + and reaction_frame is not None + ): + msg = ( + f'Value {repr(reaction_frame)} passed to `reaction_frame` was ' + f'of type {type(reaction_frame)}, must be {ReferenceFrame}.' + ) + raise TypeError(msg) + self._reaction_frame = reaction_frame + + def to_loads(self): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced by a torque + actuator that acts on a pair of bodies attached by a pin joint. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (PinJoint, ReferenceFrame, + ... RigidBody, TorqueActuator) + >>> torque = symbols('T') + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> parent = RigidBody('parent', frame=N) + >>> child = RigidBody('child', frame=A) + >>> pin_joint = PinJoint( + ... 'pin', + ... parent, + ... child, + ... joint_axis=N.z, + ... ) + >>> actuator = TorqueActuator.at_pin_joint(torque, pin_joint) + + The forces produces by the damper can be generated by calling the + ``to_loads`` method. + + >>> actuator.to_loads() + [(A, T*N.z), (N, - T*N.z)] + + Alternatively, if a torque actuator is created without a reaction frame + then the loads returned by the ``to_loads`` method will contain just + the single load acting on the target frame. + + >>> actuator = TorqueActuator(torque, N.z, N) + >>> actuator.to_loads() + [(N, T*N.z)] + + """ + loads = [ + Torque(self.target_frame, self.torque*self.axis), + ] + if self.reaction_frame is not None: + loads.append(Torque(self.reaction_frame, -self.torque*self.axis)) + return loads + + def __repr__(self): + """Representation of a ``TorqueActuator``.""" + string = ( + f'{self.__class__.__name__}({self.torque}, axis={self.axis}, ' + f'target_frame={self.target_frame}' + ) + if self.reaction_frame is not None: + string += f', reaction_frame={self.reaction_frame})' + else: + string += ')' + return string + + +class DuffingSpring(ForceActuator): + """A nonlinear spring based on the Duffing equation. + + Explanation + =========== + + Here, ``DuffingSpring`` represents the force exerted by a nonlinear spring based on the Duffing equation: + F = -beta*x-alpha*x**3, where x is the displacement from the equilibrium position, beta is the linear spring constant, + and alpha is the coefficient for the nonlinear cubic term. + + Parameters + ========== + + linear_stiffness : Expr + The linear stiffness coefficient (beta). + nonlinear_stiffness : Expr + The nonlinear stiffness coefficient (alpha). + pathway : PathwayBase + The pathway that the actuator follows. + equilibrium_length : Expr, optional + The length at which the spring is in equilibrium (x). + """ + + def __init__(self, linear_stiffness, nonlinear_stiffness, pathway, equilibrium_length=S.Zero): + self.linear_stiffness = sympify(linear_stiffness, strict=True) + self.nonlinear_stiffness = sympify(nonlinear_stiffness, strict=True) + self.equilibrium_length = sympify(equilibrium_length, strict=True) + + if not isinstance(pathway, PathwayBase): + raise TypeError("pathway must be an instance of PathwayBase.") + self._pathway = pathway + + @property + def linear_stiffness(self): + return self._linear_stiffness + + @linear_stiffness.setter + def linear_stiffness(self, linear_stiffness): + if hasattr(self, '_linear_stiffness'): + msg = ( + f'Can\'t set attribute `linear_stiffness` to ' + f'{repr(linear_stiffness)} as it is immutable.' + ) + raise AttributeError(msg) + self._linear_stiffness = sympify(linear_stiffness, strict=True) + + @property + def nonlinear_stiffness(self): + return self._nonlinear_stiffness + + @nonlinear_stiffness.setter + def nonlinear_stiffness(self, nonlinear_stiffness): + if hasattr(self, '_nonlinear_stiffness'): + msg = ( + f'Can\'t set attribute `nonlinear_stiffness` to ' + f'{repr(nonlinear_stiffness)} as it is immutable.' + ) + raise AttributeError(msg) + self._nonlinear_stiffness = sympify(nonlinear_stiffness, strict=True) + + @property + def pathway(self): + return self._pathway + + @pathway.setter + def pathway(self, pathway): + if hasattr(self, '_pathway'): + msg = ( + f'Can\'t set attribute `pathway` to {repr(pathway)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + if not isinstance(pathway, PathwayBase): + msg = ( + f'Value {repr(pathway)} passed to `pathway` was of type ' + f'{type(pathway)}, must be {PathwayBase}.' + ) + raise TypeError(msg) + self._pathway = pathway + + @property + def equilibrium_length(self): + return self._equilibrium_length + + @equilibrium_length.setter + def equilibrium_length(self, equilibrium_length): + if hasattr(self, '_equilibrium_length'): + msg = ( + f'Can\'t set attribute `equilibrium_length` to ' + f'{repr(equilibrium_length)} as it is immutable.' + ) + raise AttributeError(msg) + self._equilibrium_length = sympify(equilibrium_length, strict=True) + + @property + def force(self): + """The force produced by the Duffing spring.""" + displacement = self.pathway.length - self.equilibrium_length + return -self.linear_stiffness * displacement - self.nonlinear_stiffness * displacement**3 + + @force.setter + def force(self, force): + if hasattr(self, '_force'): + msg = ( + f'Can\'t set attribute `force` to {repr(force)} as it is ' + f'immutable.' + ) + raise AttributeError(msg) + self._force = sympify(force, strict=True) + + def __repr__(self): + return (f"{self.__class__.__name__}(" + f"{self.linear_stiffness}, {self.nonlinear_stiffness}, {self.pathway}, " + f"equilibrium_length={self.equilibrium_length})") diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/body.py b/lib/python3.10/site-packages/sympy/physics/mechanics/body.py new file mode 100644 index 0000000000000000000000000000000000000000..efc367158bbf51e7d9929318ac9286ba5c3fb3ac --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/body.py @@ -0,0 +1,710 @@ +from sympy import Symbol +from sympy.physics.vector import Point, Vector, ReferenceFrame, Dyadic +from sympy.physics.mechanics import RigidBody, Particle, Inertia +from sympy.physics.mechanics.body_base import BodyBase +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Body'] + + +# XXX: We use type:ignore because the classes RigidBody and Particle have +# inconsistent parallel axis methods that take different numbers of arguments. +class Body(RigidBody, Particle): # type: ignore + """ + Body is a common representation of either a RigidBody or a Particle SymPy + object depending on what is passed in during initialization. If a mass is + passed in and central_inertia is left as None, the Particle object is + created. Otherwise a RigidBody object will be created. + + .. deprecated:: 1.13 + The Body class is deprecated. Its functionality is captured by + :class:`~.RigidBody` and :class:`~.Particle`. + + Explanation + =========== + + The attributes that Body possesses will be the same as a Particle instance + or a Rigid Body instance depending on which was created. Additional + attributes are listed below. + + Attributes + ========== + + name : string + The body's name + masscenter : Point + The point which represents the center of mass of the rigid body + frame : ReferenceFrame + The reference frame which the body is fixed in + mass : Sympifyable + The body's mass + inertia : (Dyadic, Point) + The body's inertia around its center of mass. This attribute is specific + to the rigid body form of Body and is left undefined for the Particle + form + loads : iterable + This list contains information on the different loads acting on the + Body. Forces are listed as a (point, vector) tuple and torques are + listed as (reference frame, vector) tuples. + + Parameters + ========== + + name : String + Defines the name of the body. It is used as the base for defining + body specific properties. + masscenter : Point, optional + A point that represents the center of mass of the body or particle. + If no point is given, a point is generated. + mass : Sympifyable, optional + A Sympifyable object which represents the mass of the body. If no + mass is passed, one is generated. + frame : ReferenceFrame, optional + The ReferenceFrame that represents the reference frame of the body. + If no frame is given, a frame is generated. + central_inertia : Dyadic, optional + Central inertia dyadic of the body. If none is passed while creating + RigidBody, a default inertia is generated. + + Examples + ======== + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + Default behaviour. This results in the creation of a RigidBody object for + which the mass, mass center, frame and inertia attributes are given default + values. :: + + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body') + + This next example demonstrates the code required to specify all of the + values of the Body object. Note this will also create a RigidBody version of + the Body object. :: + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import ReferenceFrame, Point, inertia + >>> from sympy.physics.mechanics import Body + >>> mass = Symbol('mass') + >>> masscenter = Point('masscenter') + >>> frame = ReferenceFrame('frame') + >>> ixx = Symbol('ixx') + >>> body_inertia = inertia(frame, ixx, 0, 0) + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body', masscenter, mass, frame, body_inertia) + + The minimal code required to create a Particle version of the Body object + involves simply passing in a name and a mass. :: + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import Body + >>> mass = Symbol('mass') + >>> with ignore_warnings(DeprecationWarning): + ... body = Body('name_of_body', mass=mass) + + The Particle version of the Body object can also receive a masscenter point + and a reference frame, just not an inertia. + """ + + def __init__(self, name, masscenter=None, mass=None, frame=None, + central_inertia=None): + sympy_deprecation_warning( + """ + Support for the Body class has been removed, as its functionality is + fully captured by RigidBody and Particle. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-mechanics-body-class" + ) + + self._loads = [] + + if frame is None: + frame = ReferenceFrame(name + '_frame') + + if masscenter is None: + masscenter = Point(name + '_masscenter') + + if central_inertia is None and mass is None: + ixx = Symbol(name + '_ixx') + iyy = Symbol(name + '_iyy') + izz = Symbol(name + '_izz') + izx = Symbol(name + '_izx') + ixy = Symbol(name + '_ixy') + iyz = Symbol(name + '_iyz') + _inertia = Inertia.from_inertia_scalars(masscenter, frame, ixx, iyy, + izz, ixy, iyz, izx) + else: + _inertia = (central_inertia, masscenter) + + if mass is None: + _mass = Symbol(name + '_mass') + else: + _mass = mass + + masscenter.set_vel(frame, 0) + + # If user passes masscenter and mass then a particle is created + # otherwise a rigidbody. As a result a body may or may not have inertia. + # Note: BodyBase.__init__ is used to prevent problems with super() calls in + # Particle and RigidBody arising due to multiple inheritance. + if central_inertia is None and mass is not None: + BodyBase.__init__(self, name, masscenter, _mass) + self.frame = frame + self._central_inertia = Dyadic(0) + else: + BodyBase.__init__(self, name, masscenter, _mass) + self.frame = frame + self.inertia = _inertia + + def __repr__(self): + if self.is_rigidbody: + return RigidBody.__repr__(self) + return Particle.__repr__(self) + + @property + def loads(self): + return self._loads + + @property + def x(self): + """The basis Vector for the Body, in the x direction.""" + return self.frame.x + + @property + def y(self): + """The basis Vector for the Body, in the y direction.""" + return self.frame.y + + @property + def z(self): + """The basis Vector for the Body, in the z direction.""" + return self.frame.z + + @property + def inertia(self): + """The body's inertia about a point; stored as (Dyadic, Point).""" + if self.is_rigidbody: + return RigidBody.inertia.fget(self) + return (self.central_inertia, self.masscenter) + + @inertia.setter + def inertia(self, I): + RigidBody.inertia.fset(self, I) + + @property + def is_rigidbody(self): + if hasattr(self, '_inertia'): + return True + return False + + def kinetic_energy(self, frame): + """Kinetic energy of the body. + + Parameters + ========== + + frame : ReferenceFrame or Body + The Body's angular velocity and the velocity of it's mass + center are typically defined with respect to an inertial frame but + any relevant frame in which the velocities are known can be supplied. + + Examples + ======== + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, ReferenceFrame, Point + >>> from sympy import symbols + >>> m, v, r, omega = symbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> with ignore_warnings(DeprecationWarning): + ... P = Body('P', masscenter=O, mass=m) + >>> P.masscenter.set_vel(N, v * N.y) + >>> P.kinetic_energy(N) + m*v**2/2 + + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B', masscenter=P, frame=b) + >>> B.kinetic_energy(N) + B_ixx*omega**2/2 + B_mass*v**2/2 + + See Also + ======== + + sympy.physics.mechanics : Particle, RigidBody + + """ + if isinstance(frame, Body): + frame = Body.frame + if self.is_rigidbody: + return RigidBody(self.name, self.masscenter, self.frame, self.mass, + (self.central_inertia, self.masscenter)).kinetic_energy(frame) + return Particle(self.name, self.masscenter, self.mass).kinetic_energy(frame) + + def apply_force(self, force, point=None, reaction_body=None, reaction_point=None): + """Add force to the body(s). + + Explanation + =========== + + Applies the force on self or equal and opposite forces on + self and other body if both are given on the desired point on the bodies. + The force applied on other body is taken opposite of self, i.e, -force. + + Parameters + ========== + + force: Vector + The force to be applied. + point: Point, optional + The point on self on which force is applied. + By default self's masscenter. + reaction_body: Body, optional + Second body on which equal and opposite force + is to be applied. + reaction_point : Point, optional + The point on other body on which equal and opposite + force is applied. By default masscenter of other body. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, Point, dynamicsymbols + >>> m, g = symbols('m g') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> force1 = m*g*B.z + >>> B.apply_force(force1) #Applying force on B's masscenter + >>> B.loads + [(B_masscenter, g*m*B_frame.z)] + + We can also remove some part of force from any point on the body by + adding the opposite force to the body on that point. + + >>> f1, f2 = dynamicsymbols('f1 f2') + >>> P = Point('P') #Considering point P on body B + >>> B.apply_force(f1*B.x + f2*B.y, P) + >>> B.loads + [(B_masscenter, g*m*B_frame.z), (P, f1(t)*B_frame.x + f2(t)*B_frame.y)] + + Let's remove f1 from point P on body B. + + >>> B.apply_force(-f1*B.x, P) + >>> B.loads + [(B_masscenter, g*m*B_frame.z), (P, f2(t)*B_frame.y)] + + To further demonstrate the use of ``apply_force`` attribute, + consider two bodies connected through a spring. + + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> with ignore_warnings(DeprecationWarning): + ... N = Body('N') #Newtonion Frame + >>> x = dynamicsymbols('x') + >>> with ignore_warnings(DeprecationWarning): + ... B1 = Body('B1') + ... B2 = Body('B2') + >>> spring_force = x*N.x + + Now let's apply equal and opposite spring force to the bodies. + + >>> P1 = Point('P1') + >>> P2 = Point('P2') + >>> B1.apply_force(spring_force, point=P1, reaction_body=B2, reaction_point=P2) + + We can check the loads(forces) applied to bodies now. + + >>> B1.loads + [(P1, x(t)*N_frame.x)] + >>> B2.loads + [(P2, - x(t)*N_frame.x)] + + Notes + ===== + + If a new force is applied to a body on a point which already has some + force applied on it, then the new force is added to the already applied + force on that point. + + """ + + if not isinstance(point, Point): + if point is None: + point = self.masscenter # masscenter + else: + raise TypeError("Force must be applied to a point on the body.") + if not isinstance(force, Vector): + raise TypeError("Force must be a vector.") + + if reaction_body is not None: + reaction_body.apply_force(-force, point=reaction_point) + + for load in self._loads: + if point in load: + force += load[1] + self._loads.remove(load) + break + + self._loads.append((point, force)) + + def apply_torque(self, torque, reaction_body=None): + """Add torque to the body(s). + + Explanation + =========== + + Applies the torque on self or equal and opposite torques on + self and other body if both are given. + The torque applied on other body is taken opposite of self, + i.e, -torque. + + Parameters + ========== + + torque: Vector + The torque to be applied. + reaction_body: Body, optional + Second body on which equal and opposite torque + is to be applied. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> t = symbols('t') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> torque1 = t*B.z + >>> B.apply_torque(torque1) + >>> B.loads + [(B_frame, t*B_frame.z)] + + We can also remove some part of torque from the body by + adding the opposite torque to the body. + + >>> t1, t2 = dynamicsymbols('t1 t2') + >>> B.apply_torque(t1*B.x + t2*B.y) + >>> B.loads + [(B_frame, t1(t)*B_frame.x + t2(t)*B_frame.y + t*B_frame.z)] + + Let's remove t1 from Body B. + + >>> B.apply_torque(-t1*B.x) + >>> B.loads + [(B_frame, t2(t)*B_frame.y + t*B_frame.z)] + + To further demonstrate the use, let us consider two bodies such that + a torque `T` is acting on one body, and `-T` on the other. + + >>> from sympy.physics.mechanics import Body, dynamicsymbols + >>> with ignore_warnings(DeprecationWarning): + ... N = Body('N') #Newtonion frame + ... B1 = Body('B1') + ... B2 = Body('B2') + >>> v = dynamicsymbols('v') + >>> T = v*N.y #Torque + + Now let's apply equal and opposite torque to the bodies. + + >>> B1.apply_torque(T, B2) + + We can check the loads (torques) applied to bodies now. + + >>> B1.loads + [(B1_frame, v(t)*N_frame.y)] + >>> B2.loads + [(B2_frame, - v(t)*N_frame.y)] + + Notes + ===== + + If a new torque is applied on body which already has some torque applied on it, + then the new torque is added to the previous torque about the body's frame. + + """ + + if not isinstance(torque, Vector): + raise TypeError("A Vector must be supplied to add torque.") + + if reaction_body is not None: + reaction_body.apply_torque(-torque) + + for load in self._loads: + if self.frame in load: + torque += load[1] + self._loads.remove(load) + break + self._loads.append((self.frame, torque)) + + def clear_loads(self): + """ + Clears the Body's loads list. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> force = B.x + B.y + >>> B.apply_force(force) + >>> B.loads + [(B_masscenter, B_frame.x + B_frame.y)] + >>> B.clear_loads() + >>> B.loads + [] + + """ + + self._loads = [] + + def remove_load(self, about=None): + """ + Remove load about a point or frame. + + Parameters + ========== + + about : Point or ReferenceFrame, optional + The point about which force is applied, + and is to be removed. + If about is None, then the torque about + self's frame is removed. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, Point + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B') + >>> P = Point('P') + >>> f1 = B.x + >>> f2 = B.y + >>> B.apply_force(f1) + >>> B.apply_force(f2, P) + >>> B.loads + [(B_masscenter, B_frame.x), (P, B_frame.y)] + + >>> B.remove_load(P) + >>> B.loads + [(B_masscenter, B_frame.x)] + + """ + + if about is not None: + if not isinstance(about, Point): + raise TypeError('Load is applied about Point or ReferenceFrame.') + else: + about = self.frame + + for load in self._loads: + if about in load: + self._loads.remove(load) + break + + def masscenter_vel(self, body): + """ + Returns the velocity of the mass center with respect to the provided + rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the velocity in. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + ... B = Body('B') + >>> A.masscenter.set_vel(B.frame, 5*B.frame.x) + >>> A.masscenter_vel(B) + 5*B_frame.x + >>> A.masscenter_vel(B.frame) + 5*B_frame.x + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.masscenter.vel(frame) + + def ang_vel_in(self, body): + """ + Returns this body's angular velocity with respect to the provided + rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the angular velocity in. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body, ReferenceFrame + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + >>> N = ReferenceFrame('N') + >>> with ignore_warnings(DeprecationWarning): + ... B = Body('B', frame=N) + >>> A.frame.set_ang_vel(N, 5*N.x) + >>> A.ang_vel_in(B) + 5*N.x + >>> A.ang_vel_in(N) + 5*N.x + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.frame.ang_vel_in(frame) + + def dcm(self, body): + """ + Returns the direction cosine matrix of this body relative to the + provided rigid body or reference frame. + + Parameters + ========== + + body: Body or ReferenceFrame + The rigid body or reference frame to calculate the dcm. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + ... B = Body('B') + >>> A.frame.orient_axis(B.frame, B.frame.x, 5) + >>> A.dcm(B) + Matrix([ + [1, 0, 0], + [0, cos(5), sin(5)], + [0, -sin(5), cos(5)]]) + >>> A.dcm(B.frame) + Matrix([ + [1, 0, 0], + [0, cos(5), sin(5)], + [0, -sin(5), cos(5)]]) + + """ + + if isinstance(body, ReferenceFrame): + frame=body + elif isinstance(body, Body): + frame = body.frame + return self.frame.dcm(frame) + + def parallel_axis(self, point, frame=None): + """Returns the inertia dyadic of the body with respect to another + point. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the rigid body expressed about the provided + point. + + Example + ======= + + As Body has been deprecated, the following examples are for illustrative + purposes only. The functionality of Body is fully captured by + :class:`~.RigidBody` and :class:`~.Particle`. To ignore the deprecation + warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + >>> from sympy.physics.mechanics import Body + >>> with ignore_warnings(DeprecationWarning): + ... A = Body('A') + >>> P = A.masscenter.locatenew('point', 3 * A.x + 5 * A.y) + >>> A.parallel_axis(P).to_matrix(A.frame) + Matrix([ + [A_ixx + 25*A_mass, A_ixy - 15*A_mass, A_izx], + [A_ixy - 15*A_mass, A_iyy + 9*A_mass, A_iyz], + [ A_izx, A_iyz, A_izz + 34*A_mass]]) + + """ + if self.is_rigidbody: + return RigidBody.parallel_axis(self, point, frame) + return Particle.parallel_axis(self, point, frame) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/body_base.py b/lib/python3.10/site-packages/sympy/physics/mechanics/body_base.py new file mode 100644 index 0000000000000000000000000000000000000000..d2546faf685f579d2aea10ed7f139a4beced7dd0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/body_base.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from sympy import Symbol, sympify +from sympy.physics.vector import Point + +__all__ = ['BodyBase'] + + +class BodyBase(ABC): + """Abstract class for body type objects.""" + def __init__(self, name, masscenter=None, mass=None): + # Note: If frame=None, no auto-generated frame is created, because a + # Particle does not need to have a frame by default. + if not isinstance(name, str): + raise TypeError('Supply a valid name.') + self._name = name + if mass is None: + mass = Symbol(f'{name}_mass') + if masscenter is None: + masscenter = Point(f'{name}_masscenter') + self.mass = mass + self.masscenter = masscenter + self.potential_energy = 0 + self.points = [] + + def __str__(self): + return self.name + + def __repr__(self): + return (f'{self.__class__.__name__}({repr(self.name)}, masscenter=' + f'{repr(self.masscenter)}, mass={repr(self.mass)})') + + @property + def name(self): + """The name of the body.""" + return self._name + + @property + def masscenter(self): + """The body's center of mass.""" + return self._masscenter + + @masscenter.setter + def masscenter(self, point): + if not isinstance(point, Point): + raise TypeError("The body's center of mass must be a Point object.") + self._masscenter = point + + @property + def mass(self): + """The body's mass.""" + return self._mass + + @mass.setter + def mass(self, mass): + self._mass = sympify(mass) + + @property + def potential_energy(self): + """The potential energy of the body. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point + >>> from sympy import symbols + >>> m, g, h = symbols('m g h') + >>> O = Point('O') + >>> P = Particle('P', O, m) + >>> P.potential_energy = m * g * h + >>> P.potential_energy + g*h*m + + """ + return self._potential_energy + + @potential_energy.setter + def potential_energy(self, scalar): + self._potential_energy = sympify(scalar) + + @abstractmethod + def kinetic_energy(self, frame): + pass + + @abstractmethod + def linear_momentum(self, frame): + pass + + @abstractmethod + def angular_momentum(self, point, frame): + pass + + @abstractmethod + def parallel_axis(self, point, frame): + pass diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/functions.py b/lib/python3.10/site-packages/sympy/physics/mechanics/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..42abe2b7fe608b4602cdab518f209b446b2dbe03 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/functions.py @@ -0,0 +1,735 @@ +from sympy.utilities import dict_merge +from sympy.utilities.iterables import iterable +from sympy.physics.vector import (Dyadic, Vector, ReferenceFrame, + Point, dynamicsymbols) +from sympy.physics.vector.printing import (vprint, vsprint, vpprint, vlatex, + init_vprinting) +from sympy.physics.mechanics.particle import Particle +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.simplify.simplify import simplify +from sympy import Matrix, Mul, Derivative, sin, cos, tan, S +from sympy.core.function import AppliedUndef +from sympy.physics.mechanics.inertia import (inertia as _inertia, + inertia_of_point_mass as _inertia_of_point_mass) +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['linear_momentum', + 'angular_momentum', + 'kinetic_energy', + 'potential_energy', + 'Lagrangian', + 'mechanics_printing', + 'mprint', + 'msprint', + 'mpprint', + 'mlatex', + 'msubs', + 'find_dynamicsymbols'] + +# These are functions that we've moved and renamed during extracting the +# basic vector calculus code from the mechanics packages. + +mprint = vprint +msprint = vsprint +mpprint = vpprint +mlatex = vlatex + + +def mechanics_printing(**kwargs): + """ + Initializes time derivative printing for all SymPy objects in + mechanics module. + """ + + init_vprinting(**kwargs) + +mechanics_printing.__doc__ = init_vprinting.__doc__ + + +def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0): + sympy_deprecation_warning( + """ + The inertia function has been moved. + Import it from "sympy.physics.mechanics". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _inertia(frame, ixx, iyy, izz, ixy, iyz, izx) + + +def inertia_of_point_mass(mass, pos_vec, frame): + sympy_deprecation_warning( + """ + The inertia_of_point_mass function has been moved. + Import it from "sympy.physics.mechanics". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _inertia_of_point_mass(mass, pos_vec, frame) + + +def linear_momentum(frame, *body): + """Linear momentum of the system. + + Explanation + =========== + + This function returns the linear momentum of a system of Particle's and/or + RigidBody's. The linear momentum of a system is equal to the vector sum of + the linear momentum of its constituents. Consider a system, S, comprised of + a rigid body, A, and a particle, P. The linear momentum of the system, L, + is equal to the vector sum of the linear momentum of the particle, L1, and + the linear momentum of the rigid body, L2, i.e. + + L = L1 + L2 + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose linear momentum is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, linear_momentum + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = Point('Ac') + >>> Ac.set_vel(N, 25 * N.y) + >>> I = outer(N.x, N.x) + >>> A = RigidBody('A', Ac, N, 20, (I, Ac)) + >>> linear_momentum(N, A, Pa) + 10*N.x + 500*N.y + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please specify a valid ReferenceFrame') + else: + linear_momentum_sys = Vector(0) + for e in body: + if isinstance(e, (RigidBody, Particle)): + linear_momentum_sys += e.linear_momentum(frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return linear_momentum_sys + + +def angular_momentum(point, frame, *body): + """Angular momentum of a system. + + Explanation + =========== + + This function returns the angular momentum of a system of Particle's and/or + RigidBody's. The angular momentum of such a system is equal to the vector + sum of the angular momentum of its constituents. Consider a system, S, + comprised of a rigid body, A, and a particle, P. The angular momentum of + the system, H, is equal to the vector sum of the angular momentum of the + particle, H1, and the angular momentum of the rigid body, H2, i.e. + + H = H1 + H2 + + Parameters + ========== + + point : Point + The point about which angular momentum of the system is desired. + frame : ReferenceFrame + The frame in which angular momentum is desired. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose angular momentum is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, angular_momentum + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> angular_momentum(O, N, Pa, A) + 10*N.z + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please enter a valid ReferenceFrame') + if not isinstance(point, Point): + raise TypeError('Please specify a valid Point') + else: + angular_momentum_sys = Vector(0) + for e in body: + if isinstance(e, (RigidBody, Particle)): + angular_momentum_sys += e.angular_momentum(point, frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return angular_momentum_sys + + +def kinetic_energy(frame, *body): + """Kinetic energy of a multibody system. + + Explanation + =========== + + This function returns the kinetic energy of a system of Particle's and/or + RigidBody's. The kinetic energy of such a system is equal to the sum of + the kinetic energies of its constituents. Consider a system, S, comprising + a rigid body, A, and a particle, P. The kinetic energy of the system, T, + is equal to the vector sum of the kinetic energy of the particle, T1, and + the kinetic energy of the rigid body, T2, i.e. + + T = T1 + T2 + + Kinetic energy is a scalar. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the velocity or angular velocity of the body is + defined. + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose kinetic energy is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, kinetic_energy + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> kinetic_energy(N, Pa, A) + 350 + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please enter a valid ReferenceFrame') + ke_sys = S.Zero + for e in body: + if isinstance(e, (RigidBody, Particle)): + ke_sys += e.kinetic_energy(frame) + else: + raise TypeError('*body must have only Particle or RigidBody') + return ke_sys + + +def potential_energy(*body): + """Potential energy of a multibody system. + + Explanation + =========== + + This function returns the potential energy of a system of Particle's and/or + RigidBody's. The potential energy of such a system is equal to the sum of + the potential energy of its constituents. Consider a system, S, comprising + a rigid body, A, and a particle, P. The potential energy of the system, V, + is equal to the vector sum of the potential energy of the particle, V1, and + the potential energy of the rigid body, V2, i.e. + + V = V1 + V2 + + Potential energy is a scalar. + + Parameters + ========== + + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose potential energy is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, potential_energy + >>> from sympy import symbols + >>> M, m, g, h = symbols('M m g h') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> Pa = Particle('Pa', P, m) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> a = ReferenceFrame('a') + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, M, (I, Ac)) + >>> Pa.potential_energy = m * g * h + >>> A.potential_energy = M * g * h + >>> potential_energy(Pa, A) + M*g*h + g*h*m + + """ + + pe_sys = S.Zero + for e in body: + if isinstance(e, (RigidBody, Particle)): + pe_sys += e.potential_energy + else: + raise TypeError('*body must have only Particle or RigidBody') + return pe_sys + + +def gravity(acceleration, *bodies): + from sympy.physics.mechanics.loads import gravity as _gravity + sympy_deprecation_warning( + """ + The gravity function has been moved. + Import it from "sympy.physics.mechanics.loads". + """, + deprecated_since_version="1.13", + active_deprecations_target="moved-mechanics-functions" + ) + return _gravity(acceleration, *bodies) + + +def center_of_mass(point, *bodies): + """ + Returns the position vector from the given point to the center of mass + of the given bodies(particles or rigidbodies). + + Example + ======= + + >>> from sympy import symbols, S + >>> from sympy.physics.vector import Point + >>> from sympy.physics.mechanics import Particle, ReferenceFrame, RigidBody, outer + >>> from sympy.physics.mechanics.functions import center_of_mass + >>> a = ReferenceFrame('a') + >>> m = symbols('m', real=True) + >>> p1 = Particle('p1', Point('p1_pt'), S(1)) + >>> p2 = Particle('p2', Point('p2_pt'), S(2)) + >>> p3 = Particle('p3', Point('p3_pt'), S(3)) + >>> p4 = Particle('p4', Point('p4_pt'), m) + >>> b_f = ReferenceFrame('b_f') + >>> b_cm = Point('b_cm') + >>> mb = symbols('mb') + >>> b = RigidBody('b', b_cm, b_f, mb, (outer(b_f.x, b_f.x), b_cm)) + >>> p2.point.set_pos(p1.point, a.x) + >>> p3.point.set_pos(p1.point, a.x + a.y) + >>> p4.point.set_pos(p1.point, a.y) + >>> b.masscenter.set_pos(p1.point, a.y + a.z) + >>> point_o=Point('o') + >>> point_o.set_pos(p1.point, center_of_mass(p1.point, p1, p2, p3, p4, b)) + >>> expr = 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z + >>> point_o.pos_from(p1.point) + 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z + + """ + if not bodies: + raise TypeError("No bodies(instances of Particle or Rigidbody) were passed.") + + total_mass = 0 + vec = Vector(0) + for i in bodies: + total_mass += i.mass + + masscenter = getattr(i, 'masscenter', None) + if masscenter is None: + masscenter = i.point + vec += i.mass*masscenter.pos_from(point) + + return vec/total_mass + + +def Lagrangian(frame, *body): + """Lagrangian of a multibody system. + + Explanation + =========== + + This function returns the Lagrangian of a system of Particle's and/or + RigidBody's. The Lagrangian of such a system is equal to the difference + between the kinetic energies and potential energies of its constituents. If + T and V are the kinetic and potential energies of a system then it's + Lagrangian, L, is defined as + + L = T - V + + The Lagrangian is a scalar. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the velocity or angular velocity of the body is + defined to determine the kinetic energy. + + body1, body2, body3... : Particle and/or RigidBody + The body (or bodies) whose Lagrangian is required. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame + >>> from sympy.physics.mechanics import RigidBody, outer, Lagrangian + >>> from sympy import symbols + >>> M, m, g, h = symbols('M m g h') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> O.set_vel(N, 0 * N.x) + >>> P = O.locatenew('P', 1 * N.x) + >>> P.set_vel(N, 10 * N.x) + >>> Pa = Particle('Pa', P, 1) + >>> Ac = O.locatenew('Ac', 2 * N.y) + >>> Ac.set_vel(N, 5 * N.y) + >>> a = ReferenceFrame('a') + >>> a.set_ang_vel(N, 10 * N.z) + >>> I = outer(N.z, N.z) + >>> A = RigidBody('A', Ac, a, 20, (I, Ac)) + >>> Pa.potential_energy = m * g * h + >>> A.potential_energy = M * g * h + >>> Lagrangian(N, Pa, A) + -M*g*h - g*h*m + 350 + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Please supply a valid ReferenceFrame') + for e in body: + if not isinstance(e, (RigidBody, Particle)): + raise TypeError('*body must have only Particle or RigidBody') + return kinetic_energy(frame, *body) - potential_energy(*body) + + +def find_dynamicsymbols(expression, exclude=None, reference_frame=None): + """Find all dynamicsymbols in expression. + + Explanation + =========== + + If the optional ``exclude`` kwarg is used, only dynamicsymbols + not in the iterable ``exclude`` are returned. + If we intend to apply this function on a vector, the optional + ``reference_frame`` is also used to inform about the corresponding frame + with respect to which the dynamic symbols of the given vector is to be + determined. + + Parameters + ========== + + expression : SymPy expression + + exclude : iterable of dynamicsymbols, optional + + reference_frame : ReferenceFrame, optional + The frame with respect to which the dynamic symbols of the + given vector is to be determined. + + Examples + ======== + + >>> from sympy.physics.mechanics import dynamicsymbols, find_dynamicsymbols + >>> from sympy.physics.mechanics import ReferenceFrame + >>> x, y = dynamicsymbols('x, y') + >>> expr = x + x.diff()*y + >>> find_dynamicsymbols(expr) + {x(t), y(t), Derivative(x(t), t)} + >>> find_dynamicsymbols(expr, exclude=[x, y]) + {Derivative(x(t), t)} + >>> a, b, c = dynamicsymbols('a, b, c') + >>> A = ReferenceFrame('A') + >>> v = a * A.x + b * A.y + c * A.z + >>> find_dynamicsymbols(v, reference_frame=A) + {a(t), b(t), c(t)} + + """ + t_set = {dynamicsymbols._t} + if exclude: + if iterable(exclude): + exclude_set = set(exclude) + else: + raise TypeError("exclude kwarg must be iterable") + else: + exclude_set = set() + if isinstance(expression, Vector): + if reference_frame is None: + raise ValueError("You must provide reference_frame when passing a " + "vector expression, got %s." % reference_frame) + else: + expression = expression.to_matrix(reference_frame) + return {i for i in expression.atoms(AppliedUndef, Derivative) if + i.free_symbols == t_set} - exclude_set + + +def msubs(expr, *sub_dicts, smart=False, **kwargs): + """A custom subs for use on expressions derived in physics.mechanics. + + Traverses the expression tree once, performing the subs found in sub_dicts. + Terms inside ``Derivative`` expressions are ignored: + + Examples + ======== + + >>> from sympy.physics.mechanics import dynamicsymbols, msubs + >>> x = dynamicsymbols('x') + >>> msubs(x.diff() + x, {x: 1}) + Derivative(x(t), t) + 1 + + Note that sub_dicts can be a single dictionary, or several dictionaries: + + >>> x, y, z = dynamicsymbols('x, y, z') + >>> sub1 = {x: 1, y: 2} + >>> sub2 = {z: 3, x.diff(): 4} + >>> msubs(x.diff() + x + y + z, sub1, sub2) + 10 + + If smart=True (default False), also checks for conditions that may result + in ``nan``, but if simplified would yield a valid expression. For example: + + >>> from sympy import sin, tan + >>> (sin(x)/tan(x)).subs(x, 0) + nan + >>> msubs(sin(x)/tan(x), {x: 0}, smart=True) + 1 + + It does this by first replacing all ``tan`` with ``sin/cos``. Then each + node is traversed. If the node is a fraction, subs is first evaluated on + the denominator. If this results in 0, simplification of the entire + fraction is attempted. Using this selective simplification, only + subexpressions that result in 1/0 are targeted, resulting in faster + performance. + + """ + + sub_dict = dict_merge(*sub_dicts) + if smart: + func = _smart_subs + elif hasattr(expr, 'msubs'): + return expr.msubs(sub_dict) + else: + func = lambda expr, sub_dict: _crawl(expr, _sub_func, sub_dict) + if isinstance(expr, (Matrix, Vector, Dyadic)): + return expr.applyfunc(lambda x: func(x, sub_dict)) + else: + return func(expr, sub_dict) + + +def _crawl(expr, func, *args, **kwargs): + """Crawl the expression tree, and apply func to every node.""" + val = func(expr, *args, **kwargs) + if val is not None: + return val + new_args = (_crawl(arg, func, *args, **kwargs) for arg in expr.args) + return expr.func(*new_args) + + +def _sub_func(expr, sub_dict): + """Perform direct matching substitution, ignoring derivatives.""" + if expr in sub_dict: + return sub_dict[expr] + elif not expr.args or expr.is_Derivative: + return expr + + +def _tan_repl_func(expr): + """Replace tan with sin/cos.""" + if isinstance(expr, tan): + return sin(*expr.args) / cos(*expr.args) + elif not expr.args or expr.is_Derivative: + return expr + + +def _smart_subs(expr, sub_dict): + """Performs subs, checking for conditions that may result in `nan` or + `oo`, and attempts to simplify them out. + + The expression tree is traversed twice, and the following steps are + performed on each expression node: + - First traverse: + Replace all `tan` with `sin/cos`. + - Second traverse: + If node is a fraction, check if the denominator evaluates to 0. + If so, attempt to simplify it out. Then if node is in sub_dict, + sub in the corresponding value. + + """ + expr = _crawl(expr, _tan_repl_func) + + def _recurser(expr, sub_dict): + # Decompose the expression into num, den + num, den = _fraction_decomp(expr) + if den != 1: + # If there is a non trivial denominator, we need to handle it + denom_subbed = _recurser(den, sub_dict) + if denom_subbed.evalf() == 0: + # If denom is 0 after this, attempt to simplify the bad expr + expr = simplify(expr) + else: + # Expression won't result in nan, find numerator + num_subbed = _recurser(num, sub_dict) + return num_subbed / denom_subbed + # We have to crawl the tree manually, because `expr` may have been + # modified in the simplify step. First, perform subs as normal: + val = _sub_func(expr, sub_dict) + if val is not None: + return val + new_args = (_recurser(arg, sub_dict) for arg in expr.args) + return expr.func(*new_args) + return _recurser(expr, sub_dict) + + +def _fraction_decomp(expr): + """Return num, den such that expr = num/den.""" + if not isinstance(expr, Mul): + return expr, 1 + num = [] + den = [] + for a in expr.args: + if a.is_Pow and a.args[1] < 0: + den.append(1 / a) + else: + num.append(a) + if not den: + return expr, 1 + num = Mul(*num) + den = Mul(*den) + return num, den + + +def _f_list_parser(fl, ref_frame): + """Parses the provided forcelist composed of items + of the form (obj, force). + Returns a tuple containing: + vel_list: The velocity (ang_vel for Frames, vel for Points) in + the provided reference frame. + f_list: The forces. + + Used internally in the KanesMethod and LagrangesMethod classes. + + """ + def flist_iter(): + for pair in fl: + obj, force = pair + if isinstance(obj, ReferenceFrame): + yield obj.ang_vel_in(ref_frame), force + elif isinstance(obj, Point): + yield obj.vel(ref_frame), force + else: + raise TypeError('First entry in each forcelist pair must ' + 'be a point or frame.') + + if not fl: + vel_list, f_list = (), () + else: + unzip = lambda l: list(zip(*l)) if l[0] else [(), ()] + vel_list, f_list = unzip(list(flist_iter())) + return vel_list, f_list + + +def _validate_coordinates(coordinates=None, speeds=None, check_duplicates=True, + is_dynamicsymbols=True, u_auxiliary=None): + """Validate the generalized coordinates and generalized speeds. + + Parameters + ========== + coordinates : iterable, optional + Generalized coordinates to be validated. + speeds : iterable, optional + Generalized speeds to be validated. + check_duplicates : bool, optional + Checks if there are duplicates in the generalized coordinates and + generalized speeds. If so it will raise a ValueError. The default is + True. + is_dynamicsymbols : iterable, optional + Checks if all the generalized coordinates and generalized speeds are + dynamicsymbols. If any is not a dynamicsymbol, a ValueError will be + raised. The default is True. + u_auxiliary : iterable, optional + Auxiliary generalized speeds to be validated. + + """ + t_set = {dynamicsymbols._t} + # Convert input to iterables + if coordinates is None: + coordinates = [] + elif not iterable(coordinates): + coordinates = [coordinates] + if speeds is None: + speeds = [] + elif not iterable(speeds): + speeds = [speeds] + if u_auxiliary is None: + u_auxiliary = [] + elif not iterable(u_auxiliary): + u_auxiliary = [u_auxiliary] + + msgs = [] + if check_duplicates: # Check for duplicates + seen = set() + coord_duplicates = {x for x in coordinates if x in seen or seen.add(x)} + seen = set() + speed_duplicates = {x for x in speeds if x in seen or seen.add(x)} + seen = set() + aux_duplicates = {x for x in u_auxiliary if x in seen or seen.add(x)} + overlap_coords = set(coordinates).intersection(speeds) + overlap_aux = set(coordinates).union(speeds).intersection(u_auxiliary) + if coord_duplicates: + msgs.append(f'The generalized coordinates {coord_duplicates} are ' + f'duplicated, all generalized coordinates should be ' + f'unique.') + if speed_duplicates: + msgs.append(f'The generalized speeds {speed_duplicates} are ' + f'duplicated, all generalized speeds should be unique.') + if aux_duplicates: + msgs.append(f'The auxiliary speeds {aux_duplicates} are duplicated,' + f' all auxiliary speeds should be unique.') + if overlap_coords: + msgs.append(f'{overlap_coords} are defined as both generalized ' + f'coordinates and generalized speeds.') + if overlap_aux: + msgs.append(f'The auxiliary speeds {overlap_aux} are also defined ' + f'as generalized coordinates or generalized speeds.') + if is_dynamicsymbols: # Check whether all coordinates are dynamicsymbols + for coordinate in coordinates: + if not (isinstance(coordinate, (AppliedUndef, Derivative)) and + coordinate.free_symbols == t_set): + msgs.append(f'Generalized coordinate "{coordinate}" is not a ' + f'dynamicsymbol.') + for speed in speeds: + if not (isinstance(speed, (AppliedUndef, Derivative)) and + speed.free_symbols == t_set): + msgs.append( + f'Generalized speed "{speed}" is not a dynamicsymbol.') + for aux in u_auxiliary: + if not (isinstance(aux, (AppliedUndef, Derivative)) and + aux.free_symbols == t_set): + msgs.append( + f'Auxiliary speed "{aux}" is not a dynamicsymbol.') + if msgs: + raise ValueError('\n'.join(msgs)) + + +def _parse_linear_solver(linear_solver): + """Helper function to retrieve a specified linear solver.""" + if callable(linear_solver): + return linear_solver + return lambda A, b: Matrix.solve(A, b, method=linear_solver) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/inertia.py b/lib/python3.10/site-packages/sympy/physics/mechanics/inertia.py new file mode 100644 index 0000000000000000000000000000000000000000..d2fe37c7f9f39c692f98c8bf038a73326e171dd7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/inertia.py @@ -0,0 +1,197 @@ +from sympy import sympify +from sympy.physics.vector import Point, Dyadic, ReferenceFrame, outer +from collections import namedtuple + +__all__ = ['inertia', 'inertia_of_point_mass', 'Inertia'] + + +def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0): + """Simple way to create inertia Dyadic object. + + Explanation + =========== + + Creates an inertia Dyadic based on the given tensor values and a body-fixed + reference frame. + + Parameters + ========== + + frame : ReferenceFrame + The frame the inertia is defined in. + ixx : Sympifyable + The xx element in the inertia dyadic. + iyy : Sympifyable + The yy element in the inertia dyadic. + izz : Sympifyable + The zz element in the inertia dyadic. + ixy : Sympifyable + The xy element in the inertia dyadic. + iyz : Sympifyable + The yz element in the inertia dyadic. + izx : Sympifyable + The zx element in the inertia dyadic. + + Examples + ======== + + >>> from sympy.physics.mechanics import ReferenceFrame, inertia + >>> N = ReferenceFrame('N') + >>> inertia(N, 1, 2, 3) + (N.x|N.x) + 2*(N.y|N.y) + 3*(N.z|N.z) + + """ + + if not isinstance(frame, ReferenceFrame): + raise TypeError('Need to define the inertia in a frame') + ixx, iyy, izz = sympify(ixx), sympify(iyy), sympify(izz) + ixy, iyz, izx = sympify(ixy), sympify(iyz), sympify(izx) + return (ixx*outer(frame.x, frame.x) + ixy*outer(frame.x, frame.y) + + izx*outer(frame.x, frame.z) + ixy*outer(frame.y, frame.x) + + iyy*outer(frame.y, frame.y) + iyz*outer(frame.y, frame.z) + + izx*outer(frame.z, frame.x) + iyz*outer(frame.z, frame.y) + + izz*outer(frame.z, frame.z)) + + +def inertia_of_point_mass(mass, pos_vec, frame): + """Inertia dyadic of a point mass relative to point O. + + Parameters + ========== + + mass : Sympifyable + Mass of the point mass + pos_vec : Vector + Position from point O to point mass + frame : ReferenceFrame + Reference frame to express the dyadic in + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ReferenceFrame, inertia_of_point_mass + >>> N = ReferenceFrame('N') + >>> r, m = symbols('r m') + >>> px = r * N.x + >>> inertia_of_point_mass(m, px, N) + m*r**2*(N.y|N.y) + m*r**2*(N.z|N.z) + + """ + + return mass*( + (outer(frame.x, frame.x) + + outer(frame.y, frame.y) + + outer(frame.z, frame.z)) * + (pos_vec.dot(pos_vec)) - outer(pos_vec, pos_vec)) + + +class Inertia(namedtuple('Inertia', ['dyadic', 'point'])): + """Inertia object consisting of a Dyadic and a Point of reference. + + Explanation + =========== + + This is a simple class to store the Point and Dyadic, belonging to an + inertia. + + Attributes + ========== + + dyadic : Dyadic + The dyadic of the inertia. + point : Point + The reference point of the inertia. + + Examples + ======== + + >>> from sympy.physics.mechanics import ReferenceFrame, Point, Inertia + >>> N = ReferenceFrame('N') + >>> Po = Point('Po') + >>> Inertia(N.x.outer(N.x) + N.y.outer(N.y) + N.z.outer(N.z), Po) + ((N.x|N.x) + (N.y|N.y) + (N.z|N.z), Po) + + In the example above the Dyadic was created manually, one can however also + use the ``inertia`` function for this or the class method ``from_tensor`` as + shown below. + + >>> Inertia.from_inertia_scalars(Po, N, 1, 1, 1) + ((N.x|N.x) + (N.y|N.y) + (N.z|N.z), Po) + + """ + def __new__(cls, dyadic, point): + # Switch order if given in the wrong order + if isinstance(dyadic, Point) and isinstance(point, Dyadic): + point, dyadic = dyadic, point + if not isinstance(point, Point): + raise TypeError('Reference point should be of type Point') + if not isinstance(dyadic, Dyadic): + raise TypeError('Inertia value should be expressed as a Dyadic') + return super().__new__(cls, dyadic, point) + + @classmethod + def from_inertia_scalars(cls, point, frame, ixx, iyy, izz, ixy=0, iyz=0, + izx=0): + """Simple way to create an Inertia object based on the tensor values. + + Explanation + =========== + + This class method uses the :func`~.inertia` to create the Dyadic based + on the tensor values. + + Parameters + ========== + + point : Point + The reference point of the inertia. + frame : ReferenceFrame + The frame the inertia is defined in. + ixx : Sympifyable + The xx element in the inertia dyadic. + iyy : Sympifyable + The yy element in the inertia dyadic. + izz : Sympifyable + The zz element in the inertia dyadic. + ixy : Sympifyable + The xy element in the inertia dyadic. + iyz : Sympifyable + The yz element in the inertia dyadic. + izx : Sympifyable + The zx element in the inertia dyadic. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import ReferenceFrame, Point, Inertia + >>> ixx, iyy, izz, ixy, iyz, izx = symbols('ixx iyy izz ixy iyz izx') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> I = Inertia.from_inertia_scalars(P, N, ixx, iyy, izz, ixy, iyz, izx) + + The tensor values can easily be seen when converting the dyadic to a + matrix. + + >>> I.dyadic.to_matrix(N) + Matrix([ + [ixx, ixy, izx], + [ixy, iyy, iyz], + [izx, iyz, izz]]) + + """ + return cls(inertia(frame, ixx, iyy, izz, ixy, iyz, izx), point) + + def __add__(self, other): + raise TypeError(f"unsupported operand type(s) for +: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + def __mul__(self, other): + raise TypeError(f"unsupported operand type(s) for *: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + __radd__ = __add__ + __rmul__ = __mul__ diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/joint.py b/lib/python3.10/site-packages/sympy/physics/mechanics/joint.py new file mode 100644 index 0000000000000000000000000000000000000000..af53cc67e4d70abc7b651da123a9361ddf263b60 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/joint.py @@ -0,0 +1,2188 @@ +# coding=utf-8 + +from abc import ABC, abstractmethod + +from sympy import pi, Derivative, Matrix +from sympy.core.function import AppliedUndef +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.functions import _validate_coordinates +from sympy.physics.vector import (Vector, dynamicsymbols, cross, Point, + ReferenceFrame) +from sympy.utilities.iterables import iterable +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Joint', 'PinJoint', 'PrismaticJoint', 'CylindricalJoint', + 'PlanarJoint', 'SphericalJoint', 'WeldJoint'] + + +class Joint(ABC): + """Abstract base class for all specific joints. + + Explanation + =========== + + A joint subtracts degrees of freedom from a body. This is the base class + for all specific joints and holds all common methods acting as an interface + for all joints. Custom joint can be created by inheriting Joint class and + defining all abstract functions. + + The abstract methods are: + + - ``_generate_coordinates`` + - ``_generate_speeds`` + - ``_orient_frames`` + - ``_set_angular_velocity`` + - ``_set_linear_velocity`` + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody or Body + The parent body of joint. + child : Particle or RigidBody or Body + The child body of joint. + coordinates : iterable of dynamicsymbols, optional + Generalized coordinates of the joint. + speeds : iterable of dynamicsymbols, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody or Body + The joint's parent body. + child : Particle or RigidBody or Body + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Notes + ===== + + When providing a vector as the intermediate frame, a new intermediate frame + is created which aligns its X axis with the provided vector. This is done + with a single fixed rotation about a rotation axis. This rotation axis is + determined by taking the cross product of the ``body.x`` axis with the + provided vector. In the case where the provided vector is in the ``-body.x`` + direction, the rotation is done about the ``body.y`` axis. + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + parent_joint_pos=None, child_joint_pos=None): + + if not isinstance(name, str): + raise TypeError('Supply a valid name.') + self._name = name + + if not isinstance(parent, BodyBase): + raise TypeError('Parent must be a body.') + self._parent = parent + + if not isinstance(child, BodyBase): + raise TypeError('Child must be a body.') + self._child = child + + if parent_axis is not None or child_axis is not None: + sympy_deprecation_warning( + """ + The parent_axis and child_axis arguments for the Joint classes + are deprecated. Instead use parent_interframe, child_interframe. + """, + deprecated_since_version="1.12", + active_deprecations_target="deprecated-mechanics-joint-axis", + stacklevel=4 + ) + if parent_interframe is None: + parent_interframe = parent_axis + if child_interframe is None: + child_interframe = child_axis + + # Set parent and child frame attributes + if hasattr(self._parent, 'frame'): + self._parent_frame = self._parent.frame + else: + if isinstance(parent_interframe, ReferenceFrame): + self._parent_frame = parent_interframe + else: + self._parent_frame = ReferenceFrame( + f'{self.name}_{self._parent.name}_frame') + if hasattr(self._child, 'frame'): + self._child_frame = self._child.frame + else: + if isinstance(child_interframe, ReferenceFrame): + self._child_frame = child_interframe + else: + self._child_frame = ReferenceFrame( + f'{self.name}_{self._child.name}_frame') + + self._parent_interframe = self._locate_joint_frame( + self._parent, parent_interframe, self._parent_frame) + self._child_interframe = self._locate_joint_frame( + self._child, child_interframe, self._child_frame) + self._parent_axis = self._axis(parent_axis, self._parent_frame) + self._child_axis = self._axis(child_axis, self._child_frame) + + if parent_joint_pos is not None or child_joint_pos is not None: + sympy_deprecation_warning( + """ + The parent_joint_pos and child_joint_pos arguments for the Joint + classes are deprecated. Instead use parent_point and child_point. + """, + deprecated_since_version="1.12", + active_deprecations_target="deprecated-mechanics-joint-pos", + stacklevel=4 + ) + if parent_point is None: + parent_point = parent_joint_pos + if child_point is None: + child_point = child_joint_pos + self._parent_point = self._locate_joint_pos( + self._parent, parent_point, self._parent_frame) + self._child_point = self._locate_joint_pos( + self._child, child_point, self._child_frame) + + self._coordinates = self._generate_coordinates(coordinates) + self._speeds = self._generate_speeds(speeds) + _validate_coordinates(self.coordinates, self.speeds) + self._kdes = self._generate_kdes() + + self._orient_frames() + self._set_angular_velocity() + self._set_linear_velocity() + + def __str__(self): + return self.name + + def __repr__(self): + return self.__str__() + + @property + def name(self): + """Name of the joint.""" + return self._name + + @property + def parent(self): + """Parent body of Joint.""" + return self._parent + + @property + def child(self): + """Child body of Joint.""" + return self._child + + @property + def coordinates(self): + """Matrix of the joint's generalized coordinates.""" + return self._coordinates + + @property + def speeds(self): + """Matrix of the joint's generalized speeds.""" + return self._speeds + + @property + def kdes(self): + """Kinematical differential equations of the joint.""" + return self._kdes + + @property + def parent_axis(self): + """The axis of parent frame.""" + # Will be removed with `deprecated-mechanics-joint-axis` + return self._parent_axis + + @property + def child_axis(self): + """The axis of child frame.""" + # Will be removed with `deprecated-mechanics-joint-axis` + return self._child_axis + + @property + def parent_point(self): + """Attachment point where the joint is fixed to the parent body.""" + return self._parent_point + + @property + def child_point(self): + """Attachment point where the joint is fixed to the child body.""" + return self._child_point + + @property + def parent_interframe(self): + return self._parent_interframe + + @property + def child_interframe(self): + return self._child_interframe + + @abstractmethod + def _generate_coordinates(self, coordinates): + """Generate Matrix of the joint's generalized coordinates.""" + pass + + @abstractmethod + def _generate_speeds(self, speeds): + """Generate Matrix of the joint's generalized speeds.""" + pass + + @abstractmethod + def _orient_frames(self): + """Orient frames as per the joint.""" + pass + + @abstractmethod + def _set_angular_velocity(self): + """Set angular velocity of the joint related frames.""" + pass + + @abstractmethod + def _set_linear_velocity(self): + """Set velocity of related points to the joint.""" + pass + + @staticmethod + def _to_vector(matrix, frame): + """Converts a matrix to a vector in the given frame.""" + return Vector([(matrix, frame)]) + + @staticmethod + def _axis(ax, *frames): + """Check whether an axis is fixed in one of the frames.""" + if ax is None: + ax = frames[0].x + return ax + if not isinstance(ax, Vector): + raise TypeError("Axis must be a Vector.") + ref_frame = None # Find a body in which the axis can be expressed + for frame in frames: + try: + ax.to_matrix(frame) + except ValueError: + pass + else: + ref_frame = frame + break + if ref_frame is None: + raise ValueError("Axis cannot be expressed in one of the body's " + "frames.") + if not ax.dt(ref_frame) == 0: + raise ValueError('Axis cannot be time-varying when viewed from the ' + 'associated body.') + return ax + + @staticmethod + def _choose_rotation_axis(frame, axis): + components = axis.to_matrix(frame) + x, y, z = components[0], components[1], components[2] + + if x != 0: + if y != 0: + if z != 0: + return cross(axis, frame.x) + if z != 0: + return frame.y + return frame.z + else: + if y != 0: + return frame.x + return frame.y + + @staticmethod + def _create_aligned_interframe(frame, align_axis, frame_axis=None, + frame_name=None): + """ + Returns an intermediate frame, where the ``frame_axis`` defined in + ``frame`` is aligned with ``axis``. By default this means that the X + axis will be aligned with ``axis``. + + Parameters + ========== + + frame : BodyBase or ReferenceFrame + The body or reference frame with respect to which the intermediate + frame is oriented. + align_axis : Vector + The vector with respect to which the intermediate frame will be + aligned. + frame_axis : Vector + The vector of the frame which should get aligned with ``axis``. The + default is the X axis of the frame. + frame_name : string + Name of the to be created intermediate frame. The default adds + "_int_frame" to the name of ``frame``. + + Example + ======= + + An intermediate frame, where the X axis of the parent becomes aligned + with ``parent.y + parent.z`` can be created as follows: + + >>> from sympy.physics.mechanics.joint import Joint + >>> from sympy.physics.mechanics import RigidBody + >>> parent = RigidBody('parent') + >>> parent_interframe = Joint._create_aligned_interframe( + ... parent, parent.y + parent.z) + >>> parent_interframe + parent_int_frame + >>> parent.frame.dcm(parent_interframe) + Matrix([ + [ 0, -sqrt(2)/2, -sqrt(2)/2], + [sqrt(2)/2, 1/2, -1/2], + [sqrt(2)/2, -1/2, 1/2]]) + >>> (parent.y + parent.z).express(parent_interframe) + sqrt(2)*parent_int_frame.x + + Notes + ===== + + The direction cosine matrix between the given frame and intermediate + frame is formed using a simple rotation about an axis that is normal to + both ``align_axis`` and ``frame_axis``. In general, the normal axis is + formed by crossing the ``frame_axis`` with the ``align_axis``. The + exception is if the axes are parallel with opposite directions, in which + case the rotation vector is chosen using the rules in the following + table with the vectors expressed in the given frame: + + .. list-table:: + :header-rows: 1 + + * - ``align_axis`` + - ``frame_axis`` + - ``rotation_axis`` + * - ``-x`` + - ``x`` + - ``z`` + * - ``-y`` + - ``y`` + - ``x`` + * - ``-z`` + - ``z`` + - ``y`` + * - ``-x-y`` + - ``x+y`` + - ``z`` + * - ``-y-z`` + - ``y+z`` + - ``x`` + * - ``-x-z`` + - ``x+z`` + - ``y`` + * - ``-x-y-z`` + - ``x+y+z`` + - ``(x+y+z) × x`` + + """ + if isinstance(frame, BodyBase): + frame = frame.frame + if frame_axis is None: + frame_axis = frame.x + if frame_name is None: + if frame.name[-6:] == '_frame': + frame_name = f'{frame.name[:-6]}_int_frame' + else: + frame_name = f'{frame.name}_int_frame' + angle = frame_axis.angle_between(align_axis) + rotation_axis = cross(frame_axis, align_axis) + if rotation_axis == Vector(0) and angle == 0: + return frame + if angle == pi: + rotation_axis = Joint._choose_rotation_axis(frame, align_axis) + + int_frame = ReferenceFrame(frame_name) + int_frame.orient_axis(frame, rotation_axis, angle) + int_frame.set_ang_vel(frame, 0 * rotation_axis) + return int_frame + + def _generate_kdes(self): + """Generate kinematical differential equations.""" + kdes = [] + t = dynamicsymbols._t + for i in range(len(self.coordinates)): + kdes.append(-self.coordinates[i].diff(t) + self.speeds[i]) + return Matrix(kdes) + + def _locate_joint_pos(self, body, joint_pos, body_frame=None): + """Returns the attachment point of a body.""" + if body_frame is None: + body_frame = body.frame + if joint_pos is None: + return body.masscenter + if not isinstance(joint_pos, (Point, Vector)): + raise TypeError('Attachment point must be a Point or Vector.') + if isinstance(joint_pos, Vector): + point_name = f'{self.name}_{body.name}_joint' + joint_pos = body.masscenter.locatenew(point_name, joint_pos) + if not joint_pos.pos_from(body.masscenter).dt(body_frame) == 0: + raise ValueError('Attachment point must be fixed to the associated ' + 'body.') + return joint_pos + + def _locate_joint_frame(self, body, interframe, body_frame=None): + """Returns the attachment frame of a body.""" + if body_frame is None: + body_frame = body.frame + if interframe is None: + return body_frame + if isinstance(interframe, Vector): + interframe = Joint._create_aligned_interframe( + body_frame, interframe, + frame_name=f'{self.name}_{body.name}_int_frame') + elif not isinstance(interframe, ReferenceFrame): + raise TypeError('Interframe must be a ReferenceFrame.') + if not interframe.ang_vel_in(body_frame) == 0: + raise ValueError(f'Interframe {interframe} is not fixed to body ' + f'{body}.') + body.masscenter.set_vel(interframe, 0) # Fixate interframe to body + return interframe + + def _fill_coordinate_list(self, coordinates, n_coords, label='q', offset=0, + number_single=False): + """Helper method for _generate_coordinates and _generate_speeds. + + Parameters + ========== + + coordinates : iterable + Iterable of coordinates or speeds that have been provided. + n_coords : Integer + Number of coordinates that should be returned. + label : String, optional + Coordinate type either 'q' (coordinates) or 'u' (speeds). The + Default is 'q'. + offset : Integer + Count offset when creating new dynamicsymbols. The default is 0. + number_single : Boolean + Boolean whether if n_coords == 1, number should still be used. The + default is False. + + """ + + def create_symbol(number): + if n_coords == 1 and not number_single: + return dynamicsymbols(f'{label}_{self.name}') + return dynamicsymbols(f'{label}{number}_{self.name}') + + name = 'generalized coordinate' if label == 'q' else 'generalized speed' + generated_coordinates = [] + if coordinates is None: + coordinates = [] + elif not iterable(coordinates): + coordinates = [coordinates] + if not (len(coordinates) == 0 or len(coordinates) == n_coords): + raise ValueError(f'Expected {n_coords} {name}s, instead got ' + f'{len(coordinates)} {name}s.') + # Supports more iterables, also Matrix + for i, coord in enumerate(coordinates): + if coord is None: + generated_coordinates.append(create_symbol(i + offset)) + elif isinstance(coord, (AppliedUndef, Derivative)): + generated_coordinates.append(coord) + else: + raise TypeError(f'The {name} {coord} should have been a ' + f'dynamicsymbol.') + for i in range(len(coordinates) + offset, n_coords + offset): + generated_coordinates.append(create_symbol(i)) + return Matrix(generated_coordinates) + + +class PinJoint(Joint): + """Pin (Revolute) Joint. + + .. raw:: html + :file: ../../../doc/src/modules/physics/mechanics/api/PinJoint.svg + + Explanation + =========== + + A pin joint is defined such that the joint rotation axis is fixed in both + the child and parent and the location of the joint is relative to the mass + center of each body. The child rotates an angle, θ, from the parent about + the rotation axis and has a simple angular speed, ω, relative to the + parent. The direction cosine matrix between the child interframe and + parent interframe is formed using a simple rotation about the joint axis. + The page on the joints framework gives a more detailed explanation of the + intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody or Body + The parent body of joint. + child : Particle or RigidBody or Body + The child body of joint. + coordinates : dynamicsymbol, optional + Generalized coordinates of the joint. + speeds : dynamicsymbol, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector + The axis about which the rotation occurs. Note that the components + of this axis are the same in the parent_interframe and child_interframe. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody or Body + The joint's parent body. + child : Particle or RigidBody or Body + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : Matrix + Matrix of the joint's generalized speeds. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + joint_axis : Vector + The axis about which the rotation occurs. Note that the components of + this axis are the same in the parent_interframe and child_interframe. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single pin joint is created from two bodies and has the following basic + attributes: + + >>> from sympy.physics.mechanics import RigidBody, PinJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PinJoint('PC', parent, child) + >>> joint + PinJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([[q_PC(t)]]) + >>> joint.speeds + Matrix([[u_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q_PC(t)), sin(q_PC(t))], + [0, -sin(q_PC(t)), cos(q_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the pin joint, the kinematics of simple + double pendulum that rotates about the Z axis of each connected body can be + created as follows. + + >>> from sympy import symbols, trigsimp + >>> from sympy.physics.mechanics import RigidBody, PinJoint + >>> l1, l2 = symbols('l1 l2') + + First create bodies to represent the fixed ceiling and one to represent + each pendulum bob. + + >>> ceiling = RigidBody('C') + >>> upper_bob = RigidBody('U') + >>> lower_bob = RigidBody('L') + + The first joint will connect the upper bob to the ceiling by a distance of + ``l1`` and the joint axis will be about the Z axis for each body. + + >>> ceiling_joint = PinJoint('P1', ceiling, upper_bob, + ... child_point=-l1*upper_bob.frame.x, + ... joint_axis=ceiling.frame.z) + + The second joint will connect the lower bob to the upper bob by a distance + of ``l2`` and the joint axis will also be about the Z axis for each body. + + >>> pendulum_joint = PinJoint('P2', upper_bob, lower_bob, + ... child_point=-l2*lower_bob.frame.x, + ... joint_axis=upper_bob.frame.z) + + Once the joints are established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of pendulum link relative + to the ceiling are found: + + >>> upper_bob.frame.dcm(ceiling.frame) + Matrix([ + [ cos(q_P1(t)), sin(q_P1(t)), 0], + [-sin(q_P1(t)), cos(q_P1(t)), 0], + [ 0, 0, 1]]) + >>> trigsimp(lower_bob.frame.dcm(ceiling.frame)) + Matrix([ + [ cos(q_P1(t) + q_P2(t)), sin(q_P1(t) + q_P2(t)), 0], + [-sin(q_P1(t) + q_P2(t)), cos(q_P1(t) + q_P2(t)), 0], + [ 0, 0, 1]]) + + The position of the lower bob's masscenter is found with: + + >>> lower_bob.masscenter.pos_from(ceiling.masscenter) + l1*U_frame.x + l2*L_frame.x + + The angular velocities of the two pendulum links can be computed with + respect to the ceiling. + + >>> upper_bob.frame.ang_vel_in(ceiling.frame) + u_P1(t)*C_frame.z + >>> lower_bob.frame.ang_vel_in(ceiling.frame) + u_P1(t)*C_frame.z + u_P2(t)*U_frame.z + + And finally, the linear velocities of the two pendulum bobs can be computed + with respect to the ceiling. + + >>> upper_bob.masscenter.vel(ceiling.frame) + l1*u_P1(t)*U_frame.y + >>> lower_bob.masscenter.vel(ceiling.frame) + l1*u_P1(t)*U_frame.y + l2*(u_P1(t) + u_P2(t))*L_frame.y + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + joint_axis=None, parent_joint_pos=None, child_joint_pos=None): + + self._joint_axis = joint_axis + super().__init__(name, parent, child, coordinates, speeds, parent_point, + child_point, parent_interframe, child_interframe, + parent_axis, child_axis, parent_joint_pos, + child_joint_pos) + + def __str__(self): + return (f'PinJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis about which the child rotates with respect to the parent.""" + return self._joint_axis + + def _generate_coordinates(self, coordinate): + return self._fill_coordinate_list(coordinate, 1, 'q') + + def _generate_speeds(self, speed): + return self._fill_coordinate_list(speed, 1, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, self.coordinates[0]) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, self.speeds[ + 0] * self.joint_axis.normalize()) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.v2pt_theory(self.parent_point, + self._parent_frame, self._child_frame) + + +class PrismaticJoint(Joint): + """Prismatic (Sliding) Joint. + + .. image:: PrismaticJoint.svg + + Explanation + =========== + + It is defined such that the child body translates with respect to the parent + body along the body-fixed joint axis. The location of the joint is defined + by two points, one in each body, which coincide when the generalized + coordinate is zero. The direction cosine matrix between the + parent_interframe and child_interframe is the identity matrix. Therefore, + the direction cosine matrix between the parent and child frames is fully + defined by the definition of the intermediate frames. The page on the joints + framework gives a more detailed explanation of the intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody or Body + The parent body of joint. + child : Particle or RigidBody or Body + The child body of joint. + coordinates : dynamicsymbol, optional + Generalized coordinates of the joint. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : dynamicsymbol, optional + Generalized speeds of joint. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the parent body which aligns with an axis fixed in the + child body. The default is the x axis of parent's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + child_axis : Vector, optional + .. deprecated:: 1.12 + Axis fixed in the child body which aligns with an axis fixed in the + parent body. The default is the x axis of child's reference frame. + For more information on this deprecation, see + :ref:`deprecated-mechanics-joint-axis`. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector + The axis along which the translation occurs. Note that the components + of this axis are the same in the parent_interframe and child_interframe. + parent_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by parent_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + child_joint_pos : Point or Vector, optional + .. deprecated:: 1.12 + This argument is replaced by child_point and will be removed in a + future version. + See :ref:`deprecated-mechanics-joint-pos` for more information. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody or Body + The joint's parent body. + child : Particle or RigidBody or Body + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_axis : Vector + The axis fixed in the parent frame that represents the joint. + child_axis : Vector + The axis fixed in the child frame that represents the joint. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single prismatic joint is created from two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, PrismaticJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PrismaticJoint('PC', parent, child) + >>> joint + PrismaticJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([[q_PC(t)]]) + >>> joint.speeds + Matrix([[u_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + 0 + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> joint.child_point.pos_from(joint.parent_point) + q_PC(t)*P_frame.x + + To further demonstrate the use of the prismatic joint, the kinematics of two + masses sliding, one moving relative to a fixed body and the other relative + to the moving body. about the X axis of each connected body can be created + as follows. + + >>> from sympy.physics.mechanics import PrismaticJoint, RigidBody + + First create bodies to represent the fixed ceiling and one to represent + a particle. + + >>> wall = RigidBody('W') + >>> Part1 = RigidBody('P1') + >>> Part2 = RigidBody('P2') + + The first joint will connect the particle to the ceiling and the + joint axis will be about the X axis for each body. + + >>> J1 = PrismaticJoint('J1', wall, Part1) + + The second joint will connect the second particle to the first particle + and the joint axis will also be about the X axis for each body. + + >>> J2 = PrismaticJoint('J2', Part1, Part2) + + Once the joint is established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of Part relative + to the ceiling are found: + + >>> Part1.frame.dcm(wall.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + >>> Part2.frame.dcm(wall.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + + The position of the particles' masscenter is found with: + + >>> Part1.masscenter.pos_from(wall.masscenter) + q_J1(t)*W_frame.x + + >>> Part2.masscenter.pos_from(wall.masscenter) + q_J1(t)*W_frame.x + q_J2(t)*P1_frame.x + + The angular velocities of the two particle links can be computed with + respect to the ceiling. + + >>> Part1.frame.ang_vel_in(wall.frame) + 0 + + >>> Part2.frame.ang_vel_in(wall.frame) + 0 + + And finally, the linear velocities of the two particles can be computed + with respect to the ceiling. + + >>> Part1.masscenter.vel(wall.frame) + u_J1(t)*W_frame.x + + >>> Part2.masscenter.vel(wall.frame) + u_J1(t)*W_frame.x + Derivative(q_J2(t), t)*P1_frame.x + + """ + + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, parent_axis=None, child_axis=None, + joint_axis=None, parent_joint_pos=None, child_joint_pos=None): + + self._joint_axis = joint_axis + super().__init__(name, parent, child, coordinates, speeds, parent_point, + child_point, parent_interframe, child_interframe, + parent_axis, child_axis, parent_joint_pos, + child_joint_pos) + + def __str__(self): + return (f'PrismaticJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis along which the child translates with respect to the parent.""" + return self._joint_axis + + def _generate_coordinates(self, coordinate): + return self._fill_coordinate_list(coordinate, 1, 'q') + + def _generate_speeds(self, speed): + return self._fill_coordinate_list(speed, 1, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, 0) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, 0) + + def _set_linear_velocity(self): + axis = self.joint_axis.normalize() + self.child_point.set_pos(self.parent_point, self.coordinates[0] * axis) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child_point.set_vel(self._parent_frame, self.speeds[0] * axis) + self.child.masscenter.set_vel(self._parent_frame, self.speeds[0] * axis) + + +class CylindricalJoint(Joint): + """Cylindrical Joint. + + .. image:: CylindricalJoint.svg + :align: center + :width: 600 + + Explanation + =========== + + A cylindrical joint is defined such that the child body both rotates about + and translates along the body-fixed joint axis with respect to the parent + body. The joint axis is both the rotation axis and translation axis. The + location of the joint is defined by two points, one in each body, which + coincide when the generalized coordinate corresponding to the translation is + zero. The direction cosine matrix between the child interframe and parent + interframe is formed using a simple rotation about the joint axis. The page + on the joints framework gives a more detailed explanation of the + intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody or Body + The parent body of joint. + child : Particle or RigidBody or Body + The child body of joint. + rotation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the rotation angle. The default + value is ``dynamicsymbols(f'q0_{joint.name}')``. + translation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the translation distance. The + default value is ``dynamicsymbols(f'q1_{joint.name}')``. + rotation_speed : dynamicsymbol, optional + Generalized speed corresponding to the angular velocity. The default + value is ``dynamicsymbols(f'u0_{joint.name}')``. + translation_speed : dynamicsymbol, optional + Generalized speed corresponding to the translation velocity. The default + value is ``dynamicsymbols(f'u1_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + joint_axis : Vector, optional + The rotation as well as translation axis. Note that the components of + this axis are the same in the parent_interframe and child_interframe. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody or Body + The joint's parent body. + child : Particle or RigidBody or Body + The joint's child body. + rotation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the rotation angle. + translation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the translation distance. + rotation_speed : dynamicsymbol + Generalized speed corresponding to the angular velocity. + translation_speed : dynamicsymbol + Generalized speed corresponding to the translation velocity. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + joint_axis : Vector + The axis of rotation and translation. + + Examples + ========= + + A single cylindrical joint is created between two bodies and has the + following basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, CylindricalJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = CylindricalJoint('PC', parent, child) + >>> joint + CylindricalJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_axis + P_frame.x + >>> joint.child_axis + C_frame.x + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u0_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q0_PC(t)), sin(q0_PC(t))], + [0, -sin(q0_PC(t)), cos(q0_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + q1_PC(t)*P_frame.x + >>> child.masscenter.vel(parent.frame) + u1_PC(t)*P_frame.x + + To further demonstrate the use of the cylindrical joint, the kinematics of + two cylindrical joints perpendicular to each other can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import RigidBody, CylindricalJoint + >>> r, l, w = symbols('r l w') + + First create bodies to represent the fixed floor with a fixed pole on it. + The second body represents a freely moving tube around that pole. The third + body represents a solid flag freely translating along and rotating around + the Y axis of the tube. + + >>> floor = RigidBody('floor') + >>> tube = RigidBody('tube') + >>> flag = RigidBody('flag') + + The first joint will connect the first tube to the floor with it translating + along and rotating around the Z axis of both bodies. + + >>> floor_joint = CylindricalJoint('C1', floor, tube, joint_axis=floor.z) + + The second joint will connect the tube perpendicular to the flag along the Y + axis of both the tube and the flag, with the joint located at a distance + ``r`` from the tube's center of mass and a combination of the distances + ``l`` and ``w`` from the flag's center of mass. + + >>> flag_joint = CylindricalJoint('C2', tube, flag, + ... parent_point=r * tube.y, + ... child_point=-w * flag.y + l * flag.z, + ... joint_axis=tube.y) + + Once the joints are established the kinematics of the connected bodies can + be accessed. First the direction cosine matrices of both the body and the + flag relative to the floor are found: + + >>> tube.frame.dcm(floor.frame) + Matrix([ + [ cos(q0_C1(t)), sin(q0_C1(t)), 0], + [-sin(q0_C1(t)), cos(q0_C1(t)), 0], + [ 0, 0, 1]]) + >>> flag.frame.dcm(floor.frame) + Matrix([ + [cos(q0_C1(t))*cos(q0_C2(t)), sin(q0_C1(t))*cos(q0_C2(t)), -sin(q0_C2(t))], + [ -sin(q0_C1(t)), cos(q0_C1(t)), 0], + [sin(q0_C2(t))*cos(q0_C1(t)), sin(q0_C1(t))*sin(q0_C2(t)), cos(q0_C2(t))]]) + + The position of the flag's center of mass is found with: + + >>> flag.masscenter.pos_from(floor.masscenter) + q1_C1(t)*floor_frame.z + (r + q1_C2(t))*tube_frame.y + w*flag_frame.y - l*flag_frame.z + + The angular velocities of the two tubes can be computed with respect to the + floor. + + >>> tube.frame.ang_vel_in(floor.frame) + u0_C1(t)*floor_frame.z + >>> flag.frame.ang_vel_in(floor.frame) + u0_C1(t)*floor_frame.z + u0_C2(t)*tube_frame.y + + Finally, the linear velocities of the two tube centers of mass can be + computed with respect to the floor, while expressed in the tube's frame. + + >>> tube.masscenter.vel(floor.frame).to_matrix(tube.frame) + Matrix([ + [ 0], + [ 0], + [u1_C1(t)]]) + >>> flag.masscenter.vel(floor.frame).to_matrix(tube.frame).simplify() + Matrix([ + [-l*u0_C2(t)*cos(q0_C2(t)) - r*u0_C1(t) - w*u0_C1(t) - q1_C2(t)*u0_C1(t)], + [ -l*u0_C1(t)*sin(q0_C2(t)) + Derivative(q1_C2(t), t)], + [ l*u0_C2(t)*sin(q0_C2(t)) + u1_C1(t)]]) + + """ + + def __init__(self, name, parent, child, rotation_coordinate=None, + translation_coordinate=None, rotation_speed=None, + translation_speed=None, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None, + joint_axis=None): + self._joint_axis = joint_axis + coordinates = (rotation_coordinate, translation_coordinate) + speeds = (rotation_speed, translation_speed) + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'CylindricalJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def joint_axis(self): + """Axis about and along which the rotation and translation occurs.""" + return self._joint_axis + + @property + def rotation_coordinate(self): + """Generalized coordinate corresponding to the rotation angle.""" + return self.coordinates[0] + + @property + def translation_coordinate(self): + """Generalized coordinate corresponding to the translation distance.""" + return self.coordinates[1] + + @property + def rotation_speed(self): + """Generalized speed corresponding to the angular velocity.""" + return self.speeds[0] + + @property + def translation_speed(self): + """Generalized speed corresponding to the translation velocity.""" + return self.speeds[1] + + def _generate_coordinates(self, coordinates): + return self._fill_coordinate_list(coordinates, 2, 'q') + + def _generate_speeds(self, speeds): + return self._fill_coordinate_list(speeds, 2, 'u') + + def _orient_frames(self): + self._joint_axis = self._axis(self.joint_axis, self.parent_interframe) + self.child_interframe.orient_axis( + self.parent_interframe, self.joint_axis, self.rotation_coordinate) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel( + self.parent_interframe, + self.rotation_speed * self.joint_axis.normalize()) + + def _set_linear_velocity(self): + self.child_point.set_pos( + self.parent_point, + self.translation_coordinate * self.joint_axis.normalize()) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child_point.set_vel( + self._parent_frame, + self.translation_speed * self.joint_axis.normalize()) + self.child.masscenter.v2pt_theory(self.child_point, self._parent_frame, + self.child_interframe) + + +class PlanarJoint(Joint): + """Planar Joint. + + .. raw:: html + :file: ../../../doc/src/modules/physics/mechanics/api/PlanarJoint.svg + + Explanation + =========== + + A planar joint is defined such that the child body translates over a fixed + plane of the parent body as well as rotate about the rotation axis, which + is perpendicular to that plane. The origin of this plane is the + ``parent_point`` and the plane is spanned by two nonparallel planar vectors. + The location of the ``child_point`` is based on the planar vectors + ($\\vec{v}_1$, $\\vec{v}_2$) and generalized coordinates ($q_1$, $q_2$), + i.e. $\\vec{r} = q_1 \\hat{v}_1 + q_2 \\hat{v}_2$. The direction cosine + matrix between the ``child_interframe`` and ``parent_interframe`` is formed + using a simple rotation ($q_0$) about the rotation axis. + + In order to simplify the definition of the ``PlanarJoint``, the + ``rotation_axis`` and ``planar_vectors`` are set to be the unit vectors of + the ``parent_interframe`` according to the table below. This ensures that + you can only define these vectors by creating a separate frame and supplying + that as the interframe. If you however would only like to supply the normals + of the plane with respect to the parent and child bodies, then you can also + supply those to the ``parent_interframe`` and ``child_interframe`` + arguments. An example of both of these cases is in the examples section + below and the page on the joints framework provides a more detailed + explanation of the intermediate frames. + + .. list-table:: + + * - ``rotation_axis`` + - ``parent_interframe.x`` + * - ``planar_vectors[0]`` + - ``parent_interframe.y`` + * - ``planar_vectors[1]`` + - ``parent_interframe.z`` + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody or Body + The parent body of joint. + child : Particle or RigidBody or Body + The child body of joint. + rotation_coordinate : dynamicsymbol, optional + Generalized coordinate corresponding to the rotation angle. The default + value is ``dynamicsymbols(f'q0_{joint.name}')``. + planar_coordinates : iterable of dynamicsymbols, optional + Two generalized coordinates used for the planar translation. The default + value is ``dynamicsymbols(f'q1_{joint.name} q2_{joint.name}')``. + rotation_speed : dynamicsymbol, optional + Generalized speed corresponding to the angular velocity. The default + value is ``dynamicsymbols(f'u0_{joint.name}')``. + planar_speeds : dynamicsymbols, optional + Two generalized speeds used for the planar translation velocity. The + default value is ``dynamicsymbols(f'u1_{joint.name} u2_{joint.name}')``. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody or Body + The joint's parent body. + child : Particle or RigidBody or Body + The joint's child body. + rotation_coordinate : dynamicsymbol + Generalized coordinate corresponding to the rotation angle. + planar_coordinates : Matrix + Two generalized coordinates used for the planar translation. + rotation_speed : dynamicsymbol + Generalized speed corresponding to the angular velocity. + planar_speeds : Matrix + Two generalized speeds used for the planar translation velocity. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + rotation_axis : Vector + The axis about which the rotation occurs. + planar_vectors : list + The vectors that describe the planar translation directions. + + Examples + ========= + + A single planar joint is created between two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, PlanarJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = PlanarJoint('PC', parent, child) + >>> joint + PlanarJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.rotation_axis + P_frame.x + >>> joint.planar_vectors + [P_frame.y, P_frame.z] + >>> joint.rotation_coordinate + q0_PC(t) + >>> joint.planar_coordinates + Matrix([ + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.rotation_speed + u0_PC(t) + >>> joint.planar_speeds + Matrix([ + [u1_PC(t)], + [u2_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)], + [u2_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame) + u0_PC(t)*P_frame.x + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, cos(q0_PC(t)), sin(q0_PC(t))], + [0, -sin(q0_PC(t)), cos(q0_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + q1_PC(t)*P_frame.y + q2_PC(t)*P_frame.z + >>> child.masscenter.vel(parent.frame) + u1_PC(t)*P_frame.y + u2_PC(t)*P_frame.z + + To further demonstrate the use of the planar joint, the kinematics of a + block sliding on a slope, can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import PlanarJoint, RigidBody, ReferenceFrame + >>> a, d, h = symbols('a d h') + + First create bodies to represent the slope and the block. + + >>> ground = RigidBody('G') + >>> block = RigidBody('B') + + To define the slope you can either define the plane by specifying the + ``planar_vectors`` or/and the ``rotation_axis``. However it is advisable to + create a rotated intermediate frame, so that the ``parent_vectors`` and + ``rotation_axis`` will be the unit vectors of this intermediate frame. + + >>> slope = ReferenceFrame('A') + >>> slope.orient_axis(ground.frame, ground.y, a) + + The planar joint can be created using these bodies and intermediate frame. + We can specify the origin of the slope to be ``d`` above the slope's center + of mass and the block's center of mass to be a distance ``h`` above the + slope's surface. Note that we can specify the normal of the plane using the + rotation axis argument. + + >>> joint = PlanarJoint('PC', ground, block, parent_point=d * ground.x, + ... child_point=-h * block.x, parent_interframe=slope) + + Once the joint is established the kinematics of the bodies can be accessed. + First the ``rotation_axis``, which is normal to the plane and the + ``plane_vectors``, can be found. + + >>> joint.rotation_axis + A.x + >>> joint.planar_vectors + [A.y, A.z] + + The direction cosine matrix of the block with respect to the ground can be + found with: + + >>> block.frame.dcm(ground.frame) + Matrix([ + [ cos(a), 0, -sin(a)], + [sin(a)*sin(q0_PC(t)), cos(q0_PC(t)), sin(q0_PC(t))*cos(a)], + [sin(a)*cos(q0_PC(t)), -sin(q0_PC(t)), cos(a)*cos(q0_PC(t))]]) + + The angular velocity of the block can be computed with respect to the + ground. + + >>> block.frame.ang_vel_in(ground.frame) + u0_PC(t)*A.x + + The position of the block's center of mass can be found with: + + >>> block.masscenter.pos_from(ground.masscenter) + d*G_frame.x + h*B_frame.x + q1_PC(t)*A.y + q2_PC(t)*A.z + + Finally, the linear velocity of the block's center of mass can be + computed with respect to the ground. + + >>> block.masscenter.vel(ground.frame) + u1_PC(t)*A.y + u2_PC(t)*A.z + + In some cases it could be your preference to only define the normals of the + plane with respect to both bodies. This can most easily be done by supplying + vectors to the ``interframe`` arguments. What will happen in this case is + that an interframe will be created with its ``x`` axis aligned with the + provided vector. For a further explanation of how this is done see the notes + of the ``Joint`` class. In the code below, the above example (with the block + on the slope) is recreated by supplying vectors to the interframe arguments. + Note that the previously described option is however more computationally + efficient, because the algorithm now has to compute the rotation angle + between the provided vector and the 'x' axis. + + >>> from sympy import symbols, cos, sin + >>> from sympy.physics.mechanics import PlanarJoint, RigidBody + >>> a, d, h = symbols('a d h') + >>> ground = RigidBody('G') + >>> block = RigidBody('B') + >>> joint = PlanarJoint( + ... 'PC', ground, block, parent_point=d * ground.x, + ... child_point=-h * block.x, child_interframe=block.x, + ... parent_interframe=cos(a) * ground.x + sin(a) * ground.z) + >>> block.frame.dcm(ground.frame).simplify() + Matrix([ + [ cos(a), 0, sin(a)], + [-sin(a)*sin(q0_PC(t)), cos(q0_PC(t)), sin(q0_PC(t))*cos(a)], + [-sin(a)*cos(q0_PC(t)), -sin(q0_PC(t)), cos(a)*cos(q0_PC(t))]]) + + """ + + def __init__(self, name, parent, child, rotation_coordinate=None, + planar_coordinates=None, rotation_speed=None, + planar_speeds=None, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None): + # A ready to merge implementation of setting the planar_vectors and + # rotation_axis was added and removed in PR #24046 + coordinates = (rotation_coordinate, planar_coordinates) + speeds = (rotation_speed, planar_speeds) + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'PlanarJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + @property + def rotation_coordinate(self): + """Generalized coordinate corresponding to the rotation angle.""" + return self.coordinates[0] + + @property + def planar_coordinates(self): + """Two generalized coordinates used for the planar translation.""" + return self.coordinates[1:, 0] + + @property + def rotation_speed(self): + """Generalized speed corresponding to the angular velocity.""" + return self.speeds[0] + + @property + def planar_speeds(self): + """Two generalized speeds used for the planar translation velocity.""" + return self.speeds[1:, 0] + + @property + def rotation_axis(self): + """The axis about which the rotation occurs.""" + return self.parent_interframe.x + + @property + def planar_vectors(self): + """The vectors that describe the planar translation directions.""" + return [self.parent_interframe.y, self.parent_interframe.z] + + def _generate_coordinates(self, coordinates): + rotation_speed = self._fill_coordinate_list(coordinates[0], 1, 'q', + number_single=True) + planar_speeds = self._fill_coordinate_list(coordinates[1], 2, 'q', 1) + return rotation_speed.col_join(planar_speeds) + + def _generate_speeds(self, speeds): + rotation_speed = self._fill_coordinate_list(speeds[0], 1, 'u', + number_single=True) + planar_speeds = self._fill_coordinate_list(speeds[1], 2, 'u', 1) + return rotation_speed.col_join(planar_speeds) + + def _orient_frames(self): + self.child_interframe.orient_axis( + self.parent_interframe, self.rotation_axis, + self.rotation_coordinate) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel( + self.parent_interframe, + self.rotation_speed * self.rotation_axis) + + def _set_linear_velocity(self): + self.child_point.set_pos( + self.parent_point, + self.planar_coordinates[0] * self.planar_vectors[0] + + self.planar_coordinates[1] * self.planar_vectors[1]) + self.parent_point.set_vel(self.parent_interframe, 0) + self.child_point.set_vel(self.child_interframe, 0) + self.child_point.set_vel( + self._parent_frame, self.planar_speeds[0] * self.planar_vectors[0] + + self.planar_speeds[1] * self.planar_vectors[1]) + self.child.masscenter.v2pt_theory(self.child_point, self._parent_frame, + self._child_frame) + + +class SphericalJoint(Joint): + """Spherical (Ball-and-Socket) Joint. + + .. image:: SphericalJoint.svg + :align: center + :width: 600 + + Explanation + =========== + + A spherical joint is defined such that the child body is free to rotate in + any direction, without allowing a translation of the ``child_point``. As can + also be seen in the image, the ``parent_point`` and ``child_point`` are + fixed on top of each other, i.e. the ``joint_point``. This rotation is + defined using the :func:`parent_interframe.orient(child_interframe, + rot_type, amounts, rot_order) + ` method. The default + rotation consists of three relative rotations, i.e. body-fixed rotations. + Based on the direction cosine matrix following from these rotations, the + angular velocity is computed based on the generalized coordinates and + generalized speeds. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody or Body + The parent body of joint. + child : Particle or RigidBody or Body + The child body of joint. + coordinates: iterable of dynamicsymbols, optional + Generalized coordinates of the joint. + speeds : iterable of dynamicsymbols, optional + Generalized speeds of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + rot_type : str, optional + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Body'``: three successive rotations about new intermediate axes, + also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent frames' unit + vectors + + The default method is ``'Body'``. + amounts : + Expressions defining the rotation angles or direction cosine matrix. + These must match the ``rot_type``. See examples below for details. The + input types are: + + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + + The default amounts are the given ``coordinates``. + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required for + ``'Body'`` and ``'Space'``. The default value is ``123``. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody or Body + The joint's parent body. + child : Particle or RigidBody or Body + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. + speeds : Matrix + Matrix of the joint's generalized speeds. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single spherical joint is created from two bodies and has the following + basic attributes: + + >>> from sympy.physics.mechanics import RigidBody, SphericalJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = SphericalJoint('PC', parent, child) + >>> joint + SphericalJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.parent_interframe + P_frame + >>> joint.child_interframe + C_frame + >>> joint.coordinates + Matrix([ + [q0_PC(t)], + [q1_PC(t)], + [q2_PC(t)]]) + >>> joint.speeds + Matrix([ + [u0_PC(t)], + [u1_PC(t)], + [u2_PC(t)]]) + >>> child.frame.ang_vel_in(parent.frame).to_matrix(child.frame) + Matrix([ + [ u0_PC(t)*cos(q1_PC(t))*cos(q2_PC(t)) + u1_PC(t)*sin(q2_PC(t))], + [-u0_PC(t)*sin(q2_PC(t))*cos(q1_PC(t)) + u1_PC(t)*cos(q2_PC(t))], + [ u0_PC(t)*sin(q1_PC(t)) + u2_PC(t)]]) + >>> child.frame.x.to_matrix(parent.frame) + Matrix([ + [ cos(q1_PC(t))*cos(q2_PC(t))], + [sin(q0_PC(t))*sin(q1_PC(t))*cos(q2_PC(t)) + sin(q2_PC(t))*cos(q0_PC(t))], + [sin(q0_PC(t))*sin(q2_PC(t)) - sin(q1_PC(t))*cos(q0_PC(t))*cos(q2_PC(t))]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the spherical joint, the kinematics of a + spherical joint with a ZXZ rotation can be created as follows. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import RigidBody, SphericalJoint + >>> l1 = symbols('l1') + + First create bodies to represent the fixed floor and a pendulum bob. + + >>> floor = RigidBody('F') + >>> bob = RigidBody('B') + + The joint will connect the bob to the floor, with the joint located at a + distance of ``l1`` from the child's center of mass and the rotation set to a + body-fixed ZXZ rotation. + + >>> joint = SphericalJoint('S', floor, bob, child_point=l1 * bob.y, + ... rot_type='body', rot_order='ZXZ') + + Now that the joint is established, the kinematics of the connected body can + be accessed. + + The position of the bob's masscenter is found with: + + >>> bob.masscenter.pos_from(floor.masscenter) + - l1*B_frame.y + + The angular velocities of the pendulum link can be computed with respect to + the floor. + + >>> bob.frame.ang_vel_in(floor.frame).to_matrix( + ... floor.frame).simplify() + Matrix([ + [u1_S(t)*cos(q0_S(t)) + u2_S(t)*sin(q0_S(t))*sin(q1_S(t))], + [u1_S(t)*sin(q0_S(t)) - u2_S(t)*sin(q1_S(t))*cos(q0_S(t))], + [ u0_S(t) + u2_S(t)*cos(q1_S(t))]]) + + Finally, the linear velocity of the bob's center of mass can be computed. + + >>> bob.masscenter.vel(floor.frame).to_matrix(bob.frame) + Matrix([ + [ l1*(u0_S(t)*cos(q1_S(t)) + u2_S(t))], + [ 0], + [-l1*(u0_S(t)*sin(q1_S(t))*sin(q2_S(t)) + u1_S(t)*cos(q2_S(t)))]]) + + """ + def __init__(self, name, parent, child, coordinates=None, speeds=None, + parent_point=None, child_point=None, parent_interframe=None, + child_interframe=None, rot_type='BODY', amounts=None, + rot_order=123): + self._rot_type = rot_type + self._amounts = amounts + self._rot_order = rot_order + super().__init__(name, parent, child, coordinates, speeds, + parent_point, child_point, + parent_interframe=parent_interframe, + child_interframe=child_interframe) + + def __str__(self): + return (f'SphericalJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + def _generate_coordinates(self, coordinates): + return self._fill_coordinate_list(coordinates, 3, 'q') + + def _generate_speeds(self, speeds): + return self._fill_coordinate_list(speeds, len(self.coordinates), 'u') + + def _orient_frames(self): + supported_rot_types = ('BODY', 'SPACE') + if self._rot_type.upper() not in supported_rot_types: + raise NotImplementedError( + f'Rotation type "{self._rot_type}" is not implemented. ' + f'Implemented rotation types are: {supported_rot_types}') + amounts = self.coordinates if self._amounts is None else self._amounts + self.child_interframe.orient(self.parent_interframe, self._rot_type, + amounts, self._rot_order) + + def _set_angular_velocity(self): + t = dynamicsymbols._t + vel = self.child_interframe.ang_vel_in(self.parent_interframe).xreplace( + {q.diff(t): u for q, u in zip(self.coordinates, self.speeds)} + ) + self.child_interframe.set_ang_vel(self.parent_interframe, vel) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.v2pt_theory(self.parent_point, self._parent_frame, + self._child_frame) + + +class WeldJoint(Joint): + """Weld Joint. + + .. raw:: html + :file: ../../../doc/src/modules/physics/mechanics/api/WeldJoint.svg + + Explanation + =========== + + A weld joint is defined such that there is no relative motion between the + child and parent bodies. The direction cosine matrix between the attachment + frame (``parent_interframe`` and ``child_interframe``) is the identity + matrix and the attachment points (``parent_point`` and ``child_point``) are + coincident. The page on the joints framework gives a more detailed + explanation of the intermediate frames. + + Parameters + ========== + + name : string + A unique name for the joint. + parent : Particle or RigidBody or Body + The parent body of joint. + child : Particle or RigidBody or Body + The child body of joint. + parent_point : Point or Vector, optional + Attachment point where the joint is fixed to the parent body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the parent's mass + center. + child_point : Point or Vector, optional + Attachment point where the joint is fixed to the child body. If a + vector is provided, then the attachment point is computed by adding the + vector to the body's mass center. The default value is the child's mass + center. + parent_interframe : ReferenceFrame, optional + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the parent's own frame. + child_interframe : ReferenceFrame, optional + Intermediate frame of the child body with respect to which the joint + transformation is formulated. If a Vector is provided then an interframe + is created which aligns its X axis with the given vector. The default + value is the child's own frame. + + Attributes + ========== + + name : string + The joint's name. + parent : Particle or RigidBody or Body + The joint's parent body. + child : Particle or RigidBody or Body + The joint's child body. + coordinates : Matrix + Matrix of the joint's generalized coordinates. The default value is + ``dynamicsymbols(f'q_{joint.name}')``. + speeds : Matrix + Matrix of the joint's generalized speeds. The default value is + ``dynamicsymbols(f'u_{joint.name}')``. + parent_point : Point + Attachment point where the joint is fixed to the parent body. + child_point : Point + Attachment point where the joint is fixed to the child body. + parent_interframe : ReferenceFrame + Intermediate frame of the parent body with respect to which the joint + transformation is formulated. + child_interframe : ReferenceFrame + Intermediate frame of the child body with respect to which the joint + transformation is formulated. + kdes : Matrix + Kinematical differential equations of the joint. + + Examples + ========= + + A single weld joint is created from two bodies and has the following basic + attributes: + + >>> from sympy.physics.mechanics import RigidBody, WeldJoint + >>> parent = RigidBody('P') + >>> parent + P + >>> child = RigidBody('C') + >>> child + C + >>> joint = WeldJoint('PC', parent, child) + >>> joint + WeldJoint: PC parent: P child: C + >>> joint.name + 'PC' + >>> joint.parent + P + >>> joint.child + C + >>> joint.parent_point + P_masscenter + >>> joint.child_point + C_masscenter + >>> joint.coordinates + Matrix(0, 0, []) + >>> joint.speeds + Matrix(0, 0, []) + >>> child.frame.ang_vel_in(parent.frame) + 0 + >>> child.frame.dcm(parent.frame) + Matrix([ + [1, 0, 0], + [0, 1, 0], + [0, 0, 1]]) + >>> joint.child_point.pos_from(joint.parent_point) + 0 + + To further demonstrate the use of the weld joint, two relatively-fixed + bodies rotated by a quarter turn about the Y axis can be created as follows: + + >>> from sympy import symbols, pi + >>> from sympy.physics.mechanics import ReferenceFrame, RigidBody, WeldJoint + >>> l1, l2 = symbols('l1 l2') + + First create the bodies to represent the parent and rotated child body. + + >>> parent = RigidBody('P') + >>> child = RigidBody('C') + + Next the intermediate frame specifying the fixed rotation with respect to + the parent can be created. + + >>> rotated_frame = ReferenceFrame('Pr') + >>> rotated_frame.orient_axis(parent.frame, parent.y, pi / 2) + + The weld between the parent body and child body is located at a distance + ``l1`` from the parent's center of mass in the X direction and ``l2`` from + the child's center of mass in the child's negative X direction. + + >>> weld = WeldJoint('weld', parent, child, parent_point=l1 * parent.x, + ... child_point=-l2 * child.x, + ... parent_interframe=rotated_frame) + + Now that the joint has been established, the kinematics of the bodies can be + accessed. The direction cosine matrix of the child body with respect to the + parent can be found: + + >>> child.frame.dcm(parent.frame) + Matrix([ + [0, 0, -1], + [0, 1, 0], + [1, 0, 0]]) + + As can also been seen from the direction cosine matrix, the parent X axis is + aligned with the child's Z axis: + >>> parent.x == child.z + True + + The position of the child's center of mass with respect to the parent's + center of mass can be found with: + + >>> child.masscenter.pos_from(parent.masscenter) + l1*P_frame.x + l2*C_frame.x + + The angular velocity of the child with respect to the parent is 0 as one + would expect. + + >>> child.frame.ang_vel_in(parent.frame) + 0 + + """ + + def __init__(self, name, parent, child, parent_point=None, child_point=None, + parent_interframe=None, child_interframe=None): + super().__init__(name, parent, child, [], [], parent_point, + child_point, parent_interframe=parent_interframe, + child_interframe=child_interframe) + self._kdes = Matrix(1, 0, []).T # Removes stackability problems #10770 + + def __str__(self): + return (f'WeldJoint: {self.name} parent: {self.parent} ' + f'child: {self.child}') + + def _generate_coordinates(self, coordinate): + return Matrix() + + def _generate_speeds(self, speed): + return Matrix() + + def _orient_frames(self): + self.child_interframe.orient_axis(self.parent_interframe, + self.parent_interframe.x, 0) + + def _set_angular_velocity(self): + self.child_interframe.set_ang_vel(self.parent_interframe, 0) + + def _set_linear_velocity(self): + self.child_point.set_pos(self.parent_point, 0) + self.parent_point.set_vel(self._parent_frame, 0) + self.child_point.set_vel(self._child_frame, 0) + self.child.masscenter.set_vel(self._parent_frame, 0) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/jointsmethod.py b/lib/python3.10/site-packages/sympy/physics/mechanics/jointsmethod.py new file mode 100644 index 0000000000000000000000000000000000000000..df7bd56360072feb57a65e5f78c2d116f0d4842d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/jointsmethod.py @@ -0,0 +1,318 @@ +from sympy.physics.mechanics import (Body, Lagrangian, KanesMethod, LagrangesMethod, + RigidBody, Particle) +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.method import _Methods +from sympy import Matrix +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['JointsMethod'] + + +class JointsMethod(_Methods): + """Method for formulating the equations of motion using a set of interconnected bodies with joints. + + .. deprecated:: 1.13 + The JointsMethod class is deprecated. Its functionality has been + replaced by the new :class:`~.System` class. + + Parameters + ========== + + newtonion : Body or ReferenceFrame + The newtonion(inertial) frame. + *joints : Joint + The joints in the system + + Attributes + ========== + + q, u : iterable + Iterable of the generalized coordinates and speeds + bodies : iterable + Iterable of Body objects in the system. + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + mass_matrix : Matrix, shape(n, n) + The system's mass matrix + forcing : Matrix, shape(n, 1) + The system's forcing vector + mass_matrix_full : Matrix, shape(2*n, 2*n) + The "mass matrix" for the u's and q's + forcing_full : Matrix, shape(2*n, 1) + The "forcing vector" for the u's and q's + method : KanesMethod or Lagrange's method + Method's object. + kdes : iterable + Iterable of kde in they system. + + Examples + ======== + + As Body and JointsMethod have been deprecated, the following examples are + for illustrative purposes only. The functionality of Body is fully captured + by :class:`~.RigidBody` and :class:`~.Particle` and the functionality of + JointsMethod is fully captured by :class:`~.System`. To ignore the + deprecation warning we can use the ignore_warnings context manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Body, JointsMethod, PrismaticJoint + >>> from sympy.physics.vector import dynamicsymbols + >>> c, k = symbols('c k') + >>> x, v = dynamicsymbols('x v') + >>> with ignore_warnings(DeprecationWarning): + ... wall = Body('W') + ... body = Body('B') + >>> J = PrismaticJoint('J', wall, body, coordinates=x, speeds=v) + >>> wall.apply_force(c*v*wall.x, reaction_body=body) + >>> wall.apply_force(k*x*wall.x, reaction_body=body) + >>> with ignore_warnings(DeprecationWarning): + ... method = JointsMethod(wall, J) + >>> method.form_eoms() + Matrix([[-B_mass*Derivative(v(t), t) - c*v(t) - k*x(t)]]) + >>> M = method.mass_matrix_full + >>> F = method.forcing_full + >>> rhs = M.LUsolve(F) + >>> rhs + Matrix([ + [ v(t)], + [(-c*v(t) - k*x(t))/B_mass]]) + + Notes + ===== + + ``JointsMethod`` currently only works with systems that do not have any + configuration or motion constraints. + + """ + + def __init__(self, newtonion, *joints): + sympy_deprecation_warning( + """ + The JointsMethod class is deprecated. + Its functionality has been replaced by the new System class. + """, + deprecated_since_version="1.13", + active_deprecations_target="deprecated-mechanics-jointsmethod" + ) + if isinstance(newtonion, BodyBase): + self.frame = newtonion.frame + else: + self.frame = newtonion + + self._joints = joints + self._bodies = self._generate_bodylist() + self._loads = self._generate_loadlist() + self._q = self._generate_q() + self._u = self._generate_u() + self._kdes = self._generate_kdes() + + self._method = None + + @property + def bodies(self): + """List of bodies in they system.""" + return self._bodies + + @property + def loads(self): + """List of loads on the system.""" + return self._loads + + @property + def q(self): + """List of the generalized coordinates.""" + return self._q + + @property + def u(self): + """List of the generalized speeds.""" + return self._u + + @property + def kdes(self): + """List of the generalized coordinates.""" + return self._kdes + + @property + def forcing_full(self): + """The "forcing vector" for the u's and q's.""" + return self.method.forcing_full + + @property + def mass_matrix_full(self): + """The "mass matrix" for the u's and q's.""" + return self.method.mass_matrix_full + + @property + def mass_matrix(self): + """The system's mass matrix.""" + return self.method.mass_matrix + + @property + def forcing(self): + """The system's forcing vector.""" + return self.method.forcing + + @property + def method(self): + """Object of method used to form equations of systems.""" + return self._method + + def _generate_bodylist(self): + bodies = [] + for joint in self._joints: + if joint.child not in bodies: + bodies.append(joint.child) + if joint.parent not in bodies: + bodies.append(joint.parent) + return bodies + + def _generate_loadlist(self): + load_list = [] + for body in self.bodies: + if isinstance(body, Body): + load_list.extend(body.loads) + return load_list + + def _generate_q(self): + q_ind = [] + for joint in self._joints: + for coordinate in joint.coordinates: + if coordinate in q_ind: + raise ValueError('Coordinates of joints should be unique.') + q_ind.append(coordinate) + return Matrix(q_ind) + + def _generate_u(self): + u_ind = [] + for joint in self._joints: + for speed in joint.speeds: + if speed in u_ind: + raise ValueError('Speeds of joints should be unique.') + u_ind.append(speed) + return Matrix(u_ind) + + def _generate_kdes(self): + kd_ind = Matrix(1, 0, []).T + for joint in self._joints: + kd_ind = kd_ind.col_join(joint.kdes) + return kd_ind + + def _convert_bodies(self): + # Convert `Body` to `Particle` and `RigidBody` + bodylist = [] + for body in self.bodies: + if not isinstance(body, Body): + bodylist.append(body) + continue + if body.is_rigidbody: + rb = RigidBody(body.name, body.masscenter, body.frame, body.mass, + (body.central_inertia, body.masscenter)) + rb.potential_energy = body.potential_energy + bodylist.append(rb) + else: + part = Particle(body.name, body.masscenter, body.mass) + part.potential_energy = body.potential_energy + bodylist.append(part) + return bodylist + + def form_eoms(self, method=KanesMethod): + """Method to form system's equation of motions. + + Parameters + ========== + + method : Class + Class name of method. + + Returns + ======== + + Matrix + Vector of equations of motions. + + Examples + ======== + + As Body and JointsMethod have been deprecated, the following examples + are for illustrative purposes only. The functionality of Body is fully + captured by :class:`~.RigidBody` and :class:`~.Particle` and the + functionality of JointsMethod is fully captured by :class:`~.System`. To + ignore the deprecation warning we can use the ignore_warnings context + manager. + + >>> from sympy.utilities.exceptions import ignore_warnings + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import S, symbols + >>> from sympy.physics.mechanics import LagrangesMethod, dynamicsymbols, Body + >>> from sympy.physics.mechanics import PrismaticJoint, JointsMethod + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> with ignore_warnings(DeprecationWarning): + ... wall = Body('W') + ... part = Body('P', mass=m) + >>> part.potential_energy = k * q**2 / S(2) + >>> J = PrismaticJoint('J', wall, part, coordinates=q, speeds=qd) + >>> wall.apply_force(b * qd * wall.x, reaction_body=part) + >>> with ignore_warnings(DeprecationWarning): + ... method = JointsMethod(wall, J) + >>> method.form_eoms(LagrangesMethod) + Matrix([[b*Derivative(q(t), t) + k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> method.rhs() + Matrix([ + [ Derivative(q(t), t)], + [(-b*Derivative(q(t), t) - k*q(t))/m]]) + + """ + + bodylist = self._convert_bodies() + if issubclass(method, LagrangesMethod): #LagrangesMethod or similar + L = Lagrangian(self.frame, *bodylist) + self._method = method(L, self.q, self.loads, bodylist, self.frame) + else: #KanesMethod or similar + self._method = method(self.frame, q_ind=self.q, u_ind=self.u, kd_eqs=self.kdes, + forcelist=self.loads, bodies=bodylist) + soln = self.method._form_eoms() + return soln + + def rhs(self, inv_method=None): + """Returns equations that can be solved numerically. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + Returns + ======== + + Matrix + Numerically solvable equations. + + See Also + ======== + + sympy.physics.mechanics.kane.KanesMethod.rhs: + KanesMethod's rhs function. + sympy.physics.mechanics.lagrange.LagrangesMethod.rhs: + LagrangesMethod's rhs function. + + """ + + return self.method.rhs(inv_method=inv_method) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/kane.py b/lib/python3.10/site-packages/sympy/physics/mechanics/kane.py new file mode 100644 index 0000000000000000000000000000000000000000..7edea5fa881cd89ae366c3141e575b00c0e5ac34 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/kane.py @@ -0,0 +1,860 @@ +from sympy import zeros, Matrix, diff, eye +from sympy.core.sorting import default_sort_key +from sympy.physics.vector import (ReferenceFrame, dynamicsymbols, + partial_velocity) +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.particle import Particle +from sympy.physics.mechanics.rigidbody import RigidBody +from sympy.physics.mechanics.functions import (msubs, find_dynamicsymbols, + _f_list_parser, + _validate_coordinates, + _parse_linear_solver) +from sympy.physics.mechanics.linearize import Linearizer +from sympy.utilities.iterables import iterable + +__all__ = ['KanesMethod'] + + +class KanesMethod(_Methods): + r"""Kane's method object. + + Explanation + =========== + + This object is used to do the "book-keeping" as you go through and form + equations of motion in the way Kane presents in: + Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 McGraw-Hill + + The attributes are for equations in the form [M] udot = forcing. + + Attributes + ========== + + q, u : Matrix + Matrices of the generalized coordinates and speeds + bodies : iterable + Iterable of Particle and RigidBody objects in the system. + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + auxiliary_eqs : Matrix + If applicable, the set of auxiliary Kane's + equations used to solve for non-contributing + forces. + mass_matrix : Matrix + The system's dynamics mass matrix: [k_d; k_dnh] + forcing : Matrix + The system's dynamics forcing vector: -[f_d; f_dnh] + mass_matrix_kin : Matrix + The "mass matrix" for kinematic differential equations: k_kqdot + forcing_kin : Matrix + The forcing vector for kinematic differential equations: -(k_ku*u + f_k) + mass_matrix_full : Matrix + The "mass matrix" for the u's and q's with dynamics and kinematics + forcing_full : Matrix + The "forcing vector" for the u's and q's with dynamics and kinematics + + Parameters + ========== + + frame : ReferenceFrame + The inertial reference frame for the system. + q_ind : iterable of dynamicsymbols + Independent generalized coordinates. + u_ind : iterable of dynamicsymbols + Independent generalized speeds. + kd_eqs : iterable of Expr, optional + Kinematic differential equations, which linearly relate the generalized + speeds to the time-derivatives of the generalized coordinates. + q_dependent : iterable of dynamicsymbols, optional + Dependent generalized coordinates. + configuration_constraints : iterable of Expr, optional + Constraints on the system's configuration, i.e. holonomic constraints. + u_dependent : iterable of dynamicsymbols, optional + Dependent generalized speeds. + velocity_constraints : iterable of Expr, optional + Constraints on the system's velocity, i.e. the combination of the + nonholonomic constraints and the time-derivative of the holonomic + constraints. + acceleration_constraints : iterable of Expr, optional + Constraints on the system's acceleration, by default these are the + time-derivative of the velocity constraints. + u_auxiliary : iterable of dynamicsymbols, optional + Auxiliary generalized speeds. + bodies : iterable of Particle and/or RigidBody, optional + The particles and rigid bodies in the system. + forcelist : iterable of tuple[Point | ReferenceFrame, Vector], optional + Forces and torques applied on the system. + explicit_kinematics : bool + Boolean whether the mass matrices and forcing vectors should use the + explicit form (default) or implicit form for kinematics. + See the notes for more details. + kd_eqs_solver : str, callable + Method used to solve the kinematic differential equations. If a string + is supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``f(A, rhs)``, where it solves the + equations and returns the solution. The default utilizes LU solve. See + the notes for more information. + constraint_solver : str, callable + Method used to solve the velocity constraints. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``f(A, rhs)``, where it solves the + equations and returns the solution. The default utilizes LU solve. See + the notes for more information. + + Notes + ===== + + The mass matrices and forcing vectors related to kinematic equations + are given in the explicit form by default. In other words, the kinematic + mass matrix is $\mathbf{k_{k\dot{q}}} = \mathbf{I}$. + In order to get the implicit form of those matrices/vectors, you can set the + ``explicit_kinematics`` attribute to ``False``. So $\mathbf{k_{k\dot{q}}}$ + is not necessarily an identity matrix. This can provide more compact + equations for non-simple kinematics. + + Two linear solvers can be supplied to ``KanesMethod``: one for solving the + kinematic differential equations and one to solve the velocity constraints. + Both of these sets of equations can be expressed as a linear system ``Ax = rhs``, + which have to be solved in order to obtain the equations of motion. + + The default solver ``'LU'``, which stands for LU solve, results relatively low + number of operations. The weakness of this method is that it can result in zero + division errors. + + If zero divisions are encountered, a possible solver which may solve the problem + is ``"CRAMER"``. This method uses Cramer's rule to solve the system. This method + is slower and results in more operations than the default solver. However it only + uses a single division by default per entry of the solution. + + While a valid list of solvers can be found at + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`, it is also possible to supply a + `callable`. This way it is possible to use a different solver routine. If the + kinematic differential equations are not too complex it can be worth it to simplify + the solution by using ``lambda A, b: simplify(Matrix.LUsolve(A, b))``. Another + option solver one may use is :func:`sympy.solvers.solveset.linsolve`. This can be + done using `lambda A, b: tuple(linsolve((A, b)))[0]`, where we select the first + solution as our system should have only one unique solution. + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + In this example, we first need to do the kinematics. + This involves creating generalized speeds and coordinates and their + derivatives. + Then we create a point and set its velocity in a frame. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import dynamicsymbols, ReferenceFrame + >>> from sympy.physics.mechanics import Point, Particle, KanesMethod + >>> q, u = dynamicsymbols('q u') + >>> qd, ud = dynamicsymbols('q u', 1) + >>> m, c, k = symbols('m c k') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, u * N.x) + + Next we need to arrange/store information in the way that KanesMethod + requires. The kinematic differential equations should be an iterable of + expressions. A list of forces/torques must be constructed, where each entry + in the list is a (Point, Vector) or (ReferenceFrame, Vector) tuple, where + the Vectors represent the Force or Torque. + Next a particle needs to be created, and it needs to have a point and mass + assigned to it. + Finally, a list of all bodies and particles needs to be created. + + >>> kd = [qd - u] + >>> FL = [(P, (-k * q - c * u) * N.x)] + >>> pa = Particle('pa', P, m) + >>> BL = [pa] + + Finally we can generate the equations of motion. + First we create the KanesMethod object and supply an inertial frame, + coordinates, generalized speeds, and the kinematic differential equations. + Additional quantities such as configuration and motion constraints, + dependent coordinates and speeds, and auxiliary speeds are also supplied + here (see the online documentation). + Next we form FR* and FR to complete: Fr + Fr* = 0. + We have the equations of motion at this point. + It makes sense to rearrange them though, so we calculate the mass matrix and + the forcing terms, for E.o.M. in the form: [MM] udot = forcing, where MM is + the mass matrix, udot is a vector of the time derivatives of the + generalized speeds, and forcing is a vector representing "forcing" terms. + + >>> KM = KanesMethod(N, q_ind=[q], u_ind=[u], kd_eqs=kd) + >>> (fr, frstar) = KM.kanes_equations(BL, FL) + >>> MM = KM.mass_matrix + >>> forcing = KM.forcing + >>> rhs = MM.inv() * forcing + >>> rhs + Matrix([[(-c*u(t) - k*q(t))/m]]) + >>> KM.linearize(A_and_B=True)[0] + Matrix([ + [ 0, 1], + [-k/m, -c/m]]) + + Please look at the documentation pages for more information on how to + perform linearization and how to deal with dependent coordinates & speeds, + and how do deal with bringing non-contributing forces into evidence. + + """ + + def __init__(self, frame, q_ind, u_ind, kd_eqs=None, q_dependent=None, + configuration_constraints=None, u_dependent=None, + velocity_constraints=None, acceleration_constraints=None, + u_auxiliary=None, bodies=None, forcelist=None, + explicit_kinematics=True, kd_eqs_solver='LU', + constraint_solver='LU'): + + """Please read the online documentation. """ + if not q_ind: + q_ind = [dynamicsymbols('dummy_q')] + kd_eqs = [dynamicsymbols('dummy_kd')] + + if not isinstance(frame, ReferenceFrame): + raise TypeError('An inertial ReferenceFrame must be supplied') + self._inertial = frame + + self._fr = None + self._frstar = None + + self._forcelist = forcelist + self._bodylist = bodies + + self.explicit_kinematics = explicit_kinematics + self._constraint_solver = constraint_solver + self._initialize_vectors(q_ind, q_dependent, u_ind, u_dependent, + u_auxiliary) + _validate_coordinates(self.q, self.u) + self._initialize_kindiffeq_matrices(kd_eqs, kd_eqs_solver) + self._initialize_constraint_matrices( + configuration_constraints, velocity_constraints, + acceleration_constraints, constraint_solver) + + def _initialize_vectors(self, q_ind, q_dep, u_ind, u_dep, u_aux): + """Initialize the coordinate and speed vectors.""" + + none_handler = lambda x: Matrix(x) if x else Matrix() + + # Initialize generalized coordinates + q_dep = none_handler(q_dep) + if not iterable(q_ind): + raise TypeError('Generalized coordinates must be an iterable.') + if not iterable(q_dep): + raise TypeError('Dependent coordinates must be an iterable.') + q_ind = Matrix(q_ind) + self._qdep = q_dep + self._q = Matrix([q_ind, q_dep]) + self._qdot = self.q.diff(dynamicsymbols._t) + + # Initialize generalized speeds + u_dep = none_handler(u_dep) + if not iterable(u_ind): + raise TypeError('Generalized speeds must be an iterable.') + if not iterable(u_dep): + raise TypeError('Dependent speeds must be an iterable.') + u_ind = Matrix(u_ind) + self._udep = u_dep + self._u = Matrix([u_ind, u_dep]) + self._udot = self.u.diff(dynamicsymbols._t) + self._uaux = none_handler(u_aux) + + def _initialize_constraint_matrices(self, config, vel, acc, linear_solver='LU'): + """Initializes constraint matrices.""" + linear_solver = _parse_linear_solver(linear_solver) + # Define vector dimensions + o = len(self.u) + m = len(self._udep) + p = o - m + none_handler = lambda x: Matrix(x) if x else Matrix() + + # Initialize configuration constraints + config = none_handler(config) + if len(self._qdep) != len(config): + raise ValueError('There must be an equal number of dependent ' + 'coordinates and configuration constraints.') + self._f_h = none_handler(config) + + # Initialize velocity and acceleration constraints + vel = none_handler(vel) + acc = none_handler(acc) + if len(vel) != m: + raise ValueError('There must be an equal number of dependent ' + 'speeds and velocity constraints.') + if acc and (len(acc) != m): + raise ValueError('There must be an equal number of dependent ' + 'speeds and acceleration constraints.') + if vel: + u_zero = dict.fromkeys(self.u, 0) + udot_zero = dict.fromkeys(self._udot, 0) + + # When calling kanes_equations, another class instance will be + # created if auxiliary u's are present. In this case, the + # computation of kinetic differential equation matrices will be + # skipped as this was computed during the original KanesMethod + # object, and the qd_u_map will not be available. + if self._qdot_u_map is not None: + vel = msubs(vel, self._qdot_u_map) + + self._f_nh = msubs(vel, u_zero) + self._k_nh = (vel - self._f_nh).jacobian(self.u) + # If no acceleration constraints given, calculate them. + if not acc: + _f_dnh = (self._k_nh.diff(dynamicsymbols._t) * self.u + + self._f_nh.diff(dynamicsymbols._t)) + if self._qdot_u_map is not None: + _f_dnh = msubs(_f_dnh, self._qdot_u_map) + self._f_dnh = _f_dnh + self._k_dnh = self._k_nh + else: + if self._qdot_u_map is not None: + acc = msubs(acc, self._qdot_u_map) + self._f_dnh = msubs(acc, udot_zero) + self._k_dnh = (acc - self._f_dnh).jacobian(self._udot) + + # Form of non-holonomic constraints is B*u + C = 0. + # We partition B into independent and dependent columns: + # Ars is then -B_dep.inv() * B_ind, and it relates dependent speeds + # to independent speeds as: udep = Ars*uind, neglecting the C term. + B_ind = self._k_nh[:, :p] + B_dep = self._k_nh[:, p:o] + self._Ars = -linear_solver(B_dep, B_ind) + else: + self._f_nh = Matrix() + self._k_nh = Matrix() + self._f_dnh = Matrix() + self._k_dnh = Matrix() + self._Ars = Matrix() + + def _initialize_kindiffeq_matrices(self, kdeqs, linear_solver='LU'): + """Initialize the kinematic differential equation matrices. + + Parameters + ========== + kdeqs : sequence of sympy expressions + Kinematic differential equations in the form of f(u,q',q,t) where + f() = 0. The equations have to be linear in the generalized + coordinates and generalized speeds. + + """ + linear_solver = _parse_linear_solver(linear_solver) + if kdeqs: + if len(self.q) != len(kdeqs): + raise ValueError('There must be an equal number of kinematic ' + 'differential equations and coordinates.') + + u = self.u + qdot = self._qdot + + kdeqs = Matrix(kdeqs) + + u_zero = dict.fromkeys(u, 0) + uaux_zero = dict.fromkeys(self._uaux, 0) + qdot_zero = dict.fromkeys(qdot, 0) + + # Extract the linear coefficient matrices as per the following + # equation: + # + # k_ku(q,t)*u(t) + k_kqdot(q,t)*q'(t) + f_k(q,t) = 0 + # + k_ku = kdeqs.jacobian(u) + k_kqdot = kdeqs.jacobian(qdot) + f_k = kdeqs.xreplace(u_zero).xreplace(qdot_zero) + + # The kinematic differential equations should be linear in both q' + # and u, so check for u and q' in the components. + dy_syms = find_dynamicsymbols(k_ku.row_join(k_kqdot).row_join(f_k)) + nonlin_vars = [vari for vari in u[:] + qdot[:] if vari in dy_syms] + if nonlin_vars: + msg = ('The provided kinematic differential equations are ' + 'nonlinear in {}. They must be linear in the ' + 'generalized speeds and derivatives of the generalized ' + 'coordinates.') + raise ValueError(msg.format(nonlin_vars)) + + self._f_k_implicit = f_k.xreplace(uaux_zero) + self._k_ku_implicit = k_ku.xreplace(uaux_zero) + self._k_kqdot_implicit = k_kqdot + + # Solve for q'(t) such that the coefficient matrices are now in + # this form: + # + # k_kqdot^-1*k_ku*u(t) + I*q'(t) + k_kqdot^-1*f_k = 0 + # + # NOTE : Solving the kinematic differential equations here is not + # necessary and prevents the equations from being provided in fully + # implicit form. + f_k_explicit = linear_solver(k_kqdot, f_k) + k_ku_explicit = linear_solver(k_kqdot, k_ku) + self._qdot_u_map = dict(zip(qdot, -(k_ku_explicit*u + f_k_explicit))) + + self._f_k = f_k_explicit.xreplace(uaux_zero) + self._k_ku = k_ku_explicit.xreplace(uaux_zero) + self._k_kqdot = eye(len(qdot)) + + else: + self._qdot_u_map = None + self._f_k_implicit = self._f_k = Matrix() + self._k_ku_implicit = self._k_ku = Matrix() + self._k_kqdot_implicit = self._k_kqdot = Matrix() + + def _form_fr(self, fl): + """Form the generalized active force.""" + if fl is not None and (len(fl) == 0 or not iterable(fl)): + raise ValueError('Force pairs must be supplied in an ' + 'non-empty iterable or None.') + + N = self._inertial + # pull out relevant velocities for constructing partial velocities + vel_list, f_list = _f_list_parser(fl, N) + vel_list = [msubs(i, self._qdot_u_map) for i in vel_list] + f_list = [msubs(i, self._qdot_u_map) for i in f_list] + + # Fill Fr with dot product of partial velocities and forces + o = len(self.u) + b = len(f_list) + FR = zeros(o, 1) + partials = partial_velocity(vel_list, self.u, N) + for i in range(o): + FR[i] = sum(partials[j][i].dot(f_list[j]) for j in range(b)) + + # In case there are dependent speeds + if self._udep: + p = o - len(self._udep) + FRtilde = FR[:p, 0] + FRold = FR[p:o, 0] + FRtilde += self._Ars.T * FRold + FR = FRtilde + + self._forcelist = fl + self._fr = FR + return FR + + def _form_frstar(self, bl): + """Form the generalized inertia force.""" + + if not iterable(bl): + raise TypeError('Bodies must be supplied in an iterable.') + + t = dynamicsymbols._t + N = self._inertial + # Dicts setting things to zero + udot_zero = dict.fromkeys(self._udot, 0) + uaux_zero = dict.fromkeys(self._uaux, 0) + uauxdot = [diff(i, t) for i in self._uaux] + uauxdot_zero = dict.fromkeys(uauxdot, 0) + # Dictionary of q' and q'' to u and u' + q_ddot_u_map = {k.diff(t): v.diff(t).xreplace( + self._qdot_u_map) for (k, v) in self._qdot_u_map.items()} + q_ddot_u_map.update(self._qdot_u_map) + + # Fill up the list of partials: format is a list with num elements + # equal to number of entries in body list. Each of these elements is a + # list - either of length 1 for the translational components of + # particles or of length 2 for the translational and rotational + # components of rigid bodies. The inner most list is the list of + # partial velocities. + def get_partial_velocity(body): + if isinstance(body, RigidBody): + vlist = [body.masscenter.vel(N), body.frame.ang_vel_in(N)] + elif isinstance(body, Particle): + vlist = [body.point.vel(N),] + else: + raise TypeError('The body list may only contain either ' + 'RigidBody or Particle as list elements.') + v = [msubs(vel, self._qdot_u_map) for vel in vlist] + return partial_velocity(v, self.u, N) + partials = [get_partial_velocity(body) for body in bl] + + # Compute fr_star in two components: + # fr_star = -(MM*u' + nonMM) + o = len(self.u) + MM = zeros(o, o) + nonMM = zeros(o, 1) + zero_uaux = lambda expr: msubs(expr, uaux_zero) + zero_udot_uaux = lambda expr: msubs(msubs(expr, udot_zero), uaux_zero) + for i, body in enumerate(bl): + if isinstance(body, RigidBody): + M = zero_uaux(body.mass) + I = zero_uaux(body.central_inertia) + vel = zero_uaux(body.masscenter.vel(N)) + omega = zero_uaux(body.frame.ang_vel_in(N)) + acc = zero_udot_uaux(body.masscenter.acc(N)) + inertial_force = (M.diff(t) * vel + M * acc) + inertial_torque = zero_uaux((I.dt(body.frame).dot(omega)) + + msubs(I.dot(body.frame.ang_acc_in(N)), udot_zero) + + (omega.cross(I.dot(omega)))) + for j in range(o): + tmp_vel = zero_uaux(partials[i][0][j]) + tmp_ang = zero_uaux(I.dot(partials[i][1][j])) + for k in range(o): + # translational + MM[j, k] += M*tmp_vel.dot(partials[i][0][k]) + # rotational + MM[j, k] += tmp_ang.dot(partials[i][1][k]) + nonMM[j] += inertial_force.dot(partials[i][0][j]) + nonMM[j] += inertial_torque.dot(partials[i][1][j]) + else: + M = zero_uaux(body.mass) + vel = zero_uaux(body.point.vel(N)) + acc = zero_udot_uaux(body.point.acc(N)) + inertial_force = (M.diff(t) * vel + M * acc) + for j in range(o): + temp = zero_uaux(partials[i][0][j]) + for k in range(o): + MM[j, k] += M*temp.dot(partials[i][0][k]) + nonMM[j] += inertial_force.dot(partials[i][0][j]) + # Compose fr_star out of MM and nonMM + MM = zero_uaux(msubs(MM, q_ddot_u_map)) + nonMM = msubs(msubs(nonMM, q_ddot_u_map), + udot_zero, uauxdot_zero, uaux_zero) + fr_star = -(MM * msubs(Matrix(self._udot), uauxdot_zero) + nonMM) + + # If there are dependent speeds, we need to find fr_star_tilde + if self._udep: + p = o - len(self._udep) + fr_star_ind = fr_star[:p, 0] + fr_star_dep = fr_star[p:o, 0] + fr_star = fr_star_ind + (self._Ars.T * fr_star_dep) + # Apply the same to MM + MMi = MM[:p, :] + MMd = MM[p:o, :] + MM = MMi + (self._Ars.T * MMd) + # Apply the same to nonMM + nonMM = nonMM[:p, :] + (self._Ars.T * nonMM[p:o, :]) + + self._bodylist = bl + self._frstar = fr_star + self._k_d = MM + self._f_d = -(self._fr - nonMM) + return fr_star + + def to_linearizer(self, linear_solver='LU'): + """Returns an instance of the Linearizer class, initiated from the + data in the KanesMethod class. This may be more desirable than using + the linearize class method, as the Linearizer object will allow more + efficient recalculation (i.e. about varying operating points). + + Parameters + ========== + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + Returns + ======= + Linearizer + An instantiated + :class:`sympy.physics.mechanics.linearize.Linearizer`. + + """ + + if (self._fr is None) or (self._frstar is None): + raise ValueError('Need to compute Fr, Fr* first.') + + # Get required equation components. The Kane's method class breaks + # these into pieces. Need to reassemble + f_c = self._f_h + if self._f_nh and self._k_nh: + f_v = self._f_nh + self._k_nh*Matrix(self.u) + else: + f_v = Matrix() + if self._f_dnh and self._k_dnh: + f_a = self._f_dnh + self._k_dnh*Matrix(self._udot) + else: + f_a = Matrix() + # Dicts to sub to zero, for splitting up expressions + u_zero = dict.fromkeys(self.u, 0) + ud_zero = dict.fromkeys(self._udot, 0) + qd_zero = dict.fromkeys(self._qdot, 0) + qd_u_zero = dict.fromkeys(Matrix([self._qdot, self.u]), 0) + # Break the kinematic differential eqs apart into f_0 and f_1 + f_0 = msubs(self._f_k, u_zero) + self._k_kqdot*Matrix(self._qdot) + f_1 = msubs(self._f_k, qd_zero) + self._k_ku*Matrix(self.u) + # Break the dynamic differential eqs into f_2 and f_3 + f_2 = msubs(self._frstar, qd_u_zero) + f_3 = msubs(self._frstar, ud_zero) + self._fr + f_4 = zeros(len(f_2), 1) + + # Get the required vector components + q = self.q + u = self.u + if self._qdep: + q_i = q[:-len(self._qdep)] + else: + q_i = q + q_d = self._qdep + if self._udep: + u_i = u[:-len(self._udep)] + else: + u_i = u + u_d = self._udep + + # Form dictionary to set auxiliary speeds & their derivatives to 0. + uaux = self._uaux + uauxdot = uaux.diff(dynamicsymbols._t) + uaux_zero = dict.fromkeys(Matrix([uaux, uauxdot]), 0) + + # Checking for dynamic symbols outside the dynamic differential + # equations; throws error if there is. + sym_list = set(Matrix([q, self._qdot, u, self._udot, uaux, uauxdot])) + if any(find_dynamicsymbols(i, sym_list) for i in [self._k_kqdot, + self._k_ku, self._f_k, self._k_dnh, self._f_dnh, self._k_d]): + raise ValueError('Cannot have dynamicsymbols outside dynamic \ + forcing vector.') + + # Find all other dynamic symbols, forming the forcing vector r. + # Sort r to make it canonical. + r = list(find_dynamicsymbols(msubs(self._f_d, uaux_zero), sym_list)) + r.sort(key=default_sort_key) + + # Check for any derivatives of variables in r that are also found in r. + for i in r: + if diff(i, dynamicsymbols._t) in r: + raise ValueError('Cannot have derivatives of specified \ + quantities when linearizing forcing terms.') + return Linearizer(f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i, + q_d, u_i, u_d, r, linear_solver=linear_solver) + + # TODO : Remove `new_method` after 1.1 has been released. + def linearize(self, *, new_method=None, linear_solver='LU', **kwargs): + """ Linearize the equations of motion about a symbolic operating point. + + Parameters + ========== + new_method + Deprecated, does nothing and will be removed. + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + **kwargs + Extra keyword arguments are passed to + :meth:`sympy.physics.mechanics.linearize.Linearizer.linearize`. + + Explanation + =========== + + If kwarg A_and_B is False (default), returns M, A, B, r for the + linearized form, M*[q', u']^T = A*[q_ind, u_ind]^T + B*r. + + If kwarg A_and_B is True, returns A, B, r for the linearized form + dx = A*x + B*r, where x = [q_ind, u_ind]^T. Note that this is + computationally intensive if there are many symbolic parameters. For + this reason, it may be more desirable to use the default A_and_B=False, + returning M, A, and B. Values may then be substituted in to these + matrices, and the state space form found as + A = P.T*M.inv()*A, B = P.T*M.inv()*B, where P = Linearizer.perm_mat. + + In both cases, r is found as all dynamicsymbols in the equations of + motion that are not part of q, u, q', or u'. They are sorted in + canonical form. + + The operating points may be also entered using the ``op_point`` kwarg. + This takes a dictionary of {symbol: value}, or a an iterable of such + dictionaries. The values may be numeric or symbolic. The more values + you can specify beforehand, the faster this computation will run. + + For more documentation, please see the ``Linearizer`` class. + + """ + + linearizer = self.to_linearizer(linear_solver=linear_solver) + result = linearizer.linearize(**kwargs) + return result + (linearizer.r,) + + def kanes_equations(self, bodies=None, loads=None): + """ Method to form Kane's equations, Fr + Fr* = 0. + + Explanation + =========== + + Returns (Fr, Fr*). In the case where auxiliary generalized speeds are + present (say, s auxiliary speeds, o generalized speeds, and m motion + constraints) the length of the returned vectors will be o - m + s in + length. The first o - m equations will be the constrained Kane's + equations, then the s auxiliary Kane's equations. These auxiliary + equations can be accessed with the auxiliary_eqs property. + + Parameters + ========== + + bodies : iterable + An iterable of all RigidBody's and Particle's in the system. + A system must have at least one body. + loads : iterable + Takes in an iterable of (Particle, Vector) or (ReferenceFrame, Vector) + tuples which represent the force at a point or torque on a frame. + Must be either a non-empty iterable of tuples or None which corresponds + to a system with no constraints. + """ + if bodies is None: + bodies = self.bodies + if loads is None and self._forcelist is not None: + loads = self._forcelist + if loads == []: + loads = None + if not self._k_kqdot: + raise AttributeError('Create an instance of KanesMethod with ' + 'kinematic differential equations to use this method.') + fr = self._form_fr(loads) + frstar = self._form_frstar(bodies) + if self._uaux: + if not self._udep: + km = KanesMethod(self._inertial, self.q, self._uaux, + u_auxiliary=self._uaux, constraint_solver=self._constraint_solver) + else: + km = KanesMethod(self._inertial, self.q, self._uaux, + u_auxiliary=self._uaux, u_dependent=self._udep, + velocity_constraints=(self._k_nh * self.u + + self._f_nh), + acceleration_constraints=(self._k_dnh * self._udot + + self._f_dnh), + constraint_solver=self._constraint_solver + ) + km._qdot_u_map = self._qdot_u_map + self._km = km + fraux = km._form_fr(loads) + frstaraux = km._form_frstar(bodies) + self._aux_eq = fraux + frstaraux + self._fr = fr.col_join(fraux) + self._frstar = frstar.col_join(frstaraux) + return (self._fr, self._frstar) + + def _form_eoms(self): + fr, frstar = self.kanes_equations(self.bodylist, self.forcelist) + return fr + frstar + + def rhs(self, inv_method=None): + """Returns the system's equations of motion in first order form. The + output is the right hand side of:: + + x' = |q'| =: f(q, u, r, p, t) + |u'| + + The right hand side is what is needed by most numerical ODE + integrators. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + """ + rhs = zeros(len(self.q) + len(self.u), 1) + kdes = self.kindiffdict() + for i, q_i in enumerate(self.q): + rhs[i] = kdes[q_i.diff()] + + if inv_method is None: + rhs[len(self.q):, 0] = self.mass_matrix.LUsolve(self.forcing) + else: + rhs[len(self.q):, 0] = (self.mass_matrix.inv(inv_method, + try_block_diag=True) * + self.forcing) + + return rhs + + def kindiffdict(self): + """Returns a dictionary mapping q' to u.""" + if not self._qdot_u_map: + raise AttributeError('Create an instance of KanesMethod with ' + 'kinematic differential equations to use this method.') + return self._qdot_u_map + + @property + def auxiliary_eqs(self): + """A matrix containing the auxiliary equations.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + if not self._uaux: + raise ValueError('No auxiliary speeds have been declared.') + return self._aux_eq + + @property + def mass_matrix_kin(self): + r"""The kinematic "mass matrix" $\mathbf{k_{k\dot{q}}}$ of the system.""" + return self._k_kqdot if self.explicit_kinematics else self._k_kqdot_implicit + + @property + def forcing_kin(self): + """The kinematic "forcing vector" of the system.""" + if self.explicit_kinematics: + return -(self._k_ku * Matrix(self.u) + self._f_k) + else: + return -(self._k_ku_implicit * Matrix(self.u) + self._f_k_implicit) + + @property + def mass_matrix(self): + """The mass matrix of the system.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + return Matrix([self._k_d, self._k_dnh]) + + @property + def forcing(self): + """The forcing vector of the system.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + return -Matrix([self._f_d, self._f_dnh]) + + @property + def mass_matrix_full(self): + """The mass matrix of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + if not self._fr or not self._frstar: + raise ValueError('Need to compute Fr, Fr* first.') + o, n = len(self.u), len(self.q) + return (self.mass_matrix_kin.row_join(zeros(n, o))).col_join( + zeros(o, n).row_join(self.mass_matrix)) + + @property + def forcing_full(self): + """The forcing vector of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + return Matrix([self.forcing_kin, self.forcing]) + + @property + def q(self): + return self._q + + @property + def u(self): + return self._u + + @property + def bodylist(self): + return self._bodylist + + @property + def forcelist(self): + return self._forcelist + + @property + def bodies(self): + return self._bodylist + + @property + def loads(self): + return self._forcelist diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/lagrange.py b/lib/python3.10/site-packages/sympy/physics/mechanics/lagrange.py new file mode 100644 index 0000000000000000000000000000000000000000..282176a404f77762abc3ee8c6a575519b2de1f02 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/lagrange.py @@ -0,0 +1,512 @@ +from sympy import diff, zeros, Matrix, eye, sympify +from sympy.core.sorting import default_sort_key +from sympy.physics.vector import dynamicsymbols, ReferenceFrame +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.functions import ( + find_dynamicsymbols, msubs, _f_list_parser, _validate_coordinates) +from sympy.physics.mechanics.linearize import Linearizer +from sympy.utilities.iterables import iterable + +__all__ = ['LagrangesMethod'] + + +class LagrangesMethod(_Methods): + """Lagrange's method object. + + Explanation + =========== + + This object generates the equations of motion in a two step procedure. The + first step involves the initialization of LagrangesMethod by supplying the + Lagrangian and the generalized coordinates, at the bare minimum. If there + are any constraint equations, they can be supplied as keyword arguments. + The Lagrange multipliers are automatically generated and are equal in + number to the constraint equations. Similarly any non-conservative forces + can be supplied in an iterable (as described below and also shown in the + example) along with a ReferenceFrame. This is also discussed further in the + __init__ method. + + Attributes + ========== + + q, u : Matrix + Matrices of the generalized coordinates and speeds + loads : iterable + Iterable of (Point, vector) or (ReferenceFrame, vector) tuples + describing the forces on the system. + bodies : iterable + Iterable containing the rigid bodies and particles of the system. + mass_matrix : Matrix + The system's mass matrix + forcing : Matrix + The system's forcing vector + mass_matrix_full : Matrix + The "mass matrix" for the qdot's, qdoubledot's, and the + lagrange multipliers (lam) + forcing_full : Matrix + The forcing vector for the qdot's, qdoubledot's and + lagrange multipliers (lam) + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + In this example, we first need to do the kinematics. + This involves creating generalized coordinates and their derivatives. + Then we create a point and set its velocity in a frame. + + >>> from sympy.physics.mechanics import LagrangesMethod, Lagrangian + >>> from sympy.physics.mechanics import ReferenceFrame, Particle, Point + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy import symbols + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, qd * N.x) + + We need to then prepare the information as required by LagrangesMethod to + generate equations of motion. + First we create the Particle, which has a point attached to it. + Following this the lagrangian is created from the kinetic and potential + energies. + Then, an iterable of nonconservative forces/torques must be constructed, + where each item is a (Point, Vector) or (ReferenceFrame, Vector) tuple, + with the Vectors representing the nonconservative forces or torques. + + >>> Pa = Particle('Pa', P, m) + >>> Pa.potential_energy = k * q**2 / 2.0 + >>> L = Lagrangian(N, Pa) + >>> fl = [(P, -b * qd * N.x)] + + Finally we can generate the equations of motion. + First we create the LagrangesMethod object. To do this one must supply + the Lagrangian, and the generalized coordinates. The constraint equations, + the forcelist, and the inertial frame may also be provided, if relevant. + Next we generate Lagrange's equations of motion, such that: + Lagrange's equations of motion = 0. + We have the equations of motion at this point. + + >>> l = LagrangesMethod(L, [q], forcelist = fl, frame = N) + >>> print(l.form_lagranges_equations()) + Matrix([[b*Derivative(q(t), t) + 1.0*k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> print(l.rhs()) + Matrix([[Derivative(q(t), t)], [(-b*Derivative(q(t), t) - 1.0*k*q(t))/m]]) + + Please refer to the docstrings on each method for more details. + """ + + def __init__(self, Lagrangian, qs, forcelist=None, bodies=None, frame=None, + hol_coneqs=None, nonhol_coneqs=None): + """Supply the following for the initialization of LagrangesMethod. + + Lagrangian : Sympifyable + + qs : array_like + The generalized coordinates + + hol_coneqs : array_like, optional + The holonomic constraint equations + + nonhol_coneqs : array_like, optional + The nonholonomic constraint equations + + forcelist : iterable, optional + Takes an iterable of (Point, Vector) or (ReferenceFrame, Vector) + tuples which represent the force at a point or torque on a frame. + This feature is primarily to account for the nonconservative forces + and/or moments. + + bodies : iterable, optional + Takes an iterable containing the rigid bodies and particles of the + system. + + frame : ReferenceFrame, optional + Supply the inertial frame. This is used to determine the + generalized forces due to non-conservative forces. + """ + + self._L = Matrix([sympify(Lagrangian)]) + self.eom = None + self._m_cd = Matrix() # Mass Matrix of differentiated coneqs + self._m_d = Matrix() # Mass Matrix of dynamic equations + self._f_cd = Matrix() # Forcing part of the diff coneqs + self._f_d = Matrix() # Forcing part of the dynamic equations + self.lam_coeffs = Matrix() # The coeffecients of the multipliers + + forcelist = forcelist if forcelist else [] + if not iterable(forcelist): + raise TypeError('Force pairs must be supplied in an iterable.') + self._forcelist = forcelist + if frame and not isinstance(frame, ReferenceFrame): + raise TypeError('frame must be a valid ReferenceFrame') + self._bodies = bodies + self.inertial = frame + + self.lam_vec = Matrix() + + self._term1 = Matrix() + self._term2 = Matrix() + self._term3 = Matrix() + self._term4 = Matrix() + + # Creating the qs, qdots and qdoubledots + if not iterable(qs): + raise TypeError('Generalized coordinates must be an iterable') + self._q = Matrix(qs) + self._qdots = self.q.diff(dynamicsymbols._t) + self._qdoubledots = self._qdots.diff(dynamicsymbols._t) + _validate_coordinates(self.q) + + mat_build = lambda x: Matrix(x) if x else Matrix() + hol_coneqs = mat_build(hol_coneqs) + nonhol_coneqs = mat_build(nonhol_coneqs) + self.coneqs = Matrix([hol_coneqs.diff(dynamicsymbols._t), + nonhol_coneqs]) + self._hol_coneqs = hol_coneqs + + def form_lagranges_equations(self): + """Method to form Lagrange's equations of motion. + + Returns a vector of equations of motion using Lagrange's equations of + the second kind. + """ + + qds = self._qdots + qdd_zero = dict.fromkeys(self._qdoubledots, 0) + n = len(self.q) + + # Internally we represent the EOM as four terms: + # EOM = term1 - term2 - term3 - term4 = 0 + + # First term + self._term1 = self._L.jacobian(qds) + self._term1 = self._term1.diff(dynamicsymbols._t).T + + # Second term + self._term2 = self._L.jacobian(self.q).T + + # Third term + if self.coneqs: + coneqs = self.coneqs + m = len(coneqs) + # Creating the multipliers + self.lam_vec = Matrix(dynamicsymbols('lam1:' + str(m + 1))) + self.lam_coeffs = -coneqs.jacobian(qds) + self._term3 = self.lam_coeffs.T * self.lam_vec + # Extracting the coeffecients of the qdds from the diff coneqs + diffconeqs = coneqs.diff(dynamicsymbols._t) + self._m_cd = diffconeqs.jacobian(self._qdoubledots) + # The remaining terms i.e. the 'forcing' terms in diff coneqs + self._f_cd = -diffconeqs.subs(qdd_zero) + else: + self._term3 = zeros(n, 1) + + # Fourth term + if self.forcelist: + N = self.inertial + self._term4 = zeros(n, 1) + for i, qd in enumerate(qds): + flist = zip(*_f_list_parser(self.forcelist, N)) + self._term4[i] = sum(v.diff(qd, N).dot(f) for (v, f) in flist) + else: + self._term4 = zeros(n, 1) + + # Form the dynamic mass and forcing matrices + without_lam = self._term1 - self._term2 - self._term4 + self._m_d = without_lam.jacobian(self._qdoubledots) + self._f_d = -without_lam.subs(qdd_zero) + + # Form the EOM + self.eom = without_lam - self._term3 + return self.eom + + def _form_eoms(self): + return self.form_lagranges_equations() + + @property + def mass_matrix(self): + """Returns the mass matrix, which is augmented by the Lagrange + multipliers, if necessary. + + Explanation + =========== + + If the system is described by 'n' generalized coordinates and there are + no constraint equations then an n X n matrix is returned. + + If there are 'n' generalized coordinates and 'm' constraint equations + have been supplied during initialization then an n X (n+m) matrix is + returned. The (n + m - 1)th and (n + m)th columns contain the + coefficients of the Lagrange multipliers. + """ + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + if self.coneqs: + return (self._m_d).row_join(self.lam_coeffs.T) + else: + return self._m_d + + @property + def mass_matrix_full(self): + """Augments the coefficients of qdots to the mass_matrix.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + n = len(self.q) + m = len(self.coneqs) + row1 = eye(n).row_join(zeros(n, n + m)) + row2 = zeros(n, n).row_join(self.mass_matrix) + if self.coneqs: + row3 = zeros(m, n).row_join(self._m_cd).row_join(zeros(m, m)) + return row1.col_join(row2).col_join(row3) + else: + return row1.col_join(row2) + + @property + def forcing(self): + """Returns the forcing vector from 'lagranges_equations' method.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + return self._f_d + + @property + def forcing_full(self): + """Augments qdots to the forcing vector above.""" + + if self.eom is None: + raise ValueError('Need to compute the equations of motion first') + if self.coneqs: + return self._qdots.col_join(self.forcing).col_join(self._f_cd) + else: + return self._qdots.col_join(self.forcing) + + def to_linearizer(self, q_ind=None, qd_ind=None, q_dep=None, qd_dep=None, + linear_solver='LU'): + """Returns an instance of the Linearizer class, initiated from the data + in the LagrangesMethod class. This may be more desirable than using the + linearize class method, as the Linearizer object will allow more + efficient recalculation (i.e. about varying operating points). + + Parameters + ========== + + q_ind, qd_ind : array_like, optional + The independent generalized coordinates and speeds. + q_dep, qd_dep : array_like, optional + The dependent generalized coordinates and speeds. + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + Returns + ======= + Linearizer + An instantiated + :class:`sympy.physics.mechanics.linearize.Linearizer`. + + """ + + # Compose vectors + t = dynamicsymbols._t + q = self.q + u = self._qdots + ud = u.diff(t) + # Get vector of lagrange multipliers + lams = self.lam_vec + + mat_build = lambda x: Matrix(x) if x else Matrix() + q_i = mat_build(q_ind) + q_d = mat_build(q_dep) + u_i = mat_build(qd_ind) + u_d = mat_build(qd_dep) + + # Compose general form equations + f_c = self._hol_coneqs + f_v = self.coneqs + f_a = f_v.diff(t) + f_0 = u + f_1 = -u + f_2 = self._term1 + f_3 = -(self._term2 + self._term4) + f_4 = -self._term3 + + # Check that there are an appropriate number of independent and + # dependent coordinates + if len(q_d) != len(f_c) or len(u_d) != len(f_v): + raise ValueError(("Must supply {:} dependent coordinates, and " + + "{:} dependent speeds").format(len(f_c), len(f_v))) + if set(Matrix([q_i, q_d])) != set(q): + raise ValueError("Must partition q into q_ind and q_dep, with " + + "no extra or missing symbols.") + if set(Matrix([u_i, u_d])) != set(u): + raise ValueError("Must partition qd into qd_ind and qd_dep, " + + "with no extra or missing symbols.") + + # Find all other dynamic symbols, forming the forcing vector r. + # Sort r to make it canonical. + insyms = set(Matrix([q, u, ud, lams])) + r = list(find_dynamicsymbols(f_3, insyms)) + r.sort(key=default_sort_key) + # Check for any derivatives of variables in r that are also found in r. + for i in r: + if diff(i, dynamicsymbols._t) in r: + raise ValueError('Cannot have derivatives of specified \ + quantities when linearizing forcing terms.') + + return Linearizer(f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i, + q_d, u_i, u_d, r, lams, linear_solver=linear_solver) + + def linearize(self, q_ind=None, qd_ind=None, q_dep=None, qd_dep=None, + linear_solver='LU', **kwargs): + """Linearize the equations of motion about a symbolic operating point. + + Parameters + ========== + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + **kwargs + Extra keyword arguments are passed to + :meth:`sympy.physics.mechanics.linearize.Linearizer.linearize`. + + Explanation + =========== + + If kwarg A_and_B is False (default), returns M, A, B, r for the + linearized form, M*[q', u']^T = A*[q_ind, u_ind]^T + B*r. + + If kwarg A_and_B is True, returns A, B, r for the linearized form + dx = A*x + B*r, where x = [q_ind, u_ind]^T. Note that this is + computationally intensive if there are many symbolic parameters. For + this reason, it may be more desirable to use the default A_and_B=False, + returning M, A, and B. Values may then be substituted in to these + matrices, and the state space form found as + A = P.T*M.inv()*A, B = P.T*M.inv()*B, where P = Linearizer.perm_mat. + + In both cases, r is found as all dynamicsymbols in the equations of + motion that are not part of q, u, q', or u'. They are sorted in + canonical form. + + The operating points may be also entered using the ``op_point`` kwarg. + This takes a dictionary of {symbol: value}, or a an iterable of such + dictionaries. The values may be numeric or symbolic. The more values + you can specify beforehand, the faster this computation will run. + + For more documentation, please see the ``Linearizer`` class.""" + + linearizer = self.to_linearizer(q_ind, qd_ind, q_dep, qd_dep, + linear_solver=linear_solver) + result = linearizer.linearize(**kwargs) + return result + (linearizer.r,) + + def solve_multipliers(self, op_point=None, sol_type='dict'): + """Solves for the values of the lagrange multipliers symbolically at + the specified operating point. + + Parameters + ========== + + op_point : dict or iterable of dicts, optional + Point at which to solve at. The operating point is specified as + a dictionary or iterable of dictionaries of {symbol: value}. The + value may be numeric or symbolic itself. + + sol_type : str, optional + Solution return type. Valid options are: + - 'dict': A dict of {symbol : value} (default) + - 'Matrix': An ordered column matrix of the solution + """ + + # Determine number of multipliers + k = len(self.lam_vec) + if k == 0: + raise ValueError("System has no lagrange multipliers to solve for.") + # Compose dict of operating conditions + if isinstance(op_point, dict): + op_point_dict = op_point + elif iterable(op_point): + op_point_dict = {} + for op in op_point: + op_point_dict.update(op) + elif op_point is None: + op_point_dict = {} + else: + raise TypeError("op_point must be either a dictionary or an " + "iterable of dictionaries.") + # Compose the system to be solved + mass_matrix = self.mass_matrix.col_join(-self.lam_coeffs.row_join( + zeros(k, k))) + force_matrix = self.forcing.col_join(self._f_cd) + # Sub in the operating point + mass_matrix = msubs(mass_matrix, op_point_dict) + force_matrix = msubs(force_matrix, op_point_dict) + # Solve for the multipliers + sol_list = mass_matrix.LUsolve(-force_matrix)[-k:] + if sol_type == 'dict': + return dict(zip(self.lam_vec, sol_list)) + elif sol_type == 'Matrix': + return Matrix(sol_list) + else: + raise ValueError("Unknown sol_type {:}.".format(sol_type)) + + def rhs(self, inv_method=None, **kwargs): + """Returns equations that can be solved numerically. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + """ + + if inv_method is None: + self._rhs = self.mass_matrix_full.LUsolve(self.forcing_full) + else: + self._rhs = (self.mass_matrix_full.inv(inv_method, + try_block_diag=True) * self.forcing_full) + return self._rhs + + @property + def q(self): + return self._q + + @property + def u(self): + return self._qdots + + @property + def bodies(self): + return self._bodies + + @property + def forcelist(self): + return self._forcelist + + @property + def loads(self): + return self._forcelist diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/linearize.py b/lib/python3.10/site-packages/sympy/physics/mechanics/linearize.py new file mode 100644 index 0000000000000000000000000000000000000000..9d102c61f8f60318f1a2c9896e94324c8cf02889 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/linearize.py @@ -0,0 +1,474 @@ +__all__ = ['Linearizer'] + +from sympy import Matrix, eye, zeros +from sympy.core.symbol import Dummy +from sympy.utilities.iterables import flatten +from sympy.physics.vector import dynamicsymbols +from sympy.physics.mechanics.functions import msubs, _parse_linear_solver + +from collections import namedtuple +from collections.abc import Iterable + + +class Linearizer: + """This object holds the general model form for a dynamic system. This + model is used for computing the linearized form of the system, while + properly dealing with constraints leading to dependent coordinates and + speeds. The notation and method is described in [1]_. + + Attributes + ========== + + f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a : Matrix + Matrices holding the general system form. + q, u, r : Matrix + Matrices holding the generalized coordinates, speeds, and + input vectors. + q_i, u_i : Matrix + Matrices of the independent generalized coordinates and speeds. + q_d, u_d : Matrix + Matrices of the dependent generalized coordinates and speeds. + perm_mat : Matrix + Permutation matrix such that [q_ind, u_ind]^T = perm_mat*[q, u]^T + + References + ========== + + .. [1] D. L. Peterson, G. Gede, and M. Hubbard, "Symbolic linearization of + equations of motion of constrained multibody systems," Multibody + Syst Dyn, vol. 33, no. 2, pp. 143-161, Feb. 2015, doi: + 10.1007/s11044-014-9436-5. + + """ + + def __init__(self, f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a, q, u, q_i=None, + q_d=None, u_i=None, u_d=None, r=None, lams=None, + linear_solver='LU'): + """ + Parameters + ========== + + f_0, f_1, f_2, f_3, f_4, f_c, f_v, f_a : array_like + System of equations holding the general system form. + Supply empty array or Matrix if the parameter + does not exist. + q : array_like + The generalized coordinates. + u : array_like + The generalized speeds + q_i, u_i : array_like, optional + The independent generalized coordinates and speeds. + q_d, u_d : array_like, optional + The dependent generalized coordinates and speeds. + r : array_like, optional + The input variables. + lams : array_like, optional + The lagrange multipliers + linear_solver : str, callable + Method used to solve the several symbolic linear systems of the + form ``A*x=b`` in the linearization process. If a string is + supplied, it should be a valid method that can be used with the + :meth:`sympy.matrices.matrixbase.MatrixBase.solve`. If a callable is + supplied, it should have the format ``x = f(A, b)``, where it + solves the equations and returns the solution. The default is + ``'LU'`` which corresponds to SymPy's ``A.LUsolve(b)``. + ``LUsolve()`` is fast to compute but will often result in + divide-by-zero and thus ``nan`` results. + + """ + self.linear_solver = _parse_linear_solver(linear_solver) + + # Generalized equation form + self.f_0 = Matrix(f_0) + self.f_1 = Matrix(f_1) + self.f_2 = Matrix(f_2) + self.f_3 = Matrix(f_3) + self.f_4 = Matrix(f_4) + self.f_c = Matrix(f_c) + self.f_v = Matrix(f_v) + self.f_a = Matrix(f_a) + + # Generalized equation variables + self.q = Matrix(q) + self.u = Matrix(u) + none_handler = lambda x: Matrix(x) if x else Matrix() + self.q_i = none_handler(q_i) + self.q_d = none_handler(q_d) + self.u_i = none_handler(u_i) + self.u_d = none_handler(u_d) + self.r = none_handler(r) + self.lams = none_handler(lams) + + # Derivatives of generalized equation variables + self._qd = self.q.diff(dynamicsymbols._t) + self._ud = self.u.diff(dynamicsymbols._t) + # If the user doesn't actually use generalized variables, and the + # qd and u vectors have any intersecting variables, this can cause + # problems. We'll fix this with some hackery, and Dummy variables + dup_vars = set(self._qd).intersection(self.u) + self._qd_dup = Matrix([var if var not in dup_vars else Dummy() for var + in self._qd]) + + # Derive dimesion terms + l = len(self.f_c) + m = len(self.f_v) + n = len(self.q) + o = len(self.u) + s = len(self.r) + k = len(self.lams) + dims = namedtuple('dims', ['l', 'm', 'n', 'o', 's', 'k']) + self._dims = dims(l, m, n, o, s, k) + + self._Pq = None + self._Pqi = None + self._Pqd = None + self._Pu = None + self._Pui = None + self._Pud = None + self._C_0 = None + self._C_1 = None + self._C_2 = None + self.perm_mat = None + + self._setup_done = False + + def _setup(self): + # Calculations here only need to be run once. They are moved out of + # the __init__ method to increase the speed of Linearizer creation. + self._form_permutation_matrices() + self._form_block_matrices() + self._form_coefficient_matrices() + self._setup_done = True + + def _form_permutation_matrices(self): + """Form the permutation matrices Pq and Pu.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Compute permutation matrices + if n != 0: + self._Pq = permutation_matrix(self.q, Matrix([self.q_i, self.q_d])) + if l > 0: + self._Pqi = self._Pq[:, :-l] + self._Pqd = self._Pq[:, -l:] + else: + self._Pqi = self._Pq + self._Pqd = Matrix() + if o != 0: + self._Pu = permutation_matrix(self.u, Matrix([self.u_i, self.u_d])) + if m > 0: + self._Pui = self._Pu[:, :-m] + self._Pud = self._Pu[:, -m:] + else: + self._Pui = self._Pu + self._Pud = Matrix() + # Compute combination permutation matrix for computing A and B + P_col1 = Matrix([self._Pqi, zeros(o + k, n - l)]) + P_col2 = Matrix([zeros(n, o - m), self._Pui, zeros(k, o - m)]) + if P_col1: + if P_col2: + self.perm_mat = P_col1.row_join(P_col2) + else: + self.perm_mat = P_col1 + else: + self.perm_mat = P_col2 + + def _form_coefficient_matrices(self): + """Form the coefficient matrices C_0, C_1, and C_2.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Build up the coefficient matrices C_0, C_1, and C_2 + # If there are configuration constraints (l > 0), form C_0 as normal. + # If not, C_0 is I_(nxn). Note that this works even if n=0 + if l > 0: + f_c_jac_q = self.f_c.jacobian(self.q) + self._C_0 = (eye(n) - self._Pqd * + self.linear_solver(f_c_jac_q*self._Pqd, + f_c_jac_q))*self._Pqi + else: + self._C_0 = eye(n) + # If there are motion constraints (m > 0), form C_1 and C_2 as normal. + # If not, C_1 is 0, and C_2 is I_(oxo). Note that this works even if + # o = 0. + if m > 0: + f_v_jac_u = self.f_v.jacobian(self.u) + temp = f_v_jac_u * self._Pud + if n != 0: + f_v_jac_q = self.f_v.jacobian(self.q) + self._C_1 = -self._Pud * self.linear_solver(temp, f_v_jac_q) + else: + self._C_1 = zeros(o, n) + self._C_2 = (eye(o) - self._Pud * + self.linear_solver(temp, f_v_jac_u))*self._Pui + else: + self._C_1 = zeros(o, n) + self._C_2 = eye(o) + + def _form_block_matrices(self): + """Form the block matrices for composing M, A, and B.""" + + # Extract dimension variables + l, m, n, o, s, k = self._dims + # Block Matrix Definitions. These are only defined if under certain + # conditions. If undefined, an empty matrix is used instead + if n != 0: + self._M_qq = self.f_0.jacobian(self._qd) + self._A_qq = -(self.f_0 + self.f_1).jacobian(self.q) + else: + self._M_qq = Matrix() + self._A_qq = Matrix() + if n != 0 and m != 0: + self._M_uqc = self.f_a.jacobian(self._qd_dup) + self._A_uqc = -self.f_a.jacobian(self.q) + else: + self._M_uqc = Matrix() + self._A_uqc = Matrix() + if n != 0 and o - m + k != 0: + self._M_uqd = self.f_3.jacobian(self._qd_dup) + self._A_uqd = -(self.f_2 + self.f_3 + self.f_4).jacobian(self.q) + else: + self._M_uqd = Matrix() + self._A_uqd = Matrix() + if o != 0 and m != 0: + self._M_uuc = self.f_a.jacobian(self._ud) + self._A_uuc = -self.f_a.jacobian(self.u) + else: + self._M_uuc = Matrix() + self._A_uuc = Matrix() + if o != 0 and o - m + k != 0: + self._M_uud = self.f_2.jacobian(self._ud) + self._A_uud = -(self.f_2 + self.f_3).jacobian(self.u) + else: + self._M_uud = Matrix() + self._A_uud = Matrix() + if o != 0 and n != 0: + self._A_qu = -self.f_1.jacobian(self.u) + else: + self._A_qu = Matrix() + if k != 0 and o - m + k != 0: + self._M_uld = self.f_4.jacobian(self.lams) + else: + self._M_uld = Matrix() + if s != 0 and o - m + k != 0: + self._B_u = -self.f_3.jacobian(self.r) + else: + self._B_u = Matrix() + + def linearize(self, op_point=None, A_and_B=False, simplify=False): + """Linearize the system about the operating point. Note that + q_op, u_op, qd_op, ud_op must satisfy the equations of motion. + These may be either symbolic or numeric. + + Parameters + ========== + op_point : dict or iterable of dicts, optional + Dictionary or iterable of dictionaries containing the operating + point conditions for all or a subset of the generalized + coordinates, generalized speeds, and time derivatives of the + generalized speeds. These will be substituted into the linearized + system before the linearization is complete. Leave set to ``None`` + if you want the operating point to be an arbitrary set of symbols. + Note that any reduction in symbols (whether substituted for numbers + or expressions with a common parameter) will result in faster + runtime. + A_and_B : bool, optional + If A_and_B=False (default), (M, A, B) is returned and of + A_and_B=True, (A, B) is returned. See below. + simplify : bool, optional + Determines if returned values are simplified before return. + For large expressions this may be time consuming. Default is False. + + Returns + ======= + M, A, B : Matrices, ``A_and_B=False`` + Matrices from the implicit form: + ``[M]*[q', u']^T = [A]*[q_ind, u_ind]^T + [B]*r`` + A, B : Matrices, ``A_and_B=True`` + Matrices from the explicit form: + ``[q_ind', u_ind']^T = [A]*[q_ind, u_ind]^T + [B]*r`` + + Notes + ===== + + Note that the process of solving with A_and_B=True is computationally + intensive if there are many symbolic parameters. For this reason, it + may be more desirable to use the default A_and_B=False, returning M, A, + and B. More values may then be substituted in to these matrices later + on. The state space form can then be found as A = P.T*M.LUsolve(A), B = + P.T*M.LUsolve(B), where P = Linearizer.perm_mat. + + """ + + # Run the setup if needed: + if not self._setup_done: + self._setup() + + # Compose dict of operating conditions + if isinstance(op_point, dict): + op_point_dict = op_point + elif isinstance(op_point, Iterable): + op_point_dict = {} + for op in op_point: + op_point_dict.update(op) + else: + op_point_dict = {} + + # Extract dimension variables + l, m, n, o, s, k = self._dims + + # Rename terms to shorten expressions + M_qq = self._M_qq + M_uqc = self._M_uqc + M_uqd = self._M_uqd + M_uuc = self._M_uuc + M_uud = self._M_uud + M_uld = self._M_uld + A_qq = self._A_qq + A_uqc = self._A_uqc + A_uqd = self._A_uqd + A_qu = self._A_qu + A_uuc = self._A_uuc + A_uud = self._A_uud + B_u = self._B_u + C_0 = self._C_0 + C_1 = self._C_1 + C_2 = self._C_2 + + # Build up Mass Matrix + # |M_qq 0_nxo 0_nxk| + # M = |M_uqc M_uuc 0_mxk| + # |M_uqd M_uud M_uld| + if o != 0: + col2 = Matrix([zeros(n, o), M_uuc, M_uud]) + if k != 0: + col3 = Matrix([zeros(n + m, k), M_uld]) + if n != 0: + col1 = Matrix([M_qq, M_uqc, M_uqd]) + if o != 0 and k != 0: + M = col1.row_join(col2).row_join(col3) + elif o != 0: + M = col1.row_join(col2) + else: + M = col1 + elif k != 0: + M = col2.row_join(col3) + else: + M = col2 + M_eq = msubs(M, op_point_dict) + + # Build up state coefficient matrix A + # |(A_qq + A_qu*C_1)*C_0 A_qu*C_2| + # A = |(A_uqc + A_uuc*C_1)*C_0 A_uuc*C_2| + # |(A_uqd + A_uud*C_1)*C_0 A_uud*C_2| + # Col 1 is only defined if n != 0 + if n != 0: + r1c1 = A_qq + if o != 0: + r1c1 += (A_qu * C_1) + r1c1 = r1c1 * C_0 + if m != 0: + r2c1 = A_uqc + if o != 0: + r2c1 += (A_uuc * C_1) + r2c1 = r2c1 * C_0 + else: + r2c1 = Matrix() + if o - m + k != 0: + r3c1 = A_uqd + if o != 0: + r3c1 += (A_uud * C_1) + r3c1 = r3c1 * C_0 + else: + r3c1 = Matrix() + col1 = Matrix([r1c1, r2c1, r3c1]) + else: + col1 = Matrix() + # Col 2 is only defined if o != 0 + if o != 0: + if n != 0: + r1c2 = A_qu * C_2 + else: + r1c2 = Matrix() + if m != 0: + r2c2 = A_uuc * C_2 + else: + r2c2 = Matrix() + if o - m + k != 0: + r3c2 = A_uud * C_2 + else: + r3c2 = Matrix() + col2 = Matrix([r1c2, r2c2, r3c2]) + else: + col2 = Matrix() + if col1: + if col2: + Amat = col1.row_join(col2) + else: + Amat = col1 + else: + Amat = col2 + Amat_eq = msubs(Amat, op_point_dict) + + # Build up the B matrix if there are forcing variables + # |0_(n + m)xs| + # B = |B_u | + if s != 0 and o - m + k != 0: + Bmat = zeros(n + m, s).col_join(B_u) + Bmat_eq = msubs(Bmat, op_point_dict) + else: + Bmat_eq = Matrix() + + # kwarg A_and_B indicates to return A, B for forming the equation + # dx = [A]x + [B]r, where x = [q_indnd, u_indnd]^T, + if A_and_B: + A_cont = self.perm_mat.T * self.linear_solver(M_eq, Amat_eq) + if Bmat_eq: + B_cont = self.perm_mat.T * self.linear_solver(M_eq, Bmat_eq) + else: + # Bmat = Matrix([]), so no need to sub + B_cont = Bmat_eq + if simplify: + A_cont.simplify() + B_cont.simplify() + return A_cont, B_cont + # Otherwise return M, A, B for forming the equation + # [M]dx = [A]x + [B]r, where x = [q, u]^T + else: + if simplify: + M_eq.simplify() + Amat_eq.simplify() + Bmat_eq.simplify() + return M_eq, Amat_eq, Bmat_eq + + +def permutation_matrix(orig_vec, per_vec): + """Compute the permutation matrix to change order of + orig_vec into order of per_vec. + + Parameters + ========== + + orig_vec : array_like + Symbols in original ordering. + per_vec : array_like + Symbols in new ordering. + + Returns + ======= + + p_matrix : Matrix + Permutation matrix such that orig_vec == (p_matrix * per_vec). + """ + if not isinstance(orig_vec, (list, tuple)): + orig_vec = flatten(orig_vec) + if not isinstance(per_vec, (list, tuple)): + per_vec = flatten(per_vec) + if set(orig_vec) != set(per_vec): + raise ValueError("orig_vec and per_vec must be the same length, " + "and contain the same symbols.") + ind_list = [orig_vec.index(i) for i in per_vec] + p_matrix = zeros(len(orig_vec)) + for i, j in enumerate(ind_list): + p_matrix[i, j] = 1 + return p_matrix diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/loads.py b/lib/python3.10/site-packages/sympy/physics/mechanics/loads.py new file mode 100644 index 0000000000000000000000000000000000000000..3b9db763ffd6f99905e9d17fdc07f4171de4801b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/loads.py @@ -0,0 +1,177 @@ +from abc import ABC +from collections import namedtuple +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.vector import Vector, ReferenceFrame, Point + +__all__ = ['LoadBase', 'Force', 'Torque'] + + +class LoadBase(ABC, namedtuple('LoadBase', ['location', 'vector'])): + """Abstract base class for the various loading types.""" + + def __add__(self, other): + raise TypeError(f"unsupported operand type(s) for +: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + def __mul__(self, other): + raise TypeError(f"unsupported operand type(s) for *: " + f"'{self.__class__.__name__}' and " + f"'{other.__class__.__name__}'") + + __radd__ = __add__ + __rmul__ = __mul__ + + +class Force(LoadBase): + """Force acting upon a point. + + Explanation + =========== + + A force is a vector that is bound to a line of action. This class stores + both a point, which lies on the line of action, and the vector. A tuple can + also be used, with the location as the first entry and the vector as second + entry. + + Examples + ======== + + A force of magnitude 2 along N.x acting on a point Po can be created as + follows: + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, Force + >>> N = ReferenceFrame('N') + >>> Po = Point('Po') + >>> Force(Po, 2 * N.x) + (Po, 2*N.x) + + If a body is supplied, then the center of mass of that body is used. + + >>> from sympy.physics.mechanics import Particle + >>> P = Particle('P', point=Po) + >>> Force(P, 2 * N.x) + (Po, 2*N.x) + + """ + + def __new__(cls, point, force): + if isinstance(point, BodyBase): + point = point.masscenter + if not isinstance(point, Point): + raise TypeError('Force location should be a Point.') + if not isinstance(force, Vector): + raise TypeError('Force vector should be a Vector.') + return super().__new__(cls, point, force) + + def __repr__(self): + return (f'{self.__class__.__name__}(point={self.point}, ' + f'force={self.force})') + + @property + def point(self): + return self.location + + @property + def force(self): + return self.vector + + +class Torque(LoadBase): + """Torque acting upon a frame. + + Explanation + =========== + + A torque is a free vector that is acting on a reference frame, which is + associated with a rigid body. This class stores both the frame and the + vector. A tuple can also be used, with the location as the first item and + the vector as second item. + + Examples + ======== + + A torque of magnitude 2 about N.x acting on a frame N can be created as + follows: + + >>> from sympy.physics.mechanics import ReferenceFrame, Torque + >>> N = ReferenceFrame('N') + >>> Torque(N, 2 * N.x) + (N, 2*N.x) + + If a body is supplied, then the frame fixed to that body is used. + + >>> from sympy.physics.mechanics import RigidBody + >>> rb = RigidBody('rb', frame=N) + >>> Torque(rb, 2 * N.x) + (N, 2*N.x) + + """ + + def __new__(cls, frame, torque): + if isinstance(frame, BodyBase): + frame = frame.frame + if not isinstance(frame, ReferenceFrame): + raise TypeError('Torque location should be a ReferenceFrame.') + if not isinstance(torque, Vector): + raise TypeError('Torque vector should be a Vector.') + return super().__new__(cls, frame, torque) + + def __repr__(self): + return (f'{self.__class__.__name__}(frame={self.frame}, ' + f'torque={self.torque})') + + @property + def frame(self): + return self.location + + @property + def torque(self): + return self.vector + + +def gravity(acceleration, *bodies): + """ + Returns a list of gravity forces given the acceleration + due to gravity and any number of particles or rigidbodies. + + Example + ======= + + >>> from sympy.physics.mechanics import ReferenceFrame, Particle, RigidBody + >>> from sympy.physics.mechanics.loads import gravity + >>> from sympy import symbols + >>> N = ReferenceFrame('N') + >>> g = symbols('g') + >>> P = Particle('P') + >>> B = RigidBody('B') + >>> gravity(g*N.y, P, B) + [(P_masscenter, P_mass*g*N.y), + (B_masscenter, B_mass*g*N.y)] + + """ + + gravity_force = [] + for body in bodies: + if not isinstance(body, BodyBase): + raise TypeError(f'{type(body)} is not a body type') + gravity_force.append(Force(body.masscenter, body.mass * acceleration)) + return gravity_force + + +def _parse_load(load): + """Helper function to parse loads and convert tuples to load objects.""" + if isinstance(load, LoadBase): + return load + elif isinstance(load, tuple): + if len(load) != 2: + raise ValueError(f'Load {load} should have a length of 2.') + if isinstance(load[0], Point): + return Force(load[0], load[1]) + elif isinstance(load[0], ReferenceFrame): + return Torque(load[0], load[1]) + else: + raise ValueError(f'Load not recognized. The load location {load[0]}' + f' should either be a Point or a ReferenceFrame.') + raise TypeError(f'Load type {type(load)} not recognized as a load. It ' + f'should be a Force, Torque or tuple.') diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/method.py b/lib/python3.10/site-packages/sympy/physics/mechanics/method.py new file mode 100644 index 0000000000000000000000000000000000000000..5c2c4a5f388e56e37bd9ecdf6daffc08ffa51070 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/method.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod + +class _Methods(ABC): + """Abstract Base Class for all methods.""" + + @abstractmethod + def q(self): + pass + + @abstractmethod + def u(self): + pass + + @abstractmethod + def bodies(self): + pass + + @abstractmethod + def loads(self): + pass + + @abstractmethod + def mass_matrix(self): + pass + + @abstractmethod + def forcing(self): + pass + + @abstractmethod + def mass_matrix_full(self): + pass + + @abstractmethod + def forcing_full(self): + pass + + def _form_eoms(self): + raise NotImplementedError("Subclasses must implement this.") diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/models.py b/lib/python3.10/site-packages/sympy/physics/mechanics/models.py new file mode 100644 index 0000000000000000000000000000000000000000..a89b929ffd540a07787f6f94714850b348c90781 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/models.py @@ -0,0 +1,230 @@ +#!/usr/bin/env python +"""This module contains some sample symbolic models used for testing and +examples.""" + +# Internal imports +from sympy.core import backend as sm +import sympy.physics.mechanics as me + + +def multi_mass_spring_damper(n=1, apply_gravity=False, + apply_external_forces=False): + r"""Returns a system containing the symbolic equations of motion and + associated variables for a simple multi-degree of freedom point mass, + spring, damper system with optional gravitational and external + specified forces. For example, a two mass system under the influence of + gravity and external forces looks like: + + :: + + ---------------- + | | | | g + \ | | | V + k0 / --- c0 | + | | | x0, v0 + --------- V + | m0 | ----- + --------- | + | | | | + \ v | | | + k1 / f0 --- c1 | + | | | x1, v1 + --------- V + | m1 | ----- + --------- + | f1 + V + + Parameters + ========== + + n : integer + The number of masses in the serial chain. + apply_gravity : boolean + If true, gravity will be applied to each mass. + apply_external_forces : boolean + If true, a time varying external force will be applied to each mass. + + Returns + ======= + + kane : sympy.physics.mechanics.kane.KanesMethod + A KanesMethod object. + + """ + + mass = sm.symbols('m:{}'.format(n)) + stiffness = sm.symbols('k:{}'.format(n)) + damping = sm.symbols('c:{}'.format(n)) + + acceleration_due_to_gravity = sm.symbols('g') + + coordinates = me.dynamicsymbols('x:{}'.format(n)) + speeds = me.dynamicsymbols('v:{}'.format(n)) + specifieds = me.dynamicsymbols('f:{}'.format(n)) + + ceiling = me.ReferenceFrame('N') + origin = me.Point('origin') + origin.set_vel(ceiling, 0) + + points = [origin] + kinematic_equations = [] + particles = [] + forces = [] + + for i in range(n): + + center = points[-1].locatenew('center{}'.format(i), + coordinates[i] * ceiling.x) + center.set_vel(ceiling, points[-1].vel(ceiling) + + speeds[i] * ceiling.x) + points.append(center) + + block = me.Particle('block{}'.format(i), center, mass[i]) + + kinematic_equations.append(speeds[i] - coordinates[i].diff()) + + total_force = (-stiffness[i] * coordinates[i] - + damping[i] * speeds[i]) + try: + total_force += (stiffness[i + 1] * coordinates[i + 1] + + damping[i + 1] * speeds[i + 1]) + except IndexError: # no force from below on last mass + pass + + if apply_gravity: + total_force += mass[i] * acceleration_due_to_gravity + + if apply_external_forces: + total_force += specifieds[i] + + forces.append((center, total_force * ceiling.x)) + + particles.append(block) + + kane = me.KanesMethod(ceiling, q_ind=coordinates, u_ind=speeds, + kd_eqs=kinematic_equations) + kane.kanes_equations(particles, forces) + + return kane + + +def n_link_pendulum_on_cart(n=1, cart_force=True, joint_torques=False): + r"""Returns the system containing the symbolic first order equations of + motion for a 2D n-link pendulum on a sliding cart under the influence of + gravity. + + :: + + | + o y v + \ 0 ^ g + \ | + --\-|---- + | \| | + F-> | o --|---> x + | | + --------- + o o + + Parameters + ========== + + n : integer + The number of links in the pendulum. + cart_force : boolean, default=True + If true an external specified lateral force is applied to the cart. + joint_torques : boolean, default=False + If true joint torques will be added as specified inputs at each + joint. + + Returns + ======= + + kane : sympy.physics.mechanics.kane.KanesMethod + A KanesMethod object. + + Notes + ===== + + The degrees of freedom of the system are n + 1, i.e. one for each + pendulum link and one for the lateral motion of the cart. + + M x' = F, where x = [u0, ..., un+1, q0, ..., qn+1] + + The joint angles are all defined relative to the ground where the x axis + defines the ground line and the y axis points up. The joint torques are + applied between each adjacent link and the between the cart and the + lower link where a positive torque corresponds to positive angle. + + """ + if n <= 0: + raise ValueError('The number of links must be a positive integer.') + + q = me.dynamicsymbols('q:{}'.format(n + 1)) + u = me.dynamicsymbols('u:{}'.format(n + 1)) + + if joint_torques is True: + T = me.dynamicsymbols('T1:{}'.format(n + 1)) + + m = sm.symbols('m:{}'.format(n + 1)) + l = sm.symbols('l:{}'.format(n)) + g, t = sm.symbols('g t') + + I = me.ReferenceFrame('I') + O = me.Point('O') + O.set_vel(I, 0) + + P0 = me.Point('P0') + P0.set_pos(O, q[0] * I.x) + P0.set_vel(I, u[0] * I.x) + Pa0 = me.Particle('Pa0', P0, m[0]) + + frames = [I] + points = [P0] + particles = [Pa0] + forces = [(P0, -m[0] * g * I.y)] + kindiffs = [q[0].diff(t) - u[0]] + + if cart_force is True or joint_torques is True: + specified = [] + else: + specified = None + + for i in range(n): + Bi = I.orientnew('B{}'.format(i), 'Axis', [q[i + 1], I.z]) + Bi.set_ang_vel(I, u[i + 1] * I.z) + frames.append(Bi) + + Pi = points[-1].locatenew('P{}'.format(i + 1), l[i] * Bi.y) + Pi.v2pt_theory(points[-1], I, Bi) + points.append(Pi) + + Pai = me.Particle('Pa' + str(i + 1), Pi, m[i + 1]) + particles.append(Pai) + + forces.append((Pi, -m[i + 1] * g * I.y)) + + if joint_torques is True: + + specified.append(T[i]) + + if i == 0: + forces.append((I, -T[i] * I.z)) + + if i == n - 1: + forces.append((Bi, T[i] * I.z)) + else: + forces.append((Bi, T[i] * I.z - T[i + 1] * I.z)) + + kindiffs.append(q[i + 1].diff(t) - u[i + 1]) + + if cart_force is True: + F = me.dynamicsymbols('F') + forces.append((P0, F * I.x)) + specified.append(F) + + kane = me.KanesMethod(I, q_ind=q, u_ind=u, kd_eqs=kindiffs) + kane.kanes_equations(particles, forces) + + return kane diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/particle.py b/lib/python3.10/site-packages/sympy/physics/mechanics/particle.py new file mode 100644 index 0000000000000000000000000000000000000000..5d49d4f811b8d1c7fff16c71991f5e01da6ded02 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/particle.py @@ -0,0 +1,209 @@ +from sympy import S +from sympy.physics.vector import cross, dot +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.inertia import inertia_of_point_mass +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['Particle'] + + +class Particle(BodyBase): + """A particle. + + Explanation + =========== + + Particles have a non-zero mass and lack spatial extension; they take up no + space. + + Values need to be supplied on initialization, but can be changed later. + + Parameters + ========== + + name : str + Name of particle + point : Point + A physics/mechanics Point which represents the position, velocity, and + acceleration of this Particle + mass : Sympifyable + A SymPy expression representing the Particle's mass + potential_energy : Sympifyable + The potential energy of the Particle. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point + >>> from sympy import Symbol + >>> po = Point('po') + >>> m = Symbol('m') + >>> pa = Particle('pa', po, m) + >>> # Or you could change these later + >>> pa.mass = m + >>> pa.point = po + + """ + point = BodyBase.masscenter + + def __init__(self, name, point=None, mass=None): + super().__init__(name, point, mass) + + def linear_momentum(self, frame): + """Linear momentum of the particle. + + Explanation + =========== + + The linear momentum L, of a particle P, with respect to frame N is + given by: + + L = m * v + + where m is the mass of the particle, and v is the velocity of the + particle in the frame N. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v = dynamicsymbols('m v') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> A = Particle('A', P, m) + >>> P.set_vel(N, v * N.x) + >>> A.linear_momentum(N) + m*v*N.x + + """ + + return self.mass * self.point.vel(frame) + + def angular_momentum(self, point, frame): + """Angular momentum of the particle about the point. + + Explanation + =========== + + The angular momentum H, about some point O of a particle, P, is given + by: + + ``H = cross(r, m * v)`` + + where r is the position vector from point O to the particle P, m is + the mass of the particle, and v is the velocity of the particle in + the inertial frame, N. + + Parameters + ========== + + point : Point + The point about which angular momentum of the particle is desired. + + frame : ReferenceFrame + The frame in which angular momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy.physics.mechanics import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v, r = dynamicsymbols('m v r') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> A = O.locatenew('A', r * N.x) + >>> P = Particle('P', A, m) + >>> P.point.set_vel(N, v * N.y) + >>> P.angular_momentum(O, N) + m*r*v*N.z + + """ + + return cross(self.point.pos_from(point), + self.mass * self.point.vel(frame)) + + def kinetic_energy(self, frame): + """Kinetic energy of the particle. + + Explanation + =========== + + The kinetic energy, T, of a particle, P, is given by: + + ``T = 1/2 (dot(m * v, v))`` + + where m is the mass of particle P, and v is the velocity of the + particle in the supplied ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The Particle's velocity is typically defined with respect to + an inertial frame but any relevant frame in which the velocity is + known can be supplied. + + Examples + ======== + + >>> from sympy.physics.mechanics import Particle, Point, ReferenceFrame + >>> from sympy import symbols + >>> m, v, r = symbols('m v r') + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> P = Particle('P', O, m) + >>> P.point.set_vel(N, v * N.y) + >>> P.kinetic_energy(N) + m*v**2/2 + + """ + + return S.Half * self.mass * dot(self.point.vel(frame), + self.point.vel(frame)) + + def set_potential_energy(self, scalar): + sympy_deprecation_warning( + """ +The sympy.physics.mechanics.Particle.set_potential_energy() +method is deprecated. Instead use + + P.potential_energy = scalar + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-set-potential-energy", + ) + self.potential_energy = scalar + + def parallel_axis(self, point, frame): + """Returns an inertia dyadic of the particle with respect to another + point and frame. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the particle expressed about the provided + point and frame. + + """ + return inertia_of_point_mass(self.mass, self.point.pos_from(point), + frame) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/pathway.py b/lib/python3.10/site-packages/sympy/physics/mechanics/pathway.py new file mode 100644 index 0000000000000000000000000000000000000000..3823750b0aeddcd8a8c4c55c02f816d823141fbb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/pathway.py @@ -0,0 +1,688 @@ +"""Implementations of pathways for use by actuators.""" + +from abc import ABC, abstractmethod + +from sympy.core.singleton import S +from sympy.physics.mechanics.loads import Force +from sympy.physics.mechanics.wrapping_geometry import WrappingGeometryBase +from sympy.physics.vector import Point, dynamicsymbols + + +__all__ = ['PathwayBase', 'LinearPathway', 'ObstacleSetPathway', + 'WrappingPathway'] + + +class PathwayBase(ABC): + """Abstract base class for all pathway classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom pathway types through subclassing. + + """ + + def __init__(self, *attachments): + """Initializer for ``PathwayBase``.""" + self.attachments = attachments + + @property + def attachments(self): + """The pair of points defining a pathway's ends.""" + return self._attachments + + @attachments.setter + def attachments(self, attachments): + if hasattr(self, '_attachments'): + msg = ( + f'Can\'t set attribute `attachments` to {repr(attachments)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if len(attachments) != 2: + msg = ( + f'Value {repr(attachments)} passed to `attachments` was an ' + f'iterable of length {len(attachments)}, must be an iterable ' + f'of length 2.' + ) + raise ValueError(msg) + for i, point in enumerate(attachments): + if not isinstance(point, Point): + msg = ( + f'Value {repr(point)} passed to `attachments` at index ' + f'{i} was of type {type(point)}, must be {Point}.' + ) + raise TypeError(msg) + self._attachments = tuple(attachments) + + @property + @abstractmethod + def length(self): + """An expression representing the pathway's length.""" + pass + + @property + @abstractmethod + def extension_velocity(self): + """An expression representing the pathway's extension velocity.""" + pass + + @abstractmethod + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + """ + pass + + def __repr__(self): + """Default representation of a pathway.""" + attachments = ', '.join(str(a) for a in self.attachments) + return f'{self.__class__.__name__}({attachments})' + + +class LinearPathway(PathwayBase): + """Linear pathway between a pair of attachment points. + + Explanation + =========== + + A linear pathway forms a straight-line segment between two points and is + the simplest pathway that can be formed. It will not interact with any + other objects in the system, i.e. a ``LinearPathway`` will intersect other + objects to ensure that the path between its two ends (its attachments) is + the shortest possible. + + A linear pathway is made up of two points that can move relative to each + other, and a pair of equal and opposite forces acting on the points. If the + positive time-varying Euclidean distance between the two points is defined, + then the "extension velocity" is the time derivative of this distance. The + extension velocity is positive when the two points are moving away from + each other and negative when moving closer to each other. The direction for + the force acting on either point is determined by constructing a unit + vector directed from the other point to this point. This establishes a sign + convention such that a positive force magnitude tends to push the points + apart. The following diagram shows the positive force sense and the + distance between the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + >>> from sympy.physics.mechanics import LinearPathway + + To construct a pathway, two points are required to be passed to the + ``attachments`` parameter as a ``tuple``. + + >>> from sympy.physics.mechanics import Point + >>> pA, pB = Point('pA'), Point('pB') + >>> linear_pathway = LinearPathway(pA, pB) + >>> linear_pathway + LinearPathway(pA, pB) + + The pathway created above isn't very interesting without the positions and + velocities of its attachment points being described. Without this its not + possible to describe how the pathway moves, i.e. its length or its + extension velocity. + + >>> from sympy.physics.mechanics import ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> pB.set_pos(pA, q*N.x) + >>> pB.pos_from(pA) + q(t)*N.x + + A pathway's length can be accessed via its ``length`` attribute. + + >>> linear_pathway.length + sqrt(q(t)**2) + + Note how what appears to be an overly-complex expression is returned. This + is actually required as it ensures that a pathway's length is always + positive. + + A pathway's extension velocity can be accessed similarly via its + ``extension_velocity`` attribute. + + >>> linear_pathway.extension_velocity + sqrt(q(t)**2)*Derivative(q(t), t)/q(t) + + Parameters + ========== + + attachments : tuple[Point, Point] + Pair of ``Point`` objects between which the linear pathway spans. + Constructor expects two points to be passed, e.g. + ``LinearPathway(Point('pA'), Point('pB'))``. More or fewer points will + cause an error to be thrown. + + """ + + def __init__(self, *attachments): + """Initializer for ``LinearPathway``. + + Parameters + ========== + + attachments : Point + Pair of ``Point`` objects between which the linear pathway spans. + Constructor expects two points to be passed, e.g. + ``LinearPathway(Point('pA'), Point('pB'))``. More or fewer points + will cause an error to be thrown. + + """ + super().__init__(*attachments) + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + return _point_pair_length(*self.attachments) + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + return _point_pair_extension_velocity(*self.attachments) + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in a linear + actuator that produces an expansile force ``F``. First, create a linear + actuator between two points separated by the coordinate ``q`` in the + ``x`` direction of the global frame ``N``. + + >>> from sympy.physics.mechanics import (LinearPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> pA, pB = Point('pA'), Point('pB') + >>> pB.set_pos(pA, q*N.x) + >>> linear_pathway = LinearPathway(pA, pB) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> from sympy import symbols + >>> F = symbols('F') + >>> linear_pathway.to_loads(F) + [(pA, - F*q(t)/sqrt(q(t)**2)*N.x), (pB, F*q(t)/sqrt(q(t)**2)*N.x)] + + Parameters + ========== + + force : Expr + Magnitude of the force acting along the length of the pathway. As + per the sign conventions for the pathway length, pathway extension + velocity, and pair of point forces, if this ``Expr`` is positive + then the force will act to push the pair of points away from one + another (it is expansile). + + """ + relative_position = _point_pair_relative_position(*self.attachments) + loads = [ + Force(self.attachments[0], -force*relative_position/self.length), + Force(self.attachments[-1], force*relative_position/self.length), + ] + return loads + + +class ObstacleSetPathway(PathwayBase): + """Obstacle-set pathway between a set of attachment points. + + Explanation + =========== + + An obstacle-set pathway forms a series of straight-line segment between + pairs of consecutive points in a set of points. It is similiar to multiple + linear pathways joined end-to-end. It will not interact with any other + objects in the system, i.e. an ``ObstacleSetPathway`` will intersect other + objects to ensure that the path between its pairs of points (its + attachments) is the shortest possible. + + Examples + ======== + + To construct an obstacle-set pathway, three or more points are required to + be passed to the ``attachments`` parameter as a ``tuple``. + + >>> from sympy.physics.mechanics import ObstacleSetPathway, Point + >>> pA, pB, pC, pD = Point('pA'), Point('pB'), Point('pC'), Point('pD') + >>> obstacle_set_pathway = ObstacleSetPathway(pA, pB, pC, pD) + >>> obstacle_set_pathway + ObstacleSetPathway(pA, pB, pC, pD) + + The pathway created above isn't very interesting without the positions and + velocities of its attachment points being described. Without this its not + possible to describe how the pathway moves, i.e. its length or its + extension velocity. + + >>> from sympy import cos, sin + >>> from sympy.physics.mechanics import ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> pO = Point('pO') + >>> pA.set_pos(pO, N.y) + >>> pB.set_pos(pO, -N.x) + >>> pC.set_pos(pA, cos(q) * N.x - (sin(q) + 1) * N.y) + >>> pD.set_pos(pA, sin(q) * N.x + (cos(q) - 1) * N.y) + >>> pB.pos_from(pA) + - N.x - N.y + >>> pC.pos_from(pA) + cos(q(t))*N.x + (-sin(q(t)) - 1)*N.y + >>> pD.pos_from(pA) + sin(q(t))*N.x + (cos(q(t)) - 1)*N.y + + A pathway's length can be accessed via its ``length`` attribute. + + >>> obstacle_set_pathway.length.simplify() + sqrt(2)*(sqrt(cos(q(t)) + 1) + 2) + + A pathway's extension velocity can be accessed similarly via its + ``extension_velocity`` attribute. + + >>> obstacle_set_pathway.extension_velocity.simplify() + -sqrt(2)*sin(q(t))*Derivative(q(t), t)/(2*sqrt(cos(q(t)) + 1)) + + Parameters + ========== + + attachments : tuple[Point, Point] + The set of ``Point`` objects that define the segmented obstacle-set + pathway. + + """ + + def __init__(self, *attachments): + """Initializer for ``ObstacleSetPathway``. + + Parameters + ========== + + attachments : tuple[Point, ...] + The set of ``Point`` objects that define the segmented obstacle-set + pathway. + + """ + super().__init__(*attachments) + + @property + def attachments(self): + """The set of points defining a pathway's segmented path.""" + return self._attachments + + @attachments.setter + def attachments(self, attachments): + if hasattr(self, '_attachments'): + msg = ( + f'Can\'t set attribute `attachments` to {repr(attachments)} ' + f'as it is immutable.' + ) + raise AttributeError(msg) + if len(attachments) <= 2: + msg = ( + f'Value {repr(attachments)} passed to `attachments` was an ' + f'iterable of length {len(attachments)}, must be an iterable ' + f'of length 3 or greater.' + ) + raise ValueError(msg) + for i, point in enumerate(attachments): + if not isinstance(point, Point): + msg = ( + f'Value {repr(point)} passed to `attachments` at index ' + f'{i} was of type {type(point)}, must be {Point}.' + ) + raise TypeError(msg) + self._attachments = tuple(attachments) + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + length = S.Zero + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + length += _point_pair_length(*attachment_pair) + return length + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + extension_velocity = S.Zero + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + extension_velocity += _point_pair_extension_velocity(*attachment_pair) + return extension_velocity + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in an + actuator that follows an obstacle-set pathway between four points and + produces an expansile force ``F``. First, create a pair of reference + frames, ``A`` and ``B``, in which the four points ``pA``, ``pB``, + ``pC``, and ``pD`` will be located. The first two points in frame ``A`` + and the second two in frame ``B``. Frame ``B`` will also be oriented + such that it relates to ``A`` via a rotation of ``q`` about an axis + ``N.z`` in a global frame (``N.z``, ``A.z``, and ``B.z`` are parallel). + + >>> from sympy.physics.mechanics import (ObstacleSetPathway, Point, + ... ReferenceFrame) + >>> from sympy.physics.vector import dynamicsymbols + >>> q = dynamicsymbols('q') + >>> N = ReferenceFrame('N') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'axis', (0, N.x)) + >>> B = A.orientnew('B', 'axis', (q, N.z)) + >>> pO = Point('pO') + >>> pA, pB, pC, pD = Point('pA'), Point('pB'), Point('pC'), Point('pD') + >>> pA.set_pos(pO, A.x) + >>> pB.set_pos(pO, -A.y) + >>> pC.set_pos(pO, B.y) + >>> pD.set_pos(pO, B.x) + >>> obstacle_set_pathway = ObstacleSetPathway(pA, pB, pC, pD) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> from sympy import Symbol + >>> F = Symbol('F') + >>> obstacle_set_pathway.to_loads(F) + [(pA, sqrt(2)*F/2*A.x + sqrt(2)*F/2*A.y), + (pB, - sqrt(2)*F/2*A.x - sqrt(2)*F/2*A.y), + (pB, - F/sqrt(2*cos(q(t)) + 2)*A.y - F/sqrt(2*cos(q(t)) + 2)*B.y), + (pC, F/sqrt(2*cos(q(t)) + 2)*A.y + F/sqrt(2*cos(q(t)) + 2)*B.y), + (pC, - sqrt(2)*F/2*B.x + sqrt(2)*F/2*B.y), + (pD, sqrt(2)*F/2*B.x - sqrt(2)*F/2*B.y)] + + Parameters + ========== + + force : Expr + The force acting along the length of the pathway. It is assumed + that this ``Expr`` represents an expansile force. + + """ + loads = [] + attachment_pairs = zip(self.attachments[:-1], self.attachments[1:]) + for attachment_pair in attachment_pairs: + relative_position = _point_pair_relative_position(*attachment_pair) + length = _point_pair_length(*attachment_pair) + loads.extend([ + Force(attachment_pair[0], -force*relative_position/length), + Force(attachment_pair[1], force*relative_position/length), + ]) + return loads + + +class WrappingPathway(PathwayBase): + """Pathway that wraps a geometry object. + + Explanation + =========== + + A wrapping pathway interacts with a geometry object and forms a path that + wraps smoothly along its surface. The wrapping pathway along the geometry + object will be the geodesic that the geometry object defines based on the + two points. It will not interact with any other objects in the system, i.e. + a ``WrappingPathway`` will intersect other objects to ensure that the path + between its two ends (its attachments) is the shortest possible. + + To explain the sign conventions used for pathway length, extension + velocity, and direction of applied forces, we can ignore the geometry with + which the wrapping pathway interacts. A wrapping pathway is made up of two + points that can move relative to each other, and a pair of equal and + opposite forces acting on the points. If the positive time-varying + Euclidean distance between the two points is defined, then the "extension + velocity" is the time derivative of this distance. The extension velocity + is positive when the two points are moving away from each other and + negative when moving closer to each other. The direction for the force + acting on either point is determined by constructing a unit vector directed + from the other point to this point. This establishes a sign convention such + that a positive force magnitude tends to push the points apart. The + following diagram shows the positive force sense and the distance between + the points:: + + P Q + o<--- F --->o + | | + |<--l(t)--->| + + Examples + ======== + + >>> from sympy.physics.mechanics import WrappingPathway + + To construct a wrapping pathway, like other pathways, a pair of points must + be passed, followed by an instance of a wrapping geometry class as a + keyword argument. We'll use a cylinder with radius ``r`` and its axis + parallel to ``N.x`` passing through a point ``pO``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Point, ReferenceFrame, WrappingCylinder + >>> r = symbols('r') + >>> N = ReferenceFrame('N') + >>> pA, pB, pO = Point('pA'), Point('pB'), Point('pO') + >>> cylinder = WrappingCylinder(r, pO, N.x) + >>> wrapping_pathway = WrappingPathway(pA, pB, cylinder) + >>> wrapping_pathway + WrappingPathway(pA, pB, geometry=WrappingCylinder(radius=r, point=pO, + axis=N.x)) + + Parameters + ========== + + attachment_1 : Point + First of the pair of ``Point`` objects between which the wrapping + pathway spans. + attachment_2 : Point + Second of the pair of ``Point`` objects between which the wrapping + pathway spans. + geometry : WrappingGeometryBase + Geometry about which the pathway wraps. + + """ + + def __init__(self, attachment_1, attachment_2, geometry): + """Initializer for ``WrappingPathway``. + + Parameters + ========== + + attachment_1 : Point + First of the pair of ``Point`` objects between which the wrapping + pathway spans. + attachment_2 : Point + Second of the pair of ``Point`` objects between which the wrapping + pathway spans. + geometry : WrappingGeometryBase + Geometry about which the pathway wraps. + The geometry about which the pathway wraps. + + """ + super().__init__(attachment_1, attachment_2) + self.geometry = geometry + + @property + def geometry(self): + """Geometry around which the pathway wraps.""" + return self._geometry + + @geometry.setter + def geometry(self, geometry): + if hasattr(self, '_geometry'): + msg = ( + f'Can\'t set attribute `geometry` to {repr(geometry)} as it ' + f'is immutable.' + ) + raise AttributeError(msg) + if not isinstance(geometry, WrappingGeometryBase): + msg = ( + f'Value {repr(geometry)} passed to `geometry` was of type ' + f'{type(geometry)}, must be {WrappingGeometryBase}.' + ) + raise TypeError(msg) + self._geometry = geometry + + @property + def length(self): + """Exact analytical expression for the pathway's length.""" + return self.geometry.geodesic_length(*self.attachments) + + @property + def extension_velocity(self): + """Exact analytical expression for the pathway's extension velocity.""" + return self.length.diff(dynamicsymbols._t) + + def to_loads(self, force): + """Loads required by the equations of motion method classes. + + Explanation + =========== + + ``KanesMethod`` requires a list of ``Point``-``Vector`` tuples to be + passed to the ``loads`` parameters of its ``kanes_equations`` method + when constructing the equations of motion. This method acts as a + utility to produce the correctly-structred pairs of points and vectors + required so that these can be easily concatenated with other items in + the list of loads and passed to ``KanesMethod.kanes_equations``. These + loads are also in the correct form to also be passed to the other + equations of motion method classes, e.g. ``LagrangesMethod``. + + Examples + ======== + + The below example shows how to generate the loads produced in an + actuator that produces an expansile force ``F`` while wrapping around a + cylinder. First, create a cylinder with radius ``r`` and an axis + parallel to the ``N.z`` direction of the global frame ``N`` that also + passes through a point ``pO``. + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder) + >>> N = ReferenceFrame('N') + >>> r = symbols('r', positive=True) + >>> pO = Point('pO') + >>> cylinder = WrappingCylinder(r, pO, N.z) + + Create the pathway of the actuator using the ``WrappingPathway`` class, + defined to span between two points ``pA`` and ``pB``. Both points lie + on the surface of the cylinder and the location of ``pB`` is defined + relative to ``pA`` by the dynamics symbol ``q``. + + >>> from sympy import cos, sin + >>> from sympy.physics.mechanics import WrappingPathway, dynamicsymbols + >>> q = dynamicsymbols('q') + >>> pA = Point('pA') + >>> pB = Point('pB') + >>> pA.set_pos(pO, r*N.x) + >>> pB.set_pos(pO, r*(cos(q)*N.x + sin(q)*N.y)) + >>> pB.pos_from(pA) + (r*cos(q(t)) - r)*N.x + r*sin(q(t))*N.y + >>> pathway = WrappingPathway(pA, pB, cylinder) + + Now create a symbol ``F`` to describe the magnitude of the (expansile) + force that will be produced along the pathway. The list of loads that + ``KanesMethod`` requires can be produced by calling the pathway's + ``to_loads`` method with ``F`` passed as the only argument. + + >>> F = symbols('F') + >>> loads = pathway.to_loads(F) + >>> [load.__class__(load.location, load.vector.simplify()) for load in loads] + [(pA, F*N.y), (pB, F*sin(q(t))*N.x - F*cos(q(t))*N.y), + (pO, - F*sin(q(t))*N.x + F*(cos(q(t)) - 1)*N.y)] + + Parameters + ========== + + force : Expr + Magnitude of the force acting along the length of the pathway. It + is assumed that this ``Expr`` represents an expansile force. + + """ + pA, pB = self.attachments + pO = self.geometry.point + pA_force, pB_force = self.geometry.geodesic_end_vectors(pA, pB) + pO_force = -(pA_force + pB_force) + + loads = [ + Force(pA, force * pA_force), + Force(pB, force * pB_force), + Force(pO, force * pO_force), + ] + return loads + + def __repr__(self): + """Representation of a ``WrappingPathway``.""" + attachments = ', '.join(str(a) for a in self.attachments) + return ( + f'{self.__class__.__name__}({attachments}, ' + f'geometry={self.geometry})' + ) + + +def _point_pair_relative_position(point_1, point_2): + """The relative position between a pair of points.""" + return point_2.pos_from(point_1) + + +def _point_pair_length(point_1, point_2): + """The length of the direct linear path between two points.""" + return _point_pair_relative_position(point_1, point_2).magnitude() + + +def _point_pair_extension_velocity(point_1, point_2): + """The extension velocity of the direct linear path between two points.""" + return _point_pair_length(point_1, point_2).diff(dynamicsymbols._t) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/rigidbody.py b/lib/python3.10/site-packages/sympy/physics/mechanics/rigidbody.py new file mode 100644 index 0000000000000000000000000000000000000000..7cc61ff468f7f26d98209a48ca59ffa12a570490 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/rigidbody.py @@ -0,0 +1,314 @@ +from sympy import Symbol, S +from sympy.physics.vector import ReferenceFrame, Dyadic, Point, dot +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.inertia import inertia_of_point_mass, Inertia +from sympy.utilities.exceptions import sympy_deprecation_warning + +__all__ = ['RigidBody'] + + +class RigidBody(BodyBase): + """An idealized rigid body. + + Explanation + =========== + + This is essentially a container which holds the various components which + describe a rigid body: a name, mass, center of mass, reference frame, and + inertia. + + All of these need to be supplied on creation, but can be changed + afterwards. + + Attributes + ========== + + name : string + The body's name. + masscenter : Point + The point which represents the center of mass of the rigid body. + frame : ReferenceFrame + The ReferenceFrame which the rigid body is fixed in. + mass : Sympifyable + The body's mass. + inertia : (Dyadic, Point) + The body's inertia about a point; stored in a tuple as shown above. + potential_energy : Sympifyable + The potential energy of the RigidBody. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.mechanics import ReferenceFrame, Point, RigidBody + >>> from sympy.physics.mechanics import outer + >>> m = Symbol('m') + >>> A = ReferenceFrame('A') + >>> P = Point('P') + >>> I = outer (A.x, A.x) + >>> inertia_tuple = (I, P) + >>> B = RigidBody('B', P, A, m, inertia_tuple) + >>> # Or you could change them afterwards + >>> m2 = Symbol('m2') + >>> B.mass = m2 + + """ + + def __init__(self, name, masscenter=None, frame=None, mass=None, + inertia=None): + super().__init__(name, masscenter, mass) + if frame is None: + frame = ReferenceFrame(f'{name}_frame') + self.frame = frame + if inertia is None: + ixx = Symbol(f'{name}_ixx') + iyy = Symbol(f'{name}_iyy') + izz = Symbol(f'{name}_izz') + izx = Symbol(f'{name}_izx') + ixy = Symbol(f'{name}_ixy') + iyz = Symbol(f'{name}_iyz') + inertia = Inertia.from_inertia_scalars(self.masscenter, self.frame, + ixx, iyy, izz, ixy, iyz, izx) + self.inertia = inertia + + def __repr__(self): + return (f'{self.__class__.__name__}({repr(self.name)}, masscenter=' + f'{repr(self.masscenter)}, frame={repr(self.frame)}, mass=' + f'{repr(self.mass)}, inertia={repr(self.inertia)})') + + @property + def frame(self): + """The ReferenceFrame fixed to the body.""" + return self._frame + + @frame.setter + def frame(self, F): + if not isinstance(F, ReferenceFrame): + raise TypeError("RigidBody frame must be a ReferenceFrame object.") + self._frame = F + + @property + def x(self): + """The basis Vector for the body, in the x direction. """ + return self.frame.x + + @property + def y(self): + """The basis Vector for the body, in the y direction. """ + return self.frame.y + + @property + def z(self): + """The basis Vector for the body, in the z direction. """ + return self.frame.z + + @property + def inertia(self): + """The body's inertia about a point; stored as (Dyadic, Point).""" + return self._inertia + + @inertia.setter + def inertia(self, I): + # check if I is of the form (Dyadic, Point) + if len(I) != 2 or not isinstance(I[0], Dyadic) or not isinstance(I[1], Point): + raise TypeError("RigidBody inertia must be a tuple of the form (Dyadic, Point).") + + self._inertia = Inertia(I[0], I[1]) + # have I S/O, want I S/S* + # I S/O = I S/S* + I S*/O; I S/S* = I S/O - I S*/O + # I_S/S* = I_S/O - I_S*/O + I_Ss_O = inertia_of_point_mass(self.mass, + self.masscenter.pos_from(I[1]), + self.frame) + self._central_inertia = I[0] - I_Ss_O + + @property + def central_inertia(self): + """The body's central inertia dyadic.""" + return self._central_inertia + + @central_inertia.setter + def central_inertia(self, I): + if not isinstance(I, Dyadic): + raise TypeError("RigidBody inertia must be a Dyadic object.") + self.inertia = Inertia(I, self.masscenter) + + def linear_momentum(self, frame): + """ Linear momentum of the rigid body. + + Explanation + =========== + + The linear momentum L, of a rigid body B, with respect to frame N is + given by: + + ``L = m * v`` + + where m is the mass of the rigid body, and v is the velocity of the mass + center of B in the frame N. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which linear momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v = dynamicsymbols('m v') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> I = outer (N.x, N.x) + >>> Inertia_tuple = (I, P) + >>> B = RigidBody('B', P, N, m, Inertia_tuple) + >>> B.linear_momentum(N) + m*v*N.x + + """ + + return self.mass * self.masscenter.vel(frame) + + def angular_momentum(self, point, frame): + """Returns the angular momentum of the rigid body about a point in the + given frame. + + Explanation + =========== + + The angular momentum H of a rigid body B about some point O in a frame N + is given by: + + ``H = dot(I, w) + cross(r, m * v)`` + + where I and m are the central inertia dyadic and mass of rigid body B, w + is the angular velocity of body B in the frame N, r is the position + vector from point O to the mass center of B, and v is the velocity of + the mass center in the frame N. + + Parameters + ========== + + point : Point + The point about which angular momentum is desired. + frame : ReferenceFrame + The frame in which angular momentum is desired. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> m, v, r, omega = dynamicsymbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, 1 * N.x) + >>> I = outer(b.x, b.x) + >>> B = RigidBody('B', P, b, m, (I, P)) + >>> B.angular_momentum(P, N) + omega*b.x + + """ + I = self.central_inertia + w = self.frame.ang_vel_in(frame) + m = self.mass + r = self.masscenter.pos_from(point) + v = self.masscenter.vel(frame) + + return I.dot(w) + r.cross(m * v) + + def kinetic_energy(self, frame): + """Kinetic energy of the rigid body. + + Explanation + =========== + + The kinetic energy, T, of a rigid body, B, is given by: + + ``T = 1/2 * (dot(dot(I, w), w) + dot(m * v, v))`` + + where I and m are the central inertia dyadic and mass of rigid body B + respectively, w is the body's angular velocity, and v is the velocity of + the body's mass center in the supplied ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The RigidBody's angular velocity and the velocity of it's mass + center are typically defined with respect to an inertial frame but + any relevant frame in which the velocities are known can be + supplied. + + Examples + ======== + + >>> from sympy.physics.mechanics import Point, ReferenceFrame, outer + >>> from sympy.physics.mechanics import RigidBody + >>> from sympy import symbols + >>> m, v, r, omega = symbols('m v r omega') + >>> N = ReferenceFrame('N') + >>> b = ReferenceFrame('b') + >>> b.set_ang_vel(N, omega * b.x) + >>> P = Point('P') + >>> P.set_vel(N, v * N.x) + >>> I = outer (b.x, b.x) + >>> inertia_tuple = (I, P) + >>> B = RigidBody('B', P, b, m, inertia_tuple) + >>> B.kinetic_energy(N) + m*v**2/2 + omega**2/2 + + """ + + rotational_KE = S.Half * dot( + self.frame.ang_vel_in(frame), + dot(self.central_inertia, self.frame.ang_vel_in(frame))) + translational_KE = S.Half * self.mass * dot(self.masscenter.vel(frame), + self.masscenter.vel(frame)) + return rotational_KE + translational_KE + + def set_potential_energy(self, scalar): + sympy_deprecation_warning( + """ +The sympy.physics.mechanics.RigidBody.set_potential_energy() +method is deprecated. Instead use + + B.potential_energy = scalar + """, + deprecated_since_version="1.5", + active_deprecations_target="deprecated-set-potential-energy", + ) + self.potential_energy = scalar + + def parallel_axis(self, point, frame=None): + """Returns the inertia dyadic of the body with respect to another point. + + Parameters + ========== + + point : sympy.physics.vector.Point + The point to express the inertia dyadic about. + frame : sympy.physics.vector.ReferenceFrame + The reference frame used to construct the dyadic. + + Returns + ======= + + inertia : sympy.physics.vector.Dyadic + The inertia dyadic of the rigid body expressed about the provided + point. + + """ + if frame is None: + frame = self.frame + return self.central_inertia + inertia_of_point_mass( + self.mass, self.masscenter.pos_from(point), frame) diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/system.py b/lib/python3.10/site-packages/sympy/physics/mechanics/system.py new file mode 100644 index 0000000000000000000000000000000000000000..c8e0657d7da54ca5aaad9b37b816235641968470 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/system.py @@ -0,0 +1,1553 @@ +from functools import wraps + +from sympy.core.basic import Basic +from sympy.matrices.immutable import ImmutableMatrix +from sympy.matrices.dense import Matrix, eye, zeros +from sympy.core.containers import OrderedSet +from sympy.physics.mechanics.actuator import ActuatorBase +from sympy.physics.mechanics.body_base import BodyBase +from sympy.physics.mechanics.functions import ( + Lagrangian, _validate_coordinates, find_dynamicsymbols) +from sympy.physics.mechanics.joint import Joint +from sympy.physics.mechanics.kane import KanesMethod +from sympy.physics.mechanics.lagrange import LagrangesMethod +from sympy.physics.mechanics.loads import _parse_load, gravity +from sympy.physics.mechanics.method import _Methods +from sympy.physics.mechanics.particle import Particle +from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import filldedent + +__all__ = ['SymbolicSystem', 'System'] + + +def _reset_eom_method(method): + """Decorator to reset the eom_method if a property is changed.""" + + @wraps(method) + def wrapper(self, *args, **kwargs): + self._eom_method = None + return method(self, *args, **kwargs) + + return wrapper + + +class System(_Methods): + """Class to define a multibody system and form its equations of motion. + + Explanation + =========== + + A ``System`` instance stores the different objects associated with a model, + including bodies, joints, constraints, and other relevant information. With + all the relationships between components defined, the ``System`` can be used + to form the equations of motion using a backend, such as ``KanesMethod``. + The ``System`` has been designed to be compatible with third-party + libraries for greater flexibility and integration with other tools. + + Attributes + ========== + + frame : ReferenceFrame + Inertial reference frame of the system. + fixed_point : Point + A fixed point in the inertial reference frame. + x : Vector + Unit vector fixed in the inertial reference frame. + y : Vector + Unit vector fixed in the inertial reference frame. + z : Vector + Unit vector fixed in the inertial reference frame. + q : ImmutableMatrix + Matrix of all the generalized coordinates, i.e. the independent + generalized coordinates stacked upon the dependent. + u : ImmutableMatrix + Matrix of all the generalized speeds, i.e. the independent generealized + speeds stacked upon the dependent. + q_ind : ImmutableMatrix + Matrix of the independent generalized coordinates. + q_dep : ImmutableMatrix + Matrix of the dependent generalized coordinates. + u_ind : ImmutableMatrix + Matrix of the independent generalized speeds. + u_dep : ImmutableMatrix + Matrix of the dependent generalized speeds. + u_aux : ImmutableMatrix + Matrix of auxiliary generalized speeds. + kdes : ImmutableMatrix + Matrix of the kinematical differential equations as expressions equated + to the zero matrix. + bodies : tuple of BodyBase subclasses + Tuple of all bodies that make up the system. + joints : tuple of Joint + Tuple of all joints that connect bodies in the system. + loads : tuple of LoadBase subclasses + Tuple of all loads that have been applied to the system. + actuators : tuple of ActuatorBase subclasses + Tuple of all actuators present in the system. + holonomic_constraints : ImmutableMatrix + Matrix with the holonomic constraints as expressions equated to the zero + matrix. + nonholonomic_constraints : ImmutableMatrix + Matrix with the nonholonomic constraints as expressions equated to the + zero matrix. + velocity_constraints : ImmutableMatrix + Matrix with the velocity constraints as expressions equated to the zero + matrix. These are by default derived as the time derivatives of the + holonomic constraints extended with the nonholonomic constraints. + eom_method : subclass of KanesMethod or LagrangesMethod + Backend for forming the equations of motion. + + Examples + ======== + + In the example below a cart with a pendulum is created. The cart moves along + the x axis of the rail and the pendulum rotates about the z axis. The length + of the pendulum is ``l`` with the pendulum represented as a particle. To + move the cart a time dependent force ``F`` is applied to the cart. + + We first need to import some functions and create some of our variables. + + >>> from sympy import symbols, simplify + >>> from sympy.physics.mechanics import ( + ... mechanics_printing, dynamicsymbols, RigidBody, Particle, + ... ReferenceFrame, PrismaticJoint, PinJoint, System) + >>> mechanics_printing(pretty_print=False) + >>> g, l = symbols('g l') + >>> F = dynamicsymbols('F') + + The next step is to create bodies. It is also useful to create a frame for + locating the particle with respect to the pin joint later on, as a particle + does not have a body-fixed frame. + + >>> rail = RigidBody('rail') + >>> cart = RigidBody('cart') + >>> bob = Particle('bob') + >>> bob_frame = ReferenceFrame('bob_frame') + + Initialize the system, with the rail as the Newtonian reference. The body is + also automatically added to the system. + + >>> system = System.from_newtonian(rail) + >>> print(system.bodies[0]) + rail + + Create the joints, while immediately also adding them to the system. + + >>> system.add_joints( + ... PrismaticJoint('slider', rail, cart, joint_axis=rail.x), + ... PinJoint('pin', cart, bob, joint_axis=cart.z, + ... child_interframe=bob_frame, + ... child_point=l * bob_frame.y) + ... ) + >>> system.joints + (PrismaticJoint: slider parent: rail child: cart, + PinJoint: pin parent: cart child: bob) + + While adding the joints, the associated generalized coordinates, generalized + speeds, kinematic differential equations and bodies are also added to the + system. + + >>> system.q + Matrix([ + [q_slider], + [ q_pin]]) + >>> system.u + Matrix([ + [u_slider], + [ u_pin]]) + >>> system.kdes + Matrix([ + [u_slider - q_slider'], + [ u_pin - q_pin']]) + >>> [body.name for body in system.bodies] + ['rail', 'cart', 'bob'] + + With the kinematics established, we can now apply gravity and the cart force + ``F``. + + >>> system.apply_uniform_gravity(-g * system.y) + >>> system.add_loads((cart.masscenter, F * rail.x)) + >>> system.loads + ((rail_masscenter, - g*rail_mass*rail_frame.y), + (cart_masscenter, - cart_mass*g*rail_frame.y), + (bob_masscenter, - bob_mass*g*rail_frame.y), + (cart_masscenter, F*rail_frame.x)) + + With the entire system defined, we can now form the equations of motion. + Before forming the equations of motion, one can also run some checks that + will try to identify some common errors. + + >>> system.validate_system() + >>> system.form_eoms() + Matrix([ + [bob_mass*l*u_pin**2*sin(q_pin) - bob_mass*l*cos(q_pin)*u_pin' + - (bob_mass + cart_mass)*u_slider' + F], + [ -bob_mass*g*l*sin(q_pin) - bob_mass*l**2*u_pin' + - bob_mass*l*cos(q_pin)*u_slider']]) + >>> simplify(system.mass_matrix) + Matrix([ + [ bob_mass + cart_mass, bob_mass*l*cos(q_pin)], + [bob_mass*l*cos(q_pin), bob_mass*l**2]]) + >>> system.forcing + Matrix([ + [bob_mass*l*u_pin**2*sin(q_pin) + F], + [ -bob_mass*g*l*sin(q_pin)]]) + + The complexity of the above example can be increased if we add a constraint + to prevent the particle from moving in the horizontal (x) direction. This + can be done by adding a holonomic constraint. After which we should also + redefine what our (in)dependent generalized coordinates and speeds are. + + >>> system.add_holonomic_constraints( + ... bob.masscenter.pos_from(rail.masscenter).dot(system.x) + ... ) + >>> system.q_ind = system.get_joint('pin').coordinates + >>> system.q_dep = system.get_joint('slider').coordinates + >>> system.u_ind = system.get_joint('pin').speeds + >>> system.u_dep = system.get_joint('slider').speeds + + With the updated system the equations of motion can be formed again. + + >>> system.validate_system() + >>> system.form_eoms() + Matrix([[-bob_mass*g*l*sin(q_pin) + - bob_mass*l**2*u_pin' + - bob_mass*l*cos(q_pin)*u_slider' + - l*(bob_mass*l*u_pin**2*sin(q_pin) + - bob_mass*l*cos(q_pin)*u_pin' + - (bob_mass + cart_mass)*u_slider')*cos(q_pin) + - l*F*cos(q_pin)]]) + >>> simplify(system.mass_matrix) + Matrix([ + [bob_mass*l**2*sin(q_pin)**2, -cart_mass*l*cos(q_pin)], + [ l*cos(q_pin), 1]]) + >>> simplify(system.forcing) + Matrix([ + [-l*(bob_mass*g*sin(q_pin) + bob_mass*l*u_pin**2*sin(2*q_pin)/2 + + F*cos(q_pin))], + [ + l*u_pin**2*sin(q_pin)]]) + + """ + + def __init__(self, frame=None, fixed_point=None): + """Initialize the system. + + Parameters + ========== + + frame : ReferenceFrame, optional + The inertial frame of the system. If none is supplied, a new frame + will be created. + fixed_point : Point, optional + A fixed point in the inertial reference frame. If none is supplied, + a new fixed_point will be created. + + """ + if frame is None: + frame = ReferenceFrame('inertial_frame') + elif not isinstance(frame, ReferenceFrame): + raise TypeError('Frame must be an instance of ReferenceFrame.') + self._frame = frame + if fixed_point is None: + fixed_point = Point('inertial_point') + elif not isinstance(fixed_point, Point): + raise TypeError('Fixed point must be an instance of Point.') + self._fixed_point = fixed_point + self._fixed_point.set_vel(self._frame, 0) + self._q_ind = ImmutableMatrix(1, 0, []).T + self._q_dep = ImmutableMatrix(1, 0, []).T + self._u_ind = ImmutableMatrix(1, 0, []).T + self._u_dep = ImmutableMatrix(1, 0, []).T + self._u_aux = ImmutableMatrix(1, 0, []).T + self._kdes = ImmutableMatrix(1, 0, []).T + self._hol_coneqs = ImmutableMatrix(1, 0, []).T + self._nonhol_coneqs = ImmutableMatrix(1, 0, []).T + self._vel_constrs = None + self._bodies = [] + self._joints = [] + self._loads = [] + self._actuators = [] + self._eom_method = None + + @classmethod + def from_newtonian(cls, newtonian): + """Constructs the system with respect to a Newtonian body.""" + if isinstance(newtonian, Particle): + raise TypeError('A Particle has no frame so cannot act as ' + 'the Newtonian.') + system = cls(frame=newtonian.frame, fixed_point=newtonian.masscenter) + system.add_bodies(newtonian) + return system + + @property + def fixed_point(self): + """Fixed point in the inertial reference frame.""" + return self._fixed_point + + @property + def frame(self): + """Inertial reference frame of the system.""" + return self._frame + + @property + def x(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.x + + @property + def y(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.y + + @property + def z(self): + """Unit vector fixed in the inertial reference frame.""" + return self._frame.z + + @property + def bodies(self): + """Tuple of all bodies that have been added to the system.""" + return tuple(self._bodies) + + @bodies.setter + @_reset_eom_method + def bodies(self, bodies): + bodies = self._objects_to_list(bodies) + self._check_objects(bodies, [], BodyBase, 'Bodies', 'bodies') + self._bodies = bodies + + @property + def joints(self): + """Tuple of all joints that have been added to the system.""" + return tuple(self._joints) + + @joints.setter + @_reset_eom_method + def joints(self, joints): + joints = self._objects_to_list(joints) + self._check_objects(joints, [], Joint, 'Joints', 'joints') + self._joints = [] + self.add_joints(*joints) + + @property + def loads(self): + """Tuple of loads that have been applied on the system.""" + return tuple(self._loads) + + @loads.setter + @_reset_eom_method + def loads(self, loads): + loads = self._objects_to_list(loads) + self._loads = [_parse_load(load) for load in loads] + + @property + def actuators(self): + """Tuple of actuators present in the system.""" + return tuple(self._actuators) + + @actuators.setter + @_reset_eom_method + def actuators(self, actuators): + actuators = self._objects_to_list(actuators) + self._check_objects(actuators, [], ActuatorBase, 'Actuators', + 'actuators') + self._actuators = actuators + + @property + def q(self): + """Matrix of all the generalized coordinates with the independent + stacked upon the dependent.""" + return self._q_ind.col_join(self._q_dep) + + @property + def u(self): + """Matrix of all the generalized speeds with the independent stacked + upon the dependent.""" + return self._u_ind.col_join(self._u_dep) + + @property + def q_ind(self): + """Matrix of the independent generalized coordinates.""" + return self._q_ind + + @q_ind.setter + @_reset_eom_method + def q_ind(self, q_ind): + self._q_ind, self._q_dep = self._parse_coordinates( + self._objects_to_list(q_ind), True, [], self.q_dep, 'coordinates') + + @property + def q_dep(self): + """Matrix of the dependent generalized coordinates.""" + return self._q_dep + + @q_dep.setter + @_reset_eom_method + def q_dep(self, q_dep): + self._q_ind, self._q_dep = self._parse_coordinates( + self._objects_to_list(q_dep), False, self.q_ind, [], 'coordinates') + + @property + def u_ind(self): + """Matrix of the independent generalized speeds.""" + return self._u_ind + + @u_ind.setter + @_reset_eom_method + def u_ind(self, u_ind): + self._u_ind, self._u_dep = self._parse_coordinates( + self._objects_to_list(u_ind), True, [], self.u_dep, 'speeds') + + @property + def u_dep(self): + """Matrix of the dependent generalized speeds.""" + return self._u_dep + + @u_dep.setter + @_reset_eom_method + def u_dep(self, u_dep): + self._u_ind, self._u_dep = self._parse_coordinates( + self._objects_to_list(u_dep), False, self.u_ind, [], 'speeds') + + @property + def u_aux(self): + """Matrix of auxiliary generalized speeds.""" + return self._u_aux + + @u_aux.setter + @_reset_eom_method + def u_aux(self, u_aux): + self._u_aux = self._parse_coordinates( + self._objects_to_list(u_aux), True, [], [], 'u_auxiliary')[0] + + @property + def kdes(self): + """Kinematical differential equations as expressions equated to the zero + matrix. These equations describe the coupling between the generalized + coordinates and the generalized speeds.""" + return self._kdes + + @kdes.setter + @_reset_eom_method + def kdes(self, kdes): + kdes = self._objects_to_list(kdes) + self._kdes = self._parse_expressions( + kdes, [], 'kinematic differential equations') + + @property + def holonomic_constraints(self): + """Matrix with the holonomic constraints as expressions equated to the + zero matrix.""" + return self._hol_coneqs + + @holonomic_constraints.setter + @_reset_eom_method + def holonomic_constraints(self, constraints): + constraints = self._objects_to_list(constraints) + self._hol_coneqs = self._parse_expressions( + constraints, [], 'holonomic constraints') + + @property + def nonholonomic_constraints(self): + """Matrix with the nonholonomic constraints as expressions equated to + the zero matrix.""" + return self._nonhol_coneqs + + @nonholonomic_constraints.setter + @_reset_eom_method + def nonholonomic_constraints(self, constraints): + constraints = self._objects_to_list(constraints) + self._nonhol_coneqs = self._parse_expressions( + constraints, [], 'nonholonomic constraints') + + @property + def velocity_constraints(self): + """Matrix with the velocity constraints as expressions equated to the + zero matrix. The velocity constraints are by default derived from the + holonomic and nonholonomic constraints unless they are explicitly set. + """ + if self._vel_constrs is None: + return self.holonomic_constraints.diff(dynamicsymbols._t).col_join( + self.nonholonomic_constraints) + return self._vel_constrs + + @velocity_constraints.setter + @_reset_eom_method + def velocity_constraints(self, constraints): + if constraints is None: + self._vel_constrs = None + return + constraints = self._objects_to_list(constraints) + self._vel_constrs = self._parse_expressions( + constraints, [], 'velocity constraints') + + @property + def eom_method(self): + """Backend for forming the equations of motion.""" + return self._eom_method + + @staticmethod + def _objects_to_list(lst): + """Helper to convert passed objects to a list.""" + if not iterable(lst): # Only one object + return [lst] + return list(lst[:]) # converts Matrix and tuple to flattened list + + @staticmethod + def _check_objects(objects, obj_lst, expected_type, obj_name, type_name): + """Helper to check the objects that are being added to the system. + + Explanation + =========== + This method checks that the objects that are being added to the system + are of the correct type and have not already been added. If any of the + objects are not of the correct type or have already been added, then + an error is raised. + + Parameters + ========== + objects : iterable + The objects that would be added to the system. + obj_lst : list + The list of objects that are already in the system. + expected_type : type + The type that the objects should be. + obj_name : str + The name of the category of objects. This string is used to + formulate the error message for the user. + type_name : str + The name of the type that the objects should be. This string is used + to formulate the error message for the user. + + """ + seen = set(obj_lst) + duplicates = set() + wrong_types = set() + for obj in objects: + if not isinstance(obj, expected_type): + wrong_types.add(obj) + if obj in seen: + duplicates.add(obj) + else: + seen.add(obj) + if wrong_types: + raise TypeError(f'{obj_name} {wrong_types} are not {type_name}.') + if duplicates: + raise ValueError(f'{obj_name} {duplicates} have already been added ' + f'to the system.') + + def _parse_coordinates(self, new_coords, independent, old_coords_ind, + old_coords_dep, coord_type='coordinates'): + """Helper to parse coordinates and speeds.""" + # Construct lists of the independent and dependent coordinates + coords_ind, coords_dep = old_coords_ind[:], old_coords_dep[:] + if not iterable(independent): + independent = [independent] * len(new_coords) + for coord, indep in zip(new_coords, independent): + if indep: + coords_ind.append(coord) + else: + coords_dep.append(coord) + # Check types and duplicates + current = {'coordinates': self.q_ind[:] + self.q_dep[:], + 'speeds': self.u_ind[:] + self.u_dep[:], + 'u_auxiliary': self._u_aux[:], + coord_type: coords_ind + coords_dep} + _validate_coordinates(**current) + return (ImmutableMatrix(1, len(coords_ind), coords_ind).T, + ImmutableMatrix(1, len(coords_dep), coords_dep).T) + + @staticmethod + def _parse_expressions(new_expressions, old_expressions, name, + check_negatives=False): + """Helper to parse expressions like constraints.""" + old_expressions = old_expressions[:] + new_expressions = list(new_expressions) # Converts a possible tuple + if check_negatives: + check_exprs = old_expressions + [-expr for expr in old_expressions] + else: + check_exprs = old_expressions + System._check_objects(new_expressions, check_exprs, Basic, name, + 'expressions') + for expr in new_expressions: + if expr == 0: + raise ValueError(f'Parsed {name} are zero.') + return ImmutableMatrix(1, len(old_expressions) + len(new_expressions), + old_expressions + new_expressions).T + + @_reset_eom_method + def add_coordinates(self, *coordinates, independent=True): + """Add generalized coordinate(s) to the system. + + Parameters + ========== + + *coordinates : dynamicsymbols + One or more generalized coordinates to be added to the system. + independent : bool or list of bool, optional + Boolean whether a coordinate is dependent or independent. The + default is True, so the coordinates are added as independent by + default. + + """ + self._q_ind, self._q_dep = self._parse_coordinates( + coordinates, independent, self.q_ind, self.q_dep, 'coordinates') + + @_reset_eom_method + def add_speeds(self, *speeds, independent=True): + """Add generalized speed(s) to the system. + + Parameters + ========== + + *speeds : dynamicsymbols + One or more generalized speeds to be added to the system. + independent : bool or list of bool, optional + Boolean whether a speed is dependent or independent. The default is + True, so the speeds are added as independent by default. + + """ + self._u_ind, self._u_dep = self._parse_coordinates( + speeds, independent, self.u_ind, self.u_dep, 'speeds') + + @_reset_eom_method + def add_auxiliary_speeds(self, *speeds): + """Add auxiliary speed(s) to the system. + + Parameters + ========== + + *speeds : dynamicsymbols + One or more auxiliary speeds to be added to the system. + + """ + self._u_aux = self._parse_coordinates( + speeds, True, self._u_aux, [], 'u_auxiliary')[0] + + @_reset_eom_method + def add_kdes(self, *kdes): + """Add kinematic differential equation(s) to the system. + + Parameters + ========== + + *kdes : Expr + One or more kinematic differential equations. + + """ + self._kdes = self._parse_expressions( + kdes, self.kdes, 'kinematic differential equations', + check_negatives=True) + + @_reset_eom_method + def add_holonomic_constraints(self, *constraints): + """Add holonomic constraint(s) to the system. + + Parameters + ========== + + *constraints : Expr + One or more holonomic constraints, which are expressions that should + be zero. + + """ + self._hol_coneqs = self._parse_expressions( + constraints, self._hol_coneqs, 'holonomic constraints', + check_negatives=True) + + @_reset_eom_method + def add_nonholonomic_constraints(self, *constraints): + """Add nonholonomic constraint(s) to the system. + + Parameters + ========== + + *constraints : Expr + One or more nonholonomic constraints, which are expressions that + should be zero. + + """ + self._nonhol_coneqs = self._parse_expressions( + constraints, self._nonhol_coneqs, 'nonholonomic constraints', + check_negatives=True) + + @_reset_eom_method + def add_bodies(self, *bodies): + """Add body(ies) to the system. + + Parameters + ========== + + bodies : Particle or RigidBody + One or more bodies. + + """ + self._check_objects(bodies, self.bodies, BodyBase, 'Bodies', 'bodies') + self._bodies.extend(bodies) + + @_reset_eom_method + def add_loads(self, *loads): + """Add load(s) to the system. + + Parameters + ========== + + *loads : Force or Torque + One or more loads. + + """ + loads = [_parse_load(load) for load in loads] # Checks the loads + self._loads.extend(loads) + + @_reset_eom_method + def apply_uniform_gravity(self, acceleration): + """Apply uniform gravity to all bodies in the system by adding loads. + + Parameters + ========== + + acceleration : Vector + The acceleration due to gravity. + + """ + self.add_loads(*gravity(acceleration, *self.bodies)) + + @_reset_eom_method + def add_actuators(self, *actuators): + """Add actuator(s) to the system. + + Parameters + ========== + + *actuators : subclass of ActuatorBase + One or more actuators. + + """ + self._check_objects(actuators, self.actuators, ActuatorBase, + 'Actuators', 'actuators') + self._actuators.extend(actuators) + + @_reset_eom_method + def add_joints(self, *joints): + """Add joint(s) to the system. + + Explanation + =========== + + This methods adds one or more joints to the system including its + associated objects, i.e. generalized coordinates, generalized speeds, + kinematic differential equations and the bodies. + + Parameters + ========== + + *joints : subclass of Joint + One or more joints. + + Notes + ===== + + For the generalized coordinates, generalized speeds and bodies it is + checked whether they are already known by the system instance. If they + are, then they are not added. The kinematic differential equations are + however always added to the system, so you should not also manually add + those on beforehand. + + """ + self._check_objects(joints, self.joints, Joint, 'Joints', 'joints') + self._joints.extend(joints) + coordinates, speeds, kdes, bodies = (OrderedSet() for _ in range(4)) + for joint in joints: + coordinates.update(joint.coordinates) + speeds.update(joint.speeds) + kdes.update(joint.kdes) + bodies.update((joint.parent, joint.child)) + coordinates = coordinates.difference(self.q) + speeds = speeds.difference(self.u) + kdes = kdes.difference(self.kdes[:] + (-self.kdes)[:]) + bodies = bodies.difference(self.bodies) + self.add_coordinates(*tuple(coordinates)) + self.add_speeds(*tuple(speeds)) + self.add_kdes(*(kde for kde in tuple(kdes) if not kde == 0)) + self.add_bodies(*tuple(bodies)) + + def get_body(self, name): + """Retrieve a body from the system by name. + + Parameters + ========== + + name : str + The name of the body to retrieve. + + Returns + ======= + + RigidBody or Particle + The body with the given name, or None if no such body exists. + + """ + for body in self._bodies: + if body.name == name: + return body + + def get_joint(self, name): + """Retrieve a joint from the system by name. + + Parameters + ========== + + name : str + The name of the joint to retrieve. + + Returns + ======= + + subclass of Joint + The joint with the given name, or None if no such joint exists. + + """ + for joint in self._joints: + if joint.name == name: + return joint + + def _form_eoms(self): + return self.form_eoms() + + def form_eoms(self, eom_method=KanesMethod, **kwargs): + """Form the equations of motion of the system. + + Parameters + ========== + + eom_method : subclass of KanesMethod or LagrangesMethod + Backend class to be used for forming the equations of motion. The + default is ``KanesMethod``. + + Returns + ======== + + ImmutableMatrix + Vector of equations of motions. + + Examples + ======== + + This is a simple example for a one degree of freedom translational + spring-mass-damper. + + >>> from sympy import S, symbols + >>> from sympy.physics.mechanics import ( + ... LagrangesMethod, dynamicsymbols, PrismaticJoint, Particle, + ... RigidBody, System) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> m, k, b = symbols('m k b') + >>> wall = RigidBody('W') + >>> system = System.from_newtonian(wall) + >>> bob = Particle('P', mass=m) + >>> bob.potential_energy = S.Half * k * q**2 + >>> system.add_joints(PrismaticJoint('J', wall, bob, q, qd)) + >>> system.add_loads((bob.masscenter, b * qd * system.x)) + >>> system.form_eoms(LagrangesMethod) + Matrix([[-b*Derivative(q(t), t) + k*q(t) + m*Derivative(q(t), (t, 2))]]) + + We can also solve for the states using the 'rhs' method. + + >>> system.rhs() + Matrix([ + [ Derivative(q(t), t)], + [(b*Derivative(q(t), t) - k*q(t))/m]]) + + """ + # KanesMethod does not accept empty iterables + loads = self.loads + tuple( + load for act in self.actuators for load in act.to_loads()) + loads = loads if loads else None + if issubclass(eom_method, KanesMethod): + disallowed_kwargs = { + "frame", "q_ind", "u_ind", "kd_eqs", "q_dependent", + "u_dependent", "u_auxiliary", "configuration_constraints", + "velocity_constraints", "forcelist", "bodies"} + wrong_kwargs = disallowed_kwargs.intersection(kwargs) + if wrong_kwargs: + raise ValueError( + f"The following keyword arguments are not allowed to be " + f"overwritten in {eom_method.__name__}: {wrong_kwargs}.") + kwargs = {"frame": self.frame, "q_ind": self.q_ind, + "u_ind": self.u_ind, "kd_eqs": self.kdes, + "q_dependent": self.q_dep, "u_dependent": self.u_dep, + "configuration_constraints": self.holonomic_constraints, + "velocity_constraints": self.velocity_constraints, + "u_auxiliary": self.u_aux, + "forcelist": loads, "bodies": self.bodies, + "explicit_kinematics": False, **kwargs} + self._eom_method = eom_method(**kwargs) + elif issubclass(eom_method, LagrangesMethod): + disallowed_kwargs = { + "frame", "qs", "forcelist", "bodies", "hol_coneqs", + "nonhol_coneqs", "Lagrangian"} + wrong_kwargs = disallowed_kwargs.intersection(kwargs) + if wrong_kwargs: + raise ValueError( + f"The following keyword arguments are not allowed to be " + f"overwritten in {eom_method.__name__}: {wrong_kwargs}.") + kwargs = {"frame": self.frame, "qs": self.q, "forcelist": loads, + "bodies": self.bodies, + "hol_coneqs": self.holonomic_constraints, + "nonhol_coneqs": self.nonholonomic_constraints, **kwargs} + if "Lagrangian" not in kwargs: + kwargs["Lagrangian"] = Lagrangian(kwargs["frame"], + *kwargs["bodies"]) + self._eom_method = eom_method(**kwargs) + else: + raise NotImplementedError(f'{eom_method} has not been implemented.') + return self.eom_method._form_eoms() + + def rhs(self, inv_method=None): + """Compute the equations of motion in the explicit form. + + Parameters + ========== + + inv_method : str + The specific sympy inverse matrix calculation method to use. For a + list of valid methods, see + :meth:`~sympy.matrices.matrixbase.MatrixBase.inv` + + Returns + ======== + + ImmutableMatrix + Equations of motion in the explicit form. + + See Also + ======== + + sympy.physics.mechanics.kane.KanesMethod.rhs: + KanesMethod's ``rhs`` function. + sympy.physics.mechanics.lagrange.LagrangesMethod.rhs: + LagrangesMethod's ``rhs`` function. + + """ + return self.eom_method.rhs(inv_method=inv_method) + + @property + def mass_matrix(self): + r"""The mass matrix of the system. + + Explanation + =========== + + The mass matrix $M_d$ and the forcing vector $f_d$ of a system describe + the system's dynamics according to the following equations: + + .. math:: + M_d \dot{u} = f_d + + where $\dot{u}$ is the time derivative of the generalized speeds. + + """ + return self.eom_method.mass_matrix + + @property + def mass_matrix_full(self): + r"""The mass matrix of the system, augmented by the kinematic + differential equations in explicit or implicit form. + + Explanation + =========== + + The full mass matrix $M_m$ and the full forcing vector $f_m$ of a system + describe the dynamics and kinematics according to the following + equation: + + .. math:: + M_m \dot{x} = f_m + + where $x$ is the state vector stacking $q$ and $u$. + + """ + return self.eom_method.mass_matrix_full + + @property + def forcing(self): + """The forcing vector of the system.""" + return self.eom_method.forcing + + @property + def forcing_full(self): + """The forcing vector of the system, augmented by the kinematic + differential equations in explicit or implicit form.""" + return self.eom_method.forcing_full + + def validate_system(self, eom_method=KanesMethod, check_duplicates=False): + """Validates the system using some basic checks. + + Explanation + =========== + + This method validates the system based on the following checks: + + - The number of dependent generalized coordinates should equal the + number of holonomic constraints. + - All generalized coordinates defined by the joints should also be known + to the system. + - If ``KanesMethod`` is used as a ``eom_method``: + - All generalized speeds and kinematic differential equations + defined by the joints should also be known to the system. + - The number of dependent generalized speeds should equal the number + of velocity constraints. + - The number of generalized coordinates should be less than or equal + to the number of generalized speeds. + - The number of generalized coordinates should equal the number of + kinematic differential equations. + - If ``LagrangesMethod`` is used as ``eom_method``: + - There should not be any generalized speeds that are not + derivatives of the generalized coordinates (this includes the + generalized speeds defined by the joints). + + Parameters + ========== + + eom_method : subclass of KanesMethod or LagrangesMethod + Backend class that will be used for forming the equations of motion. + There are different checks for the different backends. The default + is ``KanesMethod``. + check_duplicates : bool + Boolean whether the system should be checked for duplicate + definitions. The default is False, because duplicates are already + checked when adding objects to the system. + + Notes + ===== + + This method is not guaranteed to be backwards compatible as it may + improve over time. The method can become both more and less strict in + certain areas. However a well-defined system should always pass all + these tests. + + """ + msgs = [] + # Save some data in variables + n_hc = self.holonomic_constraints.shape[0] + n_vc = self.velocity_constraints.shape[0] + n_q_dep, n_u_dep = self.q_dep.shape[0], self.u_dep.shape[0] + q_set, u_set = set(self.q), set(self.u) + n_q, n_u = len(q_set), len(u_set) + # Check number of holonomic constraints + if n_q_dep != n_hc: + msgs.append(filldedent(f""" + The number of dependent generalized coordinates {n_q_dep} should be + equal to the number of holonomic constraints {n_hc}.""")) + # Check if all joint coordinates and speeds are present + missing_q = set() + for joint in self.joints: + missing_q.update(set(joint.coordinates).difference(q_set)) + if missing_q: + msgs.append(filldedent(f""" + The generalized coordinates {missing_q} used in joints are not added + to the system.""")) + # Method dependent checks + if issubclass(eom_method, KanesMethod): + n_kdes = len(self.kdes) + missing_kdes, missing_u = set(), set() + for joint in self.joints: + missing_u.update(set(joint.speeds).difference(u_set)) + missing_kdes.update(set(joint.kdes).difference( + self.kdes[:] + (-self.kdes)[:])) + if missing_u: + msgs.append(filldedent(f""" + The generalized speeds {missing_u} used in joints are not added + to the system.""")) + if missing_kdes: + msgs.append(filldedent(f""" + The kinematic differential equations {missing_kdes} used in + joints are not added to the system.""")) + if n_u_dep != n_vc: + msgs.append(filldedent(f""" + The number of dependent generalized speeds {n_u_dep} should be + equal to the number of velocity constraints {n_vc}.""")) + if n_q > n_u: + msgs.append(filldedent(f""" + The number of generalized coordinates {n_q} should be less than + or equal to the number of generalized speeds {n_u}.""")) + if n_u != n_kdes: + msgs.append(filldedent(f""" + The number of generalized speeds {n_u} should be equal to the + number of kinematic differential equations {n_kdes}.""")) + elif issubclass(eom_method, LagrangesMethod): + not_qdots = set(self.u).difference(self.q.diff(dynamicsymbols._t)) + for joint in self.joints: + not_qdots.update(set( + joint.speeds).difference(self.q.diff(dynamicsymbols._t))) + if not_qdots: + msgs.append(filldedent(f""" + The generalized speeds {not_qdots} are not supported by this + method. Only derivatives of the generalized coordinates are + supported. If these symbols are used in your expressions, then + this will result in wrong equations of motion.""")) + if self.u_aux: + msgs.append(filldedent(f""" + This method does not support auxiliary speeds. If these symbols + are used in your expressions, then this will result in wrong + equations of motion. The auxiliary speeds are {self.u_aux}.""")) + else: + raise NotImplementedError(f'{eom_method} has not been implemented.') + if check_duplicates: # Should be redundant + duplicates_to_check = [('generalized coordinates', self.q), + ('generalized speeds', self.u), + ('auxiliary speeds', self.u_aux), + ('bodies', self.bodies), + ('joints', self.joints)] + for name, lst in duplicates_to_check: + seen = set() + duplicates = {x for x in lst if x in seen or seen.add(x)} + if duplicates: + msgs.append(filldedent(f""" + The {name} {duplicates} exist multiple times within the + system.""")) + if msgs: + raise ValueError('\n'.join(msgs)) + + +class SymbolicSystem: + """SymbolicSystem is a class that contains all the information about a + system in a symbolic format such as the equations of motions and the bodies + and loads in the system. + + There are three ways that the equations of motion can be described for + Symbolic System: + + + [1] Explicit form where the kinematics and dynamics are combined + x' = F_1(x, t, r, p) + + [2] Implicit form where the kinematics and dynamics are combined + M_2(x, p) x' = F_2(x, t, r, p) + + [3] Implicit form where the kinematics and dynamics are separate + M_3(q, p) u' = F_3(q, u, t, r, p) + q' = G(q, u, t, r, p) + + where + + x : states, e.g. [q, u] + t : time + r : specified (exogenous) inputs + p : constants + q : generalized coordinates + u : generalized speeds + F_1 : right hand side of the combined equations in explicit form + F_2 : right hand side of the combined equations in implicit form + F_3 : right hand side of the dynamical equations in implicit form + M_2 : mass matrix of the combined equations in implicit form + M_3 : mass matrix of the dynamical equations in implicit form + G : right hand side of the kinematical differential equations + + Parameters + ========== + + coord_states : ordered iterable of functions of time + This input will either be a collection of the coordinates or states + of the system depending on whether or not the speeds are also + given. If speeds are specified this input will be assumed to + be the coordinates otherwise this input will be assumed to + be the states. + + right_hand_side : Matrix + This variable is the right hand side of the equations of motion in + any of the forms. The specific form will be assumed depending on + whether a mass matrix or coordinate derivatives are given. + + speeds : ordered iterable of functions of time, optional + This is a collection of the generalized speeds of the system. If + given it will be assumed that the first argument (coord_states) + will represent the generalized coordinates of the system. + + mass_matrix : Matrix, optional + The matrix of the implicit forms of the equations of motion (forms + [2] and [3]). The distinction between the forms is determined by + whether or not the coordinate derivatives are passed in. If + they are given form [3] will be assumed otherwise form [2] is + assumed. + + coordinate_derivatives : Matrix, optional + The right hand side of the kinematical equations in explicit form. + If given it will be assumed that the equations of motion are being + entered in form [3]. + + alg_con : Iterable, optional + The indexes of the rows in the equations of motion that contain + algebraic constraints instead of differential equations. If the + equations are input in form [3], it will be assumed the indexes are + referencing the mass_matrix/right_hand_side combination and not the + coordinate_derivatives. + + output_eqns : Dictionary, optional + Any output equations that are desired to be tracked are stored in a + dictionary where the key corresponds to the name given for the + specific equation and the value is the equation itself in symbolic + form + + coord_idxs : Iterable, optional + If coord_states corresponds to the states rather than the + coordinates this variable will tell SymbolicSystem which indexes of + the states correspond to generalized coordinates. + + speed_idxs : Iterable, optional + If coord_states corresponds to the states rather than the + coordinates this variable will tell SymbolicSystem which indexes of + the states correspond to generalized speeds. + + bodies : iterable of Body/Rigidbody objects, optional + Iterable containing the bodies of the system + + loads : iterable of load instances (described below), optional + Iterable containing the loads of the system where forces are given + by (point of application, force vector) and torques are given by + (reference frame acting upon, torque vector). Ex [(point, force), + (ref_frame, torque)] + + Attributes + ========== + + coordinates : Matrix, shape(n, 1) + This is a matrix containing the generalized coordinates of the system + + speeds : Matrix, shape(m, 1) + This is a matrix containing the generalized speeds of the system + + states : Matrix, shape(o, 1) + This is a matrix containing the state variables of the system + + alg_con : List + This list contains the indices of the algebraic constraints in the + combined equations of motion. The presence of these constraints + requires that a DAE solver be used instead of an ODE solver. + If the system is given in form [3] the alg_con variable will be + adjusted such that it is a representation of the combined kinematics + and dynamics thus make sure it always matches the mass matrix + entered. + + dyn_implicit_mat : Matrix, shape(m, m) + This is the M matrix in form [3] of the equations of motion (the mass + matrix or generalized inertia matrix of the dynamical equations of + motion in implicit form). + + dyn_implicit_rhs : Matrix, shape(m, 1) + This is the F vector in form [3] of the equations of motion (the right + hand side of the dynamical equations of motion in implicit form). + + comb_implicit_mat : Matrix, shape(o, o) + This is the M matrix in form [2] of the equations of motion. + This matrix contains a block diagonal structure where the top + left block (the first rows) represent the matrix in the + implicit form of the kinematical equations and the bottom right + block (the last rows) represent the matrix in the implicit form + of the dynamical equations. + + comb_implicit_rhs : Matrix, shape(o, 1) + This is the F vector in form [2] of the equations of motion. The top + part of the vector represents the right hand side of the implicit form + of the kinemaical equations and the bottom of the vector represents the + right hand side of the implicit form of the dynamical equations of + motion. + + comb_explicit_rhs : Matrix, shape(o, 1) + This vector represents the right hand side of the combined equations of + motion in explicit form (form [1] from above). + + kin_explicit_rhs : Matrix, shape(m, 1) + This is the right hand side of the explicit form of the kinematical + equations of motion as can be seen in form [3] (the G matrix). + + output_eqns : Dictionary + If output equations were given they are stored in a dictionary where + the key corresponds to the name given for the specific equation and + the value is the equation itself in symbolic form + + bodies : Tuple + If the bodies in the system were given they are stored in a tuple for + future access + + loads : Tuple + If the loads in the system were given they are stored in a tuple for + future access. This includes forces and torques where forces are given + by (point of application, force vector) and torques are given by + (reference frame acted upon, torque vector). + + Example + ======= + + As a simple example, the dynamics of a simple pendulum will be input into a + SymbolicSystem object manually. First some imports will be needed and then + symbols will be set up for the length of the pendulum (l), mass at the end + of the pendulum (m), and a constant for gravity (g). :: + + >>> from sympy import Matrix, sin, symbols + >>> from sympy.physics.mechanics import dynamicsymbols, SymbolicSystem + >>> l, m, g = symbols('l m g') + + The system will be defined by an angle of theta from the vertical and a + generalized speed of omega will be used where omega = theta_dot. :: + + >>> theta, omega = dynamicsymbols('theta omega') + + Now the equations of motion are ready to be formed and passed to the + SymbolicSystem object. :: + + >>> kin_explicit_rhs = Matrix([omega]) + >>> dyn_implicit_mat = Matrix([l**2 * m]) + >>> dyn_implicit_rhs = Matrix([-g * l * m * sin(theta)]) + >>> symsystem = SymbolicSystem([theta], dyn_implicit_rhs, [omega], + ... dyn_implicit_mat) + + Notes + ===== + + m : number of generalized speeds + n : number of generalized coordinates + o : number of states + + """ + + def __init__(self, coord_states, right_hand_side, speeds=None, + mass_matrix=None, coordinate_derivatives=None, alg_con=None, + output_eqns={}, coord_idxs=None, speed_idxs=None, bodies=None, + loads=None): + """Initializes a SymbolicSystem object""" + + # Extract information on speeds, coordinates and states + if speeds is None: + self._states = Matrix(coord_states) + + if coord_idxs is None: + self._coordinates = None + else: + coords = [coord_states[i] for i in coord_idxs] + self._coordinates = Matrix(coords) + + if speed_idxs is None: + self._speeds = None + else: + speeds_inter = [coord_states[i] for i in speed_idxs] + self._speeds = Matrix(speeds_inter) + else: + self._coordinates = Matrix(coord_states) + self._speeds = Matrix(speeds) + self._states = self._coordinates.col_join(self._speeds) + + # Extract equations of motion form + if coordinate_derivatives is not None: + self._kin_explicit_rhs = coordinate_derivatives + self._dyn_implicit_rhs = right_hand_side + self._dyn_implicit_mat = mass_matrix + self._comb_implicit_rhs = None + self._comb_implicit_mat = None + self._comb_explicit_rhs = None + elif mass_matrix is not None: + self._kin_explicit_rhs = None + self._dyn_implicit_rhs = None + self._dyn_implicit_mat = None + self._comb_implicit_rhs = right_hand_side + self._comb_implicit_mat = mass_matrix + self._comb_explicit_rhs = None + else: + self._kin_explicit_rhs = None + self._dyn_implicit_rhs = None + self._dyn_implicit_mat = None + self._comb_implicit_rhs = None + self._comb_implicit_mat = None + self._comb_explicit_rhs = right_hand_side + + # Set the remainder of the inputs as instance attributes + if alg_con is not None and coordinate_derivatives is not None: + alg_con = [i + len(coordinate_derivatives) for i in alg_con] + self._alg_con = alg_con + self.output_eqns = output_eqns + + # Change the body and loads iterables to tuples if they are not tuples + # already + if not isinstance(bodies, tuple) and bodies is not None: + bodies = tuple(bodies) + if not isinstance(loads, tuple) and loads is not None: + loads = tuple(loads) + self._bodies = bodies + self._loads = loads + + @property + def coordinates(self): + """Returns the column matrix of the generalized coordinates""" + if self._coordinates is None: + raise AttributeError("The coordinates were not specified.") + else: + return self._coordinates + + @property + def speeds(self): + """Returns the column matrix of generalized speeds""" + if self._speeds is None: + raise AttributeError("The speeds were not specified.") + else: + return self._speeds + + @property + def states(self): + """Returns the column matrix of the state variables""" + return self._states + + @property + def alg_con(self): + """Returns a list with the indices of the rows containing algebraic + constraints in the combined form of the equations of motion""" + return self._alg_con + + @property + def dyn_implicit_mat(self): + """Returns the matrix, M, corresponding to the dynamic equations in + implicit form, M x' = F, where the kinematical equations are not + included""" + if self._dyn_implicit_mat is None: + raise AttributeError("dyn_implicit_mat is not specified for " + "equations of motion form [1] or [2].") + else: + return self._dyn_implicit_mat + + @property + def dyn_implicit_rhs(self): + """Returns the column matrix, F, corresponding to the dynamic equations + in implicit form, M x' = F, where the kinematical equations are not + included""" + if self._dyn_implicit_rhs is None: + raise AttributeError("dyn_implicit_rhs is not specified for " + "equations of motion form [1] or [2].") + else: + return self._dyn_implicit_rhs + + @property + def comb_implicit_mat(self): + """Returns the matrix, M, corresponding to the equations of motion in + implicit form (form [2]), M x' = F, where the kinematical equations are + included""" + if self._comb_implicit_mat is None: + if self._dyn_implicit_mat is not None: + num_kin_eqns = len(self._kin_explicit_rhs) + num_dyn_eqns = len(self._dyn_implicit_rhs) + zeros1 = zeros(num_kin_eqns, num_dyn_eqns) + zeros2 = zeros(num_dyn_eqns, num_kin_eqns) + inter1 = eye(num_kin_eqns).row_join(zeros1) + inter2 = zeros2.row_join(self._dyn_implicit_mat) + self._comb_implicit_mat = inter1.col_join(inter2) + return self._comb_implicit_mat + else: + raise AttributeError("comb_implicit_mat is not specified for " + "equations of motion form [1].") + else: + return self._comb_implicit_mat + + @property + def comb_implicit_rhs(self): + """Returns the column matrix, F, corresponding to the equations of + motion in implicit form (form [2]), M x' = F, where the kinematical + equations are included""" + if self._comb_implicit_rhs is None: + if self._dyn_implicit_rhs is not None: + kin_inter = self._kin_explicit_rhs + dyn_inter = self._dyn_implicit_rhs + self._comb_implicit_rhs = kin_inter.col_join(dyn_inter) + return self._comb_implicit_rhs + else: + raise AttributeError("comb_implicit_mat is not specified for " + "equations of motion in form [1].") + else: + return self._comb_implicit_rhs + + def compute_explicit_form(self): + """If the explicit right hand side of the combined equations of motion + is to provided upon initialization, this method will calculate it. This + calculation can potentially take awhile to compute.""" + if self._comb_explicit_rhs is not None: + raise AttributeError("comb_explicit_rhs is already formed.") + + inter1 = getattr(self, 'kin_explicit_rhs', None) + if inter1 is not None: + inter2 = self._dyn_implicit_mat.LUsolve(self._dyn_implicit_rhs) + out = inter1.col_join(inter2) + else: + out = self._comb_implicit_mat.LUsolve(self._comb_implicit_rhs) + + self._comb_explicit_rhs = out + + @property + def comb_explicit_rhs(self): + """Returns the right hand side of the equations of motion in explicit + form, x' = F, where the kinematical equations are included""" + if self._comb_explicit_rhs is None: + raise AttributeError("Please run .combute_explicit_form before " + "attempting to access comb_explicit_rhs.") + else: + return self._comb_explicit_rhs + + @property + def kin_explicit_rhs(self): + """Returns the right hand side of the kinematical equations in explicit + form, q' = G""" + if self._kin_explicit_rhs is None: + raise AttributeError("kin_explicit_rhs is not specified for " + "equations of motion form [1] or [2].") + else: + return self._kin_explicit_rhs + + def dynamic_symbols(self): + """Returns a column matrix containing all of the symbols in the system + that depend on time""" + # Create a list of all of the expressions in the equations of motion + if self._comb_explicit_rhs is None: + eom_expressions = (self.comb_implicit_mat[:] + + self.comb_implicit_rhs[:]) + else: + eom_expressions = (self._comb_explicit_rhs[:]) + + functions_of_time = set() + for expr in eom_expressions: + functions_of_time = functions_of_time.union( + find_dynamicsymbols(expr)) + functions_of_time = functions_of_time.union(self._states) + + return tuple(functions_of_time) + + def constant_symbols(self): + """Returns a column matrix containing all of the symbols in the system + that do not depend on time""" + # Create a list of all of the expressions in the equations of motion + if self._comb_explicit_rhs is None: + eom_expressions = (self.comb_implicit_mat[:] + + self.comb_implicit_rhs[:]) + else: + eom_expressions = (self._comb_explicit_rhs[:]) + + constants = set() + for expr in eom_expressions: + constants = constants.union(expr.free_symbols) + constants.remove(dynamicsymbols._t) + + return tuple(constants) + + @property + def bodies(self): + """Returns the bodies in the system""" + if self._bodies is None: + raise AttributeError("bodies were not specified for the system.") + else: + return self._bodies + + @property + def loads(self): + """Returns the loads in the system""" + if self._loads is None: + raise AttributeError("loads were not specified for the system.") + else: + return self._loads diff --git a/lib/python3.10/site-packages/sympy/physics/mechanics/wrapping_geometry.py b/lib/python3.10/site-packages/sympy/physics/mechanics/wrapping_geometry.py new file mode 100644 index 0000000000000000000000000000000000000000..47ed3c1c463499b024afb9e31cfa2ecd77534132 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/mechanics/wrapping_geometry.py @@ -0,0 +1,641 @@ +"""Geometry objects for use by wrapping pathways.""" + +from abc import ABC, abstractmethod + +from sympy import Integer, acos, pi, sqrt, sympify, tan +from sympy.core.relational import Eq +from sympy.functions.elementary.trigonometric import atan2 +from sympy.polys.polytools import cancel +from sympy.physics.vector import Vector, dot +from sympy.simplify.simplify import trigsimp + + +__all__ = [ + 'WrappingGeometryBase', + 'WrappingCylinder', + 'WrappingSphere', +] + + +class WrappingGeometryBase(ABC): + """Abstract base class for all geometry classes to inherit from. + + Notes + ===== + + Instances of this class cannot be directly instantiated by users. However, + it can be used to created custom geometry types through subclassing. + + """ + + @property + @abstractmethod + def point(cls): + """The point with which the geometry is associated.""" + pass + + @abstractmethod + def point_on_surface(self, point): + """Returns ``True`` if a point is on the geometry's surface. + + Parameters + ========== + point : Point + The point for which it's to be ascertained if it's on the + geometry's surface or not. + + """ + pass + + @abstractmethod + def geodesic_length(self, point_1, point_2): + """Returns the shortest distance between two points on a geometry's + surface. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic length should be calculated. + point_2 : Point + The point to which the geodesic length should be calculated. + + """ + pass + + @abstractmethod + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + pass + + def __repr__(self): + """Default representation of a geometry object.""" + return f'{self.__class__.__name__}()' + + +class WrappingSphere(WrappingGeometryBase): + """A solid spherical object. + + Explanation + =========== + + A wrapping geometry that allows for circular arcs to be defined between + pairs of points. These paths are always geodetic (the shortest possible). + + Examples + ======== + + To create a ``WrappingSphere`` instance, a ``Symbol`` denoting its radius + and ``Point`` at which its center will be located are needed: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import Point, WrappingSphere + >>> r = symbols('r') + >>> pO = Point('pO') + + A sphere with radius ``r`` centered on ``pO`` can be instantiated with: + + >>> WrappingSphere(r, pO) + WrappingSphere(radius=r, point=pO) + + Parameters + ========== + + radius : Symbol + Radius of the sphere. This symbol must represent a value that is + positive and constant, i.e. it cannot be a dynamic symbol, nor can it + be an expression. + point : Point + A point at which the sphere is centered. + + See Also + ======== + + WrappingCylinder: Cylindrical geometry where the wrapping direction can be + defined. + + """ + + def __init__(self, radius, point): + """Initializer for ``WrappingSphere``. + + Parameters + ========== + + radius : Symbol + The radius of the sphere. + point : Point + A point on which the sphere is centered. + + """ + self.radius = radius + self.point = point + + @property + def radius(self): + """Radius of the sphere.""" + return self._radius + + @radius.setter + def radius(self, radius): + self._radius = radius + + @property + def point(self): + """A point on which the sphere is centered.""" + return self._point + + @point.setter + def point(self, point): + self._point = point + + def point_on_surface(self, point): + """Returns ``True`` if a point is on the sphere's surface. + + Parameters + ========== + + point : Point + The point for which it's to be ascertained if it's on the sphere's + surface or not. This point's position relative to the sphere's + center must be a simple expression involving the radius of the + sphere, otherwise this check will likely not work. + + """ + point_vector = point.pos_from(self.point) + if isinstance(point_vector, Vector): + point_radius_squared = dot(point_vector, point_vector) + else: + point_radius_squared = point_vector**2 + return Eq(point_radius_squared, self.radius**2) == True + + def geodesic_length(self, point_1, point_2): + r"""Returns the shortest distance between two points on the sphere's + surface. + + Explanation + =========== + + The geodesic length, i.e. the shortest arc along the surface of a + sphere, connecting two points can be calculated using the formula: + + .. math:: + + l = \arccos\left(\mathbf{v}_1 \cdot \mathbf{v}_2\right) + + where $\mathbf{v}_1$ and $\mathbf{v}_2$ are the unit vectors from the + sphere's center to the first and second points on the sphere's surface + respectively. Note that the actual path that the geodesic will take is + undefined when the two points are directly opposite one another. + + Examples + ======== + + A geodesic length can only be calculated between two points on the + sphere's surface. Firstly, a ``WrappingSphere`` instance must be + created along with two points that will lie on its surface: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingSphere) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> pO.set_vel(N, 0) + >>> sphere = WrappingSphere(r, pO) + >>> p1 = Point('p1') + >>> p2 = Point('p2') + + Let's assume that ``p1`` lies at a distance of ``r`` in the ``N.x`` + direction from ``pO`` and that ``p2`` is located on the sphere's + surface in the ``N.y + N.z`` direction from ``pO``. These positions can + be set with: + + >>> p1.set_pos(pO, r*N.x) + >>> p1.pos_from(pO) + r*N.x + >>> p2.set_pos(pO, r*(N.y + N.z).normalize()) + >>> p2.pos_from(pO) + sqrt(2)*r/2*N.y + sqrt(2)*r/2*N.z + + The geodesic length, which is in this case is a quarter of the sphere's + circumference, can be calculated using the ``geodesic_length`` method: + + >>> sphere.geodesic_length(p1, p2) + pi*r/2 + + If the ``geodesic_length`` method is passed an argument, the ``Point`` + that doesn't lie on the sphere's surface then a ``ValueError`` is + raised because it's not possible to calculate a value in this case. + + Parameters + ========== + + point_1 : Point + Point from which the geodesic length should be calculated. + point_2 : Point + Point to which the geodesic length should be calculated. + + """ + for point in (point_1, point_2): + if not self.point_on_surface(point): + msg = ( + f'Geodesic length cannot be calculated as point {point} ' + f'with radius {point.pos_from(self.point).magnitude()} ' + f'from the sphere\'s center {self.point} does not lie on ' + f'the surface of {self} with radius {self.radius}.' + ) + raise ValueError(msg) + point_1_vector = point_1.pos_from(self.point).normalize() + point_2_vector = point_2.pos_from(self.point).normalize() + central_angle = acos(point_2_vector.dot(point_1_vector)) + geodesic_length = self.radius*central_angle + return geodesic_length + + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + pA, pB = point_1, point_2 + pO = self.point + pA_vec = pA.pos_from(pO) + pB_vec = pB.pos_from(pO) + + if pA_vec.cross(pB_vec) == 0: + msg = ( + f'Can\'t compute geodesic end vectors for the pair of points ' + f'{pA} and {pB} on a sphere {self} as they are diametrically ' + f'opposed, thus the geodesic is not defined.' + ) + raise ValueError(msg) + + return ( + pA_vec.cross(pB.pos_from(pA)).cross(pA_vec).normalize(), + pB_vec.cross(pA.pos_from(pB)).cross(pB_vec).normalize(), + ) + + def __repr__(self): + """Representation of a ``WrappingSphere``.""" + return ( + f'{self.__class__.__name__}(radius={self.radius}, ' + f'point={self.point})' + ) + + +class WrappingCylinder(WrappingGeometryBase): + """A solid (infinite) cylindrical object. + + Explanation + =========== + + A wrapping geometry that allows for circular arcs to be defined between + pairs of points. These paths are always geodetic (the shortest possible) in + the sense that they will be a straight line on the unwrapped cylinder's + surface. However, it is also possible for a direction to be specified, i.e. + paths can be influenced such that they either wrap along the shortest side + or the longest side of the cylinder. To define these directions, rotations + are in the positive direction following the right-hand rule. + + Examples + ======== + + To create a ``WrappingCylinder`` instance, a ``Symbol`` denoting its + radius, a ``Vector`` defining its axis, and a ``Point`` through which its + axis passes are needed: + + >>> from sympy import symbols + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> ax = N.x + + A cylinder with radius ``r``, and axis parallel to ``N.x`` passing through + ``pO`` can be instantiated with: + + >>> WrappingCylinder(r, pO, ax) + WrappingCylinder(radius=r, point=pO, axis=N.x) + + Parameters + ========== + + radius : Symbol + The radius of the cylinder. + point : Point + A point through which the cylinder's axis passes. + axis : Vector + The axis along which the cylinder is aligned. + + See Also + ======== + + WrappingSphere: Spherical geometry where the wrapping direction is always + geodetic. + + """ + + def __init__(self, radius, point, axis): + """Initializer for ``WrappingCylinder``. + + Parameters + ========== + + radius : Symbol + The radius of the cylinder. This symbol must represent a value that + is positive and constant, i.e. it cannot be a dynamic symbol. + point : Point + A point through which the cylinder's axis passes. + axis : Vector + The axis along which the cylinder is aligned. + + """ + self.radius = radius + self.point = point + self.axis = axis + + @property + def radius(self): + """Radius of the cylinder.""" + return self._radius + + @radius.setter + def radius(self, radius): + self._radius = radius + + @property + def point(self): + """A point through which the cylinder's axis passes.""" + return self._point + + @point.setter + def point(self, point): + self._point = point + + @property + def axis(self): + """Axis along which the cylinder is aligned.""" + return self._axis + + @axis.setter + def axis(self, axis): + self._axis = axis.normalize() + + def point_on_surface(self, point): + """Returns ``True`` if a point is on the cylinder's surface. + + Parameters + ========== + + point : Point + The point for which it's to be ascertained if it's on the + cylinder's surface or not. This point's position relative to the + cylinder's axis must be a simple expression involving the radius of + the sphere, otherwise this check will likely not work. + + """ + relative_position = point.pos_from(self.point) + parallel = relative_position.dot(self.axis) * self.axis + point_vector = relative_position - parallel + if isinstance(point_vector, Vector): + point_radius_squared = dot(point_vector, point_vector) + else: + point_radius_squared = point_vector**2 + return Eq(trigsimp(point_radius_squared), self.radius**2) == True + + def geodesic_length(self, point_1, point_2): + """The shortest distance between two points on a geometry's surface. + + Explanation + =========== + + The geodesic length, i.e. the shortest arc along the surface of a + cylinder, connecting two points. It can be calculated using Pythagoras' + theorem. The first short side is the distance between the two points on + the cylinder's surface parallel to the cylinder's axis. The second + short side is the arc of a circle between the two points of the + cylinder's surface perpendicular to the cylinder's axis. The resulting + hypotenuse is the geodesic length. + + Examples + ======== + + A geodesic length can only be calculated between two points on the + cylinder's surface. Firstly, a ``WrappingCylinder`` instance must be + created along with two points that will lie on its surface: + + >>> from sympy import symbols, cos, sin + >>> from sympy.physics.mechanics import (Point, ReferenceFrame, + ... WrappingCylinder, dynamicsymbols) + >>> N = ReferenceFrame('N') + >>> r = symbols('r') + >>> pO = Point('pO') + >>> pO.set_vel(N, 0) + >>> cylinder = WrappingCylinder(r, pO, N.x) + >>> p1 = Point('p1') + >>> p2 = Point('p2') + + Let's assume that ``p1`` is located at ``N.x + r*N.y`` relative to + ``pO`` and that ``p2`` is located at ``r*(cos(q)*N.y + sin(q)*N.z)`` + relative to ``pO``, where ``q(t)`` is a generalized coordinate + specifying the angle rotated around the ``N.x`` axis according to the + right-hand rule where ``N.y`` is zero. These positions can be set with: + + >>> q = dynamicsymbols('q') + >>> p1.set_pos(pO, N.x + r*N.y) + >>> p1.pos_from(pO) + N.x + r*N.y + >>> p2.set_pos(pO, r*(cos(q)*N.y + sin(q)*N.z).normalize()) + >>> p2.pos_from(pO).simplify() + r*cos(q(t))*N.y + r*sin(q(t))*N.z + + The geodesic length, which is in this case a is the hypotenuse of a + right triangle where the other two side lengths are ``1`` (parallel to + the cylinder's axis) and ``r*q(t)`` (parallel to the cylinder's cross + section), can be calculated using the ``geodesic_length`` method: + + >>> cylinder.geodesic_length(p1, p2).simplify() + sqrt(r**2*q(t)**2 + 1) + + If the ``geodesic_length`` method is passed an argument ``Point`` that + doesn't lie on the sphere's surface then a ``ValueError`` is raised + because it's not possible to calculate a value in this case. + + Parameters + ========== + + point_1 : Point + Point from which the geodesic length should be calculated. + point_2 : Point + Point to which the geodesic length should be calculated. + + """ + for point in (point_1, point_2): + if not self.point_on_surface(point): + msg = ( + f'Geodesic length cannot be calculated as point {point} ' + f'with radius {point.pos_from(self.point).magnitude()} ' + f'from the cylinder\'s center {self.point} does not lie on ' + f'the surface of {self} with radius {self.radius} and axis ' + f'{self.axis}.' + ) + raise ValueError(msg) + + relative_position = point_2.pos_from(point_1) + parallel_length = relative_position.dot(self.axis) + + point_1_relative_position = point_1.pos_from(self.point) + point_1_perpendicular_vector = ( + point_1_relative_position + - point_1_relative_position.dot(self.axis)*self.axis + ).normalize() + + point_2_relative_position = point_2.pos_from(self.point) + point_2_perpendicular_vector = ( + point_2_relative_position + - point_2_relative_position.dot(self.axis)*self.axis + ).normalize() + + central_angle = _directional_atan( + cancel(point_1_perpendicular_vector + .cross(point_2_perpendicular_vector) + .dot(self.axis)), + cancel(point_1_perpendicular_vector.dot(point_2_perpendicular_vector)), + ) + + planar_arc_length = self.radius*central_angle + geodesic_length = sqrt(parallel_length**2 + planar_arc_length**2) + return geodesic_length + + def geodesic_end_vectors(self, point_1, point_2): + """The vectors parallel to the geodesic at the two end points. + + Parameters + ========== + + point_1 : Point + The point from which the geodesic originates. + point_2 : Point + The point at which the geodesic terminates. + + """ + point_1_from_origin_point = point_1.pos_from(self.point) + point_2_from_origin_point = point_2.pos_from(self.point) + + if point_1_from_origin_point == point_2_from_origin_point: + msg = ( + f'Cannot compute geodesic end vectors for coincident points ' + f'{point_1} and {point_2} as no geodesic exists.' + ) + raise ValueError(msg) + + point_1_parallel = point_1_from_origin_point.dot(self.axis) * self.axis + point_2_parallel = point_2_from_origin_point.dot(self.axis) * self.axis + point_1_normal = (point_1_from_origin_point - point_1_parallel) + point_2_normal = (point_2_from_origin_point - point_2_parallel) + + if point_1_normal == point_2_normal: + point_1_perpendicular = Vector(0) + point_2_perpendicular = Vector(0) + else: + point_1_perpendicular = self.axis.cross(point_1_normal).normalize() + point_2_perpendicular = -self.axis.cross(point_2_normal).normalize() + + geodesic_length = self.geodesic_length(point_1, point_2) + relative_position = point_2.pos_from(point_1) + parallel_length = relative_position.dot(self.axis) + planar_arc_length = sqrt(geodesic_length**2 - parallel_length**2) + + point_1_vector = ( + planar_arc_length * point_1_perpendicular + + parallel_length * self.axis + ).normalize() + point_2_vector = ( + planar_arc_length * point_2_perpendicular + - parallel_length * self.axis + ).normalize() + + return (point_1_vector, point_2_vector) + + def __repr__(self): + """Representation of a ``WrappingCylinder``.""" + return ( + f'{self.__class__.__name__}(radius={self.radius}, ' + f'point={self.point}, axis={self.axis})' + ) + + +def _directional_atan(numerator, denominator): + """Compute atan in a directional sense as required for geodesics. + + Explanation + =========== + + To be able to control the direction of the geodesic length along the + surface of a cylinder a dedicated arctangent function is needed that + properly handles the directionality of different case. This function + ensures that the central angle is always positive but shifting the case + where ``atan2`` would return a negative angle to be centered around + ``2*pi``. + + Notes + ===== + + This function only handles very specific cases, i.e. the ones that are + expected to be encountered when calculating symbolic geodesics on uniformly + curved surfaces. As such, ``NotImplemented`` errors can be raised in many + cases. This function is named with a leader underscore to indicate that it + only aims to provide very specific functionality within the private scope + of this module. + + """ + + if numerator.is_number and denominator.is_number: + angle = atan2(numerator, denominator) + if angle < 0: + angle += 2 * pi + elif numerator.is_number: + msg = ( + f'Cannot compute a directional atan when the numerator {numerator} ' + f'is numeric and the denominator {denominator} is symbolic.' + ) + raise NotImplementedError(msg) + elif denominator.is_number: + msg = ( + f'Cannot compute a directional atan when the numerator {numerator} ' + f'is symbolic and the denominator {denominator} is numeric.' + ) + raise NotImplementedError(msg) + else: + ratio = sympify(trigsimp(numerator / denominator)) + if isinstance(ratio, tan): + angle = ratio.args[0] + elif ( + ratio.is_Mul + and ratio.args[0] == Integer(-1) + and isinstance(ratio.args[1], tan) + ): + angle = 2 * pi - ratio.args[1].args[0] + else: + msg = f'Cannot compute a directional atan for the value {ratio}.' + raise NotImplementedError(msg) + + return angle diff --git a/lib/python3.10/site-packages/sympy/physics/optics/__init__.py b/lib/python3.10/site-packages/sympy/physics/optics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2d83d452fd30e718546c0eac26fe03bbef59c06 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/__init__.py @@ -0,0 +1,38 @@ +__all__ = [ + 'TWave', + + 'RayTransferMatrix', 'FreeSpace', 'FlatRefraction', 'CurvedRefraction', + 'FlatMirror', 'CurvedMirror', 'ThinLens', 'GeometricRay', 'BeamParameter', + 'waist2rayleigh', 'rayleigh2waist', 'geometric_conj_ab', + 'geometric_conj_af', 'geometric_conj_bf', 'gaussian_conj', + 'conjugate_gauss_beams', + + 'Medium', + + 'refraction_angle', 'deviation', 'fresnel_coefficients', 'brewster_angle', + 'critical_angle', 'lens_makers_formula', 'mirror_formula', 'lens_formula', + 'hyperfocal_distance', 'transverse_magnification', + + 'jones_vector', 'stokes_vector', 'jones_2_stokes', 'linear_polarizer', + 'phase_retarder', 'half_wave_retarder', 'quarter_wave_retarder', + 'transmissive_filter', 'reflective_filter', 'mueller_matrix', + 'polarizing_beam_splitter', +] +from .waves import TWave + +from .gaussopt import (RayTransferMatrix, FreeSpace, FlatRefraction, + CurvedRefraction, FlatMirror, CurvedMirror, ThinLens, GeometricRay, + BeamParameter, waist2rayleigh, rayleigh2waist, geometric_conj_ab, + geometric_conj_af, geometric_conj_bf, gaussian_conj, + conjugate_gauss_beams) + +from .medium import Medium + +from .utils import (refraction_angle, deviation, fresnel_coefficients, + brewster_angle, critical_angle, lens_makers_formula, mirror_formula, + lens_formula, hyperfocal_distance, transverse_magnification) + +from .polarization import (jones_vector, stokes_vector, jones_2_stokes, + linear_polarizer, phase_retarder, half_wave_retarder, + quarter_wave_retarder, transmissive_filter, reflective_filter, + mueller_matrix, polarizing_beam_splitter) diff --git a/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac4c10028396f860b37111ddb8e8d75b6d307a5b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/gaussopt.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/gaussopt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..303bef9d6fd1c4a26af2f7289a3121b2e30d240a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/gaussopt.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/medium.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/medium.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..941afb22c36f6f7ee5bc18f839ee935044a9b884 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/medium.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/polarization.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/polarization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8af1f03544dac43e2afa8d421d2d1aa09930008e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/polarization.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/utils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb7064ccf3f94d2e1b3705e35c1c65f800dbb201 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/waves.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/waves.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a92afa07ecddc8b13f8a4c403eb7821479f513d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/__pycache__/waves.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/gaussopt.py b/lib/python3.10/site-packages/sympy/physics/optics/gaussopt.py new file mode 100644 index 0000000000000000000000000000000000000000..d9e8ef555d60e3204341cdc65cdd05fb02b2f196 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/gaussopt.py @@ -0,0 +1,923 @@ +""" +Gaussian optics. + +The module implements: + +- Ray transfer matrices for geometrical and gaussian optics. + + See RayTransferMatrix, GeometricRay and BeamParameter + +- Conjugation relations for geometrical and gaussian optics. + + See geometric_conj*, gauss_conj and conjugate_gauss_beams + +The conventions for the distances are as follows: + +focal distance + positive for convergent lenses +object distance + positive for real objects +image distance + positive for real images +""" + +__all__ = [ + 'RayTransferMatrix', + 'FreeSpace', + 'FlatRefraction', + 'CurvedRefraction', + 'FlatMirror', + 'CurvedMirror', + 'ThinLens', + 'GeometricRay', + 'BeamParameter', + 'waist2rayleigh', + 'rayleigh2waist', + 'geometric_conj_ab', + 'geometric_conj_af', + 'geometric_conj_bf', + 'gaussian_conj', + 'conjugate_gauss_beams', +] + + +from sympy.core.expr import Expr +from sympy.core.numbers import (I, pi) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan2 +from sympy.matrices.dense import Matrix, MutableDenseMatrix +from sympy.polys.rationaltools import together +from sympy.utilities.misc import filldedent + +### +# A, B, C, D matrices +### + + +class RayTransferMatrix(MutableDenseMatrix): + """ + Base class for a Ray Transfer Matrix. + + It should be used if there is not already a more specific subclass mentioned + in See Also. + + Parameters + ========== + + parameters : + A, B, C and D or 2x2 matrix (Matrix(2, 2, [A, B, C, D])) + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix, ThinLens + >>> from sympy import Symbol, Matrix + + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat + Matrix([ + [1, 2], + [3, 4]]) + + >>> RayTransferMatrix(Matrix([[1, 2], [3, 4]])) + Matrix([ + [1, 2], + [3, 4]]) + + >>> mat.A + 1 + + >>> f = Symbol('f') + >>> lens = ThinLens(f) + >>> lens + Matrix([ + [ 1, 0], + [-1/f, 1]]) + + >>> lens.C + -1/f + + See Also + ======== + + GeometricRay, BeamParameter, + FreeSpace, FlatRefraction, CurvedRefraction, + FlatMirror, CurvedMirror, ThinLens + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Ray_transfer_matrix_analysis + """ + + def __new__(cls, *args): + + if len(args) == 4: + temp = ((args[0], args[1]), (args[2], args[3])) + elif len(args) == 1 \ + and isinstance(args[0], Matrix) \ + and args[0].shape == (2, 2): + temp = args[0] + else: + raise ValueError(filldedent(''' + Expecting 2x2 Matrix or the 4 elements of + the Matrix but got %s''' % str(args))) + return Matrix.__new__(cls, temp) + + def __mul__(self, other): + if isinstance(other, RayTransferMatrix): + return RayTransferMatrix(Matrix(self)*Matrix(other)) + elif isinstance(other, GeometricRay): + return GeometricRay(Matrix(self)*Matrix(other)) + elif isinstance(other, BeamParameter): + temp = Matrix(self)*Matrix(((other.q,), (1,))) + q = (temp[0]/temp[1]).expand(complex=True) + return BeamParameter(other.wavelen, + together(re(q)), + z_r=together(im(q))) + else: + return Matrix.__mul__(self, other) + + @property + def A(self): + """ + The A parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.A + 1 + """ + return self[0, 0] + + @property + def B(self): + """ + The B parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.B + 2 + """ + return self[0, 1] + + @property + def C(self): + """ + The C parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.C + 3 + """ + return self[1, 0] + + @property + def D(self): + """ + The D parameter of the Matrix. + + Examples + ======== + + >>> from sympy.physics.optics import RayTransferMatrix + >>> mat = RayTransferMatrix(1, 2, 3, 4) + >>> mat.D + 4 + """ + return self[1, 1] + + +class FreeSpace(RayTransferMatrix): + """ + Ray Transfer Matrix for free space. + + Parameters + ========== + + distance + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FreeSpace + >>> from sympy import symbols + >>> d = symbols('d') + >>> FreeSpace(d) + Matrix([ + [1, d], + [0, 1]]) + """ + def __new__(cls, d): + return RayTransferMatrix.__new__(cls, 1, d, 0, 1) + + +class FlatRefraction(RayTransferMatrix): + """ + Ray Transfer Matrix for refraction. + + Parameters + ========== + + n1 : + Refractive index of one medium. + n2 : + Refractive index of other medium. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FlatRefraction + >>> from sympy import symbols + >>> n1, n2 = symbols('n1 n2') + >>> FlatRefraction(n1, n2) + Matrix([ + [1, 0], + [0, n1/n2]]) + """ + def __new__(cls, n1, n2): + n1, n2 = map(sympify, (n1, n2)) + return RayTransferMatrix.__new__(cls, 1, 0, 0, n1/n2) + + +class CurvedRefraction(RayTransferMatrix): + """ + Ray Transfer Matrix for refraction on curved interface. + + Parameters + ========== + + R : + Radius of curvature (positive for concave). + n1 : + Refractive index of one medium. + n2 : + Refractive index of other medium. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import CurvedRefraction + >>> from sympy import symbols + >>> R, n1, n2 = symbols('R n1 n2') + >>> CurvedRefraction(R, n1, n2) + Matrix([ + [ 1, 0], + [(n1 - n2)/(R*n2), n1/n2]]) + """ + def __new__(cls, R, n1, n2): + R, n1, n2 = map(sympify, (R, n1, n2)) + return RayTransferMatrix.__new__(cls, 1, 0, (n1 - n2)/R/n2, n1/n2) + + +class FlatMirror(RayTransferMatrix): + """ + Ray Transfer Matrix for reflection. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import FlatMirror + >>> FlatMirror() + Matrix([ + [1, 0], + [0, 1]]) + """ + def __new__(cls): + return RayTransferMatrix.__new__(cls, 1, 0, 0, 1) + + +class CurvedMirror(RayTransferMatrix): + """ + Ray Transfer Matrix for reflection from curved surface. + + Parameters + ========== + + R : radius of curvature (positive for concave) + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import CurvedMirror + >>> from sympy import symbols + >>> R = symbols('R') + >>> CurvedMirror(R) + Matrix([ + [ 1, 0], + [-2/R, 1]]) + """ + def __new__(cls, R): + R = sympify(R) + return RayTransferMatrix.__new__(cls, 1, 0, -2/R, 1) + + +class ThinLens(RayTransferMatrix): + """ + Ray Transfer Matrix for a thin lens. + + Parameters + ========== + + f : + The focal distance. + + See Also + ======== + + RayTransferMatrix + + Examples + ======== + + >>> from sympy.physics.optics import ThinLens + >>> from sympy import symbols + >>> f = symbols('f') + >>> ThinLens(f) + Matrix([ + [ 1, 0], + [-1/f, 1]]) + """ + def __new__(cls, f): + f = sympify(f) + return RayTransferMatrix.__new__(cls, 1, 0, -1/f, 1) + + +### +# Representation for geometric ray +### + +class GeometricRay(MutableDenseMatrix): + """ + Representation for a geometric ray in the Ray Transfer Matrix formalism. + + Parameters + ========== + + h : height, and + angle : angle, or + matrix : a 2x1 matrix (Matrix(2, 1, [height, angle])) + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay, FreeSpace + >>> from sympy import symbols, Matrix + >>> d, h, angle = symbols('d, h, angle') + + >>> GeometricRay(h, angle) + Matrix([ + [ h], + [angle]]) + + >>> FreeSpace(d)*GeometricRay(h, angle) + Matrix([ + [angle*d + h], + [ angle]]) + + >>> GeometricRay( Matrix( ((h,), (angle,)) ) ) + Matrix([ + [ h], + [angle]]) + + See Also + ======== + + RayTransferMatrix + + """ + + def __new__(cls, *args): + if len(args) == 1 and isinstance(args[0], Matrix) \ + and args[0].shape == (2, 1): + temp = args[0] + elif len(args) == 2: + temp = ((args[0],), (args[1],)) + else: + raise ValueError(filldedent(''' + Expecting 2x1 Matrix or the 2 elements of + the Matrix but got %s''' % str(args))) + return Matrix.__new__(cls, temp) + + @property + def height(self): + """ + The distance from the optical axis. + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay + >>> from sympy import symbols + >>> h, angle = symbols('h, angle') + >>> gRay = GeometricRay(h, angle) + >>> gRay.height + h + """ + return self[0] + + @property + def angle(self): + """ + The angle with the optical axis. + + Examples + ======== + + >>> from sympy.physics.optics import GeometricRay + >>> from sympy import symbols + >>> h, angle = symbols('h, angle') + >>> gRay = GeometricRay(h, angle) + >>> gRay.angle + angle + """ + return self[1] + + +### +# Representation for gauss beam +### + +class BeamParameter(Expr): + """ + Representation for a gaussian ray in the Ray Transfer Matrix formalism. + + Parameters + ========== + + wavelen : the wavelength, + z : the distance to waist, and + w : the waist, or + z_r : the rayleigh range. + n : the refractive index of medium. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.q + 1 + 1.88679245283019*I*pi + + >>> p.q.n() + 1.0 + 5.92753330865999*I + >>> p.w_0.n() + 0.00100000000000000 + >>> p.z_r.n() + 5.92753330865999 + + >>> from sympy.physics.optics import FreeSpace + >>> fs = FreeSpace(10) + >>> p1 = fs*p + >>> p.w.n() + 0.00101413072159615 + >>> p1.w.n() + 0.00210803120913829 + + See Also + ======== + + RayTransferMatrix + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Complex_beam_parameter + .. [2] https://en.wikipedia.org/wiki/Gaussian_beam + """ + #TODO A class Complex may be implemented. The BeamParameter may + # subclass it. See: + # https://groups.google.com/d/topic/sympy/7XkU07NRBEs/discussion + + def __new__(cls, wavelen, z, z_r=None, w=None, n=1): + wavelen = sympify(wavelen) + z = sympify(z) + n = sympify(n) + + if z_r is not None and w is None: + z_r = sympify(z_r) + elif w is not None and z_r is None: + z_r = waist2rayleigh(sympify(w), wavelen, n) + elif z_r is None and w is None: + raise ValueError('Must specify one of w and z_r.') + + return Expr.__new__(cls, wavelen, z, z_r, n) + + @property + def wavelen(self): + return self.args[0] + + @property + def z(self): + return self.args[1] + + @property + def z_r(self): + return self.args[2] + + @property + def n(self): + return self.args[3] + + @property + def q(self): + """ + The complex parameter representing the beam. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.q + 1 + 1.88679245283019*I*pi + """ + return self.z + I*self.z_r + + @property + def radius(self): + """ + The radius of curvature of the phase front. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.radius + 1 + 3.55998576005696*pi**2 + """ + return self.z*(1 + (self.z_r/self.z)**2) + + @property + def w(self): + """ + The radius of the beam w(z), at any position z along the beam. + The beam radius at `1/e^2` intensity (axial value). + + See Also + ======== + + w_0 : + The minimal radius of beam. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.w + 0.001*sqrt(0.2809/pi**2 + 1) + """ + return self.w_0*sqrt(1 + (self.z/self.z_r)**2) + + @property + def w_0(self): + """ + The minimal radius of beam at `1/e^2` intensity (peak value). + + See Also + ======== + + w : the beam radius at `1/e^2` intensity (axial value). + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.w_0 + 0.00100000000000000 + """ + return sqrt(self.z_r/(pi*self.n)*self.wavelen) + + @property + def divergence(self): + """ + Half of the total angular spread. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.divergence + 0.00053/pi + """ + return self.wavelen/pi/self.w_0 + + @property + def gouy(self): + """ + The Gouy phase. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.gouy + atan(0.53/pi) + """ + return atan2(self.z, self.z_r) + + @property + def waist_approximation_limit(self): + """ + The minimal waist for which the gauss beam approximation is valid. + + Explanation + =========== + + The gauss beam is a solution to the paraxial equation. For curvatures + that are too great it is not a valid approximation. + + Examples + ======== + + >>> from sympy.physics.optics import BeamParameter + >>> p = BeamParameter(530e-9, 1, w=1e-3) + >>> p.waist_approximation_limit + 1.06e-6/pi + """ + return 2*self.wavelen/pi + + +### +# Utilities +### + +def waist2rayleigh(w, wavelen, n=1): + """ + Calculate the rayleigh range from the waist of a gaussian beam. + + See Also + ======== + + rayleigh2waist, BeamParameter + + Examples + ======== + + >>> from sympy.physics.optics import waist2rayleigh + >>> from sympy import symbols + >>> w, wavelen = symbols('w wavelen') + >>> waist2rayleigh(w, wavelen) + pi*w**2/wavelen + """ + w, wavelen = map(sympify, (w, wavelen)) + return w**2*n*pi/wavelen + + +def rayleigh2waist(z_r, wavelen): + """Calculate the waist from the rayleigh range of a gaussian beam. + + See Also + ======== + + waist2rayleigh, BeamParameter + + Examples + ======== + + >>> from sympy.physics.optics import rayleigh2waist + >>> from sympy import symbols + >>> z_r, wavelen = symbols('z_r wavelen') + >>> rayleigh2waist(z_r, wavelen) + sqrt(wavelen*z_r)/sqrt(pi) + """ + z_r, wavelen = map(sympify, (z_r, wavelen)) + return sqrt(z_r/pi*wavelen) + + +def geometric_conj_ab(a, b): + """ + Conjugation relation for geometrical beams under paraxial conditions. + + Explanation + =========== + + Takes the distances to the optical element and returns the needed + focal distance. + + See Also + ======== + + geometric_conj_af, geometric_conj_bf + + Examples + ======== + + >>> from sympy.physics.optics import geometric_conj_ab + >>> from sympy import symbols + >>> a, b = symbols('a b') + >>> geometric_conj_ab(a, b) + a*b/(a + b) + """ + a, b = map(sympify, (a, b)) + if a.is_infinite or b.is_infinite: + return a if b.is_infinite else b + else: + return a*b/(a + b) + + +def geometric_conj_af(a, f): + """ + Conjugation relation for geometrical beams under paraxial conditions. + + Explanation + =========== + + Takes the object distance (for geometric_conj_af) or the image distance + (for geometric_conj_bf) to the optical element and the focal distance. + Then it returns the other distance needed for conjugation. + + See Also + ======== + + geometric_conj_ab + + Examples + ======== + + >>> from sympy.physics.optics.gaussopt import geometric_conj_af, geometric_conj_bf + >>> from sympy import symbols + >>> a, b, f = symbols('a b f') + >>> geometric_conj_af(a, f) + a*f/(a - f) + >>> geometric_conj_bf(b, f) + b*f/(b - f) + """ + a, f = map(sympify, (a, f)) + return -geometric_conj_ab(a, -f) + +geometric_conj_bf = geometric_conj_af + + +def gaussian_conj(s_in, z_r_in, f): + """ + Conjugation relation for gaussian beams. + + Parameters + ========== + + s_in : + The distance to optical element from the waist. + z_r_in : + The rayleigh range of the incident beam. + f : + The focal length of the optical element. + + Returns + ======= + + a tuple containing (s_out, z_r_out, m) + s_out : + The distance between the new waist and the optical element. + z_r_out : + The rayleigh range of the emergent beam. + m : + The ration between the new and the old waists. + + Examples + ======== + + >>> from sympy.physics.optics import gaussian_conj + >>> from sympy import symbols + >>> s_in, z_r_in, f = symbols('s_in z_r_in f') + + >>> gaussian_conj(s_in, z_r_in, f)[0] + 1/(-1/(s_in + z_r_in**2/(-f + s_in)) + 1/f) + + >>> gaussian_conj(s_in, z_r_in, f)[1] + z_r_in/(1 - s_in**2/f**2 + z_r_in**2/f**2) + + >>> gaussian_conj(s_in, z_r_in, f)[2] + 1/sqrt(1 - s_in**2/f**2 + z_r_in**2/f**2) + """ + s_in, z_r_in, f = map(sympify, (s_in, z_r_in, f)) + s_out = 1 / ( -1/(s_in + z_r_in**2/(s_in - f)) + 1/f ) + m = 1/sqrt((1 - (s_in/f)**2) + (z_r_in/f)**2) + z_r_out = z_r_in / ((1 - (s_in/f)**2) + (z_r_in/f)**2) + return (s_out, z_r_out, m) + + +def conjugate_gauss_beams(wavelen, waist_in, waist_out, **kwargs): + """ + Find the optical setup conjugating the object/image waists. + + Parameters + ========== + + wavelen : + The wavelength of the beam. + waist_in and waist_out : + The waists to be conjugated. + f : + The focal distance of the element used in the conjugation. + + Returns + ======= + + a tuple containing (s_in, s_out, f) + s_in : + The distance before the optical element. + s_out : + The distance after the optical element. + f : + The focal distance of the optical element. + + Examples + ======== + + >>> from sympy.physics.optics import conjugate_gauss_beams + >>> from sympy import symbols, factor + >>> l, w_i, w_o, f = symbols('l w_i w_o f') + + >>> conjugate_gauss_beams(l, w_i, w_o, f=f)[0] + f*(1 - sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2))) + + >>> factor(conjugate_gauss_beams(l, w_i, w_o, f=f)[1]) + f*w_o**2*(w_i**2/w_o**2 - sqrt(w_i**2/w_o**2 - + pi**2*w_i**4/(f**2*l**2)))/w_i**2 + + >>> conjugate_gauss_beams(l, w_i, w_o, f=f)[2] + f + """ + #TODO add the other possible arguments + wavelen, waist_in, waist_out = map(sympify, (wavelen, waist_in, waist_out)) + m = waist_out / waist_in + z = waist2rayleigh(waist_in, wavelen) + if len(kwargs) != 1: + raise ValueError("The function expects only one named argument") + elif 'dist' in kwargs: + raise NotImplementedError(filldedent(''' + Currently only focal length is supported as a parameter''')) + elif 'f' in kwargs: + f = sympify(kwargs['f']) + s_in = f * (1 - sqrt(1/m**2 - z**2/f**2)) + s_out = gaussian_conj(s_in, z, f)[0] + elif 's_in' in kwargs: + raise NotImplementedError(filldedent(''' + Currently only focal length is supported as a parameter''')) + else: + raise ValueError(filldedent(''' + The functions expects the focal length as a named argument''')) + return (s_in, s_out, f) + +#TODO +#def plot_beam(): +# """Plot the beam radius as it propagates in space.""" +# pass + +#TODO +#def plot_beam_conjugation(): +# """ +# Plot the intersection of two beams. +# +# Represents the conjugation relation. +# +# See Also +# ======== +# +# conjugate_gauss_beams +# """ +# pass diff --git a/lib/python3.10/site-packages/sympy/physics/optics/medium.py b/lib/python3.10/site-packages/sympy/physics/optics/medium.py new file mode 100644 index 0000000000000000000000000000000000000000..764b68caad5865b8f3cee028a14cfa304796b4c0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/medium.py @@ -0,0 +1,253 @@ +""" +**Contains** + +* Medium +""" +from sympy.physics.units import second, meter, kilogram, ampere + +__all__ = ['Medium'] + +from sympy.core.basic import Basic +from sympy.core.symbol import Str +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.units import speed_of_light, u0, e0 + + +c = speed_of_light.convert_to(meter/second) +_e0mksa = e0.convert_to(ampere**2*second**4/(kilogram*meter**3)) +_u0mksa = u0.convert_to(meter*kilogram/(ampere**2*second**2)) + + +class Medium(Basic): + + """ + This class represents an optical medium. The prime reason to implement this is + to facilitate refraction, Fermat's principle, etc. + + Explanation + =========== + + An optical medium is a material through which electromagnetic waves propagate. + The permittivity and permeability of the medium define how electromagnetic + waves propagate in it. + + + Parameters + ========== + + name: string + The display name of the Medium. + + permittivity: Sympifyable + Electric permittivity of the space. + + permeability: Sympifyable + Magnetic permeability of the space. + + n: Sympifyable + Index of refraction of the medium. + + + Examples + ======== + + >>> from sympy.abc import epsilon, mu + >>> from sympy.physics.optics import Medium + >>> m1 = Medium('m1') + >>> m2 = Medium('m2', epsilon, mu) + >>> m1.intrinsic_impedance + 149896229*pi*kilogram*meter**2/(1250000*ampere**2*second**3) + >>> m2.refractive_index + 299792458*meter*sqrt(epsilon*mu)/second + + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Optical_medium + + """ + + def __new__(cls, name, permittivity=None, permeability=None, n=None): + if not isinstance(name, Str): + name = Str(name) + + permittivity = _sympify(permittivity) if permittivity is not None else permittivity + permeability = _sympify(permeability) if permeability is not None else permeability + n = _sympify(n) if n is not None else n + + if n is not None: + if permittivity is not None and permeability is None: + permeability = n**2/(c**2*permittivity) + return MediumPP(name, permittivity, permeability) + elif permeability is not None and permittivity is None: + permittivity = n**2/(c**2*permeability) + return MediumPP(name, permittivity, permeability) + elif permittivity is not None and permittivity is not None: + raise ValueError("Specifying all of permittivity, permeability, and n is not allowed") + else: + return MediumN(name, n) + elif permittivity is not None and permeability is not None: + return MediumPP(name, permittivity, permeability) + elif permittivity is None and permeability is None: + return MediumPP(name, _e0mksa, _u0mksa) + else: + raise ValueError("Arguments are underspecified. Either specify n or any two of permittivity, " + "permeability, and n") + + @property + def name(self): + return self.args[0] + + @property + def speed(self): + """ + Returns speed of the electromagnetic wave travelling in the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.speed + 299792458*meter/second + >>> m2 = Medium('m2', n=1) + >>> m.speed == m2.speed + True + + """ + return c / self.n + + @property + def refractive_index(self): + """ + Returns refractive index of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.refractive_index + 1 + + """ + return (c/self.speed) + + +class MediumN(Medium): + + """ + Represents an optical medium for which only the refractive index is known. + Useful for simple ray optics. + + This class should never be instantiated directly. + Instead it should be instantiated indirectly by instantiating Medium with + only n specified. + + Examples + ======== + >>> from sympy.physics.optics import Medium + >>> m = Medium('m', n=2) + >>> m + MediumN(Str('m'), 2) + """ + + def __new__(cls, name, n): + obj = super(Medium, cls).__new__(cls, name, n) + return obj + + @property + def n(self): + return self.args[1] + + +class MediumPP(Medium): + """ + Represents an optical medium for which the permittivity and permeability are known. + + This class should never be instantiated directly. Instead it should be + instantiated indirectly by instantiating Medium with any two of + permittivity, permeability, and n specified, or by not specifying any + of permittivity, permeability, or n, in which case default values for + permittivity and permeability will be used. + + Examples + ======== + >>> from sympy.physics.optics import Medium + >>> from sympy.abc import epsilon, mu + >>> m1 = Medium('m1', permittivity=epsilon, permeability=mu) + >>> m1 + MediumPP(Str('m1'), epsilon, mu) + >>> m2 = Medium('m2') + >>> m2 + MediumPP(Str('m2'), 625000*ampere**2*second**4/(22468879468420441*pi*kilogram*meter**3), pi*kilogram*meter/(2500000*ampere**2*second**2)) + """ + + + def __new__(cls, name, permittivity, permeability): + obj = super(Medium, cls).__new__(cls, name, permittivity, permeability) + return obj + + @property + def intrinsic_impedance(self): + """ + Returns intrinsic impedance of the medium. + + Explanation + =========== + + The intrinsic impedance of a medium is the ratio of the + transverse components of the electric and magnetic fields + of the electromagnetic wave travelling in the medium. + In a region with no electrical conductivity it simplifies + to the square root of ratio of magnetic permeability to + electric permittivity. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.intrinsic_impedance + 149896229*pi*kilogram*meter**2/(1250000*ampere**2*second**3) + + """ + return sqrt(self.permeability / self.permittivity) + + @property + def permittivity(self): + """ + Returns electric permittivity of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.permittivity + 625000*ampere**2*second**4/(22468879468420441*pi*kilogram*meter**3) + + """ + return self.args[1] + + @property + def permeability(self): + """ + Returns magnetic permeability of the medium. + + Examples + ======== + + >>> from sympy.physics.optics import Medium + >>> m = Medium('m') + >>> m.permeability + pi*kilogram*meter/(2500000*ampere**2*second**2) + + """ + return self.args[2] + + @property + def n(self): + return c*sqrt(self.permittivity*self.permeability) diff --git a/lib/python3.10/site-packages/sympy/physics/optics/polarization.py b/lib/python3.10/site-packages/sympy/physics/optics/polarization.py new file mode 100644 index 0000000000000000000000000000000000000000..0bdb546548ad082ef38f5f0c159d7eadd38f6d30 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/polarization.py @@ -0,0 +1,732 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +The module implements routines to model the polarization of optical fields +and can be used to calculate the effects of polarization optical elements on +the fields. + +- Jones vectors. + +- Stokes vectors. + +- Jones matrices. + +- Mueller matrices. + +Examples +======== + +We calculate a generic Jones vector: + +>>> from sympy import symbols, pprint, zeros, simplify +>>> from sympy.physics.optics.polarization import (jones_vector, stokes_vector, +... half_wave_retarder, polarizing_beam_splitter, jones_2_stokes) + +>>> psi, chi, p, I0 = symbols("psi, chi, p, I0", real=True) +>>> x0 = jones_vector(psi, chi) +>>> pprint(x0, use_unicode=True) +⎡-ⅈ⋅sin(χ)⋅sin(ψ) + cos(χ)⋅cos(ψ)⎤ +⎢ ⎥ +⎣ⅈ⋅sin(χ)⋅cos(ψ) + sin(ψ)⋅cos(χ) ⎦ + +And the more general Stokes vector: + +>>> s0 = stokes_vector(psi, chi, p, I0) +>>> pprint(s0, use_unicode=True) +⎡ I₀ ⎤ +⎢ ⎥ +⎢I₀⋅p⋅cos(2⋅χ)⋅cos(2⋅ψ)⎥ +⎢ ⎥ +⎢I₀⋅p⋅sin(2⋅ψ)⋅cos(2⋅χ)⎥ +⎢ ⎥ +⎣ I₀⋅p⋅sin(2⋅χ) ⎦ + +We calculate how the Jones vector is modified by a half-wave plate: + +>>> alpha = symbols("alpha", real=True) +>>> HWP = half_wave_retarder(alpha) +>>> x1 = simplify(HWP*x0) + +We calculate the very common operation of passing a beam through a half-wave +plate and then through a polarizing beam-splitter. We do this by putting this +Jones vector as the first entry of a two-Jones-vector state that is transformed +by a 4x4 Jones matrix modelling the polarizing beam-splitter to get the +transmitted and reflected Jones vectors: + +>>> PBS = polarizing_beam_splitter() +>>> X1 = zeros(4, 1) +>>> X1[:2, :] = x1 +>>> X2 = PBS*X1 +>>> transmitted_port = X2[:2, :] +>>> reflected_port = X2[2:, :] + +This allows us to calculate how the power in both ports depends on the initial +polarization: + +>>> transmitted_power = jones_2_stokes(transmitted_port)[0] +>>> reflected_power = jones_2_stokes(reflected_port)[0] +>>> print(transmitted_power) +cos(-2*alpha + chi + psi)**2/2 + cos(2*alpha + chi - psi)**2/2 + + +>>> print(reflected_power) +sin(-2*alpha + chi + psi)**2/2 + sin(2*alpha + chi - psi)**2/2 + +Please see the description of the individual functions for further +details and examples. + +References +========== + +.. [1] https://en.wikipedia.org/wiki/Jones_calculus +.. [2] https://en.wikipedia.org/wiki/Mueller_calculus +.. [3] https://en.wikipedia.org/wiki/Stokes_parameters + +""" + +from sympy.core.numbers import (I, pi) +from sympy.functions.elementary.complexes import (Abs, im, re) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.physics.quantum import TensorProduct + + +def jones_vector(psi, chi): + """A Jones vector corresponding to a polarization ellipse with `psi` tilt, + and `chi` circularity. + + Parameters + ========== + + psi : numeric type or SymPy Symbol + The tilt of the polarization relative to the `x` axis. + + chi : numeric type or SymPy Symbol + The angle adjacent to the mayor axis of the polarization ellipse. + + + Returns + ======= + + Matrix : + A Jones vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, symbols, pi + >>> from sympy.physics.optics.polarization import jones_vector + >>> psi, chi = symbols("psi, chi", real=True) + + A general Jones vector. + + >>> pprint(jones_vector(psi, chi), use_unicode=True) + ⎡-ⅈ⋅sin(χ)⋅sin(ψ) + cos(χ)⋅cos(ψ)⎤ + ⎢ ⎥ + ⎣ⅈ⋅sin(χ)⋅cos(ψ) + sin(ψ)⋅cos(χ) ⎦ + + Horizontal polarization. + + >>> pprint(jones_vector(0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎣0⎦ + + Vertical polarization. + + >>> pprint(jones_vector(pi/2, 0), use_unicode=True) + ⎡0⎤ + ⎢ ⎥ + ⎣1⎦ + + Diagonal polarization. + + >>> pprint(jones_vector(pi/4, 0), use_unicode=True) + ⎡√2⎤ + ⎢──⎥ + ⎢2 ⎥ + ⎢ ⎥ + ⎢√2⎥ + ⎢──⎥ + ⎣2 ⎦ + + Anti-diagonal polarization. + + >>> pprint(jones_vector(-pi/4, 0), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2 ⎥ + ⎢────⎥ + ⎣ 2 ⎦ + + Right-hand circular polarization. + + >>> pprint(jones_vector(0, pi/4), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢√2⋅ⅈ⎥ + ⎢────⎥ + ⎣ 2 ⎦ + + Left-hand circular polarization. + + >>> pprint(jones_vector(0, -pi/4), use_unicode=True) + ⎡ √2 ⎤ + ⎢ ── ⎥ + ⎢ 2 ⎥ + ⎢ ⎥ + ⎢-√2⋅ⅈ ⎥ + ⎢──────⎥ + ⎣ 2 ⎦ + + """ + return Matrix([-I*sin(chi)*sin(psi) + cos(chi)*cos(psi), + I*sin(chi)*cos(psi) + sin(psi)*cos(chi)]) + + +def stokes_vector(psi, chi, p=1, I=1): + """A Stokes vector corresponding to a polarization ellipse with ``psi`` + tilt, and ``chi`` circularity. + + Parameters + ========== + + psi : numeric type or SymPy Symbol + The tilt of the polarization relative to the ``x`` axis. + chi : numeric type or SymPy Symbol + The angle adjacent to the mayor axis of the polarization ellipse. + p : numeric type or SymPy Symbol + The degree of polarization. + I : numeric type or SymPy Symbol + The intensity of the field. + + + Returns + ======= + + Matrix : + A Stokes vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, symbols, pi + >>> from sympy.physics.optics.polarization import stokes_vector + >>> psi, chi, p, I = symbols("psi, chi, p, I", real=True) + >>> pprint(stokes_vector(psi, chi, p, I), use_unicode=True) + ⎡ I ⎤ + ⎢ ⎥ + ⎢I⋅p⋅cos(2⋅χ)⋅cos(2⋅ψ)⎥ + ⎢ ⎥ + ⎢I⋅p⋅sin(2⋅ψ)⋅cos(2⋅χ)⎥ + ⎢ ⎥ + ⎣ I⋅p⋅sin(2⋅χ) ⎦ + + + Horizontal polarization + + >>> pprint(stokes_vector(0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢1⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣0⎦ + + Vertical polarization + + >>> pprint(stokes_vector(pi/2, 0), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢-1⎥ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎣0 ⎦ + + Diagonal polarization + + >>> pprint(stokes_vector(pi/4, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢1⎥ + ⎢ ⎥ + ⎣0⎦ + + Anti-diagonal polarization + + >>> pprint(stokes_vector(-pi/4, 0), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎢-1⎥ + ⎢ ⎥ + ⎣0 ⎦ + + Right-hand circular polarization + + >>> pprint(stokes_vector(0, pi/4), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣1⎦ + + Left-hand circular polarization + + >>> pprint(stokes_vector(0, -pi/4), use_unicode=True) + ⎡1 ⎤ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎢0 ⎥ + ⎢ ⎥ + ⎣-1⎦ + + Unpolarized light + + >>> pprint(stokes_vector(0, 0, 0), use_unicode=True) + ⎡1⎤ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎢0⎥ + ⎢ ⎥ + ⎣0⎦ + + """ + S0 = I + S1 = I*p*cos(2*psi)*cos(2*chi) + S2 = I*p*sin(2*psi)*cos(2*chi) + S3 = I*p*sin(2*chi) + return Matrix([S0, S1, S2, S3]) + + +def jones_2_stokes(e): + """Return the Stokes vector for a Jones vector ``e``. + + Parameters + ========== + + e : SymPy Matrix + A Jones vector. + + Returns + ======= + + SymPy Matrix + A Jones vector. + + Examples + ======== + + The axes on the Poincaré sphere. + + >>> from sympy import pprint, pi + >>> from sympy.physics.optics.polarization import jones_vector + >>> from sympy.physics.optics.polarization import jones_2_stokes + >>> H = jones_vector(0, 0) + >>> V = jones_vector(pi/2, 0) + >>> D = jones_vector(pi/4, 0) + >>> A = jones_vector(-pi/4, 0) + >>> R = jones_vector(0, pi/4) + >>> L = jones_vector(0, -pi/4) + >>> pprint([jones_2_stokes(e) for e in [H, V, D, A, R, L]], + ... use_unicode=True) + ⎡⎡1⎤ ⎡1 ⎤ ⎡1⎤ ⎡1 ⎤ ⎡1⎤ ⎡1 ⎤⎤ + ⎢⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥⎥ + ⎢⎢1⎥ ⎢-1⎥ ⎢0⎥ ⎢0 ⎥ ⎢0⎥ ⎢0 ⎥⎥ + ⎢⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥, ⎢ ⎥⎥ + ⎢⎢0⎥ ⎢0 ⎥ ⎢1⎥ ⎢-1⎥ ⎢0⎥ ⎢0 ⎥⎥ + ⎢⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥ ⎢ ⎥⎥ + ⎣⎣0⎦ ⎣0 ⎦ ⎣0⎦ ⎣0 ⎦ ⎣1⎦ ⎣-1⎦⎦ + + """ + ex, ey = e + return Matrix([Abs(ex)**2 + Abs(ey)**2, + Abs(ex)**2 - Abs(ey)**2, + 2*re(ex*ey.conjugate()), + -2*im(ex*ey.conjugate())]) + + +def linear_polarizer(theta=0): + """A linear polarizer Jones matrix with transmission axis at + an angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the transmission axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the polarizer. + + Examples + ======== + + A generic polarizer. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import linear_polarizer + >>> theta = symbols("theta", real=True) + >>> J = linear_polarizer(theta) + >>> pprint(J, use_unicode=True) + ⎡ 2 ⎤ + ⎢ cos (θ) sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ 2 ⎥ + ⎣sin(θ)⋅cos(θ) sin (θ) ⎦ + + + """ + M = Matrix([[cos(theta)**2, sin(theta)*cos(theta)], + [sin(theta)*cos(theta), sin(theta)**2]]) + return M + + +def phase_retarder(theta=0, delta=0): + """A phase retarder Jones matrix with retardance ``delta`` at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + delta : numeric type or SymPy Symbol + The phase difference between the fast and slow axes of the + transmitted light. + + Returns + ======= + + SymPy Matrix : + A Jones matrix representing the retarder. + + Examples + ======== + + A generic retarder. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import phase_retarder + >>> theta, delta = symbols("theta, delta", real=True) + >>> R = phase_retarder(theta, delta) + >>> pprint(R, use_unicode=True) + ⎡ -ⅈ⋅δ -ⅈ⋅δ ⎤ + ⎢ ───── ───── ⎥ + ⎢⎛ ⅈ⋅δ 2 2 ⎞ 2 ⎛ ⅈ⋅δ⎞ 2 ⎥ + ⎢⎝ℯ ⋅sin (θ) + cos (θ)⎠⋅ℯ ⎝1 - ℯ ⎠⋅ℯ ⋅sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅δ -ⅈ⋅δ ⎥ + ⎢ ───── ─────⎥ + ⎢⎛ ⅈ⋅δ⎞ 2 ⎛ ⅈ⋅δ 2 2 ⎞ 2 ⎥ + ⎣⎝1 - ℯ ⎠⋅ℯ ⋅sin(θ)⋅cos(θ) ⎝ℯ ⋅cos (θ) + sin (θ)⎠⋅ℯ ⎦ + + """ + R = Matrix([[cos(theta)**2 + exp(I*delta)*sin(theta)**2, + (1-exp(I*delta))*cos(theta)*sin(theta)], + [(1-exp(I*delta))*cos(theta)*sin(theta), + sin(theta)**2 + exp(I*delta)*cos(theta)**2]]) + return R*exp(-I*delta/2) + + +def half_wave_retarder(theta): + """A half-wave retarder Jones matrix at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the retarder. + + Examples + ======== + + A generic half-wave plate. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import half_wave_retarder + >>> theta= symbols("theta", real=True) + >>> HWP = half_wave_retarder(theta) + >>> pprint(HWP, use_unicode=True) + ⎡ ⎛ 2 2 ⎞ ⎤ + ⎢-ⅈ⋅⎝- sin (θ) + cos (θ)⎠ -2⋅ⅈ⋅sin(θ)⋅cos(θ) ⎥ + ⎢ ⎥ + ⎢ ⎛ 2 2 ⎞⎥ + ⎣ -2⋅ⅈ⋅sin(θ)⋅cos(θ) -ⅈ⋅⎝sin (θ) - cos (θ)⎠⎦ + + """ + return phase_retarder(theta, pi) + + +def quarter_wave_retarder(theta): + """A quarter-wave retarder Jones matrix at angle ``theta``. + + Parameters + ========== + + theta : numeric type or SymPy Symbol + The angle of the fast axis relative to the horizontal plane. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the retarder. + + Examples + ======== + + A generic quarter-wave plate. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import quarter_wave_retarder + >>> theta= symbols("theta", real=True) + >>> QWP = quarter_wave_retarder(theta) + >>> pprint(QWP, use_unicode=True) + ⎡ -ⅈ⋅π -ⅈ⋅π ⎤ + ⎢ ───── ───── ⎥ + ⎢⎛ 2 2 ⎞ 4 4 ⎥ + ⎢⎝ⅈ⋅sin (θ) + cos (θ)⎠⋅ℯ (1 - ⅈ)⋅ℯ ⋅sin(θ)⋅cos(θ)⎥ + ⎢ ⎥ + ⎢ -ⅈ⋅π -ⅈ⋅π ⎥ + ⎢ ───── ─────⎥ + ⎢ 4 ⎛ 2 2 ⎞ 4 ⎥ + ⎣(1 - ⅈ)⋅ℯ ⋅sin(θ)⋅cos(θ) ⎝sin (θ) + ⅈ⋅cos (θ)⎠⋅ℯ ⎦ + + """ + return phase_retarder(theta, pi/2) + + +def transmissive_filter(T): + """An attenuator Jones matrix with transmittance ``T``. + + Parameters + ========== + + T : numeric type or SymPy Symbol + The transmittance of the attenuator. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the filter. + + Examples + ======== + + A generic filter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import transmissive_filter + >>> T = symbols("T", real=True) + >>> NDF = transmissive_filter(T) + >>> pprint(NDF, use_unicode=True) + ⎡√T 0 ⎤ + ⎢ ⎥ + ⎣0 √T⎦ + + """ + return Matrix([[sqrt(T), 0], [0, sqrt(T)]]) + + +def reflective_filter(R): + """A reflective filter Jones matrix with reflectance ``R``. + + Parameters + ========== + + R : numeric type or SymPy Symbol + The reflectance of the filter. + + Returns + ======= + + SymPy Matrix + A Jones matrix representing the filter. + + Examples + ======== + + A generic filter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import reflective_filter + >>> R = symbols("R", real=True) + >>> pprint(reflective_filter(R), use_unicode=True) + ⎡√R 0 ⎤ + ⎢ ⎥ + ⎣0 -√R⎦ + + """ + return Matrix([[sqrt(R), 0], [0, -sqrt(R)]]) + + +def mueller_matrix(J): + """The Mueller matrix corresponding to Jones matrix `J`. + + Parameters + ========== + + J : SymPy Matrix + A Jones matrix. + + Returns + ======= + + SymPy Matrix + The corresponding Mueller matrix. + + Examples + ======== + + Generic optical components. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import (mueller_matrix, + ... linear_polarizer, half_wave_retarder, quarter_wave_retarder) + >>> theta = symbols("theta", real=True) + + A linear_polarizer + + >>> pprint(mueller_matrix(linear_polarizer(theta)), use_unicode=True) + ⎡ cos(2⋅θ) sin(2⋅θ) ⎤ + ⎢ 1/2 ──────── ──────── 0⎥ + ⎢ 2 2 ⎥ + ⎢ ⎥ + ⎢cos(2⋅θ) cos(4⋅θ) 1 sin(4⋅θ) ⎥ + ⎢──────── ──────── + ─ ──────── 0⎥ + ⎢ 2 4 4 4 ⎥ + ⎢ ⎥ + ⎢sin(2⋅θ) sin(4⋅θ) 1 cos(4⋅θ) ⎥ + ⎢──────── ──────── ─ - ──────── 0⎥ + ⎢ 2 4 4 4 ⎥ + ⎢ ⎥ + ⎣ 0 0 0 0⎦ + + A half-wave plate + + >>> pprint(mueller_matrix(half_wave_retarder(theta)), use_unicode=True) + ⎡1 0 0 0 ⎤ + ⎢ ⎥ + ⎢ 4 2 ⎥ + ⎢0 8⋅sin (θ) - 8⋅sin (θ) + 1 sin(4⋅θ) 0 ⎥ + ⎢ ⎥ + ⎢ 4 2 ⎥ + ⎢0 sin(4⋅θ) - 8⋅sin (θ) + 8⋅sin (θ) - 1 0 ⎥ + ⎢ ⎥ + ⎣0 0 0 -1⎦ + + A quarter-wave plate + + >>> pprint(mueller_matrix(quarter_wave_retarder(theta)), use_unicode=True) + ⎡1 0 0 0 ⎤ + ⎢ ⎥ + ⎢ cos(4⋅θ) 1 sin(4⋅θ) ⎥ + ⎢0 ──────── + ─ ──────── -sin(2⋅θ)⎥ + ⎢ 2 2 2 ⎥ + ⎢ ⎥ + ⎢ sin(4⋅θ) 1 cos(4⋅θ) ⎥ + ⎢0 ──────── ─ - ──────── cos(2⋅θ) ⎥ + ⎢ 2 2 2 ⎥ + ⎢ ⎥ + ⎣0 sin(2⋅θ) -cos(2⋅θ) 0 ⎦ + + """ + A = Matrix([[1, 0, 0, 1], + [1, 0, 0, -1], + [0, 1, 1, 0], + [0, -I, I, 0]]) + + return simplify(A*TensorProduct(J, J.conjugate())*A.inv()) + + +def polarizing_beam_splitter(Tp=1, Rs=1, Ts=0, Rp=0, phia=0, phib=0): + r"""A polarizing beam splitter Jones matrix at angle `theta`. + + Parameters + ========== + + J : SymPy Matrix + A Jones matrix. + Tp : numeric type or SymPy Symbol + The transmissivity of the P-polarized component. + Rs : numeric type or SymPy Symbol + The reflectivity of the S-polarized component. + Ts : numeric type or SymPy Symbol + The transmissivity of the S-polarized component. + Rp : numeric type or SymPy Symbol + The reflectivity of the P-polarized component. + phia : numeric type or SymPy Symbol + The phase difference between transmitted and reflected component for + output mode a. + phib : numeric type or SymPy Symbol + The phase difference between transmitted and reflected component for + output mode b. + + + Returns + ======= + + SymPy Matrix + A 4x4 matrix representing the PBS. This matrix acts on a 4x1 vector + whose first two entries are the Jones vector on one of the PBS ports, + and the last two entries the Jones vector on the other port. + + Examples + ======== + + Generic polarizing beam-splitter. + + >>> from sympy import pprint, symbols + >>> from sympy.physics.optics.polarization import polarizing_beam_splitter + >>> Ts, Rs, Tp, Rp = symbols(r"Ts, Rs, Tp, Rp", positive=True) + >>> phia, phib = symbols("phi_a, phi_b", real=True) + >>> PBS = polarizing_beam_splitter(Tp, Rs, Ts, Rp, phia, phib) + >>> pprint(PBS, use_unicode=False) + [ ____ ____ ] + [ \/ Tp 0 I*\/ Rp 0 ] + [ ] + [ ____ ____ I*phi_a] + [ 0 \/ Ts 0 -I*\/ Rs *e ] + [ ] + [ ____ ____ ] + [I*\/ Rp 0 \/ Tp 0 ] + [ ] + [ ____ I*phi_b ____ ] + [ 0 -I*\/ Rs *e 0 \/ Ts ] + + """ + PBS = Matrix([[sqrt(Tp), 0, I*sqrt(Rp), 0], + [0, sqrt(Ts), 0, -I*sqrt(Rs)*exp(I*phia)], + [I*sqrt(Rp), 0, sqrt(Tp), 0], + [0, -I*sqrt(Rs)*exp(I*phib), 0, sqrt(Ts)]]) + return PBS diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/__init__.py b/lib/python3.10/site-packages/sympy/physics/optics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_gaussopt.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_gaussopt.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41239d52e4d3d54a8e7285e5b3bdd81d90f458de Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_gaussopt.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_medium.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_medium.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b851bc13671d9138cf2ca4908c23fa281e792dc0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_medium.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_polarization.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_polarization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..860a168021ccb503dafb413a4ed9bb33064352e0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/optics/tests/__pycache__/test_polarization.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/test_gaussopt.py b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_gaussopt.py new file mode 100644 index 0000000000000000000000000000000000000000..5271f3cbb69cf5de861ff332d36418b79daeb1b5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_gaussopt.py @@ -0,0 +1,102 @@ +from sympy.core.evalf import N +from sympy.core.numbers import (Float, I, oo, pi) +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import atan2 +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import factor + +from sympy.physics.optics import (BeamParameter, CurvedMirror, + CurvedRefraction, FlatMirror, FlatRefraction, FreeSpace, GeometricRay, + RayTransferMatrix, ThinLens, conjugate_gauss_beams, + gaussian_conj, geometric_conj_ab, geometric_conj_af, geometric_conj_bf, + rayleigh2waist, waist2rayleigh) + + +def streq(a, b): + return str(a) == str(b) + + +def test_gauss_opt(): + mat = RayTransferMatrix(1, 2, 3, 4) + assert mat == Matrix([[1, 2], [3, 4]]) + assert mat == RayTransferMatrix( Matrix([[1, 2], [3, 4]]) ) + assert [mat.A, mat.B, mat.C, mat.D] == [1, 2, 3, 4] + + d, f, h, n1, n2, R = symbols('d f h n1 n2 R') + lens = ThinLens(f) + assert lens == Matrix([[ 1, 0], [-1/f, 1]]) + assert lens.C == -1/f + assert FreeSpace(d) == Matrix([[ 1, d], [0, 1]]) + assert FlatRefraction(n1, n2) == Matrix([[1, 0], [0, n1/n2]]) + assert CurvedRefraction( + R, n1, n2) == Matrix([[1, 0], [(n1 - n2)/(R*n2), n1/n2]]) + assert FlatMirror() == Matrix([[1, 0], [0, 1]]) + assert CurvedMirror(R) == Matrix([[ 1, 0], [-2/R, 1]]) + assert ThinLens(f) == Matrix([[ 1, 0], [-1/f, 1]]) + + mul = CurvedMirror(R)*FreeSpace(d) + mul_mat = Matrix([[ 1, 0], [-2/R, 1]])*Matrix([[ 1, d], [0, 1]]) + assert mul.A == mul_mat[0, 0] + assert mul.B == mul_mat[0, 1] + assert mul.C == mul_mat[1, 0] + assert mul.D == mul_mat[1, 1] + + angle = symbols('angle') + assert GeometricRay(h, angle) == Matrix([[ h], [angle]]) + assert FreeSpace( + d)*GeometricRay(h, angle) == Matrix([[angle*d + h], [angle]]) + assert GeometricRay( Matrix( ((h,), (angle,)) ) ) == Matrix([[h], [angle]]) + assert (FreeSpace(d)*GeometricRay(h, angle)).height == angle*d + h + assert (FreeSpace(d)*GeometricRay(h, angle)).angle == angle + + p = BeamParameter(530e-9, 1, w=1e-3) + assert streq(p.q, 1 + 1.88679245283019*I*pi) + assert streq(N(p.q), 1.0 + 5.92753330865999*I) + assert streq(N(p.w_0), Float(0.00100000000000000)) + assert streq(N(p.z_r), Float(5.92753330865999)) + fs = FreeSpace(10) + p1 = fs*p + assert streq(N(p.w), Float(0.00101413072159615)) + assert streq(N(p1.w), Float(0.00210803120913829)) + + w, wavelen = symbols('w wavelen') + assert waist2rayleigh(w, wavelen) == pi*w**2/wavelen + z_r, wavelen = symbols('z_r wavelen') + assert rayleigh2waist(z_r, wavelen) == sqrt(wavelen*z_r)/sqrt(pi) + + a, b, f = symbols('a b f') + assert geometric_conj_ab(a, b) == a*b/(a + b) + assert geometric_conj_af(a, f) == a*f/(a - f) + assert geometric_conj_bf(b, f) == b*f/(b - f) + assert geometric_conj_ab(oo, b) == b + assert geometric_conj_ab(a, oo) == a + + s_in, z_r_in, f = symbols('s_in z_r_in f') + assert gaussian_conj( + s_in, z_r_in, f)[0] == 1/(-1/(s_in + z_r_in**2/(-f + s_in)) + 1/f) + assert gaussian_conj( + s_in, z_r_in, f)[1] == z_r_in/(1 - s_in**2/f**2 + z_r_in**2/f**2) + assert gaussian_conj( + s_in, z_r_in, f)[2] == 1/sqrt(1 - s_in**2/f**2 + z_r_in**2/f**2) + + l, w_i, w_o, f = symbols('l w_i w_o f') + assert conjugate_gauss_beams(l, w_i, w_o, f=f)[0] == f*( + -sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2)) + 1) + assert factor(conjugate_gauss_beams(l, w_i, w_o, f=f)[1]) == f*w_o**2*( + w_i**2/w_o**2 - sqrt(w_i**2/w_o**2 - pi**2*w_i**4/(f**2*l**2)))/w_i**2 + assert conjugate_gauss_beams(l, w_i, w_o, f=f)[2] == f + + z, l, w_0 = symbols('z l w_0', positive=True) + p = BeamParameter(l, z, w=w_0) + assert p.radius == z*(pi**2*w_0**4/(l**2*z**2) + 1) + assert p.w == w_0*sqrt(l**2*z**2/(pi**2*w_0**4) + 1) + assert p.w_0 == w_0 + assert p.divergence == l/(pi*w_0) + assert p.gouy == atan2(z, pi*w_0**2/l) + assert p.waist_approximation_limit == 2*l/pi + + p = BeamParameter(530e-9, 1, w=1e-3, n=2) + assert streq(p.q, 1 + 3.77358490566038*I*pi) + assert streq(N(p.z_r), Float(11.8550666173200)) + assert streq(N(p.w_0), Float(0.00100000000000000)) diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/test_medium.py b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_medium.py new file mode 100644 index 0000000000000000000000000000000000000000..dfbb485f5b8e401f38c7f1cfa573f960a2479d7b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_medium.py @@ -0,0 +1,48 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.optics import Medium +from sympy.abc import epsilon, mu, n +from sympy.physics.units import speed_of_light, u0, e0, m, kg, s, A + +from sympy.testing.pytest import raises + +c = speed_of_light.convert_to(m/s) +e0 = e0.convert_to(A**2*s**4/(kg*m**3)) +u0 = u0.convert_to(m*kg/(A**2*s**2)) + + +def test_medium(): + m1 = Medium('m1') + assert m1.intrinsic_impedance == sqrt(u0/e0) + assert m1.speed == 1/sqrt(e0*u0) + assert m1.refractive_index == c*sqrt(e0*u0) + assert m1.permittivity == e0 + assert m1.permeability == u0 + m2 = Medium('m2', epsilon, mu) + assert m2.intrinsic_impedance == sqrt(mu/epsilon) + assert m2.speed == 1/sqrt(epsilon*mu) + assert m2.refractive_index == c*sqrt(epsilon*mu) + assert m2.permittivity == epsilon + assert m2.permeability == mu + # Increasing electric permittivity and magnetic permeability + # by small amount from its value in vacuum. + m3 = Medium('m3', 9.0*10**(-12)*s**4*A**2/(m**3*kg), 1.45*10**(-6)*kg*m/(A**2*s**2)) + assert m3.refractive_index > m1.refractive_index + assert m3 != m1 + # Decreasing electric permittivity and magnetic permeability + # by small amount from its value in vacuum. + m4 = Medium('m4', 7.0*10**(-12)*s**4*A**2/(m**3*kg), 1.15*10**(-6)*kg*m/(A**2*s**2)) + assert m4.refractive_index < m1.refractive_index + m5 = Medium('m5', permittivity=710*10**(-12)*s**4*A**2/(m**3*kg), n=1.33) + assert abs(m5.intrinsic_impedance - 6.24845417765552*kg*m**2/(A**2*s**3)) \ + < 1e-12*kg*m**2/(A**2*s**3) + assert abs(m5.speed - 225407863.157895*m/s) < 1e-6*m/s + assert abs(m5.refractive_index - 1.33000000000000) < 1e-12 + assert abs(m5.permittivity - 7.1e-10*A**2*s**4/(kg*m**3)) \ + < 1e-20*A**2*s**4/(kg*m**3) + assert abs(m5.permeability - 2.77206575232851e-8*kg*m/(A**2*s**2)) \ + < 1e-20*kg*m/(A**2*s**2) + m6 = Medium('m6', None, mu, n) + assert m6.permittivity == n**2/(c**2*mu) + # test for equality of refractive indices + assert Medium('m7').refractive_index == Medium('m8', e0, u0).refractive_index + raises(ValueError, lambda:Medium('m9', e0, u0, 2)) diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/test_polarization.py b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_polarization.py new file mode 100644 index 0000000000000000000000000000000000000000..99c595d82a4a296066d5075f6182895a8de54d91 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_polarization.py @@ -0,0 +1,57 @@ +from sympy.physics.optics.polarization import (jones_vector, stokes_vector, + jones_2_stokes, linear_polarizer, phase_retarder, half_wave_retarder, + quarter_wave_retarder, transmissive_filter, reflective_filter, + mueller_matrix, polarizing_beam_splitter) +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix + + +def test_polarization(): + assert jones_vector(0, 0) == Matrix([1, 0]) + assert jones_vector(pi/2, 0) == Matrix([0, 1]) + ################################################################# + assert stokes_vector(0, 0) == Matrix([1, 1, 0, 0]) + assert stokes_vector(pi/2, 0) == Matrix([1, -1, 0, 0]) + ################################################################# + H = jones_vector(0, 0) + V = jones_vector(pi/2, 0) + D = jones_vector(pi/4, 0) + A = jones_vector(-pi/4, 0) + R = jones_vector(0, pi/4) + L = jones_vector(0, -pi/4) + + res = [Matrix([1, 1, 0, 0]), + Matrix([1, -1, 0, 0]), + Matrix([1, 0, 1, 0]), + Matrix([1, 0, -1, 0]), + Matrix([1, 0, 0, 1]), + Matrix([1, 0, 0, -1])] + + assert [jones_2_stokes(e) for e in [H, V, D, A, R, L]] == res + ################################################################# + assert linear_polarizer(0) == Matrix([[1, 0], [0, 0]]) + ################################################################# + delta = symbols("delta", real=True) + res = Matrix([[exp(-I*delta/2), 0], [0, exp(I*delta/2)]]) + assert phase_retarder(0, delta) == res + ################################################################# + assert half_wave_retarder(0) == Matrix([[-I, 0], [0, I]]) + ################################################################# + res = Matrix([[exp(-I*pi/4), 0], [0, I*exp(-I*pi/4)]]) + assert quarter_wave_retarder(0) == res + ################################################################# + assert transmissive_filter(1) == Matrix([[1, 0], [0, 1]]) + ################################################################# + assert reflective_filter(1) == Matrix([[1, 0], [0, -1]]) + + res = Matrix([[S(1)/2, S(1)/2, 0, 0], + [S(1)/2, S(1)/2, 0, 0], + [0, 0, 0, 0], + [0, 0, 0, 0]]) + assert mueller_matrix(linear_polarizer(0)) == res + ################################################################# + res = Matrix([[1, 0, 0, 0], [0, 0, 0, -I], [0, 0, 1, 0], [0, -I, 0, 0]]) + assert polarizing_beam_splitter() == res diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/test_utils.py b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..6c93883a081d3614a604aeadc8a4b617181de669 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_utils.py @@ -0,0 +1,202 @@ +from sympy.core.numbers import comp, Rational +from sympy.physics.optics.utils import (refraction_angle, fresnel_coefficients, + deviation, brewster_angle, critical_angle, lens_makers_formula, + mirror_formula, lens_formula, hyperfocal_distance, + transverse_magnification) +from sympy.physics.optics.medium import Medium +from sympy.physics.units import e0 + +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.geometry.point import Point3D +from sympy.geometry.line import Ray3D +from sympy.geometry.plane import Plane + +from sympy.testing.pytest import raises + + +ae = lambda a, b, n: comp(a, b, 10**-n) + + +def test_refraction_angle(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1') + m2 = Medium('m2') + r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + i = Matrix([1, 1, 1]) + n = Matrix([0, 0, 1]) + normal_ray = Ray3D(Point3D(0, 0, 0), Point3D(0, 0, 1)) + P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + assert refraction_angle(r1, 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle([1, 1, 1], 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle((1, 1, 1), 1, 1, n) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, [0, 0, 1]) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, (0, 0, 1)) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, normal_ray) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(i, 1, 1, plane=P) == Matrix([ + [ 1], + [ 1], + [-1]]) + assert refraction_angle(r1, 1, 1, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + assert refraction_angle(r1, m1, 1.33, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(Rational(100, 133), Rational(100, 133), -789378201649271*sqrt(3)/1000000000000000)) + assert refraction_angle(r1, 1, m2, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + assert refraction_angle(r1, n1, n2, plane=P) == \ + Ray3D(Point3D(0, 0, 0), Point3D(n1/n2, n1/n2, -sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1))) + assert refraction_angle(r1, 1.33, 1, plane=P) == 0 # TIR + assert refraction_angle(r1, 1, 1, normal_ray) == \ + Ray3D(Point3D(0, 0, 0), direction_ratio=[1, 1, -1]) + assert ae(refraction_angle(0.5, 1, 2), 0.24207, 5) + assert ae(refraction_angle(0.5, 2, 1), 1.28293, 5) + raises(ValueError, lambda: refraction_angle(r1, m1, m2, normal_ray, P)) + raises(TypeError, lambda: refraction_angle(m1, m1, m2)) # can add other values for arg[0] + raises(TypeError, lambda: refraction_angle(r1, m1, m2, None, i)) + raises(TypeError, lambda: refraction_angle(r1, m1, m2, m2)) + + +def test_fresnel_coefficients(): + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.5, 1, 1.33), + [0.11163, -0.17138, 0.83581, 0.82862])) + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.5, 1.33, 1), + [-0.07726, 0.20482, 1.22724, 1.20482])) + m1 = Medium('m1') + m2 = Medium('m2', n=2) + assert all(ae(i, j, 5) for i, j in zip( + fresnel_coefficients(0.3, m1, m2), + [0.31784, -0.34865, 0.65892, 0.65135])) + ans = [[-0.23563, -0.97184], [0.81648, -0.57738]] + got = fresnel_coefficients(0.6, m2, m1) + for i, j in zip(got, ans): + for a, b in zip(i.as_real_imag(), j): + assert ae(a, b, 5) + + +def test_deviation(): + n1, n2 = symbols('n1, n2') + r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + n = Matrix([0, 0, 1]) + i = Matrix([-1, -1, -1]) + normal_ray = Ray3D(Point3D(0, 0, 0), Point3D(0, 0, 1)) + P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + assert deviation(r1, 1, 1, normal=n) == 0 + assert deviation(r1, 1, 1, plane=P) == 0 + assert deviation(r1, 1, 1.1, plane=P).evalf(3) + 0.119 < 1e-3 + assert deviation(i, 1, 1.1, normal=normal_ray).evalf(3) + 0.119 < 1e-3 + assert deviation(r1, 1.33, 1, plane=P) is None # TIR + assert deviation(r1, 1, 1, normal=[0, 0, 1]) == 0 + assert deviation([-1, -1, -1], 1, 1, normal=[0, 0, 1]) == 0 + assert ae(deviation(0.5, 1, 2), -0.25793, 5) + assert ae(deviation(0.5, 2, 1), 0.78293, 5) + + +def test_brewster_angle(): + m1 = Medium('m1', n=1) + m2 = Medium('m2', n=1.33) + assert ae(brewster_angle(m1, m2), 0.93, 2) + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(brewster_angle(m1, m2), 0.93, 2) + assert ae(brewster_angle(1, 1.33), 0.93, 2) + + +def test_critical_angle(): + m1 = Medium('m1', n=1) + m2 = Medium('m2', n=1.33) + assert ae(critical_angle(m2, m1), 0.85, 2) + + +def test_lens_makers_formula(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert lens_makers_formula(n1, n2, 10, -10) == 5.0*n2/(n1 - n2) + assert ae(lens_makers_formula(m1, m2, 10, -10), -20.15, 2) + assert ae(lens_makers_formula(1.33, 1, 10, -10), 15.15, 2) + + +def test_mirror_formula(): + u, v, f = symbols('u, v, f') + assert mirror_formula(focal_length=f, u=u) == f*u/(-f + u) + assert mirror_formula(focal_length=f, v=v) == f*v/(-f + v) + assert mirror_formula(u=u, v=v) == u*v/(u + v) + assert mirror_formula(u=oo, v=v) == v + assert mirror_formula(u=oo, v=oo) is oo + assert mirror_formula(focal_length=oo, u=u) == -u + assert mirror_formula(u=u, v=oo) == u + assert mirror_formula(focal_length=oo, v=oo) is oo + assert mirror_formula(focal_length=f, v=oo) == f + assert mirror_formula(focal_length=oo, v=v) == -v + assert mirror_formula(focal_length=oo, u=oo) is oo + assert mirror_formula(focal_length=f, u=oo) == f + assert mirror_formula(focal_length=oo, u=u) == -u + raises(ValueError, lambda: mirror_formula(focal_length=f, u=u, v=v)) + + +def test_lens_formula(): + u, v, f = symbols('u, v, f') + assert lens_formula(focal_length=f, u=u) == f*u/(f + u) + assert lens_formula(focal_length=f, v=v) == f*v/(f - v) + assert lens_formula(u=u, v=v) == u*v/(u - v) + assert lens_formula(u=oo, v=v) == v + assert lens_formula(u=oo, v=oo) is oo + assert lens_formula(focal_length=oo, u=u) == u + assert lens_formula(u=u, v=oo) == -u + assert lens_formula(focal_length=oo, v=oo) is -oo + assert lens_formula(focal_length=oo, v=v) == v + assert lens_formula(focal_length=f, v=oo) == -f + assert lens_formula(focal_length=oo, u=oo) is oo + assert lens_formula(focal_length=oo, u=u) == u + assert lens_formula(focal_length=f, u=oo) == f + raises(ValueError, lambda: lens_formula(focal_length=f, u=u, v=v)) + + +def test_hyperfocal_distance(): + f, N, c = symbols('f, N, c') + assert hyperfocal_distance(f=f, N=N, c=c) == f**2/(N*c) + assert ae(hyperfocal_distance(f=0.5, N=8, c=0.0033), 9.47, 2) + + +def test_transverse_magnification(): + si, so = symbols('si, so') + assert transverse_magnification(si, so) == -si/so + assert transverse_magnification(30, 15) == -2 + + +def test_lens_makers_formula_thick_lens(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(lens_makers_formula(m1, m2, 10, -10, d=1), -19.82, 2) + assert lens_makers_formula(n1, n2, 1, -1, d=0.1) == n2/((2.0 - (0.1*n1 - 0.1*n2)/n1)*(n1 - n2)) + + +def test_lens_makers_formula_plano_lens(): + n1, n2 = symbols('n1, n2') + m1 = Medium('m1', permittivity=e0, n=1) + m2 = Medium('m2', permittivity=e0, n=1.33) + assert ae(lens_makers_formula(m1, m2, 10, oo), -40.30, 2) + assert lens_makers_formula(n1, n2, 10, oo) == 10.0*n2/(n1 - n2) diff --git a/lib/python3.10/site-packages/sympy/physics/optics/tests/test_waves.py b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_waves.py new file mode 100644 index 0000000000000000000000000000000000000000..3cb8f804fb5be86d6174cb7c7b15fd8979c85ff8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/tests/test_waves.py @@ -0,0 +1,82 @@ +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import (I, pi) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.simplify.simplify import simplify +from sympy.abc import epsilon, mu +from sympy.functions.elementary.exponential import exp +from sympy.physics.units import speed_of_light, m, s +from sympy.physics.optics import TWave + +from sympy.testing.pytest import raises + +c = speed_of_light.convert_to(m/s) + +def test_twave(): + A1, phi1, A2, phi2, f = symbols('A1, phi1, A2, phi2, f') + n = Symbol('n') # Refractive index + t = Symbol('t') # Time + x = Symbol('x') # Spatial variable + E = Function('E') + w1 = TWave(A1, f, phi1) + w2 = TWave(A2, f, phi2) + assert w1.amplitude == A1 + assert w1.frequency == f + assert w1.phase == phi1 + assert w1.wavelength == c/(f*n) + assert w1.time_period == 1/f + assert w1.angular_velocity == 2*pi*f + assert w1.wavenumber == 2*pi*f*n/c + assert w1.speed == c/n + + w3 = w1 + w2 + assert w3.amplitude == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2) + assert w3.frequency == f + assert w3.phase == atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)) + assert w3.wavelength == c/(f*n) + assert w3.time_period == 1/f + assert w3.angular_velocity == 2*pi*f + assert w3.wavenumber == 2*pi*f*n/c + assert w3.speed == c/n + assert simplify(w3.rewrite(sin) - w2.rewrite(sin) - w1.rewrite(sin)) == 0 + assert w3.rewrite('pde') == epsilon*mu*Derivative(E(x, t), t, t) + Derivative(E(x, t), x, x) + assert w3.rewrite(cos) == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + + A2**2)*cos(pi*f*n*x*s/(149896229*m) - 2*pi*f*t + atan2(A1*sin(phi1) + + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2))) + assert w3.rewrite(exp) == sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + + A2**2)*exp(I*(-2*pi*f*t + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + + A2*cos(phi2)) + pi*s*f*n*x/(149896229*m))) + + w4 = TWave(A1, None, 0, 1/f) + assert w4.frequency == f + + w5 = w1 - w2 + assert w5.amplitude == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + A2**2) + assert w5.frequency == f + assert w5.phase == atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) - A2*cos(phi2)) + assert w5.wavelength == c/(f*n) + assert w5.time_period == 1/f + assert w5.angular_velocity == 2*pi*f + assert w5.wavenumber == 2*pi*f*n/c + assert w5.speed == c/n + assert simplify(w5.rewrite(sin) - w1.rewrite(sin) + w2.rewrite(sin)) == 0 + assert w5.rewrite('pde') == epsilon*mu*Derivative(E(x, t), t, t) + Derivative(E(x, t), x, x) + assert w5.rewrite(cos) == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + + A2**2)*cos(-2*pi*f*t + atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) + - A2*cos(phi2)) + pi*s*f*n*x/(149896229*m)) + assert w5.rewrite(exp) == sqrt(A1**2 - 2*A1*A2*cos(phi1 - phi2) + + A2**2)*exp(I*(-2*pi*f*t + atan2(A1*sin(phi1) - A2*sin(phi2), A1*cos(phi1) + - A2*cos(phi2)) + pi*s*f*n*x/(149896229*m))) + + w6 = 2*w1 + assert w6.amplitude == 2*A1 + assert w6.frequency == f + assert w6.phase == phi1 + w7 = -w6 + assert w7.amplitude == -2*A1 + assert w7.frequency == f + assert w7.phase == phi1 + + raises(ValueError, lambda:TWave(A1)) + raises(ValueError, lambda:TWave(A1, f, phi1, t)) diff --git a/lib/python3.10/site-packages/sympy/physics/optics/utils.py b/lib/python3.10/site-packages/sympy/physics/optics/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..72c3b78bd4b09eb069757fb3f8d3632f09ec4b80 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/utils.py @@ -0,0 +1,698 @@ +""" +**Contains** + +* refraction_angle +* fresnel_coefficients +* deviation +* brewster_angle +* critical_angle +* lens_makers_formula +* mirror_formula +* lens_formula +* hyperfocal_distance +* transverse_magnification +""" + +__all__ = ['refraction_angle', + 'deviation', + 'fresnel_coefficients', + 'brewster_angle', + 'critical_angle', + 'lens_makers_formula', + 'mirror_formula', + 'lens_formula', + 'hyperfocal_distance', + 'transverse_magnification' + ] + +from sympy.core.numbers import (Float, I, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, atan2, cos, sin, tan) +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import cancel +from sympy.series.limits import Limit +from sympy.geometry.line import Ray3D +from sympy.geometry.util import intersection +from sympy.geometry.plane import Plane +from sympy.utilities.iterables import is_sequence +from .medium import Medium + + +def refractive_index_of_medium(medium): + """ + Helper function that returns refractive index, given a medium + """ + if isinstance(medium, Medium): + n = medium.refractive_index + else: + n = sympify(medium) + return n + + +def refraction_angle(incident, medium1, medium2, normal=None, plane=None): + """ + This function calculates transmitted vector after refraction at planar + surface. ``medium1`` and ``medium2`` can be ``Medium`` or any sympifiable object. + If ``incident`` is a number then treated as angle of incidence (in radians) + in which case refraction angle is returned. + + If ``incident`` is an object of `Ray3D`, `normal` also has to be an instance + of `Ray3D` in order to get the output as a `Ray3D`. Please note that if + plane of separation is not provided and normal is an instance of `Ray3D`, + ``normal`` will be assumed to be intersecting incident ray at the plane of + separation. This will not be the case when `normal` is a `Matrix` or + any other sequence. + If ``incident`` is an instance of `Ray3D` and `plane` has not been provided + and ``normal`` is not `Ray3D`, output will be a `Matrix`. + + Parameters + ========== + + incident : Matrix, Ray3D, sequence or a number + Incident vector or angle of incidence + medium1 : sympy.physics.optics.medium.Medium or sympifiable + Medium 1 or its refractive index + medium2 : sympy.physics.optics.medium.Medium or sympifiable + Medium 2 or its refractive index + normal : Matrix, Ray3D, or sequence + Normal vector + plane : Plane + Plane of separation of the two media. + + Returns + ======= + + Returns an angle of refraction or a refracted ray depending on inputs. + + Examples + ======== + + >>> from sympy.physics.optics import refraction_angle + >>> from sympy.geometry import Point3D, Ray3D, Plane + >>> from sympy.matrices import Matrix + >>> from sympy import symbols, pi + >>> n = Matrix([0, 0, 1]) + >>> P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + >>> r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + >>> refraction_angle(r1, 1, 1, n) + Matrix([ + [ 1], + [ 1], + [-1]]) + >>> refraction_angle(r1, 1, 1, plane=P) + Ray3D(Point3D(0, 0, 0), Point3D(1, 1, -1)) + + With different index of refraction of the two media + + >>> n1, n2 = symbols('n1, n2') + >>> refraction_angle(r1, n1, n2, n) + Matrix([ + [ n1/n2], + [ n1/n2], + [-sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1)]]) + >>> refraction_angle(r1, n1, n2, plane=P) + Ray3D(Point3D(0, 0, 0), Point3D(n1/n2, n1/n2, -sqrt(3)*sqrt(-2*n1**2/(3*n2**2) + 1))) + >>> round(refraction_angle(pi/6, 1.2, 1.5), 5) + 0.41152 + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + # check if an incidence angle was supplied instead of a ray + try: + angle_of_incidence = float(incident) + except TypeError: + angle_of_incidence = None + + try: + critical_angle_ = critical_angle(medium1, medium2) + except (ValueError, TypeError): + critical_angle_ = None + + if angle_of_incidence is not None: + if normal is not None or plane is not None: + raise ValueError('Normal/plane not allowed if incident is an angle') + + if not 0.0 <= angle_of_incidence < pi*0.5: + raise ValueError('Angle of incidence not in range [0:pi/2)') + + if critical_angle_ and angle_of_incidence > critical_angle_: + raise ValueError('Ray undergoes total internal reflection') + return asin(n1*sin(angle_of_incidence)/n2) + + # Treat the incident as ray below + # A flag to check whether to return Ray3D or not + return_ray = False + + if plane is not None and normal is not None: + raise ValueError("Either plane or normal is acceptable.") + + if not isinstance(incident, Matrix): + if is_sequence(incident): + _incident = Matrix(incident) + elif isinstance(incident, Ray3D): + _incident = Matrix(incident.direction_ratio) + else: + raise TypeError( + "incident should be a Matrix, Ray3D, or sequence") + else: + _incident = incident + + # If plane is provided, get direction ratios of the normal + # to the plane from the plane else go with `normal` param. + if plane is not None: + if not isinstance(plane, Plane): + raise TypeError("plane should be an instance of geometry.plane.Plane") + # If we have the plane, we can get the intersection + # point of incident ray and the plane and thus return + # an instance of Ray3D. + if isinstance(incident, Ray3D): + return_ray = True + intersection_pt = plane.intersection(incident)[0] + _normal = Matrix(plane.normal_vector) + else: + if not isinstance(normal, Matrix): + if is_sequence(normal): + _normal = Matrix(normal) + elif isinstance(normal, Ray3D): + _normal = Matrix(normal.direction_ratio) + if isinstance(incident, Ray3D): + intersection_pt = intersection(incident, normal) + if len(intersection_pt) == 0: + raise ValueError( + "Normal isn't concurrent with the incident ray.") + else: + return_ray = True + intersection_pt = intersection_pt[0] + else: + raise TypeError( + "Normal should be a Matrix, Ray3D, or sequence") + else: + _normal = normal + + eta = n1/n2 # Relative index of refraction + # Calculating magnitude of the vectors + mag_incident = sqrt(sum(i**2 for i in _incident)) + mag_normal = sqrt(sum(i**2 for i in _normal)) + # Converting vectors to unit vectors by dividing + # them with their magnitudes + _incident /= mag_incident + _normal /= mag_normal + c1 = -_incident.dot(_normal) # cos(angle_of_incidence) + cs2 = 1 - eta**2*(1 - c1**2) # cos(angle_of_refraction)**2 + if cs2.is_negative: # This is the case of total internal reflection(TIR). + return S.Zero + drs = eta*_incident + (eta*c1 - sqrt(cs2))*_normal + # Multiplying unit vector by its magnitude + drs = drs*mag_incident + if not return_ray: + return drs + else: + return Ray3D(intersection_pt, direction_ratio=drs) + + +def fresnel_coefficients(angle_of_incidence, medium1, medium2): + """ + This function uses Fresnel equations to calculate reflection and + transmission coefficients. Those are obtained for both polarisations + when the electric field vector is in the plane of incidence (labelled 'p') + and when the electric field vector is perpendicular to the plane of + incidence (labelled 's'). There are four real coefficients unless the + incident ray reflects in total internal in which case there are two complex + ones. Angle of incidence is the angle between the incident ray and the + surface normal. ``medium1`` and ``medium2`` can be ``Medium`` or any + sympifiable object. + + Parameters + ========== + + angle_of_incidence : sympifiable + + medium1 : Medium or sympifiable + Medium 1 or its refractive index + + medium2 : Medium or sympifiable + Medium 2 or its refractive index + + Returns + ======= + + Returns a list with four real Fresnel coefficients: + [reflection p (TM), reflection s (TE), + transmission p (TM), transmission s (TE)] + If the ray is undergoes total internal reflection then returns a + list of two complex Fresnel coefficients: + [reflection p (TM), reflection s (TE)] + + Examples + ======== + + >>> from sympy.physics.optics import fresnel_coefficients + >>> fresnel_coefficients(0.3, 1, 2) + [0.317843553417859, -0.348645229818821, + 0.658921776708929, 0.651354770181179] + >>> fresnel_coefficients(0.6, 2, 1) + [-0.235625382192159 - 0.971843958291041*I, + 0.816477005968898 - 0.577377951366403*I] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fresnel_equations + """ + if not 0 <= 2*angle_of_incidence < pi: + raise ValueError('Angle of incidence not in range [0:pi/2)') + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + angle_of_refraction = asin(n1*sin(angle_of_incidence)/n2) + try: + angle_of_total_internal_reflection_onset = critical_angle(n1, n2) + except ValueError: + angle_of_total_internal_reflection_onset = None + + if angle_of_total_internal_reflection_onset is None or\ + angle_of_total_internal_reflection_onset > angle_of_incidence: + R_s = -sin(angle_of_incidence - angle_of_refraction)\ + /sin(angle_of_incidence + angle_of_refraction) + R_p = tan(angle_of_incidence - angle_of_refraction)\ + /tan(angle_of_incidence + angle_of_refraction) + T_s = 2*sin(angle_of_refraction)*cos(angle_of_incidence)\ + /sin(angle_of_incidence + angle_of_refraction) + T_p = 2*sin(angle_of_refraction)*cos(angle_of_incidence)\ + /(sin(angle_of_incidence + angle_of_refraction)\ + *cos(angle_of_incidence - angle_of_refraction)) + return [R_p, R_s, T_p, T_s] + else: + n = n2/n1 + R_s = cancel((cos(angle_of_incidence)-\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))\ + /(cos(angle_of_incidence)+\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))) + R_p = cancel((n**2*cos(angle_of_incidence)-\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))\ + /(n**2*cos(angle_of_incidence)+\ + I*sqrt(sin(angle_of_incidence)**2 - n**2))) + return [R_p, R_s] + + +def deviation(incident, medium1, medium2, normal=None, plane=None): + """ + This function calculates the angle of deviation of a ray + due to refraction at planar surface. + + Parameters + ========== + + incident : Matrix, Ray3D, sequence or float + Incident vector or angle of incidence + medium1 : sympy.physics.optics.medium.Medium or sympifiable + Medium 1 or its refractive index + medium2 : sympy.physics.optics.medium.Medium or sympifiable + Medium 2 or its refractive index + normal : Matrix, Ray3D, or sequence + Normal vector + plane : Plane + Plane of separation of the two media. + + Returns angular deviation between incident and refracted rays + + Examples + ======== + + >>> from sympy.physics.optics import deviation + >>> from sympy.geometry import Point3D, Ray3D, Plane + >>> from sympy.matrices import Matrix + >>> from sympy import symbols + >>> n1, n2 = symbols('n1, n2') + >>> n = Matrix([0, 0, 1]) + >>> P = Plane(Point3D(0, 0, 0), normal_vector=[0, 0, 1]) + >>> r1 = Ray3D(Point3D(-1, -1, 1), Point3D(0, 0, 0)) + >>> deviation(r1, 1, 1, n) + 0 + >>> deviation(r1, n1, n2, plane=P) + -acos(-sqrt(-2*n1**2/(3*n2**2) + 1)) + acos(-sqrt(3)/3) + >>> round(deviation(0.1, 1.2, 1.5), 5) + -0.02005 + """ + refracted = refraction_angle(incident, + medium1, + medium2, + normal=normal, + plane=plane) + try: + angle_of_incidence = Float(incident) + except TypeError: + angle_of_incidence = None + + if angle_of_incidence is not None: + return float(refracted) - angle_of_incidence + + if refracted != 0: + if isinstance(refracted, Ray3D): + refracted = Matrix(refracted.direction_ratio) + + if not isinstance(incident, Matrix): + if is_sequence(incident): + _incident = Matrix(incident) + elif isinstance(incident, Ray3D): + _incident = Matrix(incident.direction_ratio) + else: + raise TypeError( + "incident should be a Matrix, Ray3D, or sequence") + else: + _incident = incident + + if plane is None: + if not isinstance(normal, Matrix): + if is_sequence(normal): + _normal = Matrix(normal) + elif isinstance(normal, Ray3D): + _normal = Matrix(normal.direction_ratio) + else: + raise TypeError( + "normal should be a Matrix, Ray3D, or sequence") + else: + _normal = normal + else: + _normal = Matrix(plane.normal_vector) + + mag_incident = sqrt(sum(i**2 for i in _incident)) + mag_normal = sqrt(sum(i**2 for i in _normal)) + mag_refracted = sqrt(sum(i**2 for i in refracted)) + _incident /= mag_incident + _normal /= mag_normal + refracted /= mag_refracted + i = acos(_incident.dot(_normal)) + r = acos(refracted.dot(_normal)) + return i - r + + +def brewster_angle(medium1, medium2): + """ + This function calculates the Brewster's angle of incidence to Medium 2 from + Medium 1 in radians. + + Parameters + ========== + + medium 1 : Medium or sympifiable + Refractive index of Medium 1 + medium 2 : Medium or sympifiable + Refractive index of Medium 1 + + Examples + ======== + + >>> from sympy.physics.optics import brewster_angle + >>> brewster_angle(1, 1.33) + 0.926093295503462 + + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + return atan2(n2, n1) + +def critical_angle(medium1, medium2): + """ + This function calculates the critical angle of incidence (marking the onset + of total internal) to Medium 2 from Medium 1 in radians. + + Parameters + ========== + + medium 1 : Medium or sympifiable + Refractive index of Medium 1. + medium 2 : Medium or sympifiable + Refractive index of Medium 1. + + Examples + ======== + + >>> from sympy.physics.optics import critical_angle + >>> critical_angle(1.33, 1) + 0.850908514477849 + + """ + + n1 = refractive_index_of_medium(medium1) + n2 = refractive_index_of_medium(medium2) + + if n2 > n1: + raise ValueError('Total internal reflection impossible for n1 < n2') + else: + return asin(n2/n1) + + + +def lens_makers_formula(n_lens, n_surr, r1, r2, d=0): + """ + This function calculates focal length of a lens. + It follows cartesian sign convention. + + Parameters + ========== + + n_lens : Medium or sympifiable + Index of refraction of lens. + n_surr : Medium or sympifiable + Index of reflection of surrounding. + r1 : sympifiable + Radius of curvature of first surface. + r2 : sympifiable + Radius of curvature of second surface. + d : sympifiable, optional + Thickness of lens, default value is 0. + + Examples + ======== + + >>> from sympy.physics.optics import lens_makers_formula + >>> from sympy import S + >>> lens_makers_formula(1.33, 1, 10, -10) + 15.1515151515151 + >>> lens_makers_formula(1.2, 1, 10, S.Infinity) + 50.0000000000000 + >>> lens_makers_formula(1.33, 1, 10, -10, d=1) + 15.3418463277618 + + """ + + if isinstance(n_lens, Medium): + n_lens = n_lens.refractive_index + else: + n_lens = sympify(n_lens) + if isinstance(n_surr, Medium): + n_surr = n_surr.refractive_index + else: + n_surr = sympify(n_surr) + d = sympify(d) + + focal_length = 1/((n_lens - n_surr) / n_surr*(1/r1 - 1/r2 + (((n_lens - n_surr) * d) / (n_lens * r1 * r2)))) + + if focal_length == zoo: + return S.Infinity + return focal_length + + +def mirror_formula(focal_length=None, u=None, v=None): + """ + This function provides one of the three parameters + when two of them are supplied. + This is valid only for paraxial rays. + + Parameters + ========== + + focal_length : sympifiable + Focal length of the mirror. + u : sympifiable + Distance of object from the pole on + the principal axis. + v : sympifiable + Distance of the image from the pole + on the principal axis. + + Examples + ======== + + >>> from sympy.physics.optics import mirror_formula + >>> from sympy.abc import f, u, v + >>> mirror_formula(focal_length=f, u=u) + f*u/(-f + u) + >>> mirror_formula(focal_length=f, v=v) + f*v/(-f + v) + >>> mirror_formula(u=u, v=v) + u*v/(u + v) + + """ + if focal_length and u and v: + raise ValueError("Please provide only two parameters") + + focal_length = sympify(focal_length) + u = sympify(u) + v = sympify(v) + if u is oo: + _u = Symbol('u') + if v is oo: + _v = Symbol('v') + if focal_length is oo: + _f = Symbol('f') + if focal_length is None: + if u is oo and v is oo: + return Limit(Limit(_v*_u/(_v + _u), _u, oo), _v, oo).doit() + if u is oo: + return Limit(v*_u/(v + _u), _u, oo).doit() + if v is oo: + return Limit(_v*u/(_v + u), _v, oo).doit() + return v*u/(v + u) + if u is None: + if v is oo and focal_length is oo: + return Limit(Limit(_v*_f/(_v - _f), _v, oo), _f, oo).doit() + if v is oo: + return Limit(_v*focal_length/(_v - focal_length), _v, oo).doit() + if focal_length is oo: + return Limit(v*_f/(v - _f), _f, oo).doit() + return v*focal_length/(v - focal_length) + if v is None: + if u is oo and focal_length is oo: + return Limit(Limit(_u*_f/(_u - _f), _u, oo), _f, oo).doit() + if u is oo: + return Limit(_u*focal_length/(_u - focal_length), _u, oo).doit() + if focal_length is oo: + return Limit(u*_f/(u - _f), _f, oo).doit() + return u*focal_length/(u - focal_length) + + +def lens_formula(focal_length=None, u=None, v=None): + """ + This function provides one of the three parameters + when two of them are supplied. + This is valid only for paraxial rays. + + Parameters + ========== + + focal_length : sympifiable + Focal length of the mirror. + u : sympifiable + Distance of object from the optical center on + the principal axis. + v : sympifiable + Distance of the image from the optical center + on the principal axis. + + Examples + ======== + + >>> from sympy.physics.optics import lens_formula + >>> from sympy.abc import f, u, v + >>> lens_formula(focal_length=f, u=u) + f*u/(f + u) + >>> lens_formula(focal_length=f, v=v) + f*v/(f - v) + >>> lens_formula(u=u, v=v) + u*v/(u - v) + + """ + if focal_length and u and v: + raise ValueError("Please provide only two parameters") + + focal_length = sympify(focal_length) + u = sympify(u) + v = sympify(v) + if u is oo: + _u = Symbol('u') + if v is oo: + _v = Symbol('v') + if focal_length is oo: + _f = Symbol('f') + if focal_length is None: + if u is oo and v is oo: + return Limit(Limit(_v*_u/(_u - _v), _u, oo), _v, oo).doit() + if u is oo: + return Limit(v*_u/(_u - v), _u, oo).doit() + if v is oo: + return Limit(_v*u/(u - _v), _v, oo).doit() + return v*u/(u - v) + if u is None: + if v is oo and focal_length is oo: + return Limit(Limit(_v*_f/(_f - _v), _v, oo), _f, oo).doit() + if v is oo: + return Limit(_v*focal_length/(focal_length - _v), _v, oo).doit() + if focal_length is oo: + return Limit(v*_f/(_f - v), _f, oo).doit() + return v*focal_length/(focal_length - v) + if v is None: + if u is oo and focal_length is oo: + return Limit(Limit(_u*_f/(_u + _f), _u, oo), _f, oo).doit() + if u is oo: + return Limit(_u*focal_length/(_u + focal_length), _u, oo).doit() + if focal_length is oo: + return Limit(u*_f/(u + _f), _f, oo).doit() + return u*focal_length/(u + focal_length) + +def hyperfocal_distance(f, N, c): + """ + + Parameters + ========== + + f: sympifiable + Focal length of a given lens. + + N: sympifiable + F-number of a given lens. + + c: sympifiable + Circle of Confusion (CoC) of a given image format. + + Example + ======= + + >>> from sympy.physics.optics import hyperfocal_distance + >>> round(hyperfocal_distance(f = 0.5, N = 8, c = 0.0033), 2) + 9.47 + """ + + f = sympify(f) + N = sympify(N) + c = sympify(c) + + return (1/(N * c))*(f**2) + +def transverse_magnification(si, so): + """ + + Calculates the transverse magnification upon reflection in a mirror, + which is the ratio of the image size to the object size. + + Parameters + ========== + + so: sympifiable + Lens-object distance. + + si: sympifiable + Lens-image distance. + + Example + ======= + + >>> from sympy.physics.optics import transverse_magnification + >>> transverse_magnification(30, 15) + -2 + + """ + + si = sympify(si) + so = sympify(so) + + return (-(si/so)) diff --git a/lib/python3.10/site-packages/sympy/physics/optics/waves.py b/lib/python3.10/site-packages/sympy/physics/optics/waves.py new file mode 100644 index 0000000000000000000000000000000000000000..61e2ff4db578543f9f2694f239f03439bfab2c41 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/optics/waves.py @@ -0,0 +1,340 @@ +""" +This module has all the classes and functions related to waves in optics. + +**Contains** + +* TWave +""" + +__all__ = ['TWave'] + +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.function import Derivative, Function +from sympy.core.numbers import (Number, pi, I) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import _sympify, sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.physics.units import speed_of_light, meter, second + + +c = speed_of_light.convert_to(meter/second) + + +class TWave(Expr): + + r""" + This is a simple transverse sine wave travelling in a one-dimensional space. + Basic properties are required at the time of creation of the object, + but they can be changed later with respective methods provided. + + Explanation + =========== + + It is represented as :math:`A \times cos(k*x - \omega \times t + \phi )`, + where :math:`A` is the amplitude, :math:`\omega` is the angular frequency, + :math:`k` is the wavenumber (spatial frequency), :math:`x` is a spatial variable + to represent the position on the dimension on which the wave propagates, + and :math:`\phi` is the phase angle of the wave. + + + Arguments + ========= + + amplitude : Sympifyable + Amplitude of the wave. + frequency : Sympifyable + Frequency of the wave. + phase : Sympifyable + Phase angle of the wave. + time_period : Sympifyable + Time period of the wave. + n : Sympifyable + Refractive index of the medium. + + Raises + ======= + + ValueError : When neither frequency nor time period is provided + or they are not consistent. + TypeError : When anything other than TWave objects is added. + + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A1, phi1, A2, phi2, f = symbols('A1, phi1, A2, phi2, f') + >>> w1 = TWave(A1, f, phi1) + >>> w2 = TWave(A2, f, phi2) + >>> w3 = w1 + w2 # Superposition of two waves + >>> w3 + TWave(sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2), f, + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)), 1/f, n) + >>> w3.amplitude + sqrt(A1**2 + 2*A1*A2*cos(phi1 - phi2) + A2**2) + >>> w3.phase + atan2(A1*sin(phi1) + A2*sin(phi2), A1*cos(phi1) + A2*cos(phi2)) + >>> w3.speed + 299792458*meter/(second*n) + >>> w3.angular_velocity + 2*pi*f + + """ + + def __new__( + cls, + amplitude, + frequency=None, + phase=S.Zero, + time_period=None, + n=Symbol('n')): + if time_period is not None: + time_period = _sympify(time_period) + _frequency = S.One/time_period + if frequency is not None: + frequency = _sympify(frequency) + _time_period = S.One/frequency + if time_period is not None: + if frequency != S.One/time_period: + raise ValueError("frequency and time_period should be consistent.") + if frequency is None and time_period is None: + raise ValueError("Either frequency or time period is needed.") + if frequency is None: + frequency = _frequency + if time_period is None: + time_period = _time_period + + amplitude = _sympify(amplitude) + phase = _sympify(phase) + n = sympify(n) + obj = Basic.__new__(cls, amplitude, frequency, phase, time_period, n) + return obj + + @property + def amplitude(self): + """ + Returns the amplitude of the wave. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.amplitude + A + """ + return self.args[0] + + @property + def frequency(self): + """ + Returns the frequency of the wave, + in cycles per second. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.frequency + f + """ + return self.args[1] + + @property + def phase(self): + """ + Returns the phase angle of the wave, + in radians. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.phase + phi + """ + return self.args[2] + + @property + def time_period(self): + """ + Returns the temporal period of the wave, + in seconds per cycle. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.time_period + 1/f + """ + return self.args[3] + + @property + def n(self): + """ + Returns the refractive index of the medium + """ + return self.args[4] + + @property + def wavelength(self): + """ + Returns the wavelength (spatial period) of the wave, + in meters per cycle. + It depends on the medium of the wave. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.wavelength + 299792458*meter/(second*f*n) + """ + return c/(self.frequency*self.n) + + + @property + def speed(self): + """ + Returns the propagation speed of the wave, + in meters per second. + It is dependent on the propagation medium. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.speed + 299792458*meter/(second*n) + """ + return self.wavelength*self.frequency + + @property + def angular_velocity(self): + """ + Returns the angular velocity of the wave, + in radians per second. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.angular_velocity + 2*pi*f + """ + return 2*pi*self.frequency + + @property + def wavenumber(self): + """ + Returns the wavenumber of the wave, + in radians per meter. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.optics import TWave + >>> A, phi, f = symbols('A, phi, f') + >>> w = TWave(A, f, phi) + >>> w.wavenumber + pi*second*f*n/(149896229*meter) + """ + return 2*pi/self.wavelength + + def __str__(self): + """String representation of a TWave.""" + from sympy.printing import sstr + return type(self).__name__ + sstr(self.args) + + __repr__ = __str__ + + def __add__(self, other): + """ + Addition of two waves will result in their superposition. + The type of interference will depend on their phase angles. + """ + if isinstance(other, TWave): + if self.frequency == other.frequency and self.wavelength == other.wavelength: + return TWave(sqrt(self.amplitude**2 + other.amplitude**2 + 2 * + self.amplitude*other.amplitude*cos( + self.phase - other.phase)), + self.frequency, + atan2(self.amplitude*sin(self.phase) + + other.amplitude*sin(other.phase), + self.amplitude*cos(self.phase) + + other.amplitude*cos(other.phase)) + ) + else: + raise NotImplementedError("Interference of waves with different frequencies" + " has not been implemented.") + else: + raise TypeError(type(other).__name__ + " and TWave objects cannot be added.") + + def __mul__(self, other): + """ + Multiplying a wave by a scalar rescales the amplitude of the wave. + """ + other = sympify(other) + if isinstance(other, Number): + return TWave(self.amplitude*other, *self.args[1:]) + else: + raise TypeError(type(other).__name__ + " and TWave objects cannot be multiplied.") + + def __sub__(self, other): + return self.__add__(-1*other) + + def __neg__(self): + return self.__mul__(-1) + + def __radd__(self, other): + return self.__add__(other) + + def __rmul__(self, other): + return self.__mul__(other) + + def __rsub__(self, other): + return (-self).__radd__(other) + + def _eval_rewrite_as_sin(self, *args, **kwargs): + return self.amplitude*sin(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase + pi/2, evaluate=False) + + def _eval_rewrite_as_cos(self, *args, **kwargs): + return self.amplitude*cos(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase) + + def _eval_rewrite_as_pde(self, *args, **kwargs): + mu, epsilon, x, t = symbols('mu, epsilon, x, t') + E = Function('E') + return Derivative(E(x, t), x, 2) + mu*epsilon*Derivative(E(x, t), t, 2) + + def _eval_rewrite_as_exp(self, *args, **kwargs): + return self.amplitude*exp(I*(self.wavenumber*Symbol('x') + - self.angular_velocity*Symbol('t') + self.phase)) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__init__.py b/lib/python3.10/site-packages/sympy/physics/quantum/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf08e1f7a383eb09cac9400f772c487cf6176375 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/__init__.py @@ -0,0 +1,59 @@ +# Names exposed by 'from sympy.physics.quantum import *' + +__all__ = [ + 'AntiCommutator', + + 'qapply', + + 'Commutator', + + 'Dagger', + + 'HilbertSpaceError', 'HilbertSpace', 'TensorProductHilbertSpace', + 'TensorPowerHilbertSpace', 'DirectSumHilbertSpace', 'ComplexSpace', 'L2', + 'FockSpace', + + 'InnerProduct', + + 'Operator', 'HermitianOperator', 'UnitaryOperator', 'IdentityOperator', + 'OuterProduct', 'DifferentialOperator', + + 'represent', 'rep_innerproduct', 'rep_expectation', 'integrate_result', + 'get_basis', 'enumerate_states', + + 'KetBase', 'BraBase', 'StateBase', 'State', 'Ket', 'Bra', 'TimeDepState', + 'TimeDepBra', 'TimeDepKet', 'OrthogonalKet', 'OrthogonalBra', + 'OrthogonalState', 'Wavefunction', + + 'TensorProduct', 'tensor_product_simp', + + 'hbar', 'HBar', + +] +from .anticommutator import AntiCommutator + +from .qapply import qapply + +from .commutator import Commutator + +from .dagger import Dagger + +from .hilbert import (HilbertSpaceError, HilbertSpace, + TensorProductHilbertSpace, TensorPowerHilbertSpace, + DirectSumHilbertSpace, ComplexSpace, L2, FockSpace) + +from .innerproduct import InnerProduct + +from .operator import (Operator, HermitianOperator, UnitaryOperator, + IdentityOperator, OuterProduct, DifferentialOperator) + +from .represent import (represent, rep_innerproduct, rep_expectation, + integrate_result, get_basis, enumerate_states) + +from .state import (KetBase, BraBase, StateBase, State, Ket, Bra, + TimeDepState, TimeDepBra, TimeDepKet, OrthogonalKet, + OrthogonalBra, OrthogonalState, Wavefunction) + +from .tensorproduct import TensorProduct, tensor_product_simp + +from .constants import hbar, HBar diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18b2b9dc9b56117843b21f0e9bc25e648636de53 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/anticommutator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/anticommutator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c5778551cd84ab57974f3306be8716065fc38258 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/anticommutator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/boson.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/boson.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13416c8baa67d4fb13f678e6838288408d2cd52a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/boson.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/cartesian.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/cartesian.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a6f17e601b7b410d1996c857515eb69ff2286c4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/cartesian.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/cg.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/cg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a59acb1e906541d0e3c2a6a3949b3ebcb56e757 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/cg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/circuitplot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/circuitplot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5a89dff46ef6a16f11e6a50c226fe6a96776370 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/circuitplot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/circuitutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/circuitutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f196e4eb97721c50040a4c4b9bc96206fb6814ee Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/circuitutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/commutator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/commutator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..088c653ebf95c55a827515cb4674b1cdd0e5a177 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/commutator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/constants.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a8427218ddcd845222a1fe2cd4205d8c2d1c702 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/constants.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/dagger.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/dagger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d962d9ada4d417b2b0d01ee2eac1aa968be06c77 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/dagger.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/density.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/density.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c7c6710c18b1e2479f02f8532d9912e12605fbe Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/density.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/fermion.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/fermion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..656296f651c3d8ed6efbec4fec03a0c9f03bab05 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/fermion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/gate.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/gate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40eda50a600f4bb937cd6320aec8adb95e4455d0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/gate.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/grover.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/grover.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..752ff8326e0689e18843bcb4e8a331f3794a2e93 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/grover.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/hilbert.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/hilbert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..78588e935959f55b7e7b21804dd0aab416185a11 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/hilbert.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/identitysearch.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/identitysearch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9042a56d22c70fc47bdadaf5fa4558bcc9d256b3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/identitysearch.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/innerproduct.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/innerproduct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd1cbef5ce46ce70844a9a0ed8061f1e7f784fd5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/innerproduct.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/matrixcache.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/matrixcache.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39a6e2f8ffe330ced957a687530a394ebc8b9583 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/matrixcache.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/matrixutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/matrixutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80bd3ad15575d3b4188a6b45eb4fc0037aa31421 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/matrixutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b28f8cd6ff74cf25d881e33869feaf65d7514ffb Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operatorordering.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operatorordering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb10ceb3543538e9325ec60cd33f7d1c0ea2b733 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operatorordering.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operatorset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operatorset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0146bc683ea7f96295c82b89baa78cea3aee2bde Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/operatorset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/pauli.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/pauli.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67bae082e382e542c18060e747dbe82d3ea76c31 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/pauli.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/piab.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/piab.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..595b98ff0535d2ed1ff30456cf3c251c9bfe928a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/piab.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qapply.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qapply.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c59a31ff18de059d8088b7379c94bc0db5f57d3e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qapply.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qasm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qasm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b6995527f010d83ad8771911c9ab122e1e221ce0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qasm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qexpr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qexpr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b4e78b737d10b1cb619464d7e578e1b44d4d961 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qexpr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qft.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fc7871efd7aff9380895cba88ae3cbda80f7189 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qft.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qubit.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qubit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..07a658273d5d50aeeea5fcb5da50de4454eb0b4e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/qubit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/represent.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/represent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..039ca2db7ff2f0f1db3b165736c43e5250d05ad3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/represent.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/sho1d.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/sho1d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eae586de3cd5ea21484444672a416f14a5d7c8db Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/sho1d.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/shor.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/shor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6879860360cab8c6cc759ab8f0603dcdc41a2a65 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/shor.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/spin.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/spin.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7122dccb8ee184e356c5ad04bd56216a953cddc6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/spin.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/state.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39c6e37457ed9ba1a961ff829262bfee2d1623a3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/state.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/tensorproduct.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/tensorproduct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7586df30d45d127bf7da8e264679d6e382d0886d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/tensorproduct.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/trace.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/trace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..850dffacbadf32145a9d8aa1e005df13666b5457 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/__pycache__/trace.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/anticommutator.py b/lib/python3.10/site-packages/sympy/physics/quantum/anticommutator.py new file mode 100644 index 0000000000000000000000000000000000000000..a73f1c20779322d47356b619231fa418e88ab101 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/anticommutator.py @@ -0,0 +1,149 @@ +"""The anti-commutator: ``{A,B} = A*B + B*A``.""" + +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.dagger import Dagger + +__all__ = [ + 'AntiCommutator' +] + +#----------------------------------------------------------------------------- +# Anti-commutator +#----------------------------------------------------------------------------- + + +class AntiCommutator(Expr): + """The standard anticommutator, in an unevaluated state. + + Explanation + =========== + + Evaluating an anticommutator is defined [1]_ as: ``{A, B} = A*B + B*A``. + This class returns the anticommutator in an unevaluated form. To evaluate + the anticommutator, use the ``.doit()`` method. + + Canonical ordering of an anticommutator is ``{A, B}`` for ``A < B``. The + arguments of the anticommutator are put into canonical order using + ``__cmp__``. If ``B < A``, then ``{A, B}`` is returned as ``{B, A}``. + + Parameters + ========== + + A : Expr + The first argument of the anticommutator {A,B}. + B : Expr + The second argument of the anticommutator {A,B}. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.quantum import AntiCommutator + >>> from sympy.physics.quantum import Operator, Dagger + >>> x, y = symbols('x,y') + >>> A = Operator('A') + >>> B = Operator('B') + + Create an anticommutator and use ``doit()`` to multiply them out. + + >>> ac = AntiCommutator(A,B); ac + {A,B} + >>> ac.doit() + A*B + B*A + + The commutator orders it arguments in canonical order: + + >>> ac = AntiCommutator(B,A); ac + {A,B} + + Commutative constants are factored out: + + >>> AntiCommutator(3*x*A,x*y*B) + 3*x**2*y*{A,B} + + Adjoint operations applied to the anticommutator are properly applied to + the arguments: + + >>> Dagger(AntiCommutator(A,B)) + {Dagger(A),Dagger(B)} + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Commutator + """ + is_commutative = False + + def __new__(cls, A, B): + r = cls.eval(A, B) + if r is not None: + return r + obj = Expr.__new__(cls, A, B) + return obj + + @classmethod + def eval(cls, a, b): + if not (a and b): + return S.Zero + if a == b: + return Integer(2)*a**2 + if a.is_commutative or b.is_commutative: + return Integer(2)*a*b + + # [xA,yB] -> xy*[A,B] + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = ca + cb + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # Canonical ordering of arguments + #The Commutator [A,B] is on canonical form if A < B. + if a.compare(b) == 1: + return cls(b, a) + + def doit(self, **hints): + """ Evaluate anticommutator """ + A = self.args[0] + B = self.args[1] + if isinstance(A, Operator) and isinstance(B, Operator): + try: + comm = A._eval_anticommutator(B, **hints) + except NotImplementedError: + try: + comm = B._eval_anticommutator(A, **hints) + except NotImplementedError: + comm = None + if comm is not None: + return comm.doit(**hints) + return (A*B + B*A).doit(**hints) + + def _eval_adjoint(self): + return AntiCommutator(Dagger(self.args[0]), Dagger(self.args[1])) + + def _sympyrepr(self, printer, *args): + return "%s(%s,%s)" % ( + self.__class__.__name__, printer._print( + self.args[0]), printer._print(self.args[1]) + ) + + def _sympystr(self, printer, *args): + return "{%s,%s}" % ( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _pretty(self, printer, *args): + pform = printer._print(self.args[0], *args) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(printer._print(self.args[1], *args))) + pform = prettyForm(*pform.parens(left='{', right='}')) + return pform + + def _latex(self, printer, *args): + return "\\left\\{%s,%s\\right\\}" % tuple([ + printer._print(arg, *args) for arg in self.args]) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/boson.py b/lib/python3.10/site-packages/sympy/physics/quantum/boson.py new file mode 100644 index 0000000000000000000000000000000000000000..3be2ebc45c392e8733de7e58528e9a0567273e73 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/boson.py @@ -0,0 +1,259 @@ +"""Bosonic quantum operators.""" + +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum import Operator +from sympy.physics.quantum import HilbertSpace, FockSpace, Ket, Bra, IdentityOperator +from sympy.functions.special.tensor_functions import KroneckerDelta + + +__all__ = [ + 'BosonOp', + 'BosonFockKet', + 'BosonFockBra', + 'BosonCoherentKet', + 'BosonCoherentBra' +] + + +class BosonOp(Operator): + """A bosonic operator that satisfies [a, Dagger(a)] == 1. + + Parameters + ========== + + name : str + A string that labels the bosonic mode. + + annihilation : bool + A bool that indicates if the bosonic operator is an annihilation (True, + default value) or creation operator (False) + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, Commutator + >>> from sympy.physics.quantum.boson import BosonOp + >>> a = BosonOp("a") + >>> Commutator(a, Dagger(a)).doit() + 1 + """ + + @property + def name(self): + return self.args[0] + + @property + def is_annihilation(self): + return bool(self.args[1]) + + @classmethod + def default_args(self): + return ("a", True) + + def __new__(cls, *args, **hints): + if not len(args) in [1, 2]: + raise ValueError('1 or 2 parameters expected, got %s' % args) + + if len(args) == 1: + args = (args[0], S.One) + + if len(args) == 2: + args = (args[0], Integer(args[1])) + + return Operator.__new__(cls, *args) + + def _eval_commutator_BosonOp(self, other, **hints): + if self.name == other.name: + # [a^\dagger, a] = -1 + if not self.is_annihilation and other.is_annihilation: + return S.NegativeOne + + elif 'independent' in hints and hints['independent']: + # [a, b] = 0 + return S.Zero + + return None + + def _eval_commutator_FermionOp(self, other, **hints): + return S.Zero + + def _eval_anticommutator_BosonOp(self, other, **hints): + if 'independent' in hints and hints['independent']: + # {a, b} = 2 * a * b, because [a, b] = 0 + return 2 * self * other + + return None + + def _eval_adjoint(self): + return BosonOp(str(self.name), not self.is_annihilation) + + def __mul__(self, other): + + if other == IdentityOperator(2): + return self + + if isinstance(other, Mul): + args1 = tuple(arg for arg in other.args if arg.is_commutative) + args2 = tuple(arg for arg in other.args if not arg.is_commutative) + x = self + for y in args2: + x = x * y + return Mul(*args1) * x + + return Mul(self, other) + + def _print_contents_latex(self, printer, *args): + if self.is_annihilation: + return r'{%s}' % str(self.name) + else: + return r'{{%s}^\dagger}' % str(self.name) + + def _print_contents(self, printer, *args): + if self.is_annihilation: + return r'%s' % str(self.name) + else: + return r'Dagger(%s)' % str(self.name) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + if self.is_annihilation: + return pform + else: + return pform**prettyForm('\N{DAGGER}') + + +class BosonFockKet(Ket): + """Fock state ket for a bosonic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonFockBra + + @classmethod + def _eval_hilbert_space(cls, label): + return FockSpace() + + def _eval_innerproduct_BosonFockBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_BosonOp(self, op, **options): + if op.is_annihilation: + return sqrt(self.n) * BosonFockKet(self.n - 1) + else: + return sqrt(self.n + 1) * BosonFockKet(self.n + 1) + + +class BosonFockBra(Bra): + """Fock state bra for a bosonic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonFockKet + + @classmethod + def _eval_hilbert_space(cls, label): + return FockSpace() + + +class BosonCoherentKet(Ket): + """Coherent state ket for a bosonic mode. + + Parameters + ========== + + alpha : Number, Symbol + The complex amplitude of the coherent state. + + """ + + def __new__(cls, alpha): + return Ket.__new__(cls, alpha) + + @property + def alpha(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonCoherentBra + + @classmethod + def _eval_hilbert_space(cls, label): + return HilbertSpace() + + def _eval_innerproduct_BosonCoherentBra(self, bra, **hints): + if self.alpha == bra.alpha: + return S.One + else: + return exp(-(abs(self.alpha)**2 + abs(bra.alpha)**2 - 2 * conjugate(bra.alpha) * self.alpha)/2) + + def _apply_from_right_to_BosonOp(self, op, **options): + if op.is_annihilation: + return self.alpha * self + else: + return None + + +class BosonCoherentBra(Bra): + """Coherent state bra for a bosonic mode. + + Parameters + ========== + + alpha : Number, Symbol + The complex amplitude of the coherent state. + + """ + + def __new__(cls, alpha): + return Bra.__new__(cls, alpha) + + @property + def alpha(self): + return self.label[0] + + @classmethod + def dual_class(self): + return BosonCoherentKet + + def _apply_operator_BosonOp(self, op, **options): + if not op.is_annihilation: + return self.alpha * self + else: + return None diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/cartesian.py b/lib/python3.10/site-packages/sympy/physics/quantum/cartesian.py new file mode 100644 index 0000000000000000000000000000000000000000..f3af1856f22c8fe4535b24be30bf99d0b3541a50 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/cartesian.py @@ -0,0 +1,341 @@ +"""Operators and states for 1D cartesian position and momentum. + +TODO: + +* Add 3D classes to mappings in operatorset.py + +""" + +from sympy.core.numbers import (I, pi) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.delta_functions import DiracDelta +from sympy.sets.sets import Interval + +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.hilbert import L2 +from sympy.physics.quantum.operator import DifferentialOperator, HermitianOperator +from sympy.physics.quantum.state import Ket, Bra, State + +__all__ = [ + 'XOp', + 'YOp', + 'ZOp', + 'PxOp', + 'X', + 'Y', + 'Z', + 'Px', + 'XKet', + 'XBra', + 'PxKet', + 'PxBra', + 'PositionState3D', + 'PositionKet3D', + 'PositionBra3D' +] + +#------------------------------------------------------------------------- +# Position operators +#------------------------------------------------------------------------- + + +class XOp(HermitianOperator): + """1D cartesian position operator.""" + + @classmethod + def default_args(self): + return ("X",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _eval_commutator_PxOp(self, other): + return I*hbar + + def _apply_operator_XKet(self, ket, **options): + return ket.position*ket + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_x*ket + + def _represent_PxKet(self, basis, *, index=1, **options): + states = basis._enumerate_state(2, start_index=index) + coord1 = states[0].momentum + coord2 = states[1].momentum + d = DifferentialOperator(coord1) + delta = DiracDelta(coord1 - coord2) + + return I*hbar*(d*delta) + + +class YOp(HermitianOperator): + """ Y cartesian coordinate operator (for 2D or 3D systems) """ + + @classmethod + def default_args(self): + return ("Y",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_y*ket + + +class ZOp(HermitianOperator): + """ Z cartesian coordinate operator (for 3D systems) """ + + @classmethod + def default_args(self): + return ("Z",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PositionKet3D(self, ket, **options): + return ket.position_z*ket + +#------------------------------------------------------------------------- +# Momentum operators +#------------------------------------------------------------------------- + + +class PxOp(HermitianOperator): + """1D cartesian momentum operator.""" + + @classmethod + def default_args(self): + return ("Px",) + + @classmethod + def _eval_hilbert_space(self, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PxKet(self, ket, **options): + return ket.momentum*ket + + def _represent_XKet(self, basis, *, index=1, **options): + states = basis._enumerate_state(2, start_index=index) + coord1 = states[0].position + coord2 = states[1].position + d = DifferentialOperator(coord1) + delta = DiracDelta(coord1 - coord2) + + return -I*hbar*(d*delta) + +X = XOp('X') +Y = YOp('Y') +Z = ZOp('Z') +Px = PxOp('Px') + +#------------------------------------------------------------------------- +# Position eigenstates +#------------------------------------------------------------------------- + + +class XKet(Ket): + """1D cartesian position eigenket.""" + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("x",) + + @classmethod + def dual_class(self): + return XBra + + @property + def position(self): + """The position of the state.""" + return self.label[0] + + def _enumerate_state(self, num_states, **options): + return _enumerate_continuous_1D(self, num_states, **options) + + def _eval_innerproduct_XBra(self, bra, **hints): + return DiracDelta(self.position - bra.position) + + def _eval_innerproduct_PxBra(self, bra, **hints): + return exp(-I*self.position*bra.momentum/hbar)/sqrt(2*pi*hbar) + + +class XBra(Bra): + """1D cartesian position eigenbra.""" + + @classmethod + def default_args(self): + return ("x",) + + @classmethod + def dual_class(self): + return XKet + + @property + def position(self): + """The position of the state.""" + return self.label[0] + + +class PositionState3D(State): + """ Base class for 3D cartesian position eigenstates """ + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("x", "y", "z") + + @property + def position_x(self): + """ The x coordinate of the state """ + return self.label[0] + + @property + def position_y(self): + """ The y coordinate of the state """ + return self.label[1] + + @property + def position_z(self): + """ The z coordinate of the state """ + return self.label[2] + + +class PositionKet3D(Ket, PositionState3D): + """ 3D cartesian position eigenket """ + + def _eval_innerproduct_PositionBra3D(self, bra, **options): + x_diff = self.position_x - bra.position_x + y_diff = self.position_y - bra.position_y + z_diff = self.position_z - bra.position_z + + return DiracDelta(x_diff)*DiracDelta(y_diff)*DiracDelta(z_diff) + + @classmethod + def dual_class(self): + return PositionBra3D + + +# XXX: The type:ignore here is because mypy gives Definition of +# "_state_to_operators" in base class "PositionState3D" is incompatible with +# definition in base class "BraBase" +class PositionBra3D(Bra, PositionState3D): # type: ignore + """ 3D cartesian position eigenbra """ + + @classmethod + def dual_class(self): + return PositionKet3D + +#------------------------------------------------------------------------- +# Momentum eigenstates +#------------------------------------------------------------------------- + + +class PxKet(Ket): + """1D cartesian momentum eigenket.""" + + @classmethod + def _operators_to_state(self, op, **options): + return self.__new__(self, *_lowercase_labels(op), **options) + + def _state_to_operators(self, op_class, **options): + return op_class.__new__(op_class, + *_uppercase_labels(self), **options) + + @classmethod + def default_args(self): + return ("px",) + + @classmethod + def dual_class(self): + return PxBra + + @property + def momentum(self): + """The momentum of the state.""" + return self.label[0] + + def _enumerate_state(self, *args, **options): + return _enumerate_continuous_1D(self, *args, **options) + + def _eval_innerproduct_XBra(self, bra, **hints): + return exp(I*self.momentum*bra.position/hbar)/sqrt(2*pi*hbar) + + def _eval_innerproduct_PxBra(self, bra, **hints): + return DiracDelta(self.momentum - bra.momentum) + + +class PxBra(Bra): + """1D cartesian momentum eigenbra.""" + + @classmethod + def default_args(self): + return ("px",) + + @classmethod + def dual_class(self): + return PxKet + + @property + def momentum(self): + """The momentum of the state.""" + return self.label[0] + +#------------------------------------------------------------------------- +# Global helper functions +#------------------------------------------------------------------------- + + +def _enumerate_continuous_1D(*args, **options): + state = args[0] + num_states = args[1] + state_class = state.__class__ + index_list = options.pop('index_list', []) + + if len(index_list) == 0: + start_index = options.pop('start_index', 1) + index_list = list(range(start_index, start_index + num_states)) + + enum_states = [0 for i in range(len(index_list))] + + for i, ind in enumerate(index_list): + label = state.args[0] + enum_states[i] = state_class(str(label) + "_" + str(ind), **options) + + return enum_states + + +def _lowercase_labels(ops): + if not isinstance(ops, set): + ops = [ops] + + return [str(arg.label[0]).lower() for arg in ops] + + +def _uppercase_labels(ops): + if not isinstance(ops, set): + ops = [ops] + + new_args = [str(arg.label[0])[0].upper() + + str(arg.label[0])[1:] for arg in ops] + + return new_args diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/cg.py b/lib/python3.10/site-packages/sympy/physics/quantum/cg.py new file mode 100644 index 0000000000000000000000000000000000000000..0f285cd39413a953246777c42fb6763c22a5716b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/cg.py @@ -0,0 +1,754 @@ +#TODO: +# -Implement Clebsch-Gordan symmetries +# -Improve simplification method +# -Implement new simplifications +"""Clebsch-Gordon Coefficients.""" + +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import expand +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.printing.pretty.stringpict import prettyForm, stringPict + +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.wigner import clebsch_gordan, wigner_3j, wigner_6j, wigner_9j +from sympy.printing.precedence import PRECEDENCE + +__all__ = [ + 'CG', + 'Wigner3j', + 'Wigner6j', + 'Wigner9j', + 'cg_simp' +] + +#----------------------------------------------------------------------------- +# CG Coefficients +#----------------------------------------------------------------------------- + + +class Wigner3j(Expr): + """Class for the Wigner-3j symbols. + + Explanation + =========== + + Wigner 3j-symbols are coefficients determined by the coupling of + two angular momenta. When created, they are expressed as symbolic + quantities that, for numerical parameters, can be evaluated using the + ``.doit()`` method [1]_. + + Parameters + ========== + + j1, m1, j2, m2, j3, m3 : Number, Symbol + Terms determining the angular momentum of coupled angular momentum + systems. + + Examples + ======== + + Declare a Wigner-3j coefficient and calculate its value + + >>> from sympy.physics.quantum.cg import Wigner3j + >>> w3j = Wigner3j(6,0,4,0,2,0) + >>> w3j + Wigner3j(6, 0, 4, 0, 2, 0) + >>> w3j.doit() + sqrt(715)/143 + + See Also + ======== + + CG: Clebsch-Gordan coefficients + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + is_commutative = True + + def __new__(cls, j1, m1, j2, m2, j3, m3): + args = map(sympify, (j1, m1, j2, m2, j3, m3)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def m1(self): + return self.args[1] + + @property + def j2(self): + return self.args[2] + + @property + def m2(self): + return self.args[3] + + @property + def j3(self): + return self.args[4] + + @property + def m3(self): + return self.args[5] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ((printer._print(self.j1), printer._print(self.m1)), + (printer._print(self.j2), printer._print(self.m2)), + (printer._print(self.j3), printer._print(self.m3))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(2)) + D = None + for i in range(2): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens()) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j3, + self.m1, self.m2, self.m3)) + return r'\left(\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \end{array}\right)' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_3j(self.j1, self.j2, self.j3, self.m1, self.m2, self.m3) + + +class CG(Wigner3j): + r"""Class for Clebsch-Gordan coefficient. + + Explanation + =========== + + Clebsch-Gordan coefficients describe the angular momentum coupling between + two systems. The coefficients give the expansion of a coupled total angular + momentum state and an uncoupled tensor product state. The Clebsch-Gordan + coefficients are defined as [1]_: + + .. math :: + C^{j_3,m_3}_{j_1,m_1,j_2,m_2} = \left\langle j_1,m_1;j_2,m_2 | j_3,m_3\right\rangle + + Parameters + ========== + + j1, m1, j2, m2 : Number, Symbol + Angular momenta of states 1 and 2. + + j3, m3: Number, Symbol + Total angular momentum of the coupled system. + + Examples + ======== + + Define a Clebsch-Gordan coefficient and evaluate its value + + >>> from sympy.physics.quantum.cg import CG + >>> from sympy import S + >>> cg = CG(S(3)/2, S(3)/2, S(1)/2, -S(1)/2, 1, 1) + >>> cg + CG(3/2, 3/2, 1/2, -1/2, 1, 1) + >>> cg.doit() + sqrt(3)/2 + >>> CG(j1=S(1)/2, m1=-S(1)/2, j2=S(1)/2, m2=+S(1)/2, j3=1, m3=0).doit() + sqrt(2)/2 + + + Compare [2]_. + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + .. [2] `Clebsch-Gordan Coefficients, Spherical Harmonics, and d Functions + `_ + in P.A. Zyla *et al.* (Particle Data Group), Prog. Theor. Exp. Phys. + 2020, 083C01 (2020). + """ + precedence = PRECEDENCE["Pow"] - 1 + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return clebsch_gordan(self.j1, self.j2, self.j3, self.m1, self.m2, self.m3) + + def _pretty(self, printer, *args): + bot = printer._print_seq( + (self.j1, self.m1, self.j2, self.m2), delimiter=',') + top = printer._print_seq((self.j3, self.m3), delimiter=',') + + pad = max(top.width(), bot.width()) + bot = prettyForm(*bot.left(' ')) + top = prettyForm(*top.left(' ')) + + if not pad == bot.width(): + bot = prettyForm(*bot.right(' '*(pad - bot.width()))) + if not pad == top.width(): + top = prettyForm(*top.right(' '*(pad - top.width()))) + s = stringPict('C' + ' '*pad) + s = prettyForm(*s.below(bot)) + s = prettyForm(*s.above(top)) + return s + + def _latex(self, printer, *args): + label = map(printer._print, (self.j3, self.m3, self.j1, + self.m1, self.j2, self.m2)) + return r'C^{%s,%s}_{%s,%s,%s,%s}' % tuple(label) + + +class Wigner6j(Expr): + """Class for the Wigner-6j symbols + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + """ + def __new__(cls, j1, j2, j12, j3, j, j23): + args = map(sympify, (j1, j2, j12, j3, j, j23)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def j2(self): + return self.args[1] + + @property + def j12(self): + return self.args[2] + + @property + def j3(self): + return self.args[3] + + @property + def j(self): + return self.args[4] + + @property + def j23(self): + return self.args[5] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ((printer._print(self.j1), printer._print(self.j3)), + (printer._print(self.j2), printer._print(self.j)), + (printer._print(self.j12), printer._print(self.j23))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(2)) + D = None + for i in range(2): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens(left='{', right='}')) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j12, + self.j3, self.j, self.j23)) + return r'\left\{\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \end{array}\right\}' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_6j(self.j1, self.j2, self.j12, self.j3, self.j, self.j23) + + +class Wigner9j(Expr): + """Class for the Wigner-9j symbols + + See Also + ======== + + Wigner3j: Wigner-3j symbols + + """ + def __new__(cls, j1, j2, j12, j3, j4, j34, j13, j24, j): + args = map(sympify, (j1, j2, j12, j3, j4, j34, j13, j24, j)) + return Expr.__new__(cls, *args) + + @property + def j1(self): + return self.args[0] + + @property + def j2(self): + return self.args[1] + + @property + def j12(self): + return self.args[2] + + @property + def j3(self): + return self.args[3] + + @property + def j4(self): + return self.args[4] + + @property + def j34(self): + return self.args[5] + + @property + def j13(self): + return self.args[6] + + @property + def j24(self): + return self.args[7] + + @property + def j(self): + return self.args[8] + + @property + def is_symbolic(self): + return not all(arg.is_number for arg in self.args) + + # This is modified from the _print_Matrix method + def _pretty(self, printer, *args): + m = ( + (printer._print( + self.j1), printer._print(self.j3), printer._print(self.j13)), + (printer._print( + self.j2), printer._print(self.j4), printer._print(self.j24)), + (printer._print(self.j12), printer._print(self.j34), printer._print(self.j))) + hsep = 2 + vsep = 1 + maxw = [-1]*3 + for j in range(3): + maxw[j] = max(m[j][i].width() for i in range(3)) + D = None + for i in range(3): + D_row = None + for j in range(3): + s = m[j][i] + wdelta = maxw[j] - s.width() + wleft = wdelta //2 + wright = wdelta - wleft + + s = prettyForm(*s.right(' '*wright)) + s = prettyForm(*s.left(' '*wleft)) + + if D_row is None: + D_row = s + continue + D_row = prettyForm(*D_row.right(' '*hsep)) + D_row = prettyForm(*D_row.right(s)) + if D is None: + D = D_row + continue + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + D = prettyForm(*D.parens(left='{', right='}')) + return D + + def _latex(self, printer, *args): + label = map(printer._print, (self.j1, self.j2, self.j12, self.j3, + self.j4, self.j34, self.j13, self.j24, self.j)) + return r'\left\{\begin{array}{ccc} %s & %s & %s \\ %s & %s & %s \\ %s & %s & %s \end{array}\right\}' % \ + tuple(label) + + def doit(self, **hints): + if self.is_symbolic: + raise ValueError("Coefficients must be numerical") + return wigner_9j(self.j1, self.j2, self.j12, self.j3, self.j4, self.j34, self.j13, self.j24, self.j) + + +def cg_simp(e): + """Simplify and combine CG coefficients. + + Explanation + =========== + + This function uses various symmetry and properties of sums and + products of Clebsch-Gordan coefficients to simplify statements + involving these terms [1]_. + + Examples + ======== + + Simplify the sum over CG(a,alpha,0,0,a,alpha) for all alpha to + 2*a+1 + + >>> from sympy.physics.quantum.cg import CG, cg_simp + >>> a = CG(1,1,0,0,1,1) + >>> b = CG(1,0,0,0,1,0) + >>> c = CG(1,-1,0,0,1,-1) + >>> cg_simp(a+b+c) + 3 + + See Also + ======== + + CG: Clebsh-Gordan coefficients + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + if isinstance(e, Add): + return _cg_simp_add(e) + elif isinstance(e, Sum): + return _cg_simp_sum(e) + elif isinstance(e, Mul): + return Mul(*[cg_simp(arg) for arg in e.args]) + elif isinstance(e, Pow): + return Pow(cg_simp(e.base), e.exp) + else: + return e + + +def _cg_simp_add(e): + #TODO: Improve simplification method + """Takes a sum of terms involving Clebsch-Gordan coefficients and + simplifies the terms. + + Explanation + =========== + + First, we create two lists, cg_part, which is all the terms involving CG + coefficients, and other_part, which is all other terms. The cg_part list + is then passed to the simplification methods, which return the new cg_part + and any additional terms that are added to other_part + """ + cg_part = [] + other_part = [] + + e = expand(e) + for arg in e.args: + if arg.has(CG): + if isinstance(arg, Sum): + other_part.append(_cg_simp_sum(arg)) + elif isinstance(arg, Mul): + terms = 1 + for term in arg.args: + if isinstance(term, Sum): + terms *= _cg_simp_sum(term) + else: + terms *= term + if terms.has(CG): + cg_part.append(terms) + else: + other_part.append(terms) + else: + cg_part.append(arg) + else: + other_part.append(arg) + + cg_part, other = _check_varsh_871_1(cg_part) + other_part.append(other) + cg_part, other = _check_varsh_871_2(cg_part) + other_part.append(other) + cg_part, other = _check_varsh_872_9(cg_part) + other_part.append(other) + return Add(*cg_part) + Add(*other_part) + + +def _check_varsh_871_1(term_list): + # Sum( CG(a,alpha,b,0,a,alpha), (alpha, -a, a)) == KroneckerDelta(b,0) + a, alpha, b, lt = map(Wild, ('a', 'alpha', 'b', 'lt')) + expr = lt*CG(a, alpha, b, 0, a, alpha) + simp = (2*a + 1)*KroneckerDelta(b, 0) + sign = lt/abs(lt) + build_expr = 2*a + 1 + index_expr = a + alpha + return _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, lt), (a, b), build_expr, index_expr) + + +def _check_varsh_871_2(term_list): + # Sum((-1)**(a-alpha)*CG(a,alpha,a,-alpha,c,0),(alpha,-a,a)) + a, alpha, c, lt = map(Wild, ('a', 'alpha', 'c', 'lt')) + expr = lt*CG(a, alpha, a, -alpha, c, 0) + simp = sqrt(2*a + 1)*KroneckerDelta(c, 0) + sign = (-1)**(a - alpha)*lt/abs(lt) + build_expr = 2*a + 1 + index_expr = a + alpha + return _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, c, lt), (a, c), build_expr, index_expr) + + +def _check_varsh_872_9(term_list): + # Sum( CG(a,alpha,b,beta,c,gamma)*CG(a,alpha',b,beta',c,gamma), (gamma, -c, c), (c, abs(a-b), a+b)) + a, alpha, alphap, b, beta, betap, c, gamma, lt = map(Wild, ( + 'a', 'alpha', 'alphap', 'b', 'beta', 'betap', 'c', 'gamma', 'lt')) + # Case alpha==alphap, beta==betap + + # For numerical alpha,beta + expr = lt*CG(a, alpha, b, beta, c, gamma)**2 + simp = S.One + sign = lt/abs(lt) + x = abs(a - b) + y = abs(alpha + beta) + build_expr = a + b + 1 - Piecewise((x, x > y), (0, Eq(x, y)), (y, y > x)) + index_expr = a + b - c + term_list, other1 = _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, beta, c, gamma, lt), (a, alpha, b, beta), build_expr, index_expr) + + # For symbolic alpha,beta + x = abs(a - b) + y = a + b + build_expr = (y + 1 - x)*(x + y + 1) + index_expr = (c - x)*(x + c) + c + gamma + term_list, other2 = _check_cg_simp(expr, simp, sign, lt, term_list, (a, alpha, b, beta, c, gamma, lt), (a, alpha, b, beta), build_expr, index_expr) + + # Case alpha!=alphap or beta!=betap + # Note: this only works with leading term of 1, pattern matching is unable to match when there is a Wild leading term + # For numerical alpha,alphap,beta,betap + expr = CG(a, alpha, b, beta, c, gamma)*CG(a, alphap, b, betap, c, gamma) + simp = KroneckerDelta(alpha, alphap)*KroneckerDelta(beta, betap) + sign = S.One + x = abs(a - b) + y = abs(alpha + beta) + build_expr = a + b + 1 - Piecewise((x, x > y), (0, Eq(x, y)), (y, y > x)) + index_expr = a + b - c + term_list, other3 = _check_cg_simp(expr, simp, sign, S.One, term_list, (a, alpha, alphap, b, beta, betap, c, gamma), (a, alpha, alphap, b, beta, betap), build_expr, index_expr) + + # For symbolic alpha,alphap,beta,betap + x = abs(a - b) + y = a + b + build_expr = (y + 1 - x)*(x + y + 1) + index_expr = (c - x)*(x + c) + c + gamma + term_list, other4 = _check_cg_simp(expr, simp, sign, S.One, term_list, (a, alpha, alphap, b, beta, betap, c, gamma), (a, alpha, alphap, b, beta, betap), build_expr, index_expr) + + return term_list, other1 + other2 + other4 + + +def _check_cg_simp(expr, simp, sign, lt, term_list, variables, dep_variables, build_index_expr, index_expr): + """ Checks for simplifications that can be made, returning a tuple of the + simplified list of terms and any terms generated by simplification. + + Parameters + ========== + + expr: expression + The expression with Wild terms that will be matched to the terms in + the sum + + simp: expression + The expression with Wild terms that is substituted in place of the CG + terms in the case of simplification + + sign: expression + The expression with Wild terms denoting the sign that is on expr that + must match + + lt: expression + The expression with Wild terms that gives the leading term of the + matched expr + + term_list: list + A list of all of the terms is the sum to be simplified + + variables: list + A list of all the variables that appears in expr + + dep_variables: list + A list of the variables that must match for all the terms in the sum, + i.e. the dependent variables + + build_index_expr: expression + Expression with Wild terms giving the number of elements in cg_index + + index_expr: expression + Expression with Wild terms giving the index terms have when storing + them to cg_index + + """ + other_part = 0 + i = 0 + while i < len(term_list): + sub_1 = _check_cg(term_list[i], expr, len(variables)) + if sub_1 is None: + i += 1 + continue + if not build_index_expr.subs(sub_1).is_number: + i += 1 + continue + sub_dep = [(x, sub_1[x]) for x in dep_variables] + cg_index = [None]*build_index_expr.subs(sub_1) + for j in range(i, len(term_list)): + sub_2 = _check_cg(term_list[j], expr.subs(sub_dep), len(variables) - len(dep_variables), sign=(sign.subs(sub_1), sign.subs(sub_dep))) + if sub_2 is None: + continue + if not index_expr.subs(sub_dep).subs(sub_2).is_number: + continue + cg_index[index_expr.subs(sub_dep).subs(sub_2)] = j, expr.subs(lt, 1).subs(sub_dep).subs(sub_2), lt.subs(sub_2), sign.subs(sub_dep).subs(sub_2) + if not any(i is None for i in cg_index): + min_lt = min(*[ abs(term[2]) for term in cg_index ]) + indices = [ term[0] for term in cg_index] + indices.sort() + indices.reverse() + [ term_list.pop(j) for j in indices ] + for term in cg_index: + if abs(term[2]) > min_lt: + term_list.append( (term[2] - min_lt*term[3])*term[1] ) + other_part += min_lt*(sign*simp).subs(sub_1) + else: + i += 1 + return term_list, other_part + + +def _check_cg(cg_term, expr, length, sign=None): + """Checks whether a term matches the given expression""" + # TODO: Check for symmetries + matches = cg_term.match(expr) + if matches is None: + return + if sign is not None: + if not isinstance(sign, tuple): + raise TypeError('sign must be a tuple') + if not sign[0] == (sign[1]).subs(matches): + return + if len(matches) == length: + return matches + + +def _cg_simp_sum(e): + e = _check_varsh_sum_871_1(e) + e = _check_varsh_sum_871_2(e) + e = _check_varsh_sum_872_4(e) + return e + + +def _check_varsh_sum_871_1(e): + a = Wild('a') + alpha = symbols('alpha') + b = Wild('b') + match = e.match(Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a))) + if match is not None and len(match) == 2: + return ((2*a + 1)*KroneckerDelta(b, 0)).subs(match) + return e + + +def _check_varsh_sum_871_2(e): + a = Wild('a') + alpha = symbols('alpha') + c = Wild('c') + match = e.match( + Sum((-1)**(a - alpha)*CG(a, alpha, a, -alpha, c, 0), (alpha, -a, a))) + if match is not None and len(match) == 2: + return (sqrt(2*a + 1)*KroneckerDelta(c, 0)).subs(match) + return e + + +def _check_varsh_sum_872_4(e): + alpha = symbols('alpha') + beta = symbols('beta') + a = Wild('a') + b = Wild('b') + c = Wild('c') + cp = Wild('cp') + gamma = Wild('gamma') + gammap = Wild('gammap') + cg1 = CG(a, alpha, b, beta, c, gamma) + cg2 = CG(a, alpha, b, beta, cp, gammap) + match1 = e.match(Sum(cg1*cg2, (alpha, -a, a), (beta, -b, b))) + if match1 is not None and len(match1) == 6: + return (KroneckerDelta(c, cp)*KroneckerDelta(gamma, gammap)).subs(match1) + match2 = e.match(Sum(cg1**2, (alpha, -a, a), (beta, -b, b))) + if match2 is not None and len(match2) == 4: + return S.One + return e + + +def _cg_list(term): + if isinstance(term, CG): + return (term,), 1, 1 + cg = [] + coeff = 1 + if not isinstance(term, (Mul, Pow)): + raise NotImplementedError('term must be CG, Add, Mul or Pow') + if isinstance(term, Pow) and term.exp.is_number: + if term.exp.is_number: + [ cg.append(term.base) for _ in range(term.exp) ] + else: + return (term,), 1, 1 + if isinstance(term, Mul): + for arg in term.args: + if isinstance(arg, CG): + cg.append(arg) + else: + coeff *= arg + return cg, coeff, coeff/abs(coeff) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/circuitplot.py b/lib/python3.10/site-packages/sympy/physics/quantum/circuitplot.py new file mode 100644 index 0000000000000000000000000000000000000000..316a4be613b2e275565999130c06ea678acd8b96 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/circuitplot.py @@ -0,0 +1,370 @@ +"""Matplotlib based plotting of quantum circuits. + +Todo: + +* Optimize printing of large circuits. +* Get this to work with single gates. +* Do a better job checking the form of circuits to make sure it is a Mul of + Gates. +* Get multi-target gates plotting. +* Get initial and final states to plot. +* Get measurements to plot. Might need to rethink measurement as a gate + issue. +* Get scale and figsize to be handled in a better way. +* Write some tests/examples! +""" + +from __future__ import annotations + +from sympy.core.mul import Mul +from sympy.external import import_module +from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS + + +__all__ = [ + 'CircuitPlot', + 'circuit_plot', + 'labeller', + 'Mz', + 'Mx', + 'CreateOneQubitGate', + 'CreateCGate', +] + +np = import_module('numpy') +matplotlib = import_module( + 'matplotlib', import_kwargs={'fromlist': ['pyplot']}, + catch=(RuntimeError,)) # This is raised in environments that have no display. + +if np and matplotlib: + pyplot = matplotlib.pyplot + Line2D = matplotlib.lines.Line2D + Circle = matplotlib.patches.Circle + +#from matplotlib import rc +#rc('text',usetex=True) + +class CircuitPlot: + """A class for managing a circuit plot.""" + + scale = 1.0 + fontsize = 20.0 + linewidth = 1.0 + control_radius = 0.05 + not_radius = 0.15 + swap_delta = 0.05 + labels: list[str] = [] + inits: dict[str, str] = {} + label_buffer = 0.5 + + def __init__(self, c, nqubits, **kwargs): + if not np or not matplotlib: + raise ImportError('numpy or matplotlib not available.') + self.circuit = c + self.ngates = len(self.circuit.args) + self.nqubits = nqubits + self.update(kwargs) + self._create_grid() + self._create_figure() + self._plot_wires() + self._plot_gates() + self._finish() + + def update(self, kwargs): + """Load the kwargs into the instance dict.""" + self.__dict__.update(kwargs) + + def _create_grid(self): + """Create the grid of wires.""" + scale = self.scale + wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float) + gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float) + self._wire_grid = wire_grid + self._gate_grid = gate_grid + + def _create_figure(self): + """Create the main matplotlib figure.""" + self._figure = pyplot.figure( + figsize=(self.ngates*self.scale, self.nqubits*self.scale), + facecolor='w', + edgecolor='w' + ) + ax = self._figure.add_subplot( + 1, 1, 1, + frameon=True + ) + ax.set_axis_off() + offset = 0.5*self.scale + ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset) + ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset) + ax.set_aspect('equal') + self._axes = ax + + def _plot_wires(self): + """Plot the wires of the circuit diagram.""" + xstart = self._gate_grid[0] + xstop = self._gate_grid[-1] + xdata = (xstart - self.scale, xstop + self.scale) + for i in range(self.nqubits): + ydata = (self._wire_grid[i], self._wire_grid[i]) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + if self.labels: + init_label_buffer = 0 + if self.inits.get(self.labels[i]): init_label_buffer = 0.25 + self._axes.text( + xdata[0]-self.label_buffer-init_label_buffer,ydata[0], + render_label(self.labels[i],self.inits), + size=self.fontsize, + color='k',ha='center',va='center') + self._plot_measured_wires() + + def _plot_measured_wires(self): + ismeasured = self._measurements() + xstop = self._gate_grid[-1] + dy = 0.04 # amount to shift wires when doubled + # Plot doubled wires after they are measured + for im in ismeasured: + xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale) + ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + # Also double any controlled lines off these wires + for i,g in enumerate(self._gates()): + if isinstance(g, (CGate, CGateS)): + wires = g.controls + g.targets + for wire in wires: + if wire in ismeasured and \ + self._gate_grid[i] > self._gate_grid[ismeasured[wire]]: + ydata = min(wires), max(wires) + xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + def _gates(self): + """Create a list of all gates in the circuit plot.""" + gates = [] + if isinstance(self.circuit, Mul): + for g in reversed(self.circuit.args): + if isinstance(g, Gate): + gates.append(g) + elif isinstance(self.circuit, Gate): + gates.append(self.circuit) + return gates + + def _plot_gates(self): + """Iterate through the gates and plot each of them.""" + for i, gate in enumerate(self._gates()): + gate.plot_gate(self, i) + + def _measurements(self): + """Return a dict ``{i:j}`` where i is the index of the wire that has + been measured, and j is the gate where the wire is measured. + """ + ismeasured = {} + for i,g in enumerate(self._gates()): + if getattr(g,'measurement',False): + for target in g.targets: + if target in ismeasured: + if ismeasured[target] > i: + ismeasured[target] = i + else: + ismeasured[target] = i + return ismeasured + + def _finish(self): + # Disable clipping to make panning work well for large circuits. + for o in self._figure.findobj(): + o.set_clip_on(False) + + def one_qubit_box(self, t, gate_idx, wire_idx): + """Draw a box for a single qubit gate.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + self._axes.text( + x, y, t, + color='k', + ha='center', + va='center', + bbox={"ec": 'k', "fc": 'w', "fill": True, "lw": self.linewidth}, + size=self.fontsize + ) + + def two_qubit_box(self, t, gate_idx, wire_idx): + """Draw a box for a two qubit gate. Does not work yet. + """ + # x = self._gate_grid[gate_idx] + # y = self._wire_grid[wire_idx]+0.5 + print(self._gate_grid) + print(self._wire_grid) + # unused: + # obj = self._axes.text( + # x, y, t, + # color='k', + # ha='center', + # va='center', + # bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth), + # size=self.fontsize + # ) + + def control_line(self, gate_idx, min_wire, max_wire): + """Draw a vertical control line.""" + xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx]) + ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire]) + line = Line2D( + xdata, ydata, + color='k', + lw=self.linewidth + ) + self._axes.add_line(line) + + def control_point(self, gate_idx, wire_idx): + """Draw a control point.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + radius = self.control_radius + c = Circle( + (x, y), + radius*self.scale, + ec='k', + fc='k', + fill=True, + lw=self.linewidth + ) + self._axes.add_patch(c) + + def not_point(self, gate_idx, wire_idx): + """Draw a NOT gates as the circle with plus in the middle.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + radius = self.not_radius + c = Circle( + (x, y), + radius, + ec='k', + fc='w', + fill=False, + lw=self.linewidth + ) + self._axes.add_patch(c) + l = Line2D( + (x, x), (y - radius, y + radius), + color='k', + lw=self.linewidth + ) + self._axes.add_line(l) + + def swap_point(self, gate_idx, wire_idx): + """Draw a swap point as a cross.""" + x = self._gate_grid[gate_idx] + y = self._wire_grid[wire_idx] + d = self.swap_delta + l1 = Line2D( + (x - d, x + d), + (y - d, y + d), + color='k', + lw=self.linewidth + ) + l2 = Line2D( + (x - d, x + d), + (y + d, y - d), + color='k', + lw=self.linewidth + ) + self._axes.add_line(l1) + self._axes.add_line(l2) + +def circuit_plot(c, nqubits, **kwargs): + """Draw the circuit diagram for the circuit with nqubits. + + Parameters + ========== + + c : circuit + The circuit to plot. Should be a product of Gate instances. + nqubits : int + The number of qubits to include in the circuit. Must be at least + as big as the largest ``min_qubits`` of the gates. + """ + return CircuitPlot(c, nqubits, **kwargs) + +def render_label(label, inits={}): + """Slightly more flexible way to render labels. + + >>> from sympy.physics.quantum.circuitplot import render_label + >>> render_label('q0') + '$\\\\left|q0\\\\right\\\\rangle$' + >>> render_label('q0', {'q0':'0'}) + '$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$' + """ + init = inits.get(label) + if init: + return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init) + return r'$\left|%s\right\rangle$' % label + +def labeller(n, symbol='q'): + """Autogenerate labels for wires of quantum circuits. + + Parameters + ========== + + n : int + number of qubits in the circuit. + symbol : string + A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc. + + >>> from sympy.physics.quantum.circuitplot import labeller + >>> labeller(2) + ['q_1', 'q_0'] + >>> labeller(3,'j') + ['j_2', 'j_1', 'j_0'] + """ + return ['%s_%d' % (symbol,n-i-1) for i in range(n)] + +class Mz(OneQubitGate): + """Mock-up of a z measurement gate. + + This is in circuitplot rather than gate.py because it's not a real + gate, it just draws one. + """ + measurement = True + gate_name='Mz' + gate_name_latex='M_z' + +class Mx(OneQubitGate): + """Mock-up of an x measurement gate. + + This is in circuitplot rather than gate.py because it's not a real + gate, it just draws one. + """ + measurement = True + gate_name='Mx' + gate_name_latex='M_x' + +class CreateOneQubitGate(type): + def __new__(mcl, name, latexname=None): + if not latexname: + latexname = name + return type(name + "Gate", (OneQubitGate,), + {'gate_name': name, 'gate_name_latex': latexname}) + +def CreateCGate(name, latexname=None): + """Use a lexical closure to make a controlled gate. + """ + if not latexname: + latexname = name + onequbitgate = CreateOneQubitGate(name, latexname) + def ControlledGate(ctrls,target): + return CGate(tuple(ctrls),onequbitgate(target)) + return ControlledGate diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/circuitutils.py b/lib/python3.10/site-packages/sympy/physics/quantum/circuitutils.py new file mode 100644 index 0000000000000000000000000000000000000000..84955d3d724a2658f2dc3b26738133bd46f1aa57 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/circuitutils.py @@ -0,0 +1,488 @@ +"""Primitive circuit operations on quantum circuits.""" + +from functools import reduce + +from sympy.core.sorting import default_sort_key +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.utilities import numbered_symbols +from sympy.physics.quantum.gate import Gate + +__all__ = [ + 'kmp_table', + 'find_subcircuit', + 'replace_subcircuit', + 'convert_to_symbolic_indices', + 'convert_to_real_indices', + 'random_reduce', + 'random_insert' +] + + +def kmp_table(word): + """Build the 'partial match' table of the Knuth-Morris-Pratt algorithm. + + Note: This is applicable to strings or + quantum circuits represented as tuples. + """ + + # Current position in subcircuit + pos = 2 + # Beginning position of candidate substring that + # may reappear later in word + cnd = 0 + # The 'partial match' table that helps one determine + # the next location to start substring search + table = [] + table.append(-1) + table.append(0) + + while pos < len(word): + if word[pos - 1] == word[cnd]: + cnd = cnd + 1 + table.append(cnd) + pos = pos + 1 + elif cnd > 0: + cnd = table[cnd] + else: + table.append(0) + pos = pos + 1 + + return table + + +def find_subcircuit(circuit, subcircuit, start=0, end=0): + """Finds the subcircuit in circuit, if it exists. + + Explanation + =========== + + If the subcircuit exists, the index of the start of + the subcircuit in circuit is returned; otherwise, + -1 is returned. The algorithm that is implemented + is the Knuth-Morris-Pratt algorithm. + + Parameters + ========== + + circuit : tuple, Gate or Mul + A tuple of Gates or Mul representing a quantum circuit + subcircuit : tuple, Gate or Mul + A tuple of Gates or Mul to find in circuit + start : int + The location to start looking for subcircuit. + If start is the same or past end, -1 is returned. + end : int + The last place to look for a subcircuit. If end + is less than 1 (one), then the length of circuit + is taken to be end. + + Examples + ======== + + Find the first instance of a subcircuit: + + >>> from sympy.physics.quantum.circuitutils import find_subcircuit + >>> from sympy.physics.quantum.gate import X, Y, Z, H + >>> circuit = X(0)*Z(0)*Y(0)*H(0) + >>> subcircuit = Z(0)*Y(0) + >>> find_subcircuit(circuit, subcircuit) + 1 + + Find the first instance starting at a specific position: + + >>> find_subcircuit(circuit, subcircuit, start=1) + 1 + + >>> find_subcircuit(circuit, subcircuit, start=2) + -1 + + >>> circuit = circuit*subcircuit + >>> find_subcircuit(circuit, subcircuit, start=2) + 4 + + Find the subcircuit within some interval: + + >>> find_subcircuit(circuit, subcircuit, start=2, end=2) + -1 + """ + + if isinstance(circuit, Mul): + circuit = circuit.args + + if isinstance(subcircuit, Mul): + subcircuit = subcircuit.args + + if len(subcircuit) == 0 or len(subcircuit) > len(circuit): + return -1 + + if end < 1: + end = len(circuit) + + # Location in circuit + pos = start + # Location in the subcircuit + index = 0 + # 'Partial match' table + table = kmp_table(subcircuit) + + while (pos + index) < end: + if subcircuit[index] == circuit[pos + index]: + index = index + 1 + else: + pos = pos + index - table[index] + index = table[index] if table[index] > -1 else 0 + + if index == len(subcircuit): + return pos + + return -1 + + +def replace_subcircuit(circuit, subcircuit, replace=None, pos=0): + """Replaces a subcircuit with another subcircuit in circuit, + if it exists. + + Explanation + =========== + + If multiple instances of subcircuit exists, the first instance is + replaced. The position to being searching from (if different from + 0) may be optionally given. If subcircuit cannot be found, circuit + is returned. + + Parameters + ========== + + circuit : tuple, Gate or Mul + A quantum circuit. + subcircuit : tuple, Gate or Mul + The circuit to be replaced. + replace : tuple, Gate or Mul + The replacement circuit. + pos : int + The location to start search and replace + subcircuit, if it exists. This may be used + if it is known beforehand that multiple + instances exist, and it is desirable to + replace a specific instance. If a negative number + is given, pos will be defaulted to 0. + + Examples + ======== + + Find and remove the subcircuit: + + >>> from sympy.physics.quantum.circuitutils import replace_subcircuit + >>> from sympy.physics.quantum.gate import X, Y, Z, H + >>> circuit = X(0)*Z(0)*Y(0)*H(0)*X(0)*H(0)*Y(0) + >>> subcircuit = Z(0)*Y(0) + >>> replace_subcircuit(circuit, subcircuit) + (X(0), H(0), X(0), H(0), Y(0)) + + Remove the subcircuit given a starting search point: + + >>> replace_subcircuit(circuit, subcircuit, pos=1) + (X(0), H(0), X(0), H(0), Y(0)) + + >>> replace_subcircuit(circuit, subcircuit, pos=2) + (X(0), Z(0), Y(0), H(0), X(0), H(0), Y(0)) + + Replace the subcircuit: + + >>> replacement = H(0)*Z(0) + >>> replace_subcircuit(circuit, subcircuit, replace=replacement) + (X(0), H(0), Z(0), H(0), X(0), H(0), Y(0)) + """ + + if pos < 0: + pos = 0 + + if isinstance(circuit, Mul): + circuit = circuit.args + + if isinstance(subcircuit, Mul): + subcircuit = subcircuit.args + + if isinstance(replace, Mul): + replace = replace.args + elif replace is None: + replace = () + + # Look for the subcircuit starting at pos + loc = find_subcircuit(circuit, subcircuit, start=pos) + + # If subcircuit was found + if loc > -1: + # Get the gates to the left of subcircuit + left = circuit[0:loc] + # Get the gates to the right of subcircuit + right = circuit[loc + len(subcircuit):len(circuit)] + # Recombine the left and right side gates into a circuit + circuit = left + replace + right + + return circuit + + +def _sympify_qubit_map(mapping): + new_map = {} + for key in mapping: + new_map[key] = sympify(mapping[key]) + return new_map + + +def convert_to_symbolic_indices(seq, start=None, gen=None, qubit_map=None): + """Returns the circuit with symbolic indices and the + dictionary mapping symbolic indices to real indices. + + The mapping is 1 to 1 and onto (bijective). + + Parameters + ========== + + seq : tuple, Gate/Integer/tuple or Mul + A tuple of Gate, Integer, or tuple objects, or a Mul + start : Symbol + An optional starting symbolic index + gen : object + An optional numbered symbol generator + qubit_map : dict + An existing mapping of symbolic indices to real indices + + All symbolic indices have the format 'i#', where # is + some number >= 0. + """ + + if isinstance(seq, Mul): + seq = seq.args + + # A numbered symbol generator + index_gen = numbered_symbols(prefix='i', start=-1) + cur_ndx = next(index_gen) + + # keys are symbolic indices; values are real indices + ndx_map = {} + + def create_inverse_map(symb_to_real_map): + rev_items = lambda item: (item[1], item[0]) + return dict(map(rev_items, symb_to_real_map.items())) + + if start is not None: + if not isinstance(start, Symbol): + msg = 'Expected Symbol for starting index, got %r.' % start + raise TypeError(msg) + cur_ndx = start + + if gen is not None: + if not isinstance(gen, numbered_symbols().__class__): + msg = 'Expected a generator, got %r.' % gen + raise TypeError(msg) + index_gen = gen + + if qubit_map is not None: + if not isinstance(qubit_map, dict): + msg = ('Expected dict for existing map, got ' + + '%r.' % qubit_map) + raise TypeError(msg) + ndx_map = qubit_map + + ndx_map = _sympify_qubit_map(ndx_map) + # keys are real indices; keys are symbolic indices + inv_map = create_inverse_map(ndx_map) + + sym_seq = () + for item in seq: + # Nested items, so recurse + if isinstance(item, Gate): + result = convert_to_symbolic_indices(item.args, + qubit_map=ndx_map, + start=cur_ndx, + gen=index_gen) + sym_item, new_map, cur_ndx, index_gen = result + ndx_map.update(new_map) + inv_map = create_inverse_map(ndx_map) + + elif isinstance(item, (tuple, Tuple)): + result = convert_to_symbolic_indices(item, + qubit_map=ndx_map, + start=cur_ndx, + gen=index_gen) + sym_item, new_map, cur_ndx, index_gen = result + ndx_map.update(new_map) + inv_map = create_inverse_map(ndx_map) + + elif item in inv_map: + sym_item = inv_map[item] + + else: + cur_ndx = next(gen) + ndx_map[cur_ndx] = item + inv_map[item] = cur_ndx + sym_item = cur_ndx + + if isinstance(item, Gate): + sym_item = item.__class__(*sym_item) + + sym_seq = sym_seq + (sym_item,) + + return sym_seq, ndx_map, cur_ndx, index_gen + + +def convert_to_real_indices(seq, qubit_map): + """Returns the circuit with real indices. + + Parameters + ========== + + seq : tuple, Gate/Integer/tuple or Mul + A tuple of Gate, Integer, or tuple objects or a Mul + qubit_map : dict + A dictionary mapping symbolic indices to real indices. + + Examples + ======== + + Change the symbolic indices to real integers: + + >>> from sympy import symbols + >>> from sympy.physics.quantum.circuitutils import convert_to_real_indices + >>> from sympy.physics.quantum.gate import X, Y, H + >>> i0, i1 = symbols('i:2') + >>> index_map = {i0 : 0, i1 : 1} + >>> convert_to_real_indices(X(i0)*Y(i1)*H(i0)*X(i1), index_map) + (X(0), Y(1), H(0), X(1)) + """ + + if isinstance(seq, Mul): + seq = seq.args + + if not isinstance(qubit_map, dict): + msg = 'Expected dict for qubit_map, got %r.' % qubit_map + raise TypeError(msg) + + qubit_map = _sympify_qubit_map(qubit_map) + real_seq = () + for item in seq: + # Nested items, so recurse + if isinstance(item, Gate): + real_item = convert_to_real_indices(item.args, qubit_map) + + elif isinstance(item, (tuple, Tuple)): + real_item = convert_to_real_indices(item, qubit_map) + + else: + real_item = qubit_map[item] + + if isinstance(item, Gate): + real_item = item.__class__(*real_item) + + real_seq = real_seq + (real_item,) + + return real_seq + + +def random_reduce(circuit, gate_ids, seed=None): + """Shorten the length of a quantum circuit. + + Explanation + =========== + + random_reduce looks for circuit identities in circuit, randomly chooses + one to remove, and returns a shorter yet equivalent circuit. If no + identities are found, the same circuit is returned. + + Parameters + ========== + + circuit : Gate tuple of Mul + A tuple of Gates representing a quantum circuit + gate_ids : list, GateIdentity + List of gate identities to find in circuit + seed : int or list + seed used for _randrange; to override the random selection, provide a + list of integers: the elements of gate_ids will be tested in the order + given by the list + + """ + from sympy.core.random import _randrange + + if not gate_ids: + return circuit + + if isinstance(circuit, Mul): + circuit = circuit.args + + ids = flatten_ids(gate_ids) + + # Create the random integer generator with the seed + randrange = _randrange(seed) + + # Look for an identity in the circuit + while ids: + i = randrange(len(ids)) + id = ids.pop(i) + if find_subcircuit(circuit, id) != -1: + break + else: + # no identity was found + return circuit + + # return circuit with the identity removed + return replace_subcircuit(circuit, id) + + +def random_insert(circuit, choices, seed=None): + """Insert a circuit into another quantum circuit. + + Explanation + =========== + + random_insert randomly chooses a location in the circuit to insert + a randomly selected circuit from amongst the given choices. + + Parameters + ========== + + circuit : Gate tuple or Mul + A tuple or Mul of Gates representing a quantum circuit + choices : list + Set of circuit choices + seed : int or list + seed used for _randrange; to override the random selections, give + a list two integers, [i, j] where i is the circuit location where + choice[j] will be inserted. + + Notes + ===== + + Indices for insertion should be [0, n] if n is the length of the + circuit. + """ + from sympy.core.random import _randrange + + if not choices: + return circuit + + if isinstance(circuit, Mul): + circuit = circuit.args + + # get the location in the circuit and the element to insert from choices + randrange = _randrange(seed) + loc = randrange(len(circuit) + 1) + choice = choices[randrange(len(choices))] + + circuit = list(circuit) + circuit[loc: loc] = choice + return tuple(circuit) + +# Flatten the GateIdentity objects (with gate rules) into one single list + + +def flatten_ids(ids): + collapse = lambda acc, an_id: acc + sorted(an_id.equivalent_ids, + key=default_sort_key) + ids = reduce(collapse, ids, []) + ids.sort(key=default_sort_key) + return ids diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/commutator.py b/lib/python3.10/site-packages/sympy/physics/quantum/commutator.py new file mode 100644 index 0000000000000000000000000000000000000000..627158657481a4b66875e1d23107c1ca3bdb6969 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/commutator.py @@ -0,0 +1,239 @@ +"""The commutator: [A,B] = A*B - B*A.""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.operator import Operator + + +__all__ = [ + 'Commutator' +] + +#----------------------------------------------------------------------------- +# Commutator +#----------------------------------------------------------------------------- + + +class Commutator(Expr): + """The standard commutator, in an unevaluated state. + + Explanation + =========== + + Evaluating a commutator is defined [1]_ as: ``[A, B] = A*B - B*A``. This + class returns the commutator in an unevaluated form. To evaluate the + commutator, use the ``.doit()`` method. + + Canonical ordering of a commutator is ``[A, B]`` for ``A < B``. The + arguments of the commutator are put into canonical order using ``__cmp__``. + If ``B < A``, then ``[B, A]`` is returned as ``-[A, B]``. + + Parameters + ========== + + A : Expr + The first argument of the commutator [A,B]. + B : Expr + The second argument of the commutator [A,B]. + + Examples + ======== + + >>> from sympy.physics.quantum import Commutator, Dagger, Operator + >>> from sympy.abc import x, y + >>> A = Operator('A') + >>> B = Operator('B') + >>> C = Operator('C') + + Create a commutator and use ``.doit()`` to evaluate it: + + >>> comm = Commutator(A, B) + >>> comm + [A,B] + >>> comm.doit() + A*B - B*A + + The commutator orders it arguments in canonical order: + + >>> comm = Commutator(B, A); comm + -[A,B] + + Commutative constants are factored out: + + >>> Commutator(3*x*A, x*y*B) + 3*x**2*y*[A,B] + + Using ``.expand(commutator=True)``, the standard commutator expansion rules + can be applied: + + >>> Commutator(A+B, C).expand(commutator=True) + [A,C] + [B,C] + >>> Commutator(A, B+C).expand(commutator=True) + [A,B] + [A,C] + >>> Commutator(A*B, C).expand(commutator=True) + [A,C]*B + A*[B,C] + >>> Commutator(A, B*C).expand(commutator=True) + [A,B]*C + B*[A,C] + + Adjoint operations applied to the commutator are properly applied to the + arguments: + + >>> Dagger(Commutator(A, B)) + -[Dagger(A),Dagger(B)] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Commutator + """ + is_commutative = False + + def __new__(cls, A, B): + r = cls.eval(A, B) + if r is not None: + return r + obj = Expr.__new__(cls, A, B) + return obj + + @classmethod + def eval(cls, a, b): + if not (a and b): + return S.Zero + if a == b: + return S.Zero + if a.is_commutative or b.is_commutative: + return S.Zero + + # [xA,yB] -> xy*[A,B] + ca, nca = a.args_cnc() + cb, ncb = b.args_cnc() + c_part = ca + cb + if c_part: + return Mul(Mul(*c_part), cls(Mul._from_args(nca), Mul._from_args(ncb))) + + # Canonical ordering of arguments + # The Commutator [A, B] is in canonical form if A < B. + if a.compare(b) == 1: + return S.NegativeOne*cls(b, a) + + def _expand_pow(self, A, B, sign): + exp = A.exp + if not exp.is_integer or not exp.is_constant() or abs(exp) <= 1: + # nothing to do + return self + base = A.base + if exp.is_negative: + base = A.base**-1 + exp = -exp + comm = Commutator(base, B).expand(commutator=True) + + result = base**(exp - 1) * comm + for i in range(1, exp): + result += base**(exp - 1 - i) * comm * base**i + return sign*result.expand() + + def _eval_expand_commutator(self, **hints): + A = self.args[0] + B = self.args[1] + + if isinstance(A, Add): + # [A + B, C] -> [A, C] + [B, C] + sargs = [] + for term in A.args: + comm = Commutator(term, B) + if isinstance(comm, Commutator): + comm = comm._eval_expand_commutator() + sargs.append(comm) + return Add(*sargs) + elif isinstance(B, Add): + # [A, B + C] -> [A, B] + [A, C] + sargs = [] + for term in B.args: + comm = Commutator(A, term) + if isinstance(comm, Commutator): + comm = comm._eval_expand_commutator() + sargs.append(comm) + return Add(*sargs) + elif isinstance(A, Mul): + # [A*B, C] -> A*[B, C] + [A, C]*B + a = A.args[0] + b = Mul(*A.args[1:]) + c = B + comm1 = Commutator(b, c) + comm2 = Commutator(a, c) + if isinstance(comm1, Commutator): + comm1 = comm1._eval_expand_commutator() + if isinstance(comm2, Commutator): + comm2 = comm2._eval_expand_commutator() + first = Mul(a, comm1) + second = Mul(comm2, b) + return Add(first, second) + elif isinstance(B, Mul): + # [A, B*C] -> [A, B]*C + B*[A, C] + a = A + b = B.args[0] + c = Mul(*B.args[1:]) + comm1 = Commutator(a, b) + comm2 = Commutator(a, c) + if isinstance(comm1, Commutator): + comm1 = comm1._eval_expand_commutator() + if isinstance(comm2, Commutator): + comm2 = comm2._eval_expand_commutator() + first = Mul(comm1, c) + second = Mul(b, comm2) + return Add(first, second) + elif isinstance(A, Pow): + # [A**n, C] -> A**(n - 1)*[A, C] + A**(n - 2)*[A, C]*A + ... + [A, C]*A**(n-1) + return self._expand_pow(A, B, 1) + elif isinstance(B, Pow): + # [A, C**n] -> C**(n - 1)*[C, A] + C**(n - 2)*[C, A]*C + ... + [C, A]*C**(n-1) + return self._expand_pow(B, A, -1) + + # No changes, so return self + return self + + def doit(self, **hints): + """ Evaluate commutator """ + A = self.args[0] + B = self.args[1] + if isinstance(A, Operator) and isinstance(B, Operator): + try: + comm = A._eval_commutator(B, **hints) + except NotImplementedError: + try: + comm = -1*B._eval_commutator(A, **hints) + except NotImplementedError: + comm = None + if comm is not None: + return comm.doit(**hints) + return (A*B - B*A).doit(**hints) + + def _eval_adjoint(self): + return Commutator(Dagger(self.args[1]), Dagger(self.args[0])) + + def _sympyrepr(self, printer, *args): + return "%s(%s,%s)" % ( + self.__class__.__name__, printer._print( + self.args[0]), printer._print(self.args[1]) + ) + + def _sympystr(self, printer, *args): + return "[%s,%s]" % ( + printer._print(self.args[0]), printer._print(self.args[1])) + + def _pretty(self, printer, *args): + pform = printer._print(self.args[0], *args) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(printer._print(self.args[1], *args))) + pform = prettyForm(*pform.parens(left='[', right=']')) + return pform + + def _latex(self, printer, *args): + return "\\left[%s,%s\\right]" % tuple([ + printer._print(arg, *args) for arg in self.args]) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/constants.py b/lib/python3.10/site-packages/sympy/physics/quantum/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..3e848bf24e95e3bd612169128a1845202066c6e9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/constants.py @@ -0,0 +1,59 @@ +"""Constants (like hbar) related to quantum mechanics.""" + +from sympy.core.numbers import NumberSymbol +from sympy.core.singleton import Singleton +from sympy.printing.pretty.stringpict import prettyForm +import mpmath.libmp as mlib + +#----------------------------------------------------------------------------- +# Constants +#----------------------------------------------------------------------------- + +__all__ = [ + 'hbar', + 'HBar', +] + + +class HBar(NumberSymbol, metaclass=Singleton): + """Reduced Plank's constant in numerical and symbolic form [1]_. + + Examples + ======== + + >>> from sympy.physics.quantum.constants import hbar + >>> hbar.evalf() + 1.05457162000000e-34 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Planck_constant + """ + + is_real = True + is_positive = True + is_negative = False + is_irrational = True + + __slots__ = () + + def _as_mpf_val(self, prec): + return mlib.from_float(1.05457162e-34, prec) + + def _sympyrepr(self, printer, *args): + return 'HBar()' + + def _sympystr(self, printer, *args): + return 'hbar' + + def _pretty(self, printer, *args): + if printer._use_unicode: + return prettyForm('\N{PLANCK CONSTANT OVER TWO PI}') + return prettyForm('hbar') + + def _latex(self, printer, *args): + return r'\hbar' + +# Create an instance for everyone to use. +hbar = HBar() diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/dagger.py b/lib/python3.10/site-packages/sympy/physics/quantum/dagger.py new file mode 100644 index 0000000000000000000000000000000000000000..6305a656c3664c3be023bea5c07916915ff86d5c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/dagger.py @@ -0,0 +1,97 @@ +"""Hermitian conjugation.""" + +from sympy.core import Expr, Mul, sympify +from sympy.functions.elementary.complexes import adjoint + +__all__ = [ + 'Dagger' +] + + +class Dagger(adjoint): + """General Hermitian conjugate operation. + + Explanation + =========== + + Take the Hermetian conjugate of an argument [1]_. For matrices this + operation is equivalent to transpose and complex conjugate [2]_. + + Parameters + ========== + + arg : Expr + The SymPy expression that we want to take the dagger of. + evaluate : bool + Whether the resulting expression should be directly evaluated. + + Examples + ======== + + Daggering various quantum objects: + + >>> from sympy.physics.quantum.dagger import Dagger + >>> from sympy.physics.quantum.state import Ket, Bra + >>> from sympy.physics.quantum.operator import Operator + >>> Dagger(Ket('psi')) + >> Dagger(Bra('phi')) + |phi> + >>> Dagger(Operator('A')) + Dagger(A) + + Inner and outer products:: + + >>> from sympy.physics.quantum import InnerProduct, OuterProduct + >>> Dagger(InnerProduct(Bra('a'), Ket('b'))) + + >>> Dagger(OuterProduct(Ket('a'), Bra('b'))) + |b>>> A = Operator('A') + >>> B = Operator('B') + >>> Dagger(A*B) + Dagger(B)*Dagger(A) + >>> Dagger(A+B) + Dagger(A) + Dagger(B) + >>> Dagger(A**2) + Dagger(A)**2 + + Dagger also seamlessly handles complex numbers and matrices:: + + >>> from sympy import Matrix, I + >>> m = Matrix([[1,I],[2,I]]) + >>> m + Matrix([ + [1, I], + [2, I]]) + >>> Dagger(m) + Matrix([ + [ 1, 2], + [-I, -I]]) + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hermitian_adjoint + .. [2] https://en.wikipedia.org/wiki/Hermitian_transpose + """ + + def __new__(cls, arg, evaluate=True): + if hasattr(arg, 'adjoint') and evaluate: + return arg.adjoint() + elif hasattr(arg, 'conjugate') and hasattr(arg, 'transpose') and evaluate: + return arg.conjugate().transpose() + return Expr.__new__(cls, sympify(arg)) + + def __mul__(self, other): + from sympy.physics.quantum import IdentityOperator + if isinstance(other, IdentityOperator): + return self + + return Mul(self, other) + +adjoint.__name__ = "Dagger" +adjoint._sympyrepr = lambda a, b: "Dagger(%s)" % b._print(a.args[0]) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/density.py b/lib/python3.10/site-packages/sympy/physics/quantum/density.py new file mode 100644 index 0000000000000000000000000000000000000000..aa1f408d93fd3eb7fdcaebd7206cf0fcca2e2f18 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/density.py @@ -0,0 +1,319 @@ +from itertools import product + +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import expand +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import log +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.operator import HermitianOperator +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.matrixutils import numpy_ndarray, scipy_sparse_matrix, to_numpy +from sympy.physics.quantum.tensorproduct import TensorProduct, tensor_product_simp +from sympy.physics.quantum.trace import Tr + + +class Density(HermitianOperator): + """Density operator for representing mixed states. + + TODO: Density operator support for Qubits + + Parameters + ========== + + values : tuples/lists + Each tuple/list should be of form (state, prob) or [state,prob] + + Examples + ======== + + Create a density operator with 2 states represented by Kets. + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d + Density((|0>, 0.5),(|1>, 0.5)) + + """ + @classmethod + def _eval_args(cls, args): + # call this to qsympify the args + args = super()._eval_args(args) + + for arg in args: + # Check if arg is a tuple + if not (isinstance(arg, Tuple) and len(arg) == 2): + raise ValueError("Each argument should be of form [state,prob]" + " or ( state, prob )") + + return args + + def states(self): + """Return list of all states. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.states() + (|0>, |1>) + + """ + return Tuple(*[arg[0] for arg in self.args]) + + def probs(self): + """Return list of all probabilities. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.probs() + (0.5, 0.5) + + """ + return Tuple(*[arg[1] for arg in self.args]) + + def get_state(self, index): + """Return specific state by index. + + Parameters + ========== + + index : index of state to be returned + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.states()[1] + |1> + + """ + state = self.args[index][0] + return state + + def get_prob(self, index): + """Return probability of specific state by index. + + Parameters + =========== + + index : index of states whose probability is returned. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.probs()[1] + 0.500000000000000 + + """ + prob = self.args[index][1] + return prob + + def apply_op(self, op): + """op will operate on each individual state. + + Parameters + ========== + + op : Operator + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> from sympy.physics.quantum.operator import Operator + >>> A = Operator('A') + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.apply_op(A) + Density((A*|0>, 0.5),(A*|1>, 0.5)) + + """ + new_args = [(op*state, prob) for (state, prob) in self.args] + return Density(*new_args) + + def doit(self, **hints): + """Expand the density operator into an outer product format. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Ket + >>> from sympy.physics.quantum.density import Density + >>> from sympy.physics.quantum.operator import Operator + >>> A = Operator('A') + >>> d = Density([Ket(0), 0.5], [Ket(1),0.5]) + >>> d.doit() + 0.5*|0><0| + 0.5*|1><1| + + """ + + terms = [] + for (state, prob) in self.args: + state = state.expand() # needed to break up (a+b)*c + if (isinstance(state, Add)): + for arg in product(state.args, repeat=2): + terms.append(prob*self._generate_outer_prod(arg[0], + arg[1])) + else: + terms.append(prob*self._generate_outer_prod(state, state)) + + return Add(*terms) + + def _generate_outer_prod(self, arg1, arg2): + c_part1, nc_part1 = arg1.args_cnc() + c_part2, nc_part2 = arg2.args_cnc() + + if (len(nc_part1) == 0 or len(nc_part2) == 0): + raise ValueError('Atleast one-pair of' + ' Non-commutative instance required' + ' for outer product.') + + # Muls of Tensor Products should be expanded + # before this function is called + if (isinstance(nc_part1[0], TensorProduct) and len(nc_part1) == 1 + and len(nc_part2) == 1): + op = tensor_product_simp(nc_part1[0]*Dagger(nc_part2[0])) + else: + op = Mul(*nc_part1)*Dagger(Mul(*nc_part2)) + + return Mul(*c_part1)*Mul(*c_part2) * op + + def _represent(self, **options): + return represent(self.doit(), **options) + + def _print_operator_name_latex(self, printer, *args): + return r'\rho' + + def _print_operator_name_pretty(self, printer, *args): + return prettyForm('\N{GREEK SMALL LETTER RHO}') + + def _eval_trace(self, **kwargs): + indices = kwargs.get('indices', []) + return Tr(self.doit(), indices).doit() + + def entropy(self): + """ Compute the entropy of a density matrix. + + Refer to density.entropy() method for examples. + """ + return entropy(self) + + +def entropy(density): + """Compute the entropy of a matrix/density object. + + This computes -Tr(density*ln(density)) using the eigenvalue decomposition + of density, which is given as either a Density instance or a matrix + (numpy.ndarray, sympy.Matrix or scipy.sparse). + + Parameters + ========== + + density : density matrix of type Density, SymPy matrix, + scipy.sparse or numpy.ndarray + + Examples + ======== + + >>> from sympy.physics.quantum.density import Density, entropy + >>> from sympy.physics.quantum.spin import JzKet + >>> from sympy import S + >>> up = JzKet(S(1)/2,S(1)/2) + >>> down = JzKet(S(1)/2,-S(1)/2) + >>> d = Density((up,S(1)/2),(down,S(1)/2)) + >>> entropy(d) + log(2)/2 + + """ + if isinstance(density, Density): + density = represent(density) # represent in Matrix + + if isinstance(density, scipy_sparse_matrix): + density = to_numpy(density) + + if isinstance(density, Matrix): + eigvals = density.eigenvals().keys() + return expand(-sum(e*log(e) for e in eigvals)) + elif isinstance(density, numpy_ndarray): + import numpy as np + eigvals = np.linalg.eigvals(density) + return -np.sum(eigvals*np.log(eigvals)) + else: + raise ValueError( + "numpy.ndarray, scipy.sparse or SymPy matrix expected") + + +def fidelity(state1, state2): + """ Computes the fidelity [1]_ between two quantum states + + The arguments provided to this function should be a square matrix or a + Density object. If it is a square matrix, it is assumed to be diagonalizable. + + Parameters + ========== + + state1, state2 : a density matrix or Matrix + + + Examples + ======== + + >>> from sympy import S, sqrt + >>> from sympy.physics.quantum.dagger import Dagger + >>> from sympy.physics.quantum.spin import JzKet + >>> from sympy.physics.quantum.density import fidelity + >>> from sympy.physics.quantum.represent import represent + >>> + >>> up = JzKet(S(1)/2,S(1)/2) + >>> down = JzKet(S(1)/2,-S(1)/2) + >>> amp = 1/sqrt(2) + >>> updown = (amp*up) + (amp*down) + >>> + >>> # represent turns Kets into matrices + >>> up_dm = represent(up*Dagger(up)) + >>> down_dm = represent(down*Dagger(down)) + >>> updown_dm = represent(updown*Dagger(updown)) + >>> + >>> fidelity(up_dm, up_dm) + 1 + >>> fidelity(up_dm, down_dm) #orthogonal states + 0 + >>> fidelity(up_dm, updown_dm).evalf().round(3) + 0.707 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fidelity_of_quantum_states + + """ + state1 = represent(state1) if isinstance(state1, Density) else state1 + state2 = represent(state2) if isinstance(state2, Density) else state2 + + if not isinstance(state1, Matrix) or not isinstance(state2, Matrix): + raise ValueError("state1 and state2 must be of type Density or Matrix " + "received type=%s for state1 and type=%s for state2" % + (type(state1), type(state2))) + + if state1.shape != state2.shape and state1.is_square: + raise ValueError("The dimensions of both args should be equal and the " + "matrix obtained should be a square matrix") + + sqrt_state1 = state1**S.Half + return Tr((sqrt_state1*state2*sqrt_state1)**S.Half).doit() diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/fermion.py b/lib/python3.10/site-packages/sympy/physics/quantum/fermion.py new file mode 100644 index 0000000000000000000000000000000000000000..8080bd3b0904b837652fdae7be0bd526da2d508f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/fermion.py @@ -0,0 +1,191 @@ +"""Fermionic quantum operators.""" + +from sympy.core.numbers import Integer +from sympy.core.singleton import S +from sympy.physics.quantum import Operator +from sympy.physics.quantum import HilbertSpace, Ket, Bra +from sympy.functions.special.tensor_functions import KroneckerDelta + + +__all__ = [ + 'FermionOp', + 'FermionFockKet', + 'FermionFockBra' +] + + +class FermionOp(Operator): + """A fermionic operator that satisfies {c, Dagger(c)} == 1. + + Parameters + ========== + + name : str + A string that labels the fermionic mode. + + annihilation : bool + A bool that indicates if the fermionic operator is an annihilation + (True, default value) or creation operator (False) + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, AntiCommutator + >>> from sympy.physics.quantum.fermion import FermionOp + >>> c = FermionOp("c") + >>> AntiCommutator(c, Dagger(c)).doit() + 1 + """ + @property + def name(self): + return self.args[0] + + @property + def is_annihilation(self): + return bool(self.args[1]) + + @classmethod + def default_args(self): + return ("c", True) + + def __new__(cls, *args, **hints): + if not len(args) in [1, 2]: + raise ValueError('1 or 2 parameters expected, got %s' % args) + + if len(args) == 1: + args = (args[0], S.One) + + if len(args) == 2: + args = (args[0], Integer(args[1])) + + return Operator.__new__(cls, *args) + + def _eval_commutator_FermionOp(self, other, **hints): + if 'independent' in hints and hints['independent']: + # [c, d] = 0 + return S.Zero + + return None + + def _eval_anticommutator_FermionOp(self, other, **hints): + if self.name == other.name: + # {a^\dagger, a} = 1 + if not self.is_annihilation and other.is_annihilation: + return S.One + + elif 'independent' in hints and hints['independent']: + # {c, d} = 2 * c * d, because [c, d] = 0 for independent operators + return 2 * self * other + + return None + + def _eval_anticommutator_BosonOp(self, other, **hints): + # because fermions and bosons commute + return 2 * self * other + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return FermionOp(str(self.name), not self.is_annihilation) + + def _print_contents_latex(self, printer, *args): + if self.is_annihilation: + return r'{%s}' % str(self.name) + else: + return r'{{%s}^\dagger}' % str(self.name) + + def _print_contents(self, printer, *args): + if self.is_annihilation: + return r'%s' % str(self.name) + else: + return r'Dagger(%s)' % str(self.name) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + if self.is_annihilation: + return pform + else: + return pform**prettyForm('\N{DAGGER}') + + def _eval_power(self, exp): + from sympy.core.singleton import S + if exp == 0: + return S.One + elif exp == 1: + return self + elif (exp > 1) == True and exp.is_integer == True: + return S.Zero + elif (exp < 0) == True or exp.is_integer == False: + raise ValueError("Fermionic operators can only be raised to a" + " positive integer power") + return Operator._eval_power(self, exp) + +class FermionFockKet(Ket): + """Fock state ket for a fermionic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return FermionFockBra + + @classmethod + def _eval_hilbert_space(cls, label): + return HilbertSpace() + + def _eval_innerproduct_FermionFockBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_FermionOp(self, op, **options): + if op.is_annihilation: + if self.n == 1: + return FermionFockKet(0) + else: + return S.Zero + else: + if self.n == 0: + return FermionFockKet(1) + else: + return S.Zero + + +class FermionFockBra(Bra): + """Fock state bra for a fermionic mode. + + Parameters + ========== + + n : Number + The Fock state number. + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return FermionFockKet diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/gate.py b/lib/python3.10/site-packages/sympy/physics/quantum/gate.py new file mode 100644 index 0000000000000000000000000000000000000000..f8bcf5cd3611173cd9ebd6308dbbc896f5257f20 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/gate.py @@ -0,0 +1,1309 @@ +"""An implementation of gates that act on qubits. + +Gates are unitary operators that act on the space of qubits. + +Medium Term Todo: + +* Optimize Gate._apply_operators_Qubit to remove the creation of many + intermediate Qubit objects. +* Add commutation relationships to all operators and use this in gate_sort. +* Fix gate_sort and gate_simp. +* Get multi-target UGates plotting properly. +* Get UGate to work with either sympy/numpy matrices and output either + format. This should also use the matrix slots. +""" + +from itertools import chain +import random + +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer) +from sympy.core.power import Pow +from sympy.core.numbers import Number +from sympy.core.singleton import S as _S +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import _sympify +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.pretty.stringpict import prettyForm, stringPict + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.operator import (UnitaryOperator, Operator, + HermitianOperator) +from sympy.physics.quantum.matrixutils import matrix_tensor_product, matrix_eye +from sympy.physics.quantum.matrixcache import matrix_cache + +from sympy.matrices.matrixbase import MatrixBase + +from sympy.utilities.iterables import is_sequence + +__all__ = [ + 'Gate', + 'CGate', + 'UGate', + 'OneQubitGate', + 'TwoQubitGate', + 'IdentityGate', + 'HadamardGate', + 'XGate', + 'YGate', + 'ZGate', + 'TGate', + 'PhaseGate', + 'SwapGate', + 'CNotGate', + # Aliased gate names + 'CNOT', + 'SWAP', + 'H', + 'X', + 'Y', + 'Z', + 'T', + 'S', + 'Phase', + 'normalized', + 'gate_sort', + 'gate_simp', + 'random_circuit', + 'CPHASE', + 'CGateS', +] + +#----------------------------------------------------------------------------- +# Gate Super-Classes +#----------------------------------------------------------------------------- + +_normalized = True + + +def _max(*args, **kwargs): + if "key" not in kwargs: + kwargs["key"] = default_sort_key + return max(*args, **kwargs) + + +def _min(*args, **kwargs): + if "key" not in kwargs: + kwargs["key"] = default_sort_key + return min(*args, **kwargs) + + +def normalized(normalize): + r"""Set flag controlling normalization of Hadamard gates by `1/\sqrt{2}`. + + This is a global setting that can be used to simplify the look of various + expressions, by leaving off the leading `1/\sqrt{2}` of the Hadamard gate. + + Parameters + ---------- + normalize : bool + Should the Hadamard gate include the `1/\sqrt{2}` normalization factor? + When True, the Hadamard gate will have the `1/\sqrt{2}`. When False, the + Hadamard gate will not have this factor. + """ + global _normalized + _normalized = normalize + + +def _validate_targets_controls(tandc): + tandc = list(tandc) + # Check for integers + for bit in tandc: + if not bit.is_Integer and not bit.is_Symbol: + raise TypeError('Integer expected, got: %r' % tandc[bit]) + # Detect duplicates + if len(set(tandc)) != len(tandc): + raise QuantumError( + 'Target/control qubits in a gate cannot be duplicated' + ) + + +class Gate(UnitaryOperator): + """Non-controlled unitary gate operator that acts on qubits. + + This is a general abstract gate that needs to be subclassed to do anything + useful. + + Parameters + ---------- + label : tuple, int + A list of the target qubits (as ints) that the gate will apply to. + + Examples + ======== + + + """ + + _label_separator = ',' + + gate_name = 'G' + gate_name_latex = 'G' + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + args = Tuple(*UnitaryOperator._eval_args(args)) + _validate_targets_controls(args) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def nqubits(self): + """The total number of qubits this gate acts on. + + For controlled gate subclasses this includes both target and control + qubits, so that, for examples the CNOT gate acts on 2 qubits. + """ + return len(self.targets) + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(self.targets) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return self.label + + @property + def gate_name_plot(self): + return r'$%s$' % self.gate_name_latex + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + """The matrix representation of the target part of the gate. + + Parameters + ---------- + format : str + The format string ('sympy','numpy', etc.) + """ + raise NotImplementedError( + 'get_target_matrix is not implemented in Gate.') + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_IntQubit(self, qubits, **options): + """Redirect an apply from IntQubit to Qubit""" + return self._apply_operator_Qubit(qubits, **options) + + def _apply_operator_Qubit(self, qubits, **options): + """Apply this gate to a Qubit.""" + + # Check number of qubits this gate acts on. + if qubits.nqubits < self.min_qubits: + raise QuantumError( + 'Gate needs a minimum of %r qubits to act on, got: %r' % + (self.min_qubits, qubits.nqubits) + ) + + # If the controls are not met, just return + if isinstance(self, CGate): + if not self.eval_controls(qubits): + return qubits + + targets = self.targets + target_matrix = self.get_target_matrix(format='sympy') + + # Find which column of the target matrix this applies to. + column_index = 0 + n = 1 + for target in targets: + column_index += n*qubits[target] + n = n << 1 + column = target_matrix[:, int(column_index)] + + # Now apply each column element to the qubit. + result = 0 + for index in range(column.rows): + # TODO: This can be optimized to reduce the number of Qubit + # creations. We should simply manipulate the raw list of qubit + # values and then build the new Qubit object once. + # Make a copy of the incoming qubits. + new_qubit = qubits.__class__(*qubits.args) + # Flip the bits that need to be flipped. + for bit, target in enumerate(targets): + if new_qubit[target] != (index >> bit) & 1: + new_qubit = new_qubit.flip(target) + # The value in that row and column times the flipped-bit qubit + # is the result for that part. + result += column[index]*new_qubit + return result + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + format = options.get('format', 'sympy') + nqubits = options.get('nqubits', 0) + if nqubits == 0: + raise QuantumError( + 'The number of qubits must be given as nqubits.') + + # Make sure we have enough qubits for the gate. + if nqubits < self.min_qubits: + raise QuantumError( + 'The number of qubits %r is too small for the gate.' % nqubits + ) + + target_matrix = self.get_target_matrix(format) + targets = self.targets + if isinstance(self, CGate): + controls = self.controls + else: + controls = [] + m = represent_zbasis( + controls, targets, target_matrix, nqubits, format + ) + return m + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + + def _sympystr(self, printer, *args): + label = self._print_label(printer, *args) + return '%s(%s)' % (self.gate_name, label) + + def _pretty(self, printer, *args): + a = stringPict(self.gate_name) + b = self._print_label_pretty(printer, *args) + return self._print_subscript_pretty(a, b) + + def _latex(self, printer, *args): + label = self._print_label(printer, *args) + return '%s_{%s}' % (self.gate_name_latex, label) + + def plot_gate(self, axes, gate_idx, gate_grid, wire_grid): + raise NotImplementedError('plot_gate is not implemented.') + + +class CGate(Gate): + """A general unitary gate with control qubits. + + A general control gate applies a target gate to a set of targets if all + of the control qubits have a particular values (set by + ``CGate.control_value``). + + Parameters + ---------- + label : tuple + The label in this case has the form (controls, gate), where controls + is a tuple/list of control qubits (as ints) and gate is a ``Gate`` + instance that is the target operator. + + Examples + ======== + + """ + + gate_name = 'C' + gate_name_latex = 'C' + + # The values this class controls for. + control_value = _S.One + + simplify_cgate = False + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + # _eval_args has the right logic for the controls argument. + controls = args[0] + gate = args[1] + if not is_sequence(controls): + controls = (controls,) + controls = UnitaryOperator._eval_args(controls) + _validate_targets_controls(chain(controls, gate.targets)) + return (Tuple(*controls), gate) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**_max(_max(args[0]) + 1, args[1].min_qubits) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def nqubits(self): + """The total number of qubits this gate acts on. + + For controlled gate subclasses this includes both target and control + qubits, so that, for examples the CNOT gate acts on 2 qubits. + """ + return len(self.targets) + len(self.controls) + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(_max(self.controls), _max(self.targets)) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return self.gate.targets + + @property + def controls(self): + """A tuple of control qubits.""" + return tuple(self.label[0]) + + @property + def gate(self): + """The non-controlled gate that will be applied to the targets.""" + return self.label[1] + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + return self.gate.get_target_matrix(format) + + def eval_controls(self, qubit): + """Return True/False to indicate if the controls are satisfied.""" + return all(qubit[bit] == self.control_value for bit in self.controls) + + def decompose(self, **options): + """Decompose the controlled gate into CNOT and single qubits gates.""" + if len(self.controls) == 1: + c = self.controls[0] + t = self.gate.targets[0] + if isinstance(self.gate, YGate): + g1 = PhaseGate(t) + g2 = CNotGate(c, t) + g3 = PhaseGate(t) + g4 = ZGate(t) + return g1*g2*g3*g4 + if isinstance(self.gate, ZGate): + g1 = HadamardGate(t) + g2 = CNotGate(c, t) + g3 = HadamardGate(t) + return g1*g2*g3 + else: + return self + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + + def _print_label(self, printer, *args): + controls = self._print_sequence(self.controls, ',', printer, *args) + gate = printer._print(self.gate, *args) + return '(%s),%s' % (controls, gate) + + def _pretty(self, printer, *args): + controls = self._print_sequence_pretty( + self.controls, ',', printer, *args) + gate = printer._print(self.gate) + gate_name = stringPict(self.gate_name) + first = self._print_subscript_pretty(gate_name, controls) + gate = self._print_parens_pretty(gate) + final = prettyForm(*first.right(gate)) + return final + + def _latex(self, printer, *args): + controls = self._print_sequence(self.controls, ',', printer, *args) + gate = printer._print(self.gate, *args) + return r'%s_{%s}{\left(%s\right)}' % \ + (self.gate_name_latex, controls, gate) + + def plot_gate(self, circ_plot, gate_idx): + """ + Plot the controlled gate. If *simplify_cgate* is true, simplify + C-X and C-Z gates into their more familiar forms. + """ + min_wire = int(_min(chain(self.controls, self.targets))) + max_wire = int(_max(chain(self.controls, self.targets))) + circ_plot.control_line(gate_idx, min_wire, max_wire) + for c in self.controls: + circ_plot.control_point(gate_idx, int(c)) + if self.simplify_cgate: + if self.gate.gate_name == 'X': + self.gate.plot_gate_plus(circ_plot, gate_idx) + elif self.gate.gate_name == 'Z': + circ_plot.control_point(gate_idx, self.targets[0]) + else: + self.gate.plot_gate(circ_plot, gate_idx) + else: + self.gate.plot_gate(circ_plot, gate_idx) + + #------------------------------------------------------------------------- + # Miscellaneous + #------------------------------------------------------------------------- + + def _eval_dagger(self): + if isinstance(self.gate, HermitianOperator): + return self + else: + return Gate._eval_dagger(self) + + def _eval_inverse(self): + if isinstance(self.gate, HermitianOperator): + return self + else: + return Gate._eval_inverse(self) + + def _eval_power(self, exp): + if isinstance(self.gate, HermitianOperator): + if exp == -1: + return Gate._eval_power(self, exp) + elif abs(exp) % 2 == 0: + return self*(Gate._eval_inverse(self)) + else: + return self + else: + return Gate._eval_power(self, exp) + +class CGateS(CGate): + """Version of CGate that allows gate simplifications. + I.e. cnot looks like an oplus, cphase has dots, etc. + """ + simplify_cgate=True + + +class UGate(Gate): + """General gate specified by a set of targets and a target matrix. + + Parameters + ---------- + label : tuple + A tuple of the form (targets, U), where targets is a tuple of the + target qubits and U is a unitary matrix with dimension of + len(targets). + """ + gate_name = 'U' + gate_name_latex = 'U' + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + targets = args[0] + if not is_sequence(targets): + targets = (targets,) + targets = Gate._eval_args(targets) + _validate_targets_controls(targets) + mat = args[1] + if not isinstance(mat, MatrixBase): + raise TypeError('Matrix expected, got: %r' % mat) + #make sure this matrix is of a Basic type + mat = _sympify(mat) + dim = 2**len(targets) + if not all(dim == shape for shape in mat.shape): + raise IndexError( + 'Number of targets must match the matrix size: %r %r' % + (targets, mat) + ) + return (targets, mat) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args[0]) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def targets(self): + """A tuple of target qubits.""" + return tuple(self.label[0]) + + #------------------------------------------------------------------------- + # Gate methods + #------------------------------------------------------------------------- + + def get_target_matrix(self, format='sympy'): + """The matrix rep. of the target part of the gate. + + Parameters + ---------- + format : str + The format string ('sympy','numpy', etc.) + """ + return self.label[1] + + #------------------------------------------------------------------------- + # Print methods + #------------------------------------------------------------------------- + def _pretty(self, printer, *args): + targets = self._print_sequence_pretty( + self.targets, ',', printer, *args) + gate_name = stringPict(self.gate_name) + return self._print_subscript_pretty(gate_name, targets) + + def _latex(self, printer, *args): + targets = self._print_sequence(self.targets, ',', printer, *args) + return r'%s_{%s}' % (self.gate_name_latex, targets) + + def plot_gate(self, circ_plot, gate_idx): + circ_plot.one_qubit_box( + self.gate_name_plot, + gate_idx, int(self.targets[0]) + ) + + +class OneQubitGate(Gate): + """A single qubit unitary gate base class.""" + + nqubits = _S.One + + def plot_gate(self, circ_plot, gate_idx): + circ_plot.one_qubit_box( + self.gate_name_plot, + gate_idx, int(self.targets[0]) + ) + + def _eval_commutator(self, other, **hints): + if isinstance(other, OneQubitGate): + if self.targets != other.targets or self.__class__ == other.__class__: + return _S.Zero + return Operator._eval_commutator(self, other, **hints) + + def _eval_anticommutator(self, other, **hints): + if isinstance(other, OneQubitGate): + if self.targets != other.targets or self.__class__ == other.__class__: + return Integer(2)*self*other + return Operator._eval_anticommutator(self, other, **hints) + + +class TwoQubitGate(Gate): + """A two qubit unitary gate base class.""" + + nqubits = Integer(2) + +#----------------------------------------------------------------------------- +# Single Qubit Gates +#----------------------------------------------------------------------------- + + +class IdentityGate(OneQubitGate): + """The single qubit identity gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = True + gate_name = '1' + gate_name_latex = '1' + + # Short cut version of gate._apply_operator_Qubit + def _apply_operator_Qubit(self, qubits, **options): + # Check number of qubits this gate acts on (see gate._apply_operator_Qubit) + if qubits.nqubits < self.min_qubits: + raise QuantumError( + 'Gate needs a minimum of %r qubits to act on, got: %r' % + (self.min_qubits, qubits.nqubits) + ) + return qubits # no computation required for IdentityGate + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('eye2', format) + + def _eval_commutator(self, other, **hints): + return _S.Zero + + def _eval_anticommutator(self, other, **hints): + return Integer(2)*other + + +class HadamardGate(HermitianOperator, OneQubitGate): + """The single qubit Hadamard gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.physics.quantum.qubit import Qubit + >>> from sympy.physics.quantum.gate import HadamardGate + >>> from sympy.physics.quantum.qapply import qapply + >>> qapply(HadamardGate(0)*Qubit('1')) + sqrt(2)*|0>/2 - sqrt(2)*|1>/2 + >>> # Hadamard on bell state, applied on 2 qubits. + >>> psi = 1/sqrt(2)*(Qubit('00')+Qubit('11')) + >>> qapply(HadamardGate(0)*HadamardGate(1)*psi) + sqrt(2)*|00>/2 + sqrt(2)*|11>/2 + + """ + gate_name = 'H' + gate_name_latex = 'H' + + def get_target_matrix(self, format='sympy'): + if _normalized: + return matrix_cache.get_matrix('H', format) + else: + return matrix_cache.get_matrix('Hsqrt2', format) + + def _eval_commutator_XGate(self, other, **hints): + return I*sqrt(2)*YGate(self.targets[0]) + + def _eval_commutator_YGate(self, other, **hints): + return I*sqrt(2)*(ZGate(self.targets[0]) - XGate(self.targets[0])) + + def _eval_commutator_ZGate(self, other, **hints): + return -I*sqrt(2)*YGate(self.targets[0]) + + def _eval_anticommutator_XGate(self, other, **hints): + return sqrt(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + def _eval_anticommutator_ZGate(self, other, **hints): + return sqrt(2)*IdentityGate(self.targets[0]) + + +class XGate(HermitianOperator, OneQubitGate): + """The single qubit X, or NOT, gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'X' + gate_name_latex = 'X' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('X', format) + + def plot_gate(self, circ_plot, gate_idx): + OneQubitGate.plot_gate(self,circ_plot,gate_idx) + + def plot_gate_plus(self, circ_plot, gate_idx): + circ_plot.not_point( + gate_idx, int(self.label[0]) + ) + + def _eval_commutator_YGate(self, other, **hints): + return Integer(2)*I*ZGate(self.targets[0]) + + def _eval_anticommutator_XGate(self, other, **hints): + return Integer(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + def _eval_anticommutator_ZGate(self, other, **hints): + return _S.Zero + + +class YGate(HermitianOperator, OneQubitGate): + """The single qubit Y gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'Y' + gate_name_latex = 'Y' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('Y', format) + + def _eval_commutator_ZGate(self, other, **hints): + return Integer(2)*I*XGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return Integer(2)*IdentityGate(self.targets[0]) + + def _eval_anticommutator_ZGate(self, other, **hints): + return _S.Zero + + +class ZGate(HermitianOperator, OneQubitGate): + """The single qubit Z gate. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + gate_name = 'Z' + gate_name_latex = 'Z' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('Z', format) + + def _eval_commutator_XGate(self, other, **hints): + return Integer(2)*I*YGate(self.targets[0]) + + def _eval_anticommutator_YGate(self, other, **hints): + return _S.Zero + + +class PhaseGate(OneQubitGate): + """The single qubit phase, or S, gate. + + This gate rotates the phase of the state by pi/2 if the state is ``|1>`` and + does nothing if the state is ``|0>``. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = False + gate_name = 'S' + gate_name_latex = 'S' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('S', format) + + def _eval_commutator_ZGate(self, other, **hints): + return _S.Zero + + def _eval_commutator_TGate(self, other, **hints): + return _S.Zero + + +class TGate(OneQubitGate): + """The single qubit pi/8 gate. + + This gate rotates the phase of the state by pi/4 if the state is ``|1>`` and + does nothing if the state is ``|0>``. + + Parameters + ---------- + target : int + The target qubit this gate will apply to. + + Examples + ======== + + """ + is_hermitian = False + gate_name = 'T' + gate_name_latex = 'T' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('T', format) + + def _eval_commutator_ZGate(self, other, **hints): + return _S.Zero + + def _eval_commutator_PhaseGate(self, other, **hints): + return _S.Zero + + +# Aliases for gate names. +H = HadamardGate +X = XGate +Y = YGate +Z = ZGate +T = TGate +Phase = S = PhaseGate + + +#----------------------------------------------------------------------------- +# 2 Qubit Gates +#----------------------------------------------------------------------------- + + +class CNotGate(HermitianOperator, CGate, TwoQubitGate): + """Two qubit controlled-NOT. + + This gate performs the NOT or X gate on the target qubit if the control + qubits all have the value 1. + + Parameters + ---------- + label : tuple + A tuple of the form (control, target). + + Examples + ======== + + >>> from sympy.physics.quantum.gate import CNOT + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import Qubit + >>> c = CNOT(1,0) + >>> qapply(c*Qubit('10')) # note that qubits are indexed from right to left + |11> + + """ + gate_name = 'CNOT' + gate_name_latex = r'\text{CNOT}' + simplify_cgate = True + + #------------------------------------------------------------------------- + # Initialization + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + args = Gate._eval_args(args) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**(_max(args) + 1) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def min_qubits(self): + """The minimum number of qubits this gate needs to act on.""" + return _max(self.label) + 1 + + @property + def targets(self): + """A tuple of target qubits.""" + return (self.label[1],) + + @property + def controls(self): + """A tuple of control qubits.""" + return (self.label[0],) + + @property + def gate(self): + """The non-controlled gate that will be applied to the targets.""" + return XGate(self.label[1]) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + # The default printing of Gate works better than those of CGate, so we + # go around the overridden methods in CGate. + + def _print_label(self, printer, *args): + return Gate._print_label(self, printer, *args) + + def _pretty(self, printer, *args): + return Gate._pretty(self, printer, *args) + + def _latex(self, printer, *args): + return Gate._latex(self, printer, *args) + + #------------------------------------------------------------------------- + # Commutator/AntiCommutator + #------------------------------------------------------------------------- + + def _eval_commutator_ZGate(self, other, **hints): + """[CNOT(i, j), Z(i)] == 0.""" + if self.controls[0] == other.targets[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + def _eval_commutator_TGate(self, other, **hints): + """[CNOT(i, j), T(i)] == 0.""" + return self._eval_commutator_ZGate(other, **hints) + + def _eval_commutator_PhaseGate(self, other, **hints): + """[CNOT(i, j), S(i)] == 0.""" + return self._eval_commutator_ZGate(other, **hints) + + def _eval_commutator_XGate(self, other, **hints): + """[CNOT(i, j), X(j)] == 0.""" + if self.targets[0] == other.targets[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + def _eval_commutator_CNotGate(self, other, **hints): + """[CNOT(i, j), CNOT(i,k)] == 0.""" + if self.controls[0] == other.controls[0]: + return _S.Zero + else: + raise NotImplementedError('Commutator not implemented: %r' % other) + + +class SwapGate(TwoQubitGate): + """Two qubit SWAP gate. + + This gate swap the values of the two qubits. + + Parameters + ---------- + label : tuple + A tuple of the form (target1, target2). + + Examples + ======== + + """ + is_hermitian = True + gate_name = 'SWAP' + gate_name_latex = r'\text{SWAP}' + + def get_target_matrix(self, format='sympy'): + return matrix_cache.get_matrix('SWAP', format) + + def decompose(self, **options): + """Decompose the SWAP gate into CNOT gates.""" + i, j = self.targets[0], self.targets[1] + g1 = CNotGate(i, j) + g2 = CNotGate(j, i) + return g1*g2*g1 + + def plot_gate(self, circ_plot, gate_idx): + min_wire = int(_min(self.targets)) + max_wire = int(_max(self.targets)) + circ_plot.control_line(gate_idx, min_wire, max_wire) + circ_plot.swap_point(gate_idx, min_wire) + circ_plot.swap_point(gate_idx, max_wire) + + def _represent_ZGate(self, basis, **options): + """Represent the SWAP gate in the computational basis. + + The following representation is used to compute this: + + SWAP = |1><1|x|1><1| + |0><0|x|0><0| + |1><0|x|0><1| + |0><1|x|1><0| + """ + format = options.get('format', 'sympy') + targets = [int(t) for t in self.targets] + min_target = _min(targets) + max_target = _max(targets) + nqubits = options.get('nqubits', self.min_qubits) + + op01 = matrix_cache.get_matrix('op01', format) + op10 = matrix_cache.get_matrix('op10', format) + op11 = matrix_cache.get_matrix('op11', format) + op00 = matrix_cache.get_matrix('op00', format) + eye2 = matrix_cache.get_matrix('eye2', format) + + result = None + for i, j in ((op01, op10), (op10, op01), (op00, op00), (op11, op11)): + product = nqubits*[eye2] + product[nqubits - min_target - 1] = i + product[nqubits - max_target - 1] = j + new_result = matrix_tensor_product(*product) + if result is None: + result = new_result + else: + result = result + new_result + + return result + + +# Aliases for gate names. +CNOT = CNotGate +SWAP = SwapGate +def CPHASE(a,b): return CGateS((a,),Z(b)) + + +#----------------------------------------------------------------------------- +# Represent +#----------------------------------------------------------------------------- + + +def represent_zbasis(controls, targets, target_matrix, nqubits, format='sympy'): + """Represent a gate with controls, targets and target_matrix. + + This function does the low-level work of representing gates as matrices + in the standard computational basis (ZGate). Currently, we support two + main cases: + + 1. One target qubit and no control qubits. + 2. One target qubits and multiple control qubits. + + For the base of multiple controls, we use the following expression [1]: + + 1_{2**n} + (|1><1|)^{(n-1)} x (target-matrix - 1_{2}) + + Parameters + ---------- + controls : list, tuple + A sequence of control qubits. + targets : list, tuple + A sequence of target qubits. + target_matrix : sympy.Matrix, numpy.matrix, scipy.sparse + The matrix form of the transformation to be performed on the target + qubits. The format of this matrix must match that passed into + the `format` argument. + nqubits : int + The total number of qubits used for the representation. + format : str + The format of the final matrix ('sympy', 'numpy', 'scipy.sparse'). + + Examples + ======== + + References + ---------- + [1] http://www.johnlapeyre.com/qinf/qinf_html/node6.html. + """ + controls = [int(x) for x in controls] + targets = [int(x) for x in targets] + nqubits = int(nqubits) + + # This checks for the format as well. + op11 = matrix_cache.get_matrix('op11', format) + eye2 = matrix_cache.get_matrix('eye2', format) + + # Plain single qubit case + if len(controls) == 0 and len(targets) == 1: + product = [] + bit = targets[0] + # Fill product with [I1,Gate,I2] such that the unitaries, + # I, cause the gate to be applied to the correct Qubit + if bit != nqubits - 1: + product.append(matrix_eye(2**(nqubits - bit - 1), format=format)) + product.append(target_matrix) + if bit != 0: + product.append(matrix_eye(2**bit, format=format)) + return matrix_tensor_product(*product) + + # Single target, multiple controls. + elif len(targets) == 1 and len(controls) >= 1: + target = targets[0] + + # Build the non-trivial part. + product2 = [] + for i in range(nqubits): + product2.append(matrix_eye(2, format=format)) + for control in controls: + product2[nqubits - 1 - control] = op11 + product2[nqubits - 1 - target] = target_matrix - eye2 + + return matrix_eye(2**nqubits, format=format) + \ + matrix_tensor_product(*product2) + + # Multi-target, multi-control is not yet implemented. + else: + raise NotImplementedError( + 'The representation of multi-target, multi-control gates ' + 'is not implemented.' + ) + + +#----------------------------------------------------------------------------- +# Gate manipulation functions. +#----------------------------------------------------------------------------- + + +def gate_simp(circuit): + """Simplifies gates symbolically + + It first sorts gates using gate_sort. It then applies basic + simplification rules to the circuit, e.g., XGate**2 = Identity + """ + + # Bubble sort out gates that commute. + circuit = gate_sort(circuit) + + # Do simplifications by subing a simplification into the first element + # which can be simplified. We recursively call gate_simp with new circuit + # as input more simplifications exist. + if isinstance(circuit, Add): + return sum(gate_simp(t) for t in circuit.args) + elif isinstance(circuit, Mul): + circuit_args = circuit.args + elif isinstance(circuit, Pow): + b, e = circuit.as_base_exp() + circuit_args = (gate_simp(b)**e,) + else: + return circuit + + # Iterate through each element in circuit, simplify if possible. + for i in range(len(circuit_args)): + # H,X,Y or Z squared is 1. + # T**2 = S, S**2 = Z + if isinstance(circuit_args[i], Pow): + if isinstance(circuit_args[i].base, + (HadamardGate, XGate, YGate, ZGate)) \ + and isinstance(circuit_args[i].exp, Number): + # Build a new circuit taking replacing the + # H,X,Y,Z squared with one. + newargs = (circuit_args[:i] + + (circuit_args[i].base**(circuit_args[i].exp % 2),) + + circuit_args[i + 1:]) + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + elif isinstance(circuit_args[i].base, PhaseGate): + # Build a new circuit taking old circuit but splicing + # in simplification. + newargs = circuit_args[:i] + # Replace PhaseGate**2 with ZGate. + newargs = newargs + (ZGate(circuit_args[i].base.args[0])** + (Integer(circuit_args[i].exp/2)), circuit_args[i].base** + (circuit_args[i].exp % 2)) + # Append the last elements. + newargs = newargs + circuit_args[i + 1:] + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + elif isinstance(circuit_args[i].base, TGate): + # Build a new circuit taking all the old elements. + newargs = circuit_args[:i] + + # Put an Phasegate in place of any TGate**2. + newargs = newargs + (PhaseGate(circuit_args[i].base.args[0])** + Integer(circuit_args[i].exp/2), circuit_args[i].base** + (circuit_args[i].exp % 2)) + + # Append the last elements. + newargs = newargs + circuit_args[i + 1:] + # Recursively simplify the new circuit. + circuit = gate_simp(Mul(*newargs)) + break + return circuit + + +def gate_sort(circuit): + """Sorts the gates while keeping track of commutation relations + + This function uses a bubble sort to rearrange the order of gate + application. Keeps track of Quantum computations special commutation + relations (e.g. things that apply to the same Qubit do not commute with + each other) + + circuit is the Mul of gates that are to be sorted. + """ + # Make sure we have an Add or Mul. + if isinstance(circuit, Add): + return sum(gate_sort(t) for t in circuit.args) + if isinstance(circuit, Pow): + return gate_sort(circuit.base)**circuit.exp + elif isinstance(circuit, Gate): + return circuit + if not isinstance(circuit, Mul): + return circuit + + changes = True + while changes: + changes = False + circ_array = circuit.args + for i in range(len(circ_array) - 1): + # Go through each element and switch ones that are in wrong order + if isinstance(circ_array[i], (Gate, Pow)) and \ + isinstance(circ_array[i + 1], (Gate, Pow)): + # If we have a Pow object, look at only the base + first_base, first_exp = circ_array[i].as_base_exp() + second_base, second_exp = circ_array[i + 1].as_base_exp() + + # Use SymPy's hash based sorting. This is not mathematical + # sorting, but is rather based on comparing hashes of objects. + # See Basic.compare for details. + if first_base.compare(second_base) > 0: + if Commutator(first_base, second_base).doit() == 0: + new_args = (circuit.args[:i] + (circuit.args[i + 1],) + + (circuit.args[i],) + circuit.args[i + 2:]) + circuit = Mul(*new_args) + changes = True + break + if AntiCommutator(first_base, second_base).doit() == 0: + new_args = (circuit.args[:i] + (circuit.args[i + 1],) + + (circuit.args[i],) + circuit.args[i + 2:]) + sign = _S.NegativeOne**(first_exp*second_exp) + circuit = sign*Mul(*new_args) + changes = True + break + return circuit + + +#----------------------------------------------------------------------------- +# Utility functions +#----------------------------------------------------------------------------- + + +def random_circuit(ngates, nqubits, gate_space=(X, Y, Z, S, T, H, CNOT, SWAP)): + """Return a random circuit of ngates and nqubits. + + This uses an equally weighted sample of (X, Y, Z, S, T, H, CNOT, SWAP) + gates. + + Parameters + ---------- + ngates : int + The number of gates in the circuit. + nqubits : int + The number of qubits in the circuit. + gate_space : tuple + A tuple of the gate classes that will be used in the circuit. + Repeating gate classes multiple times in this tuple will increase + the frequency they appear in the random circuit. + """ + qubit_space = range(nqubits) + result = [] + for i in range(ngates): + g = random.choice(gate_space) + if g == CNotGate or g == SwapGate: + qubits = random.sample(qubit_space, 2) + g = g(*qubits) + else: + qubit = random.choice(qubit_space) + g = g(qubit) + result.append(g) + return Mul(*result) + + +def zx_basis_transform(self, format='sympy'): + """Transformation matrix from Z to X basis.""" + return matrix_cache.get_matrix('ZX', format) + + +def zy_basis_transform(self, format='sympy'): + """Transformation matrix from Z to Y basis.""" + return matrix_cache.get_matrix('ZY', format) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/grover.py b/lib/python3.10/site-packages/sympy/physics/quantum/grover.py new file mode 100644 index 0000000000000000000000000000000000000000..a03bd3a61a6e0960ab66d55bcc0fc7f25936199e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/grover.py @@ -0,0 +1,345 @@ +"""Grover's algorithm and helper functions. + +Todo: + +* W gate construction (or perhaps -W gate based on Mermin's book) +* Generalize the algorithm for an unknown function that returns 1 on multiple + qubit states, not just one. +* Implement _represent_ZGate in OracleGate +""" + +from sympy.core.numbers import pi +from sympy.core.sympify import sympify +from sympy.core.basic import Atom +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import eye +from sympy.core.numbers import NegativeOne +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.operator import UnitaryOperator +from sympy.physics.quantum.gate import Gate +from sympy.physics.quantum.qubit import IntQubit + +__all__ = [ + 'OracleGate', + 'WGate', + 'superposition_basis', + 'grover_iteration', + 'apply_grover' +] + + +def superposition_basis(nqubits): + """Creates an equal superposition of the computational basis. + + Parameters + ========== + + nqubits : int + The number of qubits. + + Returns + ======= + + state : Qubit + An equal superposition of the computational basis with nqubits. + + Examples + ======== + + Create an equal superposition of 2 qubits:: + + >>> from sympy.physics.quantum.grover import superposition_basis + >>> superposition_basis(2) + |0>/2 + |1>/2 + |2>/2 + |3>/2 + """ + + amp = 1/sqrt(2**nqubits) + return sum(amp*IntQubit(n, nqubits=nqubits) for n in range(2**nqubits)) + +class OracleGateFunction(Atom): + """Wrapper for python functions used in `OracleGate`s""" + + def __new__(cls, function): + if not callable(function): + raise TypeError('Callable expected, got: %r' % function) + obj = Atom.__new__(cls) + obj.function = function + return obj + + def _hashable_content(self): + return type(self), self.function + + def __call__(self, *args): + return self.function(*args) + + +class OracleGate(Gate): + """A black box gate. + + The gate marks the desired qubits of an unknown function by flipping + the sign of the qubits. The unknown function returns true when it + finds its desired qubits and false otherwise. + + Parameters + ========== + + qubits : int + Number of qubits. + + oracle : callable + A callable function that returns a boolean on a computational basis. + + Examples + ======== + + Apply an Oracle gate that flips the sign of ``|2>`` on different qubits:: + + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.grover import OracleGate + >>> f = lambda qubits: qubits == IntQubit(2) + >>> v = OracleGate(2, f) + >>> qapply(v*IntQubit(2)) + -|2> + >>> qapply(v*IntQubit(3)) + |3> + """ + + gate_name = 'V' + gate_name_latex = 'V' + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + if len(args) != 2: + raise QuantumError( + 'Insufficient/excessive arguments to Oracle. Please ' + + 'supply the number of qubits and an unknown function.' + ) + sub_args = (args[0],) + sub_args = UnitaryOperator._eval_args(sub_args) + if not sub_args[0].is_Integer: + raise TypeError('Integer expected, got: %r' % sub_args[0]) + + function = args[1] + if not isinstance(function, OracleGateFunction): + function = OracleGateFunction(function) + + return (sub_args[0], function) + + @classmethod + def _eval_hilbert_space(cls, args): + """This returns the smallest possible Hilbert space.""" + return ComplexSpace(2)**args[0] + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def search_function(self): + """The unknown function that helps find the sought after qubits.""" + return self.label[1] + + @property + def targets(self): + """A tuple of target qubits.""" + return sympify(tuple(range(self.args[0]))) + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_Qubit(self, qubits, **options): + """Apply this operator to a Qubit subclass. + + Parameters + ========== + + qubits : Qubit + The qubit subclass to apply this operator to. + + Returns + ======= + + state : Expr + The resulting quantum state. + """ + if qubits.nqubits != self.nqubits: + raise QuantumError( + 'OracleGate operates on %r qubits, got: %r' + % (self.nqubits, qubits.nqubits) + ) + # If function returns 1 on qubits + # return the negative of the qubits (flip the sign) + if self.search_function(qubits): + return -qubits + else: + return qubits + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_ZGate(self, basis, **options): + """ + Represent the OracleGate in the computational basis. + """ + nbasis = 2**self.nqubits # compute it only once + matrixOracle = eye(nbasis) + # Flip the sign given the output of the oracle function + for i in range(nbasis): + if self.search_function(IntQubit(i, nqubits=self.nqubits)): + matrixOracle[i, i] = NegativeOne() + return matrixOracle + + +class WGate(Gate): + """General n qubit W Gate in Grover's algorithm. + + The gate performs the operation ``2|phi> = (tensor product of n Hadamards)*(|0> with n qubits)`` + + Parameters + ========== + + nqubits : int + The number of qubits to operate on + + """ + + gate_name = 'W' + gate_name_latex = 'W' + + @classmethod + def _eval_args(cls, args): + if len(args) != 1: + raise QuantumError( + 'Insufficient/excessive arguments to W gate. Please ' + + 'supply the number of qubits to operate on.' + ) + args = UnitaryOperator._eval_args(args) + if not args[0].is_Integer: + raise TypeError('Integer expected, got: %r' % args[0]) + return args + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def targets(self): + return sympify(tuple(reversed(range(self.args[0])))) + + #------------------------------------------------------------------------- + # Apply + #------------------------------------------------------------------------- + + def _apply_operator_Qubit(self, qubits, **options): + """ + qubits: a set of qubits (Qubit) + Returns: quantum object (quantum expression - QExpr) + """ + if qubits.nqubits != self.nqubits: + raise QuantumError( + 'WGate operates on %r qubits, got: %r' + % (self.nqubits, qubits.nqubits) + ) + + # See 'Quantum Computer Science' by David Mermin p.92 -> W|a> result + # Return (2/(sqrt(2^n)))|phi> - |a> where |a> is the current basis + # state and phi is the superposition of basis states (see function + # create_computational_basis above) + basis_states = superposition_basis(self.nqubits) + change_to_basis = (2/sqrt(2**self.nqubits))*basis_states + return change_to_basis - qubits + + +def grover_iteration(qstate, oracle): + """Applies one application of the Oracle and W Gate, WV. + + Parameters + ========== + + qstate : Qubit + A superposition of qubits. + oracle : OracleGate + The black box operator that flips the sign of the desired basis qubits. + + Returns + ======= + + Qubit : The qubits after applying the Oracle and W gate. + + Examples + ======== + + Perform one iteration of grover's algorithm to see a phase change:: + + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.grover import OracleGate + >>> from sympy.physics.quantum.grover import superposition_basis + >>> from sympy.physics.quantum.grover import grover_iteration + >>> numqubits = 2 + >>> basis_states = superposition_basis(numqubits) + >>> f = lambda qubits: qubits == IntQubit(2) + >>> v = OracleGate(numqubits, f) + >>> qapply(grover_iteration(basis_states, v)) + |2> + + """ + wgate = WGate(oracle.nqubits) + return wgate*oracle*qstate + + +def apply_grover(oracle, nqubits, iterations=None): + """Applies grover's algorithm. + + Parameters + ========== + + oracle : callable + The unknown callable function that returns true when applied to the + desired qubits and false otherwise. + + Returns + ======= + + state : Expr + The resulting state after Grover's algorithm has been iterated. + + Examples + ======== + + Apply grover's algorithm to an even superposition of 2 qubits:: + + >>> from sympy.physics.quantum.qapply import qapply + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.grover import apply_grover + >>> f = lambda qubits: qubits == IntQubit(2) + >>> qapply(apply_grover(f, 2)) + |2> + + """ + if nqubits <= 0: + raise QuantumError( + 'Grover\'s algorithm needs nqubits > 0, received %r qubits' + % nqubits + ) + if iterations is None: + iterations = floor(sqrt(2**nqubits)*(pi/4)) + + v = OracleGate(nqubits, oracle) + iterated = superposition_basis(nqubits) + for iter in range(iterations): + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + + return iterated diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/hilbert.py b/lib/python3.10/site-packages/sympy/physics/quantum/hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..f475a9e83a6ccc93e9e2dbb9873ad111c1d05f93 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/hilbert.py @@ -0,0 +1,653 @@ +"""Hilbert spaces for quantum mechanics. + +Authors: +* Brian Granger +* Matt Curry +""" + +from functools import reduce + +from sympy.core.basic import Basic +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.sets.sets import Interval +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.qexpr import QuantumError + + +__all__ = [ + 'HilbertSpaceError', + 'HilbertSpace', + 'TensorProductHilbertSpace', + 'TensorPowerHilbertSpace', + 'DirectSumHilbertSpace', + 'ComplexSpace', + 'L2', + 'FockSpace' +] + +#----------------------------------------------------------------------------- +# Main objects +#----------------------------------------------------------------------------- + + +class HilbertSpaceError(QuantumError): + pass + +#----------------------------------------------------------------------------- +# Main objects +#----------------------------------------------------------------------------- + + +class HilbertSpace(Basic): + """An abstract Hilbert space for quantum mechanics. + + In short, a Hilbert space is an abstract vector space that is complete + with inner products defined [1]_. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import HilbertSpace + >>> hs = HilbertSpace() + >>> hs + H + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space + """ + + def __new__(cls): + obj = Basic.__new__(cls) + return obj + + @property + def dimension(self): + """Return the Hilbert dimension of the space.""" + raise NotImplementedError('This Hilbert space has no dimension.') + + def __add__(self, other): + return DirectSumHilbertSpace(self, other) + + def __radd__(self, other): + return DirectSumHilbertSpace(other, self) + + def __mul__(self, other): + return TensorProductHilbertSpace(self, other) + + def __rmul__(self, other): + return TensorProductHilbertSpace(other, self) + + def __pow__(self, other, mod=None): + if mod is not None: + raise ValueError('The third argument to __pow__ is not supported \ + for Hilbert spaces.') + return TensorPowerHilbertSpace(self, other) + + def __contains__(self, other): + """Is the operator or state in this Hilbert space. + + This is checked by comparing the classes of the Hilbert spaces, not + the instances. This is to allow Hilbert Spaces with symbolic + dimensions. + """ + if other.hilbert_space.__class__ == self.__class__: + return True + else: + return False + + def _sympystr(self, printer, *args): + return 'H' + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER H}' + return prettyForm(ustr) + + def _latex(self, printer, *args): + return r'\mathcal{H}' + + +class ComplexSpace(HilbertSpace): + """Finite dimensional Hilbert space of complex vectors. + + The elements of this Hilbert space are n-dimensional complex valued + vectors with the usual inner product that takes the complex conjugate + of the vector on the right. + + A classic example of this type of Hilbert space is spin-1/2, which is + ``ComplexSpace(2)``. Generalizing to spin-s, the space is + ``ComplexSpace(2*s+1)``. Quantum computing with N qubits is done with the + direct product space ``ComplexSpace(2)**N``. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.quantum.hilbert import ComplexSpace + >>> c1 = ComplexSpace(2) + >>> c1 + C(2) + >>> c1.dimension + 2 + + >>> n = symbols('n') + >>> c2 = ComplexSpace(n) + >>> c2 + C(n) + >>> c2.dimension + n + + """ + + def __new__(cls, dimension): + dimension = sympify(dimension) + r = cls.eval(dimension) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, dimension) + return obj + + @classmethod + def eval(cls, dimension): + if len(dimension.atoms()) == 1: + if not (dimension.is_Integer and dimension > 0 or dimension is S.Infinity + or dimension.is_Symbol): + raise TypeError('The dimension of a ComplexSpace can only' + 'be a positive integer, oo, or a Symbol: %r' + % dimension) + else: + for dim in dimension.atoms(): + if not (dim.is_Integer or dim is S.Infinity or dim.is_Symbol): + raise TypeError('The dimension of a ComplexSpace can only' + ' contain integers, oo, or a Symbol: %r' + % dim) + + @property + def dimension(self): + return self.args[0] + + def _sympyrepr(self, printer, *args): + return "%s(%s)" % (self.__class__.__name__, + printer._print(self.dimension, *args)) + + def _sympystr(self, printer, *args): + return "C(%s)" % printer._print(self.dimension, *args) + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER C}' + pform_exp = printer._print(self.dimension, *args) + pform_base = prettyForm(ustr) + return pform_base**pform_exp + + def _latex(self, printer, *args): + return r'\mathcal{C}^{%s}' % printer._print(self.dimension, *args) + + +class L2(HilbertSpace): + """The Hilbert space of square integrable functions on an interval. + + An L2 object takes in a single SymPy Interval argument which represents + the interval its functions (vectors) are defined on. + + Examples + ======== + + >>> from sympy import Interval, oo + >>> from sympy.physics.quantum.hilbert import L2 + >>> hs = L2(Interval(0,oo)) + >>> hs + L2(Interval(0, oo)) + >>> hs.dimension + oo + >>> hs.interval + Interval(0, oo) + + """ + + def __new__(cls, interval): + if not isinstance(interval, Interval): + raise TypeError('L2 interval must be an Interval instance: %r' + % interval) + obj = Basic.__new__(cls, interval) + return obj + + @property + def dimension(self): + return S.Infinity + + @property + def interval(self): + return self.args[0] + + def _sympyrepr(self, printer, *args): + return "L2(%s)" % printer._print(self.interval, *args) + + def _sympystr(self, printer, *args): + return "L2(%s)" % printer._print(self.interval, *args) + + def _pretty(self, printer, *args): + pform_exp = prettyForm('2') + pform_base = prettyForm('L') + return pform_base**pform_exp + + def _latex(self, printer, *args): + interval = printer._print(self.interval, *args) + return r'{\mathcal{L}^2}\left( %s \right)' % interval + + +class FockSpace(HilbertSpace): + """The Hilbert space for second quantization. + + Technically, this Hilbert space is a infinite direct sum of direct + products of single particle Hilbert spaces [1]_. This is a mess, so we have + a class to represent it directly. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import FockSpace + >>> hs = FockSpace() + >>> hs + F + >>> hs.dimension + oo + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Fock_space + """ + + def __new__(cls): + obj = Basic.__new__(cls) + return obj + + @property + def dimension(self): + return S.Infinity + + def _sympyrepr(self, printer, *args): + return "FockSpace()" + + def _sympystr(self, printer, *args): + return "F" + + def _pretty(self, printer, *args): + ustr = '\N{LATIN CAPITAL LETTER F}' + return prettyForm(ustr) + + def _latex(self, printer, *args): + return r'\mathcal{F}' + + +class TensorProductHilbertSpace(HilbertSpace): + """A tensor product of Hilbert spaces [1]_. + + The tensor product between Hilbert spaces is represented by the + operator ``*`` Products of the same Hilbert space will be combined into + tensor powers. + + A ``TensorProductHilbertSpace`` object takes in an arbitrary number of + ``HilbertSpace`` objects as its arguments. In addition, multiplication of + ``HilbertSpace`` objects will automatically return this tensor product + object. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + >>> from sympy import symbols + + >>> c = ComplexSpace(2) + >>> f = FockSpace() + >>> hs = c*f + >>> hs + C(2)*F + >>> hs.dimension + oo + >>> hs.spaces + (C(2), F) + + >>> c1 = ComplexSpace(2) + >>> n = symbols('n') + >>> c2 = ComplexSpace(n) + >>> hs = c1*c2 + >>> hs + C(2)*C(n) + >>> hs.dimension + 2*n + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Tensor_products + """ + + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, *args) + return obj + + @classmethod + def eval(cls, args): + """Evaluates the direct product.""" + new_args = [] + recall = False + #flatten arguments + for arg in args: + if isinstance(arg, TensorProductHilbertSpace): + new_args.extend(arg.args) + recall = True + elif isinstance(arg, (HilbertSpace, TensorPowerHilbertSpace)): + new_args.append(arg) + else: + raise TypeError('Hilbert spaces can only be multiplied by \ + other Hilbert spaces: %r' % arg) + #combine like arguments into direct powers + comb_args = [] + prev_arg = None + for new_arg in new_args: + if prev_arg is not None: + if isinstance(new_arg, TensorPowerHilbertSpace) and \ + isinstance(prev_arg, TensorPowerHilbertSpace) and \ + new_arg.base == prev_arg.base: + prev_arg = new_arg.base**(new_arg.exp + prev_arg.exp) + elif isinstance(new_arg, TensorPowerHilbertSpace) and \ + new_arg.base == prev_arg: + prev_arg = prev_arg**(new_arg.exp + 1) + elif isinstance(prev_arg, TensorPowerHilbertSpace) and \ + new_arg == prev_arg.base: + prev_arg = new_arg**(prev_arg.exp + 1) + elif new_arg == prev_arg: + prev_arg = new_arg**2 + else: + comb_args.append(prev_arg) + prev_arg = new_arg + elif prev_arg is None: + prev_arg = new_arg + comb_args.append(prev_arg) + if recall: + return TensorProductHilbertSpace(*comb_args) + elif len(comb_args) == 1: + return TensorPowerHilbertSpace(comb_args[0].base, comb_args[0].exp) + else: + return None + + @property + def dimension(self): + arg_list = [arg.dimension for arg in self.args] + if S.Infinity in arg_list: + return S.Infinity + else: + return reduce(lambda x, y: x*y, arg_list) + + @property + def spaces(self): + """A tuple of the Hilbert spaces in this tensor product.""" + return self.args + + def _spaces_printer(self, printer, *args): + spaces_strs = [] + for arg in self.args: + s = printer._print(arg, *args) + if isinstance(arg, DirectSumHilbertSpace): + s = '(%s)' % s + spaces_strs.append(s) + return spaces_strs + + def _sympyrepr(self, printer, *args): + spaces_reprs = self._spaces_printer(printer, *args) + return "TensorProductHilbertSpace(%s)" % ','.join(spaces_reprs) + + def _sympystr(self, printer, *args): + spaces_strs = self._spaces_printer(printer, *args) + return '*'.join(spaces_strs) + + def _pretty(self, printer, *args): + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right(' ' + '\N{N-ARY CIRCLED TIMES OPERATOR}' + ' ')) + else: + pform = prettyForm(*pform.right(' x ')) + return pform + + def _latex(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + arg_s = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + arg_s = r'\left(%s\right)' % arg_s + s = s + arg_s + if i != length - 1: + s = s + r'\otimes ' + return s + + +class DirectSumHilbertSpace(HilbertSpace): + """A direct sum of Hilbert spaces [1]_. + + This class uses the ``+`` operator to represent direct sums between + different Hilbert spaces. + + A ``DirectSumHilbertSpace`` object takes in an arbitrary number of + ``HilbertSpace`` objects as its arguments. Also, addition of + ``HilbertSpace`` objects will automatically return a direct sum object. + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + + >>> c = ComplexSpace(2) + >>> f = FockSpace() + >>> hs = c+f + >>> hs + C(2)+F + >>> hs.dimension + oo + >>> list(hs.spaces) + [C(2), F] + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Direct_sums + """ + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + obj = Basic.__new__(cls, *args) + return obj + + @classmethod + def eval(cls, args): + """Evaluates the direct product.""" + new_args = [] + recall = False + #flatten arguments + for arg in args: + if isinstance(arg, DirectSumHilbertSpace): + new_args.extend(arg.args) + recall = True + elif isinstance(arg, HilbertSpace): + new_args.append(arg) + else: + raise TypeError('Hilbert spaces can only be summed with other \ + Hilbert spaces: %r' % arg) + if recall: + return DirectSumHilbertSpace(*new_args) + else: + return None + + @property + def dimension(self): + arg_list = [arg.dimension for arg in self.args] + if S.Infinity in arg_list: + return S.Infinity + else: + return reduce(lambda x, y: x + y, arg_list) + + @property + def spaces(self): + """A tuple of the Hilbert spaces in this direct sum.""" + return self.args + + def _sympyrepr(self, printer, *args): + spaces_reprs = [printer._print(arg, *args) for arg in self.args] + return "DirectSumHilbertSpace(%s)" % ','.join(spaces_reprs) + + def _sympystr(self, printer, *args): + spaces_strs = [printer._print(arg, *args) for arg in self.args] + return '+'.join(spaces_strs) + + def _pretty(self, printer, *args): + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right(' \N{CIRCLED PLUS} ')) + else: + pform = prettyForm(*pform.right(' + ')) + return pform + + def _latex(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + arg_s = printer._print(self.args[i], *args) + if isinstance(self.args[i], (DirectSumHilbertSpace, + TensorProductHilbertSpace)): + arg_s = r'\left(%s\right)' % arg_s + s = s + arg_s + if i != length - 1: + s = s + r'\oplus ' + return s + + +class TensorPowerHilbertSpace(HilbertSpace): + """An exponentiated Hilbert space [1]_. + + Tensor powers (repeated tensor products) are represented by the + operator ``**`` Identical Hilbert spaces that are multiplied together + will be automatically combined into a single tensor power object. + + Any Hilbert space, product, or sum may be raised to a tensor power. The + ``TensorPowerHilbertSpace`` takes two arguments: the Hilbert space; and the + tensor power (number). + + Examples + ======== + + >>> from sympy.physics.quantum.hilbert import ComplexSpace, FockSpace + >>> from sympy import symbols + + >>> n = symbols('n') + >>> c = ComplexSpace(2) + >>> hs = c**n + >>> hs + C(2)**n + >>> hs.dimension + 2**n + + >>> c = ComplexSpace(2) + >>> c*c + C(2)**2 + >>> f = FockSpace() + >>> c*f*f + C(2)*F**2 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Hilbert_space#Tensor_products + """ + + def __new__(cls, *args): + r = cls.eval(args) + if isinstance(r, Basic): + return r + return Basic.__new__(cls, *r) + + @classmethod + def eval(cls, args): + new_args = args[0], sympify(args[1]) + exp = new_args[1] + #simplify hs**1 -> hs + if exp is S.One: + return args[0] + #simplify hs**0 -> 1 + if exp is S.Zero: + return S.One + #check (and allow) for hs**(x+42+y...) case + if len(exp.atoms()) == 1: + if not (exp.is_Integer and exp >= 0 or exp.is_Symbol): + raise ValueError('Hilbert spaces can only be raised to \ + positive integers or Symbols: %r' % exp) + else: + for power in exp.atoms(): + if not (power.is_Integer or power.is_Symbol): + raise ValueError('Tensor powers can only contain integers \ + or Symbols: %r' % power) + return new_args + + @property + def base(self): + return self.args[0] + + @property + def exp(self): + return self.args[1] + + @property + def dimension(self): + if self.base.dimension is S.Infinity: + return S.Infinity + else: + return self.base.dimension**self.exp + + def _sympyrepr(self, printer, *args): + return "TensorPowerHilbertSpace(%s,%s)" % (printer._print(self.base, + *args), printer._print(self.exp, *args)) + + def _sympystr(self, printer, *args): + return "%s**%s" % (printer._print(self.base, *args), + printer._print(self.exp, *args)) + + def _pretty(self, printer, *args): + pform_exp = printer._print(self.exp, *args) + if printer._use_unicode: + pform_exp = prettyForm(*pform_exp.left(prettyForm('\N{N-ARY CIRCLED TIMES OPERATOR}'))) + else: + pform_exp = prettyForm(*pform_exp.left(prettyForm('x'))) + pform_base = printer._print(self.base, *args) + return pform_base**pform_exp + + def _latex(self, printer, *args): + base = printer._print(self.base, *args) + exp = printer._print(self.exp, *args) + return r'{%s}^{\otimes %s}' % (base, exp) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/identitysearch.py b/lib/python3.10/site-packages/sympy/physics/quantum/identitysearch.py new file mode 100644 index 0000000000000000000000000000000000000000..9a178e9b808450b7ce91175600d6b393fc9797d6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/identitysearch.py @@ -0,0 +1,853 @@ +from collections import deque +from sympy.core.random import randint + +from sympy.external import import_module +from sympy.core.basic import Basic +from sympy.core.mul import Mul +from sympy.core.numbers import Number, equal_valued +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.dagger import Dagger + +__all__ = [ + # Public interfaces + 'generate_gate_rules', + 'generate_equivalent_ids', + 'GateIdentity', + 'bfs_identity_search', + 'random_identity_search', + + # "Private" functions + 'is_scalar_sparse_matrix', + 'is_scalar_nonsparse_matrix', + 'is_degenerate', + 'is_reducible', +] + +np = import_module('numpy') +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def is_scalar_sparse_matrix(circuit, nqubits, identity_only, eps=1e-11): + """Checks if a given scipy.sparse matrix is a scalar matrix. + + A scalar matrix is such that B = bI, where B is the scalar + matrix, b is some scalar multiple, and I is the identity + matrix. A scalar matrix would have only the element b along + it's main diagonal and zeroes elsewhere. + + Parameters + ========== + + circuit : Gate tuple + Sequence of quantum gates representing a quantum circuit + nqubits : int + Number of qubits in the circuit + identity_only : bool + Check for only identity matrices + eps : number + The tolerance value for zeroing out elements in the matrix. + Values in the range [-eps, +eps] will be changed to a zero. + """ + + if not np or not scipy: + pass + + matrix = represent(Mul(*circuit), nqubits=nqubits, + format='scipy.sparse') + + # In some cases, represent returns a 1D scalar value in place + # of a multi-dimensional scalar matrix + if (isinstance(matrix, int)): + return matrix == 1 if identity_only else True + + # If represent returns a matrix, check if the matrix is diagonal + # and if every item along the diagonal is the same + else: + # Due to floating pointing operations, must zero out + # elements that are "very" small in the dense matrix + # See parameter for default value. + + # Get the ndarray version of the dense matrix + dense_matrix = matrix.todense().getA() + # Since complex values can't be compared, must split + # the matrix into real and imaginary components + # Find the real values in between -eps and eps + bool_real = np.logical_and(dense_matrix.real > -eps, + dense_matrix.real < eps) + # Find the imaginary values between -eps and eps + bool_imag = np.logical_and(dense_matrix.imag > -eps, + dense_matrix.imag < eps) + # Replaces values between -eps and eps with 0 + corrected_real = np.where(bool_real, 0.0, dense_matrix.real) + corrected_imag = np.where(bool_imag, 0.0, dense_matrix.imag) + # Convert the matrix with real values into imaginary values + corrected_imag = corrected_imag * complex(1j) + # Recombine the real and imaginary components + corrected_dense = corrected_real + corrected_imag + + # Check if it's diagonal + row_indices = corrected_dense.nonzero()[0] + col_indices = corrected_dense.nonzero()[1] + # Check if the rows indices and columns indices are the same + # If they match, then matrix only contains elements along diagonal + bool_indices = row_indices == col_indices + is_diagonal = bool_indices.all() + + first_element = corrected_dense[0][0] + # If the first element is a zero, then can't rescale matrix + # and definitely not diagonal + if (first_element == 0.0 + 0.0j): + return False + + # The dimensions of the dense matrix should still + # be 2^nqubits if there are elements all along the + # the main diagonal + trace_of_corrected = (corrected_dense/first_element).trace() + expected_trace = pow(2, nqubits) + has_correct_trace = trace_of_corrected == expected_trace + + # If only looking for identity matrices + # first element must be a 1 + real_is_one = abs(first_element.real - 1.0) < eps + imag_is_zero = abs(first_element.imag) < eps + is_one = real_is_one and imag_is_zero + is_identity = is_one if identity_only else True + return bool(is_diagonal and has_correct_trace and is_identity) + + +def is_scalar_nonsparse_matrix(circuit, nqubits, identity_only, eps=None): + """Checks if a given circuit, in matrix form, is equivalent to + a scalar value. + + Parameters + ========== + + circuit : Gate tuple + Sequence of quantum gates representing a quantum circuit + nqubits : int + Number of qubits in the circuit + identity_only : bool + Check for only identity matrices + eps : number + This argument is ignored. It is just for signature compatibility with + is_scalar_sparse_matrix. + + Note: Used in situations when is_scalar_sparse_matrix has bugs + """ + + matrix = represent(Mul(*circuit), nqubits=nqubits) + + # In some cases, represent returns a 1D scalar value in place + # of a multi-dimensional scalar matrix + if (isinstance(matrix, Number)): + return matrix == 1 if identity_only else True + + # If represent returns a matrix, check if the matrix is diagonal + # and if every item along the diagonal is the same + else: + # Added up the diagonal elements + matrix_trace = matrix.trace() + # Divide the trace by the first element in the matrix + # if matrix is not required to be the identity matrix + adjusted_matrix_trace = (matrix_trace/matrix[0] + if not identity_only + else matrix_trace) + + is_identity = equal_valued(matrix[0], 1) if identity_only else True + + has_correct_trace = adjusted_matrix_trace == pow(2, nqubits) + + # The matrix is scalar if it's diagonal and the adjusted trace + # value is equal to 2^nqubits + return bool( + matrix.is_diagonal() and has_correct_trace and is_identity) + +if np and scipy: + is_scalar_matrix = is_scalar_sparse_matrix +else: + is_scalar_matrix = is_scalar_nonsparse_matrix + + +def _get_min_qubits(a_gate): + if isinstance(a_gate, Pow): + return a_gate.base.min_qubits + else: + return a_gate.min_qubits + + +def ll_op(left, right): + """Perform a LL operation. + + A LL operation multiplies both left and right circuits + with the dagger of the left circuit's leftmost gate, and + the dagger is multiplied on the left side of both circuits. + + If a LL is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a LL is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a LL operation: + + >>> from sympy.physics.quantum.identitysearch import ll_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> ll_op((x, y, z), ()) + ((Y(0), Z(0)), (X(0),)) + + >>> ll_op((y, z), (x,)) + ((Z(0),), (Y(0), X(0))) + """ + + if (len(left) > 0): + ll_gate = left[0] + ll_gate_is_unitary = is_scalar_matrix( + (Dagger(ll_gate), ll_gate), _get_min_qubits(ll_gate), True) + + if (len(left) > 0 and ll_gate_is_unitary): + # Get the new left side w/o the leftmost gate + new_left = left[1:len(left)] + # Add the leftmost gate to the left position on the right side + new_right = (Dagger(ll_gate),) + right + # Return the new gate rule + return (new_left, new_right) + + return None + + +def lr_op(left, right): + """Perform a LR operation. + + A LR operation multiplies both left and right circuits + with the dagger of the left circuit's rightmost gate, and + the dagger is multiplied on the right side of both circuits. + + If a LR is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a LR is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a LR operation: + + >>> from sympy.physics.quantum.identitysearch import lr_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> lr_op((x, y, z), ()) + ((X(0), Y(0)), (Z(0),)) + + >>> lr_op((x, y), (z,)) + ((X(0),), (Z(0), Y(0))) + """ + + if (len(left) > 0): + lr_gate = left[len(left) - 1] + lr_gate_is_unitary = is_scalar_matrix( + (Dagger(lr_gate), lr_gate), _get_min_qubits(lr_gate), True) + + if (len(left) > 0 and lr_gate_is_unitary): + # Get the new left side w/o the rightmost gate + new_left = left[0:len(left) - 1] + # Add the rightmost gate to the right position on the right side + new_right = right + (Dagger(lr_gate),) + # Return the new gate rule + return (new_left, new_right) + + return None + + +def rl_op(left, right): + """Perform a RL operation. + + A RL operation multiplies both left and right circuits + with the dagger of the right circuit's leftmost gate, and + the dagger is multiplied on the left side of both circuits. + + If a RL is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a RL is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a RL operation: + + >>> from sympy.physics.quantum.identitysearch import rl_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> rl_op((x,), (y, z)) + ((Y(0), X(0)), (Z(0),)) + + >>> rl_op((x, y), (z,)) + ((Z(0), X(0), Y(0)), ()) + """ + + if (len(right) > 0): + rl_gate = right[0] + rl_gate_is_unitary = is_scalar_matrix( + (Dagger(rl_gate), rl_gate), _get_min_qubits(rl_gate), True) + + if (len(right) > 0 and rl_gate_is_unitary): + # Get the new right side w/o the leftmost gate + new_right = right[1:len(right)] + # Add the leftmost gate to the left position on the left side + new_left = (Dagger(rl_gate),) + left + # Return the new gate rule + return (new_left, new_right) + + return None + + +def rr_op(left, right): + """Perform a RR operation. + + A RR operation multiplies both left and right circuits + with the dagger of the right circuit's rightmost gate, and + the dagger is multiplied on the right side of both circuits. + + If a RR is possible, it returns the new gate rule as a + 2-tuple (LHS, RHS), where LHS is the left circuit and + and RHS is the right circuit of the new rule. + If a RR is not possible, None is returned. + + Parameters + ========== + + left : Gate tuple + The left circuit of a gate rule expression. + right : Gate tuple + The right circuit of a gate rule expression. + + Examples + ======== + + Generate a new gate rule using a RR operation: + + >>> from sympy.physics.quantum.identitysearch import rr_op + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> rr_op((x, y), (z,)) + ((X(0), Y(0), Z(0)), ()) + + >>> rr_op((x,), (y, z)) + ((X(0), Z(0)), (Y(0),)) + """ + + if (len(right) > 0): + rr_gate = right[len(right) - 1] + rr_gate_is_unitary = is_scalar_matrix( + (Dagger(rr_gate), rr_gate), _get_min_qubits(rr_gate), True) + + if (len(right) > 0 and rr_gate_is_unitary): + # Get the new right side w/o the rightmost gate + new_right = right[0:len(right) - 1] + # Add the rightmost gate to the right position on the right side + new_left = left + (Dagger(rr_gate),) + # Return the new gate rule + return (new_left, new_right) + + return None + + +def generate_gate_rules(gate_seq, return_as_muls=False): + """Returns a set of gate rules. Each gate rules is represented + as a 2-tuple of tuples or Muls. An empty tuple represents an arbitrary + scalar value. + + This function uses the four operations (LL, LR, RL, RR) + to generate the gate rules. + + A gate rule is an expression such as ABC = D or AB = CD, where + A, B, C, and D are gates. Each value on either side of the + equal sign represents a circuit. The four operations allow + one to find a set of equivalent circuits from a gate identity. + The letters denoting the operation tell the user what + activities to perform on each expression. The first letter + indicates which side of the equal sign to focus on. The + second letter indicates which gate to focus on given the + side. Once this information is determined, the inverse + of the gate is multiplied on both circuits to create a new + gate rule. + + For example, given the identity, ABCD = 1, a LL operation + means look at the left value and multiply both left sides by the + inverse of the leftmost gate A. If A is Hermitian, the inverse + of A is still A. The resulting new rule is BCD = A. + + The following is a summary of the four operations. Assume + that in the examples, all gates are Hermitian. + + LL : left circuit, left multiply + ABCD = E -> AABCD = AE -> BCD = AE + LR : left circuit, right multiply + ABCD = E -> ABCDD = ED -> ABC = ED + RL : right circuit, left multiply + ABC = ED -> EABC = EED -> EABC = D + RR : right circuit, right multiply + AB = CD -> ABD = CDD -> ABD = C + + The number of gate rules generated is n*(n+1), where n + is the number of gates in the sequence (unproven). + + Parameters + ========== + + gate_seq : Gate tuple, Mul, or Number + A variable length tuple or Mul of Gates whose product is equal to + a scalar matrix + return_as_muls : bool + True to return a set of Muls; False to return a set of tuples + + Examples + ======== + + Find the gate rules of the current circuit using tuples: + + >>> from sympy.physics.quantum.identitysearch import generate_gate_rules + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> generate_gate_rules((x, x)) + {((X(0),), (X(0),)), ((X(0), X(0)), ())} + + >>> generate_gate_rules((x, y, z)) + {((), (X(0), Z(0), Y(0))), ((), (Y(0), X(0), Z(0))), + ((), (Z(0), Y(0), X(0))), ((X(0),), (Z(0), Y(0))), + ((Y(0),), (X(0), Z(0))), ((Z(0),), (Y(0), X(0))), + ((X(0), Y(0)), (Z(0),)), ((Y(0), Z(0)), (X(0),)), + ((Z(0), X(0)), (Y(0),)), ((X(0), Y(0), Z(0)), ()), + ((Y(0), Z(0), X(0)), ()), ((Z(0), X(0), Y(0)), ())} + + Find the gate rules of the current circuit using Muls: + + >>> generate_gate_rules(x*x, return_as_muls=True) + {(1, 1)} + + >>> generate_gate_rules(x*y*z, return_as_muls=True) + {(1, X(0)*Z(0)*Y(0)), (1, Y(0)*X(0)*Z(0)), + (1, Z(0)*Y(0)*X(0)), (X(0)*Y(0), Z(0)), + (Y(0)*Z(0), X(0)), (Z(0)*X(0), Y(0)), + (X(0)*Y(0)*Z(0), 1), (Y(0)*Z(0)*X(0), 1), + (Z(0)*X(0)*Y(0), 1), (X(0), Z(0)*Y(0)), + (Y(0), X(0)*Z(0)), (Z(0), Y(0)*X(0))} + """ + + if isinstance(gate_seq, Number): + if return_as_muls: + return {(S.One, S.One)} + else: + return {((), ())} + + elif isinstance(gate_seq, Mul): + gate_seq = gate_seq.args + + # Each item in queue is a 3-tuple: + # i) first item is the left side of an equality + # ii) second item is the right side of an equality + # iii) third item is the number of operations performed + # The argument, gate_seq, will start on the left side, and + # the right side will be empty, implying the presence of an + # identity. + queue = deque() + # A set of gate rules + rules = set() + # Maximum number of operations to perform + max_ops = len(gate_seq) + + def process_new_rule(new_rule, ops): + if new_rule is not None: + new_left, new_right = new_rule + + if new_rule not in rules and (new_right, new_left) not in rules: + rules.add(new_rule) + # If haven't reached the max limit on operations + if ops + 1 < max_ops: + queue.append(new_rule + (ops + 1,)) + + queue.append((gate_seq, (), 0)) + rules.add((gate_seq, ())) + + while len(queue) > 0: + left, right, ops = queue.popleft() + + # Do a LL + new_rule = ll_op(left, right) + process_new_rule(new_rule, ops) + # Do a LR + new_rule = lr_op(left, right) + process_new_rule(new_rule, ops) + # Do a RL + new_rule = rl_op(left, right) + process_new_rule(new_rule, ops) + # Do a RR + new_rule = rr_op(left, right) + process_new_rule(new_rule, ops) + + if return_as_muls: + # Convert each rule as tuples into a rule as muls + mul_rules = set() + for rule in rules: + left, right = rule + mul_rules.add((Mul(*left), Mul(*right))) + + rules = mul_rules + + return rules + + +def generate_equivalent_ids(gate_seq, return_as_muls=False): + """Returns a set of equivalent gate identities. + + A gate identity is a quantum circuit such that the product + of the gates in the circuit is equal to a scalar value. + For example, XYZ = i, where X, Y, Z are the Pauli gates and + i is the imaginary value, is considered a gate identity. + + This function uses the four operations (LL, LR, RL, RR) + to generate the gate rules and, subsequently, to locate equivalent + gate identities. + + Note that all equivalent identities are reachable in n operations + from the starting gate identity, where n is the number of gates + in the sequence. + + The max number of gate identities is 2n, where n is the number + of gates in the sequence (unproven). + + Parameters + ========== + + gate_seq : Gate tuple, Mul, or Number + A variable length tuple or Mul of Gates whose product is equal to + a scalar matrix. + return_as_muls: bool + True to return as Muls; False to return as tuples + + Examples + ======== + + Find equivalent gate identities from the current circuit with tuples: + + >>> from sympy.physics.quantum.identitysearch import generate_equivalent_ids + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> generate_equivalent_ids((x, x)) + {(X(0), X(0))} + + >>> generate_equivalent_ids((x, y, z)) + {(X(0), Y(0), Z(0)), (X(0), Z(0), Y(0)), (Y(0), X(0), Z(0)), + (Y(0), Z(0), X(0)), (Z(0), X(0), Y(0)), (Z(0), Y(0), X(0))} + + Find equivalent gate identities from the current circuit with Muls: + + >>> generate_equivalent_ids(x*x, return_as_muls=True) + {1} + + >>> generate_equivalent_ids(x*y*z, return_as_muls=True) + {X(0)*Y(0)*Z(0), X(0)*Z(0)*Y(0), Y(0)*X(0)*Z(0), + Y(0)*Z(0)*X(0), Z(0)*X(0)*Y(0), Z(0)*Y(0)*X(0)} + """ + + if isinstance(gate_seq, Number): + return {S.One} + elif isinstance(gate_seq, Mul): + gate_seq = gate_seq.args + + # Filter through the gate rules and keep the rules + # with an empty tuple either on the left or right side + + # A set of equivalent gate identities + eq_ids = set() + + gate_rules = generate_gate_rules(gate_seq) + for rule in gate_rules: + l, r = rule + if l == (): + eq_ids.add(r) + elif r == (): + eq_ids.add(l) + + if return_as_muls: + convert_to_mul = lambda id_seq: Mul(*id_seq) + eq_ids = set(map(convert_to_mul, eq_ids)) + + return eq_ids + + +class GateIdentity(Basic): + """Wrapper class for circuits that reduce to a scalar value. + + A gate identity is a quantum circuit such that the product + of the gates in the circuit is equal to a scalar value. + For example, XYZ = i, where X, Y, Z are the Pauli gates and + i is the imaginary value, is considered a gate identity. + + Parameters + ========== + + args : Gate tuple + A variable length tuple of Gates that form an identity. + + Examples + ======== + + Create a GateIdentity and look at its attributes: + + >>> from sympy.physics.quantum.identitysearch import GateIdentity + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> an_identity = GateIdentity(x, y, z) + >>> an_identity.circuit + X(0)*Y(0)*Z(0) + + >>> an_identity.equivalent_ids + {(X(0), Y(0), Z(0)), (X(0), Z(0), Y(0)), (Y(0), X(0), Z(0)), + (Y(0), Z(0), X(0)), (Z(0), X(0), Y(0)), (Z(0), Y(0), X(0))} + """ + + def __new__(cls, *args): + # args should be a tuple - a variable length argument list + obj = Basic.__new__(cls, *args) + obj._circuit = Mul(*args) + obj._rules = generate_gate_rules(args) + obj._eq_ids = generate_equivalent_ids(args) + + return obj + + @property + def circuit(self): + return self._circuit + + @property + def gate_rules(self): + return self._rules + + @property + def equivalent_ids(self): + return self._eq_ids + + @property + def sequence(self): + return self.args + + def __str__(self): + """Returns the string of gates in a tuple.""" + return str(self.circuit) + + +def is_degenerate(identity_set, gate_identity): + """Checks if a gate identity is a permutation of another identity. + + Parameters + ========== + + identity_set : set + A Python set with GateIdentity objects. + gate_identity : GateIdentity + The GateIdentity to check for existence in the set. + + Examples + ======== + + Check if the identity is a permutation of another identity: + + >>> from sympy.physics.quantum.identitysearch import ( + ... GateIdentity, is_degenerate) + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> an_identity = GateIdentity(x, y, z) + >>> id_set = {an_identity} + >>> another_id = (y, z, x) + >>> is_degenerate(id_set, another_id) + True + + >>> another_id = (x, x) + >>> is_degenerate(id_set, another_id) + False + """ + + # For now, just iteratively go through the set and check if the current + # gate_identity is a permutation of an identity in the set + for an_id in identity_set: + if (gate_identity in an_id.equivalent_ids): + return True + return False + + +def is_reducible(circuit, nqubits, begin, end): + """Determines if a circuit is reducible by checking + if its subcircuits are scalar values. + + Parameters + ========== + + circuit : Gate tuple + A tuple of Gates representing a circuit. The circuit to check + if a gate identity is contained in a subcircuit. + nqubits : int + The number of qubits the circuit operates on. + begin : int + The leftmost gate in the circuit to include in a subcircuit. + end : int + The rightmost gate in the circuit to include in a subcircuit. + + Examples + ======== + + Check if the circuit can be reduced: + + >>> from sympy.physics.quantum.identitysearch import is_reducible + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> is_reducible((x, y, z), 1, 0, 3) + True + + Check if an interval in the circuit can be reduced: + + >>> is_reducible((x, y, z), 1, 1, 3) + False + + >>> is_reducible((x, y, y), 1, 1, 3) + True + """ + + current_circuit = () + # Start from the gate at "end" and go down to almost the gate at "begin" + for ndx in reversed(range(begin, end)): + next_gate = circuit[ndx] + current_circuit = (next_gate,) + current_circuit + + # If a circuit as a matrix is equivalent to a scalar value + if (is_scalar_matrix(current_circuit, nqubits, False)): + return True + + return False + + +def bfs_identity_search(gate_list, nqubits, max_depth=None, + identity_only=False): + """Constructs a set of gate identities from the list of possible gates. + + Performs a breadth first search over the space of gate identities. + This allows the finding of the shortest gate identities first. + + Parameters + ========== + + gate_list : list, Gate + A list of Gates from which to search for gate identities. + nqubits : int + The number of qubits the quantum circuit operates on. + max_depth : int + The longest quantum circuit to construct from gate_list. + identity_only : bool + True to search for gate identities that reduce to identity; + False to search for gate identities that reduce to a scalar. + + Examples + ======== + + Find a list of gate identities: + + >>> from sympy.physics.quantum.identitysearch import bfs_identity_search + >>> from sympy.physics.quantum.gate import X, Y, Z + >>> x = X(0); y = Y(0); z = Z(0) + >>> bfs_identity_search([x], 1, max_depth=2) + {GateIdentity(X(0), X(0))} + + >>> bfs_identity_search([x, y, z], 1) + {GateIdentity(X(0), X(0)), GateIdentity(Y(0), Y(0)), + GateIdentity(Z(0), Z(0)), GateIdentity(X(0), Y(0), Z(0))} + + Find a list of identities that only equal to 1: + + >>> bfs_identity_search([x, y, z], 1, identity_only=True) + {GateIdentity(X(0), X(0)), GateIdentity(Y(0), Y(0)), + GateIdentity(Z(0), Z(0))} + """ + + if max_depth is None or max_depth <= 0: + max_depth = len(gate_list) + + id_only = identity_only + + # Start with an empty sequence (implicitly contains an IdentityGate) + queue = deque([()]) + + # Create an empty set of gate identities + ids = set() + + # Begin searching for gate identities in given space. + while (len(queue) > 0): + current_circuit = queue.popleft() + + for next_gate in gate_list: + new_circuit = current_circuit + (next_gate,) + + # Determines if a (strict) subcircuit is a scalar matrix + circuit_reducible = is_reducible(new_circuit, nqubits, + 1, len(new_circuit)) + + # In many cases when the matrix is a scalar value, + # the evaluated matrix will actually be an integer + if (is_scalar_matrix(new_circuit, nqubits, id_only) and + not is_degenerate(ids, new_circuit) and + not circuit_reducible): + ids.add(GateIdentity(*new_circuit)) + + elif (len(new_circuit) < max_depth and + not circuit_reducible): + queue.append(new_circuit) + + return ids + + +def random_identity_search(gate_list, numgates, nqubits): + """Randomly selects numgates from gate_list and checks if it is + a gate identity. + + If the circuit is a gate identity, the circuit is returned; + Otherwise, None is returned. + """ + + gate_size = len(gate_list) + circuit = () + + for i in range(numgates): + next_gate = gate_list[randint(0, gate_size - 1)] + circuit = circuit + (next_gate,) + + is_scalar = is_scalar_matrix(circuit, nqubits, False) + + return circuit if is_scalar else None diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/innerproduct.py b/lib/python3.10/site-packages/sympy/physics/quantum/innerproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..1b712f2db9a864807f64cb9cc8fc26e0189cef8e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/innerproduct.py @@ -0,0 +1,137 @@ +"""Symbolic inner product.""" + +from sympy.core.expr import Expr +from sympy.functions.elementary.complexes import conjugate +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.state import KetBase, BraBase + +__all__ = [ + 'InnerProduct' +] + + +# InnerProduct is not an QExpr because it is really just a regular commutative +# number. We have gone back and forth about this, but we gain a lot by having +# it subclass Expr. The main challenges were getting Dagger to work +# (we use _eval_conjugate) and represent (we can use atoms and subs). Having +# it be an Expr, mean that there are no commutative QExpr subclasses, +# which simplifies the design of everything. + +class InnerProduct(Expr): + """An unevaluated inner product between a Bra and a Ket [1]. + + Parameters + ========== + + bra : BraBase or subclass + The bra on the left side of the inner product. + ket : KetBase or subclass + The ket on the right side of the inner product. + + Examples + ======== + + Create an InnerProduct and check its properties: + + >>> from sympy.physics.quantum import Bra, Ket + >>> b = Bra('b') + >>> k = Ket('k') + >>> ip = b*k + >>> ip + + >>> ip.bra + >> ip.ket + |k> + + In simple products of kets and bras inner products will be automatically + identified and created:: + + >>> b*k + + + But in more complex expressions, there is ambiguity in whether inner or + outer products should be created:: + + >>> k*b*k*b + |k>*>> k*(b*k)*b + *|k>* moved to the left of the expression + because inner products are commutative complex numbers. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Inner_product + """ + is_complex = True + + def __new__(cls, bra, ket): + if not isinstance(ket, KetBase): + raise TypeError('KetBase subclass expected, got: %r' % ket) + if not isinstance(bra, BraBase): + raise TypeError('BraBase subclass expected, got: %r' % ket) + obj = Expr.__new__(cls, bra, ket) + return obj + + @property + def bra(self): + return self.args[0] + + @property + def ket(self): + return self.args[1] + + def _eval_conjugate(self): + return InnerProduct(Dagger(self.ket), Dagger(self.bra)) + + def _sympyrepr(self, printer, *args): + return '%s(%s,%s)' % (self.__class__.__name__, + printer._print(self.bra, *args), printer._print(self.ket, *args)) + + def _sympystr(self, printer, *args): + sbra = printer._print(self.bra) + sket = printer._print(self.ket) + return '%s|%s' % (sbra[:-1], sket[1:]) + + def _pretty(self, printer, *args): + # Print state contents + bra = self.bra._print_contents_pretty(printer, *args) + ket = self.ket._print_contents_pretty(printer, *args) + # Print brackets + height = max(bra.height(), ket.height()) + use_unicode = printer._use_unicode + lbracket, _ = self.bra._pretty_brackets(height, use_unicode) + cbracket, rbracket = self.ket._pretty_brackets(height, use_unicode) + # Build innerproduct + pform = prettyForm(*bra.left(lbracket)) + pform = prettyForm(*pform.right(cbracket)) + pform = prettyForm(*pform.right(ket)) + pform = prettyForm(*pform.right(rbracket)) + return pform + + def _latex(self, printer, *args): + bra_label = self.bra._print_contents_latex(printer, *args) + ket = printer._print(self.ket, *args) + return r'\left\langle %s \right. %s' % (bra_label, ket) + + def doit(self, **hints): + try: + r = self.ket._eval_innerproduct(self.bra, **hints) + except NotImplementedError: + try: + r = conjugate( + self.bra.dual._eval_innerproduct(self.ket.dual, **hints) + ) + except NotImplementedError: + r = None + if r is not None: + return r + return self diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/matrixcache.py b/lib/python3.10/site-packages/sympy/physics/quantum/matrixcache.py new file mode 100644 index 0000000000000000000000000000000000000000..3cfab3c3490c909966d8a56af395ffa578724ea7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/matrixcache.py @@ -0,0 +1,103 @@ +"""A cache for storing small matrices in multiple formats.""" + +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.power import Pow +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.matrixutils import ( + to_sympy, to_numpy, to_scipy_sparse +) + + +class MatrixCache: + """A cache for small matrices in different formats. + + This class takes small matrices in the standard ``sympy.Matrix`` format, + and then converts these to both ``numpy.matrix`` and + ``scipy.sparse.csr_matrix`` matrices. These matrices are then stored for + future recovery. + """ + + def __init__(self, dtype='complex'): + self._cache = {} + self.dtype = dtype + + def cache_matrix(self, name, m): + """Cache a matrix by its name. + + Parameters + ---------- + name : str + A descriptive name for the matrix, like "identity2". + m : list of lists + The raw matrix data as a SymPy Matrix. + """ + try: + self._sympy_matrix(name, m) + except ImportError: + pass + try: + self._numpy_matrix(name, m) + except ImportError: + pass + try: + self._scipy_sparse_matrix(name, m) + except ImportError: + pass + + def get_matrix(self, name, format): + """Get a cached matrix by name and format. + + Parameters + ---------- + name : str + A descriptive name for the matrix, like "identity2". + format : str + The format desired ('sympy', 'numpy', 'scipy.sparse') + """ + m = self._cache.get((name, format)) + if m is not None: + return m + raise NotImplementedError( + 'Matrix with name %s and format %s is not available.' % + (name, format) + ) + + def _store_matrix(self, name, format, m): + self._cache[(name, format)] = m + + def _sympy_matrix(self, name, m): + self._store_matrix(name, 'sympy', to_sympy(m)) + + def _numpy_matrix(self, name, m): + m = to_numpy(m, dtype=self.dtype) + self._store_matrix(name, 'numpy', m) + + def _scipy_sparse_matrix(self, name, m): + # TODO: explore different sparse formats. But sparse.kron will use + # coo in most cases, so we use that here. + m = to_scipy_sparse(m, dtype=self.dtype) + self._store_matrix(name, 'scipy.sparse', m) + + +sqrt2_inv = Pow(2, Rational(-1, 2), evaluate=False) + +# Save the common matrices that we will need +matrix_cache = MatrixCache() +matrix_cache.cache_matrix('eye2', Matrix([[1, 0], [0, 1]])) +matrix_cache.cache_matrix('op11', Matrix([[0, 0], [0, 1]])) # |1><1| +matrix_cache.cache_matrix('op00', Matrix([[1, 0], [0, 0]])) # |0><0| +matrix_cache.cache_matrix('op10', Matrix([[0, 0], [1, 0]])) # |1><0| +matrix_cache.cache_matrix('op01', Matrix([[0, 1], [0, 0]])) # |0><1| +matrix_cache.cache_matrix('X', Matrix([[0, 1], [1, 0]])) +matrix_cache.cache_matrix('Y', Matrix([[0, -I], [I, 0]])) +matrix_cache.cache_matrix('Z', Matrix([[1, 0], [0, -1]])) +matrix_cache.cache_matrix('S', Matrix([[1, 0], [0, I]])) +matrix_cache.cache_matrix('T', Matrix([[1, 0], [0, exp(I*pi/4)]])) +matrix_cache.cache_matrix('H', sqrt2_inv*Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix('Hsqrt2', Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix( + 'SWAP', Matrix([[1, 0, 0, 0], [0, 0, 1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])) +matrix_cache.cache_matrix('ZX', sqrt2_inv*Matrix([[1, 1], [1, -1]])) +matrix_cache.cache_matrix('ZY', Matrix([[I, 0], [0, -I]])) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/matrixutils.py b/lib/python3.10/site-packages/sympy/physics/quantum/matrixutils.py new file mode 100644 index 0000000000000000000000000000000000000000..236b38668e10a8ce3574b390b885d269c1f96f64 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/matrixutils.py @@ -0,0 +1,272 @@ +"""Utilities to deal with sympy.Matrix, numpy and scipy.sparse.""" + +from sympy.core.expr import Expr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.matrices.matrixbase import MatrixBase +from sympy.matrices import eye, zeros +from sympy.external import import_module + +__all__ = [ + 'numpy_ndarray', + 'scipy_sparse_matrix', + 'sympy_to_numpy', + 'sympy_to_scipy_sparse', + 'numpy_to_sympy', + 'scipy_sparse_to_sympy', + 'flatten_scalar', + 'matrix_dagger', + 'to_sympy', + 'to_numpy', + 'to_scipy_sparse', + 'matrix_tensor_product', + 'matrix_zeros' +] + +# Conditionally define the base classes for numpy and scipy.sparse arrays +# for use in isinstance tests. + +np = import_module('numpy') +if not np: + class numpy_ndarray: + pass +else: + numpy_ndarray = np.ndarray # type: ignore + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) +if not scipy: + class scipy_sparse_matrix: + pass + sparse = None +else: + sparse = scipy.sparse + scipy_sparse_matrix = sparse.spmatrix # type: ignore + + +def sympy_to_numpy(m, **options): + """Convert a SymPy Matrix/complex number to a numpy matrix or scalar.""" + if not np: + raise ImportError + dtype = options.get('dtype', 'complex') + if isinstance(m, MatrixBase): + return np.array(m.tolist(), dtype=dtype) + elif isinstance(m, Expr): + if m.is_Number or m.is_NumberSymbol or m == I: + return complex(m) + raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) + + +def sympy_to_scipy_sparse(m, **options): + """Convert a SymPy Matrix/complex number to a numpy matrix or scalar.""" + if not np or not sparse: + raise ImportError + dtype = options.get('dtype', 'complex') + if isinstance(m, MatrixBase): + return sparse.csr_matrix(np.array(m.tolist(), dtype=dtype)) + elif isinstance(m, Expr): + if m.is_Number or m.is_NumberSymbol or m == I: + return complex(m) + raise TypeError('Expected MatrixBase or complex scalar, got: %r' % m) + + +def scipy_sparse_to_sympy(m, **options): + """Convert a scipy.sparse matrix to a SymPy matrix.""" + return MatrixBase(m.todense()) + + +def numpy_to_sympy(m, **options): + """Convert a numpy matrix to a SymPy matrix.""" + return MatrixBase(m) + + +def to_sympy(m, **options): + """Convert a numpy/scipy.sparse matrix to a SymPy matrix.""" + if isinstance(m, MatrixBase): + return m + elif isinstance(m, numpy_ndarray): + return numpy_to_sympy(m) + elif isinstance(m, scipy_sparse_matrix): + return scipy_sparse_to_sympy(m) + elif isinstance(m, Expr): + return m + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def to_numpy(m, **options): + """Convert a sympy/scipy.sparse matrix to a numpy matrix.""" + dtype = options.get('dtype', 'complex') + if isinstance(m, (MatrixBase, Expr)): + return sympy_to_numpy(m, dtype=dtype) + elif isinstance(m, numpy_ndarray): + return m + elif isinstance(m, scipy_sparse_matrix): + return m.todense() + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def to_scipy_sparse(m, **options): + """Convert a sympy/numpy matrix to a scipy.sparse matrix.""" + dtype = options.get('dtype', 'complex') + if isinstance(m, (MatrixBase, Expr)): + return sympy_to_scipy_sparse(m, dtype=dtype) + elif isinstance(m, numpy_ndarray): + if not sparse: + raise ImportError + return sparse.csr_matrix(m) + elif isinstance(m, scipy_sparse_matrix): + return m + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % m) + + +def flatten_scalar(e): + """Flatten a 1x1 matrix to a scalar, return larger matrices unchanged.""" + if isinstance(e, MatrixBase): + if e.shape == (1, 1): + e = e[0] + if isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): + if e.shape == (1, 1): + e = complex(e[0, 0]) + return e + + +def matrix_dagger(e): + """Return the dagger of a sympy/numpy/scipy.sparse matrix.""" + if isinstance(e, MatrixBase): + return e.H + elif isinstance(e, (numpy_ndarray, scipy_sparse_matrix)): + return e.conjugate().transpose() + raise TypeError('Expected sympy/numpy/scipy.sparse matrix, got: %r' % e) + + +# TODO: Move this into sympy.matricies. +def _sympy_tensor_product(*matrices): + """Compute the kronecker product of a sequence of SymPy Matrices. + """ + from sympy.matrices.expressions.kronecker import matrix_kronecker_product + + return matrix_kronecker_product(*matrices) + + +def _numpy_tensor_product(*product): + """numpy version of tensor product of multiple arguments.""" + if not np: + raise ImportError + answer = product[0] + for item in product[1:]: + answer = np.kron(answer, item) + return answer + + +def _scipy_sparse_tensor_product(*product): + """scipy.sparse version of tensor product of multiple arguments.""" + if not sparse: + raise ImportError + answer = product[0] + for item in product[1:]: + answer = sparse.kron(answer, item) + # The final matrices will just be multiplied, so csr is a good final + # sparse format. + return sparse.csr_matrix(answer) + + +def matrix_tensor_product(*product): + """Compute the matrix tensor product of sympy/numpy/scipy.sparse matrices.""" + if isinstance(product[0], MatrixBase): + return _sympy_tensor_product(*product) + elif isinstance(product[0], numpy_ndarray): + return _numpy_tensor_product(*product) + elif isinstance(product[0], scipy_sparse_matrix): + return _scipy_sparse_tensor_product(*product) + + +def _numpy_eye(n): + """numpy version of complex eye.""" + if not np: + raise ImportError + return np.array(np.eye(n, dtype='complex')) + + +def _scipy_sparse_eye(n): + """scipy.sparse version of complex eye.""" + if not sparse: + raise ImportError + return sparse.eye(n, n, dtype='complex') + + +def matrix_eye(n, **options): + """Get the version of eye and tensor_product for a given format.""" + format = options.get('format', 'sympy') + if format == 'sympy': + return eye(n) + elif format == 'numpy': + return _numpy_eye(n) + elif format == 'scipy.sparse': + return _scipy_sparse_eye(n) + raise NotImplementedError('Invalid format: %r' % format) + + +def _numpy_zeros(m, n, **options): + """numpy version of zeros.""" + dtype = options.get('dtype', 'float64') + if not np: + raise ImportError + return np.zeros((m, n), dtype=dtype) + + +def _scipy_sparse_zeros(m, n, **options): + """scipy.sparse version of zeros.""" + spmatrix = options.get('spmatrix', 'csr') + dtype = options.get('dtype', 'float64') + if not sparse: + raise ImportError + if spmatrix == 'lil': + return sparse.lil_matrix((m, n), dtype=dtype) + elif spmatrix == 'csr': + return sparse.csr_matrix((m, n), dtype=dtype) + + +def matrix_zeros(m, n, **options): + """"Get a zeros matrix for a given format.""" + format = options.get('format', 'sympy') + if format == 'sympy': + return zeros(m, n) + elif format == 'numpy': + return _numpy_zeros(m, n, **options) + elif format == 'scipy.sparse': + return _scipy_sparse_zeros(m, n, **options) + raise NotImplementedError('Invaild format: %r' % format) + + +def _numpy_matrix_to_zero(e): + """Convert a numpy zero matrix to the zero scalar.""" + if not np: + raise ImportError + test = np.zeros_like(e) + if np.allclose(e, test): + return 0.0 + else: + return e + + +def _scipy_sparse_matrix_to_zero(e): + """Convert a scipy.sparse zero matrix to the zero scalar.""" + if not np: + raise ImportError + edense = e.todense() + test = np.zeros_like(edense) + if np.allclose(edense, test): + return 0.0 + else: + return e + + +def matrix_to_zero(e): + """Convert a zero matrix to the scalar zero.""" + if isinstance(e, MatrixBase): + if zeros(*e.shape) == e: + e = S.Zero + elif isinstance(e, numpy_ndarray): + e = _numpy_matrix_to_zero(e) + elif isinstance(e, scipy_sparse_matrix): + e = _scipy_sparse_matrix_to_zero(e) + return e diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/operator.py b/lib/python3.10/site-packages/sympy/physics/quantum/operator.py new file mode 100644 index 0000000000000000000000000000000000000000..8c540dc016fc1a1043f3c25acf71ae0e1996e1c6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/operator.py @@ -0,0 +1,657 @@ +"""Quantum mechanical operators. + +TODO: + +* Fix early 0 in apply_operators. +* Debug and test apply_operators. +* Get cse working with classes in this file. +* Doctests and documentation of special methods for InnerProduct, Commutator, + AntiCommutator, represent, apply_operators. +""" +from typing import Optional + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Derivative, expand) +from sympy.core.mul import Mul +from sympy.core.numbers import oo +from sympy.core.singleton import S +from sympy.printing.pretty.stringpict import prettyForm +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.qexpr import QExpr, dispatch_method +from sympy.matrices import eye + +__all__ = [ + 'Operator', + 'HermitianOperator', + 'UnitaryOperator', + 'IdentityOperator', + 'OuterProduct', + 'DifferentialOperator' +] + +#----------------------------------------------------------------------------- +# Operators and outer products +#----------------------------------------------------------------------------- + + +class Operator(QExpr): + """Base class for non-commuting quantum operators. + + An operator maps between quantum states [1]_. In quantum mechanics, + observables (including, but not limited to, measured physical values) are + represented as Hermitian operators [2]_. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + Create an operator and examine its attributes:: + + >>> from sympy.physics.quantum import Operator + >>> from sympy import I + >>> A = Operator('A') + >>> A + A + >>> A.hilbert_space + H + >>> A.label + (A,) + >>> A.is_commutative + False + + Create another operator and do some arithmetic operations:: + + >>> B = Operator('B') + >>> C = 2*A*A + I*B + >>> C + 2*A**2 + I*B + + Operators do not commute:: + + >>> A.is_commutative + False + >>> B.is_commutative + False + >>> A*B == B*A + False + + Polymonials of operators respect the commutation properties:: + + >>> e = (A+B)**3 + >>> e.expand() + A*B*A + A*B**2 + A**2*B + A**3 + B*A*B + B*A**2 + B**2*A + B**3 + + Operator inverses are handle symbolically:: + + >>> A.inv() + A**(-1) + >>> A*A.inv() + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Operator_%28physics%29 + .. [2] https://en.wikipedia.org/wiki/Observable + """ + is_hermitian: Optional[bool] = None + is_unitary: Optional[bool] = None + @classmethod + def default_args(self): + return ("O",) + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + _label_separator = ',' + + def _print_operator_name(self, printer, *args): + return self.__class__.__name__ + + _print_operator_name_latex = _print_operator_name + + def _print_operator_name_pretty(self, printer, *args): + return prettyForm(self.__class__.__name__) + + def _print_contents(self, printer, *args): + if len(self.label) == 1: + return self._print_label(printer, *args) + else: + return '%s(%s)' % ( + self._print_operator_name(printer, *args), + self._print_label(printer, *args) + ) + + def _print_contents_pretty(self, printer, *args): + if len(self.label) == 1: + return self._print_label_pretty(printer, *args) + else: + pform = self._print_operator_name_pretty(printer, *args) + label_pform = self._print_label_pretty(printer, *args) + label_pform = prettyForm( + *label_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(label_pform)) + return pform + + def _print_contents_latex(self, printer, *args): + if len(self.label) == 1: + return self._print_label_latex(printer, *args) + else: + return r'%s\left(%s\right)' % ( + self._print_operator_name_latex(printer, *args), + self._print_label_latex(printer, *args) + ) + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_commutator(self, other, **options): + """Evaluate [self, other] if known, return None if not known.""" + return dispatch_method(self, '_eval_commutator', other, **options) + + def _eval_anticommutator(self, other, **options): + """Evaluate [self, other] if known.""" + return dispatch_method(self, '_eval_anticommutator', other, **options) + + #------------------------------------------------------------------------- + # Operator application + #------------------------------------------------------------------------- + + def _apply_operator(self, ket, **options): + return dispatch_method(self, '_apply_operator', ket, **options) + + def _apply_from_right_to(self, bra, **options): + return None + + def matrix_element(self, *args): + raise NotImplementedError('matrix_elements is not defined') + + def inverse(self): + return self._eval_inverse() + + inv = inverse + + def _eval_inverse(self): + return self**(-1) + + def __mul__(self, other): + + if isinstance(other, IdentityOperator): + return self + + return Mul(self, other) + + +class HermitianOperator(Operator): + """A Hermitian operator that satisfies H == Dagger(H). + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, HermitianOperator + >>> H = HermitianOperator('H') + >>> Dagger(H) + H + """ + + is_hermitian = True + + def _eval_inverse(self): + if isinstance(self, UnitaryOperator): + return self + else: + return Operator._eval_inverse(self) + + def _eval_power(self, exp): + if isinstance(self, UnitaryOperator): + # so all eigenvalues of self are 1 or -1 + if exp.is_even: + from sympy.core.singleton import S + return S.One # is identity, see Issue 24153. + elif exp.is_odd: + return self + # No simplification in all other cases + return Operator._eval_power(self, exp) + + +class UnitaryOperator(Operator): + """A unitary operator that satisfies U*Dagger(U) == 1. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. For time-dependent operators, this will include the time. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger, UnitaryOperator + >>> U = UnitaryOperator('U') + >>> U*Dagger(U) + 1 + """ + is_unitary = True + def _eval_adjoint(self): + return self._eval_inverse() + + +class IdentityOperator(Operator): + """An identity operator I that satisfies op * I == I * op == op for any + operator op. + + Parameters + ========== + + N : Integer + Optional parameter that specifies the dimension of the Hilbert space + of operator. This is used when generating a matrix representation. + + Examples + ======== + + >>> from sympy.physics.quantum import IdentityOperator + >>> IdentityOperator() + I + """ + is_hermitian = True + is_unitary = True + @property + def dimension(self): + return self.N + + @classmethod + def default_args(self): + return (oo,) + + def __init__(self, *args, **hints): + if not len(args) in (0, 1): + raise ValueError('0 or 1 parameters expected, got %s' % args) + + self.N = args[0] if (len(args) == 1 and args[0]) else oo + + def _eval_commutator(self, other, **hints): + return S.Zero + + def _eval_anticommutator(self, other, **hints): + return 2 * other + + def _eval_inverse(self): + return self + + def _eval_adjoint(self): + return self + + def _apply_operator(self, ket, **options): + return ket + + def _apply_from_right_to(self, bra, **options): + return bra + + def _eval_power(self, exp): + return self + + def _print_contents(self, printer, *args): + return 'I' + + def _print_contents_pretty(self, printer, *args): + return prettyForm('I') + + def _print_contents_latex(self, printer, *args): + return r'{\mathcal{I}}' + + def __mul__(self, other): + + if isinstance(other, (Operator, Dagger)): + return other + + return Mul(self, other) + + def _represent_default_basis(self, **options): + if not self.N or self.N == oo: + raise NotImplementedError('Cannot represent infinite dimensional' + + ' identity operator as a matrix') + + format = options.get('format', 'sympy') + if format != 'sympy': + raise NotImplementedError('Representation in format ' + + '%s not implemented.' % format) + + return eye(self.N) + + +class OuterProduct(Operator): + """An unevaluated outer product between a ket and bra. + + This constructs an outer product between any subclass of ``KetBase`` and + ``BraBase`` as ``|a>>> from sympy.physics.quantum import Ket, Bra, OuterProduct, Dagger + >>> from sympy.physics.quantum import Operator + + >>> k = Ket('k') + >>> b = Bra('b') + >>> op = OuterProduct(k, b) + >>> op + |k>>> op.hilbert_space + H + >>> op.ket + |k> + >>> op.bra + >> Dagger(op) + |b>>> k*b + |k>>> A = Operator('A') + >>> A*k*b + A*|k>*>> A*(k*b) + A*|k>>> from sympy import Derivative, Function, Symbol + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy.physics.quantum.qapply import qapply + >>> f = Function('f') + >>> x = Symbol('x') + >>> d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + >>> w = Wavefunction(x**2, x) + >>> d.function + f(x) + >>> d.variables + (x,) + >>> qapply(d*w) + Wavefunction(2, x) + + """ + + @property + def variables(self): + """ + Returns the variables with which the function in the specified + arbitrary expression is evaluated + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Symbol, Function, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(1/x*Derivative(f(x), x), f(x)) + >>> d.variables + (x,) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.variables + (x, y) + """ + + return self.args[-1].args + + @property + def function(self): + """ + Returns the function which is to be replaced with the Wavefunction + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Function, Symbol, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(Derivative(f(x), x), f(x)) + >>> d.function + f(x) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.function + f(x, y) + """ + + return self.args[-1] + + @property + def expr(self): + """ + Returns the arbitrary expression which is to have the Wavefunction + substituted into it + + Examples + ======== + + >>> from sympy.physics.quantum.operator import DifferentialOperator + >>> from sympy import Function, Symbol, Derivative + >>> x = Symbol('x') + >>> f = Function('f') + >>> d = DifferentialOperator(Derivative(f(x), x), f(x)) + >>> d.expr + Derivative(f(x), x) + >>> y = Symbol('y') + >>> d = DifferentialOperator(Derivative(f(x, y), x) + + ... Derivative(f(x, y), y), f(x, y)) + >>> d.expr + Derivative(f(x, y), x) + Derivative(f(x, y), y) + """ + + return self.args[0] + + @property + def free_symbols(self): + """ + Return the free symbols of the expression. + """ + + return self.expr.free_symbols + + def _apply_operator_Wavefunction(self, func, **options): + from sympy.physics.quantum.state import Wavefunction + var = self.variables + wf_vars = func.args[1:] + + f = self.function + new_expr = self.expr.subs(f, func(*var)) + new_expr = new_expr.doit() + + return Wavefunction(new_expr, *wf_vars) + + def _eval_derivative(self, symbol): + new_expr = Derivative(self.expr, symbol) + return DifferentialOperator(new_expr, self.args[-1]) + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + def _print(self, printer, *args): + return '%s(%s)' % ( + self._print_operator_name(printer, *args), + self._print_label(printer, *args) + ) + + def _print_pretty(self, printer, *args): + pform = self._print_operator_name_pretty(printer, *args) + label_pform = self._print_label_pretty(printer, *args) + label_pform = prettyForm( + *label_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(label_pform)) + return pform diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/operatorordering.py b/lib/python3.10/site-packages/sympy/physics/quantum/operatorordering.py new file mode 100644 index 0000000000000000000000000000000000000000..d6ba3dd83b4b79b773793b0094e636cc8a901f44 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/operatorordering.py @@ -0,0 +1,290 @@ +"""Functions for reordering operator expressions.""" + +import warnings + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.physics.quantum import Commutator, AntiCommutator +from sympy.physics.quantum.boson import BosonOp +from sympy.physics.quantum.fermion import FermionOp + +__all__ = [ + 'normal_order', + 'normal_ordered_form' +] + + +def _expand_powers(factors): + """ + Helper function for normal_ordered_form and normal_order: Expand a + power expression to a multiplication expression so that that the + expression can be handled by the normal ordering functions. + """ + + new_factors = [] + for factor in factors.args: + if (isinstance(factor, Pow) + and isinstance(factor.args[1], Integer) + and factor.args[1] > 0): + for n in range(factor.args[1]): + new_factors.append(factor.args[0]) + else: + new_factors.append(factor) + + return new_factors + +def _normal_ordered_form_factor(product, independent=False, recursive_limit=10, + _recursive_depth=0): + """ + Helper function for normal_ordered_form_factor: Write multiplication + expression with bosonic or fermionic operators on normally ordered form, + using the bosonic and fermionic commutation relations. The resulting + operator expression is equivalent to the argument, but will in general be + a sum of operator products instead of a simple product. + """ + + factors = _expand_powers(product) + + new_factors = [] + n = 0 + while n < len(factors) - 1: + current, next = factors[n], factors[n + 1] + if any(not isinstance(f, (FermionOp, BosonOp)) for f in (current, next)): + new_factors.append(current) + n += 1 + continue + + key_1 = (current.is_annihilation, str(current.name)) + key_2 = (next.is_annihilation, str(next.name)) + + if key_1 <= key_2: + new_factors.append(current) + n += 1 + continue + + n += 2 + if current.is_annihilation and not next.is_annihilation: + if isinstance(current, BosonOp) and isinstance(next, BosonOp): + if current.args[0] != next.args[0]: + if independent: + c = 0 + else: + c = Commutator(current, next) + new_factors.append(next * current + c) + else: + new_factors.append(next * current + 1) + elif isinstance(current, FermionOp) and isinstance(next, FermionOp): + if current.args[0] != next.args[0]: + if independent: + c = 0 + else: + c = AntiCommutator(current, next) + new_factors.append(-next * current + c) + else: + new_factors.append(-next * current + 1) + elif (current.is_annihilation == next.is_annihilation and + isinstance(current, FermionOp) and isinstance(next, FermionOp)): + new_factors.append(-next * current) + else: + new_factors.append(next * current) + + if n == len(factors) - 1: + new_factors.append(factors[-1]) + + if new_factors == factors: + return product + else: + expr = Mul(*new_factors).expand() + return normal_ordered_form(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth + 1, + independent=independent) + + +def _normal_ordered_form_terms(expr, independent=False, recursive_limit=10, + _recursive_depth=0): + """ + Helper function for normal_ordered_form: loop through each term in an + addition expression and call _normal_ordered_form_factor to perform the + factor to an normally ordered expression. + """ + + new_terms = [] + for term in expr.args: + if isinstance(term, Mul): + new_term = _normal_ordered_form_factor( + term, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, independent=independent) + new_terms.append(new_term) + else: + new_terms.append(term) + + return Add(*new_terms) + + +def normal_ordered_form(expr, independent=False, recursive_limit=10, + _recursive_depth=0): + """Write an expression with bosonic or fermionic operators on normal + ordered form, where each term is normally ordered. Note that this + normal ordered form is equivalent to the original expression. + + Parameters + ========== + + expr : expression + The expression write on normal ordered form. + independent : bool (default False) + Whether to consider operator with different names as operating in + different Hilbert spaces. If False, the (anti-)commutation is left + explicit. + recursive_limit : int (default 10) + The number of allowed recursive applications of the function. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger + >>> from sympy.physics.quantum.boson import BosonOp + >>> from sympy.physics.quantum.operatorordering import normal_ordered_form + >>> a = BosonOp("a") + >>> normal_ordered_form(a * Dagger(a)) + 1 + Dagger(a)*a + """ + + if _recursive_depth > recursive_limit: + warnings.warn("Too many recursions, aborting") + return expr + + if isinstance(expr, Add): + return _normal_ordered_form_terms(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, + independent=independent) + elif isinstance(expr, Mul): + return _normal_ordered_form_factor(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth, + independent=independent) + else: + return expr + + +def _normal_order_factor(product, recursive_limit=10, _recursive_depth=0): + """ + Helper function for normal_order: Normal order a multiplication expression + with bosonic or fermionic operators. In general the resulting operator + expression will not be equivalent to original product. + """ + + factors = _expand_powers(product) + + n = 0 + new_factors = [] + while n < len(factors) - 1: + + if (isinstance(factors[n], BosonOp) and + factors[n].is_annihilation): + # boson + if not isinstance(factors[n + 1], BosonOp): + new_factors.append(factors[n]) + else: + if factors[n + 1].is_annihilation: + new_factors.append(factors[n]) + else: + if factors[n].args[0] != factors[n + 1].args[0]: + new_factors.append(factors[n + 1] * factors[n]) + else: + new_factors.append(factors[n + 1] * factors[n]) + n += 1 + + elif (isinstance(factors[n], FermionOp) and + factors[n].is_annihilation): + # fermion + if not isinstance(factors[n + 1], FermionOp): + new_factors.append(factors[n]) + else: + if factors[n + 1].is_annihilation: + new_factors.append(factors[n]) + else: + if factors[n].args[0] != factors[n + 1].args[0]: + new_factors.append(-factors[n + 1] * factors[n]) + else: + new_factors.append(-factors[n + 1] * factors[n]) + n += 1 + + else: + new_factors.append(factors[n]) + + n += 1 + + if n == len(factors) - 1: + new_factors.append(factors[-1]) + + if new_factors == factors: + return product + else: + expr = Mul(*new_factors).expand() + return normal_order(expr, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth + 1) + + +def _normal_order_terms(expr, recursive_limit=10, _recursive_depth=0): + """ + Helper function for normal_order: look through each term in an addition + expression and call _normal_order_factor to perform the normal ordering + on the factors. + """ + + new_terms = [] + for term in expr.args: + if isinstance(term, Mul): + new_term = _normal_order_factor(term, + recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + new_terms.append(new_term) + else: + new_terms.append(term) + + return Add(*new_terms) + + +def normal_order(expr, recursive_limit=10, _recursive_depth=0): + """Normal order an expression with bosonic or fermionic operators. Note + that this normal order is not equivalent to the original expression, but + the creation and annihilation operators in each term in expr is reordered + so that the expression becomes normal ordered. + + Parameters + ========== + + expr : expression + The expression to normal order. + + recursive_limit : int (default 10) + The number of allowed recursive applications of the function. + + Examples + ======== + + >>> from sympy.physics.quantum import Dagger + >>> from sympy.physics.quantum.boson import BosonOp + >>> from sympy.physics.quantum.operatorordering import normal_order + >>> a = BosonOp("a") + >>> normal_order(a * Dagger(a)) + Dagger(a)*a + """ + if _recursive_depth > recursive_limit: + warnings.warn("Too many recursions, aborting") + return expr + + if isinstance(expr, Add): + return _normal_order_terms(expr, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + elif isinstance(expr, Mul): + return _normal_order_factor(expr, recursive_limit=recursive_limit, + _recursive_depth=_recursive_depth) + else: + return expr diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/operatorset.py b/lib/python3.10/site-packages/sympy/physics/quantum/operatorset.py new file mode 100644 index 0000000000000000000000000000000000000000..bf32bcabbe5d33381dff0b94a9b130375032adef --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/operatorset.py @@ -0,0 +1,279 @@ +""" A module for mapping operators to their corresponding eigenstates +and vice versa + +It contains a global dictionary with eigenstate-operator pairings. +If a new state-operator pair is created, this dictionary should be +updated as well. + +It also contains functions operators_to_state and state_to_operators +for mapping between the two. These can handle both classes and +instances of operators and states. See the individual function +descriptions for details. + +TODO List: +- Update the dictionary with a complete list of state-operator pairs +""" + +from sympy.physics.quantum.cartesian import (XOp, YOp, ZOp, XKet, PxOp, PxKet, + PositionKet3D) +from sympy.physics.quantum.operator import Operator +from sympy.physics.quantum.state import StateBase, BraBase, Ket +from sympy.physics.quantum.spin import (JxOp, JyOp, JzOp, J2Op, JxKet, JyKet, + JzKet) + +__all__ = [ + 'operators_to_state', + 'state_to_operators' +] + +#state_mapping stores the mappings between states and their associated +#operators or tuples of operators. This should be updated when new +#classes are written! Entries are of the form PxKet : PxOp or +#something like 3DKet : (ROp, ThetaOp, PhiOp) + +#frozenset is used so that the reverse mapping can be made +#(regular sets are not hashable because they are mutable +state_mapping = { JxKet: frozenset((J2Op, JxOp)), + JyKet: frozenset((J2Op, JyOp)), + JzKet: frozenset((J2Op, JzOp)), + Ket: Operator, + PositionKet3D: frozenset((XOp, YOp, ZOp)), + PxKet: PxOp, + XKet: XOp } + +op_mapping = {v: k for k, v in state_mapping.items()} + + +def operators_to_state(operators, **options): + """ Returns the eigenstate of the given operator or set of operators + + A global function for mapping operator classes to their associated + states. It takes either an Operator or a set of operators and + returns the state associated with these. + + This function can handle both instances of a given operator or + just the class itself (i.e. both XOp() and XOp) + + There are multiple use cases to consider: + + 1) A class or set of classes is passed: First, we try to + instantiate default instances for these operators. If this fails, + then the class is simply returned. If we succeed in instantiating + default instances, then we try to call state._operators_to_state + on the operator instances. If this fails, the class is returned. + Otherwise, the instance returned by _operators_to_state is returned. + + 2) An instance or set of instances is passed: In this case, + state._operators_to_state is called on the instances passed. If + this fails, a state class is returned. If the method returns an + instance, that instance is returned. + + In both cases, if the operator class or set does not exist in the + state_mapping dictionary, None is returned. + + Parameters + ========== + + arg: Operator or set + The class or instance of the operator or set of operators + to be mapped to a state + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XOp, PxOp + >>> from sympy.physics.quantum.operatorset import operators_to_state + >>> from sympy.physics.quantum.operator import Operator + >>> operators_to_state(XOp) + |x> + >>> operators_to_state(XOp()) + |x> + >>> operators_to_state(PxOp) + |px> + >>> operators_to_state(PxOp()) + |px> + >>> operators_to_state(Operator) + |psi> + >>> operators_to_state(Operator()) + |psi> + """ + + if not (isinstance(operators, (Operator, set)) or issubclass(operators, Operator)): + raise NotImplementedError("Argument is not an Operator or a set!") + + if isinstance(operators, set): + for s in operators: + if not (isinstance(s, Operator) + or issubclass(s, Operator)): + raise NotImplementedError("Set is not all Operators!") + + ops = frozenset(operators) + + if ops in op_mapping: # ops is a list of classes in this case + #Try to get an object from default instances of the + #operators...if this fails, return the class + try: + op_instances = [op() for op in ops] + ret = _get_state(op_mapping[ops], set(op_instances), **options) + except NotImplementedError: + ret = op_mapping[ops] + + return ret + else: + tmp = [type(o) for o in ops] + classes = frozenset(tmp) + + if classes in op_mapping: + ret = _get_state(op_mapping[classes], ops, **options) + else: + ret = None + + return ret + else: + if operators in op_mapping: + try: + op_instance = operators() + ret = _get_state(op_mapping[operators], op_instance, **options) + except NotImplementedError: + ret = op_mapping[operators] + + return ret + elif type(operators) in op_mapping: + return _get_state(op_mapping[type(operators)], operators, **options) + else: + return None + + +def state_to_operators(state, **options): + """ Returns the operator or set of operators corresponding to the + given eigenstate + + A global function for mapping state classes to their associated + operators or sets of operators. It takes either a state class + or instance. + + This function can handle both instances of a given state or just + the class itself (i.e. both XKet() and XKet) + + There are multiple use cases to consider: + + 1) A state class is passed: In this case, we first try + instantiating a default instance of the class. If this succeeds, + then we try to call state._state_to_operators on that instance. + If the creation of the default instance or if the calling of + _state_to_operators fails, then either an operator class or set of + operator classes is returned. Otherwise, the appropriate + operator instances are returned. + + 2) A state instance is returned: Here, state._state_to_operators + is called for the instance. If this fails, then a class or set of + operator classes is returned. Otherwise, the instances are returned. + + In either case, if the state's class does not exist in + state_mapping, None is returned. + + Parameters + ========== + + arg: StateBase class or instance (or subclasses) + The class or instance of the state to be mapped to an + operator or set of operators + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XKet, PxKet, XBra, PxBra + >>> from sympy.physics.quantum.operatorset import state_to_operators + >>> from sympy.physics.quantum.state import Ket, Bra + >>> state_to_operators(XKet) + X + >>> state_to_operators(XKet()) + X + >>> state_to_operators(PxKet) + Px + >>> state_to_operators(PxKet()) + Px + >>> state_to_operators(PxBra) + Px + >>> state_to_operators(XBra) + X + >>> state_to_operators(Ket) + O + >>> state_to_operators(Bra) + O + """ + + if not (isinstance(state, StateBase) or issubclass(state, StateBase)): + raise NotImplementedError("Argument is not a state!") + + if state in state_mapping: # state is a class + state_inst = _make_default(state) + try: + ret = _get_ops(state_inst, + _make_set(state_mapping[state]), **options) + except (NotImplementedError, TypeError): + ret = state_mapping[state] + elif type(state) in state_mapping: + ret = _get_ops(state, + _make_set(state_mapping[type(state)]), **options) + elif isinstance(state, BraBase) and state.dual_class() in state_mapping: + ret = _get_ops(state, + _make_set(state_mapping[state.dual_class()])) + elif issubclass(state, BraBase) and state.dual_class() in state_mapping: + state_inst = _make_default(state) + try: + ret = _get_ops(state_inst, + _make_set(state_mapping[state.dual_class()])) + except (NotImplementedError, TypeError): + ret = state_mapping[state.dual_class()] + else: + ret = None + + return _make_set(ret) + + +def _make_default(expr): + # XXX: Catching TypeError like this is a bad way of distinguishing between + # classes and instances. The logic using this function should be rewritten + # somehow. + try: + ret = expr() + except TypeError: + ret = expr + + return ret + + +def _get_state(state_class, ops, **options): + # Try to get a state instance from the operator INSTANCES. + # If this fails, get the class + try: + ret = state_class._operators_to_state(ops, **options) + except NotImplementedError: + ret = _make_default(state_class) + + return ret + + +def _get_ops(state_inst, op_classes, **options): + # Try to get operator instances from the state INSTANCE. + # If this fails, just return the classes + try: + ret = state_inst._state_to_operators(op_classes, **options) + except NotImplementedError: + if isinstance(op_classes, (set, tuple, frozenset)): + ret = tuple(_make_default(x) for x in op_classes) + else: + ret = _make_default(op_classes) + + if isinstance(ret, set) and len(ret) == 1: + return ret[0] + + return ret + + +def _make_set(ops): + if isinstance(ops, (tuple, list, frozenset)): + return set(ops) + else: + return ops diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/pauli.py b/lib/python3.10/site-packages/sympy/physics/quantum/pauli.py new file mode 100644 index 0000000000000000000000000000000000000000..89762ed2b38e1c5df3775714ee08d3700df0fa65 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/pauli.py @@ -0,0 +1,675 @@ +"""Pauli operators and states""" + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.physics.quantum import Operator, Ket, Bra +from sympy.physics.quantum import ComplexSpace +from sympy.matrices import Matrix +from sympy.functions.special.tensor_functions import KroneckerDelta + +__all__ = [ + 'SigmaX', 'SigmaY', 'SigmaZ', 'SigmaMinus', 'SigmaPlus', 'SigmaZKet', + 'SigmaZBra', 'qsimplify_pauli' +] + + +class SigmaOpBase(Operator): + """Pauli sigma operator, base class""" + + @property + def name(self): + return self.args[0] + + @property + def use_name(self): + return bool(self.args[0]) is not False + + @classmethod + def default_args(self): + return (False,) + + def __new__(cls, *args, **hints): + return Operator.__new__(cls, *args, **hints) + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + +class SigmaX(SigmaOpBase): + """Pauli sigma x operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaX + >>> sx = SigmaX() + >>> sx + SigmaX() + >>> represent(sx) + Matrix([ + [0, 1], + [1, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args, **hints) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaY(self.name) + + def _eval_commutator_BosonOp(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaY(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_x^{(%s)}}' % str(self.name) + else: + return r'{\sigma_x}' + + def _print_contents(self, printer, *args): + return 'SigmaX()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaX(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 1], [1, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaY(SigmaOpBase): + """Pauli sigma y operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaY + >>> sy = SigmaY() + >>> sy + SigmaY() + >>> represent(sy) + Matrix([ + [0, -I], + [I, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaX(self.name) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaZ(self.name) + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_y^{(%s)}}' % str(self.name) + else: + return r'{\sigma_y}' + + def _print_contents(self, printer, *args): + return 'SigmaY()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaY(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, -I], [I, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZ(SigmaOpBase): + """Pauli sigma z operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent + >>> from sympy.physics.quantum.pauli import SigmaZ + >>> sz = SigmaZ() + >>> sz ** 3 + SigmaZ() + >>> represent(sz) + Matrix([ + [1, 0], + [0, -1]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return 2 * I * SigmaY(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return - 2 * I * SigmaX(self.name) + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaY(self, other, **hints): + return S.Zero + + def _eval_adjoint(self): + return self + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_z^{(%s)}}' % str(self.name) + else: + return r'{\sigma_z}' + + def _print_contents(self, printer, *args): + return 'SigmaZ()' + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return SigmaZ(self.name).__pow__(int(e) % 2) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[1, 0], [0, -1]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaMinus(SigmaOpBase): + """Pauli sigma minus operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent, Dagger + >>> from sympy.physics.quantum.pauli import SigmaMinus + >>> sm = SigmaMinus() + >>> sm + SigmaMinus() + >>> Dagger(sm) + SigmaPlus() + >>> represent(sm) + Matrix([ + [0, 0], + [1, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return -SigmaZ(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + return 2 * self + + def _eval_commutator_SigmaMinus(self, other, **hints): + return SigmaZ(self.name) + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.One + + def _eval_anticommutator_SigmaY(self, other, **hints): + return I * S.NegativeOne + + def _eval_anticommutator_SigmaPlus(self, other, **hints): + return S.One + + def _eval_adjoint(self): + return SigmaPlus(self.name) + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return S.Zero + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_-^{(%s)}}' % str(self.name) + else: + return r'{\sigma_-}' + + def _print_contents(self, printer, *args): + return 'SigmaMinus()' + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 0], [1, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaPlus(SigmaOpBase): + """Pauli sigma plus operator + + Parameters + ========== + + name : str + An optional string that labels the operator. Pauli operators with + different names commute. + + Examples + ======== + + >>> from sympy.physics.quantum import represent, Dagger + >>> from sympy.physics.quantum.pauli import SigmaPlus + >>> sp = SigmaPlus() + >>> sp + SigmaPlus() + >>> Dagger(sp) + SigmaMinus() + >>> represent(sp) + Matrix([ + [0, 1], + [0, 0]]) + """ + + def __new__(cls, *args, **hints): + return SigmaOpBase.__new__(cls, *args) + + def _eval_commutator_SigmaX(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return SigmaZ(self.name) + + def _eval_commutator_SigmaY(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return I * SigmaZ(self.name) + + def _eval_commutator_SigmaZ(self, other, **hints): + if self.name != other.name: + return S.Zero + else: + return -2 * self + + def _eval_commutator_SigmaMinus(self, other, **hints): + return SigmaZ(self.name) + + def _eval_anticommutator_SigmaZ(self, other, **hints): + return S.Zero + + def _eval_anticommutator_SigmaX(self, other, **hints): + return S.One + + def _eval_anticommutator_SigmaY(self, other, **hints): + return I + + def _eval_anticommutator_SigmaMinus(self, other, **hints): + return S.One + + def _eval_adjoint(self): + return SigmaMinus(self.name) + + def _eval_mul(self, other): + return self * other + + def _eval_power(self, e): + if e.is_Integer and e.is_positive: + return S.Zero + + def _print_contents_latex(self, printer, *args): + if self.use_name: + return r'{\sigma_+^{(%s)}}' % str(self.name) + else: + return r'{\sigma_+}' + + def _print_contents(self, printer, *args): + return 'SigmaPlus()' + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[0, 1], [0, 0]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZKet(Ket): + """Ket for a two-level system quantum system. + + Parameters + ========== + + n : Number + The state number (0 or 1). + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Ket.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return SigmaZBra + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(2) + + def _eval_innerproduct_SigmaZBra(self, bra, **hints): + return KroneckerDelta(self.n, bra.n) + + def _apply_from_right_to_SigmaZ(self, op, **options): + if self.n == 0: + return self + else: + return S.NegativeOne * self + + def _apply_from_right_to_SigmaX(self, op, **options): + return SigmaZKet(1) if self.n == 0 else SigmaZKet(0) + + def _apply_from_right_to_SigmaY(self, op, **options): + return I * SigmaZKet(1) if self.n == 0 else (-I) * SigmaZKet(0) + + def _apply_from_right_to_SigmaMinus(self, op, **options): + if self.n == 0: + return SigmaZKet(1) + else: + return S.Zero + + def _apply_from_right_to_SigmaPlus(self, op, **options): + if self.n == 0: + return S.Zero + else: + return SigmaZKet(0) + + def _represent_default_basis(self, **options): + format = options.get('format', 'sympy') + if format == 'sympy': + return Matrix([[1], [0]]) if self.n == 0 else Matrix([[0], [1]]) + else: + raise NotImplementedError('Representation in format ' + + format + ' not implemented.') + + +class SigmaZBra(Bra): + """Bra for a two-level quantum system. + + Parameters + ========== + + n : Number + The state number (0 or 1). + + """ + + def __new__(cls, n): + if n not in (0, 1): + raise ValueError("n must be 0 or 1") + return Bra.__new__(cls, n) + + @property + def n(self): + return self.label[0] + + @classmethod + def dual_class(self): + return SigmaZKet + + +def _qsimplify_pauli_product(a, b): + """ + Internal helper function for simplifying products of Pauli operators. + """ + if not (isinstance(a, SigmaOpBase) and isinstance(b, SigmaOpBase)): + return Mul(a, b) + + if a.name != b.name: + # Pauli matrices with different labels commute; sort by name + if a.name < b.name: + return Mul(a, b) + else: + return Mul(b, a) + + elif isinstance(a, SigmaX): + + if isinstance(b, SigmaX): + return S.One + + if isinstance(b, SigmaY): + return I * SigmaZ(a.name) + + if isinstance(b, SigmaZ): + return - I * SigmaY(a.name) + + if isinstance(b, SigmaMinus): + return (S.Half + SigmaZ(a.name)/2) + + if isinstance(b, SigmaPlus): + return (S.Half - SigmaZ(a.name)/2) + + elif isinstance(a, SigmaY): + + if isinstance(b, SigmaX): + return - I * SigmaZ(a.name) + + if isinstance(b, SigmaY): + return S.One + + if isinstance(b, SigmaZ): + return I * SigmaX(a.name) + + if isinstance(b, SigmaMinus): + return -I * (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaPlus): + return I * (S.One - SigmaZ(a.name))/2 + + elif isinstance(a, SigmaZ): + + if isinstance(b, SigmaX): + return I * SigmaY(a.name) + + if isinstance(b, SigmaY): + return - I * SigmaX(a.name) + + if isinstance(b, SigmaZ): + return S.One + + if isinstance(b, SigmaMinus): + return - SigmaMinus(a.name) + + if isinstance(b, SigmaPlus): + return SigmaPlus(a.name) + + elif isinstance(a, SigmaMinus): + + if isinstance(b, SigmaX): + return (S.One - SigmaZ(a.name))/2 + + if isinstance(b, SigmaY): + return - I * (S.One - SigmaZ(a.name))/2 + + if isinstance(b, SigmaZ): + # (SigmaX(a.name) - I * SigmaY(a.name))/2 + return SigmaMinus(b.name) + + if isinstance(b, SigmaMinus): + return S.Zero + + if isinstance(b, SigmaPlus): + return S.Half - SigmaZ(a.name)/2 + + elif isinstance(a, SigmaPlus): + + if isinstance(b, SigmaX): + return (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaY): + return I * (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaZ): + #-(SigmaX(a.name) + I * SigmaY(a.name))/2 + return -SigmaPlus(a.name) + + if isinstance(b, SigmaMinus): + return (S.One + SigmaZ(a.name))/2 + + if isinstance(b, SigmaPlus): + return S.Zero + + else: + return a * b + + +def qsimplify_pauli(e): + """ + Simplify an expression that includes products of pauli operators. + + Parameters + ========== + + e : expression + An expression that contains products of Pauli operators that is + to be simplified. + + Examples + ======== + + >>> from sympy.physics.quantum.pauli import SigmaX, SigmaY + >>> from sympy.physics.quantum.pauli import qsimplify_pauli + >>> sx, sy = SigmaX(), SigmaY() + >>> sx * sy + SigmaX()*SigmaY() + >>> qsimplify_pauli(sx * sy) + I*SigmaZ() + """ + if isinstance(e, Operator): + return e + + if isinstance(e, (Add, Pow, exp)): + t = type(e) + return t(*(qsimplify_pauli(arg) for arg in e.args)) + + if isinstance(e, Mul): + + c, nc = e.args_cnc() + + nc_s = [] + while nc: + curr = nc.pop(0) + + while (len(nc) and + isinstance(curr, SigmaOpBase) and + isinstance(nc[0], SigmaOpBase) and + curr.name == nc[0].name): + + x = nc.pop(0) + y = _qsimplify_pauli_product(curr, x) + c1, nc1 = y.args_cnc() + curr = Mul(*nc1) + c = c + c1 + + nc_s.append(curr) + + return Mul(*c) * Mul(*nc_s) + + return e diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/piab.py b/lib/python3.10/site-packages/sympy/physics/quantum/piab.py new file mode 100644 index 0000000000000000000000000000000000000000..f8ac8135ee03e640f745070602c7dd8ca20f2767 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/piab.py @@ -0,0 +1,72 @@ +"""1D quantum particle in a box.""" + +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.sets.sets import Interval + +from sympy.physics.quantum.operator import HermitianOperator +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.constants import hbar +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum.hilbert import L2 + +m = Symbol('m') +L = Symbol('L') + + +__all__ = [ + 'PIABHamiltonian', + 'PIABKet', + 'PIABBra' +] + + +class PIABHamiltonian(HermitianOperator): + """Particle in a box Hamiltonian operator.""" + + @classmethod + def _eval_hilbert_space(cls, label): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + def _apply_operator_PIABKet(self, ket, **options): + n = ket.label[0] + return (n**2*pi**2*hbar**2)/(2*m*L**2)*ket + + +class PIABKet(Ket): + """Particle in a box eigenket.""" + + @classmethod + def _eval_hilbert_space(cls, args): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + @classmethod + def dual_class(self): + return PIABBra + + def _represent_default_basis(self, **options): + return self._represent_XOp(None, **options) + + def _represent_XOp(self, basis, **options): + x = Symbol('x') + n = Symbol('n') + subs_info = options.get('subs', {}) + return sqrt(2/L)*sin(n*pi*x/L).subs(subs_info) + + def _eval_innerproduct_PIABBra(self, bra): + return KroneckerDelta(bra.label[0], self.label[0]) + + +class PIABBra(Bra): + """Particle in a box eigenbra.""" + + @classmethod + def _eval_hilbert_space(cls, label): + return L2(Interval(S.NegativeInfinity, S.Infinity)) + + @classmethod + def dual_class(self): + return PIABKet diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/qapply.py b/lib/python3.10/site-packages/sympy/physics/quantum/qapply.py new file mode 100644 index 0000000000000000000000000000000000000000..2109ed1abc1abf302f6a79bf4d1ade6e2d55d7c6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/qapply.py @@ -0,0 +1,212 @@ +"""Logic for applying operators to states. + +Todo: +* Sometimes the final result needs to be expanded, we should do this by hand. +""" + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.sympify import sympify + +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.operator import OuterProduct, Operator +from sympy.physics.quantum.state import State, KetBase, BraBase, Wavefunction +from sympy.physics.quantum.tensorproduct import TensorProduct + +__all__ = [ + 'qapply' +] + + +#----------------------------------------------------------------------------- +# Main code +#----------------------------------------------------------------------------- + +def qapply(e, **options): + """Apply operators to states in a quantum expression. + + Parameters + ========== + + e : Expr + The expression containing operators and states. This expression tree + will be walked to find operators acting on states symbolically. + options : dict + A dict of key/value pairs that determine how the operator actions + are carried out. + + The following options are valid: + + * ``dagger``: try to apply Dagger operators to the left + (default: False). + * ``ip_doit``: call ``.doit()`` in inner products when they are + encountered (default: True). + + Returns + ======= + + e : Expr + The original expression, but with the operators applied to states. + + Examples + ======== + + >>> from sympy.physics.quantum import qapply, Ket, Bra + >>> b = Bra('b') + >>> k = Ket('k') + >>> A = k * b + >>> A + |k>>> qapply(A * b.dual / (b * b.dual)) + |k> + >>> qapply(k.dual * A / (k.dual * k), dagger=True) + >> qapply(k.dual * A / (k.dual * k)) + + """ + from sympy.physics.quantum.density import Density + + dagger = options.get('dagger', False) + + if e == 0: + return S.Zero + + # This may be a bit aggressive but ensures that everything gets expanded + # to its simplest form before trying to apply operators. This includes + # things like (A+B+C)*|a> and A*(|a>+|b>) and all Commutators and + # TensorProducts. The only problem with this is that if we can't apply + # all the Operators, we have just expanded everything. + # TODO: don't expand the scalars in front of each Mul. + e = e.expand(commutator=True, tensorproduct=True) + + # If we just have a raw ket, return it. + if isinstance(e, KetBase): + return e + + # We have an Add(a, b, c, ...) and compute + # Add(qapply(a), qapply(b), ...) + elif isinstance(e, Add): + result = 0 + for arg in e.args: + result += qapply(arg, **options) + return result.expand() + + # For a Density operator call qapply on its state + elif isinstance(e, Density): + new_args = [(qapply(state, **options), prob) for (state, + prob) in e.args] + return Density(*new_args) + + # For a raw TensorProduct, call qapply on its args. + elif isinstance(e, TensorProduct): + return TensorProduct(*[qapply(t, **options) for t in e.args]) + + # For a Pow, call qapply on its base. + elif isinstance(e, Pow): + return qapply(e.base, **options)**e.exp + + # We have a Mul where there might be actual operators to apply to kets. + elif isinstance(e, Mul): + c_part, nc_part = e.args_cnc() + c_mul = Mul(*c_part) + nc_mul = Mul(*nc_part) + if isinstance(nc_mul, Mul): + result = c_mul*qapply_Mul(nc_mul, **options) + else: + result = c_mul*qapply(nc_mul, **options) + if result == e and dagger: + return Dagger(qapply_Mul(Dagger(e), **options)) + else: + return result + + # In all other cases (State, Operator, Pow, Commutator, InnerProduct, + # OuterProduct) we won't ever have operators to apply to kets. + else: + return e + + +def qapply_Mul(e, **options): + + ip_doit = options.get('ip_doit', True) + + args = list(e.args) + + # If we only have 0 or 1 args, we have nothing to do and return. + if len(args) <= 1 or not isinstance(e, Mul): + return e + rhs = args.pop() + lhs = args.pop() + + # Make sure we have two non-commutative objects before proceeding. + if (not isinstance(rhs, Wavefunction) and sympify(rhs).is_commutative) or \ + (not isinstance(lhs, Wavefunction) and sympify(lhs).is_commutative): + return e + + # For a Pow with an integer exponent, apply one of them and reduce the + # exponent by one. + if isinstance(lhs, Pow) and lhs.exp.is_Integer: + args.append(lhs.base**(lhs.exp - 1)) + lhs = lhs.base + + # Pull OuterProduct apart + if isinstance(lhs, OuterProduct): + args.append(lhs.ket) + lhs = lhs.bra + + # Call .doit() on Commutator/AntiCommutator. + if isinstance(lhs, (Commutator, AntiCommutator)): + comm = lhs.doit() + if isinstance(comm, Add): + return qapply( + e.func(*(args + [comm.args[0], rhs])) + + e.func(*(args + [comm.args[1], rhs])), + **options + ) + else: + return qapply(e.func(*args)*comm*rhs, **options) + + # Apply tensor products of operators to states + if isinstance(lhs, TensorProduct) and all(isinstance(arg, (Operator, State, Mul, Pow)) or arg == 1 for arg in lhs.args) and \ + isinstance(rhs, TensorProduct) and all(isinstance(arg, (Operator, State, Mul, Pow)) or arg == 1 for arg in rhs.args) and \ + len(lhs.args) == len(rhs.args): + result = TensorProduct(*[qapply(lhs.args[n]*rhs.args[n], **options) for n in range(len(lhs.args))]).expand(tensorproduct=True) + return qapply_Mul(e.func(*args), **options)*result + + # Now try to actually apply the operator and build an inner product. + try: + result = lhs._apply_operator(rhs, **options) + except NotImplementedError: + result = None + + if result is None: + _apply_right = getattr(rhs, '_apply_from_right_to', None) + if _apply_right is not None: + try: + result = _apply_right(lhs, **options) + except NotImplementedError: + result = None + + if result is None: + if isinstance(lhs, BraBase) and isinstance(rhs, KetBase): + result = InnerProduct(lhs, rhs) + if ip_doit: + result = result.doit() + + # TODO: I may need to expand before returning the final result. + if result == 0: + return S.Zero + elif result is None: + if len(args) == 0: + # We had two args to begin with so args=[]. + return e + else: + return qapply_Mul(e.func(*(args + [lhs])), **options)*rhs + elif isinstance(result, InnerProduct): + return result*qapply_Mul(e.func(*args), **options) + else: # result is a scalar times a Mul, Add or TensorProduct + return qapply(e.func(*args)*result, **options) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/qasm.py b/lib/python3.10/site-packages/sympy/physics/quantum/qasm.py new file mode 100644 index 0000000000000000000000000000000000000000..39b49d9a67399114e7d03f12148854b2e41b0b26 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/qasm.py @@ -0,0 +1,224 @@ +""" + +qasm.py - Functions to parse a set of qasm commands into a SymPy Circuit. + +Examples taken from Chuang's page: https://web.archive.org/web/20220120121541/https://www.media.mit.edu/quanta/qasm2circ/ + +The code returns a circuit and an associated list of labels. + +>>> from sympy.physics.quantum.qasm import Qasm +>>> q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') +>>> q.get_circuit() +CNOT(1,0)*H(1) + +>>> q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') +>>> q.get_circuit() +CNOT(1,0)*CNOT(0,1)*CNOT(1,0) +""" + +__all__ = [ + 'Qasm', + ] + +from math import prod + +from sympy.physics.quantum.gate import H, CNOT, X, Z, CGate, CGateS, SWAP, S, T,CPHASE +from sympy.physics.quantum.circuitplot import Mz + +def read_qasm(lines): + return Qasm(*lines.splitlines()) + +def read_qasm_file(filename): + return Qasm(*open(filename).readlines()) + +def flip_index(i, n): + """Reorder qubit indices from largest to smallest. + + >>> from sympy.physics.quantum.qasm import flip_index + >>> flip_index(0, 2) + 1 + >>> flip_index(1, 2) + 0 + """ + return n-i-1 + +def trim(line): + """Remove everything following comment # characters in line. + + >>> from sympy.physics.quantum.qasm import trim + >>> trim('nothing happens here') + 'nothing happens here' + >>> trim('something #happens here') + 'something ' + """ + if '#' not in line: + return line + return line.split('#')[0] + +def get_index(target, labels): + """Get qubit labels from the rest of the line,and return indices + + >>> from sympy.physics.quantum.qasm import get_index + >>> get_index('q0', ['q0', 'q1']) + 1 + >>> get_index('q1', ['q0', 'q1']) + 0 + """ + nq = len(labels) + return flip_index(labels.index(target), nq) + +def get_indices(targets, labels): + return [get_index(t, labels) for t in targets] + +def nonblank(args): + for line in args: + line = trim(line) + if line.isspace(): + continue + yield line + return + +def fullsplit(line): + words = line.split() + rest = ' '.join(words[1:]) + return fixcommand(words[0]), [s.strip() for s in rest.split(',')] + +def fixcommand(c): + """Fix Qasm command names. + + Remove all of forbidden characters from command c, and + replace 'def' with 'qdef'. + """ + forbidden_characters = ['-'] + c = c.lower() + for char in forbidden_characters: + c = c.replace(char, '') + if c == 'def': + return 'qdef' + return c + +def stripquotes(s): + """Replace explicit quotes in a string. + + >>> from sympy.physics.quantum.qasm import stripquotes + >>> stripquotes("'S'") == 'S' + True + >>> stripquotes('"S"') == 'S' + True + >>> stripquotes('S') == 'S' + True + """ + s = s.replace('"', '') # Remove second set of quotes? + s = s.replace("'", '') + return s + +class Qasm: + """Class to form objects from Qasm lines + + >>> from sympy.physics.quantum.qasm import Qasm + >>> q = Qasm('qubit q0', 'qubit q1', 'h q0', 'cnot q0,q1') + >>> q.get_circuit() + CNOT(1,0)*H(1) + >>> q = Qasm('qubit q0', 'qubit q1', 'cnot q0,q1', 'cnot q1,q0', 'cnot q0,q1') + >>> q.get_circuit() + CNOT(1,0)*CNOT(0,1)*CNOT(1,0) + """ + def __init__(self, *args, **kwargs): + self.defs = {} + self.circuit = [] + self.labels = [] + self.inits = {} + self.add(*args) + self.kwargs = kwargs + + def add(self, *lines): + for line in nonblank(lines): + command, rest = fullsplit(line) + if self.defs.get(command): #defs come first, since you can override built-in + function = self.defs.get(command) + indices = self.indices(rest) + if len(indices) == 1: + self.circuit.append(function(indices[0])) + else: + self.circuit.append(function(indices[:-1], indices[-1])) + elif hasattr(self, command): + function = getattr(self, command) + function(*rest) + else: + print("Function %s not defined. Skipping" % command) + + def get_circuit(self): + return prod(reversed(self.circuit)) + + def get_labels(self): + return list(reversed(self.labels)) + + def plot(self): + from sympy.physics.quantum.circuitplot import CircuitPlot + circuit, labels = self.get_circuit(), self.get_labels() + CircuitPlot(circuit, len(labels), labels=labels, inits=self.inits) + + def qubit(self, arg, init=None): + self.labels.append(arg) + if init: self.inits[arg] = init + + def indices(self, args): + return get_indices(args, self.labels) + + def index(self, arg): + return get_index(arg, self.labels) + + def nop(self, *args): + pass + + def x(self, arg): + self.circuit.append(X(self.index(arg))) + + def z(self, arg): + self.circuit.append(Z(self.index(arg))) + + def h(self, arg): + self.circuit.append(H(self.index(arg))) + + def s(self, arg): + self.circuit.append(S(self.index(arg))) + + def t(self, arg): + self.circuit.append(T(self.index(arg))) + + def measure(self, arg): + self.circuit.append(Mz(self.index(arg))) + + def cnot(self, a1, a2): + self.circuit.append(CNOT(*self.indices([a1, a2]))) + + def swap(self, a1, a2): + self.circuit.append(SWAP(*self.indices([a1, a2]))) + + def cphase(self, a1, a2): + self.circuit.append(CPHASE(*self.indices([a1, a2]))) + + def toffoli(self, a1, a2, a3): + i1, i2, i3 = self.indices([a1, a2, a3]) + self.circuit.append(CGateS((i1, i2), X(i3))) + + def cx(self, a1, a2): + fi, fj = self.indices([a1, a2]) + self.circuit.append(CGate(fi, X(fj))) + + def cz(self, a1, a2): + fi, fj = self.indices([a1, a2]) + self.circuit.append(CGate(fi, Z(fj))) + + def defbox(self, *args): + print("defbox not supported yet. Skipping: ", args) + + def qdef(self, name, ncontrols, symbol): + from sympy.physics.quantum.circuitplot import CreateOneQubitGate, CreateCGate + ncontrols = int(ncontrols) + command = fixcommand(name) + symbol = stripquotes(symbol) + if ncontrols > 0: + self.defs[command] = CreateCGate(symbol) + else: + self.defs[command] = CreateOneQubitGate(symbol) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/qexpr.py b/lib/python3.10/site-packages/sympy/physics/quantum/qexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..13f7f70294c5a2fcdeda007a199a87f5a3022f79 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/qexpr.py @@ -0,0 +1,413 @@ +from sympy.core.expr import Expr +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices.dense import Matrix +from sympy.printing.pretty.stringpict import prettyForm +from sympy.core.containers import Tuple +from sympy.utilities.iterables import is_sequence + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, scipy_sparse_matrix, + to_sympy, to_numpy, to_scipy_sparse +) + +__all__ = [ + 'QuantumError', + 'QExpr' +] + + +#----------------------------------------------------------------------------- +# Error handling +#----------------------------------------------------------------------------- + +class QuantumError(Exception): + pass + + +def _qsympify_sequence(seq): + """Convert elements of a sequence to standard form. + + This is like sympify, but it performs special logic for arguments passed + to QExpr. The following conversions are done: + + * (list, tuple, Tuple) => _qsympify_sequence each element and convert + sequence to a Tuple. + * basestring => Symbol + * Matrix => Matrix + * other => sympify + + Strings are passed to Symbol, not sympify to make sure that variables like + 'pi' are kept as Symbols, not the SymPy built-in number subclasses. + + Examples + ======== + + >>> from sympy.physics.quantum.qexpr import _qsympify_sequence + >>> _qsympify_sequence((1,2,[3,4,[1,]])) + (1, 2, (3, 4, (1,))) + + """ + + return tuple(__qsympify_sequence_helper(seq)) + + +def __qsympify_sequence_helper(seq): + """ + Helper function for _qsympify_sequence + This function does the actual work. + """ + #base case. If not a list, do Sympification + if not is_sequence(seq): + if isinstance(seq, Matrix): + return seq + elif isinstance(seq, str): + return Symbol(seq) + else: + return sympify(seq) + + # base condition, when seq is QExpr and also + # is iterable. + if isinstance(seq, QExpr): + return seq + + #if list, recurse on each item in the list + result = [__qsympify_sequence_helper(item) for item in seq] + + return Tuple(*result) + + +#----------------------------------------------------------------------------- +# Basic Quantum Expression from which all objects descend +#----------------------------------------------------------------------------- + +class QExpr(Expr): + """A base class for all quantum object like operators and states.""" + + # In sympy, slots are for instance attributes that are computed + # dynamically by the __new__ method. They are not part of args, but they + # derive from args. + + # The Hilbert space a quantum Object belongs to. + __slots__ = ('hilbert_space', ) + + is_commutative = False + + # The separator used in printing the label. + _label_separator = '' + + @property + def free_symbols(self): + return {self} + + def __new__(cls, *args, **kwargs): + """Construct a new quantum object. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + quantum object. For a state, this will be its symbol or its + set of quantum numbers. + + Examples + ======== + + >>> from sympy.physics.quantum.qexpr import QExpr + >>> q = QExpr(0) + >>> q + 0 + >>> q.label + (0,) + >>> q.hilbert_space + H + >>> q.args + (0,) + >>> q.is_commutative + False + """ + + # First compute args and call Expr.__new__ to create the instance + args = cls._eval_args(args, **kwargs) + if len(args) == 0: + args = cls._eval_args(tuple(cls.default_args()), **kwargs) + inst = Expr.__new__(cls, *args) + # Now set the slots on the instance + inst.hilbert_space = cls._eval_hilbert_space(args) + return inst + + @classmethod + def _new_rawargs(cls, hilbert_space, *args, **old_assumptions): + """Create new instance of this class with hilbert_space and args. + + This is used to bypass the more complex logic in the ``__new__`` + method in cases where you already have the exact ``hilbert_space`` + and ``args``. This should be used when you are positive these + arguments are valid, in their final, proper form and want to optimize + the creation of the object. + """ + + obj = Expr.__new__(cls, *args, **old_assumptions) + obj.hilbert_space = hilbert_space + return obj + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def label(self): + """The label is the unique set of identifiers for the object. + + Usually, this will include all of the information about the state + *except* the time (in the case of time-dependent objects). + + This must be a tuple, rather than a Tuple. + """ + if len(self.args) == 0: # If there is no label specified, return the default + return self._eval_args(list(self.default_args())) + else: + return self.args + + @property + def is_symbolic(self): + return True + + @classmethod + def default_args(self): + """If no arguments are specified, then this will return a default set + of arguments to be run through the constructor. + + NOTE: Any classes that override this MUST return a tuple of arguments. + Should be overridden by subclasses to specify the default arguments for kets and operators + """ + raise NotImplementedError("No default arguments for this class!") + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_adjoint(self): + obj = Expr._eval_adjoint(self) + if obj is None: + obj = Expr.__new__(Dagger, self) + if isinstance(obj, QExpr): + obj.hilbert_space = self.hilbert_space + return obj + + @classmethod + def _eval_args(cls, args): + """Process the args passed to the __new__ method. + + This simply runs args through _qsympify_sequence. + """ + return _qsympify_sequence(args) + + @classmethod + def _eval_hilbert_space(cls, args): + """Compute the Hilbert space instance from the args. + """ + from sympy.physics.quantum.hilbert import HilbertSpace + return HilbertSpace() + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + # Utilities for printing: these operate on raw SymPy objects + + def _print_sequence(self, seq, sep, printer, *args): + result = [] + for item in seq: + result.append(printer._print(item, *args)) + return sep.join(result) + + def _print_sequence_pretty(self, seq, sep, printer, *args): + pform = printer._print(seq[0], *args) + for item in seq[1:]: + pform = prettyForm(*pform.right(sep)) + pform = prettyForm(*pform.right(printer._print(item, *args))) + return pform + + # Utilities for printing: these operate prettyForm objects + + def _print_subscript_pretty(self, a, b): + top = prettyForm(*b.left(' '*a.width())) + bot = prettyForm(*a.right(' '*b.width())) + return prettyForm(binding=prettyForm.POW, *bot.below(top)) + + def _print_superscript_pretty(self, a, b): + return a**b + + def _print_parens_pretty(self, pform, left='(', right=')'): + return prettyForm(*pform.parens(left=left, right=right)) + + # Printing of labels (i.e. args) + + def _print_label(self, printer, *args): + """Prints the label of the QExpr + + This method prints self.label, using self._label_separator to separate + the elements. This method should not be overridden, instead, override + _print_contents to change printing behavior. + """ + return self._print_sequence( + self.label, self._label_separator, printer, *args + ) + + def _print_label_repr(self, printer, *args): + return self._print_sequence( + self.label, ',', printer, *args + ) + + def _print_label_pretty(self, printer, *args): + return self._print_sequence_pretty( + self.label, self._label_separator, printer, *args + ) + + def _print_label_latex(self, printer, *args): + return self._print_sequence( + self.label, self._label_separator, printer, *args + ) + + # Printing of contents (default to label) + + def _print_contents(self, printer, *args): + """Printer for contents of QExpr + + Handles the printing of any unique identifying contents of a QExpr to + print as its contents, such as any variables or quantum numbers. The + default is to print the label, which is almost always the args. This + should not include printing of any brackets or parentheses. + """ + return self._print_label(printer, *args) + + def _print_contents_pretty(self, printer, *args): + return self._print_label_pretty(printer, *args) + + def _print_contents_latex(self, printer, *args): + return self._print_label_latex(printer, *args) + + # Main printing methods + + def _sympystr(self, printer, *args): + """Default printing behavior of QExpr objects + + Handles the default printing of a QExpr. To add other things to the + printing of the object, such as an operator name to operators or + brackets to states, the class should override the _print/_pretty/_latex + functions directly and make calls to _print_contents where appropriate. + This allows things like InnerProduct to easily control its printing the + printing of contents. + """ + return self._print_contents(printer, *args) + + def _sympyrepr(self, printer, *args): + classname = self.__class__.__name__ + label = self._print_label_repr(printer, *args) + return '%s(%s)' % (classname, label) + + def _pretty(self, printer, *args): + pform = self._print_contents_pretty(printer, *args) + return pform + + def _latex(self, printer, *args): + return self._print_contents_latex(printer, *args) + + #------------------------------------------------------------------------- + # Represent + #------------------------------------------------------------------------- + + def _represent_default_basis(self, **options): + raise NotImplementedError('This object does not have a default basis') + + def _represent(self, *, basis=None, **options): + """Represent this object in a given basis. + + This method dispatches to the actual methods that perform the + representation. Subclases of QExpr should define various methods to + determine how the object will be represented in various bases. The + format of these methods is:: + + def _represent_BasisName(self, basis, **options): + + Thus to define how a quantum object is represented in the basis of + the operator Position, you would define:: + + def _represent_Position(self, basis, **options): + + Usually, basis object will be instances of Operator subclasses, but + there is a chance we will relax this in the future to accommodate other + types of basis sets that are not associated with an operator. + + If the ``format`` option is given it can be ("sympy", "numpy", + "scipy.sparse"). This will ensure that any matrices that result from + representing the object are returned in the appropriate matrix format. + + Parameters + ========== + + basis : Operator + The Operator whose basis functions will be used as the basis for + representation. + options : dict + A dictionary of key/value pairs that give options and hints for + the representation, such as the number of basis functions to + be used. + """ + if basis is None: + result = self._represent_default_basis(**options) + else: + result = dispatch_method(self, '_represent', basis, **options) + + # If we get a matrix representation, convert it to the right format. + format = options.get('format', 'sympy') + result = self._format_represent(result, format) + return result + + def _format_represent(self, result, format): + if format == 'sympy' and not isinstance(result, Matrix): + return to_sympy(result) + elif format == 'numpy' and not isinstance(result, numpy_ndarray): + return to_numpy(result) + elif format == 'scipy.sparse' and \ + not isinstance(result, scipy_sparse_matrix): + return to_scipy_sparse(result) + + return result + + +def split_commutative_parts(e): + """Split into commutative and non-commutative parts.""" + c_part, nc_part = e.args_cnc() + c_part = list(c_part) + return c_part, nc_part + + +def split_qexpr_parts(e): + """Split an expression into Expr and noncommutative QExpr parts.""" + expr_part = [] + qexpr_part = [] + for arg in e.args: + if not isinstance(arg, QExpr): + expr_part.append(arg) + else: + qexpr_part.append(arg) + return expr_part, qexpr_part + + +def dispatch_method(self, basename, arg, **options): + """Dispatch a method to the proper handlers.""" + method_name = '%s_%s' % (basename, arg.__class__.__name__) + if hasattr(self, method_name): + f = getattr(self, method_name) + # This can raise and we will allow it to propagate. + result = f(arg, **options) + if result is not None: + return result + raise NotImplementedError( + "%s.%s cannot handle: %r" % + (self.__class__.__name__, basename, arg) + ) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/qft.py b/lib/python3.10/site-packages/sympy/physics/quantum/qft.py new file mode 100644 index 0000000000000000000000000000000000000000..c6a3fa4539267f7bb6cf015521007e292b3d4cfd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/qft.py @@ -0,0 +1,215 @@ +"""An implementation of qubits and gates acting on them. + +Todo: + +* Update docstrings. +* Update tests. +* Implement apply using decompose. +* Implement represent using decompose or something smarter. For this to + work we first have to implement represent for SWAP. +* Decide if we want upper index to be inclusive in the constructor. +* Fix the printing of Rk gates in plotting. +""" + +from sympy.core.expr import Expr +from sympy.core.numbers import (I, Integer, pi) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.matrices.dense import Matrix +from sympy.functions import sqrt + +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qexpr import QuantumError, QExpr +from sympy.matrices import eye +from sympy.physics.quantum.tensorproduct import matrix_tensor_product + +from sympy.physics.quantum.gate import ( + Gate, HadamardGate, SwapGate, OneQubitGate, CGate, PhaseGate, TGate, ZGate +) + +from sympy.functions.elementary.complexes import sign + +__all__ = [ + 'QFT', + 'IQFT', + 'RkGate', + 'Rk' +] + +#----------------------------------------------------------------------------- +# Fourier stuff +#----------------------------------------------------------------------------- + + +class RkGate(OneQubitGate): + """This is the R_k gate of the QTF.""" + gate_name = 'Rk' + gate_name_latex = 'R' + + def __new__(cls, *args): + if len(args) != 2: + raise QuantumError( + 'Rk gates only take two arguments, got: %r' % args + ) + # For small k, Rk gates simplify to other gates, using these + # substitutions give us familiar results for the QFT for small numbers + # of qubits. + target = args[0] + k = args[1] + if k == 1: + return ZGate(target) + elif k == 2: + return PhaseGate(target) + elif k == 3: + return TGate(target) + args = cls._eval_args(args) + inst = Expr.__new__(cls, *args) + inst.hilbert_space = cls._eval_hilbert_space(args) + return inst + + @classmethod + def _eval_args(cls, args): + # Fall back to this, because Gate._eval_args assumes that args is + # all targets and can't contain duplicates. + return QExpr._eval_args(args) + + @property + def k(self): + return self.label[1] + + @property + def targets(self): + return self.label[:1] + + @property + def gate_name_plot(self): + return r'$%s_%s$' % (self.gate_name_latex, str(self.k)) + + def get_target_matrix(self, format='sympy'): + if format == 'sympy': + return Matrix([[1, 0], [0, exp(sign(self.k)*Integer(2)*pi*I/(Integer(2)**abs(self.k)))]]) + raise NotImplementedError( + 'Invalid format for the R_k gate: %r' % format) + + +Rk = RkGate + + +class Fourier(Gate): + """Superclass of Quantum Fourier and Inverse Quantum Fourier Gates.""" + + @classmethod + def _eval_args(self, args): + if len(args) != 2: + raise QuantumError( + 'QFT/IQFT only takes two arguments, got: %r' % args + ) + if args[0] >= args[1]: + raise QuantumError("Start must be smaller than finish") + return Gate._eval_args(args) + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + """ + Represents the (I)QFT In the Z Basis + """ + nqubits = options.get('nqubits', 0) + if nqubits == 0: + raise QuantumError( + 'The number of qubits must be given as nqubits.') + if nqubits < self.min_qubits: + raise QuantumError( + 'The number of qubits %r is too small for the gate.' % nqubits + ) + size = self.size + omega = self.omega + + #Make a matrix that has the basic Fourier Transform Matrix + arrayFT = [[omega**( + i*j % size)/sqrt(size) for i in range(size)] for j in range(size)] + matrixFT = Matrix(arrayFT) + + #Embed the FT Matrix in a higher space, if necessary + if self.label[0] != 0: + matrixFT = matrix_tensor_product(eye(2**self.label[0]), matrixFT) + if self.min_qubits < nqubits: + matrixFT = matrix_tensor_product( + matrixFT, eye(2**(nqubits - self.min_qubits))) + + return matrixFT + + @property + def targets(self): + return range(self.label[0], self.label[1]) + + @property + def min_qubits(self): + return self.label[1] + + @property + def size(self): + """Size is the size of the QFT matrix""" + return 2**(self.label[1] - self.label[0]) + + @property + def omega(self): + return Symbol('omega') + + +class QFT(Fourier): + """The forward quantum Fourier transform.""" + + gate_name = 'QFT' + gate_name_latex = 'QFT' + + def decompose(self): + """Decomposes QFT into elementary gates.""" + start = self.label[0] + finish = self.label[1] + circuit = 1 + for level in reversed(range(start, finish)): + circuit = HadamardGate(level)*circuit + for i in range(level - start): + circuit = CGate(level - i - 1, RkGate(level, i + 2))*circuit + for i in range((finish - start)//2): + circuit = SwapGate(i + start, finish - i - 1)*circuit + return circuit + + def _apply_operator_Qubit(self, qubits, **options): + return qapply(self.decompose()*qubits) + + def _eval_inverse(self): + return IQFT(*self.args) + + @property + def omega(self): + return exp(2*pi*I/self.size) + + +class IQFT(Fourier): + """The inverse quantum Fourier transform.""" + + gate_name = 'IQFT' + gate_name_latex = '{QFT^{-1}}' + + def decompose(self): + """Decomposes IQFT into elementary gates.""" + start = self.args[0] + finish = self.args[1] + circuit = 1 + for i in range((finish - start)//2): + circuit = SwapGate(i + start, finish - i - 1)*circuit + for level in range(start, finish): + for i in reversed(range(level - start)): + circuit = CGate(level - i - 1, RkGate(level, -i - 2))*circuit + circuit = HadamardGate(level)*circuit + return circuit + + def _eval_inverse(self): + return QFT(*self.args) + + @property + def omega(self): + return exp(-2*pi*I/self.size) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/qubit.py b/lib/python3.10/site-packages/sympy/physics/quantum/qubit.py new file mode 100644 index 0000000000000000000000000000000000000000..fb75b4c496b5b6292a8383c169ba63ec6d3cbb56 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/qubit.py @@ -0,0 +1,811 @@ +"""Qubits for quantum computing. + +Todo: +* Finish implementing measurement logic. This should include POVM. +* Update docstrings. +* Update tests. +""" + + +import math + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import log +from sympy.core.basic import _sympify +from sympy.external.gmpy import SYMPY_INTS +from sympy.matrices import Matrix, zeros +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.hilbert import ComplexSpace +from sympy.physics.quantum.state import Ket, Bra, State + +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, scipy_sparse_matrix +) +from mpmath.libmp.libintmath import bitcount + +__all__ = [ + 'Qubit', + 'QubitBra', + 'IntQubit', + 'IntQubitBra', + 'qubit_to_matrix', + 'matrix_to_qubit', + 'matrix_to_density', + 'measure_all', + 'measure_partial', + 'measure_partial_oneshot', + 'measure_all_oneshot' +] + +#----------------------------------------------------------------------------- +# Qubit Classes +#----------------------------------------------------------------------------- + + +class QubitState(State): + """Base class for Qubit and QubitBra.""" + + #------------------------------------------------------------------------- + # Initialization/creation + #------------------------------------------------------------------------- + + @classmethod + def _eval_args(cls, args): + # If we are passed a QubitState or subclass, we just take its qubit + # values directly. + if len(args) == 1 and isinstance(args[0], QubitState): + return args[0].qubit_values + + # Turn strings into tuple of strings + if len(args) == 1 and isinstance(args[0], str): + args = tuple( S.Zero if qb == "0" else S.One for qb in args[0]) + else: + args = tuple( S.Zero if qb == "0" else S.One if qb == "1" else qb for qb in args) + args = tuple(_sympify(arg) for arg in args) + + # Validate input (must have 0 or 1 input) + for element in args: + if element not in (S.Zero, S.One): + raise ValueError( + "Qubit values must be 0 or 1, got: %r" % element) + return args + + @classmethod + def _eval_hilbert_space(cls, args): + return ComplexSpace(2)**len(args) + + #------------------------------------------------------------------------- + # Properties + #------------------------------------------------------------------------- + + @property + def dimension(self): + """The number of Qubits in the state.""" + return len(self.qubit_values) + + @property + def nqubits(self): + return self.dimension + + @property + def qubit_values(self): + """Returns the values of the qubits as a tuple.""" + return self.label + + #------------------------------------------------------------------------- + # Special methods + #------------------------------------------------------------------------- + + def __len__(self): + return self.dimension + + def __getitem__(self, bit): + return self.qubit_values[int(self.dimension - bit - 1)] + + #------------------------------------------------------------------------- + # Utility methods + #------------------------------------------------------------------------- + + def flip(self, *bits): + """Flip the bit(s) given.""" + newargs = list(self.qubit_values) + for i in bits: + bit = int(self.dimension - i - 1) + if newargs[bit] == 1: + newargs[bit] = 0 + else: + newargs[bit] = 1 + return self.__class__(*tuple(newargs)) + + +class Qubit(QubitState, Ket): + """A multi-qubit ket in the computational (z) basis. + + We use the normal convention that the least significant qubit is on the + right, so ``|00001>`` has a 1 in the least significant qubit. + + Parameters + ========== + + values : list, str + The qubit values as a list of ints ([0,0,0,1,1,]) or a string ('011'). + + Examples + ======== + + Create a qubit in a couple of different ways and look at their attributes: + + >>> from sympy.physics.quantum.qubit import Qubit + >>> Qubit(0,0,0) + |000> + >>> q = Qubit('0101') + >>> q + |0101> + + >>> q.nqubits + 4 + >>> len(q) + 4 + >>> q.dimension + 4 + >>> q.qubit_values + (0, 1, 0, 1) + + We can flip the value of an individual qubit: + + >>> q.flip(1) + |0111> + + We can take the dagger of a Qubit to get a bra: + + >>> from sympy.physics.quantum.dagger import Dagger + >>> Dagger(q) + <0101| + >>> type(Dagger(q)) + + + Inner products work as expected: + + >>> ip = Dagger(q)*q + >>> ip + <0101|0101> + >>> ip.doit() + 1 + """ + + @classmethod + def dual_class(self): + return QubitBra + + def _eval_innerproduct_QubitBra(self, bra, **hints): + if self.label == bra.label: + return S.One + else: + return S.Zero + + def _represent_default_basis(self, **options): + return self._represent_ZGate(None, **options) + + def _represent_ZGate(self, basis, **options): + """Represent this qubits in the computational basis (ZGate). + """ + _format = options.get('format', 'sympy') + n = 1 + definite_state = 0 + for it in reversed(self.qubit_values): + definite_state += n*it + n = n*2 + result = [0]*(2**self.dimension) + result[int(definite_state)] = 1 + if _format == 'sympy': + return Matrix(result) + elif _format == 'numpy': + import numpy as np + return np.array(result, dtype='complex').transpose() + elif _format == 'scipy.sparse': + from scipy import sparse + return sparse.csr_matrix(result, dtype='complex').transpose() + + def _eval_trace(self, bra, **kwargs): + indices = kwargs.get('indices', []) + + #sort index list to begin trace from most-significant + #qubit + sorted_idx = list(indices) + if len(sorted_idx) == 0: + sorted_idx = list(range(0, self.nqubits)) + sorted_idx.sort() + + #trace out for each of index + new_mat = self*bra + for i in range(len(sorted_idx) - 1, -1, -1): + # start from tracing out from leftmost qubit + new_mat = self._reduced_density(new_mat, int(sorted_idx[i])) + + if (len(sorted_idx) == self.nqubits): + #in case full trace was requested + return new_mat[0] + else: + return matrix_to_density(new_mat) + + def _reduced_density(self, matrix, qubit, **options): + """Compute the reduced density matrix by tracing out one qubit. + The qubit argument should be of type Python int, since it is used + in bit operations + """ + def find_index_that_is_projected(j, k, qubit): + bit_mask = 2**qubit - 1 + return ((j >> qubit) << (1 + qubit)) + (j & bit_mask) + (k << qubit) + + old_matrix = represent(matrix, **options) + old_size = old_matrix.cols + #we expect the old_size to be even + new_size = old_size//2 + new_matrix = Matrix().zeros(new_size) + + for i in range(new_size): + for j in range(new_size): + for k in range(2): + col = find_index_that_is_projected(j, k, qubit) + row = find_index_that_is_projected(i, k, qubit) + new_matrix[i, j] += old_matrix[row, col] + + return new_matrix + + +class QubitBra(QubitState, Bra): + """A multi-qubit bra in the computational (z) basis. + + We use the normal convention that the least significant qubit is on the + right, so ``|00001>`` has a 1 in the least significant qubit. + + Parameters + ========== + + values : list, str + The qubit values as a list of ints ([0,0,0,1,1,]) or a string ('011'). + + See also + ======== + + Qubit: Examples using qubits + + """ + @classmethod + def dual_class(self): + return Qubit + + +class IntQubitState(QubitState): + """A base class for qubits that work with binary representations.""" + + @classmethod + def _eval_args(cls, args, nqubits=None): + # The case of a QubitState instance + if len(args) == 1 and isinstance(args[0], QubitState): + return QubitState._eval_args(args) + # otherwise, args should be integer + elif not all(isinstance(a, (int, Integer)) for a in args): + raise ValueError('values must be integers, got (%s)' % (tuple(type(a) for a in args),)) + # use nqubits if specified + if nqubits is not None: + if not isinstance(nqubits, (int, Integer)): + raise ValueError('nqubits must be an integer, got (%s)' % type(nqubits)) + if len(args) != 1: + raise ValueError( + 'too many positional arguments (%s). should be (number, nqubits=n)' % (args,)) + return cls._eval_args_with_nqubits(args[0], nqubits) + # For a single argument, we construct the binary representation of + # that integer with the minimal number of bits. + if len(args) == 1 and args[0] > 1: + #rvalues is the minimum number of bits needed to express the number + rvalues = reversed(range(bitcount(abs(args[0])))) + qubit_values = [(args[0] >> i) & 1 for i in rvalues] + return QubitState._eval_args(qubit_values) + # For two numbers, the second number is the number of bits + # on which it is expressed, so IntQubit(0,5) == |00000>. + elif len(args) == 2 and args[1] > 1: + return cls._eval_args_with_nqubits(args[0], args[1]) + else: + return QubitState._eval_args(args) + + @classmethod + def _eval_args_with_nqubits(cls, number, nqubits): + need = bitcount(abs(number)) + if nqubits < need: + raise ValueError( + 'cannot represent %s with %s bits' % (number, nqubits)) + qubit_values = [(number >> i) & 1 for i in reversed(range(nqubits))] + return QubitState._eval_args(qubit_values) + + def as_int(self): + """Return the numerical value of the qubit.""" + number = 0 + n = 1 + for i in reversed(self.qubit_values): + number += n*i + n = n << 1 + return number + + def _print_label(self, printer, *args): + return str(self.as_int()) + + def _print_label_pretty(self, printer, *args): + label = self._print_label(printer, *args) + return prettyForm(label) + + _print_label_repr = _print_label + _print_label_latex = _print_label + + +class IntQubit(IntQubitState, Qubit): + """A qubit ket that store integers as binary numbers in qubit values. + + The differences between this class and ``Qubit`` are: + + * The form of the constructor. + * The qubit values are printed as their corresponding integer, rather + than the raw qubit values. The internal storage format of the qubit + values in the same as ``Qubit``. + + Parameters + ========== + + values : int, tuple + If a single argument, the integer we want to represent in the qubit + values. This integer will be represented using the fewest possible + number of qubits. + If a pair of integers and the second value is more than one, the first + integer gives the integer to represent in binary form and the second + integer gives the number of qubits to use. + List of zeros and ones is also accepted to generate qubit by bit pattern. + + nqubits : int + The integer that represents the number of qubits. + This number should be passed with keyword ``nqubits=N``. + You can use this in order to avoid ambiguity of Qubit-style tuple of bits. + Please see the example below for more details. + + Examples + ======== + + Create a qubit for the integer 5: + + >>> from sympy.physics.quantum.qubit import IntQubit + >>> from sympy.physics.quantum.qubit import Qubit + >>> q = IntQubit(5) + >>> q + |5> + + We can also create an ``IntQubit`` by passing a ``Qubit`` instance. + + >>> q = IntQubit(Qubit('101')) + >>> q + |5> + >>> q.as_int() + 5 + >>> q.nqubits + 3 + >>> q.qubit_values + (1, 0, 1) + + We can go back to the regular qubit form. + + >>> Qubit(q) + |101> + + Please note that ``IntQubit`` also accepts a ``Qubit``-style list of bits. + So, the code below yields qubits 3, not a single bit ``1``. + + >>> IntQubit(1, 1) + |3> + + To avoid ambiguity, use ``nqubits`` parameter. + Use of this keyword is recommended especially when you provide the values by variables. + + >>> IntQubit(1, nqubits=1) + |1> + >>> a = 1 + >>> IntQubit(a, nqubits=1) + |1> + """ + @classmethod + def dual_class(self): + return IntQubitBra + + def _eval_innerproduct_IntQubitBra(self, bra, **hints): + return Qubit._eval_innerproduct_QubitBra(self, bra) + +class IntQubitBra(IntQubitState, QubitBra): + """A qubit bra that store integers as binary numbers in qubit values.""" + + @classmethod + def dual_class(self): + return IntQubit + + +#----------------------------------------------------------------------------- +# Qubit <---> Matrix conversion functions +#----------------------------------------------------------------------------- + + +def matrix_to_qubit(matrix): + """Convert from the matrix repr. to a sum of Qubit objects. + + Parameters + ---------- + matrix : Matrix, numpy.matrix, scipy.sparse + The matrix to build the Qubit representation of. This works with + SymPy matrices, numpy matrices and scipy.sparse sparse matrices. + + Examples + ======== + + Represent a state and then go back to its qubit form: + + >>> from sympy.physics.quantum.qubit import matrix_to_qubit, Qubit + >>> from sympy.physics.quantum.represent import represent + >>> q = Qubit('01') + >>> matrix_to_qubit(represent(q)) + |01> + """ + # Determine the format based on the type of the input matrix + format = 'sympy' + if isinstance(matrix, numpy_ndarray): + format = 'numpy' + if isinstance(matrix, scipy_sparse_matrix): + format = 'scipy.sparse' + + # Make sure it is of correct dimensions for a Qubit-matrix representation. + # This logic should work with sympy, numpy or scipy.sparse matrices. + if matrix.shape[0] == 1: + mlistlen = matrix.shape[1] + nqubits = log(mlistlen, 2) + ket = False + cls = QubitBra + elif matrix.shape[1] == 1: + mlistlen = matrix.shape[0] + nqubits = log(mlistlen, 2) + ket = True + cls = Qubit + else: + raise QuantumError( + 'Matrix must be a row/column vector, got %r' % matrix + ) + if not isinstance(nqubits, Integer): + raise QuantumError('Matrix must be a row/column vector of size ' + '2**nqubits, got: %r' % matrix) + # Go through each item in matrix, if element is non-zero, make it into a + # Qubit item times the element. + result = 0 + for i in range(mlistlen): + if ket: + element = matrix[i, 0] + else: + element = matrix[0, i] + if format in ('numpy', 'scipy.sparse'): + element = complex(element) + if element != 0.0: + # Form Qubit array; 0 in bit-locations where i is 0, 1 in + # bit-locations where i is 1 + qubit_array = [int(i & (1 << x) != 0) for x in range(nqubits)] + qubit_array.reverse() + result = result + element*cls(*qubit_array) + + # If SymPy simplified by pulling out a constant coefficient, undo that. + if isinstance(result, (Mul, Add, Pow)): + result = result.expand() + + return result + + +def matrix_to_density(mat): + """ + Works by finding the eigenvectors and eigenvalues of the matrix. + We know we can decompose rho by doing: + sum(EigenVal*|Eigenvect>>> from sympy.physics.quantum.qubit import Qubit, measure_all + >>> from sympy.physics.quantum.gate import H + >>> from sympy.physics.quantum.qapply import qapply + + >>> c = H(0)*H(1)*Qubit('00') + >>> c + H(0)*H(1)*|00> + >>> q = qapply(c) + >>> measure_all(q) + [(|00>, 1/4), (|01>, 1/4), (|10>, 1/4), (|11>, 1/4)] + """ + m = qubit_to_matrix(qubit, format) + + if format == 'sympy': + results = [] + + if normalize: + m = m.normalized() + + size = max(m.shape) # Max of shape to account for bra or ket + nqubits = int(math.log(size)/math.log(2)) + for i in range(size): + if m[i] != 0.0: + results.append( + (Qubit(IntQubit(i, nqubits=nqubits)), m[i]*conjugate(m[i])) + ) + return results + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def measure_partial(qubit, bits, format='sympy', normalize=True): + """Perform a partial ensemble measure on the specified qubits. + + Parameters + ========== + + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + bits : tuple + The qubits to measure. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ======= + + result : list + A list that consists of primitive states and their probabilities. + + Examples + ======== + + >>> from sympy.physics.quantum.qubit import Qubit, measure_partial + >>> from sympy.physics.quantum.gate import H + >>> from sympy.physics.quantum.qapply import qapply + + >>> c = H(0)*H(1)*Qubit('00') + >>> c + H(0)*H(1)*|00> + >>> q = qapply(c) + >>> measure_partial(q, (0,)) + [(sqrt(2)*|00>/2 + sqrt(2)*|10>/2, 1/2), (sqrt(2)*|01>/2 + sqrt(2)*|11>/2, 1/2)] + """ + m = qubit_to_matrix(qubit, format) + + if isinstance(bits, (SYMPY_INTS, Integer)): + bits = (int(bits),) + + if format == 'sympy': + if normalize: + m = m.normalized() + + possible_outcomes = _get_possible_outcomes(m, bits) + + # Form output from function. + output = [] + for outcome in possible_outcomes: + # Calculate probability of finding the specified bits with + # given values. + prob_of_outcome = 0 + prob_of_outcome += (outcome.H*outcome)[0] + + # If the output has a chance, append it to output with found + # probability. + if prob_of_outcome != 0: + if normalize: + next_matrix = matrix_to_qubit(outcome.normalized()) + else: + next_matrix = matrix_to_qubit(outcome) + + output.append(( + next_matrix, + prob_of_outcome + )) + + return output + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def measure_partial_oneshot(qubit, bits, format='sympy'): + """Perform a partial oneshot measurement on the specified qubits. + + A oneshot measurement is equivalent to performing a measurement on a + quantum system. This type of measurement does not return the probabilities + like an ensemble measurement does, but rather returns *one* of the + possible resulting states. The exact state that is returned is determined + by picking a state randomly according to the ensemble probabilities. + + Parameters + ---------- + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + bits : tuple + The qubits to measure. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ------- + result : Qubit + The qubit that the system collapsed to upon measurement. + """ + import random + m = qubit_to_matrix(qubit, format) + + if format == 'sympy': + m = m.normalized() + possible_outcomes = _get_possible_outcomes(m, bits) + + # Form output from function + random_number = random.random() + total_prob = 0 + for outcome in possible_outcomes: + # Calculate probability of finding the specified bits + # with given values + total_prob += (outcome.H*outcome)[0] + if total_prob >= random_number: + return matrix_to_qubit(outcome.normalized()) + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) + + +def _get_possible_outcomes(m, bits): + """Get the possible states that can be produced in a measurement. + + Parameters + ---------- + m : Matrix + The matrix representing the state of the system. + bits : tuple, list + Which bits will be measured. + + Returns + ------- + result : list + The list of possible states which can occur given this measurement. + These are un-normalized so we can derive the probability of finding + this state by taking the inner product with itself + """ + + # This is filled with loads of dirty binary tricks...You have been warned + + size = max(m.shape) # Max of shape to account for bra or ket + nqubits = int(math.log2(size) + .1) # Number of qubits possible + + # Make the output states and put in output_matrices, nothing in them now. + # Each state will represent a possible outcome of the measurement + # Thus, output_matrices[0] is the matrix which we get when all measured + # bits return 0. and output_matrices[1] is the matrix for only the 0th + # bit being true + output_matrices = [] + for i in range(1 << len(bits)): + output_matrices.append(zeros(2**nqubits, 1)) + + # Bitmasks will help sort how to determine possible outcomes. + # When the bit mask is and-ed with a matrix-index, + # it will determine which state that index belongs to + bit_masks = [] + for bit in bits: + bit_masks.append(1 << bit) + + # Make possible outcome states + for i in range(2**nqubits): + trueness = 0 # This tells us to which output_matrix this value belongs + # Find trueness + for j in range(len(bit_masks)): + if i & bit_masks[j]: + trueness += j + 1 + # Put the value in the correct output matrix + output_matrices[trueness][i] = m[i] + return output_matrices + + +def measure_all_oneshot(qubit, format='sympy'): + """Perform a oneshot ensemble measurement on all qubits. + + A oneshot measurement is equivalent to performing a measurement on a + quantum system. This type of measurement does not return the probabilities + like an ensemble measurement does, but rather returns *one* of the + possible resulting states. The exact state that is returned is determined + by picking a state randomly according to the ensemble probabilities. + + Parameters + ---------- + qubits : Qubit + The qubit to measure. This can be any Qubit or a linear combination + of them. + format : str + The format of the intermediate matrices to use. Possible values are + ('sympy','numpy','scipy.sparse'). Currently only 'sympy' is + implemented. + + Returns + ------- + result : Qubit + The qubit that the system collapsed to upon measurement. + """ + import random + m = qubit_to_matrix(qubit) + + if format == 'sympy': + m = m.normalized() + random_number = random.random() + total = 0 + result = 0 + for i in m: + total += i*i.conjugate() + if total > random_number: + break + result += 1 + return Qubit(IntQubit(result, int(math.log2(max(m.shape)) + .1))) + else: + raise NotImplementedError( + "This function cannot handle non-SymPy matrix formats yet" + ) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/represent.py b/lib/python3.10/site-packages/sympy/physics/quantum/represent.py new file mode 100644 index 0000000000000000000000000000000000000000..cfb0ea6275716d31066ad40cb820d27086bc1f50 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/represent.py @@ -0,0 +1,574 @@ +"""Logic for representing operators in state in various bases. + +TODO: + +* Get represent working with continuous hilbert spaces. +* Document default basis functionality. +""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.integrals.integrals import integrate +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.matrixutils import flatten_scalar +from sympy.physics.quantum.state import KetBase, BraBase, StateBase +from sympy.physics.quantum.operator import Operator, OuterProduct +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.operatorset import operators_to_state, state_to_operators + +__all__ = [ + 'represent', + 'rep_innerproduct', + 'rep_expectation', + 'integrate_result', + 'get_basis', + 'enumerate_states' +] + +#----------------------------------------------------------------------------- +# Represent +#----------------------------------------------------------------------------- + + +def _sympy_to_scalar(e): + """Convert from a SymPy scalar to a Python scalar.""" + if isinstance(e, Expr): + if e.is_Integer: + return int(e) + elif e.is_Float: + return float(e) + elif e.is_Rational: + return float(e) + elif e.is_Number or e.is_NumberSymbol or e == I: + return complex(e) + raise TypeError('Expected number, got: %r' % e) + + +def represent(expr, **options): + """Represent the quantum expression in the given basis. + + In quantum mechanics abstract states and operators can be represented in + various basis sets. Under this operation the follow transforms happen: + + * Ket -> column vector or function + * Bra -> row vector of function + * Operator -> matrix or differential operator + + This function is the top-level interface for this action. + + This function walks the SymPy expression tree looking for ``QExpr`` + instances that have a ``_represent`` method. This method is then called + and the object is replaced by the representation returned by this method. + By default, the ``_represent`` method will dispatch to other methods + that handle the representation logic for a particular basis set. The + naming convention for these methods is the following:: + + def _represent_FooBasis(self, e, basis, **options) + + This function will have the logic for representing instances of its class + in the basis set having a class named ``FooBasis``. + + Parameters + ========== + + expr : Expr + The expression to represent. + basis : Operator, basis set + An object that contains the information about the basis set. If an + operator is used, the basis is assumed to be the orthonormal + eigenvectors of that operator. In general though, the basis argument + can be any object that contains the basis set information. + options : dict + Key/value pairs of options that are passed to the underlying method + that finds the representation. These options can be used to + control how the representation is done. For example, this is where + the size of the basis set would be set. + + Returns + ======= + + e : Expr + The SymPy expression of the represented quantum expression. + + Examples + ======== + + Here we subclass ``Operator`` and ``Ket`` to create the z-spin operator + and its spin 1/2 up eigenstate. By defining the ``_represent_SzOp`` + method, the ket can be represented in the z-spin basis. + + >>> from sympy.physics.quantum import Operator, represent, Ket + >>> from sympy import Matrix + + >>> class SzUpKet(Ket): + ... def _represent_SzOp(self, basis, **options): + ... return Matrix([1,0]) + ... + >>> class SzOp(Operator): + ... pass + ... + >>> sz = SzOp('Sz') + >>> up = SzUpKet('up') + >>> represent(up, basis=sz) + Matrix([ + [1], + [0]]) + + Here we see an example of representations in a continuous + basis. We see that the result of representing various combinations + of cartesian position operators and kets give us continuous + expressions involving DiracDelta functions. + + >>> from sympy.physics.quantum.cartesian import XOp, XKet, XBra + >>> X = XOp() + >>> x = XKet() + >>> y = XBra('y') + >>> represent(X*x) + x*DiracDelta(x - x_2) + >>> represent(X*x*y) + x*DiracDelta(x - x_3)*DiracDelta(x_1 - y) + + """ + + format = options.get('format', 'sympy') + if format == 'numpy': + import numpy as np + if isinstance(expr, QExpr) and not isinstance(expr, OuterProduct): + options['replace_none'] = False + temp_basis = get_basis(expr, **options) + if temp_basis is not None: + options['basis'] = temp_basis + try: + return expr._represent(**options) + except NotImplementedError as strerr: + #If no _represent_FOO method exists, map to the + #appropriate basis state and try + #the other methods of representation + options['replace_none'] = True + + if isinstance(expr, (KetBase, BraBase)): + try: + return rep_innerproduct(expr, **options) + except NotImplementedError: + raise NotImplementedError(strerr) + elif isinstance(expr, Operator): + try: + return rep_expectation(expr, **options) + except NotImplementedError: + raise NotImplementedError(strerr) + else: + raise NotImplementedError(strerr) + elif isinstance(expr, Add): + result = represent(expr.args[0], **options) + for args in expr.args[1:]: + # scipy.sparse doesn't support += so we use plain = here. + result = result + represent(args, **options) + return result + elif isinstance(expr, Pow): + base, exp = expr.as_base_exp() + if format in ('numpy', 'scipy.sparse'): + exp = _sympy_to_scalar(exp) + base = represent(base, **options) + # scipy.sparse doesn't support negative exponents + # and warns when inverting a matrix in csr format. + if format == 'scipy.sparse' and exp < 0: + from scipy.sparse.linalg import inv + exp = - exp + base = inv(base.tocsc()).tocsr() + if format == 'numpy': + return np.linalg.matrix_power(base, exp) + return base ** exp + elif isinstance(expr, TensorProduct): + new_args = [represent(arg, **options) for arg in expr.args] + return TensorProduct(*new_args) + elif isinstance(expr, Dagger): + return Dagger(represent(expr.args[0], **options)) + elif isinstance(expr, Commutator): + A = expr.args[0] + B = expr.args[1] + return represent(Mul(A, B) - Mul(B, A), **options) + elif isinstance(expr, AntiCommutator): + A = expr.args[0] + B = expr.args[1] + return represent(Mul(A, B) + Mul(B, A), **options) + elif isinstance(expr, InnerProduct): + return represent(Mul(expr.bra, expr.ket), **options) + elif not isinstance(expr, (Mul, OuterProduct)): + # For numpy and scipy.sparse, we can only handle numerical prefactors. + if format in ('numpy', 'scipy.sparse'): + return _sympy_to_scalar(expr) + return expr + + if not isinstance(expr, (Mul, OuterProduct)): + raise TypeError('Mul expected, got: %r' % expr) + + if "index" in options: + options["index"] += 1 + else: + options["index"] = 1 + + if "unities" not in options: + options["unities"] = [] + + result = represent(expr.args[-1], **options) + last_arg = expr.args[-1] + + for arg in reversed(expr.args[:-1]): + if isinstance(last_arg, Operator): + options["index"] += 1 + options["unities"].append(options["index"]) + elif isinstance(last_arg, BraBase) and isinstance(arg, KetBase): + options["index"] += 1 + elif isinstance(last_arg, KetBase) and isinstance(arg, Operator): + options["unities"].append(options["index"]) + elif isinstance(last_arg, KetBase) and isinstance(arg, BraBase): + options["unities"].append(options["index"]) + + next_arg = represent(arg, **options) + if format == 'numpy' and isinstance(next_arg, np.ndarray): + # Must use np.matmult to "matrix multiply" two np.ndarray + result = np.matmul(next_arg, result) + else: + result = next_arg*result + last_arg = arg + + # All three matrix formats create 1 by 1 matrices when inner products of + # vectors are taken. In these cases, we simply return a scalar. + result = flatten_scalar(result) + + result = integrate_result(expr, result, **options) + + return result + + +def rep_innerproduct(expr, **options): + """ + Returns an innerproduct like representation (e.g. ````) for the + given state. + + Attempts to calculate inner product with a bra from the specified + basis. Should only be passed an instance of KetBase or BraBase + + Parameters + ========== + + expr : KetBase or BraBase + The expression to be represented + + Examples + ======== + + >>> from sympy.physics.quantum.represent import rep_innerproduct + >>> from sympy.physics.quantum.cartesian import XOp, XKet, PxOp, PxKet + >>> rep_innerproduct(XKet()) + DiracDelta(x - x_1) + >>> rep_innerproduct(XKet(), basis=PxOp()) + sqrt(2)*exp(-I*px_1*x/hbar)/(2*sqrt(hbar)*sqrt(pi)) + >>> rep_innerproduct(PxKet(), basis=XOp()) + sqrt(2)*exp(I*px*x_1/hbar)/(2*sqrt(hbar)*sqrt(pi)) + + """ + + if not isinstance(expr, (KetBase, BraBase)): + raise TypeError("expr passed is not a Bra or Ket") + + basis = get_basis(expr, **options) + + if not isinstance(basis, StateBase): + raise NotImplementedError("Can't form this representation!") + + if "index" not in options: + options["index"] = 1 + + basis_kets = enumerate_states(basis, options["index"], 2) + + if isinstance(expr, BraBase): + bra = expr + ket = (basis_kets[1] if basis_kets[0].dual == expr else basis_kets[0]) + else: + bra = (basis_kets[1].dual if basis_kets[0] + == expr else basis_kets[0].dual) + ket = expr + + prod = InnerProduct(bra, ket) + result = prod.doit() + + format = options.get('format', 'sympy') + return expr._format_represent(result, format) + + +def rep_expectation(expr, **options): + """ + Returns an ```` type representation for the given operator. + + Parameters + ========== + + expr : Operator + Operator to be represented in the specified basis + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XOp, PxOp, PxKet + >>> from sympy.physics.quantum.represent import rep_expectation + >>> rep_expectation(XOp()) + x_1*DiracDelta(x_1 - x_2) + >>> rep_expectation(XOp(), basis=PxOp()) + + >>> rep_expectation(XOp(), basis=PxKet()) + + + """ + + if "index" not in options: + options["index"] = 1 + + if not isinstance(expr, Operator): + raise TypeError("The passed expression is not an operator") + + basis_state = get_basis(expr, **options) + + if basis_state is None or not isinstance(basis_state, StateBase): + raise NotImplementedError("Could not get basis kets for this operator") + + basis_kets = enumerate_states(basis_state, options["index"], 2) + + bra = basis_kets[1].dual + ket = basis_kets[0] + + return qapply(bra*expr*ket) + + +def integrate_result(orig_expr, result, **options): + """ + Returns the result of integrating over any unities ``(|x>>> from sympy import symbols, DiracDelta + >>> from sympy.physics.quantum.represent import integrate_result + >>> from sympy.physics.quantum.cartesian import XOp, XKet + >>> x_ket = XKet() + >>> X_op = XOp() + >>> x, x_1, x_2 = symbols('x, x_1, x_2') + >>> integrate_result(X_op*x_ket, x*DiracDelta(x-x_1)*DiracDelta(x_1-x_2)) + x*DiracDelta(x - x_1)*DiracDelta(x_1 - x_2) + >>> integrate_result(X_op*x_ket, x*DiracDelta(x-x_1)*DiracDelta(x_1-x_2), + ... unities=[1]) + x*DiracDelta(x - x_2) + + """ + if not isinstance(result, Expr): + return result + + options['replace_none'] = True + if "basis" not in options: + arg = orig_expr.args[-1] + options["basis"] = get_basis(arg, **options) + elif not isinstance(options["basis"], StateBase): + options["basis"] = get_basis(orig_expr, **options) + + basis = options.pop("basis", None) + + if basis is None: + return result + + unities = options.pop("unities", []) + + if len(unities) == 0: + return result + + kets = enumerate_states(basis, unities) + coords = [k.label[0] for k in kets] + + for coord in coords: + if coord in result.free_symbols: + #TODO: Add support for sets of operators + basis_op = state_to_operators(basis) + start = basis_op.hilbert_space.interval.start + end = basis_op.hilbert_space.interval.end + result = integrate(result, (coord, start, end)) + + return result + + +def get_basis(expr, *, basis=None, replace_none=True, **options): + """ + Returns a basis state instance corresponding to the basis specified in + options=s. If no basis is specified, the function tries to form a default + basis state of the given expression. + + There are three behaviors: + + 1. The basis specified in options is already an instance of StateBase. If + this is the case, it is simply returned. If the class is specified but + not an instance, a default instance is returned. + + 2. The basis specified is an operator or set of operators. If this + is the case, the operator_to_state mapping method is used. + + 3. No basis is specified. If expr is a state, then a default instance of + its class is returned. If expr is an operator, then it is mapped to the + corresponding state. If it is neither, then we cannot obtain the basis + state. + + If the basis cannot be mapped, then it is not changed. + + This will be called from within represent, and represent will + only pass QExpr's. + + TODO (?): Support for Muls and other types of expressions? + + Parameters + ========== + + expr : Operator or StateBase + Expression whose basis is sought + + Examples + ======== + + >>> from sympy.physics.quantum.represent import get_basis + >>> from sympy.physics.quantum.cartesian import XOp, XKet, PxOp, PxKet + >>> x = XKet() + >>> X = XOp() + >>> get_basis(x) + |x> + >>> get_basis(X) + |x> + >>> get_basis(x, basis=PxOp()) + |px> + >>> get_basis(x, basis=PxKet) + |px> + + """ + + if basis is None and not replace_none: + return None + + if basis is None: + if isinstance(expr, KetBase): + return _make_default(expr.__class__) + elif isinstance(expr, BraBase): + return _make_default(expr.dual_class()) + elif isinstance(expr, Operator): + state_inst = operators_to_state(expr) + return (state_inst if state_inst is not None else None) + else: + return None + elif (isinstance(basis, Operator) or + (not isinstance(basis, StateBase) and issubclass(basis, Operator))): + state = operators_to_state(basis) + if state is None: + return None + elif isinstance(state, StateBase): + return state + else: + return _make_default(state) + elif isinstance(basis, StateBase): + return basis + elif issubclass(basis, StateBase): + return _make_default(basis) + else: + return None + + +def _make_default(expr): + # XXX: Catching TypeError like this is a bad way of distinguishing + # instances from classes. The logic using this function should be + # rewritten somehow. + try: + expr = expr() + except TypeError: + return expr + + return expr + + +def enumerate_states(*args, **options): + """ + Returns instances of the given state with dummy indices appended + + Operates in two different modes: + + 1. Two arguments are passed to it. The first is the base state which is to + be indexed, and the second argument is a list of indices to append. + + 2. Three arguments are passed. The first is again the base state to be + indexed. The second is the start index for counting. The final argument + is the number of kets you wish to receive. + + Tries to call state._enumerate_state. If this fails, returns an empty list + + Parameters + ========== + + args : list + See list of operation modes above for explanation + + Examples + ======== + + >>> from sympy.physics.quantum.cartesian import XBra, XKet + >>> from sympy.physics.quantum.represent import enumerate_states + >>> test = XKet('foo') + >>> enumerate_states(test, 1, 3) + [|foo_1>, |foo_2>, |foo_3>] + >>> test2 = XBra('bar') + >>> enumerate_states(test2, [4, 5, 10]) + [>> from sympy.physics.quantum.sho1d import RaisingOp + >>> from sympy.physics.quantum import Dagger + + >>> ad = RaisingOp('a') + >>> ad.rewrite('xp').doit() + sqrt(2)*(m*omega*X - I*Px)/(2*sqrt(hbar)*sqrt(m*omega)) + + >>> Dagger(ad) + a + + Taking the commutator of a^dagger with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import RaisingOp, LoweringOp + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> ad = RaisingOp('a') + >>> a = LoweringOp('a') + >>> N = NumberOp('N') + >>> Commutator(ad, a).doit() + -1 + >>> Commutator(ad, N).doit() + -RaisingOp(a) + + Apply a^dagger to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import RaisingOp, SHOKet + + >>> ad = RaisingOp('a') + >>> k = SHOKet('k') + >>> qapply(ad*k) + sqrt(k + 1)*|k + 1> + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import RaisingOp + >>> from sympy.physics.quantum.represent import represent + >>> ad = RaisingOp('a') + >>> represent(ad, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 0, 0, 0], + [1, 0, 0, 0], + [0, sqrt(2), 0, 0], + [0, 0, sqrt(3), 0]]) + + """ + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/sqrt(Integer(2)*hbar*m*omega))*( + S.NegativeOne*I*Px + m*omega*X) + + def _eval_adjoint(self): + return LoweringOp(*self.args) + + def _eval_commutator_LoweringOp(self, other): + return S.NegativeOne + + def _eval_commutator_NumberOp(self, other): + return S.NegativeOne*self + + def _apply_operator_SHOKet(self, ket, **options): + temp = ket.n + S.One + return sqrt(temp)*SHOKet(temp) + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format','sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info - 1): + value = sqrt(i + 1) + if format == 'scipy.sparse': + value = float(value) + matrix[i + 1, i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + #-------------------------------------------------------------------------- + # Printing Methods + #-------------------------------------------------------------------------- + + def _print_contents(self, printer, *args): + arg0 = printer._print(self.args[0], *args) + return '%s(%s)' % (self.__class__.__name__, arg0) + + def _print_contents_pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + pform = printer._print(self.args[0], *args) + pform = pform**prettyForm('\N{DAGGER}') + return pform + + def _print_contents_latex(self, printer, *args): + arg = printer._print(self.args[0]) + return '%s^{\\dagger}' % arg + +class LoweringOp(SHOOp): + """The Lowering Operator or 'a'. + + When 'a' acts on a state it lowers the state up by one. Taking + the adjoint of 'a' returns a^dagger, the Raising Operator. 'a' + can be rewritten in terms of position and momentum. We can + represent 'a' as a matrix, which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Lowering Operator and rewrite it in terms of position and + momentum, and show that taking its adjoint returns a^dagger: + + >>> from sympy.physics.quantum.sho1d import LoweringOp + >>> from sympy.physics.quantum import Dagger + + >>> a = LoweringOp('a') + >>> a.rewrite('xp').doit() + sqrt(2)*(m*omega*X + I*Px)/(2*sqrt(hbar)*sqrt(m*omega)) + + >>> Dagger(a) + RaisingOp(a) + + Taking the commutator of 'a' with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import LoweringOp, RaisingOp + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> a = LoweringOp('a') + >>> ad = RaisingOp('a') + >>> N = NumberOp('N') + >>> Commutator(a, ad).doit() + 1 + >>> Commutator(a, N).doit() + a + + Apply 'a' to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import LoweringOp, SHOKet + + >>> a = LoweringOp('a') + >>> k = SHOKet('k') + >>> qapply(a*k) + sqrt(k)*|k - 1> + + Taking 'a' of the lowest state will return 0: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import LoweringOp, SHOKet + + >>> a = LoweringOp('a') + >>> k = SHOKet(0) + >>> qapply(a*k) + 0 + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import LoweringOp + >>> from sympy.physics.quantum.represent import represent + >>> a = LoweringOp('a') + >>> represent(a, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 1, 0, 0], + [0, 0, sqrt(2), 0], + [0, 0, 0, sqrt(3)], + [0, 0, 0, 0]]) + + """ + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/sqrt(Integer(2)*hbar*m*omega))*( + I*Px + m*omega*X) + + def _eval_adjoint(self): + return RaisingOp(*self.args) + + def _eval_commutator_RaisingOp(self, other): + return S.One + + def _eval_commutator_NumberOp(self, other): + return self + + def _apply_operator_SHOKet(self, ket, **options): + temp = ket.n - Integer(1) + if ket.n is S.Zero: + return S.Zero + else: + return sqrt(ket.n)*SHOKet(temp) + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info - 1): + value = sqrt(i + 1) + if format == 'scipy.sparse': + value = float(value) + matrix[i,i + 1] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + +class NumberOp(SHOOp): + """The Number Operator is simply a^dagger*a + + It is often useful to write a^dagger*a as simply the Number Operator + because the Number Operator commutes with the Hamiltonian. And can be + expressed using the Number Operator. Also the Number Operator can be + applied to states. We can represent the Number Operator as a matrix, + which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Number Operator and rewrite it in terms of the ladder + operators, position and momentum operators, and Hamiltonian: + + >>> from sympy.physics.quantum.sho1d import NumberOp + + >>> N = NumberOp('N') + >>> N.rewrite('a').doit() + RaisingOp(a)*a + >>> N.rewrite('xp').doit() + -1/2 + (m**2*omega**2*X**2 + Px**2)/(2*hbar*m*omega) + >>> N.rewrite('H').doit() + -1/2 + H/(hbar*omega) + + Take the Commutator of the Number Operator with other Operators: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import NumberOp, Hamiltonian + >>> from sympy.physics.quantum.sho1d import RaisingOp, LoweringOp + + >>> N = NumberOp('N') + >>> H = Hamiltonian('H') + >>> ad = RaisingOp('a') + >>> a = LoweringOp('a') + >>> Commutator(N,H).doit() + 0 + >>> Commutator(N,ad).doit() + RaisingOp(a) + >>> Commutator(N,a).doit() + -a + + Apply the Number Operator to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import NumberOp, SHOKet + + >>> N = NumberOp('N') + >>> k = SHOKet('k') + >>> qapply(N*k) + k*|k> + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import NumberOp + >>> from sympy.physics.quantum.represent import represent + >>> N = NumberOp('N') + >>> represent(N, basis=N, ndim=4, format='sympy') + Matrix([ + [0, 0, 0, 0], + [0, 1, 0, 0], + [0, 0, 2, 0], + [0, 0, 0, 3]]) + + """ + + def _eval_rewrite_as_a(self, *args, **kwargs): + return ad*a + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/(Integer(2)*m*hbar*omega))*(Px**2 + ( + m*omega*X)**2) - S.Half + + def _eval_rewrite_as_H(self, *args, **kwargs): + return H/(hbar*omega) - S.Half + + def _apply_operator_SHOKet(self, ket, **options): + return ket.n*ket + + def _eval_commutator_Hamiltonian(self, other): + return S.Zero + + def _eval_commutator_RaisingOp(self, other): + return other + + def _eval_commutator_LoweringOp(self, other): + return S.NegativeOne*other + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info): + value = i + if format == 'scipy.sparse': + value = float(value) + matrix[i,i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return matrix + + +class Hamiltonian(SHOOp): + """The Hamiltonian Operator. + + The Hamiltonian is used to solve the time-independent Schrodinger + equation. The Hamiltonian can be expressed using the ladder operators, + as well as by position and momentum. We can represent the Hamiltonian + Operator as a matrix, which will be its default basis. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + operator. + + Examples + ======== + + Create a Hamiltonian Operator and rewrite it in terms of the ladder + operators, position and momentum, and the Number Operator: + + >>> from sympy.physics.quantum.sho1d import Hamiltonian + + >>> H = Hamiltonian('H') + >>> H.rewrite('a').doit() + hbar*omega*(1/2 + RaisingOp(a)*a) + >>> H.rewrite('xp').doit() + (m**2*omega**2*X**2 + Px**2)/(2*m) + >>> H.rewrite('N').doit() + hbar*omega*(1/2 + N) + + Take the Commutator of the Hamiltonian and the Number Operator: + + >>> from sympy.physics.quantum import Commutator + >>> from sympy.physics.quantum.sho1d import Hamiltonian, NumberOp + + >>> H = Hamiltonian('H') + >>> N = NumberOp('N') + >>> Commutator(H,N).doit() + 0 + + Apply the Hamiltonian Operator to a state: + + >>> from sympy.physics.quantum import qapply + >>> from sympy.physics.quantum.sho1d import Hamiltonian, SHOKet + + >>> H = Hamiltonian('H') + >>> k = SHOKet('k') + >>> qapply(H*k) + hbar*k*omega*|k> + hbar*omega*|k>/2 + + Matrix Representation + + >>> from sympy.physics.quantum.sho1d import Hamiltonian + >>> from sympy.physics.quantum.represent import represent + + >>> H = Hamiltonian('H') + >>> represent(H, basis=N, ndim=4, format='sympy') + Matrix([ + [hbar*omega/2, 0, 0, 0], + [ 0, 3*hbar*omega/2, 0, 0], + [ 0, 0, 5*hbar*omega/2, 0], + [ 0, 0, 0, 7*hbar*omega/2]]) + + """ + + def _eval_rewrite_as_a(self, *args, **kwargs): + return hbar*omega*(ad*a + S.Half) + + def _eval_rewrite_as_xp(self, *args, **kwargs): + return (S.One/(Integer(2)*m))*(Px**2 + (m*omega*X)**2) + + def _eval_rewrite_as_N(self, *args, **kwargs): + return hbar*omega*(N + S.Half) + + def _apply_operator_SHOKet(self, ket, **options): + return (hbar*omega*(ket.n + S.Half))*ket + + def _eval_commutator_NumberOp(self, other): + return S.Zero + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_XOp(self, basis, **options): + # This logic is good but the underlying position + # representation logic is broken. + # temp = self.rewrite('xp').doit() + # result = represent(temp, basis=X) + # return result + raise NotImplementedError('Position representation is not implemented') + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + matrix = matrix_zeros(ndim_info, ndim_info, **options) + for i in range(ndim_info): + value = i + S.Half + if format == 'scipy.sparse': + value = float(value) + matrix[i,i] = value + if format == 'scipy.sparse': + matrix = matrix.tocsr() + return hbar*omega*matrix + +#------------------------------------------------------------------------------ + +class SHOState(State): + """State class for SHO states""" + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(S.Infinity) + + @property + def n(self): + return self.args[0] + + +class SHOKet(SHOState, Ket): + """1D eigenket. + + Inherits from SHOState and Ket. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket + This is usually its quantum numbers or its symbol. + + Examples + ======== + + Ket's know about their associated bra: + + >>> from sympy.physics.quantum.sho1d import SHOKet + + >>> k = SHOKet('k') + >>> k.dual + >> k.dual_class() + + + Take the Inner Product with a bra: + + >>> from sympy.physics.quantum import InnerProduct + >>> from sympy.physics.quantum.sho1d import SHOKet, SHOBra + + >>> k = SHOKet('k') + >>> b = SHOBra('b') + >>> InnerProduct(b,k).doit() + KroneckerDelta(b, k) + + Vector representation of a numerical state ket: + + >>> from sympy.physics.quantum.sho1d import SHOKet, NumberOp + >>> from sympy.physics.quantum.represent import represent + + >>> k = SHOKet(3) + >>> N = NumberOp('N') + >>> represent(k, basis=N, ndim=4) + Matrix([ + [0], + [0], + [0], + [1]]) + + """ + + @classmethod + def dual_class(self): + return SHOBra + + def _eval_innerproduct_SHOBra(self, bra, **hints): + result = KroneckerDelta(self.n, bra.n) + return result + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + options['spmatrix'] = 'lil' + vector = matrix_zeros(ndim_info, 1, **options) + if isinstance(self.n, Integer): + if self.n >= ndim_info: + return ValueError("N-Dimension too small") + if format == 'scipy.sparse': + vector[int(self.n), 0] = 1.0 + vector = vector.tocsr() + elif format == 'numpy': + vector[int(self.n), 0] = 1.0 + else: + vector[self.n, 0] = S.One + return vector + else: + return ValueError("Not Numerical State") + + +class SHOBra(SHOState, Bra): + """A time-independent Bra in SHO. + + Inherits from SHOState and Bra. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket + This is usually its quantum numbers or its symbol. + + Examples + ======== + + Bra's know about their associated ket: + + >>> from sympy.physics.quantum.sho1d import SHOBra + + >>> b = SHOBra('b') + >>> b.dual + |b> + >>> b.dual_class() + + + Vector representation of a numerical state bra: + + >>> from sympy.physics.quantum.sho1d import SHOBra, NumberOp + >>> from sympy.physics.quantum.represent import represent + + >>> b = SHOBra(3) + >>> N = NumberOp('N') + >>> represent(b, basis=N, ndim=4) + Matrix([[0, 0, 0, 1]]) + + """ + + @classmethod + def dual_class(self): + return SHOKet + + def _represent_default_basis(self, **options): + return self._represent_NumberOp(None, **options) + + def _represent_NumberOp(self, basis, **options): + ndim_info = options.get('ndim', 4) + format = options.get('format', 'sympy') + options['spmatrix'] = 'lil' + vector = matrix_zeros(1, ndim_info, **options) + if isinstance(self.n, Integer): + if self.n >= ndim_info: + return ValueError("N-Dimension too small") + if format == 'scipy.sparse': + vector[0, int(self.n)] = 1.0 + vector = vector.tocsr() + elif format == 'numpy': + vector[0, int(self.n)] = 1.0 + else: + vector[0, self.n] = S.One + return vector + else: + return ValueError("Not Numerical State") + + +ad = RaisingOp('a') +a = LoweringOp('a') +H = Hamiltonian('H') +N = NumberOp('N') +omega = Symbol('omega') +m = Symbol('m') diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/shor.py b/lib/python3.10/site-packages/sympy/physics/quantum/shor.py new file mode 100644 index 0000000000000000000000000000000000000000..fc9e55229d74634bdb82efc03c2d1649e088efb3 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/shor.py @@ -0,0 +1,173 @@ +"""Shor's algorithm and helper functions. + +Todo: + +* Get the CMod gate working again using the new Gate API. +* Fix everything. +* Update docstrings and reformat. +""" + +import math +import random + +from sympy.core.mul import Mul +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.core.intfunc import igcd +from sympy.ntheory import continued_fraction_periodic as continued_fraction +from sympy.utilities.iterables import variations + +from sympy.physics.quantum.gate import Gate +from sympy.physics.quantum.qubit import Qubit, measure_partial_oneshot +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qft import QFT +from sympy.physics.quantum.qexpr import QuantumError + + +class OrderFindingException(QuantumError): + pass + + +class CMod(Gate): + """A controlled mod gate. + + This is black box controlled Mod function for use by shor's algorithm. + TODO: implement a decompose property that returns how to do this in terms + of elementary gates + """ + + @classmethod + def _eval_args(cls, args): + # t = args[0] + # a = args[1] + # N = args[2] + raise NotImplementedError('The CMod gate has not been completed.') + + @property + def t(self): + """Size of 1/2 input register. First 1/2 holds output.""" + return self.label[0] + + @property + def a(self): + """Base of the controlled mod function.""" + return self.label[1] + + @property + def N(self): + """N is the type of modular arithmetic we are doing.""" + return self.label[2] + + def _apply_operator_Qubit(self, qubits, **options): + """ + This directly calculates the controlled mod of the second half of + the register and puts it in the second + This will look pretty when we get Tensor Symbolically working + """ + n = 1 + k = 0 + # Determine the value stored in high memory. + for i in range(self.t): + k += n*qubits[self.t + i] + n *= 2 + + # The value to go in low memory will be out. + out = int(self.a**k % self.N) + + # Create array for new qbit-ket which will have high memory unaffected + outarray = list(qubits.args[0][:self.t]) + + # Place out in low memory + for i in reversed(range(self.t)): + outarray.append((out >> i) & 1) + + return Qubit(*outarray) + + +def shor(N): + """This function implements Shor's factoring algorithm on the Integer N + + The algorithm starts by picking a random number (a) and seeing if it is + coprime with N. If it is not, then the gcd of the two numbers is a factor + and we are done. Otherwise, it begins the period_finding subroutine which + finds the period of a in modulo N arithmetic. This period, if even, can + be used to calculate factors by taking a**(r/2)-1 and a**(r/2)+1. + These values are returned. + """ + a = random.randrange(N - 2) + 2 + if igcd(N, a) != 1: + return igcd(N, a) + r = period_find(a, N) + if r % 2 == 1: + shor(N) + answer = (igcd(a**(r/2) - 1, N), igcd(a**(r/2) + 1, N)) + return answer + + +def getr(x, y, N): + fraction = continued_fraction(x, y) + # Now convert into r + total = ratioize(fraction, N) + return total + + +def ratioize(list, N): + if list[0] > N: + return S.Zero + if len(list) == 1: + return list[0] + return list[0] + ratioize(list[1:], N) + + +def period_find(a, N): + """Finds the period of a in modulo N arithmetic + + This is quantum part of Shor's algorithm. It takes two registers, + puts first in superposition of states with Hadamards so: ``|k>|0>`` + with k being all possible choices. It then does a controlled mod and + a QFT to determine the order of a. + """ + epsilon = .5 + # picks out t's such that maintains accuracy within epsilon + t = int(2*math.ceil(log(N, 2))) + # make the first half of register be 0's |000...000> + start = [0 for x in range(t)] + # Put second half into superposition of states so we have |1>x|0> + |2>x|0> + ... |k>x>|0> + ... + |2**n-1>x|0> + factor = 1/sqrt(2**t) + qubits = 0 + for arr in variations(range(2), t, repetition=True): + qbitArray = list(arr) + start + qubits = qubits + Qubit(*qbitArray) + circuit = (factor*qubits).expand() + # Controlled second half of register so that we have: + # |1>x|a**1 %N> + |2>x|a**2 %N> + ... + |k>x|a**k %N >+ ... + |2**n-1=k>x|a**k % n> + circuit = CMod(t, a, N)*circuit + # will measure first half of register giving one of the a**k%N's + + circuit = qapply(circuit) + for i in range(t): + circuit = measure_partial_oneshot(circuit, i) + # Now apply Inverse Quantum Fourier Transform on the second half of the register + + circuit = qapply(QFT(t, t*2).decompose()*circuit, floatingPoint=True) + for i in range(t): + circuit = measure_partial_oneshot(circuit, i + t) + if isinstance(circuit, Qubit): + register = circuit + elif isinstance(circuit, Mul): + register = circuit.args[-1] + else: + register = circuit.args[-1].args[-1] + + n = 1 + answer = 0 + for i in range(len(register)/2): + answer += n*register[i + t] + n = n << 1 + if answer == 0: + raise OrderFindingException( + "Order finder returned 0. Happens with chance %f" % epsilon) + #turn answer into r using continued fractions + g = getr(answer, 2**t, N) + return g diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/spin.py b/lib/python3.10/site-packages/sympy/physics/quantum/spin.py new file mode 100644 index 0000000000000000000000000000000000000000..6c568d36c57be38702b770f6fa95f4dc6a00ed15 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/spin.py @@ -0,0 +1,2150 @@ +"""Quantum mechanical angular momemtum.""" + +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.numbers import int_valued +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.simplify.simplify import simplify +from sympy.matrices import zeros +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.pretty.pretty_symbology import pretty_symbol + +from sympy.physics.quantum.qexpr import QExpr +from sympy.physics.quantum.operator import (HermitianOperator, Operator, + UnitaryOperator) +from sympy.physics.quantum.state import Bra, Ket, State +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.physics.quantum.constants import hbar +from sympy.physics.quantum.hilbert import ComplexSpace, DirectSumHilbertSpace +from sympy.physics.quantum.tensorproduct import TensorProduct +from sympy.physics.quantum.cg import CG +from sympy.physics.quantum.qapply import qapply + + +__all__ = [ + 'm_values', + 'Jplus', + 'Jminus', + 'Jx', + 'Jy', + 'Jz', + 'J2', + 'Rotation', + 'WignerD', + 'JxKet', + 'JxBra', + 'JyKet', + 'JyBra', + 'JzKet', + 'JzBra', + 'JzOp', + 'J2Op', + 'JxKetCoupled', + 'JxBraCoupled', + 'JyKetCoupled', + 'JyBraCoupled', + 'JzKetCoupled', + 'JzBraCoupled', + 'couple', + 'uncouple' +] + + +def m_values(j): + j = sympify(j) + size = 2*j + 1 + if not size.is_Integer or not size > 0: + raise ValueError( + 'Only integer or half-integer values allowed for j, got: : %r' % j + ) + return size, [j - i for i in range(int(2*j + 1))] + + +#----------------------------------------------------------------------------- +# Spin Operators +#----------------------------------------------------------------------------- + + +class SpinOpBase: + """Base class for spin operators.""" + + @classmethod + def _eval_hilbert_space(cls, label): + # We consider all j values so our space is infinite. + return ComplexSpace(S.Infinity) + + @property + def name(self): + return self.args[0] + + def _print_contents(self, printer, *args): + return '%s%s' % (self.name, self._coord) + + def _print_contents_pretty(self, printer, *args): + a = stringPict(str(self.name)) + b = stringPict(self._coord) + return self._print_subscript_pretty(a, b) + + def _print_contents_latex(self, printer, *args): + return r'%s_%s' % ((self.name, self._coord)) + + def _represent_base(self, basis, **options): + j = options.get('j', S.Half) + size, mvals = m_values(j) + result = zeros(size, size) + for p in range(size): + for q in range(size): + me = self.matrix_element(j, mvals[p], j, mvals[q]) + result[p, q] = me + return result + + def _apply_op(self, ket, orig_basis, **options): + state = ket.rewrite(self.basis) + # If the state has only one term + if isinstance(state, State): + ret = (hbar*state.m)*state + # state is a linear combination of states + elif isinstance(state, Sum): + ret = self._apply_operator_Sum(state, **options) + else: + ret = qapply(self*state) + if ret == self*state: + raise NotImplementedError + return ret.rewrite(orig_basis) + + def _apply_operator_JxKet(self, ket, **options): + return self._apply_op(ket, 'Jx', **options) + + def _apply_operator_JxKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jx', **options) + + def _apply_operator_JyKet(self, ket, **options): + return self._apply_op(ket, 'Jy', **options) + + def _apply_operator_JyKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jy', **options) + + def _apply_operator_JzKet(self, ket, **options): + return self._apply_op(ket, 'Jz', **options) + + def _apply_operator_JzKetCoupled(self, ket, **options): + return self._apply_op(ket, 'Jz', **options) + + def _apply_operator_TensorProduct(self, tp, **options): + # Uncoupling operator is only easily found for coordinate basis spin operators + # TODO: add methods for uncoupling operators + if not isinstance(self, (JxOp, JyOp, JzOp)): + raise NotImplementedError + result = [] + for n in range(len(tp.args)): + arg = [] + arg.extend(tp.args[:n]) + arg.append(self._apply_operator(tp.args[n])) + arg.extend(tp.args[n + 1:]) + result.append(tp.__class__(*arg)) + return Add(*result).expand() + + # TODO: move this to qapply_Mul + def _apply_operator_Sum(self, s, **options): + new_func = qapply(self*s.function) + if new_func == self*s.function: + raise NotImplementedError + return Sum(new_func, *s.limits) + + def _eval_trace(self, **options): + #TODO: use options to use different j values + #For now eval at default basis + + # is it efficient to represent each time + # to do a trace? + return self._represent_default_basis().trace() + + +class JplusOp(SpinOpBase, Operator): + """The J+ operator.""" + + _coord = '+' + + basis = 'Jz' + + def _eval_commutator_JminusOp(self, other): + return 2*hbar*JzOp(self.name) + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + m = ket.m + if m.is_Number and j.is_Number: + if m >= j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKet(j, m + S.One) + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if m.is_Number and j.is_Number: + if m >= j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m + S.One))*JzKetCoupled(j, m + S.One, jn, coupling) + + def matrix_element(self, j, m, jp, mp): + result = hbar*sqrt(j*(j + S.One) - mp*(mp + S.One)) + result *= KroneckerDelta(m, mp + 1) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0]) + I*JyOp(args[0]) + + +class JminusOp(SpinOpBase, Operator): + """The J- operator.""" + + _coord = '-' + + basis = 'Jz' + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + m = ket.m + if m.is_Number and j.is_Number: + if m <= -j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKet(j, m - S.One) + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if m.is_Number and j.is_Number: + if m <= -j: + return S.Zero + return hbar*sqrt(j*(j + S.One) - m*(m - S.One))*JzKetCoupled(j, m - S.One, jn, coupling) + + def matrix_element(self, j, m, jp, mp): + result = hbar*sqrt(j*(j + S.One) - mp*(mp - S.One)) + result *= KroneckerDelta(m, mp - 1) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0]) - I*JyOp(args[0]) + + +class JxOp(SpinOpBase, HermitianOperator): + """The Jx operator.""" + + _coord = 'x' + + basis = 'Jx' + + def _eval_commutator_JyOp(self, other): + return I*hbar*JzOp(self.name) + + def _eval_commutator_JzOp(self, other): + return -I*hbar*JyOp(self.name) + + def _apply_operator_JzKet(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options) + return (jp + jm)/Integer(2) + + def _apply_operator_JzKetCoupled(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + return (jp + jm)/Integer(2) + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + jp = JplusOp(self.name)._represent_JzOp(basis, **options) + jm = JminusOp(self.name)._represent_JzOp(basis, **options) + return (jp + jm)/Integer(2) + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + return (JplusOp(args[0]) + JminusOp(args[0]))/2 + + +class JyOp(SpinOpBase, HermitianOperator): + """The Jy operator.""" + + _coord = 'y' + + basis = 'Jy' + + def _eval_commutator_JzOp(self, other): + return I*hbar*JxOp(self.name) + + def _eval_commutator_JxOp(self, other): + return -I*hbar*J2Op(self.name) + + def _apply_operator_JzKet(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKet(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKet(ket, **options) + return (jp - jm)/(Integer(2)*I) + + def _apply_operator_JzKetCoupled(self, ket, **options): + jp = JplusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + jm = JminusOp(self.name)._apply_operator_JzKetCoupled(ket, **options) + return (jp - jm)/(Integer(2)*I) + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + jp = JplusOp(self.name)._represent_JzOp(basis, **options) + jm = JminusOp(self.name)._represent_JzOp(basis, **options) + return (jp - jm)/(Integer(2)*I) + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + return (JplusOp(args[0]) - JminusOp(args[0]))/(2*I) + + +class JzOp(SpinOpBase, HermitianOperator): + """The Jz operator.""" + + _coord = 'z' + + basis = 'Jz' + + def _eval_commutator_JxOp(self, other): + return I*hbar*JyOp(self.name) + + def _eval_commutator_JyOp(self, other): + return -I*hbar*JxOp(self.name) + + def _eval_commutator_JplusOp(self, other): + return hbar*JplusOp(self.name) + + def _eval_commutator_JminusOp(self, other): + return -hbar*JminusOp(self.name) + + def matrix_element(self, j, m, jp, mp): + result = hbar*mp + result *= KroneckerDelta(m, mp) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + +class J2Op(SpinOpBase, HermitianOperator): + """The J^2 operator.""" + + _coord = '2' + + def _eval_commutator_JxOp(self, other): + return S.Zero + + def _eval_commutator_JyOp(self, other): + return S.Zero + + def _eval_commutator_JzOp(self, other): + return S.Zero + + def _eval_commutator_JplusOp(self, other): + return S.Zero + + def _eval_commutator_JminusOp(self, other): + return S.Zero + + def _apply_operator_JxKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JxKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JyKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JyKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JzKet(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def _apply_operator_JzKetCoupled(self, ket, **options): + j = ket.j + return hbar**2*j*(j + 1)*ket + + def matrix_element(self, j, m, jp, mp): + result = (hbar**2)*j*(j + 1) + result *= KroneckerDelta(m, mp) + result *= KroneckerDelta(j, jp) + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _print_contents_pretty(self, printer, *args): + a = prettyForm(str(self.name)) + b = prettyForm('2') + return a**b + + def _print_contents_latex(self, printer, *args): + return r'%s^2' % str(self.name) + + def _eval_rewrite_as_xyz(self, *args, **kwargs): + return JxOp(args[0])**2 + JyOp(args[0])**2 + JzOp(args[0])**2 + + def _eval_rewrite_as_plusminus(self, *args, **kwargs): + a = args[0] + return JzOp(a)**2 + \ + S.Half*(JplusOp(a)*JminusOp(a) + JminusOp(a)*JplusOp(a)) + + +class Rotation(UnitaryOperator): + """Wigner D operator in terms of Euler angles. + + Defines the rotation operator in terms of the Euler angles defined by + the z-y-z convention for a passive transformation. That is the coordinate + axes are rotated first about the z-axis, giving the new x'-y'-z' axes. Then + this new coordinate system is rotated about the new y'-axis, giving new + x''-y''-z'' axes. Then this new coordinate system is rotated about the + z''-axis. Conventions follow those laid out in [1]_. + + Parameters + ========== + + alpha : Number, Symbol + First Euler Angle + beta : Number, Symbol + Second Euler angle + gamma : Number, Symbol + Third Euler angle + + Examples + ======== + + A simple example rotation operator: + + >>> from sympy import pi + >>> from sympy.physics.quantum.spin import Rotation + >>> Rotation(pi, 0, pi/2) + R(pi,0,pi/2) + + With symbolic Euler angles and calculating the inverse rotation operator: + + >>> from sympy import symbols + >>> a, b, c = symbols('a b c') + >>> Rotation(a, b, c) + R(a,b,c) + >>> Rotation(a, b, c).inverse() + R(-c,-b,-a) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + D: Wigner-D function + d: Wigner small-d function + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + @classmethod + def _eval_args(cls, args): + args = QExpr._eval_args(args) + if len(args) != 3: + raise ValueError('3 Euler angles required, got: %r' % args) + return args + + @classmethod + def _eval_hilbert_space(cls, label): + # We consider all j values so our space is infinite. + return ComplexSpace(S.Infinity) + + @property + def alpha(self): + return self.label[0] + + @property + def beta(self): + return self.label[1] + + @property + def gamma(self): + return self.label[2] + + def _print_operator_name(self, printer, *args): + return 'R' + + def _print_operator_name_pretty(self, printer, *args): + if printer._use_unicode: + return prettyForm('\N{SCRIPT CAPITAL R}' + ' ') + else: + return prettyForm("R ") + + def _print_operator_name_latex(self, printer, *args): + return r'\mathcal{R}' + + def _eval_inverse(self): + return Rotation(-self.gamma, -self.beta, -self.alpha) + + @classmethod + def D(cls, j, m, mp, alpha, beta, gamma): + """Wigner D-function. + + Returns an instance of the WignerD class corresponding to the Wigner-D + function specified by the parameters. + + Parameters + =========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + alpha : Number, Symbol + First Euler angle of rotation + beta : Number, Symbol + Second Euler angle of rotation + gamma : Number, Symbol + Third Euler angle of rotation + + Examples + ======== + + Return the Wigner-D matrix element for a defined rotation, both + numerical and symbolic: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi, symbols + >>> alpha, beta, gamma = symbols('alpha beta gamma') + >>> Rotation.D(1, 1, 0,pi, pi/2,-pi) + WignerD(1, 1, 0, pi, pi/2, -pi) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + + """ + return WignerD(j, m, mp, alpha, beta, gamma) + + @classmethod + def d(cls, j, m, mp, beta): + """Wigner small-d function. + + Returns an instance of the WignerD class corresponding to the Wigner-D + function specified by the parameters with the alpha and gamma angles + given as 0. + + Parameters + =========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + beta : Number, Symbol + Second Euler angle of rotation + + Examples + ======== + + Return the Wigner-D matrix element for a defined rotation, both + numerical and symbolic: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi, symbols + >>> beta = symbols('beta') + >>> Rotation.d(1, 1, 0, pi/2) + WignerD(1, 1, 0, 0, pi/2, 0) + + See Also + ======== + + WignerD: Symbolic Wigner-D function + + """ + return WignerD(j, m, mp, 0, beta, 0) + + def matrix_element(self, j, m, jp, mp): + result = self.__class__.D( + jp, m, mp, self.alpha, self.beta, self.gamma + ) + result *= KroneckerDelta(j, jp) + return result + + def _represent_base(self, basis, **options): + j = sympify(options.get('j', S.Half)) + # TODO: move evaluation up to represent function/implement elsewhere + evaluate = sympify(options.get('doit')) + size, mvals = m_values(j) + result = zeros(size, size) + for p in range(size): + for q in range(size): + me = self.matrix_element(j, mvals[p], j, mvals[q]) + if evaluate: + result[p, q] = me.doit() + else: + result[p, q] = me + return result + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(basis, **options) + + def _apply_operator_uncoupled(self, state, ket, *, dummy=True, **options): + a = self.alpha + b = self.beta + g = self.gamma + j = ket.j + m = ket.m + if j.is_number: + s = [] + size = m_values(j) + sz = size[1] + for mp in sz: + r = Rotation.D(j, m, mp, a, b, g) + z = r.doit() + s.append(z*state(j, mp)) + return Add(*s) + else: + if dummy: + mp = Dummy('mp') + else: + mp = symbols('mp') + return Sum(Rotation.D(j, m, mp, a, b, g)*state(j, mp), (mp, -j, j)) + + def _apply_operator_JxKet(self, ket, **options): + return self._apply_operator_uncoupled(JxKet, ket, **options) + + def _apply_operator_JyKet(self, ket, **options): + return self._apply_operator_uncoupled(JyKet, ket, **options) + + def _apply_operator_JzKet(self, ket, **options): + return self._apply_operator_uncoupled(JzKet, ket, **options) + + def _apply_operator_coupled(self, state, ket, *, dummy=True, **options): + a = self.alpha + b = self.beta + g = self.gamma + j = ket.j + m = ket.m + jn = ket.jn + coupling = ket.coupling + if j.is_number: + s = [] + size = m_values(j) + sz = size[1] + for mp in sz: + r = Rotation.D(j, m, mp, a, b, g) + z = r.doit() + s.append(z*state(j, mp, jn, coupling)) + return Add(*s) + else: + if dummy: + mp = Dummy('mp') + else: + mp = symbols('mp') + return Sum(Rotation.D(j, m, mp, a, b, g)*state( + j, mp, jn, coupling), (mp, -j, j)) + + def _apply_operator_JxKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JxKetCoupled, ket, **options) + + def _apply_operator_JyKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JyKetCoupled, ket, **options) + + def _apply_operator_JzKetCoupled(self, ket, **options): + return self._apply_operator_coupled(JzKetCoupled, ket, **options) + +class WignerD(Expr): + r"""Wigner-D function + + The Wigner D-function gives the matrix elements of the rotation + operator in the jm-representation. For the Euler angles `\alpha`, + `\beta`, `\gamma`, the D-function is defined such that: + + .. math :: + = \delta_{jj'} D(j, m, m', \alpha, \beta, \gamma) + + Where the rotation operator is as defined by the Rotation class [1]_. + + The Wigner D-function defined in this way gives: + + .. math :: + D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma} + + Where d is the Wigner small-d function, which is given by Rotation.d. + + The Wigner small-d function gives the component of the Wigner + D-function that is determined by the second Euler angle. That is the + Wigner D-function is: + + .. math :: + D(j, m, m', \alpha, \beta, \gamma) = e^{-i m \alpha} d(j, m, m', \beta) e^{-i m' \gamma} + + Where d is the small-d function. The Wigner D-function is given by + Rotation.D. + + Note that to evaluate the D-function, the j, m and mp parameters must + be integer or half integer numbers. + + Parameters + ========== + + j : Number + Total angular momentum + m : Number + Eigenvalue of angular momentum along axis after rotation + mp : Number + Eigenvalue of angular momentum along rotated axis + alpha : Number, Symbol + First Euler angle of rotation + beta : Number, Symbol + Second Euler angle of rotation + gamma : Number, Symbol + Third Euler angle of rotation + + Examples + ======== + + Evaluate the Wigner-D matrix elements of a simple rotation: + + >>> from sympy.physics.quantum.spin import Rotation + >>> from sympy import pi + >>> rot = Rotation.D(1, 1, 0, pi, pi/2, 0) + >>> rot + WignerD(1, 1, 0, pi, pi/2, 0) + >>> rot.doit() + sqrt(2)/2 + + Evaluate the Wigner-d matrix elements of a simple rotation + + >>> rot = Rotation.d(1, 1, 0, pi/2) + >>> rot + WignerD(1, 1, 0, 0, pi/2, 0) + >>> rot.doit() + -sqrt(2)/2 + + See Also + ======== + + Rotation: Rotation operator + + References + ========== + + .. [1] Varshalovich, D A, Quantum Theory of Angular Momentum. 1988. + """ + + is_commutative = True + + def __new__(cls, *args, **hints): + if not len(args) == 6: + raise ValueError('6 parameters expected, got %s' % args) + args = sympify(args) + evaluate = hints.get('evaluate', False) + if evaluate: + return Expr.__new__(cls, *args)._eval_wignerd() + return Expr.__new__(cls, *args) + + @property + def j(self): + return self.args[0] + + @property + def m(self): + return self.args[1] + + @property + def mp(self): + return self.args[2] + + @property + def alpha(self): + return self.args[3] + + @property + def beta(self): + return self.args[4] + + @property + def gamma(self): + return self.args[5] + + def _latex(self, printer, *args): + if self.alpha == 0 and self.gamma == 0: + return r'd^{%s}_{%s,%s}\left(%s\right)' % \ + ( + printer._print(self.j), printer._print( + self.m), printer._print(self.mp), + printer._print(self.beta) ) + return r'D^{%s}_{%s,%s}\left(%s,%s,%s\right)' % \ + ( + printer._print( + self.j), printer._print(self.m), printer._print(self.mp), + printer._print(self.alpha), printer._print(self.beta), printer._print(self.gamma) ) + + def _pretty(self, printer, *args): + top = printer._print(self.j) + + bot = printer._print(self.m) + bot = prettyForm(*bot.right(',')) + bot = prettyForm(*bot.right(printer._print(self.mp))) + + pad = max(top.width(), bot.width()) + top = prettyForm(*top.left(' ')) + bot = prettyForm(*bot.left(' ')) + if pad > top.width(): + top = prettyForm(*top.right(' '*(pad - top.width()))) + if pad > bot.width(): + bot = prettyForm(*bot.right(' '*(pad - bot.width()))) + if self.alpha == 0 and self.gamma == 0: + args = printer._print(self.beta) + s = stringPict('d' + ' '*pad) + else: + args = printer._print(self.alpha) + args = prettyForm(*args.right(',')) + args = prettyForm(*args.right(printer._print(self.beta))) + args = prettyForm(*args.right(',')) + args = prettyForm(*args.right(printer._print(self.gamma))) + + s = stringPict('D' + ' '*pad) + + args = prettyForm(*args.parens()) + s = prettyForm(*s.above(top)) + s = prettyForm(*s.below(bot)) + s = prettyForm(*s.right(args)) + return s + + def doit(self, **hints): + hints['evaluate'] = True + return WignerD(*self.args, **hints) + + def _eval_wignerd(self): + j = self.j + m = self.m + mp = self.mp + alpha = self.alpha + beta = self.beta + gamma = self.gamma + if alpha == 0 and beta == 0 and gamma == 0: + return KroneckerDelta(m, mp) + if not j.is_number: + raise ValueError( + 'j parameter must be numerical to evaluate, got %s' % j) + r = 0 + if beta == pi/2: + # Varshalovich Equation (5), Section 4.16, page 113, setting + # alpha=gamma=0. + for k in range(2*j + 1): + if k > j + mp or k > j - m or k < mp - m: + continue + r += (S.NegativeOne)**k*binomial(j + mp, k)*binomial(j - mp, k + m - mp) + r *= (S.NegativeOne)**(m - mp) / 2**j*sqrt(factorial(j + m) * + factorial(j - m) / (factorial(j + mp)*factorial(j - mp))) + else: + # Varshalovich Equation(5), Section 4.7.2, page 87, where we set + # beta1=beta2=pi/2, and we get alpha=gamma=pi/2 and beta=phi+pi, + # then we use the Eq. (1), Section 4.4. page 79, to simplify: + # d(j, m, mp, beta+pi) = (-1)**(j-mp)*d(j, m, -mp, beta) + # This happens to be almost the same as in Eq.(10), Section 4.16, + # except that we need to substitute -mp for mp. + size, mvals = m_values(j) + for mpp in mvals: + r += Rotation.d(j, m, mpp, pi/2).doit()*(cos(-mpp*beta) + I*sin(-mpp*beta))*\ + Rotation.d(j, mpp, -mp, pi/2).doit() + # Empirical normalization factor so results match Varshalovich + # Tables 4.3-4.12 + # Note that this exact normalization does not follow from the + # above equations + r = r*I**(2*j - m - mp)*(-1)**(2*m) + # Finally, simplify the whole expression + r = simplify(r) + r *= exp(-I*m*alpha)*exp(-I*mp*gamma) + return r + + +Jx = JxOp('J') +Jy = JyOp('J') +Jz = JzOp('J') +J2 = J2Op('J') +Jplus = JplusOp('J') +Jminus = JminusOp('J') + + +#----------------------------------------------------------------------------- +# Spin States +#----------------------------------------------------------------------------- + + +class SpinState(State): + """Base class for angular momentum states.""" + + _label_separator = ',' + + def __new__(cls, j, m): + j = sympify(j) + m = sympify(m) + if j.is_number: + if 2*j != int(2*j): + raise ValueError( + 'j must be integer or half-integer, got: %s' % j) + if j < 0: + raise ValueError('j must be >= 0, got: %s' % j) + if m.is_number: + if 2*m != int(2*m): + raise ValueError( + 'm must be integer or half-integer, got: %s' % m) + if j.is_number and m.is_number: + if abs(m) > j: + raise ValueError('Allowed values for m are -j <= m <= j, got j, m: %s, %s' % (j, m)) + if int(j - m) != j - m: + raise ValueError('Both j and m must be integer or half-integer, got j, m: %s, %s' % (j, m)) + return State.__new__(cls, j, m) + + @property + def j(self): + return self.label[0] + + @property + def m(self): + return self.label[1] + + @classmethod + def _eval_hilbert_space(cls, label): + return ComplexSpace(2*label[0] + 1) + + def _represent_base(self, **options): + j = self.j + m = self.m + alpha = sympify(options.get('alpha', 0)) + beta = sympify(options.get('beta', 0)) + gamma = sympify(options.get('gamma', 0)) + size, mvals = m_values(j) + result = zeros(size, 1) + # breaks finding angles on L930 + for p, mval in enumerate(mvals): + if m.is_number: + result[p, 0] = Rotation.D( + self.j, mval, self.m, alpha, beta, gamma).doit() + else: + result[p, 0] = Rotation.D(self.j, mval, + self.m, alpha, beta, gamma) + return result + + def _eval_rewrite_as_Jx(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jx, JxBra, **options) + return self._rewrite_basis(Jx, JxKet, **options) + + def _eval_rewrite_as_Jy(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jy, JyBra, **options) + return self._rewrite_basis(Jy, JyKet, **options) + + def _eval_rewrite_as_Jz(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jz, JzBra, **options) + return self._rewrite_basis(Jz, JzKet, **options) + + def _rewrite_basis(self, basis, evect, **options): + from sympy.physics.quantum.represent import represent + j = self.j + args = self.args[2:] + if j.is_number: + if isinstance(self, CoupledSpinState): + if j == int(j): + start = j**2 + else: + start = (2*j - 1)*(2*j + 1)/4 + else: + start = 0 + vect = represent(self, basis=basis, **options) + result = Add( + *[vect[start + i]*evect(j, j - i, *args) for i in range(2*j + 1)]) + if isinstance(self, CoupledSpinState) and options.get('coupled') is False: + return uncouple(result) + return result + else: + i = 0 + mi = symbols('mi') + # make sure not to introduce a symbol already in the state + while self.subs(mi, 0) != self: + i += 1 + mi = symbols('mi%d' % i) + break + # TODO: better way to get angles of rotation + if isinstance(self, CoupledSpinState): + test_args = (0, mi, (0, 0)) + else: + test_args = (0, mi) + if isinstance(self, Ket): + angles = represent( + self.__class__(*test_args), basis=basis)[0].args[3:6] + else: + angles = represent(self.__class__( + *test_args), basis=basis)[0].args[0].args[3:6] + if angles == (0, 0, 0): + return self + else: + state = evect(j, mi, *args) + lt = Rotation.D(j, mi, self.m, *angles) + return Sum(lt*state, (mi, -j, j)) + + def _eval_innerproduct_JxBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JxOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_innerproduct_JyBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JyOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_innerproduct_JzBra(self, bra, **hints): + result = KroneckerDelta(self.j, bra.j) + if bra.dual_class() is not self.__class__: + result *= self._represent_JzOp(None)[bra.j - bra.m] + else: + result *= KroneckerDelta( + self.j, bra.j)*KroneckerDelta(self.m, bra.m) + return result + + def _eval_trace(self, bra, **hints): + + # One way to implement this method is to assume the basis set k is + # passed. + # Then we can apply the discrete form of Trace formula here + # Tr(|i> + #then we do qapply() on each each inner product and sum over them. + + # OR + + # Inner product of |i>>> from sympy.physics.quantum.spin import JzKet, JxKet + >>> from sympy import symbols + >>> JzKet(1, 0) + |1,0> + >>> j, m = symbols('j m') + >>> JzKet(j, m) + |j,m> + + Rewriting the JzKet in terms of eigenkets of the Jx operator: + Note: that the resulting eigenstates are JxKet's + + >>> JzKet(1,1).rewrite("Jx") + |1,-1>/2 - sqrt(2)*|1,0>/2 + |1,1>/2 + + Get the vector representation of a state in terms of the basis elements + of the Jx operator: + + >>> from sympy.physics.quantum.represent import represent + >>> from sympy.physics.quantum.spin import Jx, Jz + >>> represent(JzKet(1,-1), basis=Jx) + Matrix([ + [ 1/2], + [sqrt(2)/2], + [ 1/2]]) + + Apply innerproducts between states: + + >>> from sympy.physics.quantum.innerproduct import InnerProduct + >>> from sympy.physics.quantum.spin import JxBra + >>> i = InnerProduct(JxBra(1,1), JzKet(1,1)) + >>> i + <1,1|1,1> + >>> i.doit() + 1/2 + + *Uncoupled States:* + + Define an uncoupled state as a TensorProduct between two Jz eigenkets: + + >>> from sympy.physics.quantum.tensorproduct import TensorProduct + >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2') + >>> TensorProduct(JzKet(1,0), JzKet(1,1)) + |1,0>x|1,1> + >>> TensorProduct(JzKet(j1,m1), JzKet(j2,m2)) + |j1,m1>x|j2,m2> + + A TensorProduct can be rewritten, in which case the eigenstates that make + up the tensor product is rewritten to the new basis: + + >>> TensorProduct(JzKet(1,1),JxKet(1,1)).rewrite('Jz') + |1,1>x|1,-1>/2 + sqrt(2)*|1,1>x|1,0>/2 + |1,1>x|1,1>/2 + + The represent method for TensorProduct's gives the vector representation of + the state. Note that the state in the product basis is the equivalent of the + tensor product of the vector representation of the component eigenstates: + + >>> represent(TensorProduct(JzKet(1,0),JzKet(1,1))) + Matrix([ + [0], + [0], + [0], + [1], + [0], + [0], + [0], + [0], + [0]]) + >>> represent(TensorProduct(JzKet(1,1),JxKet(1,1)), basis=Jz) + Matrix([ + [ 1/2], + [sqrt(2)/2], + [ 1/2], + [ 0], + [ 0], + [ 0], + [ 0], + [ 0], + [ 0]]) + + See Also + ======== + + JzKetCoupled: Coupled eigenstates + sympy.physics.quantum.tensorproduct.TensorProduct: Used to specify uncoupled states + uncouple: Uncouples states given coupling parameters + couple: Couples uncoupled states + + """ + + @classmethod + def dual_class(self): + return JzBra + + @classmethod + def coupled_class(self): + return JzKetCoupled + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_base(beta=pi*Rational(3, 2), **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_base(**options) + + +class JzBra(SpinState, Bra): + """Eigenbra of Jz. + + See the JzKet for the usage of spin eigenstates. + + See Also + ======== + + JzKet: Usage of spin states + + """ + + @classmethod + def dual_class(self): + return JzKet + + @classmethod + def coupled_class(self): + return JzBraCoupled + + +# Method used primarily to create coupled_n and coupled_jn by __new__ in +# CoupledSpinState +# This same method is also used by the uncouple method, and is separated from +# the CoupledSpinState class to maintain consistency in defining coupling +def _build_coupled(jcoupling, length): + n_list = [ [n + 1] for n in range(length) ] + coupled_jn = [] + coupled_n = [] + for n1, n2, j_new in jcoupling: + coupled_jn.append(j_new) + coupled_n.append( (n_list[n1 - 1], n_list[n2 - 1]) ) + n_sort = sorted(n_list[n1 - 1] + n_list[n2 - 1]) + n_list[n_sort[0] - 1] = n_sort + return coupled_n, coupled_jn + + +class CoupledSpinState(SpinState): + """Base class for coupled angular momentum states.""" + + def __new__(cls, j, m, jn, *jcoupling): + # Check j and m values using SpinState + SpinState(j, m) + # Build and check coupling scheme from arguments + if len(jcoupling) == 0: + # Use default coupling scheme + jcoupling = [] + for n in range(2, len(jn)): + jcoupling.append( (1, n, Add(*[jn[i] for i in range(n)])) ) + jcoupling.append( (1, len(jn), j) ) + elif len(jcoupling) == 1: + # Use specified coupling scheme + jcoupling = jcoupling[0] + else: + raise TypeError("CoupledSpinState only takes 3 or 4 arguments, got: %s" % (len(jcoupling) + 3) ) + # Check arguments have correct form + if not isinstance(jn, (list, tuple, Tuple)): + raise TypeError('jn must be Tuple, list or tuple, got %s' % + jn.__class__.__name__) + if not isinstance(jcoupling, (list, tuple, Tuple)): + raise TypeError('jcoupling must be Tuple, list or tuple, got %s' % + jcoupling.__class__.__name__) + if not all(isinstance(term, (list, tuple, Tuple)) for term in jcoupling): + raise TypeError( + 'All elements of jcoupling must be list, tuple or Tuple') + if not len(jn) - 1 == len(jcoupling): + raise ValueError('jcoupling must have length of %d, got %d' % + (len(jn) - 1, len(jcoupling))) + if not all(len(x) == 3 for x in jcoupling): + raise ValueError('All elements of jcoupling must have length 3') + # Build sympified args + j = sympify(j) + m = sympify(m) + jn = Tuple( *[sympify(ji) for ji in jn] ) + jcoupling = Tuple( *[Tuple(sympify( + n1), sympify(n2), sympify(ji)) for (n1, n2, ji) in jcoupling] ) + # Check values in coupling scheme give physical state + if any(2*ji != int(2*ji) for ji in jn if ji.is_number): + raise ValueError('All elements of jn must be integer or half-integer, got: %s' % jn) + if any(n1 != int(n1) or n2 != int(n2) for (n1, n2, _) in jcoupling): + raise ValueError('Indices in jcoupling must be integers') + if any(n1 < 1 or n2 < 1 or n1 > len(jn) or n2 > len(jn) for (n1, n2, _) in jcoupling): + raise ValueError('Indices must be between 1 and the number of coupled spin spaces') + if any(2*ji != int(2*ji) for (_, _, ji) in jcoupling if ji.is_number): + raise ValueError('All coupled j values in coupling scheme must be integer or half-integer') + coupled_n, coupled_jn = _build_coupled(jcoupling, len(jn)) + jvals = list(jn) + for n, (n1, n2) in enumerate(coupled_n): + j1 = jvals[min(n1) - 1] + j2 = jvals[min(n2) - 1] + j3 = coupled_jn[n] + if sympify(j1).is_number and sympify(j2).is_number and sympify(j3).is_number: + if j1 + j2 < j3: + raise ValueError('All couplings must have j1+j2 >= j3, ' + 'in coupling number %d got j1,j2,j3: %d,%d,%d' % (n + 1, j1, j2, j3)) + if abs(j1 - j2) > j3: + raise ValueError("All couplings must have |j1+j2| <= j3, " + "in coupling number %d got j1,j2,j3: %d,%d,%d" % (n + 1, j1, j2, j3)) + if int_valued(j1 + j2): + pass + jvals[min(n1 + n2) - 1] = j3 + if len(jcoupling) > 0 and jcoupling[-1][2] != j: + raise ValueError('Last j value coupled together must be the final j of the state') + # Return state + return State.__new__(cls, j, m, jn, jcoupling) + + def _print_label(self, printer, *args): + label = [printer._print(self.j), printer._print(self.m)] + for i, ji in enumerate(self.jn, start=1): + label.append('j%d=%s' % ( + i, printer._print(ji) + )) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + label.append('j(%s)=%s' % ( + ','.join(str(i) for i in sorted(n1 + n2)), printer._print(jn) + )) + return ','.join(label) + + def _print_label_pretty(self, printer, *args): + label = [self.j, self.m] + for i, ji in enumerate(self.jn, start=1): + symb = 'j%d' % i + symb = pretty_symbol(symb) + symb = prettyForm(symb + '=') + item = prettyForm(*symb.right(printer._print(ji))) + label.append(item) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + n = ','.join(pretty_symbol("j%d" % i)[-1] for i in sorted(n1 + n2)) + symb = prettyForm('j' + n + '=') + item = prettyForm(*symb.right(printer._print(jn))) + label.append(item) + return self._print_sequence_pretty( + label, self._label_separator, printer, *args + ) + + def _print_label_latex(self, printer, *args): + label = [ + printer._print(self.j, *args), + printer._print(self.m, *args) + ] + for i, ji in enumerate(self.jn, start=1): + label.append('j_{%d}=%s' % (i, printer._print(ji, *args)) ) + for jn, (n1, n2) in zip(self.coupled_jn[:-1], self.coupled_n[:-1]): + n = ','.join(str(i) for i in sorted(n1 + n2)) + label.append('j_{%s}=%s' % (n, printer._print(jn, *args)) ) + return self._label_separator.join(label) + + @property + def jn(self): + return self.label[2] + + @property + def coupling(self): + return self.label[3] + + @property + def coupled_jn(self): + return _build_coupled(self.label[3], len(self.label[2]))[1] + + @property + def coupled_n(self): + return _build_coupled(self.label[3], len(self.label[2]))[0] + + @classmethod + def _eval_hilbert_space(cls, label): + j = Add(*label[2]) + if j.is_number: + return DirectSumHilbertSpace(*[ ComplexSpace(x) for x in range(int(2*j + 1), 0, -2) ]) + else: + # TODO: Need hilbert space fix, see issue 5732 + # Desired behavior: + #ji = symbols('ji') + #ret = Sum(ComplexSpace(2*ji + 1), (ji, 0, j)) + # Temporary fix: + return ComplexSpace(2*j + 1) + + def _represent_coupled_base(self, **options): + evect = self.uncoupled_class() + if not self.j.is_number: + raise ValueError( + 'State must not have symbolic j value to represent') + if not self.hilbert_space.dimension.is_number: + raise ValueError( + 'State must not have symbolic j values to represent') + result = zeros(self.hilbert_space.dimension, 1) + if self.j == int(self.j): + start = self.j**2 + else: + start = (2*self.j - 1)*(1 + 2*self.j)/4 + result[start:start + 2*self.j + 1, 0] = evect( + self.j, self.m)._represent_base(**options) + return result + + def _eval_rewrite_as_Jx(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jx, JxBraCoupled, **options) + return self._rewrite_basis(Jx, JxKetCoupled, **options) + + def _eval_rewrite_as_Jy(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jy, JyBraCoupled, **options) + return self._rewrite_basis(Jy, JyKetCoupled, **options) + + def _eval_rewrite_as_Jz(self, *args, **options): + if isinstance(self, Bra): + return self._rewrite_basis(Jz, JzBraCoupled, **options) + return self._rewrite_basis(Jz, JzKetCoupled, **options) + + +class JxKetCoupled(CoupledSpinState, Ket): + """Coupled eigenket of Jx. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JxBraCoupled + + @classmethod + def uncoupled_class(self): + return JxKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(**options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(beta=pi/2, **options) + + +class JxBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jx. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JxKetCoupled + + @classmethod + def uncoupled_class(self): + return JxBra + + +class JyKetCoupled(CoupledSpinState, Ket): + """Coupled eigenket of Jy. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JyBraCoupled + + @classmethod + def uncoupled_class(self): + return JyKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(gamma=pi/2, **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(**options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=-pi/2, gamma=pi/2, **options) + + +class JyBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jy. + + See JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JyKetCoupled + + @classmethod + def uncoupled_class(self): + return JyBra + + +class JzKetCoupled(CoupledSpinState, Ket): + r"""Coupled eigenket of Jz + + Spin state that is an eigenket of Jz which represents the coupling of + separate spin spaces. + + The arguments for creating instances of JzKetCoupled are ``j``, ``m``, + ``jn`` and an optional ``jcoupling`` argument. The ``j`` and ``m`` options + are the total angular momentum quantum numbers, as used for normal states + (e.g. JzKet). + + The other required parameter in ``jn``, which is a tuple defining the `j_n` + angular momentum quantum numbers of the product spaces. So for example, if + a state represented the coupling of the product basis state + `\left|j_1,m_1\right\rangle\times\left|j_2,m_2\right\rangle`, the ``jn`` + for this state would be ``(j1,j2)``. + + The final option is ``jcoupling``, which is used to define how the spaces + specified by ``jn`` are coupled, which includes both the order these spaces + are coupled together and the quantum numbers that arise from these + couplings. The ``jcoupling`` parameter itself is a list of lists, such that + each of the sublists defines a single coupling between the spin spaces. If + there are N coupled angular momentum spaces, that is ``jn`` has N elements, + then there must be N-1 sublists. Each of these sublists making up the + ``jcoupling`` parameter have length 3. The first two elements are the + indices of the product spaces that are considered to be coupled together. + For example, if we want to couple `j_1` and `j_4`, the indices would be 1 + and 4. If a state has already been coupled, it is referenced by the + smallest index that is coupled, so if `j_2` and `j_4` has already been + coupled to some `j_{24}`, then this value can be coupled by referencing it + with index 2. The final element of the sublist is the quantum number of the + coupled state. So putting everything together, into a valid sublist for + ``jcoupling``, if `j_1` and `j_2` are coupled to an angular momentum space + with quantum number `j_{12}` with the value ``j12``, the sublist would be + ``(1,2,j12)``, N-1 of these sublists are used in the list for + ``jcoupling``. + + Note the ``jcoupling`` parameter is optional, if it is not specified, the + default coupling is taken. This default value is to coupled the spaces in + order and take the quantum number of the coupling to be the maximum value. + For example, if the spin spaces are `j_1`, `j_2`, `j_3`, `j_4`, then the + default coupling couples `j_1` and `j_2` to `j_{12}=j_1+j_2`, then, + `j_{12}` and `j_3` are coupled to `j_{123}=j_{12}+j_3`, and finally + `j_{123}` and `j_4` to `j=j_{123}+j_4`. The jcoupling value that would + correspond to this is: + + ``((1,2,j1+j2),(1,3,j1+j2+j3))`` + + Parameters + ========== + + args : tuple + The arguments that must be passed are ``j``, ``m``, ``jn``, and + ``jcoupling``. The ``j`` value is the total angular momentum. The ``m`` + value is the eigenvalue of the Jz spin operator. The ``jn`` list are + the j values of argular momentum spaces coupled together. The + ``jcoupling`` parameter is an optional parameter defining how the spaces + are coupled together. See the above description for how these coupling + parameters are defined. + + Examples + ======== + + Defining simple spin states, both numerical and symbolic: + + >>> from sympy.physics.quantum.spin import JzKetCoupled + >>> from sympy import symbols + >>> JzKetCoupled(1, 0, (1, 1)) + |1,0,j1=1,j2=1> + >>> j, m, j1, j2 = symbols('j m j1 j2') + >>> JzKetCoupled(j, m, (j1, j2)) + |j,m,j1=j1,j2=j2> + + Defining coupled spin states for more than 2 coupled spaces with various + coupling parameters: + + >>> JzKetCoupled(2, 1, (1, 1, 1)) + |2,1,j1=1,j2=1,j3=1,j(1,2)=2> + >>> JzKetCoupled(2, 1, (1, 1, 1), ((1,2,2),(1,3,2)) ) + |2,1,j1=1,j2=1,j3=1,j(1,2)=2> + >>> JzKetCoupled(2, 1, (1, 1, 1), ((2,3,1),(1,2,2)) ) + |2,1,j1=1,j2=1,j3=1,j(2,3)=1> + + Rewriting the JzKetCoupled in terms of eigenkets of the Jx operator: + Note: that the resulting eigenstates are JxKetCoupled + + >>> JzKetCoupled(1,1,(1,1)).rewrite("Jx") + |1,-1,j1=1,j2=1>/2 - sqrt(2)*|1,0,j1=1,j2=1>/2 + |1,1,j1=1,j2=1>/2 + + The rewrite method can be used to convert a coupled state to an uncoupled + state. This is done by passing coupled=False to the rewrite function: + + >>> JzKetCoupled(1, 0, (1, 1)).rewrite('Jz', coupled=False) + -sqrt(2)*|1,-1>x|1,1>/2 + sqrt(2)*|1,1>x|1,-1>/2 + + Get the vector representation of a state in terms of the basis elements + of the Jx operator: + + >>> from sympy.physics.quantum.represent import represent + >>> from sympy.physics.quantum.spin import Jx + >>> from sympy import S + >>> represent(JzKetCoupled(1,-1,(S(1)/2,S(1)/2)), basis=Jx) + Matrix([ + [ 0], + [ 1/2], + [sqrt(2)/2], + [ 1/2]]) + + See Also + ======== + + JzKet: Normal spin eigenstates + uncouple: Uncoupling of coupling spin states + couple: Coupling of uncoupled spin states + + """ + + @classmethod + def dual_class(self): + return JzBraCoupled + + @classmethod + def uncoupled_class(self): + return JzKet + + def _represent_default_basis(self, **options): + return self._represent_JzOp(None, **options) + + def _represent_JxOp(self, basis, **options): + return self._represent_coupled_base(beta=pi*Rational(3, 2), **options) + + def _represent_JyOp(self, basis, **options): + return self._represent_coupled_base(alpha=pi*Rational(3, 2), beta=pi/2, gamma=pi/2, **options) + + def _represent_JzOp(self, basis, **options): + return self._represent_coupled_base(**options) + + +class JzBraCoupled(CoupledSpinState, Bra): + """Coupled eigenbra of Jz. + + See the JzKetCoupled for the usage of coupled spin eigenstates. + + See Also + ======== + + JzKetCoupled: Usage of coupled spin states + + """ + + @classmethod + def dual_class(self): + return JzKetCoupled + + @classmethod + def uncoupled_class(self): + return JzBra + +#----------------------------------------------------------------------------- +# Coupling/uncoupling +#----------------------------------------------------------------------------- + + +def couple(expr, jcoupling_list=None): + """ Couple a tensor product of spin states + + This function can be used to couple an uncoupled tensor product of spin + states. All of the eigenstates to be coupled must be of the same class. It + will return a linear combination of eigenstates that are subclasses of + CoupledSpinState determined by Clebsch-Gordan angular momentum coupling + coefficients. + + Parameters + ========== + + expr : Expr + An expression involving TensorProducts of spin states to be coupled. + Each state must be a subclass of SpinState and they all must be the + same class. + + jcoupling_list : list or tuple + Elements of this list are sub-lists of length 2 specifying the order of + the coupling of the spin spaces. The length of this must be N-1, where N + is the number of states in the tensor product to be coupled. The + elements of this sublist are the same as the first two elements of each + sublist in the ``jcoupling`` parameter defined for JzKetCoupled. If this + parameter is not specified, the default value is taken, which couples + the first and second product basis spaces, then couples this new coupled + space to the third product space, etc + + Examples + ======== + + Couple a tensor product of numerical states for two spaces: + + >>> from sympy.physics.quantum.spin import JzKet, couple + >>> from sympy.physics.quantum.tensorproduct import TensorProduct + >>> couple(TensorProduct(JzKet(1,0), JzKet(1,1))) + -sqrt(2)*|1,1,j1=1,j2=1>/2 + sqrt(2)*|2,1,j1=1,j2=1>/2 + + + Numerical coupling of three spaces using the default coupling method, i.e. + first and second spaces couple, then this couples to the third space: + + >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0))) + sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,2)=2>/3 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,2)=2>/3 + + Perform this same coupling, but we define the coupling to first couple + the first and third spaces: + + >>> couple(TensorProduct(JzKet(1,1), JzKet(1,1), JzKet(1,0)), ((1,3),(1,2)) ) + sqrt(2)*|2,2,j1=1,j2=1,j3=1,j(1,3)=1>/2 - sqrt(6)*|2,2,j1=1,j2=1,j3=1,j(1,3)=2>/6 + sqrt(3)*|3,2,j1=1,j2=1,j3=1,j(1,3)=2>/3 + + Couple a tensor product of symbolic states: + + >>> from sympy import symbols + >>> j1,m1,j2,m2 = symbols('j1 m1 j2 m2') + >>> couple(TensorProduct(JzKet(j1,m1), JzKet(j2,m2))) + Sum(CG(j1, m1, j2, m2, j, m1 + m2)*|j,m1 + m2,j1=j1,j2=j2>, (j, m1 + m2, j1 + j2)) + + """ + a = expr.atoms(TensorProduct) + for tp in a: + # Allow other tensor products to be in expression + if not all(isinstance(state, SpinState) for state in tp.args): + continue + # If tensor product has all spin states, raise error for invalid tensor product state + if not all(state.__class__ is tp.args[0].__class__ for state in tp.args): + raise TypeError('All states must be the same basis') + expr = expr.subs(tp, _couple(tp, jcoupling_list)) + return expr + + +def _couple(tp, jcoupling_list): + states = tp.args + coupled_evect = states[0].coupled_class() + + # Define default coupling if none is specified + if jcoupling_list is None: + jcoupling_list = [] + for n in range(1, len(states)): + jcoupling_list.append( (1, n + 1) ) + + # Check jcoupling_list valid + if not len(jcoupling_list) == len(states) - 1: + raise TypeError('jcoupling_list must be length %d, got %d' % + (len(states) - 1, len(jcoupling_list))) + if not all( len(coupling) == 2 for coupling in jcoupling_list): + raise ValueError('Each coupling must define 2 spaces') + if any(n1 == n2 for n1, n2 in jcoupling_list): + raise ValueError('Spin spaces cannot couple to themselves') + if all(sympify(n1).is_number and sympify(n2).is_number for n1, n2 in jcoupling_list): + j_test = [0]*len(states) + for n1, n2 in jcoupling_list: + if j_test[n1 - 1] == -1 or j_test[n2 - 1] == -1: + raise ValueError('Spaces coupling j_n\'s are referenced by smallest n value') + j_test[max(n1, n2) - 1] = -1 + + # j values of states to be coupled together + jn = [state.j for state in states] + mn = [state.m for state in states] + + # Create coupling_list, which defines all the couplings between all + # the spaces from jcoupling_list + coupling_list = [] + n_list = [ [i + 1] for i in range(len(states)) ] + for j_coupling in jcoupling_list: + # Least n for all j_n which is coupled as first and second spaces + n1, n2 = j_coupling + # List of all n's coupled in first and second spaces + j1_n = list(n_list[n1 - 1]) + j2_n = list(n_list[n2 - 1]) + coupling_list.append( (j1_n, j2_n) ) + # Set new j_n to be coupling of all j_n in both first and second spaces + n_list[ min(n1, n2) - 1 ] = sorted(j1_n + j2_n) + + if all(state.j.is_number and state.m.is_number for state in states): + # Numerical coupling + # Iterate over difference between maximum possible j value of each coupling and the actual value + diff_max = [ Add( *[ jn[n - 1] - mn[n - 1] for n in coupling[0] + + coupling[1] ] ) for coupling in coupling_list ] + result = [] + for diff in range(diff_max[-1] + 1): + # Determine available configurations + n = len(coupling_list) + tot = binomial(diff + n - 1, diff) + + for config_num in range(tot): + diff_list = _confignum_to_difflist(config_num, diff, n) + + # Skip the configuration if non-physical + # This is a lazy check for physical states given the loose restrictions of diff_max + if any(d > m for d, m in zip(diff_list, diff_max)): + continue + + # Determine term + cg_terms = [] + coupled_j = list(jn) + jcoupling = [] + for (j1_n, j2_n), coupling_diff in zip(coupling_list, diff_list): + j1 = coupled_j[ min(j1_n) - 1 ] + j2 = coupled_j[ min(j2_n) - 1 ] + j3 = j1 + j2 - coupling_diff + coupled_j[ min(j1_n + j2_n) - 1 ] = j3 + m1 = Add( *[ mn[x - 1] for x in j1_n] ) + m2 = Add( *[ mn[x - 1] for x in j2_n] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + jcoupling.append( (min(j1_n), min(j2_n), j3) ) + # Better checks that state is physical + if any(abs(term[5]) > term[4] for term in cg_terms): + continue + if any(term[0] + term[2] < term[4] for term in cg_terms): + continue + if any(abs(term[0] - term[2]) > term[4] for term in cg_terms): + continue + coeff = Mul( *[ CG(*term).doit() for term in cg_terms] ) + state = coupled_evect(j3, m3, jn, jcoupling) + result.append(coeff*state) + return Add(*result) + else: + # Symbolic coupling + cg_terms = [] + jcoupling = [] + sum_terms = [] + coupled_j = list(jn) + for j1_n, j2_n in coupling_list: + j1 = coupled_j[ min(j1_n) - 1 ] + j2 = coupled_j[ min(j2_n) - 1 ] + if len(j1_n + j2_n) == len(states): + j3 = symbols('j') + else: + j3_name = 'j' + ''.join(["%s" % n for n in j1_n + j2_n]) + j3 = symbols(j3_name) + coupled_j[ min(j1_n + j2_n) - 1 ] = j3 + m1 = Add( *[ mn[x - 1] for x in j1_n] ) + m2 = Add( *[ mn[x - 1] for x in j2_n] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + jcoupling.append( (min(j1_n), min(j2_n), j3) ) + sum_terms.append((j3, m3, j1 + j2)) + coeff = Mul( *[ CG(*term) for term in cg_terms] ) + state = coupled_evect(j3, m3, jn, jcoupling) + return Sum(coeff*state, *sum_terms) + + +def uncouple(expr, jn=None, jcoupling_list=None): + """ Uncouple a coupled spin state + + Gives the uncoupled representation of a coupled spin state. Arguments must + be either a spin state that is a subclass of CoupledSpinState or a spin + state that is a subclass of SpinState and an array giving the j values + of the spaces that are to be coupled + + Parameters + ========== + + expr : Expr + The expression containing states that are to be coupled. If the states + are a subclass of SpinState, the ``jn`` and ``jcoupling`` parameters + must be defined. If the states are a subclass of CoupledSpinState, + ``jn`` and ``jcoupling`` will be taken from the state. + + jn : list or tuple + The list of the j-values that are coupled. If state is a + CoupledSpinState, this parameter is ignored. This must be defined if + state is not a subclass of CoupledSpinState. The syntax of this + parameter is the same as the ``jn`` parameter of JzKetCoupled. + + jcoupling_list : list or tuple + The list defining how the j-values are coupled together. If state is a + CoupledSpinState, this parameter is ignored. This must be defined if + state is not a subclass of CoupledSpinState. The syntax of this + parameter is the same as the ``jcoupling`` parameter of JzKetCoupled. + + Examples + ======== + + Uncouple a numerical state using a CoupledSpinState state: + + >>> from sympy.physics.quantum.spin import JzKetCoupled, uncouple + >>> from sympy import S + >>> uncouple(JzKetCoupled(1, 0, (S(1)/2, S(1)/2))) + sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2 + + Perform the same calculation using a SpinState state: + + >>> from sympy.physics.quantum.spin import JzKet + >>> uncouple(JzKet(1, 0), (S(1)/2, S(1)/2)) + sqrt(2)*|1/2,-1/2>x|1/2,1/2>/2 + sqrt(2)*|1/2,1/2>x|1/2,-1/2>/2 + + Uncouple a numerical state of three coupled spaces using a CoupledSpinState state: + + >>> uncouple(JzKetCoupled(1, 1, (1, 1, 1), ((1,3,1),(1,2,1)) )) + |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2 + + Perform the same calculation using a SpinState state: + + >>> uncouple(JzKet(1, 1), (1, 1, 1), ((1,3,1),(1,2,1)) ) + |1,-1>x|1,1>x|1,1>/2 - |1,0>x|1,0>x|1,1>/2 + |1,1>x|1,0>x|1,0>/2 - |1,1>x|1,1>x|1,-1>/2 + + Uncouple a symbolic state using a CoupledSpinState state: + + >>> from sympy import symbols + >>> j,m,j1,j2 = symbols('j m j1 j2') + >>> uncouple(JzKetCoupled(j, m, (j1, j2))) + Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2)) + + Perform the same calculation using a SpinState state + + >>> uncouple(JzKet(j, m), (j1, j2)) + Sum(CG(j1, m1, j2, m2, j, m)*|j1,m1>x|j2,m2>, (m1, -j1, j1), (m2, -j2, j2)) + + """ + a = expr.atoms(SpinState) + for state in a: + expr = expr.subs(state, _uncouple(state, jn, jcoupling_list)) + return expr + + +def _uncouple(state, jn, jcoupling_list): + if isinstance(state, CoupledSpinState): + jn = state.jn + coupled_n = state.coupled_n + coupled_jn = state.coupled_jn + evect = state.uncoupled_class() + elif isinstance(state, SpinState): + if jn is None: + raise ValueError("Must specify j-values for coupled state") + if not isinstance(jn, (list, tuple)): + raise TypeError("jn must be list or tuple") + if jcoupling_list is None: + # Use default + jcoupling_list = [] + for i in range(1, len(jn)): + jcoupling_list.append( + (1, 1 + i, Add(*[jn[j] for j in range(i + 1)])) ) + if not isinstance(jcoupling_list, (list, tuple)): + raise TypeError("jcoupling must be a list or tuple") + if not len(jcoupling_list) == len(jn) - 1: + raise ValueError("Must specify 2 fewer coupling terms than the number of j values") + coupled_n, coupled_jn = _build_coupled(jcoupling_list, len(jn)) + evect = state.__class__ + else: + raise TypeError("state must be a spin state") + j = state.j + m = state.m + coupling_list = [] + j_list = list(jn) + + # Create coupling, which defines all the couplings between all the spaces + for j3, (n1, n2) in zip(coupled_jn, coupled_n): + # j's which are coupled as first and second spaces + j1 = j_list[n1[0] - 1] + j2 = j_list[n2[0] - 1] + # Build coupling list + coupling_list.append( (n1, n2, j1, j2, j3) ) + # Set new value in j_list + j_list[min(n1 + n2) - 1] = j3 + + if j.is_number and m.is_number: + diff_max = [ 2*x for x in jn ] + diff = Add(*jn) - m + + n = len(jn) + tot = binomial(diff + n - 1, diff) + + result = [] + for config_num in range(tot): + diff_list = _confignum_to_difflist(config_num, diff, n) + if any(d > p for d, p in zip(diff_list, diff_max)): + continue + + cg_terms = [] + for coupling in coupling_list: + j1_n, j2_n, j1, j2, j3 = coupling + m1 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j1_n ] ) + m2 = Add( *[ jn[x - 1] - diff_list[x - 1] for x in j2_n ] ) + m3 = m1 + m2 + cg_terms.append( (j1, m1, j2, m2, j3, m3) ) + coeff = Mul( *[ CG(*term).doit() for term in cg_terms ] ) + state = TensorProduct( + *[ evect(j, j - d) for j, d in zip(jn, diff_list) ] ) + result.append(coeff*state) + return Add(*result) + else: + # Symbolic coupling + m_str = "m1:%d" % (len(jn) + 1) + mvals = symbols(m_str) + cg_terms = [(j1, Add(*[mvals[n - 1] for n in j1_n]), + j2, Add(*[mvals[n - 1] for n in j2_n]), + j3, Add(*[mvals[n - 1] for n in j1_n + j2_n])) for j1_n, j2_n, j1, j2, j3 in coupling_list[:-1] ] + cg_terms.append(*[(j1, Add(*[mvals[n - 1] for n in j1_n]), + j2, Add(*[mvals[n - 1] for n in j2_n]), + j, m) for j1_n, j2_n, j1, j2, j3 in [coupling_list[-1]] ]) + cg_coeff = Mul(*[CG(*cg_term) for cg_term in cg_terms]) + sum_terms = [ (m, -j, j) for j, m in zip(jn, mvals) ] + state = TensorProduct( *[ evect(j, m) for j, m in zip(jn, mvals) ] ) + return Sum(cg_coeff*state, *sum_terms) + + +def _confignum_to_difflist(config_num, diff, list_len): + # Determines configuration of diffs into list_len number of slots + diff_list = [] + for n in range(list_len): + prev_diff = diff + # Number of spots after current one + rem_spots = list_len - n - 1 + # Number of configurations of distributing diff among the remaining spots + rem_configs = binomial(diff + rem_spots - 1, diff) + while config_num >= rem_configs: + config_num -= rem_configs + diff -= 1 + rem_configs = binomial(diff + rem_spots - 1, diff) + diff_list.append(prev_diff - diff) + return diff_list diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/state.py b/lib/python3.10/site-packages/sympy/physics/quantum/state.py new file mode 100644 index 0000000000000000000000000000000000000000..3688a54b4fd789d400980e76ae20d2036dd9b182 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/state.py @@ -0,0 +1,1017 @@ +"""Dirac notation for states.""" + +from sympy.core.cache import cacheit +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.function import Function +from sympy.core.numbers import oo, equal_valued +from sympy.core.singleton import S +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.printing.pretty.stringpict import stringPict +from sympy.physics.quantum.qexpr import QExpr, dispatch_method + +__all__ = [ + 'KetBase', + 'BraBase', + 'StateBase', + 'State', + 'Ket', + 'Bra', + 'TimeDepState', + 'TimeDepBra', + 'TimeDepKet', + 'OrthogonalKet', + 'OrthogonalBra', + 'OrthogonalState', + 'Wavefunction' +] + + +#----------------------------------------------------------------------------- +# States, bras and kets. +#----------------------------------------------------------------------------- + +# ASCII brackets +_lbracket = "<" +_rbracket = ">" +_straight_bracket = "|" + + +# Unicode brackets +# MATHEMATICAL ANGLE BRACKETS +_lbracket_ucode = "\N{MATHEMATICAL LEFT ANGLE BRACKET}" +_rbracket_ucode = "\N{MATHEMATICAL RIGHT ANGLE BRACKET}" +# LIGHT VERTICAL BAR +_straight_bracket_ucode = "\N{LIGHT VERTICAL BAR}" + +# Other options for unicode printing of <, > and | for Dirac notation. + +# LEFT-POINTING ANGLE BRACKET +# _lbracket = "\u2329" +# _rbracket = "\u232A" + +# LEFT ANGLE BRACKET +# _lbracket = "\u3008" +# _rbracket = "\u3009" + +# VERTICAL LINE +# _straight_bracket = "\u007C" + + +class StateBase(QExpr): + """Abstract base class for general abstract states in quantum mechanics. + + All other state classes defined will need to inherit from this class. It + carries the basic structure for all other states such as dual, _eval_adjoint + and label. + + This is an abstract base class and you should not instantiate it directly, + instead use State. + """ + + @classmethod + def _operators_to_state(self, ops, **options): + """ Returns the eigenstate instance for the passed operators. + + This method should be overridden in subclasses. It will handle being + passed either an Operator instance or set of Operator instances. It + should return the corresponding state INSTANCE or simply raise a + NotImplementedError. See cartesian.py for an example. + """ + + raise NotImplementedError("Cannot map operators to states in this class. Method not implemented!") + + def _state_to_operators(self, op_classes, **options): + """ Returns the operators which this state instance is an eigenstate + of. + + This method should be overridden in subclasses. It will be called on + state instances and be passed the operator classes that we wish to make + into instances. The state instance will then transform the classes + appropriately, or raise a NotImplementedError if it cannot return + operator instances. See cartesian.py for examples, + """ + + raise NotImplementedError( + "Cannot map this state to operators. Method not implemented!") + + @property + def operators(self): + """Return the operator(s) that this state is an eigenstate of""" + from .operatorset import state_to_operators # import internally to avoid circular import errors + return state_to_operators(self) + + def _enumerate_state(self, num_states, **options): + raise NotImplementedError("Cannot enumerate this state!") + + def _represent_default_basis(self, **options): + return self._represent(basis=self.operators) + + def _apply_operator(self, op, **options): + return None + + #------------------------------------------------------------------------- + # Dagger/dual + #------------------------------------------------------------------------- + + @property + def dual(self): + """Return the dual state of this one.""" + return self.dual_class()._new_rawargs(self.hilbert_space, *self.args) + + @classmethod + def dual_class(self): + """Return the class used to construct the dual.""" + raise NotImplementedError( + 'dual_class must be implemented in a subclass' + ) + + def _eval_adjoint(self): + """Compute the dagger of this state using the dual.""" + return self.dual + + #------------------------------------------------------------------------- + # Printing + #------------------------------------------------------------------------- + + def _pretty_brackets(self, height, use_unicode=True): + # Return pretty printed brackets for the state + # Ideally, this could be done by pform.parens but it does not support the angled < and > + + # Setup for unicode vs ascii + if use_unicode: + lbracket, rbracket = getattr(self, 'lbracket_ucode', ""), getattr(self, 'rbracket_ucode', "") + slash, bslash, vert = '\N{BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT}', \ + '\N{BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT}', \ + '\N{BOX DRAWINGS LIGHT VERTICAL}' + else: + lbracket, rbracket = getattr(self, 'lbracket', ""), getattr(self, 'rbracket', "") + slash, bslash, vert = '/', '\\', '|' + + # If height is 1, just return brackets + if height == 1: + return stringPict(lbracket), stringPict(rbracket) + # Make height even + height += (height % 2) + + brackets = [] + for bracket in lbracket, rbracket: + # Create left bracket + if bracket in {_lbracket, _lbracket_ucode}: + bracket_args = [ ' ' * (height//2 - i - 1) + + slash for i in range(height // 2)] + bracket_args.extend( + [' ' * i + bslash for i in range(height // 2)]) + # Create right bracket + elif bracket in {_rbracket, _rbracket_ucode}: + bracket_args = [ ' ' * i + bslash for i in range(height // 2)] + bracket_args.extend([ ' ' * ( + height//2 - i - 1) + slash for i in range(height // 2)]) + # Create straight bracket + elif bracket in {_straight_bracket, _straight_bracket_ucode}: + bracket_args = [vert] * height + else: + raise ValueError(bracket) + brackets.append( + stringPict('\n'.join(bracket_args), baseline=height//2)) + return brackets + + def _sympystr(self, printer, *args): + contents = self._print_contents(printer, *args) + return '%s%s%s' % (getattr(self, 'lbracket', ""), contents, getattr(self, 'rbracket', "")) + + def _pretty(self, printer, *args): + from sympy.printing.pretty.stringpict import prettyForm + # Get brackets + pform = self._print_contents_pretty(printer, *args) + lbracket, rbracket = self._pretty_brackets( + pform.height(), printer._use_unicode) + # Put together state + pform = prettyForm(*pform.left(lbracket)) + pform = prettyForm(*pform.right(rbracket)) + return pform + + def _latex(self, printer, *args): + contents = self._print_contents_latex(printer, *args) + # The extra {} brackets are needed to get matplotlib's latex + # rendered to render this properly. + return '{%s%s%s}' % (getattr(self, 'lbracket_latex', ""), contents, getattr(self, 'rbracket_latex', "")) + + +class KetBase(StateBase): + """Base class for Kets. + + This class defines the dual property and the brackets for printing. This is + an abstract base class and you should not instantiate it directly, instead + use Ket. + """ + + lbracket = _straight_bracket + rbracket = _rbracket + lbracket_ucode = _straight_bracket_ucode + rbracket_ucode = _rbracket_ucode + lbracket_latex = r'\left|' + rbracket_latex = r'\right\rangle ' + + @classmethod + def default_args(self): + return ("psi",) + + @classmethod + def dual_class(self): + return BraBase + + def __mul__(self, other): + """KetBase*other""" + from sympy.physics.quantum.operator import OuterProduct + if isinstance(other, BraBase): + return OuterProduct(self, other) + else: + return Expr.__mul__(self, other) + + def __rmul__(self, other): + """other*KetBase""" + from sympy.physics.quantum.innerproduct import InnerProduct + if isinstance(other, BraBase): + return InnerProduct(other, self) + else: + return Expr.__rmul__(self, other) + + #------------------------------------------------------------------------- + # _eval_* methods + #------------------------------------------------------------------------- + + def _eval_innerproduct(self, bra, **hints): + """Evaluate the inner product between this ket and a bra. + + This is called to compute , where the ket is ``self``. + + This method will dispatch to sub-methods having the format:: + + ``def _eval_innerproduct_BraClass(self, **hints):`` + + Subclasses should define these methods (one for each BraClass) to + teach the ket how to take inner products with bras. + """ + return dispatch_method(self, '_eval_innerproduct', bra, **hints) + + def _apply_from_right_to(self, op, **options): + """Apply an Operator to this Ket as Operator*Ket + + This method will dispatch to methods having the format:: + + ``def _apply_from_right_to_OperatorName(op, **options):`` + + Subclasses should define these methods (one for each OperatorName) to + teach the Ket how to implement OperatorName*Ket + + Parameters + ========== + + op : Operator + The Operator that is acting on the Ket as op*Ket + options : dict + A dict of key/value pairs that control how the operator is applied + to the Ket. + """ + return dispatch_method(self, '_apply_from_right_to', op, **options) + + +class BraBase(StateBase): + """Base class for Bras. + + This class defines the dual property and the brackets for printing. This + is an abstract base class and you should not instantiate it directly, + instead use Bra. + """ + + lbracket = _lbracket + rbracket = _straight_bracket + lbracket_ucode = _lbracket_ucode + rbracket_ucode = _straight_bracket_ucode + lbracket_latex = r'\left\langle ' + rbracket_latex = r'\right|' + + @classmethod + def _operators_to_state(self, ops, **options): + state = self.dual_class()._operators_to_state(ops, **options) + return state.dual + + def _state_to_operators(self, op_classes, **options): + return self.dual._state_to_operators(op_classes, **options) + + def _enumerate_state(self, num_states, **options): + dual_states = self.dual._enumerate_state(num_states, **options) + return [x.dual for x in dual_states] + + @classmethod + def default_args(self): + return self.dual_class().default_args() + + @classmethod + def dual_class(self): + return KetBase + + def __mul__(self, other): + """BraBase*other""" + from sympy.physics.quantum.innerproduct import InnerProduct + if isinstance(other, KetBase): + return InnerProduct(self, other) + else: + return Expr.__mul__(self, other) + + def __rmul__(self, other): + """other*BraBase""" + from sympy.physics.quantum.operator import OuterProduct + if isinstance(other, KetBase): + return OuterProduct(other, self) + else: + return Expr.__rmul__(self, other) + + def _represent(self, **options): + """A default represent that uses the Ket's version.""" + from sympy.physics.quantum.dagger import Dagger + return Dagger(self.dual._represent(**options)) + + +class State(StateBase): + """General abstract quantum state used as a base class for Ket and Bra.""" + pass + + +class Ket(State, KetBase): + """A general time-independent Ket in quantum mechanics. + + Inherits from State and KetBase. This class should be used as the base + class for all physical, time-independent Kets in a system. This class + and its subclasses will be the main classes that users will use for + expressing Kets in Dirac notation [1]_. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + ket. This will usually be its symbol or its quantum numbers. For + time-dependent state, this will include the time. + + Examples + ======== + + Create a simple Ket and looking at its properties:: + + >>> from sympy.physics.quantum import Ket + >>> from sympy import symbols, I + >>> k = Ket('psi') + >>> k + |psi> + >>> k.hilbert_space + H + >>> k.is_commutative + False + >>> k.label + (psi,) + + Ket's know about their associated bra:: + + >>> k.dual + >> k.dual_class() + + + Take a linear combination of two kets:: + + >>> k0 = Ket(0) + >>> k1 = Ket(1) + >>> 2*I*k0 - 4*k1 + 2*I*|0> - 4*|1> + + Compound labels are passed as tuples:: + + >>> n, m = symbols('n,m') + >>> k = Ket(n,m) + >>> k + |nm> + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bra-ket_notation + """ + + @classmethod + def dual_class(self): + return Bra + + +class Bra(State, BraBase): + """A general time-independent Bra in quantum mechanics. + + Inherits from State and BraBase. A Bra is the dual of a Ket [1]_. This + class and its subclasses will be the main classes that users will use for + expressing Bras in Dirac notation. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the + ket. This will usually be its symbol or its quantum numbers. For + time-dependent state, this will include the time. + + Examples + ======== + + Create a simple Bra and look at its properties:: + + >>> from sympy.physics.quantum import Bra + >>> from sympy import symbols, I + >>> b = Bra('psi') + >>> b + >> b.hilbert_space + H + >>> b.is_commutative + False + + Bra's know about their dual Ket's:: + + >>> b.dual + |psi> + >>> b.dual_class() + + + Like Kets, Bras can have compound labels and be manipulated in a similar + manner:: + + >>> n, m = symbols('n,m') + >>> b = Bra(n,m) - I*Bra(m,n) + >>> b + -I*>> b.subs(n,m) + >> from sympy.physics.quantum import TimeDepKet + >>> k = TimeDepKet('psi', 't') + >>> k + |psi;t> + >>> k.time + t + >>> k.label + (psi,) + >>> k.hilbert_space + H + + TimeDepKets know about their dual bra:: + + >>> k.dual + >> k.dual_class() + + """ + + @classmethod + def dual_class(self): + return TimeDepBra + + +class TimeDepBra(TimeDepState, BraBase): + """General time-dependent Bra in quantum mechanics. + + This inherits from TimeDepState and BraBase and is the main class that + should be used for Bras that vary with time. Its dual is a TimeDepBra. + + Parameters + ========== + + args : tuple + The list of numbers or parameters that uniquely specify the ket. This + will usually be its symbol or its quantum numbers. For time-dependent + state, this will include the time as the final argument. + + Examples + ======== + + >>> from sympy.physics.quantum import TimeDepBra + >>> b = TimeDepBra('psi', 't') + >>> b + >> b.time + t + >>> b.label + (psi,) + >>> b.hilbert_space + H + >>> b.dual + |psi;t> + """ + + @classmethod + def dual_class(self): + return TimeDepKet + + +class OrthogonalState(State, StateBase): + """General abstract quantum state used as a base class for Ket and Bra.""" + pass + +class OrthogonalKet(OrthogonalState, KetBase): + """Orthogonal Ket in quantum mechanics. + + The inner product of two states with different labels will give zero, + states with the same label will give one. + + >>> from sympy.physics.quantum import OrthogonalBra, OrthogonalKet + >>> from sympy.abc import m, n + >>> (OrthogonalBra(n)*OrthogonalKet(n)).doit() + 1 + >>> (OrthogonalBra(n)*OrthogonalKet(n+1)).doit() + 0 + >>> (OrthogonalBra(n)*OrthogonalKet(m)).doit() + + """ + + @classmethod + def dual_class(self): + return OrthogonalBra + + def _eval_innerproduct(self, bra, **hints): + + if len(self.args) != len(bra.args): + raise ValueError('Cannot multiply a ket that has a different number of labels.') + + for arg, bra_arg in zip(self.args, bra.args): + diff = arg - bra_arg + diff = diff.expand() + + is_zero = diff.is_zero + + if is_zero is False: + return S.Zero # i.e. Integer(0) + + if is_zero is None: + return None + + return S.One # i.e. Integer(1) + + +class OrthogonalBra(OrthogonalState, BraBase): + """Orthogonal Bra in quantum mechanics. + """ + + @classmethod + def dual_class(self): + return OrthogonalKet + + +class Wavefunction(Function): + """Class for representations in continuous bases + + This class takes an expression and coordinates in its constructor. It can + be used to easily calculate normalizations and probabilities. + + Parameters + ========== + + expr : Expr + The expression representing the functional form of the w.f. + + coords : Symbol or tuple + The coordinates to be integrated over, and their bounds + + Examples + ======== + + Particle in a box, specifying bounds in the more primitive way of using + Piecewise: + + >>> from sympy import Symbol, Piecewise, pi, N + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x = Symbol('x', real=True) + >>> n = 1 + >>> L = 1 + >>> g = Piecewise((0, x < 0), (0, x > L), (sqrt(2//L)*sin(n*pi*x/L), True)) + >>> f = Wavefunction(g, x) + >>> f.norm + 1 + >>> f.is_normalized + True + >>> p = f.prob() + >>> p(0) + 0 + >>> p(L) + 0 + >>> p(0.5) + 2 + >>> p(0.85*L) + 2*sin(0.85*pi)**2 + >>> N(p(0.85*L)) + 0.412214747707527 + + Additionally, you can specify the bounds of the function and the indices in + a more compact way: + + >>> from sympy import symbols, pi, diff + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + 1 + >>> f(L+1) + 0 + >>> f(L-1) + sqrt(2)*sin(pi*n*(L - 1)/L)/sqrt(L) + >>> f(-1) + 0 + >>> f(0.85) + sqrt(2)*sin(0.85*pi*n/L)/sqrt(L) + >>> f(0.85, n=1, L=1) + sqrt(2)*sin(0.85*pi) + >>> f.is_commutative + False + + All arguments are automatically sympified, so you can define the variables + as strings rather than symbols: + + >>> expr = x**2 + >>> f = Wavefunction(expr, 'x') + >>> type(f.variables[0]) + + + Derivatives of Wavefunctions will return Wavefunctions: + + >>> diff(f, x) + Wavefunction(2*x, x) + + """ + + #Any passed tuples for coordinates and their bounds need to be + #converted to Tuples before Function's constructor is called, to + #avoid errors from calling is_Float in the constructor + def __new__(cls, *args, **options): + new_args = [None for i in args] + ct = 0 + for arg in args: + if isinstance(arg, tuple): + new_args[ct] = Tuple(*arg) + else: + new_args[ct] = arg + ct += 1 + + return super().__new__(cls, *new_args, **options) + + def __call__(self, *args, **options): + var = self.variables + + if len(args) != len(var): + raise NotImplementedError( + "Incorrect number of arguments to function!") + + ct = 0 + #If the passed value is outside the specified bounds, return 0 + for v in var: + lower, upper = self.limits[v] + + #Do the comparison to limits only if the passed symbol is actually + #a symbol present in the limits; + #Had problems with a comparison of x > L + if isinstance(args[ct], Expr) and \ + not (lower in args[ct].free_symbols + or upper in args[ct].free_symbols): + continue + + if (args[ct] < lower) == True or (args[ct] > upper) == True: + return S.Zero + + ct += 1 + + expr = self.expr + + #Allows user to make a call like f(2, 4, m=1, n=1) + for symbol in list(expr.free_symbols): + if str(symbol) in options.keys(): + val = options[str(symbol)] + expr = expr.subs(symbol, val) + + return expr.subs(zip(var, args)) + + def _eval_derivative(self, symbol): + expr = self.expr + deriv = expr._eval_derivative(symbol) + + return Wavefunction(deriv, *self.args[1:]) + + def _eval_conjugate(self): + return Wavefunction(conjugate(self.expr), *self.args[1:]) + + def _eval_transpose(self): + return self + + @property + def free_symbols(self): + return self.expr.free_symbols + + @property + def is_commutative(self): + """ + Override Function's is_commutative so that order is preserved in + represented expressions + """ + return False + + @classmethod + def eval(self, *args): + return None + + @property + def variables(self): + """ + Return the coordinates which the wavefunction depends on + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x,y = symbols('x,y') + >>> f = Wavefunction(x*y, x, y) + >>> f.variables + (x, y) + >>> g = Wavefunction(x*y, x) + >>> g.variables + (x,) + + """ + var = [g[0] if isinstance(g, Tuple) else g for g in self._args[1:]] + return tuple(var) + + @property + def limits(self): + """ + Return the limits of the coordinates which the w.f. depends on If no + limits are specified, defaults to ``(-oo, oo)``. + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x, y = symbols('x, y') + >>> f = Wavefunction(x**2, (x, 0, 1)) + >>> f.limits + {x: (0, 1)} + >>> f = Wavefunction(x**2, x) + >>> f.limits + {x: (-oo, oo)} + >>> f = Wavefunction(x**2 + y**2, x, (y, -1, 2)) + >>> f.limits + {x: (-oo, oo), y: (-1, 2)} + + """ + limits = [(g[1], g[2]) if isinstance(g, Tuple) else (-oo, oo) + for g in self._args[1:]] + return dict(zip(self.variables, tuple(limits))) + + @property + def expr(self): + """ + Return the expression which is the functional form of the Wavefunction + + Examples + ======== + + >>> from sympy.physics.quantum.state import Wavefunction + >>> from sympy import symbols + >>> x, y = symbols('x, y') + >>> f = Wavefunction(x**2, x) + >>> f.expr + x**2 + + """ + return self._args[0] + + @property + def is_normalized(self): + """ + Returns true if the Wavefunction is properly normalized + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.is_normalized + True + + """ + + return equal_valued(self.norm, 1) + + @property # type: ignore + @cacheit + def norm(self): + """ + Return the normalization of the specified functional form. + + This function integrates over the coordinates of the Wavefunction, with + the bounds specified. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sqrt, sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sqrt(2/L)*sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + 1 + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.norm + sqrt(2)*sqrt(L)/2 + + """ + + exp = self.expr*conjugate(self.expr) + var = self.variables + limits = self.limits + + for v in var: + curr_limits = limits[v] + exp = integrate(exp, (v, curr_limits[0], curr_limits[1])) + + return sqrt(exp) + + def normalize(self): + """ + Return a normalized version of the Wavefunction + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x = symbols('x', real=True) + >>> L = symbols('L', positive=True) + >>> n = symbols('n', integer=True, positive=True) + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.normalize() + Wavefunction(sqrt(2)*sin(pi*n*x/L)/sqrt(L), (x, 0, L)) + + """ + const = self.norm + + if const is oo: + raise NotImplementedError("The function is not normalizable!") + else: + return Wavefunction((const)**(-1)*self.expr, *self.args[1:]) + + def prob(self): + r""" + Return the absolute magnitude of the w.f., `|\psi(x)|^2` + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.functions import sin + >>> from sympy.physics.quantum.state import Wavefunction + >>> x, L = symbols('x,L', real=True) + >>> n = symbols('n', integer=True) + >>> g = sin(n*pi*x/L) + >>> f = Wavefunction(g, (x, 0, L)) + >>> f.prob() + Wavefunction(sin(pi*n*x/L)**2, x) + + """ + + return Wavefunction(self.expr*conjugate(self.expr), *self.variables) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tensorproduct.py b/lib/python3.10/site-packages/sympy/physics/quantum/tensorproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..334f2f66bf3e7a080f3cf6db61f8ddc48b6b67da --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tensorproduct.py @@ -0,0 +1,425 @@ +"""Abstract tensor product.""" + +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sympify import sympify +from sympy.matrices.dense import DenseMatrix as Matrix +from sympy.matrices.immutable import ImmutableDenseMatrix as ImmutableMatrix +from sympy.printing.pretty.stringpict import prettyForm + +from sympy.physics.quantum.qexpr import QuantumError +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.state import Ket, Bra +from sympy.physics.quantum.matrixutils import ( + numpy_ndarray, + scipy_sparse_matrix, + matrix_tensor_product +) +from sympy.physics.quantum.trace import Tr + + +__all__ = [ + 'TensorProduct', + 'tensor_product_simp' +] + +#----------------------------------------------------------------------------- +# Tensor product +#----------------------------------------------------------------------------- + +_combined_printing = False + + +def combined_tensor_printing(combined): + """Set flag controlling whether tensor products of states should be + printed as a combined bra/ket or as an explicit tensor product of different + bra/kets. This is a global setting for all TensorProduct class instances. + + Parameters + ---------- + combine : bool + When true, tensor product states are combined into one ket/bra, and + when false explicit tensor product notation is used between each + ket/bra. + """ + global _combined_printing + _combined_printing = combined + + +class TensorProduct(Expr): + """The tensor product of two or more arguments. + + For matrices, this uses ``matrix_tensor_product`` to compute the Kronecker + or tensor product matrix. For other objects a symbolic ``TensorProduct`` + instance is returned. The tensor product is a non-commutative + multiplication that is used primarily with operators and states in quantum + mechanics. + + Currently, the tensor product distinguishes between commutative and + non-commutative arguments. Commutative arguments are assumed to be scalars + and are pulled out in front of the ``TensorProduct``. Non-commutative + arguments remain in the resulting ``TensorProduct``. + + Parameters + ========== + + args : tuple + A sequence of the objects to take the tensor product of. + + Examples + ======== + + Start with a simple tensor product of SymPy matrices:: + + >>> from sympy import Matrix + >>> from sympy.physics.quantum import TensorProduct + + >>> m1 = Matrix([[1,2],[3,4]]) + >>> m2 = Matrix([[1,0],[0,1]]) + >>> TensorProduct(m1, m2) + Matrix([ + [1, 0, 2, 0], + [0, 1, 0, 2], + [3, 0, 4, 0], + [0, 3, 0, 4]]) + >>> TensorProduct(m2, m1) + Matrix([ + [1, 2, 0, 0], + [3, 4, 0, 0], + [0, 0, 1, 2], + [0, 0, 3, 4]]) + + We can also construct tensor products of non-commutative symbols: + + >>> from sympy import Symbol + >>> A = Symbol('A',commutative=False) + >>> B = Symbol('B',commutative=False) + >>> tp = TensorProduct(A, B) + >>> tp + AxB + + We can take the dagger of a tensor product (note the order does NOT reverse + like the dagger of a normal product): + + >>> from sympy.physics.quantum import Dagger + >>> Dagger(tp) + Dagger(A)xDagger(B) + + Expand can be used to distribute a tensor product across addition: + + >>> C = Symbol('C',commutative=False) + >>> tp = TensorProduct(A+B,C) + >>> tp + (A + B)xC + >>> tp.expand(tensorproduct=True) + AxC + BxC + """ + is_commutative = False + + def __new__(cls, *args): + if isinstance(args[0], (Matrix, ImmutableMatrix, numpy_ndarray, + scipy_sparse_matrix)): + return matrix_tensor_product(*args) + c_part, new_args = cls.flatten(sympify(args)) + c_part = Mul(*c_part) + if len(new_args) == 0: + return c_part + elif len(new_args) == 1: + return c_part * new_args[0] + else: + tp = Expr.__new__(cls, *new_args) + return c_part * tp + + @classmethod + def flatten(cls, args): + # TODO: disallow nested TensorProducts. + c_part = [] + nc_parts = [] + for arg in args: + cp, ncp = arg.args_cnc() + c_part.extend(list(cp)) + nc_parts.append(Mul._from_args(ncp)) + return c_part, nc_parts + + def _eval_adjoint(self): + return TensorProduct(*[Dagger(i) for i in self.args]) + + def _eval_rewrite(self, rule, args, **hints): + return TensorProduct(*args).expand(tensorproduct=True) + + def _sympystr(self, printer, *args): + length = len(self.args) + s = '' + for i in range(length): + if isinstance(self.args[i], (Add, Pow, Mul)): + s = s + '(' + s = s + printer._print(self.args[i]) + if isinstance(self.args[i], (Add, Pow, Mul)): + s = s + ')' + if i != length - 1: + s = s + 'x' + return s + + def _pretty(self, printer, *args): + + if (_combined_printing and + (all(isinstance(arg, Ket) for arg in self.args) or + all(isinstance(arg, Bra) for arg in self.args))): + + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print('', *args) + length_i = len(self.args[i].args) + for j in range(length_i): + part_pform = printer._print(self.args[i].args[j], *args) + next_pform = prettyForm(*next_pform.right(part_pform)) + if j != length_i - 1: + next_pform = prettyForm(*next_pform.right(', ')) + + if len(self.args[i].args) > 1: + next_pform = prettyForm( + *next_pform.parens(left='{', right='}')) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + pform = prettyForm(*pform.right(',' + ' ')) + + pform = prettyForm(*pform.left(self.args[0].lbracket)) + pform = prettyForm(*pform.right(self.args[0].rbracket)) + return pform + + length = len(self.args) + pform = printer._print('', *args) + for i in range(length): + next_pform = printer._print(self.args[i], *args) + if isinstance(self.args[i], (Add, Mul)): + next_pform = prettyForm( + *next_pform.parens(left='(', right=')') + ) + pform = prettyForm(*pform.right(next_pform)) + if i != length - 1: + if printer._use_unicode: + pform = prettyForm(*pform.right('\N{N-ARY CIRCLED TIMES OPERATOR}' + ' ')) + else: + pform = prettyForm(*pform.right('x' + ' ')) + return pform + + def _latex(self, printer, *args): + + if (_combined_printing and + (all(isinstance(arg, Ket) for arg in self.args) or + all(isinstance(arg, Bra) for arg in self.args))): + + def _label_wrap(label, nlabels): + return label if nlabels == 1 else r"\left\{%s\right\}" % label + + s = r", ".join([_label_wrap(arg._print_label_latex(printer, *args), + len(arg.args)) for arg in self.args]) + + return r"{%s%s%s}" % (self.args[0].lbracket_latex, s, + self.args[0].rbracket_latex) + + length = len(self.args) + s = '' + for i in range(length): + if isinstance(self.args[i], (Add, Mul)): + s = s + '\\left(' + # The extra {} brackets are needed to get matplotlib's latex + # rendered to render this properly. + s = s + '{' + printer._print(self.args[i], *args) + '}' + if isinstance(self.args[i], (Add, Mul)): + s = s + '\\right)' + if i != length - 1: + s = s + '\\otimes ' + return s + + def doit(self, **hints): + return TensorProduct(*[item.doit(**hints) for item in self.args]) + + def _eval_expand_tensorproduct(self, **hints): + """Distribute TensorProducts across addition.""" + args = self.args + add_args = [] + for i in range(len(args)): + if isinstance(args[i], Add): + for aa in args[i].args: + tp = TensorProduct(*args[:i] + (aa,) + args[i + 1:]) + c_part, nc_part = tp.args_cnc() + # Check for TensorProduct object: is the one object in nc_part, if any: + # (Note: any other object type to be expanded must be added here) + if len(nc_part) == 1 and isinstance(nc_part[0], TensorProduct): + nc_part = (nc_part[0]._eval_expand_tensorproduct(), ) + add_args.append(Mul(*c_part)*Mul(*nc_part)) + break + + if add_args: + return Add(*add_args) + else: + return self + + def _eval_trace(self, **kwargs): + indices = kwargs.get('indices', None) + exp = tensor_product_simp(self) + + if indices is None or len(indices) == 0: + return Mul(*[Tr(arg).doit() for arg in exp.args]) + else: + return Mul(*[Tr(value).doit() if idx in indices else value + for idx, value in enumerate(exp.args)]) + + +def tensor_product_simp_Mul(e): + """Simplify a Mul with TensorProducts. + + Current the main use of this is to simplify a ``Mul`` of ``TensorProduct``s + to a ``TensorProduct`` of ``Muls``. It currently only works for relatively + simple cases where the initial ``Mul`` only has scalars and raw + ``TensorProduct``s, not ``Add``, ``Pow``, ``Commutator``s of + ``TensorProduct``s. + + Parameters + ========== + + e : Expr + A ``Mul`` of ``TensorProduct``s to be simplified. + + Returns + ======= + + e : Expr + A ``TensorProduct`` of ``Mul``s. + + Examples + ======== + + This is an example of the type of simplification that this function + performs:: + + >>> from sympy.physics.quantum.tensorproduct import \ + tensor_product_simp_Mul, TensorProduct + >>> from sympy import Symbol + >>> A = Symbol('A',commutative=False) + >>> B = Symbol('B',commutative=False) + >>> C = Symbol('C',commutative=False) + >>> D = Symbol('D',commutative=False) + >>> e = TensorProduct(A,B)*TensorProduct(C,D) + >>> e + AxB*CxD + >>> tensor_product_simp_Mul(e) + (A*C)x(B*D) + + """ + # TODO: This won't work with Muls that have other composites of + # TensorProducts, like an Add, Commutator, etc. + # TODO: This only works for the equivalent of single Qbit gates. + if not isinstance(e, Mul): + return e + c_part, nc_part = e.args_cnc() + n_nc = len(nc_part) + if n_nc == 0: + return e + elif n_nc == 1: + if isinstance(nc_part[0], Pow): + return Mul(*c_part) * tensor_product_simp_Pow(nc_part[0]) + return e + elif e.has(TensorProduct): + current = nc_part[0] + if not isinstance(current, TensorProduct): + if isinstance(current, Pow): + if isinstance(current.base, TensorProduct): + current = tensor_product_simp_Pow(current) + else: + raise TypeError('TensorProduct expected, got: %r' % current) + n_terms = len(current.args) + new_args = list(current.args) + for next in nc_part[1:]: + # TODO: check the hilbert spaces of next and current here. + if isinstance(next, TensorProduct): + if n_terms != len(next.args): + raise QuantumError( + 'TensorProducts of different lengths: %r and %r' % + (current, next) + ) + for i in range(len(new_args)): + new_args[i] = new_args[i] * next.args[i] + else: + if isinstance(next, Pow): + if isinstance(next.base, TensorProduct): + new_tp = tensor_product_simp_Pow(next) + for i in range(len(new_args)): + new_args[i] = new_args[i] * new_tp.args[i] + else: + raise TypeError('TensorProduct expected, got: %r' % next) + else: + raise TypeError('TensorProduct expected, got: %r' % next) + current = next + return Mul(*c_part) * TensorProduct(*new_args) + elif e.has(Pow): + new_args = [ tensor_product_simp_Pow(nc) for nc in nc_part ] + return tensor_product_simp_Mul(Mul(*c_part) * TensorProduct(*new_args)) + else: + return e + +def tensor_product_simp_Pow(e): + """Evaluates ``Pow`` expressions whose base is ``TensorProduct``""" + if not isinstance(e, Pow): + return e + + if isinstance(e.base, TensorProduct): + return TensorProduct(*[ b**e.exp for b in e.base.args]) + else: + return e + +def tensor_product_simp(e, **hints): + """Try to simplify and combine TensorProducts. + + In general this will try to pull expressions inside of ``TensorProducts``. + It currently only works for relatively simple cases where the products have + only scalars, raw ``TensorProducts``, not ``Add``, ``Pow``, ``Commutators`` + of ``TensorProducts``. It is best to see what it does by showing examples. + + Examples + ======== + + >>> from sympy.physics.quantum import tensor_product_simp + >>> from sympy.physics.quantum import TensorProduct + >>> from sympy import Symbol + >>> A = Symbol('A',commutative=False) + >>> B = Symbol('B',commutative=False) + >>> C = Symbol('C',commutative=False) + >>> D = Symbol('D',commutative=False) + + First see what happens to products of tensor products: + + >>> e = TensorProduct(A,B)*TensorProduct(C,D) + >>> e + AxB*CxD + >>> tensor_product_simp(e) + (A*C)x(B*D) + + This is the core logic of this function, and it works inside, powers, sums, + commutators and anticommutators as well: + + >>> tensor_product_simp(e**2) + (A*C)x(B*D)**2 + + """ + if isinstance(e, Add): + return Add(*[tensor_product_simp(arg) for arg in e.args]) + elif isinstance(e, Pow): + if isinstance(e.base, TensorProduct): + return tensor_product_simp_Pow(e) + else: + return tensor_product_simp(e.base) ** e.exp + elif isinstance(e, Mul): + return tensor_product_simp_Mul(e) + elif isinstance(e, Commutator): + return Commutator(*[tensor_product_simp(arg) for arg in e.args]) + elif isinstance(e, AntiCommutator): + return AntiCommutator(*[tensor_product_simp(arg) for arg in e.args]) + else: + return e diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dde88364f6989fbf46966a87603850e680e5e60 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_anticommutator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_anticommutator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f48f08bd3541e2e43ecf852a6de7e500d4ef8ce9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_anticommutator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_boson.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_boson.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b2f0fb053a07d6230c4a1f0d1ddce21a9729b66 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_boson.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_cartesian.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_cartesian.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d35bd821ea8890b2a41aaa86a814efa297665ed7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_cartesian.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_cg.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_cg.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a74d139142afba76d4a743ff8a166cc9554684e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_cg.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitplot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitplot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b0b6a181a3dc8ed08f3dccae21fa7b25698a428d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitplot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23d16a66cd93ab26066369bd544f79725eaa63bf Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_circuitutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_commutator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_commutator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f839242ead24d49928f687155199f295920ec3c4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_commutator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_constants.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_constants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e445ddbf8c55f1cc325e25a18e042aff8376fdab Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_constants.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_dagger.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_dagger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90e6dd8b0a1125c629bfdf4a1eeeb481c9abc4cc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_dagger.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_density.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_density.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d24e344acf413e490688272cc594ee85c6e0a23 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_density.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_fermion.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_fermion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4b1023e19a223ce1c24ff10994cf8d4442b8ccca Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_fermion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_gate.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_gate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b198b245622a443435a746ac581f02f350f2d685 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_gate.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_grover.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_grover.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b20e1db28896a401ff3a4b3f11249a45bb347551 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_grover.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_hilbert.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_hilbert.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ab1f2de2bf455625be5009521a60b59ea3eaee3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_hilbert.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_identitysearch.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_identitysearch.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc573455be794e973716f4b6f18a7df398c4e8fd Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_identitysearch.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_innerproduct.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_innerproduct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e39af5206dbfbab95878b7537579c6d72da2211 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_innerproduct.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_matrixutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_matrixutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..757e3c12b126cb5b0bdb4ecaf07715e232f08aef Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_matrixutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9c994fa4bb30b3332a41df7a6fe5e5c8b79bc26c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorordering.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorordering.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cf5a46f3f83d55cd9b2ade3ceb79e55a10e60e2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorordering.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..004d842fc6a5b9384c634993286d1038d89a0272 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_operatorset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_pauli.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_pauli.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2799994b26033a4568527258d43608ed4aae56e8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_pauli.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_piab.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_piab.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0ac22909654752cb75cbff1618fb46038e642bf4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_piab.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_printing.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_printing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d2f6f0b8f3ae3a2035f8e2bacd0705feceb5087 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_printing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qapply.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qapply.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bbb92f882ba5a4ae4b46448ea33fa2472a150f3e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qapply.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qasm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qasm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6de000db594ee8ab4a01075d199ac2ec46c3fd3a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qasm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qexpr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qexpr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9af3fe88ae36b512bd89d954bb2df4334e57862d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qexpr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qft.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c4d3650e14076ae9904ba85c6eb58ccd10c1191 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qft.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qubit.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qubit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd657ec9d02735674ba79b28d2ae943dbdb76e9c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_qubit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_represent.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_represent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c093fa12769efb6a93986a371ac33cebb8eb2169 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_represent.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_sho1d.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_sho1d.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..275b5cb0554b06d91eaffe80f7f53ff2af3dd727 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_sho1d.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_shor.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_shor.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2c52b09c9b290f5655464ffd7b32a63cd83fb307 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_shor.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_state.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_state.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fb0e2a9104c4be16b75cc65cc83f08be4c9a256e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_state.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_tensorproduct.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_tensorproduct.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..53010c053e6b3043b93244d376a84184311d68f1 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_tensorproduct.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_trace.cpython-310.pyc b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_trace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0422673aa2acd8708f576b9301aab3a7da3d57b5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/physics/quantum/tests/__pycache__/test_trace.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_cg.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_cg.py new file mode 100644 index 0000000000000000000000000000000000000000..384512aaac7a8d984ff2a733e6349161dc9414a0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_cg.py @@ -0,0 +1,183 @@ +from sympy.concrete.summations import Sum +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.physics.quantum.cg import Wigner3j, Wigner6j, Wigner9j, CG, cg_simp +from sympy.functions.special.tensor_functions import KroneckerDelta + + +def test_cg_simp_add(): + j, m1, m1p, m2, m2p = symbols('j m1 m1p m2 m2p') + # Test Varshalovich 8.7.1 Eq 1 + a = CG(S.Half, S.Half, 0, 0, S.Half, S.Half) + b = CG(S.Half, Rational(-1, 2), 0, 0, S.Half, Rational(-1, 2)) + c = CG(1, 1, 0, 0, 1, 1) + d = CG(1, 0, 0, 0, 1, 0) + e = CG(1, -1, 0, 0, 1, -1) + assert cg_simp(a + b) == 2 + assert cg_simp(c + d + e) == 3 + assert cg_simp(a + b + c + d + e) == 5 + assert cg_simp(a + b + c) == 2 + c + assert cg_simp(2*a + b) == 2 + a + assert cg_simp(2*c + d + e) == 3 + c + assert cg_simp(5*a + 5*b) == 10 + assert cg_simp(5*c + 5*d + 5*e) == 15 + assert cg_simp(-a - b) == -2 + assert cg_simp(-c - d - e) == -3 + assert cg_simp(-6*a - 6*b) == -12 + assert cg_simp(-4*c - 4*d - 4*e) == -12 + a = CG(S.Half, S.Half, j, 0, S.Half, S.Half) + b = CG(S.Half, Rational(-1, 2), j, 0, S.Half, Rational(-1, 2)) + c = CG(1, 1, j, 0, 1, 1) + d = CG(1, 0, j, 0, 1, 0) + e = CG(1, -1, j, 0, 1, -1) + assert cg_simp(a + b) == 2*KroneckerDelta(j, 0) + assert cg_simp(c + d + e) == 3*KroneckerDelta(j, 0) + assert cg_simp(a + b + c + d + e) == 5*KroneckerDelta(j, 0) + assert cg_simp(a + b + c) == 2*KroneckerDelta(j, 0) + c + assert cg_simp(2*a + b) == 2*KroneckerDelta(j, 0) + a + assert cg_simp(2*c + d + e) == 3*KroneckerDelta(j, 0) + c + assert cg_simp(5*a + 5*b) == 10*KroneckerDelta(j, 0) + assert cg_simp(5*c + 5*d + 5*e) == 15*KroneckerDelta(j, 0) + assert cg_simp(-a - b) == -2*KroneckerDelta(j, 0) + assert cg_simp(-c - d - e) == -3*KroneckerDelta(j, 0) + assert cg_simp(-6*a - 6*b) == -12*KroneckerDelta(j, 0) + assert cg_simp(-4*c - 4*d - 4*e) == -12*KroneckerDelta(j, 0) + # Test Varshalovich 8.7.1 Eq 2 + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 0, 0) + b = CG(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0) + c = CG(1, 1, 1, -1, 0, 0) + d = CG(1, 0, 1, 0, 0, 0) + e = CG(1, -1, 1, 1, 0, 0) + assert cg_simp(a - b) == sqrt(2) + assert cg_simp(c - d + e) == sqrt(3) + assert cg_simp(a - b + c - d + e) == sqrt(2) + sqrt(3) + assert cg_simp(a - b + c) == sqrt(2) + c + assert cg_simp(2*a - b) == sqrt(2) + a + assert cg_simp(2*c - d + e) == sqrt(3) + c + assert cg_simp(5*a - 5*b) == 5*sqrt(2) + assert cg_simp(5*c - 5*d + 5*e) == 5*sqrt(3) + assert cg_simp(-a + b) == -sqrt(2) + assert cg_simp(-c + d - e) == -sqrt(3) + assert cg_simp(-6*a + 6*b) == -6*sqrt(2) + assert cg_simp(-4*c + 4*d - 4*e) == -4*sqrt(3) + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), j, 0) + b = CG(S.Half, Rational(-1, 2), S.Half, S.Half, j, 0) + c = CG(1, 1, 1, -1, j, 0) + d = CG(1, 0, 1, 0, j, 0) + e = CG(1, -1, 1, 1, j, 0) + assert cg_simp(a - b) == sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(c - d + e) == sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(a - b + c - d + e) == sqrt( + 2)*KroneckerDelta(j, 0) + sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(a - b + c) == sqrt(2)*KroneckerDelta(j, 0) + c + assert cg_simp(2*a - b) == sqrt(2)*KroneckerDelta(j, 0) + a + assert cg_simp(2*c - d + e) == sqrt(3)*KroneckerDelta(j, 0) + c + assert cg_simp(5*a - 5*b) == 5*sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(5*c - 5*d + 5*e) == 5*sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(-a + b) == -sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(-c + d - e) == -sqrt(3)*KroneckerDelta(j, 0) + assert cg_simp(-6*a + 6*b) == -6*sqrt(2)*KroneckerDelta(j, 0) + assert cg_simp(-4*c + 4*d - 4*e) == -4*sqrt(3)*KroneckerDelta(j, 0) + # Test Varshalovich 8.7.2 Eq 9 + # alpha=alphap,beta=betap case + # numerical + a = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 1, 0)**2 + b = CG(S.Half, S.Half, S.Half, Rational(-1, 2), 0, 0)**2 + c = CG(1, 0, 1, 1, 1, 1)**2 + d = CG(1, 0, 1, 1, 2, 1)**2 + assert cg_simp(a + b) == 1 + assert cg_simp(c + d) == 1 + assert cg_simp(a + b + c + d) == 2 + assert cg_simp(4*a + 4*b) == 4 + assert cg_simp(4*c + 4*d) == 4 + assert cg_simp(5*a + 3*b) == 3 + 2*a + assert cg_simp(5*c + 3*d) == 3 + 2*c + assert cg_simp(-a - b) == -1 + assert cg_simp(-c - d) == -1 + # symbolic + a = CG(S.Half, m1, S.Half, m2, 1, 1)**2 + b = CG(S.Half, m1, S.Half, m2, 1, 0)**2 + c = CG(S.Half, m1, S.Half, m2, 1, -1)**2 + d = CG(S.Half, m1, S.Half, m2, 0, 0)**2 + assert cg_simp(a + b + c + d) == 1 + assert cg_simp(4*a + 4*b + 4*c + 4*d) == 4 + assert cg_simp(3*a + 5*b + 3*c + 4*d) == 3 + 2*b + d + assert cg_simp(-a - b - c - d) == -1 + a = CG(1, m1, 1, m2, 2, 2)**2 + b = CG(1, m1, 1, m2, 2, 1)**2 + c = CG(1, m1, 1, m2, 2, 0)**2 + d = CG(1, m1, 1, m2, 2, -1)**2 + e = CG(1, m1, 1, m2, 2, -2)**2 + f = CG(1, m1, 1, m2, 1, 1)**2 + g = CG(1, m1, 1, m2, 1, 0)**2 + h = CG(1, m1, 1, m2, 1, -1)**2 + i = CG(1, m1, 1, m2, 0, 0)**2 + assert cg_simp(a + b + c + d + e + f + g + h + i) == 1 + assert cg_simp(4*(a + b + c + d + e + f + g + h + i)) == 4 + assert cg_simp(a + b + 2*c + d + 4*e + f + g + h + i) == 1 + c + 3*e + assert cg_simp(-a - b - c - d - e - f - g - h - i) == -1 + # alpha!=alphap or beta!=betap case + # numerical + a = CG(S.Half, S( + 1)/2, S.Half, Rational(-1, 2), 1, 0)*CG(S.Half, Rational(-1, 2), S.Half, S.Half, 1, 0) + b = CG(S.Half, S( + 1)/2, S.Half, Rational(-1, 2), 0, 0)*CG(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0) + c = CG(1, 1, 1, 0, 2, 1)*CG(1, 0, 1, 1, 2, 1) + d = CG(1, 1, 1, 0, 1, 1)*CG(1, 0, 1, 1, 1, 1) + assert cg_simp(a + b) == 0 + assert cg_simp(c + d) == 0 + # symbolic + a = CG(S.Half, m1, S.Half, m2, 1, 1)*CG(S.Half, m1p, S.Half, m2p, 1, 1) + b = CG(S.Half, m1, S.Half, m2, 1, 0)*CG(S.Half, m1p, S.Half, m2p, 1, 0) + c = CG(S.Half, m1, S.Half, m2, 1, -1)*CG(S.Half, m1p, S.Half, m2p, 1, -1) + d = CG(S.Half, m1, S.Half, m2, 0, 0)*CG(S.Half, m1p, S.Half, m2p, 0, 0) + assert cg_simp(a + b + c + d) == KroneckerDelta(m1, m1p)*KroneckerDelta(m2, m2p) + a = CG(1, m1, 1, m2, 2, 2)*CG(1, m1p, 1, m2p, 2, 2) + b = CG(1, m1, 1, m2, 2, 1)*CG(1, m1p, 1, m2p, 2, 1) + c = CG(1, m1, 1, m2, 2, 0)*CG(1, m1p, 1, m2p, 2, 0) + d = CG(1, m1, 1, m2, 2, -1)*CG(1, m1p, 1, m2p, 2, -1) + e = CG(1, m1, 1, m2, 2, -2)*CG(1, m1p, 1, m2p, 2, -2) + f = CG(1, m1, 1, m2, 1, 1)*CG(1, m1p, 1, m2p, 1, 1) + g = CG(1, m1, 1, m2, 1, 0)*CG(1, m1p, 1, m2p, 1, 0) + h = CG(1, m1, 1, m2, 1, -1)*CG(1, m1p, 1, m2p, 1, -1) + i = CG(1, m1, 1, m2, 0, 0)*CG(1, m1p, 1, m2p, 0, 0) + assert cg_simp( + a + b + c + d + e + f + g + h + i) == KroneckerDelta(m1, m1p)*KroneckerDelta(m2, m2p) + + +def test_cg_simp_sum(): + x, a, b, c, cp, alpha, beta, gamma, gammap = symbols( + 'x a b c cp alpha beta gamma gammap') + # Varshalovich 8.7.1 Eq 1 + assert cg_simp(x * Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a) + )) == x*(2*a + 1)*KroneckerDelta(b, 0) + assert cg_simp(x * Sum(CG(a, alpha, b, 0, a, alpha), (alpha, -a, a)) + CG(1, 0, 1, 0, 1, 0)) == x*(2*a + 1)*KroneckerDelta(b, 0) + CG(1, 0, 1, 0, 1, 0) + assert cg_simp(2 * Sum(CG(1, alpha, 0, 0, 1, alpha), (alpha, -1, 1))) == 6 + # Varshalovich 8.7.1 Eq 2 + assert cg_simp(x*Sum((-1)**(a - alpha) * CG(a, alpha, a, -alpha, c, + 0), (alpha, -a, a))) == x*sqrt(2*a + 1)*KroneckerDelta(c, 0) + assert cg_simp(3*Sum((-1)**(2 - alpha) * CG( + 2, alpha, 2, -alpha, 0, 0), (alpha, -2, 2))) == 3*sqrt(5) + # Varshalovich 8.7.2 Eq 4 + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, cp, gammap), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(c, cp)*KroneckerDelta(gamma, gammap) + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, c, gammap), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(gamma, gammap) + assert cg_simp(Sum(CG(a, alpha, b, beta, c, gamma)*CG(a, alpha, b, beta, cp, gamma), (alpha, -a, a), (beta, -b, b))) == KroneckerDelta(c, cp) + assert cg_simp(Sum(CG( + a, alpha, b, beta, c, gamma)**2, (alpha, -a, a), (beta, -b, b))) == 1 + assert cg_simp(Sum(CG(2, alpha, 1, beta, 2, gamma)*CG(2, alpha, 1, beta, 2, gammap), (alpha, -2, 2), (beta, -1, 1))) == KroneckerDelta(gamma, gammap) + + +def test_doit(): + assert Wigner3j(S.Half, Rational(-1, 2), S.Half, S.Half, 0, 0).doit() == -sqrt(2)/2 + assert Wigner3j(1/2,1/2,1/2,1/2,1/2,1/2).doit() == 0 + assert Wigner3j(9/2,9/2,9/2,9/2,9/2,9/2).doit() == 0 + assert Wigner6j(1, 2, 3, 2, 1, 2).doit() == sqrt(21)/105 + assert Wigner6j(3, 1, 2, 2, 2, 1).doit() == sqrt(21) / 105 + assert Wigner9j( + 2, 1, 1, Rational(3, 2), S.Half, 1, S.Half, S.Half, 0).doit() == sqrt(2)/12 + assert CG(S.Half, S.Half, S.Half, Rational(-1, 2), 1, 0).doit() == sqrt(2)/2 + # J minus M is not integer + assert Wigner3j(1, -1, S.Half, S.Half, 1, S.Half).doit() == 0 + assert CG(4, -1, S.Half, S.Half, 4, Rational(-1, 2)).doit() == 0 diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_circuitplot.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_circuitplot.py new file mode 100644 index 0000000000000000000000000000000000000000..fcc89f77047450ad3f8663f371f483654dc70ea9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_circuitplot.py @@ -0,0 +1,69 @@ +from sympy.physics.quantum.circuitplot import labeller, render_label, Mz, CreateOneQubitGate,\ + CreateCGate +from sympy.physics.quantum.gate import CNOT, H, SWAP, CGate, S, T +from sympy.external import import_module +from sympy.testing.pytest import skip + +mpl = import_module('matplotlib') + +def test_render_label(): + assert render_label('q0') == r'$\left|q0\right\rangle$' + assert render_label('q0', {'q0': '0'}) == r'$\left|q0\right\rangle=\left|0\right\rangle$' + +def test_Mz(): + assert str(Mz(0)) == 'Mz(0)' + +def test_create1(): + Qgate = CreateOneQubitGate('Q') + assert str(Qgate(0)) == 'Q(0)' + +def test_createc(): + Qgate = CreateCGate('Q') + assert str(Qgate([1],0)) == 'C((1),Q(0))' + +def test_labeller(): + """Test the labeller utility""" + assert labeller(2) == ['q_1', 'q_0'] + assert labeller(3,'j') == ['j_2', 'j_1', 'j_0'] + +def test_cnot(): + """Test a simple cnot circuit. Right now this only makes sure the code doesn't + raise an exception, and some simple properties + """ + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(CNOT(1,0),2,labels=labeller(2)) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == ['q_1', 'q_0'] + + c = CircuitPlot(CNOT(1,0),2) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == [] + +def test_ex1(): + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(CNOT(1,0)*H(1),2,labels=labeller(2)) + assert c.ngates == 2 + assert c.nqubits == 2 + assert c.labels == ['q_1', 'q_0'] + +def test_ex4(): + if not mpl: + skip("matplotlib not installed") + else: + from sympy.physics.quantum.circuitplot import CircuitPlot + + c = CircuitPlot(SWAP(0,2)*H(0)* CGate((0,),S(1)) *H(1)*CGate((0,),T(2))\ + *CGate((1,),S(2))*H(2),3,labels=labeller(3,'j')) + assert c.ngates == 7 + assert c.nqubits == 3 + assert c.labels == ['j_2', 'j_1', 'j_0'] diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_circuitutils.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_circuitutils.py new file mode 100644 index 0000000000000000000000000000000000000000..8ea7232320417db8bf745871cff0e77aaf1901e7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_circuitutils.py @@ -0,0 +1,402 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.core.symbol import Symbol +from sympy.utilities import numbered_symbols +from sympy.physics.quantum.gate import X, Y, Z, H, CNOT, CGate +from sympy.physics.quantum.identitysearch import bfs_identity_search +from sympy.physics.quantum.circuitutils import (kmp_table, find_subcircuit, + replace_subcircuit, convert_to_symbolic_indices, + convert_to_real_indices, random_reduce, random_insert, + flatten_ids) +from sympy.testing.pytest import slow + + +def create_gate_sequence(qubit=0): + gates = (X(qubit), Y(qubit), Z(qubit), H(qubit)) + return gates + + +def test_kmp_table(): + word = ('a', 'b', 'c', 'd', 'a', 'b', 'd') + expected_table = [-1, 0, 0, 0, 0, 1, 2] + assert expected_table == kmp_table(word) + + word = ('P', 'A', 'R', 'T', 'I', 'C', 'I', 'P', 'A', 'T', 'E', ' ', + 'I', 'N', ' ', 'P', 'A', 'R', 'A', 'C', 'H', 'U', 'T', 'E') + expected_table = [-1, 0, 0, 0, 0, 0, 0, 0, 1, 2, 0, 0, + 0, 0, 0, 0, 1, 2, 3, 0, 0, 0, 0, 0] + assert expected_table == kmp_table(word) + + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + word = (x, y, y, x, z) + expected_table = [-1, 0, 0, 0, 1] + assert expected_table == kmp_table(word) + + word = (x, x, y, h, z) + expected_table = [-1, 0, 1, 0, 0] + assert expected_table == kmp_table(word) + + +def test_find_subcircuit(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + x1 = X(1) + y1 = Y(1) + + i0 = Symbol('i0') + x_i0 = X(i0) + y_i0 = Y(i0) + z_i0 = Z(i0) + h_i0 = H(i0) + + circuit = (x, y, z) + + assert find_subcircuit(circuit, (x,)) == 0 + assert find_subcircuit(circuit, (x1,)) == -1 + assert find_subcircuit(circuit, (y,)) == 1 + assert find_subcircuit(circuit, (h,)) == -1 + assert find_subcircuit(circuit, Mul(x, h)) == -1 + assert find_subcircuit(circuit, Mul(x, y, z)) == 0 + assert find_subcircuit(circuit, Mul(y, z)) == 1 + assert find_subcircuit(Mul(*circuit), (x, y, z, h)) == -1 + assert find_subcircuit(Mul(*circuit), (z, y, x)) == -1 + assert find_subcircuit(circuit, (x,), start=2, end=1) == -1 + + circuit = (x, y, x, y, z) + assert find_subcircuit(Mul(*circuit), Mul(x, y, z)) == 2 + assert find_subcircuit(circuit, (x,), start=1) == 2 + assert find_subcircuit(circuit, (x, y), start=1, end=2) == -1 + assert find_subcircuit(Mul(*circuit), (x, y), start=1, end=3) == -1 + assert find_subcircuit(circuit, (x, y), start=1, end=4) == 2 + assert find_subcircuit(circuit, (x, y), start=2, end=4) == 2 + + circuit = (x, y, z, x1, x, y, z, h, x, y, x1, + x, y, z, h, y1, h) + assert find_subcircuit(circuit, (x, y, z, h, y1)) == 11 + + circuit = (x, y, x_i0, y_i0, z_i0, z) + assert find_subcircuit(circuit, (x_i0, y_i0, z_i0)) == 2 + + circuit = (x_i0, y_i0, z_i0, x_i0, y_i0, h_i0) + subcircuit = (x_i0, y_i0, z_i0) + result = find_subcircuit(circuit, subcircuit) + assert result == 0 + + +def test_replace_subcircuit(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + # Standard cases + circuit = (z, y, x, x) + remove = (z, y, x) + assert replace_subcircuit(circuit, Mul(*remove)) == (x,) + assert replace_subcircuit(circuit, remove + (x,)) == () + assert replace_subcircuit(circuit, remove, pos=1) == circuit + assert replace_subcircuit(circuit, remove, pos=0) == (x,) + assert replace_subcircuit(circuit, (x, x), pos=2) == (z, y) + assert replace_subcircuit(circuit, (h,)) == circuit + + circuit = (x, y, x, y, z) + remove = (x, y, z) + assert replace_subcircuit(Mul(*circuit), Mul(*remove)) == (x, y) + remove = (x, y, x, y) + assert replace_subcircuit(circuit, remove) == (z,) + + circuit = (x, h, cgate_z, h, cnot) + remove = (x, h, cgate_z) + assert replace_subcircuit(circuit, Mul(*remove), pos=-1) == (h, cnot) + assert replace_subcircuit(circuit, remove, pos=1) == circuit + remove = (h, h) + assert replace_subcircuit(circuit, remove) == circuit + remove = (h, cgate_z, h, cnot) + assert replace_subcircuit(circuit, remove) == (x,) + + replace = (h, x) + actual = replace_subcircuit(circuit, remove, + replace=replace) + assert actual == (x, h, x) + + circuit = (x, y, h, x, y, z) + remove = (x, y) + replace = (cnot, cgate_z) + actual = replace_subcircuit(circuit, remove, + replace=Mul(*replace)) + assert actual == (cnot, cgate_z, h, x, y, z) + + actual = replace_subcircuit(circuit, remove, + replace=replace, pos=1) + assert actual == (x, y, h, cnot, cgate_z, z) + + +def test_convert_to_symbolic_indices(): + (x, y, z, h) = create_gate_sequence() + + i0 = Symbol('i0') + exp_map = {i0: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x,)) + assert actual == (X(i0),) + assert act_map == exp_map + + expected = (X(i0), Y(i0), Z(i0), H(i0)) + exp_map = {i0: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x, y, z, h)) + assert actual == expected + assert exp_map == act_map + + (x1, y1, z1, h1) = create_gate_sequence(1) + i1 = Symbol('i1') + + expected = (X(i0), Y(i0), Z(i0), H(i0)) + exp_map = {i0: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x1, y1, z1, h1)) + assert actual == expected + assert act_map == exp_map + + expected = (X(i0), Y(i0), Z(i0), H(i0), X(i1), Y(i1), Z(i1), H(i1)) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x, y, z, h, + x1, y1, z1, h1)) + assert actual == expected + assert act_map == exp_map + + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(Mul(x1, y1, + z1, h1, x, y, z, h)) + assert actual == expected + assert act_map == exp_map + + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), H(i0), H(i1)) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(Mul(x, x1, + y, y1, z, z1, h, h1)) + assert actual == expected + assert act_map == exp_map + + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices((x1, x, y1, y, + z1, z, h1, h)) + assert actual == expected + assert act_map == exp_map + + cnot_10 = CNOT(1, 0) + cnot_01 = CNOT(0, 1) + cgate_z_10 = CGate(1, Z(0)) + cgate_z_01 = CGate(0, Z(1)) + + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), + H(i0), H(i1), CNOT(i1, i0), CNOT(i0, i1), + CGate(i1, Z(i0)), CGate(i0, Z(i1))) + exp_map = {i0: Integer(0), i1: Integer(1)} + args = (x, x1, y, y1, z, z1, h, h1, cnot_10, cnot_01, + cgate_z_10, cgate_z_01) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (x1, x, y1, y, z1, z, h1, h, cnot_10, cnot_01, + cgate_z_10, cgate_z_01) + expected = (X(i0), X(i1), Y(i0), Y(i1), Z(i0), Z(i1), + H(i0), H(i1), CNOT(i0, i1), CNOT(i1, i0), + CGate(i0, Z(i1)), CGate(i1, Z(i0))) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_10, h, cgate_z_01, h) + expected = (CNOT(i0, i1), H(i1), CGate(i1, Z(i0)), H(i1)) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_01, h1, cgate_z_10, h1) + exp_map = {i0: Integer(0), i1: Integer(1)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + args = (cnot_10, h1, cgate_z_01, h1) + expected = (CNOT(i0, i1), H(i0), CGate(i1, Z(i0)), H(i0)) + exp_map = {i0: Integer(1), i1: Integer(0)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + i2 = Symbol('i2') + ccgate_z = CGate(0, CGate(1, Z(2))) + ccgate_x = CGate(1, CGate(2, X(0))) + args = (ccgate_z, ccgate_x) + + expected = (CGate(i0, CGate(i1, Z(i2))), CGate(i1, CGate(i2, X(i0)))) + exp_map = {i0: Integer(0), i1: Integer(1), i2: Integer(2)} + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + ndx_map = {i0: Integer(0)} + index_gen = numbered_symbols(prefix='i', start=1) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args, + qubit_map=ndx_map, + start=i0, + gen=index_gen) + assert actual == expected + assert act_map == exp_map + + i3 = Symbol('i3') + cgate_x0_c321 = CGate((3, 2, 1), X(0)) + exp_map = {i0: Integer(3), i1: Integer(2), + i2: Integer(1), i3: Integer(0)} + expected = (CGate((i0, i1, i2), X(i3)),) + args = (cgate_x0_c321,) + actual, act_map, sndx, gen = convert_to_symbolic_indices(args) + assert actual == expected + assert act_map == exp_map + + +def test_convert_to_real_indices(): + i0 = Symbol('i0') + i1 = Symbol('i1') + + (x, y, z, h) = create_gate_sequence() + + x_i0 = X(i0) + y_i0 = Y(i0) + z_i0 = Z(i0) + + qubit_map = {i0: 0} + args = (z_i0, y_i0, x_i0) + expected = (z, y, x) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + cnot_10 = CNOT(1, 0) + cnot_01 = CNOT(0, 1) + cgate_z_10 = CGate(1, Z(0)) + cgate_z_01 = CGate(0, Z(1)) + + cnot_i1_i0 = CNOT(i1, i0) + cnot_i0_i1 = CNOT(i0, i1) + cgate_z_i1_i0 = CGate(i1, Z(i0)) + + qubit_map = {i0: 0, i1: 1} + args = (cnot_i1_i0,) + expected = (cnot_10,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + args = (cgate_z_i1_i0,) + expected = (cgate_z_10,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + args = (cnot_i0_i1,) + expected = (cnot_01,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + qubit_map = {i0: 1, i1: 0} + args = (cgate_z_i1_i0,) + expected = (cgate_z_01,) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + i2 = Symbol('i2') + ccgate_z = CGate(i0, CGate(i1, Z(i2))) + ccgate_x = CGate(i1, CGate(i2, X(i0))) + + qubit_map = {i0: 0, i1: 1, i2: 2} + args = (ccgate_z, ccgate_x) + expected = (CGate(0, CGate(1, Z(2))), CGate(1, CGate(2, X(0)))) + actual = convert_to_real_indices(Mul(*args), qubit_map) + assert actual == expected + + qubit_map = {i0: 1, i2: 0, i1: 2} + args = (ccgate_x, ccgate_z) + expected = (CGate(2, CGate(0, X(1))), CGate(1, CGate(2, Z(0)))) + actual = convert_to_real_indices(args, qubit_map) + assert actual == expected + + +@slow +def test_random_reduce(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + gate_list = [x, y, z] + ids = list(bfs_identity_search(gate_list, 1, max_depth=4)) + + circuit = (x, y, h, z, cnot) + assert random_reduce(circuit, []) == circuit + assert random_reduce(circuit, ids) == circuit + + seq = [2, 11, 9, 3, 5] + circuit = (x, y, z, x, y, h) + assert random_reduce(circuit, ids, seed=seq) == (x, y, h) + + circuit = (x, x, y, y, z, z) + assert random_reduce(circuit, ids, seed=seq) == (x, x, y, y) + + seq = [14, 13, 0] + assert random_reduce(circuit, ids, seed=seq) == (y, y, z, z) + + gate_list = [x, y, z, h, cnot, cgate_z] + ids = list(bfs_identity_search(gate_list, 2, max_depth=4)) + + seq = [25] + circuit = (x, y, z, y, h, y, h, cgate_z, h, cnot) + expected = (x, y, z, cgate_z, h, cnot) + assert random_reduce(circuit, ids, seed=seq) == expected + circuit = Mul(*circuit) + assert random_reduce(circuit, ids, seed=seq) == expected + + +@slow +def test_random_insert(): + x = X(0) + y = Y(0) + z = Z(0) + h = H(0) + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + + choices = [(x, x)] + circuit = (y, y) + loc, choice = 0, 0 + actual = random_insert(circuit, choices, seed=[loc, choice]) + assert actual == (x, x, y, y) + + circuit = (x, y, z, h) + choices = [(h, h), (x, y, z)] + expected = (x, x, y, z, y, z, h) + loc, choice = 1, 1 + actual = random_insert(circuit, choices, seed=[loc, choice]) + assert actual == expected + + gate_list = [x, y, z, h, cnot, cgate_z] + ids = list(bfs_identity_search(gate_list, 2, max_depth=4)) + + eq_ids = flatten_ids(ids) + + circuit = (x, y, h, cnot, cgate_z) + expected = (x, z, x, z, x, y, h, cnot, cgate_z) + loc, choice = 1, 30 + actual = random_insert(circuit, eq_ids, seed=[loc, choice]) + assert actual == expected + circuit = Mul(*circuit) + actual = random_insert(circuit, eq_ids, seed=[loc, choice]) + assert actual == expected diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_commutator.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_commutator.py new file mode 100644 index 0000000000000000000000000000000000000000..04f45feddaca63d7306363a9235c63f534d11430 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_commutator.py @@ -0,0 +1,81 @@ +from sympy.core.numbers import Integer +from sympy.core.symbol import symbols + +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.commutator import Commutator as Comm +from sympy.physics.quantum.operator import Operator + + +a, b, c = symbols('a,b,c') +n = symbols('n', integer=True) +A, B, C, D = symbols('A,B,C,D', commutative=False) + + +def test_commutator(): + c = Comm(A, B) + assert c.is_commutative is False + assert isinstance(c, Comm) + assert c.subs(A, C) == Comm(C, B) + + +def test_commutator_identities(): + assert Comm(a*A, b*B) == a*b*Comm(A, B) + assert Comm(A, A) == 0 + assert Comm(a, b) == 0 + assert Comm(A, B) == -Comm(B, A) + assert Comm(A, B).doit() == A*B - B*A + assert Comm(A, B*C).expand(commutator=True) == Comm(A, B)*C + B*Comm(A, C) + assert Comm(A*B, C*D).expand(commutator=True) == \ + A*C*Comm(B, D) + A*Comm(B, C)*D + C*Comm(A, D)*B + Comm(A, C)*D*B + assert Comm(A, B**2).expand(commutator=True) == Comm(A, B)*B + B*Comm(A, B) + assert Comm(A**2, C**2).expand(commutator=True) == \ + Comm(A*B, C*D).expand(commutator=True).replace(B, A).replace(D, C) == \ + A*C*Comm(A, C) + A*Comm(A, C)*C + C*Comm(A, C)*A + Comm(A, C)*C*A + assert Comm(A, C**-2).expand(commutator=True) == \ + Comm(A, (1/C)*(1/D)).expand(commutator=True).replace(D, C) + assert Comm(A + B, C + D).expand(commutator=True) == \ + Comm(A, C) + Comm(A, D) + Comm(B, C) + Comm(B, D) + assert Comm(A, B + C).expand(commutator=True) == Comm(A, B) + Comm(A, C) + assert Comm(A**n, B).expand(commutator=True) == Comm(A**n, B) + + e = Comm(A, Comm(B, C)) + Comm(B, Comm(C, A)) + Comm(C, Comm(A, B)) + assert e.doit().expand() == 0 + + +def test_commutator_dagger(): + comm = Comm(A*B, C) + assert Dagger(comm).expand(commutator=True) == \ + - Comm(Dagger(B), Dagger(C))*Dagger(A) - \ + Dagger(B)*Comm(Dagger(A), Dagger(C)) + + +class Foo(Operator): + + def _eval_commutator_Bar(self, bar): + return Integer(0) + + +class Bar(Operator): + pass + + +class Tam(Operator): + + def _eval_commutator_Foo(self, foo): + return Integer(1) + + +def test_eval_commutator(): + F = Foo('F') + B = Bar('B') + T = Tam('T') + assert Comm(F, B).doit() == 0 + assert Comm(B, F).doit() == 0 + assert Comm(F, T).doit() == -1 + assert Comm(T, F).doit() == 1 + assert Comm(B, T).doit() == B*T - T*B + assert Comm(F**2, B).expand(commutator=True).doit() == 0 + assert Comm(F**2, T).expand(commutator=True).doit() == -2*F + assert Comm(F, T**2).expand(commutator=True).doit() == -2*T + assert Comm(T**2, F).expand(commutator=True).doit() == 2*T + assert Comm(T**2, F**3).expand(commutator=True).doit() == 2*F*T*F + 2*F**2*T + 2*T*F**2 diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_constants.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_constants.py new file mode 100644 index 0000000000000000000000000000000000000000..48a773ea6b5afbaf956143b50b16b3b18aaf5beb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_constants.py @@ -0,0 +1,13 @@ +from sympy.core.numbers import Float + +from sympy.physics.quantum.constants import hbar + + +def test_hbar(): + assert hbar.is_commutative is True + assert hbar.is_real is True + assert hbar.is_positive is True + assert hbar.is_negative is False + assert hbar.is_irrational is True + + assert hbar.evalf() == Float(1.05457162e-34) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_dagger.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_dagger.py new file mode 100644 index 0000000000000000000000000000000000000000..0d379095deef60d4bc7fc90ac264e72e3ee74a11 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_dagger.py @@ -0,0 +1,102 @@ +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer) +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import conjugate +from sympy.matrices.dense import Matrix + +from sympy.physics.quantum.dagger import adjoint, Dagger +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.physics.quantum.operator import Operator, IdentityOperator + + +def test_scalars(): + x = symbols('x', complex=True) + assert Dagger(x) == conjugate(x) + assert Dagger(I*x) == -I*conjugate(x) + + i = symbols('i', real=True) + assert Dagger(i) == i + + p = symbols('p') + assert isinstance(Dagger(p), adjoint) + + i = Integer(3) + assert Dagger(i) == i + + A = symbols('A', commutative=False) + assert Dagger(A).is_commutative is False + + +def test_matrix(): + x = symbols('x') + m = Matrix([[I, x*I], [2, 4]]) + assert Dagger(m) == m.H + + +def test_dagger_mul(): + O = Operator('O') + I = IdentityOperator() + assert Dagger(O)*O == Dagger(O)*O + assert Dagger(O)*O*I == Mul(Dagger(O), O)*I + assert Dagger(O)*Dagger(O) == Dagger(O)**2 + assert Dagger(O)*Dagger(I) == Dagger(O) + + +class Foo(Expr): + + def _eval_adjoint(self): + return I + + +def test_eval_adjoint(): + f = Foo() + d = Dagger(f) + assert d == I + +np = import_module('numpy') + + +def test_numpy_dagger(): + if not np: + skip("numpy not installed.") + + a = np.array([[1.0, 2.0j], [-1.0j, 2.0]]) + adag = a.copy().transpose().conjugate() + assert (Dagger(a) == adag).all() + + +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + + +def test_scipy_sparse_dagger(): + if not np: + skip("numpy not installed.") + if not scipy: + skip("scipy not installed.") + else: + sparse = scipy.sparse + + a = sparse.csr_matrix([[1.0 + 0.0j, 2.0j], [-1.0j, 2.0 + 0.0j]]) + adag = a.copy().transpose().conjugate() + assert np.linalg.norm((Dagger(a) - adag).todense()) == 0.0 + + +def test_unknown(): + """Check treatment of unknown objects. + Objects without adjoint or conjugate/transpose methods + are sympified and wrapped in dagger. + """ + x = symbols("x") + result = Dagger(x) + assert result.args == (x,) and isinstance(result, adjoint) + + +def test_unevaluated(): + """Check that evaluate=False returns unevaluated Dagger. + """ + x = symbols("x", real=True) + assert Dagger(x) == x + result = Dagger(x, evaluate=False) + assert result.args == (x,) and isinstance(result, adjoint) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_density.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_density.py new file mode 100644 index 0000000000000000000000000000000000000000..399acce6e201b39f65ea674048198fd2f087b4d0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_density.py @@ -0,0 +1,289 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import log +from sympy.external import import_module +from sympy.physics.quantum.density import Density, entropy, fidelity +from sympy.physics.quantum.state import Ket, TimeDepKet +from sympy.physics.quantum.qubit import Qubit +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.cartesian import XKet, PxKet, PxOp, XOp +from sympy.physics.quantum.spin import JzKet +from sympy.physics.quantum.operator import OuterProduct +from sympy.physics.quantum.trace import Tr +from sympy.functions import sqrt +from sympy.testing.pytest import raises +from sympy.physics.quantum.matrixutils import scipy_sparse_matrix +from sympy.physics.quantum.tensorproduct import TensorProduct + + +def test_eval_args(): + # check instance created + assert isinstance(Density([Ket(0), 0.5], [Ket(1), 0.5]), Density) + assert isinstance(Density([Qubit('00'), 1/sqrt(2)], + [Qubit('11'), 1/sqrt(2)]), Density) + + #test if Qubit object type preserved + d = Density([Qubit('00'), 1/sqrt(2)], [Qubit('11'), 1/sqrt(2)]) + for (state, prob) in d.args: + assert isinstance(state, Qubit) + + # check for value error, when prob is not provided + raises(ValueError, lambda: Density([Ket(0)], [Ket(1)])) + + +def test_doit(): + + x, y = symbols('x y') + A, B, C, D, E, F = symbols('A B C D E F', commutative=False) + d = Density([XKet(), 0.5], [PxKet(), 0.5]) + assert (0.5*(PxKet()*Dagger(PxKet())) + + 0.5*(XKet()*Dagger(XKet()))) == d.doit() + + # check for kets with expr in them + d_with_sym = Density([XKet(x*y), 0.5], [PxKet(x*y), 0.5]) + assert (0.5*(PxKet(x*y)*Dagger(PxKet(x*y))) + + 0.5*(XKet(x*y)*Dagger(XKet(x*y)))) == d_with_sym.doit() + + d = Density([(A + B)*C, 1.0]) + assert d.doit() == (1.0*A*C*Dagger(C)*Dagger(A) + + 1.0*A*C*Dagger(C)*Dagger(B) + + 1.0*B*C*Dagger(C)*Dagger(A) + + 1.0*B*C*Dagger(C)*Dagger(B)) + + # With TensorProducts as args + # Density with simple tensor products as args + t = TensorProduct(A, B, C) + d = Density([t, 1.0]) + assert d.doit() == \ + 1.0 * TensorProduct(A*Dagger(A), B*Dagger(B), C*Dagger(C)) + + # Density with multiple Tensorproducts as states + t2 = TensorProduct(A, B) + t3 = TensorProduct(C, D) + + d = Density([t2, 0.5], [t3, 0.5]) + assert d.doit() == (0.5 * TensorProduct(A*Dagger(A), B*Dagger(B)) + + 0.5 * TensorProduct(C*Dagger(C), D*Dagger(D))) + + #Density with mixed states + d = Density([t2 + t3, 1.0]) + assert d.doit() == (1.0 * TensorProduct(A*Dagger(A), B*Dagger(B)) + + 1.0 * TensorProduct(A*Dagger(C), B*Dagger(D)) + + 1.0 * TensorProduct(C*Dagger(A), D*Dagger(B)) + + 1.0 * TensorProduct(C*Dagger(C), D*Dagger(D))) + + #Density operators with spin states + tp1 = TensorProduct(JzKet(1, 1), JzKet(1, -1)) + d = Density([tp1, 1]) + + # full trace + t = Tr(d) + assert t.doit() == 1 + + #Partial trace on density operators with spin states + t = Tr(d, [0]) + assert t.doit() == JzKet(1, -1) * Dagger(JzKet(1, -1)) + t = Tr(d, [1]) + assert t.doit() == JzKet(1, 1) * Dagger(JzKet(1, 1)) + + # with another spin state + tp2 = TensorProduct(JzKet(S.Half, S.Half), JzKet(S.Half, Rational(-1, 2))) + d = Density([tp2, 1]) + + #full trace + t = Tr(d) + assert t.doit() == 1 + + #Partial trace on density operators with spin states + t = Tr(d, [0]) + assert t.doit() == JzKet(S.Half, Rational(-1, 2)) * Dagger(JzKet(S.Half, Rational(-1, 2))) + t = Tr(d, [1]) + assert t.doit() == JzKet(S.Half, S.Half) * Dagger(JzKet(S.Half, S.Half)) + + +def test_apply_op(): + d = Density([Ket(0), 0.5], [Ket(1), 0.5]) + assert d.apply_op(XOp()) == Density([XOp()*Ket(0), 0.5], + [XOp()*Ket(1), 0.5]) + + +def test_represent(): + x, y = symbols('x y') + d = Density([XKet(), 0.5], [PxKet(), 0.5]) + assert (represent(0.5*(PxKet()*Dagger(PxKet()))) + + represent(0.5*(XKet()*Dagger(XKet())))) == represent(d) + + # check for kets with expr in them + d_with_sym = Density([XKet(x*y), 0.5], [PxKet(x*y), 0.5]) + assert (represent(0.5*(PxKet(x*y)*Dagger(PxKet(x*y)))) + + represent(0.5*(XKet(x*y)*Dagger(XKet(x*y))))) == \ + represent(d_with_sym) + + # check when given explicit basis + assert (represent(0.5*(XKet()*Dagger(XKet())), basis=PxOp()) + + represent(0.5*(PxKet()*Dagger(PxKet())), basis=PxOp())) == \ + represent(d, basis=PxOp()) + + +def test_states(): + d = Density([Ket(0), 0.5], [Ket(1), 0.5]) + states = d.states() + assert states[0] == Ket(0) and states[1] == Ket(1) + + +def test_probs(): + d = Density([Ket(0), .75], [Ket(1), 0.25]) + probs = d.probs() + assert probs[0] == 0.75 and probs[1] == 0.25 + + #probs can be symbols + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + probs = d.probs() + assert probs[0] == x and probs[1] == y + + +def test_get_state(): + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + states = (d.get_state(0), d.get_state(1)) + assert states[0] == Ket(0) and states[1] == Ket(1) + + +def test_get_prob(): + x, y = symbols('x y') + d = Density([Ket(0), x], [Ket(1), y]) + probs = (d.get_prob(0), d.get_prob(1)) + assert probs[0] == x and probs[1] == y + + +def test_entropy(): + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + d = Density((up, S.Half), (down, S.Half)) + + # test for density object + ent = entropy(d) + assert entropy(d) == log(2)/2 + assert d.entropy() == log(2)/2 + + np = import_module('numpy', min_module_version='1.4.0') + if np: + #do this test only if 'numpy' is available on test machine + np_mat = represent(d, format='numpy') + ent = entropy(np_mat) + assert isinstance(np_mat, np.ndarray) + assert ent.real == 0.69314718055994529 + assert ent.imag == 0 + + scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + if scipy and np: + #do this test only if numpy and scipy are available + mat = represent(d, format="scipy.sparse") + assert isinstance(mat, scipy_sparse_matrix) + assert ent.real == 0.69314718055994529 + assert ent.imag == 0 + + +def test_eval_trace(): + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + d = Density((up, 0.5), (down, 0.5)) + + t = Tr(d) + assert t.doit() == 1.0 + + #test dummy time dependent states + class TestTimeDepKet(TimeDepKet): + def _eval_trace(self, bra, **options): + return 1 + + x, t = symbols('x t') + k1 = TestTimeDepKet(0, 0.5) + k2 = TestTimeDepKet(0, 1) + d = Density([k1, 0.5], [k2, 0.5]) + assert d.doit() == (0.5 * OuterProduct(k1, k1.dual) + + 0.5 * OuterProduct(k2, k2.dual)) + + t = Tr(d) + assert t.doit() == 1.0 + + +def test_fidelity(): + #test with kets + up = JzKet(S.Half, S.Half) + down = JzKet(S.Half, Rational(-1, 2)) + updown = (S.One/sqrt(2))*up + (S.One/sqrt(2))*down + + #check with matrices + up_dm = represent(up * Dagger(up)) + down_dm = represent(down * Dagger(down)) + updown_dm = represent(updown * Dagger(updown)) + + assert abs(fidelity(up_dm, up_dm) - 1) < 1e-3 + assert fidelity(up_dm, down_dm) < 1e-3 + assert abs(fidelity(up_dm, updown_dm) - (S.One/sqrt(2))) < 1e-3 + assert abs(fidelity(updown_dm, down_dm) - (S.One/sqrt(2))) < 1e-3 + + #check with density + up_dm = Density([up, 1.0]) + down_dm = Density([down, 1.0]) + updown_dm = Density([updown, 1.0]) + + assert abs(fidelity(up_dm, up_dm) - 1) < 1e-3 + assert abs(fidelity(up_dm, down_dm)) < 1e-3 + assert abs(fidelity(up_dm, updown_dm) - (S.One/sqrt(2))) < 1e-3 + assert abs(fidelity(updown_dm, down_dm) - (S.One/sqrt(2))) < 1e-3 + + #check mixed states with density + updown2 = sqrt(3)/2*up + S.Half*down + d1 = Density([updown, 0.25], [updown2, 0.75]) + d2 = Density([updown, 0.75], [updown2, 0.25]) + assert abs(fidelity(d1, d2) - 0.991) < 1e-3 + assert abs(fidelity(d2, d1) - fidelity(d1, d2)) < 1e-3 + + #using qubits/density(pure states) + state1 = Qubit('0') + state2 = Qubit('1') + state3 = S.One/sqrt(2)*state1 + S.One/sqrt(2)*state2 + state4 = sqrt(Rational(2, 3))*state1 + S.One/sqrt(3)*state2 + + state1_dm = Density([state1, 1]) + state2_dm = Density([state2, 1]) + state3_dm = Density([state3, 1]) + + assert fidelity(state1_dm, state1_dm) == 1 + assert fidelity(state1_dm, state2_dm) == 0 + assert abs(fidelity(state1_dm, state3_dm) - 1/sqrt(2)) < 1e-3 + assert abs(fidelity(state3_dm, state2_dm) - 1/sqrt(2)) < 1e-3 + + #using qubits/density(mixed states) + d1 = Density([state3, 0.70], [state4, 0.30]) + d2 = Density([state3, 0.20], [state4, 0.80]) + assert abs(fidelity(d1, d1) - 1) < 1e-3 + assert abs(fidelity(d1, d2) - 0.996) < 1e-3 + assert abs(fidelity(d1, d2) - fidelity(d2, d1)) < 1e-3 + + #TODO: test for invalid arguments + # non-square matrix + mat1 = [[0, 0], + [0, 0], + [0, 0]] + + mat2 = [[0, 0], + [0, 0]] + raises(ValueError, lambda: fidelity(mat1, mat2)) + + # unequal dimensions + mat1 = [[0, 0], + [0, 0]] + mat2 = [[0, 0, 0], + [0, 0, 0], + [0, 0, 0]] + raises(ValueError, lambda: fidelity(mat1, mat2)) + + # unsupported data-type + x, y = 1, 2 # random values that is not a matrix + raises(ValueError, lambda: fidelity(x, y)) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_fermion.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_fermion.py new file mode 100644 index 0000000000000000000000000000000000000000..061648c2d5578481196949c38e90ff169fcea972 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_fermion.py @@ -0,0 +1,62 @@ +from pytest import raises + +import sympy +from sympy.physics.quantum import Dagger, AntiCommutator, qapply +from sympy.physics.quantum.fermion import FermionOp +from sympy.physics.quantum.fermion import FermionFockKet, FermionFockBra +from sympy import Symbol + + +def test_fermionoperator(): + c = FermionOp('c') + d = FermionOp('d') + + assert isinstance(c, FermionOp) + assert isinstance(Dagger(c), FermionOp) + + assert c.is_annihilation + assert not Dagger(c).is_annihilation + + assert FermionOp("c") == FermionOp("c", True) + assert FermionOp("c") != FermionOp("d") + assert FermionOp("c", True) != FermionOp("c", False) + + assert AntiCommutator(c, Dagger(c)).doit() == 1 + + assert AntiCommutator(c, Dagger(d)).doit() == c * Dagger(d) + Dagger(d) * c + + +def test_fermion_states(): + c = FermionOp("c") + + # Fock states + assert (FermionFockBra(0) * FermionFockKet(1)).doit() == 0 + assert (FermionFockBra(1) * FermionFockKet(1)).doit() == 1 + + assert qapply(c * FermionFockKet(1)) == FermionFockKet(0) + assert qapply(c * FermionFockKet(0)) == 0 + + assert qapply(Dagger(c) * FermionFockKet(0)) == FermionFockKet(1) + assert qapply(Dagger(c) * FermionFockKet(1)) == 0 + + +def test_power(): + c = FermionOp("c") + assert c**0 == 1 + assert c**1 == c + assert c**2 == 0 + assert c**3 == 0 + assert Dagger(c)**1 == Dagger(c) + assert Dagger(c)**2 == 0 + + assert (c**Symbol('a')).func == sympy.core.power.Pow + assert (c**Symbol('a')).args == (c, Symbol('a')) + + with raises(ValueError): + c**-1 + + with raises(ValueError): + c**3.2 + + with raises(TypeError): + c**1j diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_gate.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..2d7bf1d624faca8afe4b10699d23acc161ca0cdd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_gate.py @@ -0,0 +1,360 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational, pi) +from sympy.core.symbol import (Wild, symbols) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices import Matrix, ImmutableMatrix + +from sympy.physics.quantum.gate import (XGate, YGate, ZGate, random_circuit, + CNOT, IdentityGate, H, X, Y, S, T, Z, SwapGate, gate_simp, gate_sort, + CNotGate, TGate, HadamardGate, PhaseGate, UGate, CGate) +from sympy.physics.quantum.commutator import Commutator +from sympy.physics.quantum.anticommutator import AntiCommutator +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import Qubit, IntQubit, qubit_to_matrix, \ + matrix_to_qubit +from sympy.physics.quantum.matrixutils import matrix_to_zero +from sympy.physics.quantum.matrixcache import sqrt2_inv +from sympy.physics.quantum import Dagger + + +def test_gate(): + """Test a basic gate.""" + h = HadamardGate(1) + assert h.min_qubits == 2 + assert h.nqubits == 1 + + i0 = Wild('i0') + i1 = Wild('i1') + h0_w1 = HadamardGate(i0) + h0_w2 = HadamardGate(i0) + h1_w1 = HadamardGate(i1) + + assert h0_w1 == h0_w2 + assert h0_w1 != h1_w1 + assert h1_w1 != h0_w2 + + cnot_10_w1 = CNOT(i1, i0) + cnot_10_w2 = CNOT(i1, i0) + cnot_01_w1 = CNOT(i0, i1) + + assert cnot_10_w1 == cnot_10_w2 + assert cnot_10_w1 != cnot_01_w1 + assert cnot_10_w2 != cnot_01_w1 + + +def test_UGate(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + + # Test basic case where gate exists in 1-qubit space + u1 = UGate((0,), uMat) + assert represent(u1, nqubits=1) == uMat + assert qapply(u1*Qubit('0')) == a*Qubit('0') + c*Qubit('1') + assert qapply(u1*Qubit('1')) == b*Qubit('0') + d*Qubit('1') + + # Test case where gate exists in a larger space + u2 = UGate((1,), uMat) + u2Rep = represent(u2, nqubits=2) + for i in range(4): + assert u2Rep*qubit_to_matrix(IntQubit(i, 2)) == \ + qubit_to_matrix(qapply(u2*IntQubit(i, 2))) + + +def test_cgate(): + """Test the general CGate.""" + # Test single control functionality + CNOTMatrix = Matrix( + [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + assert represent(CGate(1, XGate(0)), nqubits=2) == CNOTMatrix + + # Test multiple control bit functionality + ToffoliGate = CGate((1, 2), XGate(0)) + assert represent(ToffoliGate, nqubits=3) == \ + Matrix( + [[1, 0, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 0, 1, 0, 0, 0], [0, 0, 0, 0, 0, + 1, 0, 0], [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 1, 0]]) + + ToffoliGate = CGate((3, 0), XGate(1)) + assert qapply(ToffoliGate*Qubit('1001')) == \ + matrix_to_qubit(represent(ToffoliGate*Qubit('1001'), nqubits=4)) + assert qapply(ToffoliGate*Qubit('0000')) == \ + matrix_to_qubit(represent(ToffoliGate*Qubit('0000'), nqubits=4)) + + CYGate = CGate(1, YGate(0)) + CYGate_matrix = Matrix( + ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 0, -I), (0, 0, I, 0))) + # Test 2 qubit controlled-Y gate decompose method. + assert represent(CYGate.decompose(), nqubits=2) == CYGate_matrix + + CZGate = CGate(0, ZGate(1)) + CZGate_matrix = Matrix( + ((1, 0, 0, 0), (0, 1, 0, 0), (0, 0, 1, 0), (0, 0, 0, -1))) + assert qapply(CZGate*Qubit('11')) == -Qubit('11') + assert matrix_to_qubit(represent(CZGate*Qubit('11'), nqubits=2)) == \ + -Qubit('11') + # Test 2 qubit controlled-Z gate decompose method. + assert represent(CZGate.decompose(), nqubits=2) == CZGate_matrix + + CPhaseGate = CGate(0, PhaseGate(1)) + assert qapply(CPhaseGate*Qubit('11')) == \ + I*Qubit('11') + assert matrix_to_qubit(represent(CPhaseGate*Qubit('11'), nqubits=2)) == \ + I*Qubit('11') + + # Test that the dagger, inverse, and power of CGate is evaluated properly + assert Dagger(CZGate) == CZGate + assert pow(CZGate, 1) == Dagger(CZGate) + assert Dagger(CZGate) == CZGate.inverse() + assert Dagger(CPhaseGate) != CPhaseGate + assert Dagger(CPhaseGate) == CPhaseGate.inverse() + assert Dagger(CPhaseGate) == pow(CPhaseGate, -1) + assert pow(CPhaseGate, -1) == CPhaseGate.inverse() + + +def test_UGate_CGate_combo(): + a, b, c, d = symbols('a,b,c,d') + uMat = Matrix([[a, b], [c, d]]) + cMat = Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, a, b], [0, 0, c, d]]) + + # Test basic case where gate exists in 1-qubit space. + u1 = UGate((0,), uMat) + cu1 = CGate(1, u1) + assert represent(cu1, nqubits=2) == cMat + assert qapply(cu1*Qubit('10')) == a*Qubit('10') + c*Qubit('11') + assert qapply(cu1*Qubit('11')) == b*Qubit('10') + d*Qubit('11') + assert qapply(cu1*Qubit('01')) == Qubit('01') + assert qapply(cu1*Qubit('00')) == Qubit('00') + + # Test case where gate exists in a larger space. + u2 = UGate((1,), uMat) + u2Rep = represent(u2, nqubits=2) + for i in range(4): + assert u2Rep*qubit_to_matrix(IntQubit(i, 2)) == \ + qubit_to_matrix(qapply(u2*IntQubit(i, 2))) + +def test_UGate_OneQubitGate_combo(): + v, w, f, g = symbols('v w f g') + uMat1 = ImmutableMatrix([[v, w], [f, g]]) + cMat1 = Matrix([[v, w + 1, 0, 0], [f + 1, g, 0, 0], [0, 0, v, w + 1], [0, 0, f + 1, g]]) + u1 = X(0) + UGate(0, uMat1) + assert represent(u1, nqubits=2) == cMat1 + + uMat2 = ImmutableMatrix([[1/sqrt(2), 1/sqrt(2)], [I/sqrt(2), -I/sqrt(2)]]) + cMat2_1 = Matrix([[Rational(1, 2) + I/2, Rational(1, 2) - I/2], + [Rational(1, 2) - I/2, Rational(1, 2) + I/2]]) + cMat2_2 = Matrix([[1, 0], [0, I]]) + u2 = UGate(0, uMat2) + assert represent(H(0)*u2, nqubits=1) == cMat2_1 + assert represent(u2*H(0), nqubits=1) == cMat2_2 + +def test_represent_hadamard(): + """Test the representation of the hadamard gate.""" + circuit = HadamardGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + # Check that the answers are same to within an epsilon. + assert answer == Matrix([sqrt2_inv, sqrt2_inv, 0, 0]) + + +def test_represent_xgate(): + """Test the representation of the X gate.""" + circuit = XGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([0, 1, 0, 0]) == answer + + +def test_represent_ygate(): + """Test the representation of the Y gate.""" + circuit = YGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert answer[0] == 0 and answer[1] == I and \ + answer[2] == 0 and answer[3] == 0 + + +def test_represent_zgate(): + """Test the representation of the Z gate.""" + circuit = ZGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([1, 0, 0, 0]) == answer + + +def test_represent_phasegate(): + """Test the representation of the S gate.""" + circuit = PhaseGate(0)*Qubit('01') + answer = represent(circuit, nqubits=2) + assert Matrix([0, I, 0, 0]) == answer + + +def test_represent_tgate(): + """Test the representation of the T gate.""" + circuit = TGate(0)*Qubit('01') + assert Matrix([0, exp(I*pi/4), 0, 0]) == represent(circuit, nqubits=2) + + +def test_compound_gates(): + """Test a compound gate representation.""" + circuit = YGate(0)*ZGate(0)*XGate(0)*HadamardGate(0)*Qubit('00') + answer = represent(circuit, nqubits=2) + assert Matrix([I/sqrt(2), I/sqrt(2), 0, 0]) == answer + + +def test_cnot_gate(): + """Test the CNOT gate.""" + circuit = CNotGate(1, 0) + assert represent(circuit, nqubits=2) == \ + Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 0, 1], [0, 0, 1, 0]]) + circuit = circuit*Qubit('111') + assert matrix_to_qubit(represent(circuit, nqubits=3)) == \ + qapply(circuit) + + circuit = CNotGate(1, 0) + assert Dagger(circuit) == circuit + assert Dagger(Dagger(circuit)) == circuit + assert circuit*circuit == 1 + + +def test_gate_sort(): + """Test gate_sort.""" + for g in (X, Y, Z, H, S, T): + assert gate_sort(g(2)*g(1)*g(0)) == g(0)*g(1)*g(2) + e = gate_sort(X(1)*H(0)**2*CNOT(0, 1)*X(1)*X(0)) + assert e == H(0)**2*CNOT(0, 1)*X(0)*X(1)**2 + assert gate_sort(Z(0)*X(0)) == -X(0)*Z(0) + assert gate_sort(Z(0)*X(0)**2) == X(0)**2*Z(0) + assert gate_sort(Y(0)*H(0)) == -H(0)*Y(0) + assert gate_sort(Y(0)*X(0)) == -X(0)*Y(0) + assert gate_sort(Z(0)*Y(0)) == -Y(0)*Z(0) + assert gate_sort(T(0)*S(0)) == S(0)*T(0) + assert gate_sort(Z(0)*S(0)) == S(0)*Z(0) + assert gate_sort(Z(0)*T(0)) == T(0)*Z(0) + assert gate_sort(Z(0)*CNOT(0, 1)) == CNOT(0, 1)*Z(0) + assert gate_sort(S(0)*CNOT(0, 1)) == CNOT(0, 1)*S(0) + assert gate_sort(T(0)*CNOT(0, 1)) == CNOT(0, 1)*T(0) + assert gate_sort(X(1)*CNOT(0, 1)) == CNOT(0, 1)*X(1) + # This takes a long time and should only be uncommented once in a while. + # nqubits = 5 + # ngates = 10 + # trials = 10 + # for i in range(trials): + # c = random_circuit(ngates, nqubits) + # assert represent(c, nqubits=nqubits) == \ + # represent(gate_sort(c), nqubits=nqubits) + + +def test_gate_simp(): + """Test gate_simp.""" + e = H(0)*X(1)*H(0)**2*CNOT(0, 1)*X(1)**3*X(0)*Z(3)**2*S(4)**3 + assert gate_simp(e) == H(0)*CNOT(0, 1)*S(4)*X(0)*Z(4) + assert gate_simp(X(0)*X(0)) == 1 + assert gate_simp(Y(0)*Y(0)) == 1 + assert gate_simp(Z(0)*Z(0)) == 1 + assert gate_simp(H(0)*H(0)) == 1 + assert gate_simp(T(0)*T(0)) == S(0) + assert gate_simp(S(0)*S(0)) == Z(0) + assert gate_simp(Integer(1)) == Integer(1) + assert gate_simp(X(0)**2 + Y(0)**2) == Integer(2) + + +def test_swap_gate(): + """Test the SWAP gate.""" + swap_gate_matrix = Matrix( + ((1, 0, 0, 0), (0, 0, 1, 0), (0, 1, 0, 0), (0, 0, 0, 1))) + assert represent(SwapGate(1, 0).decompose(), nqubits=2) == swap_gate_matrix + assert qapply(SwapGate(1, 3)*Qubit('0010')) == Qubit('1000') + nqubits = 4 + for i in range(nqubits): + for j in range(i): + assert represent(SwapGate(i, j), nqubits=nqubits) == \ + represent(SwapGate(i, j).decompose(), nqubits=nqubits) + + +def test_one_qubit_commutators(): + """Test single qubit gate commutation relations.""" + for g1 in (IdentityGate, X, Y, Z, H, T, S): + for g2 in (IdentityGate, X, Y, Z, H, T, S): + e = Commutator(g1(0), g2(0)) + a = matrix_to_zero(represent(e, nqubits=1, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=1, format='sympy')) + assert a == b + + e = Commutator(g1(0), g2(1)) + assert e.doit() == 0 + + +def test_one_qubit_anticommutators(): + """Test single qubit gate anticommutation relations.""" + for g1 in (IdentityGate, X, Y, Z, H): + for g2 in (IdentityGate, X, Y, Z, H): + e = AntiCommutator(g1(0), g2(0)) + a = matrix_to_zero(represent(e, nqubits=1, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=1, format='sympy')) + assert a == b + e = AntiCommutator(g1(0), g2(1)) + a = matrix_to_zero(represent(e, nqubits=2, format='sympy')) + b = matrix_to_zero(represent(e.doit(), nqubits=2, format='sympy')) + assert a == b + + +def test_cnot_commutators(): + """Test commutators of involving CNOT gates.""" + assert Commutator(CNOT(0, 1), Z(0)).doit() == 0 + assert Commutator(CNOT(0, 1), T(0)).doit() == 0 + assert Commutator(CNOT(0, 1), S(0)).doit() == 0 + assert Commutator(CNOT(0, 1), X(1)).doit() == 0 + assert Commutator(CNOT(0, 1), CNOT(0, 1)).doit() == 0 + assert Commutator(CNOT(0, 1), CNOT(0, 2)).doit() == 0 + assert Commutator(CNOT(0, 2), CNOT(0, 1)).doit() == 0 + assert Commutator(CNOT(1, 2), CNOT(1, 0)).doit() == 0 + + +def test_random_circuit(): + c = random_circuit(10, 3) + assert isinstance(c, Mul) + m = represent(c, nqubits=3) + assert m.shape == (8, 8) + assert isinstance(m, Matrix) + + +def test_hermitian_XGate(): + x = XGate(1, 2) + x_dagger = Dagger(x) + + assert (x == x_dagger) + + +def test_hermitian_YGate(): + y = YGate(1, 2) + y_dagger = Dagger(y) + + assert (y == y_dagger) + + +def test_hermitian_ZGate(): + z = ZGate(1, 2) + z_dagger = Dagger(z) + + assert (z == z_dagger) + + +def test_unitary_XGate(): + x = XGate(1, 2) + x_dagger = Dagger(x) + + assert (x*x_dagger == 1) + + +def test_unitary_YGate(): + y = YGate(1, 2) + y_dagger = Dagger(y) + + assert (y*y_dagger == 1) + + +def test_unitary_ZGate(): + z = ZGate(1, 2) + z_dagger = Dagger(z) + + assert (z*z_dagger == 1) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_grover.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_grover.py new file mode 100644 index 0000000000000000000000000000000000000000..b93a5bc5e59380a993dc34e4a160e75f799b3493 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_grover.py @@ -0,0 +1,92 @@ +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.physics.quantum.represent import represent +from sympy.physics.quantum.qapply import qapply +from sympy.physics.quantum.qubit import IntQubit +from sympy.physics.quantum.grover import (apply_grover, superposition_basis, + OracleGate, grover_iteration, WGate) + + +def return_one_on_two(qubits): + return qubits == IntQubit(2, qubits.nqubits) + + +def return_one_on_one(qubits): + return qubits == IntQubit(1, nqubits=qubits.nqubits) + + +def test_superposition_basis(): + nbits = 2 + first_half_state = IntQubit(0, nqubits=nbits)/2 + IntQubit(1, nqubits=nbits)/2 + second_half_state = IntQubit(2, nbits)/2 + IntQubit(3, nbits)/2 + assert first_half_state + second_half_state == superposition_basis(nbits) + + nbits = 3 + firstq = (1/sqrt(8))*IntQubit(0, nqubits=nbits) + (1/sqrt(8))*IntQubit(1, nqubits=nbits) + secondq = (1/sqrt(8))*IntQubit(2, nbits) + (1/sqrt(8))*IntQubit(3, nbits) + thirdq = (1/sqrt(8))*IntQubit(4, nbits) + (1/sqrt(8))*IntQubit(5, nbits) + fourthq = (1/sqrt(8))*IntQubit(6, nbits) + (1/sqrt(8))*IntQubit(7, nbits) + assert firstq + secondq + thirdq + fourthq == superposition_basis(nbits) + + +def test_OracleGate(): + v = OracleGate(1, lambda qubits: qubits == IntQubit(0)) + assert qapply(v*IntQubit(0)) == -IntQubit(0) + assert qapply(v*IntQubit(1)) == IntQubit(1) + + nbits = 2 + v = OracleGate(2, return_one_on_two) + assert qapply(v*IntQubit(0, nbits)) == IntQubit(0, nqubits=nbits) + assert qapply(v*IntQubit(1, nbits)) == IntQubit(1, nqubits=nbits) + assert qapply(v*IntQubit(2, nbits)) == -IntQubit(2, nbits) + assert qapply(v*IntQubit(3, nbits)) == IntQubit(3, nbits) + + assert represent(OracleGate(1, lambda qubits: qubits == IntQubit(0)), nqubits=1) == \ + Matrix([[-1, 0], [0, 1]]) + assert represent(v, nqubits=2) == Matrix([[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]]) + + +def test_WGate(): + nqubits = 2 + basis_states = superposition_basis(nqubits) + assert qapply(WGate(nqubits)*basis_states) == basis_states + + expected = ((2/sqrt(pow(2, nqubits)))*basis_states) - IntQubit(1, nqubits=nqubits) + assert qapply(WGate(nqubits)*IntQubit(1, nqubits=nqubits)) == expected + + +def test_grover_iteration_1(): + numqubits = 2 + basis_states = superposition_basis(numqubits) + v = OracleGate(numqubits, return_one_on_one) + expected = IntQubit(1, nqubits=numqubits) + assert qapply(grover_iteration(basis_states, v)) == expected + + +def test_grover_iteration_2(): + numqubits = 4 + basis_states = superposition_basis(numqubits) + v = OracleGate(numqubits, return_one_on_two) + # After (pi/4)sqrt(pow(2, n)), IntQubit(2) should have highest prob + # In this case, after around pi times (3 or 4) + iterated = grover_iteration(basis_states, v) + iterated = qapply(iterated) + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + iterated = grover_iteration(iterated, v) + iterated = qapply(iterated) + # In this case, probability was highest after 3 iterations + # Probability of Qubit('0010') was 251/256 (3) vs 781/1024 (4) + # Ask about measurement + expected = (-13*basis_states)/64 + 264*IntQubit(2, numqubits)/256 + assert qapply(expected) == iterated + + +def test_grover(): + nqubits = 2 + assert apply_grover(return_one_on_one, nqubits) == IntQubit(1, nqubits=nqubits) + + nqubits = 4 + basis_states = superposition_basis(nqubits) + expected = (-13*basis_states)/64 + 264*IntQubit(2, nqubits)/256 + assert apply_grover(return_one_on_two, 4) == qapply(expected) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_hilbert.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_hilbert.py new file mode 100644 index 0000000000000000000000000000000000000000..9a0e5c4187c6c62e14505efb1597a5cd63c23fea --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_hilbert.py @@ -0,0 +1,110 @@ +from sympy.physics.quantum.hilbert import ( + HilbertSpace, ComplexSpace, L2, FockSpace, TensorProductHilbertSpace, + DirectSumHilbertSpace, TensorPowerHilbertSpace +) + +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.printing.repr import srepr +from sympy.printing.str import sstr +from sympy.sets.sets import Interval + + +def test_hilbert_space(): + hs = HilbertSpace() + assert isinstance(hs, HilbertSpace) + assert sstr(hs) == 'H' + assert srepr(hs) == 'HilbertSpace()' + + +def test_complex_space(): + c1 = ComplexSpace(2) + assert isinstance(c1, ComplexSpace) + assert c1.dimension == 2 + assert sstr(c1) == 'C(2)' + assert srepr(c1) == 'ComplexSpace(Integer(2))' + + n = Symbol('n') + c2 = ComplexSpace(n) + assert isinstance(c2, ComplexSpace) + assert c2.dimension == n + assert sstr(c2) == 'C(n)' + assert srepr(c2) == "ComplexSpace(Symbol('n'))" + assert c2.subs(n, 2) == ComplexSpace(2) + + +def test_L2(): + b1 = L2(Interval(-oo, 1)) + assert isinstance(b1, L2) + assert b1.dimension is oo + assert b1.interval == Interval(-oo, 1) + + x = Symbol('x', real=True) + y = Symbol('y', real=True) + b2 = L2(Interval(x, y)) + assert b2.dimension is oo + assert b2.interval == Interval(x, y) + assert b2.subs(x, -1) == L2(Interval(-1, y)) + + +def test_fock_space(): + f1 = FockSpace() + f2 = FockSpace() + assert isinstance(f1, FockSpace) + assert f1.dimension is oo + assert f1 == f2 + + +def test_tensor_product(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1*hs2 + assert isinstance(h, TensorProductHilbertSpace) + assert h.dimension == 2*n + assert h.spaces == (hs1, hs2) + + h = hs2*hs2 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs2 + assert h.exp == 2 + assert h.dimension == n**2 + + f = FockSpace() + h = hs1*hs2*f + assert h.dimension is oo + + +def test_tensor_power(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1**2 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs1 + assert h.exp == 2 + assert h.dimension == 4 + + h = hs2**3 + assert isinstance(h, TensorPowerHilbertSpace) + assert h.base == hs2 + assert h.exp == 3 + assert h.dimension == n**3 + + +def test_direct_sum(): + n = Symbol('n') + hs1 = ComplexSpace(2) + hs2 = ComplexSpace(n) + + h = hs1 + hs2 + assert isinstance(h, DirectSumHilbertSpace) + assert h.dimension == 2 + n + assert h.spaces == (hs1, hs2) + + f = FockSpace() + h = hs1 + f + hs2 + assert h.dimension is oo + assert h.spaces == (hs1, f, hs2) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_identitysearch.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_identitysearch.py new file mode 100644 index 0000000000000000000000000000000000000000..8747b1f9d9630e699695f67734333f9d61581fb8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_identitysearch.py @@ -0,0 +1,492 @@ +from sympy.external import import_module +from sympy.core.mul import Mul +from sympy.core.numbers import Integer +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.gate import (X, Y, Z, H, CNOT, + IdentityGate, CGate, PhaseGate, TGate) +from sympy.physics.quantum.identitysearch import (generate_gate_rules, + generate_equivalent_ids, GateIdentity, bfs_identity_search, + is_scalar_sparse_matrix, + is_scalar_nonsparse_matrix, is_degenerate, is_reducible) +from sympy.testing.pytest import skip + + +def create_gate_sequence(qubit=0): + gates = (X(qubit), Y(qubit), Z(qubit), H(qubit)) + return gates + + +def test_generate_gate_rules_1(): + # Test with tuples + (x, y, z, h) = create_gate_sequence() + ph = PhaseGate(0) + cgate_t = CGate(0, TGate(1)) + + assert generate_gate_rules((x,)) == {((x,), ())} + + gate_rules = {((x, x), ()), + ((x,), (x,))} + assert generate_gate_rules((x, x)) == gate_rules + + gate_rules = {((x, y, x), ()), + ((y, x, x), ()), + ((x, x, y), ()), + ((y, x), (x,)), + ((x, y), (x,)), + ((y,), (x, x))} + assert generate_gate_rules((x, y, x)) == gate_rules + + gate_rules = {((x, y, z), ()), ((y, z, x), ()), ((z, x, y), ()), + ((), (x, z, y)), ((), (y, x, z)), ((), (z, y, x)), + ((x,), (z, y)), ((y, z), (x,)), ((y,), (x, z)), + ((z, x), (y,)), ((z,), (y, x)), ((x, y), (z,))} + actual = generate_gate_rules((x, y, z)) + assert actual == gate_rules + + gate_rules = { + ((), (h, z, y, x)), ((), (x, h, z, y)), ((), (y, x, h, z)), + ((), (z, y, x, h)), ((h,), (z, y, x)), ((x,), (h, z, y)), + ((y,), (x, h, z)), ((z,), (y, x, h)), ((h, x), (z, y)), + ((x, y), (h, z)), ((y, z), (x, h)), ((z, h), (y, x)), + ((h, x, y), (z,)), ((x, y, z), (h,)), ((y, z, h), (x,)), + ((z, h, x), (y,)), ((h, x, y, z), ()), ((x, y, z, h), ()), + ((y, z, h, x), ()), ((z, h, x, y), ())} + actual = generate_gate_rules((x, y, z, h)) + assert actual == gate_rules + + gate_rules = {((), (cgate_t**(-1), ph**(-1), x)), + ((), (ph**(-1), x, cgate_t**(-1))), + ((), (x, cgate_t**(-1), ph**(-1))), + ((cgate_t,), (ph**(-1), x)), + ((ph,), (x, cgate_t**(-1))), + ((x,), (cgate_t**(-1), ph**(-1))), + ((cgate_t, x), (ph**(-1),)), + ((ph, cgate_t), (x,)), + ((x, ph), (cgate_t**(-1),)), + ((cgate_t, x, ph), ()), + ((ph, cgate_t, x), ()), + ((x, ph, cgate_t), ())} + actual = generate_gate_rules((x, ph, cgate_t)) + assert actual == gate_rules + + gate_rules = {(Integer(1), cgate_t**(-1)*ph**(-1)*x), + (Integer(1), ph**(-1)*x*cgate_t**(-1)), + (Integer(1), x*cgate_t**(-1)*ph**(-1)), + (cgate_t, ph**(-1)*x), + (ph, x*cgate_t**(-1)), + (x, cgate_t**(-1)*ph**(-1)), + (cgate_t*x, ph**(-1)), + (ph*cgate_t, x), + (x*ph, cgate_t**(-1)), + (cgate_t*x*ph, Integer(1)), + (ph*cgate_t*x, Integer(1)), + (x*ph*cgate_t, Integer(1))} + actual = generate_gate_rules((x, ph, cgate_t), return_as_muls=True) + assert actual == gate_rules + + +def test_generate_gate_rules_2(): + # Test with Muls + (x, y, z, h) = create_gate_sequence() + ph = PhaseGate(0) + cgate_t = CGate(0, TGate(1)) + + # Note: 1 (type int) is not the same as 1 (type One) + expected = {(x, Integer(1))} + assert generate_gate_rules((x,), return_as_muls=True) == expected + + expected = {(Integer(1), Integer(1))} + assert generate_gate_rules(x*x, return_as_muls=True) == expected + + expected = {((), ())} + assert generate_gate_rules(x*x, return_as_muls=False) == expected + + gate_rules = {(x*y*x, Integer(1)), + (y, Integer(1)), + (y*x, x), + (x*y, x)} + assert generate_gate_rules(x*y*x, return_as_muls=True) == gate_rules + + gate_rules = {(x*y*z, Integer(1)), + (y*z*x, Integer(1)), + (z*x*y, Integer(1)), + (Integer(1), x*z*y), + (Integer(1), y*x*z), + (Integer(1), z*y*x), + (x, z*y), + (y*z, x), + (y, x*z), + (z*x, y), + (z, y*x), + (x*y, z)} + actual = generate_gate_rules(x*y*z, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {(Integer(1), h*z*y*x), + (Integer(1), x*h*z*y), + (Integer(1), y*x*h*z), + (Integer(1), z*y*x*h), + (h, z*y*x), (x, h*z*y), + (y, x*h*z), (z, y*x*h), + (h*x, z*y), (z*h, y*x), + (x*y, h*z), (y*z, x*h), + (h*x*y, z), (x*y*z, h), + (y*z*h, x), (z*h*x, y), + (h*x*y*z, Integer(1)), + (x*y*z*h, Integer(1)), + (y*z*h*x, Integer(1)), + (z*h*x*y, Integer(1))} + actual = generate_gate_rules(x*y*z*h, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {(Integer(1), cgate_t**(-1)*ph**(-1)*x), + (Integer(1), ph**(-1)*x*cgate_t**(-1)), + (Integer(1), x*cgate_t**(-1)*ph**(-1)), + (cgate_t, ph**(-1)*x), + (ph, x*cgate_t**(-1)), + (x, cgate_t**(-1)*ph**(-1)), + (cgate_t*x, ph**(-1)), + (ph*cgate_t, x), + (x*ph, cgate_t**(-1)), + (cgate_t*x*ph, Integer(1)), + (ph*cgate_t*x, Integer(1)), + (x*ph*cgate_t, Integer(1))} + actual = generate_gate_rules(x*ph*cgate_t, return_as_muls=True) + assert actual == gate_rules + + gate_rules = {((), (cgate_t**(-1), ph**(-1), x)), + ((), (ph**(-1), x, cgate_t**(-1))), + ((), (x, cgate_t**(-1), ph**(-1))), + ((cgate_t,), (ph**(-1), x)), + ((ph,), (x, cgate_t**(-1))), + ((x,), (cgate_t**(-1), ph**(-1))), + ((cgate_t, x), (ph**(-1),)), + ((ph, cgate_t), (x,)), + ((x, ph), (cgate_t**(-1),)), + ((cgate_t, x, ph), ()), + ((ph, cgate_t, x), ()), + ((x, ph, cgate_t), ())} + actual = generate_gate_rules(x*ph*cgate_t) + assert actual == gate_rules + + +def test_generate_equivalent_ids_1(): + # Test with tuples + (x, y, z, h) = create_gate_sequence() + + assert generate_equivalent_ids((x,)) == {(x,)} + assert generate_equivalent_ids((x, x)) == {(x, x)} + assert generate_equivalent_ids((x, y)) == {(x, y), (y, x)} + + gate_seq = (x, y, z) + gate_ids = {(x, y, z), (y, z, x), (z, x, y), (z, y, x), + (y, x, z), (x, z, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + gate_ids = {Mul(x, y, z), Mul(y, z, x), Mul(z, x, y), + Mul(z, y, x), Mul(y, x, z), Mul(x, z, y)} + assert generate_equivalent_ids(gate_seq, return_as_muls=True) == gate_ids + + gate_seq = (x, y, z, h) + gate_ids = {(x, y, z, h), (y, z, h, x), + (h, x, y, z), (h, z, y, x), + (z, y, x, h), (y, x, h, z), + (z, h, x, y), (x, h, z, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + gate_seq = (x, y, x, y) + gate_ids = {(x, y, x, y), (y, x, y, x)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + cgate_y = CGate((1,), y) + gate_seq = (y, cgate_y, y, cgate_y) + gate_ids = {(y, cgate_y, y, cgate_y), (cgate_y, y, cgate_y, y)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + gate_seq = (cnot, h, cgate_z, h) + gate_ids = {(cnot, h, cgate_z, h), (h, cgate_z, h, cnot), + (h, cnot, h, cgate_z), (cgate_z, h, cnot, h)} + assert generate_equivalent_ids(gate_seq) == gate_ids + + +def test_generate_equivalent_ids_2(): + # Test with Muls + (x, y, z, h) = create_gate_sequence() + + assert generate_equivalent_ids((x,), return_as_muls=True) == {x} + + gate_ids = {Integer(1)} + assert generate_equivalent_ids(x*x, return_as_muls=True) == gate_ids + + gate_ids = {x*y, y*x} + assert generate_equivalent_ids(x*y, return_as_muls=True) == gate_ids + + gate_ids = {(x, y), (y, x)} + assert generate_equivalent_ids(x*y) == gate_ids + + circuit = Mul(*(x, y, z)) + gate_ids = {x*y*z, y*z*x, z*x*y, z*y*x, + y*x*z, x*z*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + circuit = Mul(*(x, y, z, h)) + gate_ids = {x*y*z*h, y*z*h*x, + h*x*y*z, h*z*y*x, + z*y*x*h, y*x*h*z, + z*h*x*y, x*h*z*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + circuit = Mul(*(x, y, x, y)) + gate_ids = {x*y*x*y, y*x*y*x} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + cgate_y = CGate((1,), y) + circuit = Mul(*(y, cgate_y, y, cgate_y)) + gate_ids = {y*cgate_y*y*cgate_y, cgate_y*y*cgate_y*y} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + cnot = CNOT(1, 0) + cgate_z = CGate((0,), Z(1)) + circuit = Mul(*(cnot, h, cgate_z, h)) + gate_ids = {cnot*h*cgate_z*h, h*cgate_z*h*cnot, + h*cnot*h*cgate_z, cgate_z*h*cnot*h} + assert generate_equivalent_ids(circuit, return_as_muls=True) == gate_ids + + +def test_is_scalar_nonsparse_matrix(): + numqubits = 2 + id_only = False + + id_gate = (IdentityGate(1),) + actual = is_scalar_nonsparse_matrix(id_gate, numqubits, id_only) + assert actual is True + + x0 = X(0) + xx_circuit = (x0, x0) + actual = is_scalar_nonsparse_matrix(xx_circuit, numqubits, id_only) + assert actual is True + + x1 = X(1) + y1 = Y(1) + xy_circuit = (x1, y1) + actual = is_scalar_nonsparse_matrix(xy_circuit, numqubits, id_only) + assert actual is False + + z1 = Z(1) + xyz_circuit = (x1, y1, z1) + actual = is_scalar_nonsparse_matrix(xyz_circuit, numqubits, id_only) + assert actual is True + + cnot = CNOT(1, 0) + cnot_circuit = (cnot, cnot) + actual = is_scalar_nonsparse_matrix(cnot_circuit, numqubits, id_only) + assert actual is True + + h = H(0) + hh_circuit = (h, h) + actual = is_scalar_nonsparse_matrix(hh_circuit, numqubits, id_only) + assert actual is True + + h1 = H(1) + xhzh_circuit = (x1, h1, z1, h1) + actual = is_scalar_nonsparse_matrix(xhzh_circuit, numqubits, id_only) + assert actual is True + + id_only = True + actual = is_scalar_nonsparse_matrix(xhzh_circuit, numqubits, id_only) + assert actual is True + actual = is_scalar_nonsparse_matrix(xyz_circuit, numqubits, id_only) + assert actual is False + actual = is_scalar_nonsparse_matrix(cnot_circuit, numqubits, id_only) + assert actual is True + actual = is_scalar_nonsparse_matrix(hh_circuit, numqubits, id_only) + assert actual is True + + +def test_is_scalar_sparse_matrix(): + np = import_module('numpy') + if not np: + skip("numpy not installed.") + + scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) + if not scipy: + skip("scipy not installed.") + + numqubits = 2 + id_only = False + + id_gate = (IdentityGate(1),) + assert is_scalar_sparse_matrix(id_gate, numqubits, id_only) is True + + x0 = X(0) + xx_circuit = (x0, x0) + assert is_scalar_sparse_matrix(xx_circuit, numqubits, id_only) is True + + x1 = X(1) + y1 = Y(1) + xy_circuit = (x1, y1) + assert is_scalar_sparse_matrix(xy_circuit, numqubits, id_only) is False + + z1 = Z(1) + xyz_circuit = (x1, y1, z1) + assert is_scalar_sparse_matrix(xyz_circuit, numqubits, id_only) is True + + cnot = CNOT(1, 0) + cnot_circuit = (cnot, cnot) + assert is_scalar_sparse_matrix(cnot_circuit, numqubits, id_only) is True + + h = H(0) + hh_circuit = (h, h) + assert is_scalar_sparse_matrix(hh_circuit, numqubits, id_only) is True + + # NOTE: + # The elements of the sparse matrix for the following circuit + # is actually 1.0000000000000002+0.0j. + h1 = H(1) + xhzh_circuit = (x1, h1, z1, h1) + assert is_scalar_sparse_matrix(xhzh_circuit, numqubits, id_only) is True + + id_only = True + assert is_scalar_sparse_matrix(xhzh_circuit, numqubits, id_only) is True + assert is_scalar_sparse_matrix(xyz_circuit, numqubits, id_only) is False + assert is_scalar_sparse_matrix(cnot_circuit, numqubits, id_only) is True + assert is_scalar_sparse_matrix(hh_circuit, numqubits, id_only) is True + + +def test_is_degenerate(): + (x, y, z, h) = create_gate_sequence() + + gate_id = GateIdentity(x, y, z) + ids = {gate_id} + + another_id = (z, y, x) + assert is_degenerate(ids, another_id) is True + + +def test_is_reducible(): + nqubits = 2 + (x, y, z, h) = create_gate_sequence() + + circuit = (x, y, y) + assert is_reducible(circuit, nqubits, 1, 3) is True + + circuit = (x, y, x) + assert is_reducible(circuit, nqubits, 1, 3) is False + + circuit = (x, y, y, x) + assert is_reducible(circuit, nqubits, 0, 4) is True + + circuit = (x, y, y, x) + assert is_reducible(circuit, nqubits, 1, 3) is True + + circuit = (x, y, z, y, y) + assert is_reducible(circuit, nqubits, 1, 5) is True + + +def test_bfs_identity_search(): + assert bfs_identity_search([], 1) == set() + + (x, y, z, h) = create_gate_sequence() + + gate_list = [x] + id_set = {GateIdentity(x, x)} + assert bfs_identity_search(gate_list, 1, max_depth=2) == id_set + + # Set should not contain degenerate quantum circuits + gate_list = [x, y, z] + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(x, y, z)} + assert bfs_identity_search(gate_list, 1) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(y, z, y, z)} + assert bfs_identity_search(gate_list, 1, max_depth=4) == id_set + assert bfs_identity_search(gate_list, 1, max_depth=5) == id_set + + gate_list = [x, y, z, h] + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(x, h, z, h), + GateIdentity(y, z, y, z), + GateIdentity(y, h, y, h)} + assert bfs_identity_search(gate_list, 1) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h)} + assert id_set == bfs_identity_search(gate_list, 1, max_depth=3, + identity_only=True) + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, y, z), + GateIdentity(x, y, x, y), + GateIdentity(x, z, x, z), + GateIdentity(x, h, z, h), + GateIdentity(y, z, y, z), + GateIdentity(y, h, y, h), + GateIdentity(x, y, h, x, h), + GateIdentity(x, z, h, y, h), + GateIdentity(y, z, h, z, h)} + assert bfs_identity_search(gate_list, 1, max_depth=5) == id_set + + id_set = {GateIdentity(x, x), + GateIdentity(y, y), + GateIdentity(z, z), + GateIdentity(h, h), + GateIdentity(x, h, z, h)} + assert id_set == bfs_identity_search(gate_list, 1, max_depth=4, + identity_only=True) + + cnot = CNOT(1, 0) + gate_list = [x, cnot] + id_set = {GateIdentity(x, x), + GateIdentity(cnot, cnot), + GateIdentity(x, cnot, x, cnot)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + cgate_x = CGate((1,), x) + gate_list = [x, cgate_x] + id_set = {GateIdentity(x, x), + GateIdentity(cgate_x, cgate_x), + GateIdentity(x, cgate_x, x, cgate_x)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + cgate_z = CGate((0,), Z(1)) + gate_list = [cnot, cgate_z, h] + id_set = {GateIdentity(h, h), + GateIdentity(cgate_z, cgate_z), + GateIdentity(cnot, cnot), + GateIdentity(cnot, h, cgate_z, h)} + assert bfs_identity_search(gate_list, 2, max_depth=4) == id_set + + s = PhaseGate(0) + t = TGate(0) + gate_list = [s, t] + id_set = {GateIdentity(s, s, s, s)} + assert bfs_identity_search(gate_list, 1, max_depth=4) == id_set + + +def test_bfs_identity_search_xfail(): + s = PhaseGate(0) + t = TGate(0) + gate_list = [Dagger(s), t] + id_set = {GateIdentity(Dagger(s), t, t)} + assert bfs_identity_search(gate_list, 1, max_depth=3) == id_set diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_innerproduct.py b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_innerproduct.py new file mode 100644 index 0000000000000000000000000000000000000000..2632031f8a9a9ec65dfab6d834eb704a00b621d3 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/tests/test_innerproduct.py @@ -0,0 +1,71 @@ +from sympy.core.numbers import (I, Integer) + +from sympy.physics.quantum.innerproduct import InnerProduct +from sympy.physics.quantum.dagger import Dagger +from sympy.physics.quantum.state import Bra, Ket, StateBase + + +def test_innerproduct(): + k = Ket('k') + b = Bra('b') + ip = InnerProduct(b, k) + assert isinstance(ip, InnerProduct) + assert ip.bra == b + assert ip.ket == k + assert b*k == InnerProduct(b, k) + assert k*(b*k)*b == k*InnerProduct(b, k)*b + assert InnerProduct(b, k).subs(b, Dagger(k)) == Dagger(k)*k + + +def test_innerproduct_dagger(): + k = Ket('k') + b = Bra('b') + ip = b*k + assert Dagger(ip) == Dagger(k)*Dagger(b) + + +class FooState(StateBase): + pass + + +class FooKet(Ket, FooState): + + @classmethod + def dual_class(self): + return FooBra + + def _eval_innerproduct_FooBra(self, bra): + return Integer(1) + + def _eval_innerproduct_BarBra(self, bra): + return I + + +class FooBra(Bra, FooState): + @classmethod + def dual_class(self): + return FooKet + + +class BarState(StateBase): + pass + + +class BarKet(Ket, BarState): + @classmethod + def dual_class(self): + return BarBra + + +class BarBra(Bra, BarState): + @classmethod + def dual_class(self): + return BarKet + + +def test_doit(): + f = FooKet('foo') + b = BarBra('bar') + assert InnerProduct(b, f).doit() == I + assert InnerProduct(Dagger(f), Dagger(b)).doit() == -I + assert InnerProduct(Dagger(f), f).doit() == Integer(1) diff --git a/lib/python3.10/site-packages/sympy/physics/quantum/trace.py b/lib/python3.10/site-packages/sympy/physics/quantum/trace.py new file mode 100644 index 0000000000000000000000000000000000000000..03ab18f78a1bfcf5bfcd679f00eac8685144fd8c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/quantum/trace.py @@ -0,0 +1,230 @@ +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.sympify import sympify +from sympy.matrices import Matrix + + +def _is_scalar(e): + """ Helper method used in Tr""" + + # sympify to set proper attributes + e = sympify(e) + if isinstance(e, Expr): + if (e.is_Integer or e.is_Float or + e.is_Rational or e.is_Number or + (e.is_Symbol and e.is_commutative) + ): + return True + + return False + + +def _cycle_permute(l): + """ Cyclic permutations based on canonical ordering + + Explanation + =========== + + This method does the sort based ascii values while + a better approach would be to used lexicographic sort. + + TODO: Handle condition such as symbols have subscripts/superscripts + in case of lexicographic sort + + """ + + if len(l) == 1: + return l + + min_item = min(l, key=default_sort_key) + indices = [i for i, x in enumerate(l) if x == min_item] + + le = list(l) + le.extend(l) # duplicate and extend string for easy processing + + # adding the first min_item index back for easier looping + indices.append(len(l) + indices[0]) + + # create sublist of items with first item as min_item and last_item + # in each of the sublist is item just before the next occurrence of + # minitem in the cycle formed. + sublist = [[le[indices[i]:indices[i + 1]]] for i in + range(len(indices) - 1)] + + # we do comparison of strings by comparing elements + # in each sublist + idx = sublist.index(min(sublist)) + ordered_l = le[indices[idx]:indices[idx] + len(l)] + + return ordered_l + + +def _rearrange_args(l): + """ this just moves the last arg to first position + to enable expansion of args + A,B,A ==> A**2,B + """ + if len(l) == 1: + return l + + x = list(l[-1:]) + x.extend(l[0:-1]) + return Mul(*x).args + + +class Tr(Expr): + """ Generic Trace operation than can trace over: + + a) SymPy matrix + b) operators + c) outer products + + Parameters + ========== + o : operator, matrix, expr + i : tuple/list indices (optional) + + Examples + ======== + + # TODO: Need to handle printing + + a) Trace(A+B) = Tr(A) + Tr(B) + b) Trace(scalar*Operator) = scalar*Trace(Operator) + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy import symbols, Matrix + >>> a, b = symbols('a b', commutative=True) + >>> A, B = symbols('A B', commutative=False) + >>> Tr(a*A,[2]) + a*Tr(A) + >>> m = Matrix([[1,2],[1,1]]) + >>> Tr(m) + 2 + + """ + def __new__(cls, *args): + """ Construct a Trace object. + + Parameters + ========== + args = SymPy expression + indices = tuple/list if indices, optional + + """ + + # expect no indices,int or a tuple/list/Tuple + if (len(args) == 2): + if not isinstance(args[1], (list, Tuple, tuple)): + indices = Tuple(args[1]) + else: + indices = Tuple(*args[1]) + + expr = args[0] + elif (len(args) == 1): + indices = Tuple() + expr = args[0] + else: + raise ValueError("Arguments to Tr should be of form " + "(expr[, [indices]])") + + if isinstance(expr, Matrix): + return expr.trace() + elif hasattr(expr, 'trace') and callable(expr.trace): + #for any objects that have trace() defined e.g numpy + return expr.trace() + elif isinstance(expr, Add): + return Add(*[Tr(arg, indices) for arg in expr.args]) + elif isinstance(expr, Mul): + c_part, nc_part = expr.args_cnc() + if len(nc_part) == 0: + return Mul(*c_part) + else: + obj = Expr.__new__(cls, Mul(*nc_part), indices ) + #this check is needed to prevent cached instances + #being returned even if len(c_part)==0 + return Mul(*c_part)*obj if len(c_part) > 0 else obj + elif isinstance(expr, Pow): + if (_is_scalar(expr.args[0]) and + _is_scalar(expr.args[1])): + return expr + else: + return Expr.__new__(cls, expr, indices) + else: + if (_is_scalar(expr)): + return expr + + return Expr.__new__(cls, expr, indices) + + @property + def kind(self): + expr = self.args[0] + expr_kind = expr.kind + return expr_kind.element_kind + + def doit(self, **hints): + """ Perform the trace operation. + + #TODO: Current version ignores the indices set for partial trace. + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy.physics.quantum.operator import OuterProduct + >>> from sympy.physics.quantum.spin import JzKet, JzBra + >>> t = Tr(OuterProduct(JzKet(1,1), JzBra(1,1))) + >>> t.doit() + 1 + + """ + if hasattr(self.args[0], '_eval_trace'): + return self.args[0]._eval_trace(indices=self.args[1]) + + return self + + @property + def is_number(self): + # TODO : improve this implementation + return True + + #TODO: Review if the permute method is needed + # and if it needs to return a new instance + def permute(self, pos): + """ Permute the arguments cyclically. + + Parameters + ========== + + pos : integer, if positive, shift-right, else shift-left + + Examples + ======== + + >>> from sympy.physics.quantum.trace import Tr + >>> from sympy import symbols + >>> A, B, C, D = symbols('A B C D', commutative=False) + >>> t = Tr(A*B*C*D) + >>> t.permute(2) + Tr(C*D*A*B) + >>> t.permute(-2) + Tr(C*D*A*B) + + """ + if pos > 0: + pos = pos % len(self.args[0].args) + else: + pos = -(abs(pos) % len(self.args[0].args)) + + args = list(self.args[0].args[-pos:] + self.args[0].args[0:-pos]) + + return Tr(Mul(*(args))) + + def _hashable_content(self): + if isinstance(self.args[0], Mul): + args = _cycle_permute(_rearrange_args(self.args[0].args)) + else: + args = [self.args[0]] + + return tuple(args) + (self.args[1], ) diff --git a/lib/python3.10/site-packages/sympy/physics/tests/__init__.py b/lib/python3.10/site-packages/sympy/physics/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_clebsch_gordan.py b/lib/python3.10/site-packages/sympy/physics/tests/test_clebsch_gordan.py new file mode 100644 index 0000000000000000000000000000000000000000..68bfa0ac94df04a6bd2acab7396a1ebdcd778938 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_clebsch_gordan.py @@ -0,0 +1,198 @@ +from sympy.core.numbers import (I, pi, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.spherical_harmonics import Ynm +from sympy.matrices.dense import Matrix +from sympy.physics.wigner import (clebsch_gordan, wigner_9j, wigner_6j, gaunt, + real_gaunt, racah, dot_rot_grad_Ynm, wigner_3j, wigner_d_small, wigner_d) +from sympy.testing.pytest import raises + +# for test cases, refer : https://en.wikipedia.org/wiki/Table_of_Clebsch%E2%80%93Gordan_coefficients + +def test_clebsch_gordan_docs(): + assert clebsch_gordan(Rational(3, 2), S.Half, 2, Rational(3, 2), S.Half, 2) == 1 + assert clebsch_gordan(Rational(3, 2), S.Half, 1, Rational(3, 2), Rational(-1, 2), 1) == sqrt(3)/2 + assert clebsch_gordan(Rational(3, 2), S.Half, 1, Rational(-1, 2), S.Half, 0) == -sqrt(2)/2 + + +def test_clebsch_gordan(): + # Argument order: (j_1, j_2, j, m_1, m_2, m) + + h = S.One + k = S.Half + l = Rational(3, 2) + i = Rational(-1, 2) + n = Rational(7, 2) + p = Rational(5, 2) + assert clebsch_gordan(k, k, 1, k, k, 1) == 1 + assert clebsch_gordan(k, k, 1, k, k, 0) == 0 + assert clebsch_gordan(k, k, 1, i, i, -1) == 1 + assert clebsch_gordan(k, k, 1, k, i, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 0, k, i, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 1, i, k, 0) == sqrt(2)/2 + assert clebsch_gordan(k, k, 0, i, k, 0) == -sqrt(2)/2 + assert clebsch_gordan(h, k, l, 1, k, l) == 1 + assert clebsch_gordan(h, k, l, 1, i, k) == 1/sqrt(3) + assert clebsch_gordan(h, k, k, 1, i, k) == sqrt(2)/sqrt(3) + assert clebsch_gordan(h, k, k, 0, k, k) == -1/sqrt(3) + assert clebsch_gordan(h, k, l, 0, k, k) == sqrt(2)/sqrt(3) + assert clebsch_gordan(h, h, S(2), 1, 1, S(2)) == 1 + assert clebsch_gordan(h, h, S(2), 1, 0, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, S(2), 0, 1, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, 1, 1, 0, 1) == 1/sqrt(2) + assert clebsch_gordan(h, h, 1, 0, 1, 1) == -1/sqrt(2) + assert clebsch_gordan(l, l, S(3), l, l, S(3)) == 1 + assert clebsch_gordan(l, l, S(2), l, k, S(2)) == 1/sqrt(2) + assert clebsch_gordan(l, l, S(3), l, k, S(2)) == 1/sqrt(2) + assert clebsch_gordan(S(2), S(2), S(4), S(2), S(2), S(4)) == 1 + assert clebsch_gordan(S(2), S(2), S(3), S(2), 1, S(3)) == 1/sqrt(2) + assert clebsch_gordan(S(2), S(2), S(3), 1, 1, S(2)) == 0 + assert clebsch_gordan(p, h, n, p, 1, n) == 1 + assert clebsch_gordan(p, h, p, p, 0, p) == sqrt(5)/sqrt(7) + assert clebsch_gordan(p, h, l, k, 1, l) == 1/sqrt(15) + + +def test_wigner(): + def tn(a, b): + return (a - b).n(64) < S('1e-64') + assert tn(wigner_9j(1, 1, 1, 1, 1, 1, 1, 1, 0, prec=64), Rational(1, 18)) + assert wigner_9j(3, 3, 2, 3, 3, 2, 3, 3, 2) == 3221*sqrt( + 70)/(246960*sqrt(105)) - 365/(3528*sqrt(70)*sqrt(105)) + assert wigner_6j(5, 5, 5, 5, 5, 5) == Rational(1, 52) + assert tn(wigner_6j(8, 8, 8, 8, 8, 8, prec=64), Rational(-12219, 965770)) + # regression test for #8747 + half = S.Half + assert wigner_9j(0, 0, 0, 0, half, half, 0, half, half) == half + assert (wigner_9j(3, 5, 4, + 7 * half, 5 * half, 4, + 9 * half, 9 * half, 0) + == -sqrt(Rational(361, 205821000))) + assert (wigner_9j(1, 4, 3, + 5 * half, 4, 5 * half, + 5 * half, 2, 7 * half) + == -sqrt(Rational(3971, 373403520))) + assert (wigner_9j(4, 9 * half, 5 * half, + 2, 4, 4, + 5, 7 * half, 7 * half) + == -sqrt(Rational(3481, 5042614500))) + + +def test_gaunt(): + def tn(a, b): + return (a - b).n(64) < S('1e-64') + assert gaunt(1, 0, 1, 1, 0, -1) == -1/(2*sqrt(pi)) + assert isinstance(gaunt(1, 1, 0, -1, 1, 0).args[0], Rational) + assert isinstance(gaunt(0, 1, 1, 0, -1, 1).args[0], Rational) + + assert tn(gaunt( + 10, 10, 12, 9, 3, -12, prec=64), (Rational(-98, 62031)) * sqrt(6279)/sqrt(pi)) + def gaunt_ref(l1, l2, l3, m1, m2, m3): + return ( + sqrt((2 * l1 + 1) * (2 * l2 + 1) * (2 * l3 + 1) / (4 * pi)) * + wigner_3j(l1, l2, l3, 0, 0, 0) * + wigner_3j(l1, l2, l3, m1, m2, m3) + ) + threshold = 1e-10 + l_max = 3 + l3_max = 24 + for l1 in range(l_max + 1): + for l2 in range(l_max + 1): + for l3 in range(l3_max + 1): + for m1 in range(-l1, l1 + 1): + for m2 in range(-l2, l2 + 1): + for m3 in range(-l3, l3 + 1): + args = l1, l2, l3, m1, m2, m3 + g = gaunt(*args) + g0 = gaunt_ref(*args) + assert abs(g - g0) < threshold + if m1 + m2 + m3 != 0: + assert abs(g) < threshold + if (l1 + l2 + l3) % 2: + assert abs(g) < threshold + assert gaunt(1, 1, 0, 0, 2, -2) is S.Zero + + +def test_realgaunt(): + # All non-zero values corresponding to l values from 0 to 2 + for l in range(3): + for m in range(-l, l+1): + assert real_gaunt(0, l, l, 0, m, m) == 1/(2*sqrt(pi)) + assert real_gaunt(1, 1, 2, 0, 0, 0) == sqrt(5)/(5*sqrt(pi)) + assert real_gaunt(1, 1, 2, 1, 1, 0) == -sqrt(5)/(10*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 0, 0) == sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 2, 2) == -sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(2, 2, 2, -2, -2, 0) == -sqrt(5)/(7*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, 0, -1) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, 0, 1, 1) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, 1, 1, 2) == sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, 1, -2) == -sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(1, 1, 2, -1, -1, 2) == -sqrt(15)/(10*sqrt(pi)) + assert real_gaunt(2, 2, 2, 0, 1, 1) == sqrt(5)/(14*sqrt(pi)) + assert real_gaunt(2, 2, 2, 1, 1, 2) == sqrt(15)/(14*sqrt(pi)) + assert real_gaunt(2, 2, 2, -1, -1, 2) == -sqrt(15)/(14*sqrt(pi)) + + assert real_gaunt(-2, -2, -2, -2, -2, 0) is S.Zero # m test + assert real_gaunt(-2, 1, 0, 1, 1, 1) is S.Zero # l test + assert real_gaunt(-2, -1, -2, -1, -1, 0) is S.Zero # m and l test + assert real_gaunt(-2, -2, -2, -2, -2, -2) is S.Zero # m and k test + assert real_gaunt(-2, -1, -2, -1, -1, -1) is S.Zero # m, l and k test + + x = symbols('x', integer=True) + v = [0]*6 + for i in range(len(v)): + v[i] = x # non literal ints fail + raises(ValueError, lambda: real_gaunt(*v)) + v[i] = 0 + + +def test_racah(): + assert racah(3,3,3,3,3,3) == Rational(-1,14) + assert racah(2,2,2,2,2,2) == Rational(-3,70) + assert racah(7,8,7,1,7,7, prec=4).is_Float + assert racah(5.5,7.5,9.5,6.5,8,9) == -719*sqrt(598)/1158924 + assert abs(racah(5.5,7.5,9.5,6.5,8,9, prec=4) - (-0.01517)) < S('1e-4') + + +def test_dot_rota_grad_SH(): + theta, phi = symbols("theta phi") + assert dot_rot_grad_Ynm(1, 1, 1, 1, 1, 0) != \ + sqrt(30)*Ynm(2, 2, 1, 0)/(10*sqrt(pi)) + assert dot_rot_grad_Ynm(1, 1, 1, 1, 1, 0).doit() == \ + sqrt(30)*Ynm(2, 2, 1, 0)/(10*sqrt(pi)) + assert dot_rot_grad_Ynm(1, 5, 1, 1, 1, 2) != \ + 0 + assert dot_rot_grad_Ynm(1, 5, 1, 1, 1, 2).doit() == \ + 0 + assert dot_rot_grad_Ynm(3, 3, 3, 3, theta, phi).doit() == \ + 15*sqrt(3003)*Ynm(6, 6, theta, phi)/(143*sqrt(pi)) + assert dot_rot_grad_Ynm(3, 3, 1, 1, theta, phi).doit() == \ + sqrt(3)*Ynm(4, 4, theta, phi)/sqrt(pi) + assert dot_rot_grad_Ynm(3, 2, 2, 0, theta, phi).doit() == \ + 3*sqrt(55)*Ynm(5, 2, theta, phi)/(11*sqrt(pi)) + assert dot_rot_grad_Ynm(3, 2, 3, 2, theta, phi).doit().expand() == \ + -sqrt(70)*Ynm(4, 4, theta, phi)/(11*sqrt(pi)) + \ + 45*sqrt(182)*Ynm(6, 4, theta, phi)/(143*sqrt(pi)) + + +def test_wigner_d(): + half = S(1)/2 + assert wigner_d_small(half, 0) == Matrix([[1, 0], [0, 1]]) + assert wigner_d_small(half, pi/2) == Matrix([[1, 1], [-1, 1]])/sqrt(2) + assert wigner_d_small(half, pi) == Matrix([[0, 1], [-1, 0]]) + + alpha, beta, gamma = symbols("alpha, beta, gamma", real=True) + D = wigner_d(half, alpha, beta, gamma) + assert D[0, 0] == exp(I*alpha/2)*exp(I*gamma/2)*cos(beta/2) + assert D[0, 1] == exp(I*alpha/2)*exp(-I*gamma/2)*sin(beta/2) + assert D[1, 0] == -exp(-I*alpha/2)*exp(I*gamma/2)*sin(beta/2) + assert D[1, 1] == exp(-I*alpha/2)*exp(-I*gamma/2)*cos(beta/2) + + # Test Y_{n mi}(g*x)=\sum_{mj}D^n_{mi mj}*Y_{n mj}(x) + theta, phi = symbols("theta phi", real=True) + v = Matrix([Ynm(1, mj, theta, phi) for mj in range(1, -2, -1)]) + w = wigner_d(1, -pi/2, pi/2, -pi/2)@v.subs({theta: pi/4, phi: pi}) + w_ = v.subs({theta: pi/2, phi: pi/4}) + assert w.expand(func=True).as_real_imag() == w_.expand(func=True).as_real_imag() diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_hydrogen.py b/lib/python3.10/site-packages/sympy/physics/tests/test_hydrogen.py new file mode 100644 index 0000000000000000000000000000000000000000..eb11744dd8e731f24fcd6f6be2a92ada4fffc554 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_hydrogen.py @@ -0,0 +1,126 @@ +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.physics.hydrogen import R_nl, E_nl, E_nl_dirac, Psi_nlm +from sympy.testing.pytest import raises + +n, r, Z = symbols('n r Z') + + +def feq(a, b, max_relative_error=1e-12, max_absolute_error=1e-12): + a = float(a) + b = float(b) + # if the numbers are close enough (absolutely), then they are equal + if abs(a - b) < max_absolute_error: + return True + # if not, they can still be equal if their relative error is small + if abs(b) > abs(a): + relative_error = abs((a - b)/b) + else: + relative_error = abs((a - b)/a) + return relative_error <= max_relative_error + + +def test_wavefunction(): + a = 1/Z + R = { + (1, 0): 2*sqrt(1/a**3) * exp(-r/a), + (2, 0): sqrt(1/(2*a**3)) * exp(-r/(2*a)) * (1 - r/(2*a)), + (2, 1): S.Half * sqrt(1/(6*a**3)) * exp(-r/(2*a)) * r/a, + (3, 0): Rational(2, 3) * sqrt(1/(3*a**3)) * exp(-r/(3*a)) * + (1 - 2*r/(3*a) + Rational(2, 27) * (r/a)**2), + (3, 1): Rational(4, 27) * sqrt(2/(3*a**3)) * exp(-r/(3*a)) * + (1 - r/(6*a)) * r/a, + (3, 2): Rational(2, 81) * sqrt(2/(15*a**3)) * exp(-r/(3*a)) * (r/a)**2, + (4, 0): Rational(1, 4) * sqrt(1/a**3) * exp(-r/(4*a)) * + (1 - 3*r/(4*a) + Rational(1, 8) * (r/a)**2 - Rational(1, 192) * (r/a)**3), + (4, 1): Rational(1, 16) * sqrt(5/(3*a**3)) * exp(-r/(4*a)) * + (1 - r/(4*a) + Rational(1, 80) * (r/a)**2) * (r/a), + (4, 2): Rational(1, 64) * sqrt(1/(5*a**3)) * exp(-r/(4*a)) * + (1 - r/(12*a)) * (r/a)**2, + (4, 3): Rational(1, 768) * sqrt(1/(35*a**3)) * exp(-r/(4*a)) * (r/a)**3, + } + for n, l in R: + assert simplify(R_nl(n, l, r, Z) - R[(n, l)]) == 0 + + +def test_norm(): + # Maximum "n" which is tested: + n_max = 2 # it works, but is slow, for n_max > 2 + for n in range(n_max + 1): + for l in range(n): + assert integrate(R_nl(n, l, r)**2 * r**2, (r, 0, oo)) == 1 + +def test_psi_nlm(): + r=S('r') + phi=S('phi') + theta=S('theta') + assert (Psi_nlm(1, 0, 0, r, phi, theta) == exp(-r) / sqrt(pi)) + assert (Psi_nlm(2, 1, -1, r, phi, theta)) == S.Half * exp(-r / (2)) * r \ + * (sin(theta) * exp(-I * phi) / (4 * sqrt(pi))) + assert (Psi_nlm(3, 2, 1, r, phi, theta, 2) == -sqrt(2) * sin(theta) \ + * exp(I * phi) * cos(theta) / (4 * sqrt(pi)) * S(2) / 81 \ + * sqrt(2 * 2 ** 3) * exp(-2 * r / (3)) * (r * 2) ** 2) + +def test_hydrogen_energies(): + assert E_nl(n, Z) == -Z**2/(2*n**2) + assert E_nl(n) == -1/(2*n**2) + + assert E_nl(1, 47) == -S(47)**2/(2*1**2) + assert E_nl(2, 47) == -S(47)**2/(2*2**2) + + assert E_nl(1) == -S.One/(2*1**2) + assert E_nl(2) == -S.One/(2*2**2) + assert E_nl(3) == -S.One/(2*3**2) + assert E_nl(4) == -S.One/(2*4**2) + assert E_nl(100) == -S.One/(2*100**2) + + raises(ValueError, lambda: E_nl(0)) + + +def test_hydrogen_energies_relat(): + # First test exact formulas for small "c" so that we get nice expressions: + assert E_nl_dirac(2, 0, Z=1, c=1) == 1/sqrt(2) - 1 + assert simplify(E_nl_dirac(2, 0, Z=1, c=2) - ( (8*sqrt(3) + 16) + / sqrt(16*sqrt(3) + 32) - 4)) == 0 + assert simplify(E_nl_dirac(2, 0, Z=1, c=3) - ( (54*sqrt(2) + 81) + / sqrt(108*sqrt(2) + 162) - 9)) == 0 + + # Now test for almost the correct speed of light, without floating point + # numbers: + assert simplify(E_nl_dirac(2, 0, Z=1, c=137) - ( (352275361 + 10285412 * + sqrt(1173)) / sqrt(704550722 + 20570824 * sqrt(1173)) - 18769)) == 0 + assert simplify(E_nl_dirac(2, 0, Z=82, c=137) - ( (352275361 + 2571353 * + sqrt(12045)) / sqrt(704550722 + 5142706*sqrt(12045)) - 18769)) == 0 + + # Test using exact speed of light, and compare against the nonrelativistic + # energies: + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l), E_nl(n), 1e-5, 1e-5) + if l > 0: + assert feq(E_nl_dirac(n, l, False), E_nl(n), 1e-5, 1e-5) + + Z = 2 + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l, Z=Z), E_nl(n, Z), 1e-4, 1e-4) + if l > 0: + assert feq(E_nl_dirac(n, l, False, Z), E_nl(n, Z), 1e-4, 1e-4) + + Z = 3 + for n in range(1, 5): + for l in range(n): + assert feq(E_nl_dirac(n, l, Z=Z), E_nl(n, Z), 1e-3, 1e-3) + if l > 0: + assert feq(E_nl_dirac(n, l, False, Z), E_nl(n, Z), 1e-3, 1e-3) + + # Test the exceptions: + raises(ValueError, lambda: E_nl_dirac(0, 0)) + raises(ValueError, lambda: E_nl_dirac(1, -1)) + raises(ValueError, lambda: E_nl_dirac(1, 0, False)) diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_paulialgebra.py b/lib/python3.10/site-packages/sympy/physics/tests/test_paulialgebra.py new file mode 100644 index 0000000000000000000000000000000000000000..f773470a1802f2864b79f56d38be1de030ff86dc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_paulialgebra.py @@ -0,0 +1,57 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.physics.paulialgebra import Pauli +from sympy.testing.pytest import XFAIL +from sympy.physics.quantum import TensorProduct + +sigma1 = Pauli(1) +sigma2 = Pauli(2) +sigma3 = Pauli(3) + +tau1 = symbols("tau1", commutative = False) + + +def test_Pauli(): + + assert sigma1 == sigma1 + assert sigma1 != sigma2 + + assert sigma1*sigma2 == I*sigma3 + assert sigma3*sigma1 == I*sigma2 + assert sigma2*sigma3 == I*sigma1 + + assert sigma1*sigma1 == 1 + assert sigma2*sigma2 == 1 + assert sigma3*sigma3 == 1 + + assert sigma1**0 == 1 + assert sigma1**1 == sigma1 + assert sigma1**2 == 1 + assert sigma1**3 == sigma1 + assert sigma1**4 == 1 + + assert sigma3**2 == 1 + + assert sigma1*2*sigma1 == 2 + + +def test_evaluate_pauli_product(): + from sympy.physics.paulialgebra import evaluate_pauli_product + + assert evaluate_pauli_product(I*sigma2*sigma3) == -sigma1 + + # Check issue 6471 + assert evaluate_pauli_product(-I*4*sigma1*sigma2) == 4*sigma3 + + assert evaluate_pauli_product( + 1 + I*sigma1*sigma2*sigma1*sigma2 + \ + I*sigma1*sigma2*tau1*sigma1*sigma3 + \ + ((tau1**2).subs(tau1, I*sigma1)) + \ + sigma3*((tau1**2).subs(tau1, I*sigma1)) + \ + TensorProduct(I*sigma1*sigma2*sigma1*sigma2, 1) + ) == 1 -I + I*sigma3*tau1*sigma2 - 1 - sigma3 - I*TensorProduct(1,1) + + +@XFAIL +def test_Pauli_should_work(): + assert sigma1*sigma3*sigma1 == -sigma3 diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_physics_matrices.py b/lib/python3.10/site-packages/sympy/physics/tests/test_physics_matrices.py new file mode 100644 index 0000000000000000000000000000000000000000..14fa47668d0760826e0354c8cafae787a24256eb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_physics_matrices.py @@ -0,0 +1,84 @@ +from sympy.physics.matrices import msigma, mgamma, minkowski_tensor, pat_matrix, mdft +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import (Matrix, eye, zeros) +from sympy.testing.pytest import warns_deprecated_sympy + + +def test_parallel_axis_theorem(): + # This tests the parallel axis theorem matrix by comparing to test + # matrices. + + # First case, 1 in all directions. + mat1 = Matrix(((2, -1, -1), (-1, 2, -1), (-1, -1, 2))) + assert pat_matrix(1, 1, 1, 1) == mat1 + assert pat_matrix(2, 1, 1, 1) == 2*mat1 + + # Second case, 1 in x, 0 in all others + mat2 = Matrix(((0, 0, 0), (0, 1, 0), (0, 0, 1))) + assert pat_matrix(1, 1, 0, 0) == mat2 + assert pat_matrix(2, 1, 0, 0) == 2*mat2 + + # Third case, 1 in y, 0 in all others + mat3 = Matrix(((1, 0, 0), (0, 0, 0), (0, 0, 1))) + assert pat_matrix(1, 0, 1, 0) == mat3 + assert pat_matrix(2, 0, 1, 0) == 2*mat3 + + # Fourth case, 1 in z, 0 in all others + mat4 = Matrix(((1, 0, 0), (0, 1, 0), (0, 0, 0))) + assert pat_matrix(1, 0, 0, 1) == mat4 + assert pat_matrix(2, 0, 0, 1) == 2*mat4 + + +def test_Pauli(): + #this and the following test are testing both Pauli and Dirac matrices + #and also that the general Matrix class works correctly in a real world + #situation + sigma1 = msigma(1) + sigma2 = msigma(2) + sigma3 = msigma(3) + + assert sigma1 == sigma1 + assert sigma1 != sigma2 + + # sigma*I -> I*sigma (see #354) + assert sigma1*sigma2 == sigma3*I + assert sigma3*sigma1 == sigma2*I + assert sigma2*sigma3 == sigma1*I + + assert sigma1*sigma1 == eye(2) + assert sigma2*sigma2 == eye(2) + assert sigma3*sigma3 == eye(2) + + assert sigma1*2*sigma1 == 2*eye(2) + assert sigma1*sigma3*sigma1 == -sigma3 + + +def test_Dirac(): + gamma0 = mgamma(0) + gamma1 = mgamma(1) + gamma2 = mgamma(2) + gamma3 = mgamma(3) + gamma5 = mgamma(5) + + # gamma*I -> I*gamma (see #354) + assert gamma5 == gamma0 * gamma1 * gamma2 * gamma3 * I + assert gamma1 * gamma2 + gamma2 * gamma1 == zeros(4) + assert gamma0 * gamma0 == eye(4) * minkowski_tensor[0, 0] + assert gamma2 * gamma2 != eye(4) * minkowski_tensor[0, 0] + assert gamma2 * gamma2 == eye(4) * minkowski_tensor[2, 2] + + assert mgamma(5, True) == \ + mgamma(0, True)*mgamma(1, True)*mgamma(2, True)*mgamma(3, True)*I + +def test_mdft(): + with warns_deprecated_sympy(): + assert mdft(1) == Matrix([[1]]) + with warns_deprecated_sympy(): + assert mdft(2) == 1/sqrt(2)*Matrix([[1,1],[1,-1]]) + with warns_deprecated_sympy(): + assert mdft(4) == Matrix([[S.Half, S.Half, S.Half, S.Half], + [S.Half, -I/2, Rational(-1,2), I/2], + [S.Half, Rational(-1,2), S.Half, Rational(-1,2)], + [S.Half, I/2, Rational(-1,2), -I/2]]) diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_pring.py b/lib/python3.10/site-packages/sympy/physics/tests/test_pring.py new file mode 100644 index 0000000000000000000000000000000000000000..ed7398eac4a8bb1cd4af810825caf3fcefb5f18f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_pring.py @@ -0,0 +1,41 @@ +from sympy.physics.pring import wavefunction, energy +from sympy.core.numbers import (I, pi) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.abc import m, x, r +from sympy.physics.quantum.constants import hbar + + +def test_wavefunction(): + Psi = { + 0: (1/sqrt(2 * pi)), + 1: (1/sqrt(2 * pi)) * exp(I * x), + 2: (1/sqrt(2 * pi)) * exp(2 * I * x), + 3: (1/sqrt(2 * pi)) * exp(3 * I * x) + } + for n in Psi: + assert simplify(wavefunction(n, x) - Psi[n]) == 0 + + +def test_norm(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert integrate( + wavefunction(i, x) * wavefunction(-i, x), (x, 0, 2 * pi)) == 1 + + +def test_orthogonality(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + for j in range(i+1, n+1): + assert integrate( + wavefunction(i, x) * wavefunction(j, x), (x, 0, 2 * pi)) == 0 + + +def test_energy(n=1): + # Maximum "n" which is tested: + for i in range(n+1): + assert simplify( + energy(i, m, r) - ((i**2 * hbar**2) / (2 * m * r**2))) == 0 diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_qho_1d.py b/lib/python3.10/site-packages/sympy/physics/tests/test_qho_1d.py new file mode 100644 index 0000000000000000000000000000000000000000..34e52c9e3a721496fc61f7d2b31414db15caa7a8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_qho_1d.py @@ -0,0 +1,50 @@ +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.integrals.integrals import integrate +from sympy.simplify.simplify import simplify +from sympy.abc import omega, m, x +from sympy.physics.qho_1d import psi_n, E_n, coherent_state +from sympy.physics.quantum.constants import hbar + +nu = m * omega / hbar + + +def test_wavefunction(): + Psi = { + 0: (nu/pi)**Rational(1, 4) * exp(-nu * x**2 /2), + 1: (nu/pi)**Rational(1, 4) * sqrt(2*nu) * x * exp(-nu * x**2 /2), + 2: (nu/pi)**Rational(1, 4) * (2 * nu * x**2 - 1)/sqrt(2) * exp(-nu * x**2 /2), + 3: (nu/pi)**Rational(1, 4) * sqrt(nu/3) * (2 * nu * x**3 - 3 * x) * exp(-nu * x**2 /2) + } + for n in Psi: + assert simplify(psi_n(n, x, m, omega) - Psi[n]) == 0 + + +def test_norm(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert integrate(psi_n(i, x, 1, 1)**2, (x, -oo, oo)) == 1 + + +def test_orthogonality(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + for j in range(i + 1, n + 1): + assert integrate( + psi_n(i, x, 1, 1)*psi_n(j, x, 1, 1), (x, -oo, oo)) == 0 + + +def test_energies(n=1): + # Maximum "n" which is tested: + for i in range(n + 1): + assert E_n(i, omega) == hbar * omega * (i + S.Half) + +def test_coherent_state(n=10): + # Maximum "n" which is tested: + # test whether coherent state is the eigenstate of annihilation operator + alpha = Symbol("alpha") + for i in range(n + 1): + assert simplify(sqrt(n + 1) * coherent_state(n + 1, alpha)) == simplify(alpha * coherent_state(n, alpha)) diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_secondquant.py b/lib/python3.10/site-packages/sympy/physics/tests/test_secondquant.py new file mode 100644 index 0000000000000000000000000000000000000000..dc9f4a499a7bee96d5fb5c76e83d84a72db5db8a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_secondquant.py @@ -0,0 +1,1280 @@ +from sympy.physics.secondquant import ( + Dagger, Bd, VarBosonicBasis, BBra, B, BKet, FixedBosonicBasis, + matrix_rep, apply_operators, InnerProduct, Commutator, KroneckerDelta, + AnnihilateBoson, CreateBoson, BosonicOperator, + F, Fd, FKet, BosonState, CreateFermion, AnnihilateFermion, + evaluate_deltas, AntiSymmetricTensor, contraction, NO, wicks, + PermutationOperator, simplify_index_permutations, + _sort_anticommuting_fermions, _get_ordered_dummies, + substitute_dummies, FockStateBosonKet, + ContractionAppliesOnlyToFermions +) + +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, expand) +from sympy.core.numbers import (I, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.printing.repr import srepr +from sympy.simplify.simplify import simplify + +from sympy.testing.pytest import slow, raises +from sympy.printing.latex import latex + + +def test_PermutationOperator(): + p, q, r, s = symbols('p,q,r,s') + f, g, h, i = map(Function, 'fghi') + P = PermutationOperator + assert P(p, q).get_permuted(f(p)*g(q)) == -f(q)*g(p) + assert P(p, q).get_permuted(f(p, q)) == -f(q, p) + assert P(p, q).get_permuted(f(p)) == f(p) + expr = (f(p)*g(q)*h(r)*i(s) + - f(q)*g(p)*h(r)*i(s) + - f(p)*g(q)*h(s)*i(r) + + f(q)*g(p)*h(s)*i(r)) + perms = [P(p, q), P(r, s)] + assert (simplify_index_permutations(expr, perms) == + P(p, q)*P(r, s)*f(p)*g(q)*h(r)*i(s)) + assert latex(P(p, q)) == 'P(pq)' + + +def test_index_permutations_with_dummies(): + a, b, c, d = symbols('a b c d') + p, q, r, s = symbols('p q r s', cls=Dummy) + f, g = map(Function, 'fg') + P = PermutationOperator + + # No dummy substitution necessary + expr = f(a, b, p, q) - f(b, a, p, q) + assert simplify_index_permutations( + expr, [P(a, b)]) == P(a, b)*f(a, b, p, q) + + # Cases where dummy substitution is needed + expected = P(a, b)*substitute_dummies(f(a, b, p, q)) + + expr = f(a, b, p, q) - f(b, a, q, p) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expected == substitute_dummies(result) + + expr = f(a, b, q, p) - f(b, a, p, q) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expected == substitute_dummies(result) + + # A case where nothing can be done + expr = f(a, b, q, p) - g(b, a, p, q) + result = simplify_index_permutations(expr, [P(a, b)]) + assert expr == result + + +def test_dagger(): + i, j, n, m = symbols('i,j,n,m') + assert Dagger(1) == 1 + assert Dagger(1.0) == 1.0 + assert Dagger(2*I) == -2*I + assert Dagger(S.Half*I/3.0) == I*Rational(-1, 2)/3.0 + assert Dagger(BKet([n])) == BBra([n]) + assert Dagger(B(0)) == Bd(0) + assert Dagger(Bd(0)) == B(0) + assert Dagger(B(n)) == Bd(n) + assert Dagger(Bd(n)) == B(n) + assert Dagger(B(0) + B(1)) == Bd(0) + Bd(1) + assert Dagger(n*m) == Dagger(n)*Dagger(m) # n, m commute + assert Dagger(B(n)*B(m)) == Bd(m)*Bd(n) + assert Dagger(B(n)**10) == Dagger(B(n))**10 + assert Dagger('a') == Dagger(Symbol('a')) + assert Dagger(Dagger('a')) == Symbol('a') + + +def test_operator(): + i, j = symbols('i,j') + o = BosonicOperator(i) + assert o.state == i + assert o.is_symbolic + o = BosonicOperator(1) + assert o.state == 1 + assert not o.is_symbolic + + +def test_create(): + i, j, n, m = symbols('i,j,n,m') + o = Bd(i) + assert latex(o) == "{b^\\dagger_{i}}" + assert isinstance(o, CreateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Bd(0) + assert o.apply_operator(BKet([n])) == sqrt(n + 1)*BKet([n + 1]) + o = Bd(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_annihilate(): + i, j, n, m = symbols('i,j,n,m') + o = B(i) + assert latex(o) == "b_{i}" + assert isinstance(o, AnnihilateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = B(0) + assert o.apply_operator(BKet([n])) == sqrt(n)*BKet([n - 1]) + o = B(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_basic_state(): + i, j, n, m = symbols('i,j,n,m') + s = BosonState([0, 1, 2, 3, 4]) + assert len(s) == 5 + assert s.args[0] == tuple(range(5)) + assert s.up(0) == BosonState([1, 1, 2, 3, 4]) + assert s.down(4) == BosonState([0, 1, 2, 3, 3]) + for i in range(5): + assert s.up(i).down(i) == s + assert s.down(0) == 0 + for i in range(5): + assert s[i] == i + s = BosonState([n, m]) + assert s.down(0) == BosonState([n - 1, m]) + assert s.up(0) == BosonState([n + 1, m]) + + +def test_basic_apply(): + n = symbols("n") + e = B(0)*BKet([n]) + assert apply_operators(e) == sqrt(n)*BKet([n - 1]) + e = Bd(0)*BKet([n]) + assert apply_operators(e) == sqrt(n + 1)*BKet([n + 1]) + + +def test_complex_apply(): + n, m = symbols("n,m") + o = Bd(0)*B(0)*Bd(1)*B(0) + e = apply_operators(o*BKet([n, m])) + answer = sqrt(n)*sqrt(m + 1)*(-1 + n)*BKet([-1 + n, 1 + m]) + assert expand(e) == expand(answer) + + +def test_number_operator(): + n = symbols("n") + o = Bd(0)*B(0) + e = apply_operators(o*BKet([n])) + assert e == n*BKet([n]) + + +def test_inner_product(): + i, j, k, l = symbols('i,j,k,l') + s1 = BBra([0]) + s2 = BKet([1]) + assert InnerProduct(s1, Dagger(s1)) == 1 + assert InnerProduct(s1, s2) == 0 + s1 = BBra([i, j]) + s2 = BKet([k, l]) + r = InnerProduct(s1, s2) + assert r == KroneckerDelta(i, k)*KroneckerDelta(j, l) + + +def test_symbolic_matrix_elements(): + n, m = symbols('n,m') + s1 = BBra([n]) + s2 = BKet([m]) + o = B(0) + e = apply_operators(s1*o*s2) + assert e == sqrt(m)*KroneckerDelta(n, m - 1) + + +def test_matrix_elements(): + b = VarBosonicBasis(5) + o = B(0) + m = matrix_rep(o, b) + for i in range(4): + assert m[i, i + 1] == sqrt(i + 1) + o = Bd(0) + m = matrix_rep(o, b) + for i in range(4): + assert m[i + 1, i] == sqrt(i + 1) + + +def test_fixed_bosonic_basis(): + b = FixedBosonicBasis(2, 2) + # assert b == [FockState((2, 0)), FockState((1, 1)), FockState((0, 2))] + state = b.state(1) + assert state == FockStateBosonKet((1, 1)) + assert b.index(state) == 1 + assert b.state(1) == b[1] + assert len(b) == 3 + assert str(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + assert repr(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + assert srepr(b) == '[FockState((2, 0)), FockState((1, 1)), FockState((0, 2))]' + + +@slow +def test_sho(): + n, m = symbols('n,m') + h_n = Bd(n)*B(n)*(n + S.Half) + H = Sum(h_n, (n, 0, 5)) + o = H.doit(deep=False) + b = FixedBosonicBasis(2, 6) + m = matrix_rep(o, b) + # We need to double check these energy values to make sure that they + # are correct and have the proper degeneracies! + diag = [1, 2, 3, 3, 4, 5, 4, 5, 6, 7, 5, 6, 7, 8, 9, 6, 7, 8, 9, 10, 11] + for i in range(len(diag)): + assert diag[i] == m[i, i] + + +def test_commutation(): + n, m = symbols("n,m", above_fermi=True) + c = Commutator(B(0), Bd(0)) + assert c == 1 + c = Commutator(Bd(0), B(0)) + assert c == -1 + c = Commutator(B(n), Bd(0)) + assert c == KroneckerDelta(n, 0) + c = Commutator(B(0), B(0)) + assert c == 0 + c = Commutator(B(0), Bd(0)) + e = simplify(apply_operators(c*BKet([n]))) + assert e == BKet([n]) + c = Commutator(B(0), B(1)) + e = simplify(apply_operators(c*BKet([n, m]))) + assert e == 0 + + c = Commutator(F(m), Fd(m)) + assert c == +1 - 2*NO(Fd(m)*F(m)) + c = Commutator(Fd(m), F(m)) + assert c.expand() == -1 + 2*NO(Fd(m)*F(m)) + + C = Commutator + X, Y, Z = symbols('X,Y,Z', commutative=False) + assert C(C(X, Y), Z) != 0 + assert C(C(X, Z), Y) != 0 + assert C(Y, C(X, Z)) != 0 + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + D = KroneckerDelta + + assert C(Fd(a), F(i)) == -2*NO(F(i)*Fd(a)) + assert C(Fd(j), NO(Fd(a)*F(i))).doit(wicks=True) == -D(j, i)*Fd(a) + assert C(Fd(a)*F(i), Fd(b)*F(j)).doit(wicks=True) == 0 + + c1 = Commutator(F(a), Fd(a)) + assert Commutator.eval(c1, c1) == 0 + c = Commutator(Fd(a)*F(i),Fd(b)*F(j)) + assert latex(c) == r'\left[{a^\dagger_{a}} a_{i},{a^\dagger_{b}} a_{j}\right]' + assert repr(c) == 'Commutator(CreateFermion(a)*AnnihilateFermion(i),CreateFermion(b)*AnnihilateFermion(j))' + assert str(c) == '[CreateFermion(a)*AnnihilateFermion(i),CreateFermion(b)*AnnihilateFermion(j)]' + + +def test_create_f(): + i, j, n, m = symbols('i,j,n,m') + o = Fd(i) + assert isinstance(o, CreateFermion) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Fd(1) + assert o.apply_operator(FKet([n])) == FKet([1, n]) + assert o.apply_operator(FKet([n])) == -FKet([n, 1]) + o = Fd(n) + assert o.apply_operator(FKet([])) == FKet([n]) + + vacuum = FKet([], fermi_level=4) + assert vacuum == FKet([], fermi_level=4) + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + + assert Fd(i).apply_operator(FKet([i, j, k], 4)) == FKet([j, k], 4) + assert Fd(a).apply_operator(FKet([i, b, k], 4)) == FKet([a, i, b, k], 4) + + assert Dagger(B(p)).apply_operator(q) == q*CreateBoson(p) + assert repr(Fd(p)) == 'CreateFermion(p)' + assert srepr(Fd(p)) == "CreateFermion(Symbol('p'))" + assert latex(Fd(p)) == r'{a^\dagger_{p}}' + + +def test_annihilate_f(): + i, j, n, m = symbols('i,j,n,m') + o = F(i) + assert isinstance(o, AnnihilateFermion) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = F(1) + assert o.apply_operator(FKet([1, n])) == FKet([n]) + assert o.apply_operator(FKet([n, 1])) == -FKet([n]) + o = F(n) + assert o.apply_operator(FKet([n])) == FKet([]) + + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + assert F(i).apply_operator(FKet([i, j, k], 4)) == 0 + assert F(a).apply_operator(FKet([i, b, k], 4)) == 0 + assert F(l).apply_operator(FKet([i, j, k], 3)) == 0 + assert F(l).apply_operator(FKet([i, j, k], 4)) == FKet([l, i, j, k], 4) + assert str(F(p)) == 'f(p)' + assert repr(F(p)) == 'AnnihilateFermion(p)' + assert srepr(F(p)) == "AnnihilateFermion(Symbol('p'))" + assert latex(F(p)) == 'a_{p}' + + +def test_create_b(): + i, j, n, m = symbols('i,j,n,m') + o = Bd(i) + assert isinstance(o, CreateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = Bd(0) + assert o.apply_operator(BKet([n])) == sqrt(n + 1)*BKet([n + 1]) + o = Bd(n) + assert o.apply_operator(BKet([n])) == o*BKet([n]) + + +def test_annihilate_b(): + i, j, n, m = symbols('i,j,n,m') + o = B(i) + assert isinstance(o, AnnihilateBoson) + o = o.subs(i, j) + assert o.atoms(Symbol) == {j} + o = B(0) + + +def test_wicks(): + p, q, r, s = symbols('p,q,r,s', above_fermi=True) + + # Testing for particles only + + str = F(p)*Fd(q) + assert wicks(str) == NO(F(p)*Fd(q)) + KroneckerDelta(p, q) + str = Fd(p)*F(q) + assert wicks(str) == NO(Fd(p)*F(q)) + + str = F(p)*Fd(q)*F(r)*Fd(s) + nstr = wicks(str) + fasit = NO( + KroneckerDelta(p, q)*KroneckerDelta(r, s) + + KroneckerDelta(p, q)*AnnihilateFermion(r)*CreateFermion(s) + + KroneckerDelta(r, s)*AnnihilateFermion(p)*CreateFermion(q) + - KroneckerDelta(p, s)*AnnihilateFermion(r)*CreateFermion(q) + - AnnihilateFermion(p)*AnnihilateFermion(r)*CreateFermion(q)*CreateFermion(s)) + assert nstr == fasit + + assert (p*q*nstr).expand() == wicks(p*q*str) + assert (nstr*p*q*2).expand() == wicks(str*p*q*2) + + # Testing CC equations particles and holes + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + p, q, r, s = symbols('p q r s', cls=Dummy) + + assert (wicks(F(a)*NO(F(i)*F(j))*Fd(b)) == + NO(F(a)*F(i)*F(j)*Fd(b)) + + KroneckerDelta(a, b)*NO(F(i)*F(j))) + assert (wicks(F(a)*NO(F(i)*F(j)*F(k))*Fd(b)) == + NO(F(a)*F(i)*F(j)*F(k)*Fd(b)) - + KroneckerDelta(a, b)*NO(F(i)*F(j)*F(k))) + + expr = wicks(Fd(i)*NO(Fd(j)*F(k))*F(l)) + assert (expr == + -KroneckerDelta(i, k)*NO(Fd(j)*F(l)) - + KroneckerDelta(j, l)*NO(Fd(i)*F(k)) - + KroneckerDelta(i, k)*KroneckerDelta(j, l) + + KroneckerDelta(i, l)*NO(Fd(j)*F(k)) + + NO(Fd(i)*Fd(j)*F(k)*F(l))) + expr = wicks(F(a)*NO(F(b)*Fd(c))*Fd(d)) + assert (expr == + -KroneckerDelta(a, c)*NO(F(b)*Fd(d)) - + KroneckerDelta(b, d)*NO(F(a)*Fd(c)) - + KroneckerDelta(a, c)*KroneckerDelta(b, d) + + KroneckerDelta(a, d)*NO(F(b)*Fd(c)) + + NO(F(a)*F(b)*Fd(c)*Fd(d))) + + +def test_NO(): + i, j, k, l = symbols('i j k l', below_fermi=True) + a, b, c, d = symbols('a b c d', above_fermi=True) + p, q, r, s = symbols('p q r s', cls=Dummy) + + assert (NO(Fd(p)*F(q) + Fd(a)*F(b)) == + NO(Fd(p)*F(q)) + NO(Fd(a)*F(b))) + assert (NO(Fd(i)*NO(F(j)*Fd(a))) == + NO(Fd(i)*F(j)*Fd(a))) + assert NO(1) == 1 + assert NO(i) == i + assert (NO(Fd(a)*Fd(b)*(F(c) + F(d))) == + NO(Fd(a)*Fd(b)*F(c)) + + NO(Fd(a)*Fd(b)*F(d))) + + assert NO(Fd(a)*F(b))._remove_brackets() == Fd(a)*F(b) + assert NO(F(j)*Fd(i))._remove_brackets() == F(j)*Fd(i) + + assert (NO(Fd(p)*F(q)).subs(Fd(p), Fd(a) + Fd(i)) == + NO(Fd(a)*F(q)) + NO(Fd(i)*F(q))) + assert (NO(Fd(p)*F(q)).subs(F(q), F(a) + F(i)) == + NO(Fd(p)*F(a)) + NO(Fd(p)*F(i))) + + expr = NO(Fd(p)*F(q))._remove_brackets() + assert wicks(expr) == NO(expr) + + assert NO(Fd(a)*F(b)) == - NO(F(b)*Fd(a)) + + no = NO(Fd(a)*F(i)*F(b)*Fd(j)) + l1 = list(no.iter_q_creators()) + assert l1 == [0, 1] + l2 = list(no.iter_q_annihilators()) + assert l2 == [3, 2] + no = NO(Fd(a)*Fd(i)) + assert no.has_q_creators == 1 + assert no.has_q_annihilators == -1 + assert str(no) == ':CreateFermion(a)*CreateFermion(i):' + assert repr(no) == 'NO(CreateFermion(a)*CreateFermion(i))' + assert latex(no) == r'\left\{{a^\dagger_{a}} {a^\dagger_{i}}\right\}' + raises(NotImplementedError, lambda: NO(Bd(p)*F(q))) + + +def test_sorting(): + i, j = symbols('i,j', below_fermi=True) + a, b = symbols('a,b', above_fermi=True) + p, q = symbols('p,q') + + # p, q + assert _sort_anticommuting_fermions([Fd(p), F(q)]) == ([Fd(p), F(q)], 0) + assert _sort_anticommuting_fermions([F(p), Fd(q)]) == ([Fd(q), F(p)], 1) + + # i, p + assert _sort_anticommuting_fermions([F(p), Fd(i)]) == ([F(p), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), F(p)]) == ([F(p), Fd(i)], 1) + assert _sort_anticommuting_fermions([Fd(p), Fd(i)]) == ([Fd(p), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), Fd(p)]) == ([Fd(p), Fd(i)], 1) + assert _sort_anticommuting_fermions([F(p), F(i)]) == ([F(i), F(p)], 1) + assert _sort_anticommuting_fermions([F(i), F(p)]) == ([F(i), F(p)], 0) + assert _sort_anticommuting_fermions([Fd(p), F(i)]) == ([F(i), Fd(p)], 1) + assert _sort_anticommuting_fermions([F(i), Fd(p)]) == ([F(i), Fd(p)], 0) + + # a, p + assert _sort_anticommuting_fermions([F(p), Fd(a)]) == ([Fd(a), F(p)], 1) + assert _sort_anticommuting_fermions([Fd(a), F(p)]) == ([Fd(a), F(p)], 0) + assert _sort_anticommuting_fermions([Fd(p), Fd(a)]) == ([Fd(a), Fd(p)], 1) + assert _sort_anticommuting_fermions([Fd(a), Fd(p)]) == ([Fd(a), Fd(p)], 0) + assert _sort_anticommuting_fermions([F(p), F(a)]) == ([F(p), F(a)], 0) + assert _sort_anticommuting_fermions([F(a), F(p)]) == ([F(p), F(a)], 1) + assert _sort_anticommuting_fermions([Fd(p), F(a)]) == ([Fd(p), F(a)], 0) + assert _sort_anticommuting_fermions([F(a), Fd(p)]) == ([Fd(p), F(a)], 1) + + # i, a + assert _sort_anticommuting_fermions([F(i), Fd(j)]) == ([F(i), Fd(j)], 0) + assert _sort_anticommuting_fermions([Fd(j), F(i)]) == ([F(i), Fd(j)], 1) + assert _sort_anticommuting_fermions([Fd(a), Fd(i)]) == ([Fd(a), Fd(i)], 0) + assert _sort_anticommuting_fermions([Fd(i), Fd(a)]) == ([Fd(a), Fd(i)], 1) + assert _sort_anticommuting_fermions([F(a), F(i)]) == ([F(i), F(a)], 1) + assert _sort_anticommuting_fermions([F(i), F(a)]) == ([F(i), F(a)], 0) + + +def test_contraction(): + i, j, k, l = symbols('i,j,k,l', below_fermi=True) + a, b, c, d = symbols('a,b,c,d', above_fermi=True) + p, q, r, s = symbols('p,q,r,s') + assert contraction(Fd(i), F(j)) == KroneckerDelta(i, j) + assert contraction(F(a), Fd(b)) == KroneckerDelta(a, b) + assert contraction(F(a), Fd(i)) == 0 + assert contraction(Fd(a), F(i)) == 0 + assert contraction(F(i), Fd(a)) == 0 + assert contraction(Fd(i), F(a)) == 0 + assert contraction(Fd(i), F(p)) == KroneckerDelta(i, p) + restr = evaluate_deltas(contraction(Fd(p), F(q))) + assert restr.is_only_below_fermi + restr = evaluate_deltas(contraction(F(p), Fd(q))) + assert restr.is_only_above_fermi + raises(ContractionAppliesOnlyToFermions, lambda: contraction(B(a), Fd(b))) + + +def test_evaluate_deltas(): + i, j, k = symbols('i,j,k') + + r = KroneckerDelta(i, j) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(i, k) + + r = KroneckerDelta(i, 0) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(i, 0) * KroneckerDelta(j, k) + + r = KroneckerDelta(1, j) * KroneckerDelta(j, k) + assert evaluate_deltas(r) == KroneckerDelta(1, k) + + r = KroneckerDelta(j, 2) * KroneckerDelta(k, j) + assert evaluate_deltas(r) == KroneckerDelta(2, k) + + r = KroneckerDelta(i, 0) * KroneckerDelta(i, j) * KroneckerDelta(j, 1) + assert evaluate_deltas(r) == 0 + + r = (KroneckerDelta(0, i) * KroneckerDelta(0, j) + * KroneckerDelta(1, j) * KroneckerDelta(1, j)) + assert evaluate_deltas(r) == 0 + + +def test_Tensors(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + p, q, r, s = symbols('p q r s') + + AT = AntiSymmetricTensor + assert AT('t', (a, b), (i, j)) == -AT('t', (b, a), (i, j)) + assert AT('t', (a, b), (i, j)) == AT('t', (b, a), (j, i)) + assert AT('t', (a, b), (i, j)) == -AT('t', (a, b), (j, i)) + assert AT('t', (a, a), (i, j)) == 0 + assert AT('t', (a, b), (i, i)) == 0 + assert AT('t', (a, b, c), (i, j)) == -AT('t', (b, a, c), (i, j)) + assert AT('t', (a, b, c), (i, j, k)) == AT('t', (b, a, c), (i, k, j)) + + tabij = AT('t', (a, b), (i, j)) + assert tabij.has(a) + assert tabij.has(b) + assert tabij.has(i) + assert tabij.has(j) + assert tabij.subs(b, c) == AT('t', (a, c), (i, j)) + assert (2*tabij).subs(i, c) == 2*AT('t', (a, b), (c, j)) + assert tabij.symbol == Symbol('t') + assert latex(tabij) == '{t^{ab}_{ij}}' + assert str(tabij) == 't((_a, _b),(_i, _j))' + + assert AT('t', (a, a), (i, j)).subs(a, b) == AT('t', (b, b), (i, j)) + assert AT('t', (a, i), (a, j)).subs(a, b) == AT('t', (b, i), (b, j)) + + +def test_fully_contracted(): + i, j, k, l = symbols('i j k l', below_fermi=True) + a, b, c, d = symbols('a b c d', above_fermi=True) + p, q, r, s = symbols('p q r s', cls=Dummy) + + Fock = (AntiSymmetricTensor('f', (p,), (q,))* + NO(Fd(p)*F(q))) + V = (AntiSymmetricTensor('v', (p, q), (r, s))* + NO(Fd(p)*Fd(q)*F(s)*F(r)))/4 + + Fai = wicks(NO(Fd(i)*F(a))*Fock, + keep_only_fully_contracted=True, + simplify_kronecker_deltas=True) + assert Fai == AntiSymmetricTensor('f', (a,), (i,)) + Vabij = wicks(NO(Fd(i)*Fd(j)*F(b)*F(a))*V, + keep_only_fully_contracted=True, + simplify_kronecker_deltas=True) + assert Vabij == AntiSymmetricTensor('v', (a, b), (i, j)) + + +def test_substitute_dummies_without_dummies(): + i, j = symbols('i,j') + assert substitute_dummies(att(i, j) + 2) == att(i, j) + 2 + assert substitute_dummies(att(i, j) + 1) == att(i, j) + 1 + + +def test_substitute_dummies_NO_operator(): + i, j = symbols('i j', cls=Dummy) + assert substitute_dummies(att(i, j)*NO(Fd(i)*F(j)) + - att(j, i)*NO(Fd(j)*F(i))) == 0 + + +def test_substitute_dummies_SQ_operator(): + i, j = symbols('i j', cls=Dummy) + assert substitute_dummies(att(i, j)*Fd(i)*F(j) + - att(j, i)*Fd(j)*F(i)) == 0 + + +def test_substitute_dummies_new_indices(): + i, j = symbols('i j', below_fermi=True, cls=Dummy) + a, b = symbols('a b', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + f = Function('f') + assert substitute_dummies(f(i, a, p) - f(j, b, q), new_indices=True) == 0 + + +def test_substitute_dummies_substitution_order(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + f = Function('f') + from sympy.utilities.iterables import variations + for permut in variations([i, j, k, l], 4): + assert substitute_dummies(f(*permut) - f(i, j, k, l)) == 0 + + +def test_dummy_order_inner_outer_lines_VT1T1T1(): + ii = symbols('i', below_fermi=True) + aa = symbols('a', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # Coupled-Cluster T1 terms with V*T1*T1*T1 + # t^{a}_{k} t^{c}_{i} t^{d}_{l} v^{lk}_{dc} + exprs = [ + # permut v and t <=> swapping internal lines, equivalent + # irrespective of symmetries in v + v(k, l, c, d)*t(c, ii)*t(d, l)*t(aa, k), + v(l, k, c, d)*t(c, ii)*t(d, k)*t(aa, l), + v(k, l, d, c)*t(d, ii)*t(c, l)*t(aa, k), + v(l, k, d, c)*t(d, ii)*t(c, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_inner_outer_lines_VT1T1T1T1(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # Coupled-Cluster T2 terms with V*T1*T1*T1*T1 + exprs = [ + # permut t <=> swapping external lines, not equivalent + # except if v has certain symmetries. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, c, d)*t(c, jj)*t(d, ii)*t(aa, k)*t(bb, l), + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(bb, k)*t(aa, l), + v(k, l, c, d)*t(c, jj)*t(d, ii)*t(bb, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + exprs = [ + # permut v <=> swapping external lines, not equivalent + # except if v has certain symmetries. + # + # Note that in contrast to above, these permutations have identical + # dummy order. That is because the proximity to external indices + # has higher influence on the canonical dummy ordering than the + # position of a dummy on the factors. In fact, the terms here are + # similar in structure as the result of the dummy substitutions above. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(l, k, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, d, c)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(l, k, d, c)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + exprs = [ + # permut t and v <=> swapping internal lines, equivalent. + # Canonical dummy order is different, and a consistent + # substitution reveals the equivalence. + v(k, l, c, d)*t(c, ii)*t(d, jj)*t(aa, k)*t(bb, l), + v(k, l, d, c)*t(c, jj)*t(d, ii)*t(aa, k)*t(bb, l), + v(l, k, c, d)*t(c, ii)*t(d, jj)*t(bb, k)*t(aa, l), + v(l, k, d, c)*t(c, jj)*t(d, ii)*t(bb, k)*t(aa, l), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_get_subNO(): + p, q, r = symbols('p,q,r') + assert NO(F(p)*F(q)*F(r)).get_subNO(1) == NO(F(p)*F(r)) + assert NO(F(p)*F(q)*F(r)).get_subNO(0) == NO(F(q)*F(r)) + assert NO(F(p)*F(q)*F(r)).get_subNO(2) == NO(F(p)*F(q)) + + +def test_equivalent_internal_lines_VT1T1(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ # permute v. Different dummy order. Not equivalent. + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, a, b)*t(a, i)*t(b, j), + v(i, j, b, a)*t(a, i)*t(b, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v. Different dummy order. Equivalent + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, b, a)*t(a, i)*t(b, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + exprs = [ # permute t. Same dummy order, not equivalent. + v(i, j, a, b)*t(a, i)*t(b, j), + v(i, j, a, b)*t(b, i)*t(a, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Different dummy order, equivalent + v(i, j, a, b)*t(a, i)*t(b, j), + v(j, i, a, b)*t(a, j)*t(b, i), + v(i, j, b, a)*t(b, i)*t(a, j), + v(j, i, b, a)*t(b, j)*t(a, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT2conjT2(): + # this diagram requires special handling in TCE + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # v(abcd)t(abij)t(ijcd) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + # v(abcd)t(abij)t(jicd) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2conjT2_ambiguous_order(): + # These diagrams invokes _determine_ambiguous() because the + # dummies can not be ordered unambiguously by the key alone + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + # v(abcd)t(abij)t(cdij) + template = v(p1, p2, p3, p4)*t(p1, p2, i, j)*t(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + template = v(p1, p2, p3, p4)*t(p1, p2, j, i)*t(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert dums(base) != dums(expr) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + exprs = [ + # permute v. Same dummy order, not equivalent. + # + # This test show that the dummy order may not be sensitive to all + # index permutations. The following expressions have identical + # structure as the resulting terms from of the dummy substitutions + # in the test above. Here, all expressions have the same dummy + # order, so they cannot be simplified by means of dummy + # substitution. In order to simplify further, it is necessary to + # exploit symmetries in the objects, for instance if t or v is + # antisymmetric. + v(i, j, a, b)*t(a, b, i, j), + v(j, i, a, b)*t(a, b, i, j), + v(i, j, b, a)*t(a, b, i, j), + v(j, i, b, a)*t(a, b, i, j), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) == dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ + # permute t. + v(i, j, a, b)*t(a, b, i, j), + v(i, j, a, b)*t(b, a, i, j), + v(i, j, a, b)*t(a, b, j, i), + v(i, j, a, b)*t(b, a, j, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Relabelling of dummies should be equivalent. + v(i, j, a, b)*t(a, b, i, j), + v(j, i, a, b)*t(a, b, j, i), + v(i, j, b, a)*t(b, a, i, j), + v(j, i, b, a)*t(b, a, j, i), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_VT2T2(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(d, bb, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(d, bb, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(c, bb, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(c, bb, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + v(k, l, c, d)*t(c, aa, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(c, aa, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(d, aa, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(d, aa, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_pqrs(): + ii, jj = symbols('i j') + aa, bb = symbols('a b') + k, l = symbols('k l', cls=Dummy) + c, d = symbols('c d', cls=Dummy) + + v = Function('v') + t = Function('t') + dums = _get_ordered_dummies + + exprs = [ + v(k, l, c, d)*t(aa, c, ii, k)*t(bb, d, jj, l), + v(l, k, c, d)*t(aa, c, ii, l)*t(bb, d, jj, k), + v(k, l, d, c)*t(aa, d, ii, k)*t(bb, c, jj, l), + v(l, k, d, c)*t(aa, d, ii, l)*t(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert dums(exprs[0]) != dums(permut) + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_well_defined(): + aa, bb = symbols('a b', above_fermi=True) + k, l, m = symbols('k l m', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + + A = Function('A') + B = Function('B') + C = Function('C') + dums = _get_ordered_dummies + + # We go through all key components in the order of increasing priority, + # and consider only fully orderable expressions. Non-orderable expressions + # are tested elsewhere. + + # pos in first factor determines sort order + assert dums(A(k, l)*B(l, k)) == [k, l] + assert dums(A(l, k)*B(l, k)) == [l, k] + assert dums(A(k, l)*B(k, l)) == [k, l] + assert dums(A(l, k)*B(k, l)) == [l, k] + + # factors involving the index + assert dums(A(k, l)*B(l, m)*C(k, m)) == [l, k, m] + assert dums(A(k, l)*B(l, m)*C(m, k)) == [l, k, m] + assert dums(A(l, k)*B(l, m)*C(k, m)) == [l, k, m] + assert dums(A(l, k)*B(l, m)*C(m, k)) == [l, k, m] + assert dums(A(k, l)*B(m, l)*C(k, m)) == [l, k, m] + assert dums(A(k, l)*B(m, l)*C(m, k)) == [l, k, m] + assert dums(A(l, k)*B(m, l)*C(k, m)) == [l, k, m] + assert dums(A(l, k)*B(m, l)*C(m, k)) == [l, k, m] + + # same, but with factor order determined by non-dummies + assert dums(A(k, aa, l)*A(l, bb, m)*A(bb, k, m)) == [l, k, m] + assert dums(A(k, aa, l)*A(l, bb, m)*A(bb, m, k)) == [l, k, m] + assert dums(A(k, aa, l)*A(m, bb, l)*A(bb, k, m)) == [l, k, m] + assert dums(A(k, aa, l)*A(m, bb, l)*A(bb, m, k)) == [l, k, m] + assert dums(A(l, aa, k)*A(l, bb, m)*A(bb, k, m)) == [l, k, m] + assert dums(A(l, aa, k)*A(l, bb, m)*A(bb, m, k)) == [l, k, m] + assert dums(A(l, aa, k)*A(m, bb, l)*A(bb, k, m)) == [l, k, m] + assert dums(A(l, aa, k)*A(m, bb, l)*A(bb, m, k)) == [l, k, m] + + # index range + assert dums(A(p, c, k)*B(p, c, k)) == [k, c, p] + assert dums(A(p, k, c)*B(p, c, k)) == [k, c, p] + assert dums(A(c, k, p)*B(p, c, k)) == [k, c, p] + assert dums(A(c, p, k)*B(p, c, k)) == [k, c, p] + assert dums(A(k, c, p)*B(p, c, k)) == [k, c, p] + assert dums(A(k, p, c)*B(p, c, k)) == [k, c, p] + assert dums(B(p, c, k)*A(p, c, k)) == [k, c, p] + assert dums(B(p, k, c)*A(p, c, k)) == [k, c, p] + assert dums(B(c, k, p)*A(p, c, k)) == [k, c, p] + assert dums(B(c, p, k)*A(p, c, k)) == [k, c, p] + assert dums(B(k, c, p)*A(p, c, k)) == [k, c, p] + assert dums(B(k, p, c)*A(p, c, k)) == [k, c, p] + + +def test_dummy_order_ambiguous(): + aa, bb = symbols('a b', above_fermi=True) + i, j, k, l, m = symbols('i j k l m', below_fermi=True, cls=Dummy) + a, b, c, d, e = symbols('a b c d e', above_fermi=True, cls=Dummy) + p, q = symbols('p q', cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + p5, p6, p7, p8 = symbols('p5 p6 p7 p8', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + h5, h6, h7, h8 = symbols('h5 h6 h7 h8', below_fermi=True, cls=Dummy) + + A = Function('A') + B = Function('B') + + from sympy.utilities.iterables import variations + + # A*A*A*A*B -- ordering of p5 and p4 is used to figure out the rest + template = A(p1, p2)*A(p4, p1)*A(p2, p3)*A(p3, p5)*B(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # A*A*A*A*A -- an arbitrary index is assigned and the rest are figured out + template = A(p1, p2)*A(p4, p1)*A(p2, p3)*A(p3, p5)*A(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # A*A*A -- ordering of p5 and p4 is used to figure out the rest + template = A(p1, p2, p4, p1)*A(p2, p3, p3, p5)*A(p5, p4) + permutator = variations([a, b, c, d, e], 5) + base = template.subs(zip([p1, p2, p3, p4, p5], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4, p5], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def atv(*args): + return AntiSymmetricTensor('v', args[:2], args[2:] ) + + +def att(*args): + if len(args) == 4: + return AntiSymmetricTensor('t', args[:2], args[2:] ) + elif len(args) == 2: + return AntiSymmetricTensor('t', (args[0],), (args[1],)) + + +def test_dummy_order_inner_outer_lines_VT1T1T1_AT(): + ii = symbols('i', below_fermi=True) + aa = symbols('a', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + # Coupled-Cluster T1 terms with V*T1*T1*T1 + # t^{a}_{k} t^{c}_{i} t^{d}_{l} v^{lk}_{dc} + exprs = [ + # permut v and t <=> swapping internal lines, equivalent + # irrespective of symmetries in v + atv(k, l, c, d)*att(c, ii)*att(d, l)*att(aa, k), + atv(l, k, c, d)*att(c, ii)*att(d, k)*att(aa, l), + atv(k, l, d, c)*att(d, ii)*att(c, l)*att(aa, k), + atv(l, k, d, c)*att(d, ii)*att(c, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_dummy_order_inner_outer_lines_VT1T1T1T1_AT(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + # Coupled-Cluster T2 terms with V*T1*T1*T1*T1 + # non-equivalent substitutions (change of sign) + exprs = [ + # permut t <=> swapping external lines + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(aa, k)*att(bb, l), + atv(k, l, c, d)*att(c, jj)*att(d, ii)*att(aa, k)*att(bb, l), + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(bb, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == -substitute_dummies(permut) + + # equivalent substitutions + exprs = [ + atv(k, l, c, d)*att(c, ii)*att(d, jj)*att(aa, k)*att(bb, l), + # permut t <=> swapping external lines + atv(k, l, c, d)*att(c, jj)*att(d, ii)*att(bb, k)*att(aa, l), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT1T1_AT(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + exprs = [ # permute v. Different dummy order. Not equivalent. + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, a, b)*att(a, i)*att(b, j), + atv(i, j, b, a)*att(a, i)*att(b, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v. Different dummy order. Equivalent + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, b, a)*att(a, i)*att(b, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + exprs = [ # permute t. Same dummy order, not equivalent. + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(i, j, a, b)*att(b, i)*att(a, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Different dummy order, equivalent + atv(i, j, a, b)*att(a, i)*att(b, j), + atv(j, i, a, b)*att(a, j)*att(b, i), + atv(i, j, b, a)*att(b, i)*att(a, j), + atv(j, i, b, a)*att(b, j)*att(a, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_equivalent_internal_lines_VT2conjT2_AT(): + # this diagram requires special handling in TCE + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + # atv(abcd)att(abij)att(ijcd) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + # atv(abcd)att(abij)att(jicd) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(j, i, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(i, j, p3, p4) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2conjT2_ambiguous_order_AT(): + # These diagrams invokes _determine_ambiguous() because the + # dummies can not be ordered unambiguously by the key alone + i, j, k, l, m, n = symbols('i j k l m n', below_fermi=True, cls=Dummy) + a, b, c, d, e, f = symbols('a b c d e f', above_fermi=True, cls=Dummy) + p1, p2, p3, p4 = symbols('p1 p2 p3 p4', above_fermi=True, cls=Dummy) + h1, h2, h3, h4 = symbols('h1 h2 h3 h4', below_fermi=True, cls=Dummy) + + from sympy.utilities.iterables import variations + + # atv(abcd)att(abij)att(cdij) + template = atv(p1, p2, p3, p4)*att(p1, p2, i, j)*att(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + template = atv(p1, p2, p3, p4)*att(p1, p2, j, i)*att(p3, p4, i, j) + permutator = variations([a, b, c, d], 4) + base = template.subs(zip([p1, p2, p3, p4], next(permutator))) + for permut in permutator: + subslist = zip([p1, p2, p3, p4], permut) + expr = template.subs(subslist) + assert substitute_dummies(expr) == substitute_dummies(base) + + +def test_equivalent_internal_lines_VT2_AT(): + i, j, k, l = symbols('i j k l', below_fermi=True, cls=Dummy) + a, b, c, d = symbols('a b c d', above_fermi=True, cls=Dummy) + + exprs = [ + # permute v. Same dummy order, not equivalent. + atv(i, j, a, b)*att(a, b, i, j), + atv(j, i, a, b)*att(a, b, i, j), + atv(i, j, b, a)*att(a, b, i, j), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ + # permute t. + atv(i, j, a, b)*att(a, b, i, j), + atv(i, j, a, b)*att(b, a, i, j), + atv(i, j, a, b)*att(a, b, j, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) != substitute_dummies(permut) + + exprs = [ # permute v and t. Relabelling of dummies should be equivalent. + atv(i, j, a, b)*att(a, b, i, j), + atv(j, i, a, b)*att(a, b, j, i), + atv(i, j, b, a)*att(b, a, i, j), + atv(j, i, b, a)*att(b, a, j, i), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_VT2T2_AT(): + ii, jj = symbols('i j', below_fermi=True) + aa, bb = symbols('a b', above_fermi=True) + k, l = symbols('k l', below_fermi=True, cls=Dummy) + c, d = symbols('c d', above_fermi=True, cls=Dummy) + + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(d, bb, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(d, bb, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(c, bb, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(c, bb, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + exprs = [ + atv(k, l, c, d)*att(c, aa, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(c, aa, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(d, aa, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(d, aa, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_internal_external_pqrs_AT(): + ii, jj = symbols('i j') + aa, bb = symbols('a b') + k, l = symbols('k l', cls=Dummy) + c, d = symbols('c d', cls=Dummy) + + exprs = [ + atv(k, l, c, d)*att(aa, c, ii, k)*att(bb, d, jj, l), + atv(l, k, c, d)*att(aa, c, ii, l)*att(bb, d, jj, k), + atv(k, l, d, c)*att(aa, d, ii, k)*att(bb, c, jj, l), + atv(l, k, d, c)*att(aa, d, ii, l)*att(bb, c, jj, k), + ] + for permut in exprs[1:]: + assert substitute_dummies(exprs[0]) == substitute_dummies(permut) + + +def test_issue_19661(): + a = Symbol('0') + assert latex(Commutator(Bd(a)**2, B(a)) + ) == '- \\left[b_{0},{b^\\dagger_{0}}^{2}\\right]' + + +def test_canonical_ordering_AntiSymmetricTensor(): + v = symbols("v") + + c, d = symbols(('c','d'), above_fermi=True, + cls=Dummy) + k, l = symbols(('k','l'), below_fermi=True, + cls=Dummy) + + # formerly, the left gave either the left or the right + assert AntiSymmetricTensor(v, (k, l), (d, c) + ) == -AntiSymmetricTensor(v, (l, k), (d, c)) diff --git a/lib/python3.10/site-packages/sympy/physics/tests/test_sho.py b/lib/python3.10/site-packages/sympy/physics/tests/test_sho.py new file mode 100644 index 0000000000000000000000000000000000000000..7248838b4bb9ad280fd4211bbe208063b65adcf5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/tests/test_sho.py @@ -0,0 +1,21 @@ +from sympy.core import symbols, Rational, Function, diff +from sympy.physics.sho import R_nl, E_nl +from sympy.simplify.simplify import simplify + + +def test_sho_R_nl(): + omega, r = symbols('omega r') + l = symbols('l', integer=True) + u = Function('u') + + # check that it obeys the Schrodinger equation + for n in range(5): + schreq = ( -diff(u(r), r, 2)/2 + ((l*(l + 1))/(2*r**2) + + omega**2*r**2/2 - E_nl(n, l, omega))*u(r) ) + result = schreq.subs(u(r), r*R_nl(n, l, omega/2, r)) + assert simplify(result.doit()) == 0 + + +def test_energy(): + n, l, hw = symbols('n l hw') + assert simplify(E_nl(n, l, hw) - (2*n + l + Rational(3, 2))*hw) == 0 diff --git a/lib/python3.10/site-packages/sympy/physics/units/__init__.py b/lib/python3.10/site-packages/sympy/physics/units/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bf17c7f3051b03d9c0fc794d9d79885c94cc878e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/units/__init__.py @@ -0,0 +1,453 @@ +# isort:skip_file +""" +Dimensional analysis and unit systems. + +This module defines dimension/unit systems and physical quantities. It is +based on a group-theoretical construction where dimensions are represented as +vectors (coefficients being the exponents), and units are defined as a dimension +to which we added a scale. + +Quantities are built from a factor and a unit, and are the basic objects that +one will use when doing computations. + +All objects except systems and prefixes can be used in SymPy expressions. +Note that as part of a CAS, various objects do not combine automatically +under operations. + +Details about the implementation can be found in the documentation, and we +will not repeat all the explanations we gave there concerning our approach. +Ideas about future developments can be found on the `Github wiki +`_, and you should consult +this page if you are willing to help. + +Useful functions: + +- ``find_unit``: easily lookup pre-defined units. +- ``convert_to(expr, newunit)``: converts an expression into the same + expression expressed in another unit. + +""" + +from .dimensions import Dimension, DimensionSystem +from .unitsystem import UnitSystem +from .util import convert_to +from .quantities import Quantity + +from .definitions.dimension_definitions import ( + amount_of_substance, acceleration, action, area, + capacitance, charge, conductance, current, energy, + force, frequency, impedance, inductance, length, + luminous_intensity, magnetic_density, + magnetic_flux, mass, momentum, power, pressure, temperature, time, + velocity, voltage, volume +) + +Unit = Quantity + +speed = velocity +luminosity = luminous_intensity +magnetic_flux_density = magnetic_density +amount = amount_of_substance + +from .prefixes import ( + # 10-power based: + yotta, + zetta, + exa, + peta, + tera, + giga, + mega, + kilo, + hecto, + deca, + deci, + centi, + milli, + micro, + nano, + pico, + femto, + atto, + zepto, + yocto, + # 2-power based: + kibi, + mebi, + gibi, + tebi, + pebi, + exbi, +) + +from .definitions import ( + percent, percents, + permille, + rad, radian, radians, + deg, degree, degrees, + sr, steradian, steradians, + mil, angular_mil, angular_mils, + m, meter, meters, + kg, kilogram, kilograms, + s, second, seconds, + A, ampere, amperes, + K, kelvin, kelvins, + mol, mole, moles, + cd, candela, candelas, + g, gram, grams, + mg, milligram, milligrams, + ug, microgram, micrograms, + t, tonne, metric_ton, + newton, newtons, N, + joule, joules, J, + watt, watts, W, + pascal, pascals, Pa, pa, + hertz, hz, Hz, + coulomb, coulombs, C, + volt, volts, v, V, + ohm, ohms, + siemens, S, mho, mhos, + farad, farads, F, + henry, henrys, H, + tesla, teslas, T, + weber, webers, Wb, wb, + optical_power, dioptre, D, + lux, lx, + katal, kat, + gray, Gy, + becquerel, Bq, + km, kilometer, kilometers, + dm, decimeter, decimeters, + cm, centimeter, centimeters, + mm, millimeter, millimeters, + um, micrometer, micrometers, micron, microns, + nm, nanometer, nanometers, + pm, picometer, picometers, + ft, foot, feet, + inch, inches, + yd, yard, yards, + mi, mile, miles, + nmi, nautical_mile, nautical_miles, + angstrom, angstroms, + ha, hectare, + l, L, liter, liters, + dl, dL, deciliter, deciliters, + cl, cL, centiliter, centiliters, + ml, mL, milliliter, milliliters, + ms, millisecond, milliseconds, + us, microsecond, microseconds, + ns, nanosecond, nanoseconds, + ps, picosecond, picoseconds, + minute, minutes, + h, hour, hours, + day, days, + anomalistic_year, anomalistic_years, + sidereal_year, sidereal_years, + tropical_year, tropical_years, + common_year, common_years, + julian_year, julian_years, + draconic_year, draconic_years, + gaussian_year, gaussian_years, + full_moon_cycle, full_moon_cycles, + year, years, + G, gravitational_constant, + c, speed_of_light, + elementary_charge, + hbar, + planck, + eV, electronvolt, electronvolts, + avogadro_number, + avogadro, avogadro_constant, + boltzmann, boltzmann_constant, + stefan, stefan_boltzmann_constant, + R, molar_gas_constant, + faraday_constant, + josephson_constant, + von_klitzing_constant, + Da, dalton, amu, amus, atomic_mass_unit, atomic_mass_constant, + me, electron_rest_mass, + gee, gees, acceleration_due_to_gravity, + u0, magnetic_constant, vacuum_permeability, + e0, electric_constant, vacuum_permittivity, + Z0, vacuum_impedance, + coulomb_constant, electric_force_constant, + atmosphere, atmospheres, atm, + kPa, + bar, bars, + pound, pounds, + psi, + dHg0, + mmHg, torr, + mmu, mmus, milli_mass_unit, + quart, quarts, + ly, lightyear, lightyears, + au, astronomical_unit, astronomical_units, + planck_mass, + planck_time, + planck_temperature, + planck_length, + planck_charge, + planck_area, + planck_volume, + planck_momentum, + planck_energy, + planck_force, + planck_power, + planck_density, + planck_energy_density, + planck_intensity, + planck_angular_frequency, + planck_pressure, + planck_current, + planck_voltage, + planck_impedance, + planck_acceleration, + bit, bits, + byte, + kibibyte, kibibytes, + mebibyte, mebibytes, + gibibyte, gibibytes, + tebibyte, tebibytes, + pebibyte, pebibytes, + exbibyte, exbibytes, +) + +from .systems import ( + mks, mksa, si +) + + +def find_unit(quantity, unit_system="SI"): + """ + Return a list of matching units or dimension names. + + - If ``quantity`` is a string -- units/dimensions containing the string + `quantity`. + - If ``quantity`` is a unit or dimension -- units having matching base + units or dimensions. + + Examples + ======== + + >>> from sympy.physics import units as u + >>> u.find_unit('charge') + ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + >>> u.find_unit(u.charge) + ['C', 'coulomb', 'coulombs', 'planck_charge', 'elementary_charge'] + >>> u.find_unit("ampere") + ['ampere', 'amperes'] + >>> u.find_unit('angstrom') + ['angstrom', 'angstroms'] + >>> u.find_unit('volt') + ['volt', 'volts', 'electronvolt', 'electronvolts', 'planck_voltage'] + >>> u.find_unit(u.inch**3)[:9] + ['L', 'l', 'cL', 'cl', 'dL', 'dl', 'mL', 'ml', 'liter'] + """ + unit_system = UnitSystem.get_unit_system(unit_system) + + import sympy.physics.units as u + rv = [] + if isinstance(quantity, str): + rv = [i for i in dir(u) if quantity in i and isinstance(getattr(u, i), Quantity)] + dim = getattr(u, quantity) + if isinstance(dim, Dimension): + rv.extend(find_unit(dim)) + else: + for i in sorted(dir(u)): + other = getattr(u, i) + if not isinstance(other, Quantity): + continue + if isinstance(quantity, Quantity): + if quantity.dimension == other.dimension: + rv.append(str(i)) + elif isinstance(quantity, Dimension): + if other.dimension == quantity: + rv.append(str(i)) + elif other.dimension == Dimension(unit_system.get_dimensional_expr(quantity)): + rv.append(str(i)) + return sorted(set(rv), key=lambda x: (len(x), x)) + +# NOTE: the old units module had additional variables: +# 'density', 'illuminance', 'resistance'. +# They were not dimensions, but units (old Unit class). + +__all__ = [ + 'Dimension', 'DimensionSystem', + 'UnitSystem', + 'convert_to', + 'Quantity', + + 'amount_of_substance', 'acceleration', 'action', 'area', + 'capacitance', 'charge', 'conductance', 'current', 'energy', + 'force', 'frequency', 'impedance', 'inductance', 'length', + 'luminous_intensity', 'magnetic_density', + 'magnetic_flux', 'mass', 'momentum', 'power', 'pressure', 'temperature', 'time', + 'velocity', 'voltage', 'volume', + + 'Unit', + + 'speed', + 'luminosity', + 'magnetic_flux_density', + 'amount', + + 'yotta', + 'zetta', + 'exa', + 'peta', + 'tera', + 'giga', + 'mega', + 'kilo', + 'hecto', + 'deca', + 'deci', + 'centi', + 'milli', + 'micro', + 'nano', + 'pico', + 'femto', + 'atto', + 'zepto', + 'yocto', + + 'kibi', + 'mebi', + 'gibi', + 'tebi', + 'pebi', + 'exbi', + + 'percent', 'percents', + 'permille', + 'rad', 'radian', 'radians', + 'deg', 'degree', 'degrees', + 'sr', 'steradian', 'steradians', + 'mil', 'angular_mil', 'angular_mils', + 'm', 'meter', 'meters', + 'kg', 'kilogram', 'kilograms', + 's', 'second', 'seconds', + 'A', 'ampere', 'amperes', + 'K', 'kelvin', 'kelvins', + 'mol', 'mole', 'moles', + 'cd', 'candela', 'candelas', + 'g', 'gram', 'grams', + 'mg', 'milligram', 'milligrams', + 'ug', 'microgram', 'micrograms', + 't', 'tonne', 'metric_ton', + 'newton', 'newtons', 'N', + 'joule', 'joules', 'J', + 'watt', 'watts', 'W', + 'pascal', 'pascals', 'Pa', 'pa', + 'hertz', 'hz', 'Hz', + 'coulomb', 'coulombs', 'C', + 'volt', 'volts', 'v', 'V', + 'ohm', 'ohms', + 'siemens', 'S', 'mho', 'mhos', + 'farad', 'farads', 'F', + 'henry', 'henrys', 'H', + 'tesla', 'teslas', 'T', + 'weber', 'webers', 'Wb', 'wb', + 'optical_power', 'dioptre', 'D', + 'lux', 'lx', + 'katal', 'kat', + 'gray', 'Gy', + 'becquerel', 'Bq', + 'km', 'kilometer', 'kilometers', + 'dm', 'decimeter', 'decimeters', + 'cm', 'centimeter', 'centimeters', + 'mm', 'millimeter', 'millimeters', + 'um', 'micrometer', 'micrometers', 'micron', 'microns', + 'nm', 'nanometer', 'nanometers', + 'pm', 'picometer', 'picometers', + 'ft', 'foot', 'feet', + 'inch', 'inches', + 'yd', 'yard', 'yards', + 'mi', 'mile', 'miles', + 'nmi', 'nautical_mile', 'nautical_miles', + 'angstrom', 'angstroms', + 'ha', 'hectare', + 'l', 'L', 'liter', 'liters', + 'dl', 'dL', 'deciliter', 'deciliters', + 'cl', 'cL', 'centiliter', 'centiliters', + 'ml', 'mL', 'milliliter', 'milliliters', + 'ms', 'millisecond', 'milliseconds', + 'us', 'microsecond', 'microseconds', + 'ns', 'nanosecond', 'nanoseconds', + 'ps', 'picosecond', 'picoseconds', + 'minute', 'minutes', + 'h', 'hour', 'hours', + 'day', 'days', + 'anomalistic_year', 'anomalistic_years', + 'sidereal_year', 'sidereal_years', + 'tropical_year', 'tropical_years', + 'common_year', 'common_years', + 'julian_year', 'julian_years', + 'draconic_year', 'draconic_years', + 'gaussian_year', 'gaussian_years', + 'full_moon_cycle', 'full_moon_cycles', + 'year', 'years', + 'G', 'gravitational_constant', + 'c', 'speed_of_light', + 'elementary_charge', + 'hbar', + 'planck', + 'eV', 'electronvolt', 'electronvolts', + 'avogadro_number', + 'avogadro', 'avogadro_constant', + 'boltzmann', 'boltzmann_constant', + 'stefan', 'stefan_boltzmann_constant', + 'R', 'molar_gas_constant', + 'faraday_constant', + 'josephson_constant', + 'von_klitzing_constant', + 'Da', 'dalton', 'amu', 'amus', 'atomic_mass_unit', 'atomic_mass_constant', + 'me', 'electron_rest_mass', + 'gee', 'gees', 'acceleration_due_to_gravity', + 'u0', 'magnetic_constant', 'vacuum_permeability', + 'e0', 'electric_constant', 'vacuum_permittivity', + 'Z0', 'vacuum_impedance', + 'coulomb_constant', 'electric_force_constant', + 'atmosphere', 'atmospheres', 'atm', + 'kPa', + 'bar', 'bars', + 'pound', 'pounds', + 'psi', + 'dHg0', + 'mmHg', 'torr', + 'mmu', 'mmus', 'milli_mass_unit', + 'quart', 'quarts', + 'ly', 'lightyear', 'lightyears', + 'au', 'astronomical_unit', 'astronomical_units', + 'planck_mass', + 'planck_time', + 'planck_temperature', + 'planck_length', + 'planck_charge', + 'planck_area', + 'planck_volume', + 'planck_momentum', + 'planck_energy', + 'planck_force', + 'planck_power', + 'planck_density', + 'planck_energy_density', + 'planck_intensity', + 'planck_angular_frequency', + 'planck_pressure', + 'planck_current', + 'planck_voltage', + 'planck_impedance', + 'planck_acceleration', + 'bit', 'bits', + 'byte', + 'kibibyte', 'kibibytes', + 'mebibyte', 'mebibytes', + 'gibibyte', 'gibibytes', + 'tebibyte', 'tebibytes', + 'pebibyte', 'pebibytes', + 'exbibyte', 'exbibytes', + + 'mks', 'mksa', 'si', +] diff --git a/lib/python3.10/site-packages/sympy/physics/units/dimensions.py b/lib/python3.10/site-packages/sympy/physics/units/dimensions.py new file mode 100644 index 0000000000000000000000000000000000000000..951f8ab17a58606e74f1c14ffa312fedbeebb4b6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/units/dimensions.py @@ -0,0 +1,590 @@ +""" +Definition of physical dimensions. + +Unit systems will be constructed on top of these dimensions. + +Most of the examples in the doc use MKS system and are presented from the +computer point of view: from a human point, adding length to time is not legal +in MKS but it is in natural system; for a computer in natural system there is +no time dimension (but a velocity dimension instead) - in the basis - so the +question of adding time to length has no meaning. +""" + +from __future__ import annotations + +import collections +from functools import reduce + +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.matrices.dense import Matrix +from sympy.functions.elementary.trigonometric import TrigonometricFunction +from sympy.core.expr import Expr +from sympy.core.power import Pow + + +class _QuantityMapper: + + _quantity_scale_factors_global: dict[Expr, Expr] = {} + _quantity_dimensional_equivalence_map_global: dict[Expr, Expr] = {} + _quantity_dimension_global: dict[Expr, Expr] = {} + + def __init__(self, *args, **kwargs): + self._quantity_dimension_map = {} + self._quantity_scale_factors = {} + + def set_quantity_dimension(self, quantity, dimension): + """ + Set the dimension for the quantity in a unit system. + + If this relation is valid in every unit system, use + ``quantity.set_global_dimension(dimension)`` instead. + """ + from sympy.physics.units import Quantity + dimension = sympify(dimension) + if not isinstance(dimension, Dimension): + if dimension == 1: + dimension = Dimension(1) + else: + raise ValueError("expected dimension or 1") + elif isinstance(dimension, Quantity): + dimension = self.get_quantity_dimension(dimension) + self._quantity_dimension_map[quantity] = dimension + + def set_quantity_scale_factor(self, quantity, scale_factor): + """ + Set the scale factor of a quantity relative to another quantity. + + It should be used only once per quantity to just one other quantity, + the algorithm will then be able to compute the scale factors to all + other quantities. + + In case the scale factor is valid in every unit system, please use + ``quantity.set_global_relative_scale_factor(scale_factor)`` instead. + """ + from sympy.physics.units import Quantity + from sympy.physics.units.prefixes import Prefix + scale_factor = sympify(scale_factor) + # replace all prefixes by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Prefix), + lambda x: x.scale_factor + ) + # replace all quantities by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Quantity), + lambda x: self.get_quantity_scale_factor(x) + ) + self._quantity_scale_factors[quantity] = scale_factor + + def get_quantity_dimension(self, unit): + from sympy.physics.units import Quantity + # First look-up the local dimension map, then the global one: + if unit in self._quantity_dimension_map: + return self._quantity_dimension_map[unit] + if unit in self._quantity_dimension_global: + return self._quantity_dimension_global[unit] + if unit in self._quantity_dimensional_equivalence_map_global: + dep_unit = self._quantity_dimensional_equivalence_map_global[unit] + if isinstance(dep_unit, Quantity): + return self.get_quantity_dimension(dep_unit) + else: + return Dimension(self.get_dimensional_expr(dep_unit)) + if isinstance(unit, Quantity): + return Dimension(unit.name) + else: + return Dimension(1) + + def get_quantity_scale_factor(self, unit): + if unit in self._quantity_scale_factors: + return self._quantity_scale_factors[unit] + if unit in self._quantity_scale_factors_global: + mul_factor, other_unit = self._quantity_scale_factors_global[unit] + return mul_factor*self.get_quantity_scale_factor(other_unit) + return S.One + + +class Dimension(Expr): + """ + This class represent the dimension of a physical quantities. + + The ``Dimension`` constructor takes as parameters a name and an optional + symbol. + + For example, in classical mechanics we know that time is different from + temperature and dimensions make this difference (but they do not provide + any measure of these quantites. + + >>> from sympy.physics.units import Dimension + >>> length = Dimension('length') + >>> length + Dimension(length) + >>> time = Dimension('time') + >>> time + Dimension(time) + + Dimensions can be composed using multiplication, division and + exponentiation (by a number) to give new dimensions. Addition and + subtraction is defined only when the two objects are the same dimension. + + >>> velocity = length / time + >>> velocity + Dimension(length/time) + + It is possible to use a dimension system object to get the dimensionsal + dependencies of a dimension, for example the dimension system used by the + SI units convention can be used: + + >>> from sympy.physics.units.systems.si import dimsys_SI + >>> dimsys_SI.get_dimensional_dependencies(velocity) + {Dimension(length, L): 1, Dimension(time, T): -1} + >>> length + length + Dimension(length) + >>> l2 = length**2 + >>> l2 + Dimension(length**2) + >>> dimsys_SI.get_dimensional_dependencies(l2) + {Dimension(length, L): 2} + + """ + + _op_priority = 13.0 + + # XXX: This doesn't seem to be used anywhere... + _dimensional_dependencies = {} # type: ignore + + is_commutative = True + is_number = False + # make sqrt(M**2) --> M + is_positive = True + is_real = True + + def __new__(cls, name, symbol=None): + + if isinstance(name, str): + name = Symbol(name) + else: + name = sympify(name) + + if not isinstance(name, Expr): + raise TypeError("Dimension name needs to be a valid math expression") + + if isinstance(symbol, str): + symbol = Symbol(symbol) + elif symbol is not None: + assert isinstance(symbol, Symbol) + + obj = Expr.__new__(cls, name) + + obj._name = name + obj._symbol = symbol + return obj + + @property + def name(self): + return self._name + + @property + def symbol(self): + return self._symbol + + def __str__(self): + """ + Display the string representation of the dimension. + """ + if self.symbol is None: + return "Dimension(%s)" % (self.name) + else: + return "Dimension(%s, %s)" % (self.name, self.symbol) + + def __repr__(self): + return self.__str__() + + def __neg__(self): + return self + + def __add__(self, other): + from sympy.physics.units.quantities import Quantity + other = sympify(other) + if isinstance(other, Basic): + if other.has(Quantity): + raise TypeError("cannot sum dimension and quantity") + if isinstance(other, Dimension) and self == other: + return self + return super().__add__(other) + return self + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + # there is no notion of ordering (or magnitude) among dimension, + # subtraction is equivalent to addition when the operation is legal + return self + other + + def __rsub__(self, other): + # there is no notion of ordering (or magnitude) among dimension, + # subtraction is equivalent to addition when the operation is legal + return self + other + + def __pow__(self, other): + return self._eval_power(other) + + def _eval_power(self, other): + other = sympify(other) + return Dimension(self.name**other) + + def __mul__(self, other): + from sympy.physics.units.quantities import Quantity + if isinstance(other, Basic): + if other.has(Quantity): + raise TypeError("cannot sum dimension and quantity") + if isinstance(other, Dimension): + return Dimension(self.name*other.name) + if not other.free_symbols: # other.is_number cannot be used + return self + return super().__mul__(other) + return self + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + return self*Pow(other, -1) + + def __rtruediv__(self, other): + return other * pow(self, -1) + + @classmethod + def _from_dimensional_dependencies(cls, dependencies): + return reduce(lambda x, y: x * y, ( + d**e for d, e in dependencies.items() + ), 1) + + def has_integer_powers(self, dim_sys): + """ + Check if the dimension object has only integer powers. + + All the dimension powers should be integers, but rational powers may + appear in intermediate steps. This method may be used to check that the + final result is well-defined. + """ + + return all(dpow.is_Integer for dpow in dim_sys.get_dimensional_dependencies(self).values()) + + +# Create dimensions according to the base units in MKSA. +# For other unit systems, they can be derived by transforming the base +# dimensional dependency dictionary. + + +class DimensionSystem(Basic, _QuantityMapper): + r""" + DimensionSystem represents a coherent set of dimensions. + + The constructor takes three parameters: + + - base dimensions; + - derived dimensions: these are defined in terms of the base dimensions + (for example velocity is defined from the division of length by time); + - dependency of dimensions: how the derived dimensions depend + on the base dimensions. + + Optionally either the ``derived_dims`` or the ``dimensional_dependencies`` + may be omitted. + """ + + def __new__(cls, base_dims, derived_dims=(), dimensional_dependencies={}): + dimensional_dependencies = dict(dimensional_dependencies) + + def parse_dim(dim): + if isinstance(dim, str): + dim = Dimension(Symbol(dim)) + elif isinstance(dim, Dimension): + pass + elif isinstance(dim, Symbol): + dim = Dimension(dim) + else: + raise TypeError("%s wrong type" % dim) + return dim + + base_dims = [parse_dim(i) for i in base_dims] + derived_dims = [parse_dim(i) for i in derived_dims] + + for dim in base_dims: + if (dim in dimensional_dependencies + and (len(dimensional_dependencies[dim]) != 1 or + dimensional_dependencies[dim].get(dim, None) != 1)): + raise IndexError("Repeated value in base dimensions") + dimensional_dependencies[dim] = Dict({dim: 1}) + + def parse_dim_name(dim): + if isinstance(dim, Dimension): + return dim + elif isinstance(dim, str): + return Dimension(Symbol(dim)) + elif isinstance(dim, Symbol): + return Dimension(dim) + else: + raise TypeError("unrecognized type %s for %s" % (type(dim), dim)) + + for dim in dimensional_dependencies.keys(): + dim = parse_dim(dim) + if (dim not in derived_dims) and (dim not in base_dims): + derived_dims.append(dim) + + def parse_dict(d): + return Dict({parse_dim_name(i): j for i, j in d.items()}) + + # Make sure everything is a SymPy type: + dimensional_dependencies = {parse_dim_name(i): parse_dict(j) for i, j in + dimensional_dependencies.items()} + + for dim in derived_dims: + if dim in base_dims: + raise ValueError("Dimension %s both in base and derived" % dim) + if dim not in dimensional_dependencies: + # TODO: should this raise a warning? + dimensional_dependencies[dim] = Dict({dim: 1}) + + base_dims.sort(key=default_sort_key) + derived_dims.sort(key=default_sort_key) + + base_dims = Tuple(*base_dims) + derived_dims = Tuple(*derived_dims) + dimensional_dependencies = Dict({i: Dict(j) for i, j in dimensional_dependencies.items()}) + obj = Basic.__new__(cls, base_dims, derived_dims, dimensional_dependencies) + return obj + + @property + def base_dims(self): + return self.args[0] + + @property + def derived_dims(self): + return self.args[1] + + @property + def dimensional_dependencies(self): + return self.args[2] + + def _get_dimensional_dependencies_for_name(self, dimension): + if isinstance(dimension, str): + dimension = Dimension(Symbol(dimension)) + elif not isinstance(dimension, Dimension): + dimension = Dimension(dimension) + + if dimension.name.is_Symbol: + # Dimensions not included in the dependencies are considered + # as base dimensions: + return dict(self.dimensional_dependencies.get(dimension, {dimension: 1})) + + if dimension.name.is_number or dimension.name.is_NumberSymbol: + return {} + + get_for_name = self._get_dimensional_dependencies_for_name + + if dimension.name.is_Mul: + ret = collections.defaultdict(int) + dicts = [get_for_name(i) for i in dimension.name.args] + for d in dicts: + for k, v in d.items(): + ret[k] += v + return {k: v for (k, v) in ret.items() if v != 0} + + if dimension.name.is_Add: + dicts = [get_for_name(i) for i in dimension.name.args] + if all(d == dicts[0] for d in dicts[1:]): + return dicts[0] + raise TypeError("Only equivalent dimensions can be added or subtracted.") + + if dimension.name.is_Pow: + dim_base = get_for_name(dimension.name.base) + dim_exp = get_for_name(dimension.name.exp) + if dim_exp == {} or dimension.name.exp.is_Symbol: + return {k: v * dimension.name.exp for (k, v) in dim_base.items()} + else: + raise TypeError("The exponent for the power operator must be a Symbol or dimensionless.") + + if dimension.name.is_Function: + args = (Dimension._from_dimensional_dependencies( + get_for_name(arg)) for arg in dimension.name.args) + result = dimension.name.func(*args) + + dicts = [get_for_name(i) for i in dimension.name.args] + + if isinstance(result, Dimension): + return self.get_dimensional_dependencies(result) + elif result.func == dimension.name.func: + if isinstance(dimension.name, TrigonometricFunction): + if dicts[0] in ({}, {Dimension('angle'): 1}): + return {} + else: + raise TypeError("The input argument for the function {} must be dimensionless or have dimensions of angle.".format(dimension.func)) + else: + if all(item == {} for item in dicts): + return {} + else: + raise TypeError("The input arguments for the function {} must be dimensionless.".format(dimension.func)) + else: + return get_for_name(result) + + raise TypeError("Type {} not implemented for get_dimensional_dependencies".format(type(dimension.name))) + + def get_dimensional_dependencies(self, name, mark_dimensionless=False): + dimdep = self._get_dimensional_dependencies_for_name(name) + if mark_dimensionless and dimdep == {}: + return {Dimension(1): 1} + return dict(dimdep.items()) + + def equivalent_dims(self, dim1, dim2): + deps1 = self.get_dimensional_dependencies(dim1) + deps2 = self.get_dimensional_dependencies(dim2) + return deps1 == deps2 + + def extend(self, new_base_dims, new_derived_dims=(), new_dim_deps=None): + deps = dict(self.dimensional_dependencies) + if new_dim_deps: + deps.update(new_dim_deps) + + new_dim_sys = DimensionSystem( + tuple(self.base_dims) + tuple(new_base_dims), + tuple(self.derived_dims) + tuple(new_derived_dims), + deps + ) + new_dim_sys._quantity_dimension_map.update(self._quantity_dimension_map) + new_dim_sys._quantity_scale_factors.update(self._quantity_scale_factors) + return new_dim_sys + + def is_dimensionless(self, dimension): + """ + Check if the dimension object really has a dimension. + + A dimension should have at least one component with non-zero power. + """ + if dimension.name == 1: + return True + return self.get_dimensional_dependencies(dimension) == {} + + @property + def list_can_dims(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + List all canonical dimension names. + """ + dimset = set() + for i in self.base_dims: + dimset.update(set(self.get_dimensional_dependencies(i).keys())) + return tuple(sorted(dimset, key=str)) + + @property + def inv_can_transf_matrix(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Compute the inverse transformation matrix from the base to the + canonical dimension basis. + + It corresponds to the matrix where columns are the vector of base + dimensions in canonical basis. + + This matrix will almost never be used because dimensions are always + defined with respect to the canonical basis, so no work has to be done + to get them in this basis. Nonetheless if this matrix is not square + (or not invertible) it means that we have chosen a bad basis. + """ + matrix = reduce(lambda x, y: x.row_join(y), + [self.dim_can_vector(d) for d in self.base_dims]) + return matrix + + @property + def can_transf_matrix(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Return the canonical transformation matrix from the canonical to the + base dimension basis. + + It is the inverse of the matrix computed with inv_can_transf_matrix(). + """ + + #TODO: the inversion will fail if the system is inconsistent, for + # example if the matrix is not a square + return reduce(lambda x, y: x.row_join(y), + [self.dim_can_vector(d) for d in sorted(self.base_dims, key=str)] + ).inv() + + def dim_can_vector(self, dim): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Dimensional representation in terms of the canonical base dimensions. + """ + + vec = [] + for d in self.list_can_dims: + vec.append(self.get_dimensional_dependencies(dim).get(d, 0)) + return Matrix(vec) + + def dim_vector(self, dim): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + + Vector representation in terms of the base dimensions. + """ + return self.can_transf_matrix * Matrix(self.dim_can_vector(dim)) + + def print_dim_base(self, dim): + """ + Give the string expression of a dimension in term of the basis symbols. + """ + dims = self.dim_vector(dim) + symbols = [i.symbol if i.symbol is not None else i.name for i in self.base_dims] + res = S.One + for (s, p) in zip(symbols, dims): + res *= s**p + return res + + @property + def dim(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Give the dimension of the system. + + That is return the number of dimensions forming the basis. + """ + return len(self.base_dims) + + @property + def is_consistent(self): + """ + Useless method, kept for compatibility with previous versions. + + DO NOT USE. + + Check if the system is well defined. + """ + + # not enough or too many base dimensions compared to independent + # dimensions + # in vector language: the set of vectors do not form a basis + return self.inv_can_transf_matrix.is_square diff --git a/lib/python3.10/site-packages/sympy/physics/units/prefixes.py b/lib/python3.10/site-packages/sympy/physics/units/prefixes.py new file mode 100644 index 0000000000000000000000000000000000000000..44fd7cb9efe4b1d6307810af6b9cd140817126f9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/units/prefixes.py @@ -0,0 +1,219 @@ +""" +Module defining unit prefixe class and some constants. + +Constant dict for SI and binary prefixes are defined as PREFIXES and +BIN_PREFIXES. +""" +from sympy.core.expr import Expr +from sympy.core.sympify import sympify +from sympy.core.singleton import S + +class Prefix(Expr): + """ + This class represent prefixes, with their name, symbol and factor. + + Prefixes are used to create derived units from a given unit. They should + always be encapsulated into units. + + The factor is constructed from a base (default is 10) to some power, and + it gives the total multiple or fraction. For example the kilometer km + is constructed from the meter (factor 1) and the kilo (10 to the power 3, + i.e. 1000). The base can be changed to allow e.g. binary prefixes. + + A prefix multiplied by something will always return the product of this + other object times the factor, except if the other object: + + - is a prefix and they can be combined into a new prefix; + - defines multiplication with prefixes (which is the case for the Unit + class). + """ + _op_priority = 13.0 + is_commutative = True + + def __new__(cls, name, abbrev, exponent, base=sympify(10), latex_repr=None): + + name = sympify(name) + abbrev = sympify(abbrev) + exponent = sympify(exponent) + base = sympify(base) + + obj = Expr.__new__(cls, name, abbrev, exponent, base) + obj._name = name + obj._abbrev = abbrev + obj._scale_factor = base**exponent + obj._exponent = exponent + obj._base = base + obj._latex_repr = latex_repr + return obj + + @property + def name(self): + return self._name + + @property + def abbrev(self): + return self._abbrev + + @property + def scale_factor(self): + return self._scale_factor + + def _latex(self, printer): + if self._latex_repr is None: + return r'\text{%s}' % self._abbrev + return self._latex_repr + + @property + def base(self): + return self._base + + def __str__(self): + return str(self._abbrev) + + def __repr__(self): + if self.base == 10: + return "Prefix(%r, %r, %r)" % ( + str(self.name), str(self.abbrev), self._exponent) + else: + return "Prefix(%r, %r, %r, %r)" % ( + str(self.name), str(self.abbrev), self._exponent, self.base) + + def __mul__(self, other): + from sympy.physics.units import Quantity + if not isinstance(other, (Quantity, Prefix)): + return super().__mul__(other) + + fact = self.scale_factor * other.scale_factor + + if isinstance(other, Prefix): + if fact == 1: + return S.One + # simplify prefix + for p in PREFIXES: + if PREFIXES[p].scale_factor == fact: + return PREFIXES[p] + return fact + + return self.scale_factor * other + + def __truediv__(self, other): + if not hasattr(other, "scale_factor"): + return super().__truediv__(other) + + fact = self.scale_factor / other.scale_factor + + if fact == 1: + return S.One + elif isinstance(other, Prefix): + for p in PREFIXES: + if PREFIXES[p].scale_factor == fact: + return PREFIXES[p] + return fact + + return self.scale_factor / other + + def __rtruediv__(self, other): + if other == 1: + for p in PREFIXES: + if PREFIXES[p].scale_factor == 1 / self.scale_factor: + return PREFIXES[p] + return other / self.scale_factor + + +def prefix_unit(unit, prefixes): + """ + Return a list of all units formed by unit and the given prefixes. + + You can use the predefined PREFIXES or BIN_PREFIXES, but you can also + pass as argument a subdict of them if you do not want all prefixed units. + + >>> from sympy.physics.units.prefixes import (PREFIXES, + ... prefix_unit) + >>> from sympy.physics.units import m + >>> pref = {"m": PREFIXES["m"], "c": PREFIXES["c"], "d": PREFIXES["d"]} + >>> prefix_unit(m, pref) # doctest: +SKIP + [millimeter, centimeter, decimeter] + """ + + from sympy.physics.units.quantities import Quantity + from sympy.physics.units import UnitSystem + + prefixed_units = [] + + for prefix in prefixes.values(): + quantity = Quantity( + "%s%s" % (prefix.name, unit.name), + abbrev=("%s%s" % (prefix.abbrev, unit.abbrev)), + is_prefixed=True, + ) + UnitSystem._quantity_dimensional_equivalence_map_global[quantity] = unit + UnitSystem._quantity_scale_factors_global[quantity] = (prefix.scale_factor, unit) + prefixed_units.append(quantity) + + return prefixed_units + + +yotta = Prefix('yotta', 'Y', 24) +zetta = Prefix('zetta', 'Z', 21) +exa = Prefix('exa', 'E', 18) +peta = Prefix('peta', 'P', 15) +tera = Prefix('tera', 'T', 12) +giga = Prefix('giga', 'G', 9) +mega = Prefix('mega', 'M', 6) +kilo = Prefix('kilo', 'k', 3) +hecto = Prefix('hecto', 'h', 2) +deca = Prefix('deca', 'da', 1) +deci = Prefix('deci', 'd', -1) +centi = Prefix('centi', 'c', -2) +milli = Prefix('milli', 'm', -3) +micro = Prefix('micro', 'mu', -6, latex_repr=r"\mu") +nano = Prefix('nano', 'n', -9) +pico = Prefix('pico', 'p', -12) +femto = Prefix('femto', 'f', -15) +atto = Prefix('atto', 'a', -18) +zepto = Prefix('zepto', 'z', -21) +yocto = Prefix('yocto', 'y', -24) + + +# https://physics.nist.gov/cuu/Units/prefixes.html +PREFIXES = { + 'Y': yotta, + 'Z': zetta, + 'E': exa, + 'P': peta, + 'T': tera, + 'G': giga, + 'M': mega, + 'k': kilo, + 'h': hecto, + 'da': deca, + 'd': deci, + 'c': centi, + 'm': milli, + 'mu': micro, + 'n': nano, + 'p': pico, + 'f': femto, + 'a': atto, + 'z': zepto, + 'y': yocto, +} + + +kibi = Prefix('kibi', 'Y', 10, 2) +mebi = Prefix('mebi', 'Y', 20, 2) +gibi = Prefix('gibi', 'Y', 30, 2) +tebi = Prefix('tebi', 'Y', 40, 2) +pebi = Prefix('pebi', 'Y', 50, 2) +exbi = Prefix('exbi', 'Y', 60, 2) + + +# https://physics.nist.gov/cuu/Units/binary.html +BIN_PREFIXES = { + 'Ki': kibi, + 'Mi': mebi, + 'Gi': gibi, + 'Ti': tebi, + 'Pi': pebi, + 'Ei': exbi, +} diff --git a/lib/python3.10/site-packages/sympy/physics/units/quantities.py b/lib/python3.10/site-packages/sympy/physics/units/quantities.py new file mode 100644 index 0000000000000000000000000000000000000000..cc19e72aea83b5bd8ae7cf2f63dd49388a3815ee --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/units/quantities.py @@ -0,0 +1,152 @@ +""" +Physical quantities. +""" + +from sympy.core.expr import AtomicExpr +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.physics.units.dimensions import _QuantityMapper +from sympy.physics.units.prefixes import Prefix + + +class Quantity(AtomicExpr): + """ + Physical quantity: can be a unit of measure, a constant or a generic quantity. + """ + + is_commutative = True + is_real = True + is_number = False + is_nonzero = True + is_physical_constant = False + _diff_wrt = True + + def __new__(cls, name, abbrev=None, + latex_repr=None, pretty_unicode_repr=None, + pretty_ascii_repr=None, mathml_presentation_repr=None, + is_prefixed=False, + **assumptions): + + if not isinstance(name, Symbol): + name = Symbol(name) + + if abbrev is None: + abbrev = name + elif isinstance(abbrev, str): + abbrev = Symbol(abbrev) + + # HACK: These are here purely for type checking. They actually get assigned below. + cls._is_prefixed = is_prefixed + + obj = AtomicExpr.__new__(cls, name, abbrev) + obj._name = name + obj._abbrev = abbrev + obj._latex_repr = latex_repr + obj._unicode_repr = pretty_unicode_repr + obj._ascii_repr = pretty_ascii_repr + obj._mathml_repr = mathml_presentation_repr + obj._is_prefixed = is_prefixed + return obj + + def set_global_dimension(self, dimension): + _QuantityMapper._quantity_dimension_global[self] = dimension + + def set_global_relative_scale_factor(self, scale_factor, reference_quantity): + """ + Setting a scale factor that is valid across all unit system. + """ + from sympy.physics.units import UnitSystem + scale_factor = sympify(scale_factor) + if isinstance(scale_factor, Prefix): + self._is_prefixed = True + # replace all prefixes by their ratio to canonical units: + scale_factor = scale_factor.replace( + lambda x: isinstance(x, Prefix), + lambda x: x.scale_factor + ) + scale_factor = sympify(scale_factor) + UnitSystem._quantity_scale_factors_global[self] = (scale_factor, reference_quantity) + UnitSystem._quantity_dimensional_equivalence_map_global[self] = reference_quantity + + @property + def name(self): + return self._name + + @property + def dimension(self): + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_default_unit_system() + return unit_system.get_quantity_dimension(self) + + @property + def abbrev(self): + """ + Symbol representing the unit name. + + Prepend the abbreviation with the prefix symbol if it is defines. + """ + return self._abbrev + + @property + def scale_factor(self): + """ + Overall magnitude of the quantity as compared to the canonical units. + """ + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_default_unit_system() + return unit_system.get_quantity_scale_factor(self) + + def _eval_is_positive(self): + return True + + def _eval_is_constant(self): + return True + + def _eval_Abs(self): + return self + + def _eval_subs(self, old, new): + if isinstance(new, Quantity) and self != old: + return self + + def _latex(self, printer): + if self._latex_repr: + return self._latex_repr + else: + return r'\text{{{}}}'.format(self.args[1] \ + if len(self.args) >= 2 else self.args[0]) + + def convert_to(self, other, unit_system="SI"): + """ + Convert the quantity to another quantity of same dimensions. + + Examples + ======== + + >>> from sympy.physics.units import speed_of_light, meter, second + >>> speed_of_light + speed_of_light + >>> speed_of_light.convert_to(meter/second) + 299792458*meter/second + + >>> from sympy.physics.units import liter + >>> liter.convert_to(meter**3) + meter**3/1000 + """ + from .util import convert_to + return convert_to(self, other, unit_system) + + @property + def free_symbols(self): + """Return free symbols from quantity.""" + return set() + + @property + def is_prefixed(self): + """Whether or not the quantity is prefixed. Eg. `kilogram` is prefixed, but `gram` is not.""" + return self._is_prefixed + +class PhysicalConstant(Quantity): + """Represents a physical constant, eg. `speed_of_light` or `avogadro_constant`.""" + + is_physical_constant = True diff --git a/lib/python3.10/site-packages/sympy/physics/units/unitsystem.py b/lib/python3.10/site-packages/sympy/physics/units/unitsystem.py new file mode 100644 index 0000000000000000000000000000000000000000..5705c821c217f781717f9dd5cad6f3c9c77b145f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/units/unitsystem.py @@ -0,0 +1,205 @@ +""" +Unit system for physical quantities; include definition of constants. +""" + +from typing import Dict as tDict, Set as tSet + +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.physics.units.dimensions import _QuantityMapper +from sympy.physics.units.quantities import Quantity + +from .dimensions import Dimension + + +class UnitSystem(_QuantityMapper): + """ + UnitSystem represents a coherent set of units. + + A unit system is basically a dimension system with notions of scales. Many + of the methods are defined in the same way. + + It is much better if all base units have a symbol. + """ + + _unit_systems = {} # type: tDict[str, UnitSystem] + + def __init__(self, base_units, units=(), name="", descr="", dimension_system=None, derived_units: tDict[Dimension, Quantity]={}): + + UnitSystem._unit_systems[name] = self + + self.name = name + self.descr = descr + + self._base_units = base_units + self._dimension_system = dimension_system + self._units = tuple(set(base_units) | set(units)) + self._base_units = tuple(base_units) + self._derived_units = derived_units + + super().__init__() + + def __str__(self): + """ + Return the name of the system. + + If it does not exist, then it makes a list of symbols (or names) of + the base dimensions. + """ + + if self.name != "": + return self.name + else: + return "UnitSystem((%s))" % ", ".join( + str(d) for d in self._base_units) + + def __repr__(self): + return '' % repr(self._base_units) + + def extend(self, base, units=(), name="", description="", dimension_system=None, derived_units: tDict[Dimension, Quantity]={}): + """Extend the current system into a new one. + + Take the base and normal units of the current system to merge + them to the base and normal units given in argument. + If not provided, name and description are overridden by empty strings. + """ + + base = self._base_units + tuple(base) + units = self._units + tuple(units) + + return UnitSystem(base, units, name, description, dimension_system, {**self._derived_units, **derived_units}) + + def get_dimension_system(self): + return self._dimension_system + + def get_quantity_dimension(self, unit): + qdm = self.get_dimension_system()._quantity_dimension_map + if unit in qdm: + return qdm[unit] + return super().get_quantity_dimension(unit) + + def get_quantity_scale_factor(self, unit): + qsfm = self.get_dimension_system()._quantity_scale_factors + if unit in qsfm: + return qsfm[unit] + return super().get_quantity_scale_factor(unit) + + @staticmethod + def get_unit_system(unit_system): + if isinstance(unit_system, UnitSystem): + return unit_system + + if unit_system not in UnitSystem._unit_systems: + raise ValueError( + "Unit system is not supported. Currently" + "supported unit systems are {}".format( + ", ".join(sorted(UnitSystem._unit_systems)) + ) + ) + + return UnitSystem._unit_systems[unit_system] + + @staticmethod + def get_default_unit_system(): + return UnitSystem._unit_systems["SI"] + + @property + def dim(self): + """ + Give the dimension of the system. + + That is return the number of units forming the basis. + """ + return len(self._base_units) + + @property + def is_consistent(self): + """ + Check if the underlying dimension system is consistent. + """ + # test is performed in DimensionSystem + return self.get_dimension_system().is_consistent + + @property + def derived_units(self) -> tDict[Dimension, Quantity]: + return self._derived_units + + def get_dimensional_expr(self, expr): + from sympy.physics.units import Quantity + if isinstance(expr, Mul): + return Mul(*[self.get_dimensional_expr(i) for i in expr.args]) + elif isinstance(expr, Pow): + return self.get_dimensional_expr(expr.base) ** expr.exp + elif isinstance(expr, Add): + return self.get_dimensional_expr(expr.args[0]) + elif isinstance(expr, Derivative): + dim = self.get_dimensional_expr(expr.expr) + for independent, count in expr.variable_count: + dim /= self.get_dimensional_expr(independent)**count + return dim + elif isinstance(expr, Function): + args = [self.get_dimensional_expr(arg) for arg in expr.args] + if all(i == 1 for i in args): + return S.One + return expr.func(*args) + elif isinstance(expr, Quantity): + return self.get_quantity_dimension(expr).name + return S.One + + def _collect_factor_and_dimension(self, expr): + """ + Return tuple with scale factor expression and dimension expression. + """ + from sympy.physics.units import Quantity + if isinstance(expr, Quantity): + return expr.scale_factor, expr.dimension + elif isinstance(expr, Mul): + factor = 1 + dimension = Dimension(1) + for arg in expr.args: + arg_factor, arg_dim = self._collect_factor_and_dimension(arg) + factor *= arg_factor + dimension *= arg_dim + return factor, dimension + elif isinstance(expr, Pow): + factor, dim = self._collect_factor_and_dimension(expr.base) + exp_factor, exp_dim = self._collect_factor_and_dimension(expr.exp) + if self.get_dimension_system().is_dimensionless(exp_dim): + exp_dim = 1 + return factor ** exp_factor, dim ** (exp_factor * exp_dim) + elif isinstance(expr, Add): + factor, dim = self._collect_factor_and_dimension(expr.args[0]) + for addend in expr.args[1:]: + addend_factor, addend_dim = \ + self._collect_factor_and_dimension(addend) + if not self.get_dimension_system().equivalent_dims(dim, addend_dim): + raise ValueError( + 'Dimension of "{}" is {}, ' + 'but it should be {}'.format( + addend, addend_dim, dim)) + factor += addend_factor + return factor, dim + elif isinstance(expr, Derivative): + factor, dim = self._collect_factor_and_dimension(expr.args[0]) + for independent, count in expr.variable_count: + ifactor, idim = self._collect_factor_and_dimension(independent) + factor /= ifactor**count + dim /= idim**count + return factor, dim + elif isinstance(expr, Function): + fds = [self._collect_factor_and_dimension(arg) for arg in expr.args] + dims = [Dimension(1) if self.get_dimension_system().is_dimensionless(d[1]) else d[1] for d in fds] + return (expr.func(*(f[0] for f in fds)), *dims) + elif isinstance(expr, Dimension): + return S.One, expr + else: + return expr, Dimension(1) + + def get_units_non_prefixed(self) -> tSet[Quantity]: + """ + Return the units of the system that do not have a prefix. + """ + return set(filter(lambda u: not u.is_prefixed and not u.is_physical_constant, self._units)) diff --git a/lib/python3.10/site-packages/sympy/physics/units/util.py b/lib/python3.10/site-packages/sympy/physics/units/util.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f5a004fe9fa93b11039c5a832829715a7d44c6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/units/util.py @@ -0,0 +1,265 @@ +""" +Several methods to simplify expressions involving unit objects. +""" +from functools import reduce +from collections.abc import Iterable +from typing import Optional + +from sympy import default_sort_key +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.sorting import ordered +from sympy.core.sympify import sympify +from sympy.core.function import Function +from sympy.matrices.exceptions import NonInvertibleMatrixError +from sympy.physics.units.dimensions import Dimension, DimensionSystem +from sympy.physics.units.prefixes import Prefix +from sympy.physics.units.quantities import Quantity +from sympy.physics.units.unitsystem import UnitSystem +from sympy.utilities.iterables import sift + + +def _get_conversion_matrix_for_expr(expr, target_units, unit_system): + from sympy.matrices.dense import Matrix + + dimension_system = unit_system.get_dimension_system() + + expr_dim = Dimension(unit_system.get_dimensional_expr(expr)) + dim_dependencies = dimension_system.get_dimensional_dependencies(expr_dim, mark_dimensionless=True) + target_dims = [Dimension(unit_system.get_dimensional_expr(x)) for x in target_units] + canon_dim_units = [i for x in target_dims for i in dimension_system.get_dimensional_dependencies(x, mark_dimensionless=True)] + canon_expr_units = set(dim_dependencies) + + if not canon_expr_units.issubset(set(canon_dim_units)): + return None + + seen = set() + canon_dim_units = [i for i in canon_dim_units if not (i in seen or seen.add(i))] + + camat = Matrix([[dimension_system.get_dimensional_dependencies(i, mark_dimensionless=True).get(j, 0) for i in target_dims] for j in canon_dim_units]) + exprmat = Matrix([dim_dependencies.get(k, 0) for k in canon_dim_units]) + + try: + res_exponents = camat.solve(exprmat) + except NonInvertibleMatrixError: + return None + + return res_exponents + + +def convert_to(expr, target_units, unit_system="SI"): + """ + Convert ``expr`` to the same expression with all of its units and quantities + represented as factors of ``target_units``, whenever the dimension is compatible. + + ``target_units`` may be a single unit/quantity, or a collection of + units/quantities. + + Examples + ======== + + >>> from sympy.physics.units import speed_of_light, meter, gram, second, day + >>> from sympy.physics.units import mile, newton, kilogram, atomic_mass_constant + >>> from sympy.physics.units import kilometer, centimeter + >>> from sympy.physics.units import gravitational_constant, hbar + >>> from sympy.physics.units import convert_to + >>> convert_to(mile, kilometer) + 25146*kilometer/15625 + >>> convert_to(mile, kilometer).n() + 1.609344*kilometer + >>> convert_to(speed_of_light, meter/second) + 299792458*meter/second + >>> convert_to(day, second) + 86400*second + >>> 3*newton + 3*newton + >>> convert_to(3*newton, kilogram*meter/second**2) + 3*kilogram*meter/second**2 + >>> convert_to(atomic_mass_constant, gram) + 1.660539060e-24*gram + + Conversion to multiple units: + + >>> convert_to(speed_of_light, [meter, second]) + 299792458*meter/second + >>> convert_to(3*newton, [centimeter, gram, second]) + 300000*centimeter*gram/second**2 + + Conversion to Planck units: + + >>> convert_to(atomic_mass_constant, [gravitational_constant, speed_of_light, hbar]).n() + 7.62963087839509e-20*hbar**0.5*speed_of_light**0.5/gravitational_constant**0.5 + + """ + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_unit_system(unit_system) + + if not isinstance(target_units, (Iterable, Tuple)): + target_units = [target_units] + + def handle_Adds(expr): + return Add.fromiter(convert_to(i, target_units, unit_system) + for i in expr.args) + + if isinstance(expr, Add): + return handle_Adds(expr) + elif isinstance(expr, Pow) and isinstance(expr.base, Add): + return handle_Adds(expr.base) ** expr.exp + + expr = sympify(expr) + target_units = sympify(target_units) + + if isinstance(expr, Function): + expr = expr.together() + + if not isinstance(expr, Quantity) and expr.has(Quantity): + expr = expr.replace(lambda x: isinstance(x, Quantity), + lambda x: x.convert_to(target_units, unit_system)) + + def get_total_scale_factor(expr): + if isinstance(expr, Mul): + return reduce(lambda x, y: x * y, + [get_total_scale_factor(i) for i in expr.args]) + elif isinstance(expr, Pow): + return get_total_scale_factor(expr.base) ** expr.exp + elif isinstance(expr, Quantity): + return unit_system.get_quantity_scale_factor(expr) + return expr + + depmat = _get_conversion_matrix_for_expr(expr, target_units, unit_system) + if depmat is None: + return expr + + expr_scale_factor = get_total_scale_factor(expr) + return expr_scale_factor * Mul.fromiter( + (1/get_total_scale_factor(u)*u)**p for u, p in + zip(target_units, depmat)) + + +def quantity_simplify(expr, across_dimensions: bool=False, unit_system=None): + """Return an equivalent expression in which prefixes are replaced + with numerical values and all units of a given dimension are the + unified in a canonical manner by default. `across_dimensions` allows + for units of different dimensions to be simplified together. + + `unit_system` must be specified if `across_dimensions` is True. + + Examples + ======== + + >>> from sympy.physics.units.util import quantity_simplify + >>> from sympy.physics.units.prefixes import kilo + >>> from sympy.physics.units import foot, inch, joule, coulomb + >>> quantity_simplify(kilo*foot*inch) + 250*foot**2/3 + >>> quantity_simplify(foot - 6*inch) + foot/2 + >>> quantity_simplify(5*joule/coulomb, across_dimensions=True, unit_system="SI") + 5*volt + """ + + if expr.is_Atom or not expr.has(Prefix, Quantity): + return expr + + # replace all prefixes with numerical values + p = expr.atoms(Prefix) + expr = expr.xreplace({p: p.scale_factor for p in p}) + + # replace all quantities of given dimension with a canonical + # quantity, chosen from those in the expression + d = sift(expr.atoms(Quantity), lambda i: i.dimension) + for k in d: + if len(d[k]) == 1: + continue + v = list(ordered(d[k])) + ref = v[0]/v[0].scale_factor + expr = expr.xreplace({vi: ref*vi.scale_factor for vi in v[1:]}) + + if across_dimensions: + # combine quantities of different dimensions into a single + # quantity that is equivalent to the original expression + + if unit_system is None: + raise ValueError("unit_system must be specified if across_dimensions is True") + + unit_system = UnitSystem.get_unit_system(unit_system) + dimension_system: DimensionSystem = unit_system.get_dimension_system() + dim_expr = unit_system.get_dimensional_expr(expr) + dim_deps = dimension_system.get_dimensional_dependencies(dim_expr, mark_dimensionless=True) + + target_dimension: Optional[Dimension] = None + for ds_dim, ds_dim_deps in dimension_system.dimensional_dependencies.items(): + if ds_dim_deps == dim_deps: + target_dimension = ds_dim + break + + if target_dimension is None: + # if we can't find a target dimension, we can't do anything. unsure how to handle this case. + return expr + + target_unit = unit_system.derived_units.get(target_dimension) + if target_unit: + expr = convert_to(expr, target_unit, unit_system) + + return expr + + +def check_dimensions(expr, unit_system="SI"): + """Return expr if units in addends have the same + base dimensions, else raise a ValueError.""" + # the case of adding a number to a dimensional quantity + # is ignored for the sake of SymPy core routines, so this + # function will raise an error now if such an addend is + # found. + # Also, when doing substitutions, multiplicative constants + # might be introduced, so remove those now + + from sympy.physics.units import UnitSystem + unit_system = UnitSystem.get_unit_system(unit_system) + + def addDict(dict1, dict2): + """Merge dictionaries by adding values of common keys and + removing keys with value of 0.""" + dict3 = {**dict1, **dict2} + for key, value in dict3.items(): + if key in dict1 and key in dict2: + dict3[key] = value + dict1[key] + return {key:val for key, val in dict3.items() if val != 0} + + adds = expr.atoms(Add) + DIM_OF = unit_system.get_dimension_system().get_dimensional_dependencies + for a in adds: + deset = set() + for ai in a.args: + if ai.is_number: + deset.add(()) + continue + dims = [] + skip = False + dimdict = {} + for i in Mul.make_args(ai): + if i.has(Quantity): + i = Dimension(unit_system.get_dimensional_expr(i)) + if i.has(Dimension): + dimdict = addDict(dimdict, DIM_OF(i)) + elif i.free_symbols: + skip = True + break + dims.extend(dimdict.items()) + if not skip: + deset.add(tuple(sorted(dims, key=default_sort_key))) + if len(deset) > 1: + raise ValueError( + "addends have incompatible dimensions: {}".format(deset)) + + # clear multiplicative constants on Dimensions which may be + # left after substitution + reps = {} + for m in expr.atoms(Mul): + if any(isinstance(i, Dimension) for i in m.args): + reps[m] = m.func(*[ + i for i in m.args if not i.is_number]) + + return expr.xreplace(reps) diff --git a/lib/python3.10/site-packages/sympy/physics/vector/__init__.py b/lib/python3.10/site-packages/sympy/physics/vector/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e714852064c0b940ebda2e5fe7a08faf13f07ed0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/__init__.py @@ -0,0 +1,36 @@ +__all__ = [ + 'CoordinateSym', 'ReferenceFrame', + + 'Dyadic', + + 'Vector', + + 'Point', + + 'cross', 'dot', 'express', 'time_derivative', 'outer', + 'kinematic_equations', 'get_motion_params', 'partial_velocity', + 'dynamicsymbols', + + 'vprint', 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', 'init_vprinting', + + 'curl', 'divergence', 'gradient', 'is_conservative', 'is_solenoidal', + 'scalar_potential', 'scalar_potential_difference', + +] +from .frame import CoordinateSym, ReferenceFrame + +from .dyadic import Dyadic + +from .vector import Vector + +from .point import Point + +from .functions import (cross, dot, express, time_derivative, outer, + kinematic_equations, get_motion_params, partial_velocity, + dynamicsymbols) + +from .printing import (vprint, vsstrrepr, vsprint, vpprint, vlatex, + init_vprinting) + +from .fieldfunctions import (curl, divergence, gradient, is_conservative, + is_solenoidal, scalar_potential, scalar_potential_difference) diff --git a/lib/python3.10/site-packages/sympy/physics/vector/dyadic.py b/lib/python3.10/site-packages/sympy/physics/vector/dyadic.py new file mode 100644 index 0000000000000000000000000000000000000000..514ab2312ba6e40c87a64c3a45f411e5d12a8015 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/dyadic.py @@ -0,0 +1,545 @@ +from sympy import sympify, Add, ImmutableMatrix as Matrix +from sympy.core.evalf import EvalfMixin +from sympy.printing.defaults import Printable + +from mpmath.libmp.libmpf import prec_to_dps + + +__all__ = ['Dyadic'] + + +class Dyadic(Printable, EvalfMixin): + """A Dyadic object. + + See: + https://en.wikipedia.org/wiki/Dyadic_tensor + Kane, T., Levinson, D. Dynamics Theory and Applications. 1985 McGraw-Hill + + A more powerful way to represent a rigid body's inertia. While it is more + complex, by choosing Dyadic components to be in body fixed basis vectors, + the resulting matrix is equivalent to the inertia tensor. + + """ + + is_number = False + + def __init__(self, inlist): + """ + Just like Vector's init, you should not call this unless creating a + zero dyadic. + + zd = Dyadic(0) + + Stores a Dyadic as a list of lists; the inner list has the measure + number and the two unit vectors; the outerlist holds each unique + unit vector pair. + + """ + + self.args = [] + if inlist == 0: + inlist = [] + while len(inlist) != 0: + added = 0 + for i, v in enumerate(self.args): + if ((str(inlist[0][1]) == str(self.args[i][1])) and + (str(inlist[0][2]) == str(self.args[i][2]))): + self.args[i] = (self.args[i][0] + inlist[0][0], + inlist[0][1], inlist[0][2]) + inlist.remove(inlist[0]) + added = 1 + break + if added != 1: + self.args.append(inlist[0]) + inlist.remove(inlist[0]) + i = 0 + # This code is to remove empty parts from the list + while i < len(self.args): + if ((self.args[i][0] == 0) | (self.args[i][1] == 0) | + (self.args[i][2] == 0)): + self.args.remove(self.args[i]) + i -= 1 + i += 1 + + @property + def func(self): + """Returns the class Dyadic. """ + return Dyadic + + def __add__(self, other): + """The add operator for Dyadic. """ + other = _check_dyadic(other) + return Dyadic(self.args + other.args) + + __radd__ = __add__ + + def __mul__(self, other): + """Multiplies the Dyadic by a sympifyable expression. + + Parameters + ========== + + other : Sympafiable + The scalar to multiply this Dyadic with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> d = outer(N.x, N.x) + >>> 5 * d + 5*(N.x|N.x) + + """ + newlist = list(self.args) + other = sympify(other) + for i, v in enumerate(newlist): + newlist[i] = (other * newlist[i][0], newlist[i][1], + newlist[i][2]) + return Dyadic(newlist) + + __rmul__ = __mul__ + + def dot(self, other): + """The inner product operator for a Dyadic and a Dyadic or Vector. + + Parameters + ========== + + other : Dyadic or Vector + The other Dyadic or Vector to take the inner product with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> D1 = outer(N.x, N.y) + >>> D2 = outer(N.y, N.y) + >>> D1.dot(D2) + (N.x|N.y) + >>> D1.dot(N.y) + N.x + + """ + from sympy.physics.vector.vector import Vector, _check_vector + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Dyadic(0) + for v in self.args: + for v2 in other.args: + ol += v[0] * v2[0] * (v[2].dot(v2[1])) * (v[1].outer(v2[2])) + else: + other = _check_vector(other) + ol = Vector(0) + for v in self.args: + ol += v[0] * v[1] * (v[2].dot(other)) + return ol + + # NOTE : supports non-advertised Dyadic & Dyadic, Dyadic & Vector notation + __and__ = dot + + def __truediv__(self, other): + """Divides the Dyadic by a sympifyable expression. """ + return self.__mul__(1 / other) + + def __eq__(self, other): + """Tests for equality. + + Is currently weak; needs stronger comparison testing + + """ + + if other == 0: + other = Dyadic(0) + other = _check_dyadic(other) + if (self.args == []) and (other.args == []): + return True + elif (self.args == []) or (other.args == []): + return False + return set(self.args) == set(other.args) + + def __ne__(self, other): + return not self == other + + def __neg__(self): + return self * -1 + + def _latex(self, printer): + ar = self.args # just to shorten things + if len(ar) == 0: + return str(0) + ol = [] # output list, to be concatenated to a string + for i, v in enumerate(ar): + # if the coef of the dyadic is 1, we skip the 1 + if ar[i][0] == 1: + ol.append(' + ' + printer._print(ar[i][1]) + r"\otimes " + + printer._print(ar[i][2])) + # if the coef of the dyadic is -1, we skip the 1 + elif ar[i][0] == -1: + ol.append(' - ' + + printer._print(ar[i][1]) + + r"\otimes " + + printer._print(ar[i][2])) + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif ar[i][0] != 0: + arg_str = printer._print(ar[i][0]) + if isinstance(ar[i][0], Add): + arg_str = '(%s)' % arg_str + if arg_str.startswith('-'): + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + printer._print(ar[i][1]) + + r"\otimes " + printer._print(ar[i][2])) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def _pretty(self, printer): + e = self + + class Fake: + baseline = 0 + + def render(self, *args, **kwargs): + ar = e.args # just to shorten things + mpp = printer + if len(ar) == 0: + return str(0) + bar = "\N{CIRCLED TIMES}" if printer._use_unicode else "|" + ol = [] # output list, to be concatenated to a string + for i, v in enumerate(ar): + # if the coef of the dyadic is 1, we skip the 1 + if ar[i][0] == 1: + ol.extend([" + ", + mpp.doprint(ar[i][1]), + bar, + mpp.doprint(ar[i][2])]) + + # if the coef of the dyadic is -1, we skip the 1 + elif ar[i][0] == -1: + ol.extend([" - ", + mpp.doprint(ar[i][1]), + bar, + mpp.doprint(ar[i][2])]) + + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif ar[i][0] != 0: + if isinstance(ar[i][0], Add): + arg_str = mpp._print( + ar[i][0]).parens()[0] + else: + arg_str = mpp.doprint(ar[i][0]) + if arg_str.startswith("-"): + arg_str = arg_str[1:] + str_start = " - " + else: + str_start = " + " + ol.extend([str_start, arg_str, " ", + mpp.doprint(ar[i][1]), + bar, + mpp.doprint(ar[i][2])]) + + outstr = "".join(ol) + if outstr.startswith(" + "): + outstr = outstr[3:] + elif outstr.startswith(" "): + outstr = outstr[1:] + return outstr + return Fake() + + def __rsub__(self, other): + return (-1 * self) + other + + def _sympystr(self, printer): + """Printing method. """ + ar = self.args # just to shorten things + if len(ar) == 0: + return printer._print(0) + ol = [] # output list, to be concatenated to a string + for i, v in enumerate(ar): + # if the coef of the dyadic is 1, we skip the 1 + if ar[i][0] == 1: + ol.append(' + (' + printer._print(ar[i][1]) + '|' + + printer._print(ar[i][2]) + ')') + # if the coef of the dyadic is -1, we skip the 1 + elif ar[i][0] == -1: + ol.append(' - (' + printer._print(ar[i][1]) + '|' + + printer._print(ar[i][2]) + ')') + # If the coefficient of the dyadic is not 1 or -1, + # we might wrap it in parentheses, for readability. + elif ar[i][0] != 0: + arg_str = printer._print(ar[i][0]) + if isinstance(ar[i][0], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + '*(' + + printer._print(ar[i][1]) + + '|' + printer._print(ar[i][2]) + ')') + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def __sub__(self, other): + """The subtraction operator. """ + return self.__add__(other * -1) + + def cross(self, other): + """Returns the dyadic resulting from the dyadic vector cross product: + Dyadic x Vector. + + Parameters + ========== + other : Vector + Vector to cross with. + + Examples + ======== + >>> from sympy.physics.vector import ReferenceFrame, outer, cross + >>> N = ReferenceFrame('N') + >>> d = outer(N.x, N.x) + >>> cross(d, N.y) + (N.x|N.z) + + """ + from sympy.physics.vector.vector import _check_vector + other = _check_vector(other) + ol = Dyadic(0) + for v in self.args: + ol += v[0] * (v[1].outer((v[2].cross(other)))) + return ol + + # NOTE : supports non-advertised Dyadic ^ Vector notation + __xor__ = cross + + def express(self, frame1, frame2=None): + """Expresses this Dyadic in alternate frame(s) + + The first frame is the list side expression, the second frame is the + right side; if Dyadic is in form A.x|B.y, you can express it in two + different frames. If no second frame is given, the Dyadic is + expressed in only one frame. + + Calls the global express function + + Parameters + ========== + + frame1 : ReferenceFrame + The frame to express the left side of the Dyadic in + frame2 : ReferenceFrame + If provided, the frame to express the right side of the Dyadic in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> d.express(B, N) + cos(q)*(B.x|N.x) - sin(q)*(B.y|N.x) + + """ + from sympy.physics.vector.functions import express + return express(self, frame1, frame2) + + def to_matrix(self, reference_frame, second_reference_frame=None): + """Returns the matrix form of the dyadic with respect to one or two + reference frames. + + Parameters + ---------- + reference_frame : ReferenceFrame + The reference frame that the rows and columns of the matrix + correspond to. If a second reference frame is provided, this + only corresponds to the rows of the matrix. + second_reference_frame : ReferenceFrame, optional, default=None + The reference frame that the columns of the matrix correspond + to. + + Returns + ------- + matrix : ImmutableMatrix, shape(3,3) + The matrix that gives the 2D tensor form. + + Examples + ======== + + >>> from sympy import symbols, trigsimp + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.mechanics import inertia + >>> Ixx, Iyy, Izz, Ixy, Iyz, Ixz = symbols('Ixx, Iyy, Izz, Ixy, Iyz, Ixz') + >>> N = ReferenceFrame('N') + >>> inertia_dyadic = inertia(N, Ixx, Iyy, Izz, Ixy, Iyz, Ixz) + >>> inertia_dyadic.to_matrix(N) + Matrix([ + [Ixx, Ixy, Ixz], + [Ixy, Iyy, Iyz], + [Ixz, Iyz, Izz]]) + >>> beta = symbols('beta') + >>> A = N.orientnew('A', 'Axis', (beta, N.x)) + >>> trigsimp(inertia_dyadic.to_matrix(A)) + Matrix([ + [ Ixx, Ixy*cos(beta) + Ixz*sin(beta), -Ixy*sin(beta) + Ixz*cos(beta)], + [ Ixy*cos(beta) + Ixz*sin(beta), Iyy*cos(2*beta)/2 + Iyy/2 + Iyz*sin(2*beta) - Izz*cos(2*beta)/2 + Izz/2, -Iyy*sin(2*beta)/2 + Iyz*cos(2*beta) + Izz*sin(2*beta)/2], + [-Ixy*sin(beta) + Ixz*cos(beta), -Iyy*sin(2*beta)/2 + Iyz*cos(2*beta) + Izz*sin(2*beta)/2, -Iyy*cos(2*beta)/2 + Iyy/2 - Iyz*sin(2*beta) + Izz*cos(2*beta)/2 + Izz/2]]) + + """ + + if second_reference_frame is None: + second_reference_frame = reference_frame + + return Matrix([i.dot(self).dot(j) for i in reference_frame for j in + second_reference_frame]).reshape(3, 3) + + def doit(self, **hints): + """Calls .doit() on each term in the Dyadic""" + return sum([Dyadic([(v[0].doit(**hints), v[1], v[2])]) + for v in self.args], Dyadic(0)) + + def dt(self, frame): + """Take the time derivative of this Dyadic in a frame. + + This function calls the global time_derivative method + + Parameters + ========== + + frame : ReferenceFrame + The frame to take the time derivative in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> d.dt(B) + - q'*(N.y|N.x) - q'*(N.x|N.y) + + """ + from sympy.physics.vector.functions import time_derivative + return time_derivative(self, frame) + + def simplify(self): + """Returns a simplified Dyadic.""" + out = Dyadic(0) + for v in self.args: + out += Dyadic([(v[0].simplify(), v[1], v[2])]) + return out + + def subs(self, *args, **kwargs): + """Substitution on the Dyadic. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> s = Symbol('s') + >>> a = s*(N.x|N.x) + >>> a.subs({s: 2}) + 2*(N.x|N.x) + + """ + + return sum([Dyadic([(v[0].subs(*args, **kwargs), v[1], v[2])]) + for v in self.args], Dyadic(0)) + + def applyfunc(self, f): + """Apply a function to each component of a Dyadic.""" + if not callable(f): + raise TypeError("`f` must be callable.") + + out = Dyadic(0) + for a, b, c in self.args: + out += f(a) * (b.outer(c)) + return out + + def _eval_evalf(self, prec): + if not self.args: + return self + new_args = [] + dps = prec_to_dps(prec) + for inlist in self.args: + new_inlist = list(inlist) + new_inlist[0] = inlist[0].evalf(n=dps) + new_args.append(tuple(new_inlist)) + return Dyadic(new_args) + + def xreplace(self, rule): + """ + Replace occurrences of objects within the measure numbers of the + Dyadic. + + Parameters + ========== + + rule : dict-like + Expresses a replacement rule. + + Returns + ======= + + Dyadic + Result of the replacement. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> D = outer(N.x, N.x) + >>> x, y, z = symbols('x y z') + >>> ((1 + x*y) * D).xreplace({x: pi}) + (pi*y + 1)*(N.x|N.x) + >>> ((1 + x*y) * D).xreplace({x: pi, y: 2}) + (1 + 2*pi)*(N.x|N.x) + + Replacements occur only if an entire node in the expression tree is + matched: + + >>> ((x*y + z) * D).xreplace({x*y: pi}) + (z + pi)*(N.x|N.x) + >>> ((x*y*z) * D).xreplace({x*y: pi}) + x*y*z*(N.x|N.x) + + """ + + new_args = [] + for inlist in self.args: + new_inlist = list(inlist) + new_inlist[0] = new_inlist[0].xreplace(rule) + new_args.append(tuple(new_inlist)) + return Dyadic(new_args) + + +def _check_dyadic(other): + if not isinstance(other, Dyadic): + raise TypeError('A Dyadic must be supplied') + return other diff --git a/lib/python3.10/site-packages/sympy/physics/vector/fieldfunctions.py b/lib/python3.10/site-packages/sympy/physics/vector/fieldfunctions.py new file mode 100644 index 0000000000000000000000000000000000000000..74169921c587385323b9080709999c65c6ca0843 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/fieldfunctions.py @@ -0,0 +1,313 @@ +from sympy.core.function import diff +from sympy.core.singleton import S +from sympy.integrals.integrals import integrate +from sympy.physics.vector import Vector, express +from sympy.physics.vector.frame import _check_frame +from sympy.physics.vector.vector import _check_vector + + +__all__ = ['curl', 'divergence', 'gradient', 'is_conservative', + 'is_solenoidal', 'scalar_potential', + 'scalar_potential_difference'] + + +def curl(vect, frame): + """ + Returns the curl of a vector field computed wrt the coordinate + symbols of the given frame. + + Parameters + ========== + + vect : Vector + The vector operand + + frame : ReferenceFrame + The reference frame to calculate the curl in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import curl + >>> R = ReferenceFrame('R') + >>> v1 = R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z + >>> curl(v1, R) + 0 + >>> v2 = R[0]*R[1]*R[2]*R.x + >>> curl(v2, R) + R_x*R_y*R.y - R_x*R_z*R.z + + """ + + _check_vector(vect) + if vect == 0: + return Vector(0) + vect = express(vect, frame, variables=True) + # A mechanical approach to avoid looping overheads + vectx = vect.dot(frame.x) + vecty = vect.dot(frame.y) + vectz = vect.dot(frame.z) + outvec = Vector(0) + outvec += (diff(vectz, frame[1]) - diff(vecty, frame[2])) * frame.x + outvec += (diff(vectx, frame[2]) - diff(vectz, frame[0])) * frame.y + outvec += (diff(vecty, frame[0]) - diff(vectx, frame[1])) * frame.z + return outvec + + +def divergence(vect, frame): + """ + Returns the divergence of a vector field computed wrt the coordinate + symbols of the given frame. + + Parameters + ========== + + vect : Vector + The vector operand + + frame : ReferenceFrame + The reference frame to calculate the divergence in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import divergence + >>> R = ReferenceFrame('R') + >>> v1 = R[0]*R[1]*R[2] * (R.x+R.y+R.z) + >>> divergence(v1, R) + R_x*R_y + R_x*R_z + R_y*R_z + >>> v2 = 2*R[1]*R[2]*R.y + >>> divergence(v2, R) + 2*R_z + + """ + + _check_vector(vect) + if vect == 0: + return S.Zero + vect = express(vect, frame, variables=True) + vectx = vect.dot(frame.x) + vecty = vect.dot(frame.y) + vectz = vect.dot(frame.z) + out = S.Zero + out += diff(vectx, frame[0]) + out += diff(vecty, frame[1]) + out += diff(vectz, frame[2]) + return out + + +def gradient(scalar, frame): + """ + Returns the vector gradient of a scalar field computed wrt the + coordinate symbols of the given frame. + + Parameters + ========== + + scalar : sympifiable + The scalar field to take the gradient of + + frame : ReferenceFrame + The frame to calculate the gradient in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import gradient + >>> R = ReferenceFrame('R') + >>> s1 = R[0]*R[1]*R[2] + >>> gradient(s1, R) + R_y*R_z*R.x + R_x*R_z*R.y + R_x*R_y*R.z + >>> s2 = 5*R[0]**2*R[2] + >>> gradient(s2, R) + 10*R_x*R_z*R.x + 5*R_x**2*R.z + + """ + + _check_frame(frame) + outvec = Vector(0) + scalar = express(scalar, frame, variables=True) + for i, x in enumerate(frame): + outvec += diff(scalar, frame[i]) * x + return outvec + + +def is_conservative(field): + """ + Checks if a field is conservative. + + Parameters + ========== + + field : Vector + The field to check for conservative property + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import is_conservative + >>> R = ReferenceFrame('R') + >>> is_conservative(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) + True + >>> is_conservative(R[2] * R.y) + False + + """ + + # Field is conservative irrespective of frame + # Take the first frame in the result of the separate method of Vector + if field == Vector(0): + return True + frame = list(field.separate())[0] + return curl(field, frame).simplify() == Vector(0) + + +def is_solenoidal(field): + """ + Checks if a field is solenoidal. + + Parameters + ========== + + field : Vector + The field to check for solenoidal property + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import is_solenoidal + >>> R = ReferenceFrame('R') + >>> is_solenoidal(R[1]*R[2]*R.x + R[0]*R[2]*R.y + R[0]*R[1]*R.z) + True + >>> is_solenoidal(R[1] * R.y) + False + + """ + + # Field is solenoidal irrespective of frame + # Take the first frame in the result of the separate method in Vector + if field == Vector(0): + return True + frame = list(field.separate())[0] + return divergence(field, frame).simplify() is S.Zero + + +def scalar_potential(field, frame): + """ + Returns the scalar potential function of a field in a given frame + (without the added integration constant). + + Parameters + ========== + + field : Vector + The vector field whose scalar potential function is to be + calculated + + frame : ReferenceFrame + The frame to do the calculation in + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy.physics.vector import scalar_potential, gradient + >>> R = ReferenceFrame('R') + >>> scalar_potential(R.z, R) == R[2] + True + >>> scalar_field = 2*R[0]**2*R[1]*R[2] + >>> grad_field = gradient(scalar_field, R) + >>> scalar_potential(grad_field, R) + 2*R_x**2*R_y*R_z + + """ + + # Check whether field is conservative + if not is_conservative(field): + raise ValueError("Field is not conservative") + if field == Vector(0): + return S.Zero + # Express the field exntirely in frame + # Substitute coordinate variables also + _check_frame(frame) + field = express(field, frame, variables=True) + # Make a list of dimensions of the frame + dimensions = list(frame) + # Calculate scalar potential function + temp_function = integrate(field.dot(dimensions[0]), frame[0]) + for i, dim in enumerate(dimensions[1:]): + partial_diff = diff(temp_function, frame[i + 1]) + partial_diff = field.dot(dim) - partial_diff + temp_function += integrate(partial_diff, frame[i + 1]) + return temp_function + + +def scalar_potential_difference(field, frame, point1, point2, origin): + """ + Returns the scalar potential difference between two points in a + certain frame, wrt a given field. + + If a scalar field is provided, its values at the two points are + considered. If a conservative vector field is provided, the values + of its scalar potential function at the two points are used. + + Returns (potential at position 2) - (potential at position 1) + + Parameters + ========== + + field : Vector/sympyfiable + The field to calculate wrt + + frame : ReferenceFrame + The frame to do the calculations in + + point1 : Point + The initial Point in given frame + + position2 : Point + The second Point in the given frame + + origin : Point + The Point to use as reference point for position vector + calculation + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> from sympy.physics.vector import scalar_potential_difference + >>> R = ReferenceFrame('R') + >>> O = Point('O') + >>> P = O.locatenew('P', R[0]*R.x + R[1]*R.y + R[2]*R.z) + >>> vectfield = 4*R[0]*R[1]*R.x + 2*R[0]**2*R.y + >>> scalar_potential_difference(vectfield, R, O, P, O) + 2*R_x**2*R_y + >>> Q = O.locatenew('O', 3*R.x + R.y + 2*R.z) + >>> scalar_potential_difference(vectfield, R, P, Q, O) + -2*R_x**2*R_y + 18 + + """ + + _check_frame(frame) + if isinstance(field, Vector): + # Get the scalar potential function + scalar_fn = scalar_potential(field, frame) + else: + # Field is a scalar + scalar_fn = field + # Express positions in required frame + position1 = express(point1.pos_from(origin), frame, variables=True) + position2 = express(point2.pos_from(origin), frame, variables=True) + # Get the two positions as substitution dicts for coordinate variables + subs_dict1 = {} + subs_dict2 = {} + for i, x in enumerate(frame): + subs_dict1[frame[i]] = x.dot(position1) + subs_dict2[frame[i]] = x.dot(position2) + return scalar_fn.subs(subs_dict2) - scalar_fn.subs(subs_dict1) diff --git a/lib/python3.10/site-packages/sympy/physics/vector/frame.py b/lib/python3.10/site-packages/sympy/physics/vector/frame.py new file mode 100644 index 0000000000000000000000000000000000000000..dd0945fbf38d67b020dd88b67ae23984fa263ce1 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/frame.py @@ -0,0 +1,1575 @@ +from sympy import (diff, expand, sin, cos, sympify, eye, zeros, + ImmutableMatrix as Matrix, MatrixBase) +from sympy.core.symbol import Symbol +from sympy.simplify.trigsimp import trigsimp +from sympy.physics.vector.vector import Vector, _check_vector +from sympy.utilities.misc import translate + +from warnings import warn + +__all__ = ['CoordinateSym', 'ReferenceFrame'] + + +class CoordinateSym(Symbol): + """ + A coordinate symbol/base scalar associated wrt a Reference Frame. + + Ideally, users should not instantiate this class. Instances of + this class must only be accessed through the corresponding frame + as 'frame[index]'. + + CoordinateSyms having the same frame and index parameters are equal + (even though they may be instantiated separately). + + Parameters + ========== + + name : string + The display name of the CoordinateSym + + frame : ReferenceFrame + The reference frame this base scalar belongs to + + index : 0, 1 or 2 + The index of the dimension denoted by this coordinate variable + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, CoordinateSym + >>> A = ReferenceFrame('A') + >>> A[1] + A_y + >>> type(A[0]) + + >>> a_y = CoordinateSym('a_y', A, 1) + >>> a_y == A[1] + True + + """ + + def __new__(cls, name, frame, index): + # We can't use the cached Symbol.__new__ because this class depends on + # frame and index, which are not passed to Symbol.__xnew__. + assumptions = {} + super()._sanitize(assumptions, cls) + obj = super().__xnew__(cls, name, **assumptions) + _check_frame(frame) + if index not in range(0, 3): + raise ValueError("Invalid index specified") + obj._id = (frame, index) + return obj + + def __getnewargs_ex__(self): + return (self.name, *self._id), {} + + @property + def frame(self): + return self._id[0] + + def __eq__(self, other): + # Check if the other object is a CoordinateSym of the same frame and + # same index + if isinstance(other, CoordinateSym): + if other._id == self._id: + return True + return False + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return (self._id[0].__hash__(), self._id[1]).__hash__() + + +class ReferenceFrame: + """A reference frame in classical mechanics. + + ReferenceFrame is a class used to represent a reference frame in classical + mechanics. It has a standard basis of three unit vectors in the frame's + x, y, and z directions. + + It also can have a rotation relative to a parent frame; this rotation is + defined by a direction cosine matrix relating this frame's basis vectors to + the parent frame's basis vectors. It can also have an angular velocity + vector, defined in another frame. + + """ + _count = 0 + + def __init__(self, name, indices=None, latexs=None, variables=None): + """ReferenceFrame initialization method. + + A ReferenceFrame has a set of orthonormal basis vectors, along with + orientations relative to other ReferenceFrames and angular velocities + relative to other ReferenceFrames. + + Parameters + ========== + + indices : tuple of str + Enables the reference frame's basis unit vectors to be accessed by + Python's square bracket indexing notation using the provided three + indice strings and alters the printing of the unit vectors to + reflect this choice. + latexs : tuple of str + Alters the LaTeX printing of the reference frame's basis unit + vectors to the provided three valid LaTeX strings. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, vlatex + >>> N = ReferenceFrame('N') + >>> N.x + N.x + >>> O = ReferenceFrame('O', indices=('1', '2', '3')) + >>> O.x + O['1'] + >>> O['1'] + O['1'] + >>> P = ReferenceFrame('P', latexs=('A1', 'A2', 'A3')) + >>> vlatex(P.x) + 'A1' + + ``symbols()`` can be used to create multiple Reference Frames in one + step, for example: + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import symbols + >>> A, B, C = symbols('A B C', cls=ReferenceFrame) + >>> D, E = symbols('D E', cls=ReferenceFrame, indices=('1', '2', '3')) + >>> A[0] + A_x + >>> D.x + D['1'] + >>> E.y + E['2'] + >>> type(A) == type(D) + True + + Unit dyads for the ReferenceFrame can be accessed through the attributes ``xx``, ``xy``, etc. For example: + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> N.yz + (N.y|N.z) + >>> N.zx + (N.z|N.x) + >>> P = ReferenceFrame('P', indices=['1', '2', '3']) + >>> P.xx + (P['1']|P['1']) + >>> P.zy + (P['3']|P['2']) + + Unit dyadic is also accessible via the ``u`` attribute: + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> N.u + (N.x|N.x) + (N.y|N.y) + (N.z|N.z) + >>> P = ReferenceFrame('P', indices=['1', '2', '3']) + >>> P.u + (P['1']|P['1']) + (P['2']|P['2']) + (P['3']|P['3']) + + """ + + if not isinstance(name, str): + raise TypeError('Need to supply a valid name') + # The if statements below are for custom printing of basis-vectors for + # each frame. + # First case, when custom indices are supplied + if indices is not None: + if not isinstance(indices, (tuple, list)): + raise TypeError('Supply the indices as a list') + if len(indices) != 3: + raise ValueError('Supply 3 indices') + for i in indices: + if not isinstance(i, str): + raise TypeError('Indices must be strings') + self.str_vecs = [(name + '[\'' + indices[0] + '\']'), + (name + '[\'' + indices[1] + '\']'), + (name + '[\'' + indices[2] + '\']')] + self.pretty_vecs = [(name.lower() + "_" + indices[0]), + (name.lower() + "_" + indices[1]), + (name.lower() + "_" + indices[2])] + self.latex_vecs = [(r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[0])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[1])), + (r"\mathbf{\hat{%s}_{%s}}" % (name.lower(), + indices[2]))] + self.indices = indices + # Second case, when no custom indices are supplied + else: + self.str_vecs = [(name + '.x'), (name + '.y'), (name + '.z')] + self.pretty_vecs = [name.lower() + "_x", + name.lower() + "_y", + name.lower() + "_z"] + self.latex_vecs = [(r"\mathbf{\hat{%s}_x}" % name.lower()), + (r"\mathbf{\hat{%s}_y}" % name.lower()), + (r"\mathbf{\hat{%s}_z}" % name.lower())] + self.indices = ['x', 'y', 'z'] + # Different step, for custom latex basis vectors + if latexs is not None: + if not isinstance(latexs, (tuple, list)): + raise TypeError('Supply the indices as a list') + if len(latexs) != 3: + raise ValueError('Supply 3 indices') + for i in latexs: + if not isinstance(i, str): + raise TypeError('Latex entries must be strings') + self.latex_vecs = latexs + self.name = name + self._var_dict = {} + # The _dcm_dict dictionary will only store the dcms of adjacent + # parent-child relationships. The _dcm_cache dictionary will store + # calculated dcm along with all content of _dcm_dict for faster + # retrieval of dcms. + self._dcm_dict = {} + self._dcm_cache = {} + self._ang_vel_dict = {} + self._ang_acc_dict = {} + self._dlist = [self._dcm_dict, self._ang_vel_dict, self._ang_acc_dict] + self._cur = 0 + self._x = Vector([(Matrix([1, 0, 0]), self)]) + self._y = Vector([(Matrix([0, 1, 0]), self)]) + self._z = Vector([(Matrix([0, 0, 1]), self)]) + # Associate coordinate symbols wrt this frame + if variables is not None: + if not isinstance(variables, (tuple, list)): + raise TypeError('Supply the variable names as a list/tuple') + if len(variables) != 3: + raise ValueError('Supply 3 variable names') + for i in variables: + if not isinstance(i, str): + raise TypeError('Variable names must be strings') + else: + variables = [name + '_x', name + '_y', name + '_z'] + self.varlist = (CoordinateSym(variables[0], self, 0), + CoordinateSym(variables[1], self, 1), + CoordinateSym(variables[2], self, 2)) + ReferenceFrame._count += 1 + self.index = ReferenceFrame._count + + def __getitem__(self, ind): + """ + Returns basis vector for the provided index, if the index is a string. + + If the index is a number, returns the coordinate variable correspon- + -ding to that index. + """ + if not isinstance(ind, str): + if ind < 3: + return self.varlist[ind] + else: + raise ValueError("Invalid index provided") + if self.indices[0] == ind: + return self.x + if self.indices[1] == ind: + return self.y + if self.indices[2] == ind: + return self.z + else: + raise ValueError('Not a defined index') + + def __iter__(self): + return iter([self.x, self.y, self.z]) + + def __str__(self): + """Returns the name of the frame. """ + return self.name + + __repr__ = __str__ + + def _dict_list(self, other, num): + """Returns an inclusive list of reference frames that connect this + reference frame to the provided reference frame. + + Parameters + ========== + other : ReferenceFrame + The other reference frame to look for a connecting relationship to. + num : integer + ``0``, ``1``, and ``2`` will look for orientation, angular + velocity, and angular acceleration relationships between the two + frames, respectively. + + Returns + ======= + list + Inclusive list of reference frames that connect this reference + frame to the other reference frame. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> C = ReferenceFrame('C') + >>> D = ReferenceFrame('D') + >>> B.orient_axis(A, A.x, 1.0) + >>> C.orient_axis(B, B.x, 1.0) + >>> D.orient_axis(C, C.x, 1.0) + >>> D._dict_list(A, 0) + [D, C, B, A] + + Raises + ====== + + ValueError + When no path is found between the two reference frames or ``num`` + is an incorrect value. + + """ + + connect_type = {0: 'orientation', + 1: 'angular velocity', + 2: 'angular acceleration'} + + if num not in connect_type.keys(): + raise ValueError('Valid values for num are 0, 1, or 2.') + + possible_connecting_paths = [[self]] + oldlist = [[]] + while possible_connecting_paths != oldlist: + oldlist = possible_connecting_paths[:] # make a copy + for frame_list in possible_connecting_paths: + frames_adjacent_to_last = frame_list[-1]._dlist[num].keys() + for adjacent_frame in frames_adjacent_to_last: + if adjacent_frame not in frame_list: + connecting_path = frame_list + [adjacent_frame] + if connecting_path not in possible_connecting_paths: + possible_connecting_paths.append(connecting_path) + + for connecting_path in oldlist: + if connecting_path[-1] != other: + possible_connecting_paths.remove(connecting_path) + possible_connecting_paths.sort(key=len) + + if len(possible_connecting_paths) != 0: + return possible_connecting_paths[0] # selects the shortest path + + msg = 'No connecting {} path found between {} and {}.' + raise ValueError(msg.format(connect_type[num], self.name, other.name)) + + def _w_diff_dcm(self, otherframe): + """Angular velocity from time differentiating the DCM. """ + from sympy.physics.vector.functions import dynamicsymbols + dcm2diff = otherframe.dcm(self) + diffed = dcm2diff.diff(dynamicsymbols._t) + angvelmat = diffed * dcm2diff.T + w1 = trigsimp(expand(angvelmat[7]), recursive=True) + w2 = trigsimp(expand(angvelmat[2]), recursive=True) + w3 = trigsimp(expand(angvelmat[3]), recursive=True) + return Vector([(Matrix([w1, w2, w3]), otherframe)]) + + def variable_map(self, otherframe): + """ + Returns a dictionary which expresses the coordinate variables + of this frame in terms of the variables of otherframe. + + If Vector.simp is True, returns a simplified version of the mapped + values. Else, returns them without simplification. + + Simplification of the expressions may take time. + + Parameters + ========== + + otherframe : ReferenceFrame + The other frame to map the variables to + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> A = ReferenceFrame('A') + >>> q = dynamicsymbols('q') + >>> B = A.orientnew('B', 'Axis', [q, A.z]) + >>> A.variable_map(B) + {A_x: B_x*cos(q(t)) - B_y*sin(q(t)), A_y: B_x*sin(q(t)) + B_y*cos(q(t)), A_z: B_z} + + """ + + _check_frame(otherframe) + if (otherframe, Vector.simp) in self._var_dict: + return self._var_dict[(otherframe, Vector.simp)] + else: + vars_matrix = self.dcm(otherframe) * Matrix(otherframe.varlist) + mapping = {} + for i, x in enumerate(self): + if Vector.simp: + mapping[self.varlist[i]] = trigsimp(vars_matrix[i], + method='fu') + else: + mapping[self.varlist[i]] = vars_matrix[i] + self._var_dict[(otherframe, Vector.simp)] = mapping + return mapping + + def ang_acc_in(self, otherframe): + """Returns the angular acceleration Vector of the ReferenceFrame. + + Effectively returns the Vector: + + ``N_alpha_B`` + + which represent the angular acceleration of B in N, where B is self, + and N is otherframe. + + Parameters + ========== + + otherframe : ReferenceFrame + The ReferenceFrame which the angular acceleration is returned in. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_acc(N, V) + >>> A.ang_acc_in(N) + 10*N.x + + """ + + _check_frame(otherframe) + if otherframe in self._ang_acc_dict: + return self._ang_acc_dict[otherframe] + else: + return self.ang_vel_in(otherframe).dt(otherframe) + + def ang_vel_in(self, otherframe): + """Returns the angular velocity Vector of the ReferenceFrame. + + Effectively returns the Vector: + + ^N omega ^B + + which represent the angular velocity of B in N, where B is self, and + N is otherframe. + + Parameters + ========== + + otherframe : ReferenceFrame + The ReferenceFrame which the angular velocity is returned in. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_vel(N, V) + >>> A.ang_vel_in(N) + 10*N.x + + """ + + _check_frame(otherframe) + flist = self._dict_list(otherframe, 1) + outvec = Vector(0) + for i in range(len(flist) - 1): + outvec += flist[i]._ang_vel_dict[flist[i + 1]] + return outvec + + def dcm(self, otherframe): + r"""Returns the direction cosine matrix of this reference frame + relative to the provided reference frame. + + The returned matrix can be used to express the orthogonal unit vectors + of this frame in terms of the orthogonal unit vectors of + ``otherframe``. + + Parameters + ========== + + otherframe : ReferenceFrame + The reference frame which the direction cosine matrix of this frame + is formed relative to. + + Examples + ======== + + The following example rotates the reference frame A relative to N by a + simple rotation and then calculates the direction cosine matrix of N + relative to A. + + >>> from sympy import symbols, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> A.orient_axis(N, q1, N.x) + >>> N.dcm(A) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + The second row of the above direction cosine matrix represents the + ``N.y`` unit vector in N expressed in A. Like so: + + >>> Ny = 0*A.x + cos(q1)*A.y - sin(q1)*A.z + + Thus, expressing ``N.y`` in A should return the same result: + + >>> N.y.express(A) + cos(q1)*A.y - sin(q1)*A.z + + Notes + ===== + + It is important to know what form of the direction cosine matrix is + returned. If ``B.dcm(A)`` is called, it means the "direction cosine + matrix of B rotated relative to A". This is the matrix + :math:`{}^B\mathbf{C}^A` shown in the following relationship: + + .. math:: + + \begin{bmatrix} + \hat{\mathbf{b}}_1 \\ + \hat{\mathbf{b}}_2 \\ + \hat{\mathbf{b}}_3 + \end{bmatrix} + = + {}^B\mathbf{C}^A + \begin{bmatrix} + \hat{\mathbf{a}}_1 \\ + \hat{\mathbf{a}}_2 \\ + \hat{\mathbf{a}}_3 + \end{bmatrix}. + + :math:`{}^B\mathbf{C}^A` is the matrix that expresses the B unit + vectors in terms of the A unit vectors. + + """ + + _check_frame(otherframe) + # Check if the dcm wrt that frame has already been calculated + if otherframe in self._dcm_cache: + return self._dcm_cache[otherframe] + flist = self._dict_list(otherframe, 0) + outdcm = eye(3) + for i in range(len(flist) - 1): + outdcm = outdcm * flist[i]._dcm_dict[flist[i + 1]] + # After calculation, store the dcm in dcm cache for faster future + # retrieval + self._dcm_cache[otherframe] = outdcm + otherframe._dcm_cache[self] = outdcm.T + return outdcm + + def _dcm(self, parent, parent_orient): + # If parent.oreint(self) is already defined,then + # update the _dcm_dict of parent while over write + # all content of self._dcm_dict and self._dcm_cache + # with new dcm relation. + # Else update _dcm_cache and _dcm_dict of both + # self and parent. + frames = self._dcm_cache.keys() + dcm_dict_del = [] + dcm_cache_del = [] + if parent in frames: + for frame in frames: + if frame in self._dcm_dict: + dcm_dict_del += [frame] + dcm_cache_del += [frame] + # Reset the _dcm_cache of this frame, and remove it from the + # _dcm_caches of the frames it is linked to. Also remove it from + # the _dcm_dict of its parent + for frame in dcm_dict_del: + del frame._dcm_dict[self] + for frame in dcm_cache_del: + del frame._dcm_cache[self] + # Reset the _dcm_dict + self._dcm_dict = self._dlist[0] = {} + # Reset the _dcm_cache + self._dcm_cache = {} + + else: + # Check for loops and raise warning accordingly. + visited = [] + queue = list(frames) + cont = True # Flag to control queue loop. + while queue and cont: + node = queue.pop(0) + if node not in visited: + visited.append(node) + neighbors = node._dcm_dict.keys() + for neighbor in neighbors: + if neighbor == parent: + warn('Loops are defined among the orientation of ' + 'frames. This is likely not desired and may ' + 'cause errors in your calculations.') + cont = False + break + queue.append(neighbor) + + # Add the dcm relationship to _dcm_dict + self._dcm_dict.update({parent: parent_orient.T}) + parent._dcm_dict.update({self: parent_orient}) + # Update the dcm cache + self._dcm_cache.update({parent: parent_orient.T}) + parent._dcm_cache.update({self: parent_orient}) + + def orient_axis(self, parent, axis, angle): + """Sets the orientation of this reference frame with respect to a + parent reference frame by rotating through an angle about an axis fixed + in the parent reference frame. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + axis : Vector + Vector fixed in the parent frame about about which this frame is + rotated. It need not be a unit vector and the rotation follows the + right hand rule. + angle : sympifiable + Angle in radians by which it the frame is to be rotated. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.orient_axis(N, N.x, q1) + + The ``orient_axis()`` method generates a direction cosine matrix and + its transpose which defines the orientation of B relative to N and vice + versa. Once orient is called, ``dcm()`` outputs the appropriate + direction cosine matrix: + + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + >>> N.dcm(B) + Matrix([ + [1, 0, 0], + [0, cos(q1), -sin(q1)], + [0, sin(q1), cos(q1)]]) + + The following two lines show that the sense of the rotation can be + defined by negating the vector direction or the angle. Both lines + produce the same result. + + >>> B.orient_axis(N, -N.x, q1) + >>> B.orient_axis(N, N.x, -q1) + + """ + + from sympy.physics.vector.functions import dynamicsymbols + _check_frame(parent) + + if not isinstance(axis, Vector) and isinstance(angle, Vector): + axis, angle = angle, axis + + axis = _check_vector(axis) + theta = sympify(angle) + + if not axis.dt(parent) == 0: + raise ValueError('Axis cannot be time-varying.') + unit_axis = axis.express(parent).normalize() + unit_col = unit_axis.args[0][0] + parent_orient_axis = ( + (eye(3) - unit_col * unit_col.T) * cos(theta) + + Matrix([[0, -unit_col[2], unit_col[1]], + [unit_col[2], 0, -unit_col[0]], + [-unit_col[1], unit_col[0], 0]]) * + sin(theta) + unit_col * unit_col.T) + + self._dcm(parent, parent_orient_axis) + + thetad = (theta).diff(dynamicsymbols._t) + wvec = thetad*axis.express(parent).normalize() + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_explicit(self, parent, dcm): + """Sets the orientation of this reference frame relative to another (parent) reference frame + using a direction cosine matrix that describes the rotation from the parent to the child. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + dcm : Matrix, shape(3, 3) + Direction cosine matrix that specifies the relative rotation + between the two reference frames. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols, Matrix, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> N = ReferenceFrame('N') + + A simple rotation of ``A`` relative to ``N`` about ``N.x`` is defined + by the following direction cosine matrix: + + >>> dcm = Matrix([[1, 0, 0], + ... [0, cos(q1), -sin(q1)], + ... [0, sin(q1), cos(q1)]]) + >>> A.orient_explicit(N, dcm) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + This is equivalent to using ``orient_axis()``: + + >>> B.orient_axis(N, N.x, q1) + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + **Note carefully that** ``N.dcm(B)`` **(the transpose) would be passed + into** ``orient_explicit()`` **for** ``A.dcm(N)`` **to match** + ``B.dcm(N)``: + + >>> A.orient_explicit(N, N.dcm(B)) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + """ + _check_frame(parent) + # amounts must be a Matrix type object + # (e.g. sympy.matrices.dense.MutableDenseMatrix). + if not isinstance(dcm, MatrixBase): + raise TypeError("Amounts must be a SymPy Matrix type object.") + + self.orient_dcm(parent, dcm.T) + + def orient_dcm(self, parent, dcm): + """Sets the orientation of this reference frame relative to another (parent) reference frame + using a direction cosine matrix that describes the rotation from the child to the parent. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + dcm : Matrix, shape(3, 3) + Direction cosine matrix that specifies the relative rotation + between the two reference frames. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols, Matrix, sin, cos + >>> from sympy.physics.vector import ReferenceFrame + >>> q1 = symbols('q1') + >>> A = ReferenceFrame('A') + >>> B = ReferenceFrame('B') + >>> N = ReferenceFrame('N') + + A simple rotation of ``A`` relative to ``N`` about ``N.x`` is defined + by the following direction cosine matrix: + + >>> dcm = Matrix([[1, 0, 0], + ... [0, cos(q1), sin(q1)], + ... [0, -sin(q1), cos(q1)]]) + >>> A.orient_dcm(N, dcm) + >>> A.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + This is equivalent to using ``orient_axis()``: + + >>> B.orient_axis(N, N.x, q1) + >>> B.dcm(N) + Matrix([ + [1, 0, 0], + [0, cos(q1), sin(q1)], + [0, -sin(q1), cos(q1)]]) + + """ + + _check_frame(parent) + # amounts must be a Matrix type object + # (e.g. sympy.matrices.dense.MutableDenseMatrix). + if not isinstance(dcm, MatrixBase): + raise TypeError("Amounts must be a SymPy Matrix type object.") + + self._dcm(parent, dcm.T) + + wvec = self._w_diff_dcm(parent) + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def _rot(self, axis, angle): + """DCM for simple axis 1,2,or 3 rotations.""" + if axis == 1: + return Matrix([[1, 0, 0], + [0, cos(angle), -sin(angle)], + [0, sin(angle), cos(angle)]]) + elif axis == 2: + return Matrix([[cos(angle), 0, sin(angle)], + [0, 1, 0], + [-sin(angle), 0, cos(angle)]]) + elif axis == 3: + return Matrix([[cos(angle), -sin(angle), 0], + [sin(angle), cos(angle), 0], + [0, 0, 1]]) + + def _parse_consecutive_rotations(self, angles, rotation_order): + """Helper for orient_body_fixed and orient_space_fixed. + + Parameters + ========== + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations. The order can be specified by the strings + ``'XZX'``, ``'131'``, or the integer ``131``. There are 12 unique + valid rotation orders. + + Returns + ======= + + amounts : list + List of sympifiables corresponding to the rotation angles. + rot_order : list + List of integers corresponding to the axis of rotation. + rot_matrices : list + List of DCM around the given axis with corresponding magnitude. + + """ + amounts = list(angles) + for i, v in enumerate(amounts): + if not isinstance(v, Vector): + amounts[i] = sympify(v) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + # make sure XYZ => 123 + rot_order = translate(str(rotation_order), 'XYZxyz', '123123') + if rot_order not in approved_orders: + raise TypeError('The rotation order is not a valid order.') + + rot_order = [int(r) for r in rot_order] + if not (len(amounts) == 3 & len(rot_order) == 3): + raise TypeError('Body orientation takes 3 values & 3 orders') + rot_matrices = [self._rot(order, amount) + for (order, amount) in zip(rot_order, amounts)] + return amounts, rot_order, rot_matrices + + def orient_body_fixed(self, parent, angles, rotation_order): + """Rotates this reference frame relative to the parent reference frame + by right hand rotating through three successive body fixed simple axis + rotations. Each subsequent axis of rotation is about the "body fixed" + unit vectors of a new intermediate reference frame. This type of + rotation is also referred to rotating through the `Euler and Tait-Bryan + Angles`_. + + .. _Euler and Tait-Bryan Angles: https://en.wikipedia.org/wiki/Euler_angles + + The computed angular velocity in this method is by default expressed in + the child's frame, so it is most preferable to use ``u1 * child.x + u2 * + child.y + u3 * child.z`` as generalized speeds. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations about each intermediate reference frames' + unit vectors. The Euler rotation about the X, Z', X'' axes can be + specified by the strings ``'XZX'``, ``'131'``, or the integer + ``131``. There are 12 unique valid rotation orders (6 Euler and 6 + Tait-Bryan): zxz, xyx, yzy, zyz, xzx, yxy, xyz, yzx, zxy, xzy, zyx, + and yxz. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1, q2, q3 = symbols('q1, q2, q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B1 = ReferenceFrame('B1') + >>> B2 = ReferenceFrame('B2') + >>> B3 = ReferenceFrame('B3') + + For example, a classic Euler Angle rotation can be done by: + + >>> B.orient_body_fixed(N, (q1, q2, q3), 'XYX') + >>> B.dcm(N) + Matrix([ + [ cos(q2), sin(q1)*sin(q2), -sin(q2)*cos(q1)], + [sin(q2)*sin(q3), -sin(q1)*sin(q3)*cos(q2) + cos(q1)*cos(q3), sin(q1)*cos(q3) + sin(q3)*cos(q1)*cos(q2)], + [sin(q2)*cos(q3), -sin(q1)*cos(q2)*cos(q3) - sin(q3)*cos(q1), -sin(q1)*sin(q3) + cos(q1)*cos(q2)*cos(q3)]]) + + This rotates reference frame B relative to reference frame N through + ``q1`` about ``N.x``, then rotates B again through ``q2`` about + ``B.y``, and finally through ``q3`` about ``B.x``. It is equivalent to + three successive ``orient_axis()`` calls: + + >>> B1.orient_axis(N, N.x, q1) + >>> B2.orient_axis(B1, B1.y, q2) + >>> B3.orient_axis(B2, B2.x, q3) + >>> B3.dcm(N) + Matrix([ + [ cos(q2), sin(q1)*sin(q2), -sin(q2)*cos(q1)], + [sin(q2)*sin(q3), -sin(q1)*sin(q3)*cos(q2) + cos(q1)*cos(q3), sin(q1)*cos(q3) + sin(q3)*cos(q1)*cos(q2)], + [sin(q2)*cos(q3), -sin(q1)*cos(q2)*cos(q3) - sin(q3)*cos(q1), -sin(q1)*sin(q3) + cos(q1)*cos(q2)*cos(q3)]]) + + Acceptable rotation orders are of length 3, expressed in as a string + ``'XYZ'`` or ``'123'`` or integer ``123``. Rotations about an axis + twice in a row are prohibited. + + >>> B.orient_body_fixed(N, (q1, q2, 0), 'ZXZ') + >>> B.orient_body_fixed(N, (q1, q2, 0), '121') + >>> B.orient_body_fixed(N, (q1, q2, q3), 123) + + """ + from sympy.physics.vector.functions import dynamicsymbols + + _check_frame(parent) + + amounts, rot_order, rot_matrices = self._parse_consecutive_rotations( + angles, rotation_order) + self._dcm(parent, rot_matrices[0] * rot_matrices[1] * rot_matrices[2]) + + rot_vecs = [zeros(3, 1) for _ in range(3)] + for i, order in enumerate(rot_order): + rot_vecs[i][order - 1] = amounts[i].diff(dynamicsymbols._t) + u1, u2, u3 = rot_vecs[2] + rot_matrices[2].T * ( + rot_vecs[1] + rot_matrices[1].T * rot_vecs[0]) + wvec = u1 * self.x + u2 * self.y + u3 * self.z # There is a double - + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_space_fixed(self, parent, angles, rotation_order): + """Rotates this reference frame relative to the parent reference frame + by right hand rotating through three successive space fixed simple axis + rotations. Each subsequent axis of rotation is about the "space fixed" + unit vectors of the parent reference frame. + + The computed angular velocity in this method is by default expressed in + the child's frame, so it is most preferable to use ``u1 * child.x + u2 * + child.y + u3 * child.z`` as generalized speeds. + + Parameters + ========== + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + angles : 3-tuple of sympifiable + Three angles in radians used for the successive rotations. + rotation_order : 3 character string or 3 digit integer + Order of the rotations about the parent reference frame's unit + vectors. The order can be specified by the strings ``'XZX'``, + ``'131'``, or the integer ``131``. There are 12 unique valid + rotation orders. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q1, q2, q3 = symbols('q1, q2, q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B1 = ReferenceFrame('B1') + >>> B2 = ReferenceFrame('B2') + >>> B3 = ReferenceFrame('B3') + + >>> B.orient_space_fixed(N, (q1, q2, q3), '312') + >>> B.dcm(N) + Matrix([ + [ sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1)], + [-sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3)], + [ sin(q3)*cos(q2), -sin(q2), cos(q2)*cos(q3)]]) + + is equivalent to: + + >>> B1.orient_axis(N, N.z, q1) + >>> B2.orient_axis(B1, N.x, q2) + >>> B3.orient_axis(B2, N.y, q3) + >>> B3.dcm(N).simplify() + Matrix([ + [ sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3), sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1)], + [-sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1), cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3)], + [ sin(q3)*cos(q2), -sin(q2), cos(q2)*cos(q3)]]) + + It is worth noting that space-fixed and body-fixed rotations are + related by the order of the rotations, i.e. the reverse order of body + fixed will give space fixed and vice versa. + + >>> B.orient_space_fixed(N, (q1, q2, q3), '231') + >>> B.dcm(N) + Matrix([ + [cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), -sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [ -sin(q2), cos(q2)*cos(q3), sin(q3)*cos(q2)], + [sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3)]]) + + >>> B.orient_body_fixed(N, (q3, q2, q1), '132') + >>> B.dcm(N) + Matrix([ + [cos(q1)*cos(q2), sin(q1)*sin(q3) + sin(q2)*cos(q1)*cos(q3), -sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1)], + [ -sin(q2), cos(q2)*cos(q3), sin(q3)*cos(q2)], + [sin(q1)*cos(q2), sin(q1)*sin(q2)*cos(q3) - sin(q3)*cos(q1), sin(q1)*sin(q2)*sin(q3) + cos(q1)*cos(q3)]]) + + """ + from sympy.physics.vector.functions import dynamicsymbols + + _check_frame(parent) + + amounts, rot_order, rot_matrices = self._parse_consecutive_rotations( + angles, rotation_order) + self._dcm(parent, rot_matrices[2] * rot_matrices[1] * rot_matrices[0]) + + rot_vecs = [zeros(3, 1) for _ in range(3)] + for i, order in enumerate(rot_order): + rot_vecs[i][order - 1] = amounts[i].diff(dynamicsymbols._t) + u1, u2, u3 = rot_vecs[0] + rot_matrices[0].T * ( + rot_vecs[1] + rot_matrices[1].T * rot_vecs[2]) + wvec = u1 * self.x + u2 * self.y + u3 * self.z # There is a double - + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient_quaternion(self, parent, numbers): + """Sets the orientation of this reference frame relative to a parent + reference frame via an orientation quaternion. An orientation + quaternion is defined as a finite rotation a unit vector, ``(lambda_x, + lambda_y, lambda_z)``, by an angle ``theta``. The orientation + quaternion is described by four parameters: + + - ``q0 = cos(theta/2)`` + - ``q1 = lambda_x*sin(theta/2)`` + - ``q2 = lambda_y*sin(theta/2)`` + - ``q3 = lambda_z*sin(theta/2)`` + + See `Quaternions and Spatial Rotation + `_ on + Wikipedia for more information. + + Parameters + ========== + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + numbers : 4-tuple of sympifiable + The four quaternion scalar numbers as defined above: ``q0``, + ``q1``, ``q2``, ``q3``. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + Examples + ======== + + Setup variables for the examples: + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + + Set the orientation: + + >>> B.orient_quaternion(N, (q0, q1, q2, q3)) + >>> B.dcm(N) + Matrix([ + [q0**2 + q1**2 - q2**2 - q3**2, 2*q0*q3 + 2*q1*q2, -2*q0*q2 + 2*q1*q3], + [ -2*q0*q3 + 2*q1*q2, q0**2 - q1**2 + q2**2 - q3**2, 2*q0*q1 + 2*q2*q3], + [ 2*q0*q2 + 2*q1*q3, -2*q0*q1 + 2*q2*q3, q0**2 - q1**2 - q2**2 + q3**2]]) + + """ + + from sympy.physics.vector.functions import dynamicsymbols + _check_frame(parent) + + numbers = list(numbers) + for i, v in enumerate(numbers): + if not isinstance(v, Vector): + numbers[i] = sympify(v) + + if not (isinstance(numbers, (list, tuple)) & (len(numbers) == 4)): + raise TypeError('Amounts are a list or tuple of length 4') + q0, q1, q2, q3 = numbers + parent_orient_quaternion = ( + Matrix([[q0**2 + q1**2 - q2**2 - q3**2, + 2 * (q1 * q2 - q0 * q3), + 2 * (q0 * q2 + q1 * q3)], + [2 * (q1 * q2 + q0 * q3), + q0**2 - q1**2 + q2**2 - q3**2, + 2 * (q2 * q3 - q0 * q1)], + [2 * (q1 * q3 - q0 * q2), + 2 * (q0 * q1 + q2 * q3), + q0**2 - q1**2 - q2**2 + q3**2]])) + + self._dcm(parent, parent_orient_quaternion) + + t = dynamicsymbols._t + q0, q1, q2, q3 = numbers + q0d = diff(q0, t) + q1d = diff(q1, t) + q2d = diff(q2, t) + q3d = diff(q3, t) + w1 = 2 * (q1d * q0 + q2d * q3 - q3d * q2 - q0d * q1) + w2 = 2 * (q2d * q0 + q3d * q1 - q1d * q3 - q0d * q2) + w3 = 2 * (q3d * q0 + q1d * q2 - q2d * q1 - q0d * q3) + wvec = Vector([(Matrix([w1, w2, w3]), self)]) + + self._ang_vel_dict.update({parent: wvec}) + parent._ang_vel_dict.update({self: -wvec}) + self._var_dict = {} + + def orient(self, parent, rot_type, amounts, rot_order=''): + """Sets the orientation of this reference frame relative to another + (parent) reference frame. + + .. note:: It is now recommended to use the ``.orient_axis, + .orient_body_fixed, .orient_space_fixed, .orient_quaternion`` + methods for the different rotation types. + + Parameters + ========== + + parent : ReferenceFrame + Reference frame that this reference frame will be rotated relative + to. + rot_type : str + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Axis'``: simple rotations about a single common axis + - ``'DCM'``: for setting the direction cosine matrix directly + - ``'Body'``: three successive rotations about new intermediate + axes, also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent + frames' unit vectors + - ``'Quaternion'``: rotations defined by four parameters which + result in a singularity free direction cosine matrix + + amounts : + Expressions defining the rotation angles or direction cosine + matrix. These must match the ``rot_type``. See examples below for + details. The input types are: + + - ``'Axis'``: 2-tuple (expr/sym/func, Vector) + - ``'DCM'``: Matrix, shape(3,3) + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + - ``'Quaternion'``: 4-tuple of expressions, symbols, or + functions + + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required + for ``'Body'`` and ``'Space'``. + + Warns + ====== + + UserWarning + If the orientation creates a kinematic loop. + + """ + + _check_frame(parent) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.upper() + + if rot_order not in approved_orders: + raise TypeError('The supplied order is not an approved type') + + if rot_type == 'AXIS': + self.orient_axis(parent, amounts[1], amounts[0]) + + elif rot_type == 'DCM': + self.orient_explicit(parent, amounts) + + elif rot_type == 'BODY': + self.orient_body_fixed(parent, amounts, rot_order) + + elif rot_type == 'SPACE': + self.orient_space_fixed(parent, amounts, rot_order) + + elif rot_type == 'QUATERNION': + self.orient_quaternion(parent, amounts) + + else: + raise NotImplementedError('That is not an implemented rotation') + + def orientnew(self, newname, rot_type, amounts, rot_order='', + variables=None, indices=None, latexs=None): + r"""Returns a new reference frame oriented with respect to this + reference frame. + + See ``ReferenceFrame.orient()`` for detailed examples of how to orient + reference frames. + + Parameters + ========== + + newname : str + Name for the new reference frame. + rot_type : str + The method used to generate the direction cosine matrix. Supported + methods are: + + - ``'Axis'``: simple rotations about a single common axis + - ``'DCM'``: for setting the direction cosine matrix directly + - ``'Body'``: three successive rotations about new intermediate + axes, also called "Euler and Tait-Bryan angles" + - ``'Space'``: three successive rotations about the parent + frames' unit vectors + - ``'Quaternion'``: rotations defined by four parameters which + result in a singularity free direction cosine matrix + + amounts : + Expressions defining the rotation angles or direction cosine + matrix. These must match the ``rot_type``. See examples below for + details. The input types are: + + - ``'Axis'``: 2-tuple (expr/sym/func, Vector) + - ``'DCM'``: Matrix, shape(3,3) + - ``'Body'``: 3-tuple of expressions, symbols, or functions + - ``'Space'``: 3-tuple of expressions, symbols, or functions + - ``'Quaternion'``: 4-tuple of expressions, symbols, or + functions + + rot_order : str or int, optional + If applicable, the order of the successive of rotations. The string + ``'123'`` and integer ``123`` are equivalent, for example. Required + for ``'Body'`` and ``'Space'``. + indices : tuple of str + Enables the reference frame's basis unit vectors to be accessed by + Python's square bracket indexing notation using the provided three + indice strings and alters the printing of the unit vectors to + reflect this choice. + latexs : tuple of str + Alters the LaTeX printing of the reference frame's basis unit + vectors to the provided three valid LaTeX strings. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame, vlatex + >>> q0, q1, q2, q3 = symbols('q0 q1 q2 q3') + >>> N = ReferenceFrame('N') + + Create a new reference frame A rotated relative to N through a simple + rotation. + + >>> A = N.orientnew('A', 'Axis', (q0, N.x)) + + Create a new reference frame B rotated relative to N through body-fixed + rotations. + + >>> B = N.orientnew('B', 'Body', (q1, q2, q3), '123') + + Create a new reference frame C rotated relative to N through a simple + rotation with unique indices and LaTeX printing. + + >>> C = N.orientnew('C', 'Axis', (q0, N.x), indices=('1', '2', '3'), + ... latexs=(r'\hat{\mathbf{c}}_1',r'\hat{\mathbf{c}}_2', + ... r'\hat{\mathbf{c}}_3')) + >>> C['1'] + C['1'] + >>> print(vlatex(C['1'])) + \hat{\mathbf{c}}_1 + + """ + + newframe = self.__class__(newname, variables=variables, + indices=indices, latexs=latexs) + + approved_orders = ('123', '231', '312', '132', '213', '321', '121', + '131', '212', '232', '313', '323', '') + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.upper() + + if rot_order not in approved_orders: + raise TypeError('The supplied order is not an approved type') + + if rot_type == 'AXIS': + newframe.orient_axis(self, amounts[1], amounts[0]) + + elif rot_type == 'DCM': + newframe.orient_explicit(self, amounts) + + elif rot_type == 'BODY': + newframe.orient_body_fixed(self, amounts, rot_order) + + elif rot_type == 'SPACE': + newframe.orient_space_fixed(self, amounts, rot_order) + + elif rot_type == 'QUATERNION': + newframe.orient_quaternion(self, amounts) + + else: + raise NotImplementedError('That is not an implemented rotation') + return newframe + + def set_ang_acc(self, otherframe, value): + """Define the angular acceleration Vector in a ReferenceFrame. + + Defines the angular acceleration of this ReferenceFrame, in another. + Angular acceleration can be defined with respect to multiple different + ReferenceFrames. Care must be taken to not create loops which are + inconsistent. + + Parameters + ========== + + otherframe : ReferenceFrame + A ReferenceFrame to define the angular acceleration in + value : Vector + The Vector representing angular acceleration + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_acc(N, V) + >>> A.ang_acc_in(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(otherframe) + self._ang_acc_dict.update({otherframe: value}) + otherframe._ang_acc_dict.update({self: -value}) + + def set_ang_vel(self, otherframe, value): + """Define the angular velocity vector in a ReferenceFrame. + + Defines the angular velocity of this ReferenceFrame, in another. + Angular velocity can be defined with respect to multiple different + ReferenceFrames. Care must be taken to not create loops which are + inconsistent. + + Parameters + ========== + + otherframe : ReferenceFrame + A ReferenceFrame to define the angular velocity in + value : Vector + The Vector representing angular velocity + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> V = 10 * N.x + >>> A.set_ang_vel(N, V) + >>> A.ang_vel_in(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(otherframe) + self._ang_vel_dict.update({otherframe: value}) + otherframe._ang_vel_dict.update({self: -value}) + + @property + def x(self): + """The basis Vector for the ReferenceFrame, in the x direction. """ + return self._x + + @property + def y(self): + """The basis Vector for the ReferenceFrame, in the y direction. """ + return self._y + + @property + def z(self): + """The basis Vector for the ReferenceFrame, in the z direction. """ + return self._z + + @property + def xx(self): + """Unit dyad of basis Vectors x and x for the ReferenceFrame.""" + return Vector.outer(self.x, self.x) + + @property + def xy(self): + """Unit dyad of basis Vectors x and y for the ReferenceFrame.""" + return Vector.outer(self.x, self.y) + + @property + def xz(self): + """Unit dyad of basis Vectors x and z for the ReferenceFrame.""" + return Vector.outer(self.x, self.z) + + @property + def yx(self): + """Unit dyad of basis Vectors y and x for the ReferenceFrame.""" + return Vector.outer(self.y, self.x) + + @property + def yy(self): + """Unit dyad of basis Vectors y and y for the ReferenceFrame.""" + return Vector.outer(self.y, self.y) + + @property + def yz(self): + """Unit dyad of basis Vectors y and z for the ReferenceFrame.""" + return Vector.outer(self.y, self.z) + + @property + def zx(self): + """Unit dyad of basis Vectors z and x for the ReferenceFrame.""" + return Vector.outer(self.z, self.x) + + @property + def zy(self): + """Unit dyad of basis Vectors z and y for the ReferenceFrame.""" + return Vector.outer(self.z, self.y) + + @property + def zz(self): + """Unit dyad of basis Vectors z and z for the ReferenceFrame.""" + return Vector.outer(self.z, self.z) + + @property + def u(self): + """Unit dyadic for the ReferenceFrame.""" + return self.xx + self.yy + self.zz + + def partial_velocity(self, frame, *gen_speeds): + """Returns the partial angular velocities of this frame in the given + frame with respect to one or more provided generalized speeds. + + Parameters + ========== + frame : ReferenceFrame + The frame with which the angular velocity is defined in. + gen_speeds : functions of time + The generalized speeds. + + Returns + ======= + partial_velocities : tuple of Vector + The partial angular velocity vectors corresponding to the provided + generalized speeds. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> A.set_ang_vel(N, u1 * A.x + u2 * N.y) + >>> A.partial_velocity(N, u1) + A.x + >>> A.partial_velocity(N, u1, u2) + (A.x, N.y) + + """ + + from sympy.physics.vector.functions import partial_velocity + + vel = self.ang_vel_in(frame) + partials = partial_velocity([vel], gen_speeds, frame)[0] + + if len(partials) == 1: + return partials[0] + else: + return tuple(partials) + + +def _check_frame(other): + from .vector import VectorTypeError + if not isinstance(other, ReferenceFrame): + raise VectorTypeError(other, ReferenceFrame('A')) diff --git a/lib/python3.10/site-packages/sympy/physics/vector/functions.py b/lib/python3.10/site-packages/sympy/physics/vector/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..6775b4b23bb376992d6a9e7651ba73a951c84287 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/functions.py @@ -0,0 +1,650 @@ +from functools import reduce + +from sympy import (sympify, diff, sin, cos, Matrix, symbols, + Function, S, Symbol, linear_eq_to_matrix) +from sympy.integrals.integrals import integrate +from sympy.simplify.trigsimp import trigsimp +from .vector import Vector, _check_vector +from .frame import CoordinateSym, _check_frame +from .dyadic import Dyadic +from .printing import vprint, vsprint, vpprint, vlatex, init_vprinting +from sympy.utilities.iterables import iterable +from sympy.utilities.misc import translate + +__all__ = ['cross', 'dot', 'express', 'time_derivative', 'outer', + 'kinematic_equations', 'get_motion_params', 'partial_velocity', + 'dynamicsymbols', 'vprint', 'vsprint', 'vpprint', 'vlatex', + 'init_vprinting'] + + +def cross(vec1, vec2): + """Cross product convenience wrapper for Vector.cross(): \n""" + if not isinstance(vec1, (Vector, Dyadic)): + raise TypeError('Cross product is between two vectors') + return vec1 ^ vec2 + + +cross.__doc__ += Vector.cross.__doc__ # type: ignore + + +def dot(vec1, vec2): + """Dot product convenience wrapper for Vector.dot(): \n""" + if not isinstance(vec1, (Vector, Dyadic)): + raise TypeError('Dot product is between two vectors') + return vec1 & vec2 + + +dot.__doc__ += Vector.dot.__doc__ # type: ignore + + +def express(expr, frame, frame2=None, variables=False): + """ + Global function for 'express' functionality. + + Re-expresses a Vector, scalar(sympyfiable) or Dyadic in given frame. + + Refer to the local methods of Vector and Dyadic for details. + If 'variables' is True, then the coordinate variables (CoordinateSym + instances) of other frames present in the vector/scalar field or + dyadic expression are also substituted in terms of the base scalars of + this frame. + + Parameters + ========== + + expr : Vector/Dyadic/scalar(sympyfiable) + The expression to re-express in ReferenceFrame 'frame' + + frame: ReferenceFrame + The reference frame to express expr in + + frame2 : ReferenceFrame + The other frame required for re-expression(only for Dyadic expr) + + variables : boolean + Specifies whether to substitute the coordinate variables present + in expr, in terms of those of frame + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> q = dynamicsymbols('q') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> d = outer(N.x, N.x) + >>> from sympy.physics.vector import express + >>> express(d, B, N) + cos(q)*(B.x|N.x) - sin(q)*(B.y|N.x) + >>> express(B.x, N) + cos(q)*N.x + sin(q)*N.y + >>> express(N[0], B, variables=True) + B_x*cos(q) - B_y*sin(q) + + """ + + _check_frame(frame) + + if expr == 0: + return expr + + if isinstance(expr, Vector): + # Given expr is a Vector + if variables: + # If variables attribute is True, substitute the coordinate + # variables in the Vector + frame_list = [x[-1] for x in expr.args] + subs_dict = {} + for f in frame_list: + subs_dict.update(f.variable_map(frame)) + expr = expr.subs(subs_dict) + # Re-express in this frame + outvec = Vector([]) + for v in expr.args: + if v[1] != frame: + temp = frame.dcm(v[1]) * v[0] + if Vector.simp: + temp = temp.applyfunc(lambda x: + trigsimp(x, method='fu')) + outvec += Vector([(temp, frame)]) + else: + outvec += Vector([v]) + return outvec + + if isinstance(expr, Dyadic): + if frame2 is None: + frame2 = frame + _check_frame(frame2) + ol = Dyadic(0) + for v in expr.args: + ol += express(v[0], frame, variables=variables) * \ + (express(v[1], frame, variables=variables) | + express(v[2], frame2, variables=variables)) + return ol + + else: + if variables: + # Given expr is a scalar field + frame_set = set() + expr = sympify(expr) + # Substitute all the coordinate variables + for x in expr.free_symbols: + if isinstance(x, CoordinateSym) and x.frame != frame: + frame_set.add(x.frame) + subs_dict = {} + for f in frame_set: + subs_dict.update(f.variable_map(frame)) + return expr.subs(subs_dict) + return expr + + +def time_derivative(expr, frame, order=1): + """ + Calculate the time derivative of a vector/scalar field function + or dyadic expression in given frame. + + References + ========== + + https://en.wikipedia.org/wiki/Rotating_reference_frame#Time_derivatives_in_the_two_frames + + Parameters + ========== + + expr : Vector/Dyadic/sympifyable + The expression whose time derivative is to be calculated + + frame : ReferenceFrame + The reference frame to calculate the time derivative in + + order : integer + The order of the derivative to be calculated + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import Symbol + >>> q1 = Symbol('q1') + >>> u1 = dynamicsymbols('u1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.x]) + >>> v = u1 * N.x + >>> A.set_ang_vel(N, 10*A.x) + >>> from sympy.physics.vector import time_derivative + >>> time_derivative(v, N) + u1'*N.x + >>> time_derivative(u1*A[0], N) + N_x*u1' + >>> B = N.orientnew('B', 'Axis', [u1, N.z]) + >>> from sympy.physics.vector import outer + >>> d = outer(N.x, N.x) + >>> time_derivative(d, B) + - u1'*(N.y|N.x) - u1'*(N.x|N.y) + + """ + + t = dynamicsymbols._t + _check_frame(frame) + + if order == 0: + return expr + if order % 1 != 0 or order < 0: + raise ValueError("Unsupported value of order entered") + + if isinstance(expr, Vector): + outlist = [] + for v in expr.args: + if v[1] == frame: + outlist += [(express(v[0], frame, variables=True).diff(t), + frame)] + else: + outlist += (time_derivative(Vector([v]), v[1]) + + (v[1].ang_vel_in(frame) ^ Vector([v]))).args + outvec = Vector(outlist) + return time_derivative(outvec, frame, order - 1) + + if isinstance(expr, Dyadic): + ol = Dyadic(0) + for v in expr.args: + ol += (v[0].diff(t) * (v[1] | v[2])) + ol += (v[0] * (time_derivative(v[1], frame) | v[2])) + ol += (v[0] * (v[1] | time_derivative(v[2], frame))) + return time_derivative(ol, frame, order - 1) + + else: + return diff(express(expr, frame, variables=True), t, order) + + +def outer(vec1, vec2): + """Outer product convenience wrapper for Vector.outer():\n""" + if not isinstance(vec1, Vector): + raise TypeError('Outer product is between two Vectors') + return vec1.outer(vec2) + + +outer.__doc__ += Vector.outer.__doc__ # type: ignore + + +def kinematic_equations(speeds, coords, rot_type, rot_order=''): + """Gives equations relating the qdot's to u's for a rotation type. + + Supply rotation type and order as in orient. Speeds are assumed to be + body-fixed; if we are defining the orientation of B in A using by rot_type, + the angular velocity of B in A is assumed to be in the form: speed[0]*B.x + + speed[1]*B.y + speed[2]*B.z + + Parameters + ========== + + speeds : list of length 3 + The body fixed angular velocity measure numbers. + coords : list of length 3 or 4 + The coordinates used to define the orientation of the two frames. + rot_type : str + The type of rotation used to create the equations. Body, Space, or + Quaternion only + rot_order : str or int + If applicable, the order of a series of rotations. + + Examples + ======== + + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import kinematic_equations, vprint + >>> u1, u2, u3 = dynamicsymbols('u1 u2 u3') + >>> q1, q2, q3 = dynamicsymbols('q1 q2 q3') + >>> vprint(kinematic_equations([u1,u2,u3], [q1,q2,q3], 'body', '313'), + ... order=None) + [-(u1*sin(q3) + u2*cos(q3))/sin(q2) + q1', -u1*cos(q3) + u2*sin(q3) + q2', (u1*sin(q3) + u2*cos(q3))*cos(q2)/sin(q2) - u3 + q3'] + + """ + + # Code below is checking and sanitizing input + approved_orders = ('123', '231', '312', '132', '213', '321', '121', '131', + '212', '232', '313', '323', '1', '2', '3', '') + # make sure XYZ => 123 and rot_type is in lower case + rot_order = translate(str(rot_order), 'XYZxyz', '123123') + rot_type = rot_type.lower() + + if not isinstance(speeds, (list, tuple)): + raise TypeError('Need to supply speeds in a list') + if len(speeds) != 3: + raise TypeError('Need to supply 3 body-fixed speeds') + if not isinstance(coords, (list, tuple)): + raise TypeError('Need to supply coordinates in a list') + if rot_type in ['body', 'space']: + if rot_order not in approved_orders: + raise ValueError('Not an acceptable rotation order') + if len(coords) != 3: + raise ValueError('Need 3 coordinates for body or space') + # Actual hard-coded kinematic differential equations + w1, w2, w3 = speeds + if w1 == w2 == w3 == 0: + return [S.Zero]*3 + q1, q2, q3 = coords + q1d, q2d, q3d = [diff(i, dynamicsymbols._t) for i in coords] + s1, s2, s3 = [sin(q1), sin(q2), sin(q3)] + c1, c2, c3 = [cos(q1), cos(q2), cos(q3)] + if rot_type == 'body': + if rot_order == '123': + return [q1d - (w1 * c3 - w2 * s3) / c2, q2d - w1 * s3 - w2 * + c3, q3d - (-w1 * c3 + w2 * s3) * s2 / c2 - w3] + if rot_order == '231': + return [q1d - (w2 * c3 - w3 * s3) / c2, q2d - w2 * s3 - w3 * + c3, q3d - w1 - (- w2 * c3 + w3 * s3) * s2 / c2] + if rot_order == '312': + return [q1d - (-w1 * s3 + w3 * c3) / c2, q2d - w1 * c3 - w3 * + s3, q3d - (w1 * s3 - w3 * c3) * s2 / c2 - w2] + if rot_order == '132': + return [q1d - (w1 * c3 + w3 * s3) / c2, q2d + w1 * s3 - w3 * + c3, q3d - (w1 * c3 + w3 * s3) * s2 / c2 - w2] + if rot_order == '213': + return [q1d - (w1 * s3 + w2 * c3) / c2, q2d - w1 * c3 + w2 * + s3, q3d - (w1 * s3 + w2 * c3) * s2 / c2 - w3] + if rot_order == '321': + return [q1d - (w2 * s3 + w3 * c3) / c2, q2d - w2 * c3 + w3 * + s3, q3d - w1 - (w2 * s3 + w3 * c3) * s2 / c2] + if rot_order == '121': + return [q1d - (w2 * s3 + w3 * c3) / s2, q2d - w2 * c3 + w3 * + s3, q3d - w1 + (w2 * s3 + w3 * c3) * c2 / s2] + if rot_order == '131': + return [q1d - (-w2 * c3 + w3 * s3) / s2, q2d - w2 * s3 - w3 * + c3, q3d - w1 - (w2 * c3 - w3 * s3) * c2 / s2] + if rot_order == '212': + return [q1d - (w1 * s3 - w3 * c3) / s2, q2d - w1 * c3 - w3 * + s3, q3d - (-w1 * s3 + w3 * c3) * c2 / s2 - w2] + if rot_order == '232': + return [q1d - (w1 * c3 + w3 * s3) / s2, q2d + w1 * s3 - w3 * + c3, q3d + (w1 * c3 + w3 * s3) * c2 / s2 - w2] + if rot_order == '313': + return [q1d - (w1 * s3 + w2 * c3) / s2, q2d - w1 * c3 + w2 * + s3, q3d + (w1 * s3 + w2 * c3) * c2 / s2 - w3] + if rot_order == '323': + return [q1d - (-w1 * c3 + w2 * s3) / s2, q2d - w1 * s3 - w2 * + c3, q3d - (w1 * c3 - w2 * s3) * c2 / s2 - w3] + if rot_type == 'space': + if rot_order == '123': + return [q1d - w1 - (w2 * s1 + w3 * c1) * s2 / c2, q2d - w2 * + c1 + w3 * s1, q3d - (w2 * s1 + w3 * c1) / c2] + if rot_order == '231': + return [q1d - (w1 * c1 + w3 * s1) * s2 / c2 - w2, q2d + w1 * + s1 - w3 * c1, q3d - (w1 * c1 + w3 * s1) / c2] + if rot_order == '312': + return [q1d - (w1 * s1 + w2 * c1) * s2 / c2 - w3, q2d - w1 * + c1 + w2 * s1, q3d - (w1 * s1 + w2 * c1) / c2] + if rot_order == '132': + return [q1d - w1 - (-w2 * c1 + w3 * s1) * s2 / c2, q2d - w2 * + s1 - w3 * c1, q3d - (w2 * c1 - w3 * s1) / c2] + if rot_order == '213': + return [q1d - (w1 * s1 - w3 * c1) * s2 / c2 - w2, q2d - w1 * + c1 - w3 * s1, q3d - (-w1 * s1 + w3 * c1) / c2] + if rot_order == '321': + return [q1d - (-w1 * c1 + w2 * s1) * s2 / c2 - w3, q2d - w1 * + s1 - w2 * c1, q3d - (w1 * c1 - w2 * s1) / c2] + if rot_order == '121': + return [q1d - w1 + (w2 * s1 + w3 * c1) * c2 / s2, q2d - w2 * + c1 + w3 * s1, q3d - (w2 * s1 + w3 * c1) / s2] + if rot_order == '131': + return [q1d - w1 - (w2 * c1 - w3 * s1) * c2 / s2, q2d - w2 * + s1 - w3 * c1, q3d - (-w2 * c1 + w3 * s1) / s2] + if rot_order == '212': + return [q1d - (-w1 * s1 + w3 * c1) * c2 / s2 - w2, q2d - w1 * + c1 - w3 * s1, q3d - (w1 * s1 - w3 * c1) / s2] + if rot_order == '232': + return [q1d + (w1 * c1 + w3 * s1) * c2 / s2 - w2, q2d + w1 * + s1 - w3 * c1, q3d - (w1 * c1 + w3 * s1) / s2] + if rot_order == '313': + return [q1d + (w1 * s1 + w2 * c1) * c2 / s2 - w3, q2d - w1 * + c1 + w2 * s1, q3d - (w1 * s1 + w2 * c1) / s2] + if rot_order == '323': + return [q1d - (w1 * c1 - w2 * s1) * c2 / s2 - w3, q2d - w1 * + s1 - w2 * c1, q3d - (-w1 * c1 + w2 * s1) / s2] + elif rot_type == 'quaternion': + if rot_order != '': + raise ValueError('Cannot have rotation order for quaternion') + if len(coords) != 4: + raise ValueError('Need 4 coordinates for quaternion') + # Actual hard-coded kinematic differential equations + e0, e1, e2, e3 = coords + w = Matrix(speeds + [0]) + E = Matrix([[e0, -e3, e2, e1], + [e3, e0, -e1, e2], + [-e2, e1, e0, e3], + [-e1, -e2, -e3, e0]]) + edots = Matrix([diff(i, dynamicsymbols._t) for i in [e1, e2, e3, e0]]) + return list(edots.T - 0.5 * w.T * E.T) + else: + raise ValueError('Not an approved rotation type for this function') + + +def get_motion_params(frame, **kwargs): + """ + Returns the three motion parameters - (acceleration, velocity, and + position) as vectorial functions of time in the given frame. + + If a higher order differential function is provided, the lower order + functions are used as boundary conditions. For example, given the + acceleration, the velocity and position parameters are taken as + boundary conditions. + + The values of time at which the boundary conditions are specified + are taken from timevalue1(for position boundary condition) and + timevalue2(for velocity boundary condition). + + If any of the boundary conditions are not provided, they are taken + to be zero by default (zero vectors, in case of vectorial inputs). If + the boundary conditions are also functions of time, they are converted + to constants by substituting the time values in the dynamicsymbols._t + time Symbol. + + This function can also be used for calculating rotational motion + parameters. Have a look at the Parameters and Examples for more clarity. + + Parameters + ========== + + frame : ReferenceFrame + The frame to express the motion parameters in + + acceleration : Vector + Acceleration of the object/frame as a function of time + + velocity : Vector + Velocity as function of time or as boundary condition + of velocity at time = timevalue1 + + position : Vector + Velocity as function of time or as boundary condition + of velocity at time = timevalue1 + + timevalue1 : sympyfiable + Value of time for position boundary condition + + timevalue2 : sympyfiable + Value of time for velocity boundary condition + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, get_motion_params, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import symbols + >>> R = ReferenceFrame('R') + >>> v1, v2, v3 = dynamicsymbols('v1 v2 v3') + >>> v = v1*R.x + v2*R.y + v3*R.z + >>> get_motion_params(R, position = v) + (v1''*R.x + v2''*R.y + v3''*R.z, v1'*R.x + v2'*R.y + v3'*R.z, v1*R.x + v2*R.y + v3*R.z) + >>> a, b, c = symbols('a b c') + >>> v = a*R.x + b*R.y + c*R.z + >>> get_motion_params(R, velocity = v) + (0, a*R.x + b*R.y + c*R.z, a*t*R.x + b*t*R.y + c*t*R.z) + >>> parameters = get_motion_params(R, acceleration = v) + >>> parameters[1] + a*t*R.x + b*t*R.y + c*t*R.z + >>> parameters[2] + a*t**2/2*R.x + b*t**2/2*R.y + c*t**2/2*R.z + + """ + + def _process_vector_differential(vectdiff, condition, variable, ordinate, + frame): + """ + Helper function for get_motion methods. Finds derivative of vectdiff + wrt variable, and its integral using the specified boundary condition + at value of variable = ordinate. + Returns a tuple of - (derivative, function and integral) wrt vectdiff + + """ + + # Make sure boundary condition is independent of 'variable' + if condition != 0: + condition = express(condition, frame, variables=True) + # Special case of vectdiff == 0 + if vectdiff == Vector(0): + return (0, 0, condition) + # Express vectdiff completely in condition's frame to give vectdiff1 + vectdiff1 = express(vectdiff, frame) + # Find derivative of vectdiff + vectdiff2 = time_derivative(vectdiff, frame) + # Integrate and use boundary condition + vectdiff0 = Vector(0) + lims = (variable, ordinate, variable) + for dim in frame: + function1 = vectdiff1.dot(dim) + abscissa = dim.dot(condition).subs({variable: ordinate}) + # Indefinite integral of 'function1' wrt 'variable', using + # the given initial condition (ordinate, abscissa). + vectdiff0 += (integrate(function1, lims) + abscissa) * dim + # Return tuple + return (vectdiff2, vectdiff, vectdiff0) + + _check_frame(frame) + # Decide mode of operation based on user's input + if 'acceleration' in kwargs: + mode = 2 + elif 'velocity' in kwargs: + mode = 1 + else: + mode = 0 + # All the possible parameters in kwargs + # Not all are required for every case + # If not specified, set to default values(may or may not be used in + # calculations) + conditions = ['acceleration', 'velocity', 'position', + 'timevalue', 'timevalue1', 'timevalue2'] + for i, x in enumerate(conditions): + if x not in kwargs: + if i < 3: + kwargs[x] = Vector(0) + else: + kwargs[x] = S.Zero + elif i < 3: + _check_vector(kwargs[x]) + else: + kwargs[x] = sympify(kwargs[x]) + if mode == 2: + vel = _process_vector_differential(kwargs['acceleration'], + kwargs['velocity'], + dynamicsymbols._t, + kwargs['timevalue2'], frame)[2] + pos = _process_vector_differential(vel, kwargs['position'], + dynamicsymbols._t, + kwargs['timevalue1'], frame)[2] + return (kwargs['acceleration'], vel, pos) + elif mode == 1: + return _process_vector_differential(kwargs['velocity'], + kwargs['position'], + dynamicsymbols._t, + kwargs['timevalue1'], frame) + else: + vel = time_derivative(kwargs['position'], frame) + acc = time_derivative(vel, frame) + return (acc, vel, kwargs['position']) + + +def partial_velocity(vel_vecs, gen_speeds, frame): + """Returns a list of partial velocities with respect to the provided + generalized speeds in the given reference frame for each of the supplied + velocity vectors. + + The output is a list of lists. The outer list has a number of elements + equal to the number of supplied velocity vectors. The inner lists are, for + each velocity vector, the partial derivatives of that velocity vector with + respect to the generalized speeds supplied. + + Parameters + ========== + + vel_vecs : iterable + An iterable of velocity vectors (angular or linear). + gen_speeds : iterable + An iterable of generalized speeds. + frame : ReferenceFrame + The reference frame that the partial derivatives are going to be taken + in. + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import partial_velocity + >>> u = dynamicsymbols('u') + >>> N = ReferenceFrame('N') + >>> P = Point('P') + >>> P.set_vel(N, u * N.x) + >>> vel_vecs = [P.vel(N)] + >>> gen_speeds = [u] + >>> partial_velocity(vel_vecs, gen_speeds, N) + [[N.x]] + + """ + + if not iterable(vel_vecs): + raise TypeError('Velocity vectors must be contained in an iterable.') + + if not iterable(gen_speeds): + raise TypeError('Generalized speeds must be contained in an iterable') + + vec_partials = [] + gen_speeds = list(gen_speeds) + for vel in vel_vecs: + partials = [Vector(0) for _ in gen_speeds] + for components, ref in vel.args: + mat, _ = linear_eq_to_matrix(components, gen_speeds) + for i in range(len(gen_speeds)): + for dim, direction in enumerate(ref): + if mat[dim, i] != 0: + partials[i] += direction * mat[dim, i] + + vec_partials.append(partials) + + return vec_partials + + +def dynamicsymbols(names, level=0, **assumptions): + """Uses symbols and Function for functions of time. + + Creates a SymPy UndefinedFunction, which is then initialized as a function + of a variable, the default being Symbol('t'). + + Parameters + ========== + + names : str + Names of the dynamic symbols you want to create; works the same way as + inputs to symbols + level : int + Level of differentiation of the returned function; d/dt once of t, + twice of t, etc. + assumptions : + - real(bool) : This is used to set the dynamicsymbol as real, + by default is False. + - positive(bool) : This is used to set the dynamicsymbol as positive, + by default is False. + - commutative(bool) : This is used to set the commutative property of + a dynamicsymbol, by default is True. + - integer(bool) : This is used to set the dynamicsymbol as integer, + by default is False. + + Examples + ======== + + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy import diff, Symbol + >>> q1 = dynamicsymbols('q1') + >>> q1 + q1(t) + >>> q2 = dynamicsymbols('q2', real=True) + >>> q2.is_real + True + >>> q3 = dynamicsymbols('q3', positive=True) + >>> q3.is_positive + True + >>> q4, q5 = dynamicsymbols('q4,q5', commutative=False) + >>> bool(q4*q5 != q5*q4) + True + >>> q6 = dynamicsymbols('q6', integer=True) + >>> q6.is_integer + True + >>> diff(q1, Symbol('t')) + Derivative(q1(t), t) + + """ + esses = symbols(names, cls=Function, **assumptions) + t = dynamicsymbols._t + if iterable(esses): + esses = [reduce(diff, [t] * level, e(t)) for e in esses] + return esses + else: + return reduce(diff, [t] * level, esses(t)) + + +dynamicsymbols._t = Symbol('t') # type: ignore +dynamicsymbols._str = '\'' # type: ignore diff --git a/lib/python3.10/site-packages/sympy/physics/vector/point.py b/lib/python3.10/site-packages/sympy/physics/vector/point.py new file mode 100644 index 0000000000000000000000000000000000000000..39d61aa89694fbd7e3eca21b15fc658b379782fb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/point.py @@ -0,0 +1,635 @@ +from .vector import Vector, _check_vector +from .frame import _check_frame +from warnings import warn +from sympy.utilities.misc import filldedent + +__all__ = ['Point'] + + +class Point: + """This object represents a point in a dynamic system. + + It stores the: position, velocity, and acceleration of a point. + The position is a vector defined as the vector distance from a parent + point to this point. + + Parameters + ========== + + name : string + The display name of the Point + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> N = ReferenceFrame('N') + >>> O = Point('O') + >>> P = Point('P') + >>> u1, u2, u3 = dynamicsymbols('u1 u2 u3') + >>> O.set_vel(N, u1 * N.x + u2 * N.y + u3 * N.z) + >>> O.acc(N) + u1'*N.x + u2'*N.y + u3'*N.z + + ``symbols()`` can be used to create multiple Points in a single step, for + example: + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> from sympy import symbols + >>> N = ReferenceFrame('N') + >>> u1, u2 = dynamicsymbols('u1 u2') + >>> A, B = symbols('A B', cls=Point) + >>> type(A) + + >>> A.set_vel(N, u1 * N.x + u2 * N.y) + >>> B.set_vel(N, u2 * N.x + u1 * N.y) + >>> A.acc(N) - B.acc(N) + (u1' - u2')*N.x + (-u1' + u2')*N.y + + """ + + def __init__(self, name): + """Initialization of a Point object. """ + self.name = name + self._pos_dict = {} + self._vel_dict = {} + self._acc_dict = {} + self._pdlist = [self._pos_dict, self._vel_dict, self._acc_dict] + + def __str__(self): + return self.name + + __repr__ = __str__ + + def _check_point(self, other): + if not isinstance(other, Point): + raise TypeError('A Point must be supplied') + + def _pdict_list(self, other, num): + """Returns a list of points that gives the shortest path with respect + to position, velocity, or acceleration from this point to the provided + point. + + Parameters + ========== + other : Point + A point that may be related to this point by position, velocity, or + acceleration. + num : integer + 0 for searching the position tree, 1 for searching the velocity + tree, and 2 for searching the acceleration tree. + + Returns + ======= + list of Points + A sequence of points from self to other. + + Notes + ===== + + It is not clear if num = 1 or num = 2 actually works because the keys + to ``_vel_dict`` and ``_acc_dict`` are :class:`ReferenceFrame` objects + which do not have the ``_pdlist`` attribute. + + """ + outlist = [[self]] + oldlist = [[]] + while outlist != oldlist: + oldlist = outlist[:] + for v in outlist: + templist = v[-1]._pdlist[num].keys() + for v2 in templist: + if not v.__contains__(v2): + littletemplist = v + [v2] + if not outlist.__contains__(littletemplist): + outlist.append(littletemplist) + for v in oldlist: + if v[-1] != other: + outlist.remove(v) + outlist.sort(key=len) + if len(outlist) != 0: + return outlist[0] + raise ValueError('No Connecting Path found between ' + other.name + + ' and ' + self.name) + + def a1pt_theory(self, otherpoint, outframe, interframe): + """Sets the acceleration of this point with the 1-point theory. + + The 1-point theory for point acceleration looks like this: + + ^N a^P = ^B a^P + ^N a^O + ^N alpha^B x r^OP + ^N omega^B x (^N omega^B + x r^OP) + 2 ^N omega^B x ^B v^P + + where O is a point fixed in B, P is a point moving in B, and B is + rotating in frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 1-point theory (O) + outframe : ReferenceFrame + The frame we want this point's acceleration defined in (N) + fixedframe : ReferenceFrame + The intermediate frame in this calculation (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> q2 = dynamicsymbols('q2') + >>> qd = dynamicsymbols('q', 1) + >>> q2d = dynamicsymbols('q2', 1) + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.set_ang_vel(N, 5 * B.y) + >>> O = Point('O') + >>> P = O.locatenew('P', q * B.x + q2 * B.y) + >>> P.set_vel(B, qd * B.x + q2d * B.y) + >>> O.set_vel(N, 0) + >>> P.a1pt_theory(O, N, B) + (-25*q + q'')*B.x + q2''*B.y - 10*q'*B.z + + """ + + _check_frame(outframe) + _check_frame(interframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v = self.vel(interframe) + a1 = otherpoint.acc(outframe) + a2 = self.acc(interframe) + omega = interframe.ang_vel_in(outframe) + alpha = interframe.ang_acc_in(outframe) + self.set_acc(outframe, a2 + 2 * (omega.cross(v)) + a1 + + (alpha.cross(dist)) + (omega.cross(omega.cross(dist)))) + return self.acc(outframe) + + def a2pt_theory(self, otherpoint, outframe, fixedframe): + """Sets the acceleration of this point with the 2-point theory. + + The 2-point theory for point acceleration looks like this: + + ^N a^P = ^N a^O + ^N alpha^B x r^OP + ^N omega^B x (^N omega^B x r^OP) + + where O and P are both points fixed in frame B, which is rotating in + frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 2-point theory (O) + outframe : ReferenceFrame + The frame we want this point's acceleration defined in (N) + fixedframe : ReferenceFrame + The frame in which both points are fixed (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> N = ReferenceFrame('N') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> O = Point('O') + >>> P = O.locatenew('P', 10 * B.x) + >>> O.set_vel(N, 5 * N.x) + >>> P.a2pt_theory(O, N, B) + - 10*q'**2*B.x + 10*q''*B.y + + """ + + _check_frame(outframe) + _check_frame(fixedframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + a = otherpoint.acc(outframe) + omega = fixedframe.ang_vel_in(outframe) + alpha = fixedframe.ang_acc_in(outframe) + self.set_acc(outframe, a + (alpha.cross(dist)) + + (omega.cross(omega.cross(dist)))) + return self.acc(outframe) + + def acc(self, frame): + """The acceleration Vector of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the returned acceleration vector will be defined + in. + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_acc(N, 10 * N.x) + >>> p1.acc(N) + 10*N.x + + """ + + _check_frame(frame) + if not (frame in self._acc_dict): + if self.vel(frame) != 0: + return (self._vel_dict[frame]).dt(frame) + else: + return Vector(0) + return self._acc_dict[frame] + + def locatenew(self, name, value): + """Creates a new point with a position defined from this point. + + Parameters + ========== + + name : str + The name for the new point + value : Vector + The position of the new point relative to this point + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> N = ReferenceFrame('N') + >>> P1 = Point('P1') + >>> P2 = P1.locatenew('P2', 10 * N.x) + + """ + + if not isinstance(name, str): + raise TypeError('Must supply a valid name') + if value == 0: + value = Vector(0) + value = _check_vector(value) + p = Point(name) + p.set_pos(self, value) + self.set_pos(p, -value) + return p + + def pos_from(self, otherpoint): + """Returns a Vector distance between this Point and the other Point. + + Parameters + ========== + + otherpoint : Point + The otherpoint we are locating this one relative to + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p2 = Point('p2') + >>> p1.set_pos(p2, 10 * N.x) + >>> p1.pos_from(p2) + 10*N.x + + """ + + outvec = Vector(0) + plist = self._pdict_list(otherpoint, 0) + for i in range(len(plist) - 1): + outvec += plist[i]._pos_dict[plist[i + 1]] + return outvec + + def set_acc(self, frame, value): + """Used to set the acceleration of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which this point's acceleration is defined + value : Vector + The vector value of this point's acceleration in the frame + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_acc(N, 10 * N.x) + >>> p1.acc(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(frame) + self._acc_dict.update({frame: value}) + + def set_pos(self, otherpoint, value): + """Used to set the position of this point w.r.t. another point. + + Parameters + ========== + + otherpoint : Point + The other point which this point's location is defined relative to + value : Vector + The vector which defines the location of this point + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p2 = Point('p2') + >>> p1.set_pos(p2, 10 * N.x) + >>> p1.pos_from(p2) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + self._check_point(otherpoint) + self._pos_dict.update({otherpoint: value}) + otherpoint._pos_dict.update({self: -value}) + + def set_vel(self, frame, value): + """Sets the velocity Vector of this Point in a ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which this point's velocity is defined + value : Vector + The vector value of this point's velocity in the frame + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_vel(N, 10 * N.x) + >>> p1.vel(N) + 10*N.x + + """ + + if value == 0: + value = Vector(0) + value = _check_vector(value) + _check_frame(frame) + self._vel_dict.update({frame: value}) + + def v1pt_theory(self, otherpoint, outframe, interframe): + """Sets the velocity of this point with the 1-point theory. + + The 1-point theory for point velocity looks like this: + + ^N v^P = ^B v^P + ^N v^O + ^N omega^B x r^OP + + where O is a point fixed in B, P is a point moving in B, and B is + rotating in frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 1-point theory (O) + outframe : ReferenceFrame + The frame we want this point's velocity defined in (N) + interframe : ReferenceFrame + The intermediate frame in this calculation (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame + >>> from sympy.physics.vector import dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> q2 = dynamicsymbols('q2') + >>> qd = dynamicsymbols('q', 1) + >>> q2d = dynamicsymbols('q2', 1) + >>> N = ReferenceFrame('N') + >>> B = ReferenceFrame('B') + >>> B.set_ang_vel(N, 5 * B.y) + >>> O = Point('O') + >>> P = O.locatenew('P', q * B.x + q2 * B.y) + >>> P.set_vel(B, qd * B.x + q2d * B.y) + >>> O.set_vel(N, 0) + >>> P.v1pt_theory(O, N, B) + q'*B.x + q2'*B.y - 5*q*B.z + + """ + + _check_frame(outframe) + _check_frame(interframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v1 = self.vel(interframe) + v2 = otherpoint.vel(outframe) + omega = interframe.ang_vel_in(outframe) + self.set_vel(outframe, v1 + v2 + (omega.cross(dist))) + return self.vel(outframe) + + def v2pt_theory(self, otherpoint, outframe, fixedframe): + """Sets the velocity of this point with the 2-point theory. + + The 2-point theory for point velocity looks like this: + + ^N v^P = ^N v^O + ^N omega^B x r^OP + + where O and P are both points fixed in frame B, which is rotating in + frame N. + + Parameters + ========== + + otherpoint : Point + The first point of the 2-point theory (O) + outframe : ReferenceFrame + The frame we want this point's velocity defined in (N) + fixedframe : ReferenceFrame + The frame in which both points are fixed (B) + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q = dynamicsymbols('q') + >>> qd = dynamicsymbols('q', 1) + >>> N = ReferenceFrame('N') + >>> B = N.orientnew('B', 'Axis', [q, N.z]) + >>> O = Point('O') + >>> P = O.locatenew('P', 10 * B.x) + >>> O.set_vel(N, 5 * N.x) + >>> P.v2pt_theory(O, N, B) + 5*N.x + 10*q'*B.y + + """ + + _check_frame(outframe) + _check_frame(fixedframe) + self._check_point(otherpoint) + dist = self.pos_from(otherpoint) + v = otherpoint.vel(outframe) + omega = fixedframe.ang_vel_in(outframe) + self.set_vel(outframe, v + (omega.cross(dist))) + return self.vel(outframe) + + def vel(self, frame): + """The velocity Vector of this Point in the ReferenceFrame. + + Parameters + ========== + + frame : ReferenceFrame + The frame in which the returned velocity vector will be defined in + + Examples + ======== + + >>> from sympy.physics.vector import Point, ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> p1 = Point('p1') + >>> p1.set_vel(N, 10 * N.x) + >>> p1.vel(N) + 10*N.x + + Velocities will be automatically calculated if possible, otherwise a + ``ValueError`` will be returned. If it is possible to calculate + multiple different velocities from the relative points, the points + defined most directly relative to this point will be used. In the case + of inconsistent relative positions of points, incorrect velocities may + be returned. It is up to the user to define prior relative positions + and velocities of points in a self-consistent way. + + >>> p = Point('p') + >>> q = dynamicsymbols('q') + >>> p.set_vel(N, 10 * N.x) + >>> p2 = Point('p2') + >>> p2.set_pos(p, q*N.x) + >>> p2.vel(N) + (Derivative(q(t), t) + 10)*N.x + + """ + + _check_frame(frame) + if not (frame in self._vel_dict): + valid_neighbor_found = False + is_cyclic = False + visited = [] + queue = [self] + candidate_neighbor = [] + while queue: # BFS to find nearest point + node = queue.pop(0) + if node not in visited: + visited.append(node) + for neighbor, neighbor_pos in node._pos_dict.items(): + if neighbor in visited: + continue + try: + # Checks if pos vector is valid + neighbor_pos.express(frame) + except ValueError: + continue + if neighbor in queue: + is_cyclic = True + try: + # Checks if point has its vel defined in req frame + neighbor_velocity = neighbor._vel_dict[frame] + except KeyError: + queue.append(neighbor) + continue + candidate_neighbor.append(neighbor) + if not valid_neighbor_found: + self.set_vel(frame, self.pos_from(neighbor).dt(frame) + neighbor_velocity) + valid_neighbor_found = True + if is_cyclic: + warn(filldedent(""" + Kinematic loops are defined among the positions of points. This + is likely not desired and may cause errors in your calculations. + """)) + if len(candidate_neighbor) > 1: + warn(filldedent(f""" + Velocity of {self.name} automatically calculated based on point + {candidate_neighbor[0].name} but it is also possible from + points(s): {str(candidate_neighbor[1:])}. Velocities from these + points are not necessarily the same. This may cause errors in + your calculations.""")) + if valid_neighbor_found: + return self._vel_dict[frame] + else: + raise ValueError(filldedent(f""" + Velocity of point {self.name} has not been defined in + ReferenceFrame {frame.name}.""")) + + return self._vel_dict[frame] + + def partial_velocity(self, frame, *gen_speeds): + """Returns the partial velocities of the linear velocity vector of this + point in the given frame with respect to one or more provided + generalized speeds. + + Parameters + ========== + frame : ReferenceFrame + The frame with which the velocity is defined in. + gen_speeds : functions of time + The generalized speeds. + + Returns + ======= + partial_velocities : tuple of Vector + The partial velocity vectors corresponding to the provided + generalized speeds. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, Point + >>> from sympy.physics.vector import dynamicsymbols + >>> N = ReferenceFrame('N') + >>> A = ReferenceFrame('A') + >>> p = Point('p') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> p.set_vel(N, u1 * N.x + u2 * A.y) + >>> p.partial_velocity(N, u1) + N.x + >>> p.partial_velocity(N, u1, u2) + (N.x, A.y) + + """ + + from sympy.physics.vector.functions import partial_velocity + + vel = self.vel(frame) + partials = partial_velocity([vel], gen_speeds, frame)[0] + + if len(partials) == 1: + return partials[0] + else: + return tuple(partials) diff --git a/lib/python3.10/site-packages/sympy/physics/vector/printing.py b/lib/python3.10/site-packages/sympy/physics/vector/printing.py new file mode 100644 index 0000000000000000000000000000000000000000..2b589f673329e1e598b9b568fba6c07b8abe67bc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/printing.py @@ -0,0 +1,371 @@ +from sympy.core.function import Derivative +from sympy.core.function import UndefinedFunction, AppliedUndef +from sympy.core.symbol import Symbol +from sympy.interactive.printing import init_printing +from sympy.printing.latex import LatexPrinter +from sympy.printing.pretty.pretty import PrettyPrinter +from sympy.printing.pretty.pretty_symbology import center_accent +from sympy.printing.str import StrPrinter +from sympy.printing.precedence import PRECEDENCE + +__all__ = ['vprint', 'vsstrrepr', 'vsprint', 'vpprint', 'vlatex', + 'init_vprinting'] + + +class VectorStrPrinter(StrPrinter): + """String Printer for vector expressions. """ + + def _print_Derivative(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + if (bool(sum(i == t for i in e.variables)) & + isinstance(type(e.args[0]), UndefinedFunction)): + ol = str(e.args[0].func) + for i, v in enumerate(e.variables): + ol += dynamicsymbols._str + return ol + else: + return StrPrinter().doprint(e) + + def _print_Function(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + if isinstance(type(e), UndefinedFunction): + return StrPrinter().doprint(e).replace("(%s)" % t, '') + return e.func.__name__ + "(%s)" % self.stringify(e.args, ", ") + + +class VectorStrReprPrinter(VectorStrPrinter): + """String repr printer for vector expressions.""" + def _print_str(self, s): + return repr(s) + + +class VectorLatexPrinter(LatexPrinter): + """Latex Printer for vector expressions. """ + + def _print_Function(self, expr, exp=None): + from sympy.physics.vector.functions import dynamicsymbols + func = expr.func.__name__ + t = dynamicsymbols._t + + if (hasattr(self, '_print_' + func) and not + isinstance(type(expr), UndefinedFunction)): + return getattr(self, '_print_' + func)(expr, exp) + elif isinstance(type(expr), UndefinedFunction) and (expr.args == (t,)): + # treat this function like a symbol + expr = Symbol(func) + if exp is not None: + # copied from LatexPrinter._helper_print_standard_power, which + # we can't call because we only have exp as a string. + base = self.parenthesize(expr, PRECEDENCE['Pow']) + base = self.parenthesize_super(base) + return r"%s^{%s}" % (base, exp) + else: + return super()._print(expr) + else: + return super()._print_Function(expr, exp) + + def _print_Derivative(self, der_expr): + from sympy.physics.vector.functions import dynamicsymbols + # make sure it is in the right form + der_expr = der_expr.doit() + if not isinstance(der_expr, Derivative): + return r"\left(%s\right)" % self.doprint(der_expr) + + # check if expr is a dynamicsymbol + t = dynamicsymbols._t + expr = der_expr.expr + red = expr.atoms(AppliedUndef) + syms = der_expr.variables + test1 = not all(True for i in red if i.free_symbols == {t}) + test2 = not all(t == i for i in syms) + if test1 or test2: + return super()._print_Derivative(der_expr) + + # done checking + dots = len(syms) + base = self._print_Function(expr) + base_split = base.split('_', 1) + base = base_split[0] + if dots == 1: + base = r"\dot{%s}" % base + elif dots == 2: + base = r"\ddot{%s}" % base + elif dots == 3: + base = r"\dddot{%s}" % base + elif dots == 4: + base = r"\ddddot{%s}" % base + else: # Fallback to standard printing + return super()._print_Derivative(der_expr) + if len(base_split) != 1: + base += '_' + base_split[1] + return base + + +class VectorPrettyPrinter(PrettyPrinter): + """Pretty Printer for vectorialexpressions. """ + + def _print_Derivative(self, deriv): + from sympy.physics.vector.functions import dynamicsymbols + # XXX use U('PARTIAL DIFFERENTIAL') here ? + t = dynamicsymbols._t + dot_i = 0 + syms = list(reversed(deriv.variables)) + + while len(syms) > 0: + if syms[-1] == t: + syms.pop() + dot_i += 1 + else: + return super()._print_Derivative(deriv) + + if not (isinstance(type(deriv.expr), UndefinedFunction) and + (deriv.expr.args == (t,))): + return super()._print_Derivative(deriv) + else: + pform = self._print_Function(deriv.expr) + + # the following condition would happen with some sort of non-standard + # dynamic symbol I guess, so we'll just print the SymPy way + if len(pform.picture) > 1: + return super()._print_Derivative(deriv) + + # There are only special symbols up to fourth-order derivatives + if dot_i >= 5: + return super()._print_Derivative(deriv) + + # Deal with special symbols + dots = {0: "", + 1: "\N{COMBINING DOT ABOVE}", + 2: "\N{COMBINING DIAERESIS}", + 3: "\N{COMBINING THREE DOTS ABOVE}", + 4: "\N{COMBINING FOUR DOTS ABOVE}"} + + d = pform.__dict__ + # if unicode is false then calculate number of apostrophes needed and + # add to output + if not self._use_unicode: + apostrophes = "" + for i in range(0, dot_i): + apostrophes += "'" + d['picture'][0] += apostrophes + "(t)" + else: + d['picture'] = [center_accent(d['picture'][0], dots[dot_i])] + return pform + + def _print_Function(self, e): + from sympy.physics.vector.functions import dynamicsymbols + t = dynamicsymbols._t + # XXX works only for applied functions + func = e.func + args = e.args + func_name = func.__name__ + pform = self._print_Symbol(Symbol(func_name)) + # If this function is an Undefined function of t, it is probably a + # dynamic symbol, so we'll skip the (t). The rest of the code is + # identical to the normal PrettyPrinter code + if not (isinstance(func, UndefinedFunction) and (args == (t,))): + return super()._print_Function(e) + return pform + + +def vprint(expr, **settings): + r"""Function for printing of expressions generated in the + sympy.physics vector package. + + Extends SymPy's StrPrinter, takes the same setting accepted by SymPy's + :func:`~.sstr`, and is equivalent to ``print(sstr(foo))``. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print. + settings : args + Same as the settings accepted by SymPy's sstr(). + + Examples + ======== + + >>> from sympy.physics.vector import vprint, dynamicsymbols + >>> u1 = dynamicsymbols('u1') + >>> print(u1) + u1(t) + >>> vprint(u1) + u1 + + """ + + outstr = vsprint(expr, **settings) + + import builtins + if (outstr != 'None'): + builtins._ = outstr + print(outstr) + + +def vsstrrepr(expr, **settings): + """Function for displaying expression representation's with vector + printing enabled. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print. + settings : args + Same as the settings accepted by SymPy's sstrrepr(). + + """ + p = VectorStrReprPrinter(settings) + return p.doprint(expr) + + +def vsprint(expr, **settings): + r"""Function for displaying expressions generated in the + sympy.physics vector package. + + Returns the output of vprint() as a string. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to print + settings : args + Same as the settings accepted by SymPy's sstr(). + + Examples + ======== + + >>> from sympy.physics.vector import vsprint, dynamicsymbols + >>> u1, u2 = dynamicsymbols('u1 u2') + >>> u2d = dynamicsymbols('u2', level=1) + >>> print("%s = %s" % (u1, u2 + u2d)) + u1(t) = u2(t) + Derivative(u2(t), t) + >>> print("%s = %s" % (vsprint(u1), vsprint(u2 + u2d))) + u1 = u2 + u2' + + """ + + string_printer = VectorStrPrinter(settings) + return string_printer.doprint(expr) + + +def vpprint(expr, **settings): + r"""Function for pretty printing of expressions generated in the + sympy.physics vector package. + + Mainly used for expressions not inside a vector; the output of running + scripts and generating equations of motion. Takes the same options as + SymPy's :func:`~.pretty_print`; see that function for more information. + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to pretty print + settings : args + Same as those accepted by SymPy's pretty_print. + + + """ + + pp = VectorPrettyPrinter(settings) + + # Note that this is copied from sympy.printing.pretty.pretty_print: + + # XXX: this is an ugly hack, but at least it works + use_unicode = pp._settings['use_unicode'] + from sympy.printing.pretty.pretty_symbology import pretty_use_unicode + uflag = pretty_use_unicode(use_unicode) + + try: + return pp.doprint(expr) + finally: + pretty_use_unicode(uflag) + + +def vlatex(expr, **settings): + r"""Function for printing latex representation of sympy.physics.vector + objects. + + For latex representation of Vectors, Dyadics, and dynamicsymbols. Takes the + same options as SymPy's :func:`~.latex`; see that function for more + information; + + Parameters + ========== + + expr : valid SymPy object + SymPy expression to represent in LaTeX form + settings : args + Same as latex() + + Examples + ======== + + >>> from sympy.physics.vector import vlatex, ReferenceFrame, dynamicsymbols + >>> N = ReferenceFrame('N') + >>> q1, q2 = dynamicsymbols('q1 q2') + >>> q1d, q2d = dynamicsymbols('q1 q2', 1) + >>> q1dd, q2dd = dynamicsymbols('q1 q2', 2) + >>> vlatex(N.x + N.y) + '\\mathbf{\\hat{n}_x} + \\mathbf{\\hat{n}_y}' + >>> vlatex(q1 + q2) + 'q_{1} + q_{2}' + >>> vlatex(q1d) + '\\dot{q}_{1}' + >>> vlatex(q1 * q2d) + 'q_{1} \\dot{q}_{2}' + >>> vlatex(q1dd * q1 / q1d) + '\\frac{q_{1} \\ddot{q}_{1}}{\\dot{q}_{1}}' + + """ + latex_printer = VectorLatexPrinter(settings) + + return latex_printer.doprint(expr) + + +def init_vprinting(**kwargs): + """Initializes time derivative printing for all SymPy objects, i.e. any + functions of time will be displayed in a more compact notation. The main + benefit of this is for printing of time derivatives; instead of + displaying as ``Derivative(f(t),t)``, it will display ``f'``. This is + only actually needed for when derivatives are present and are not in a + physics.vector.Vector or physics.vector.Dyadic object. This function is a + light wrapper to :func:`~.init_printing`. Any keyword + arguments for it are valid here. + + {0} + + Examples + ======== + + >>> from sympy import Function, symbols + >>> t, x = symbols('t, x') + >>> omega = Function('omega') + >>> omega(x).diff() + Derivative(omega(x), x) + >>> omega(t).diff() + Derivative(omega(t), t) + + Now use the string printer: + + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> omega(x).diff() + Derivative(omega(x), x) + >>> omega(t).diff() + omega' + + """ + kwargs['str_printer'] = vsstrrepr + kwargs['pretty_printer'] = vpprint + kwargs['latex_printer'] = vlatex + init_printing(**kwargs) + + +params = init_printing.__doc__.split('Examples\n ========')[0] # type: ignore +init_vprinting.__doc__ = init_vprinting.__doc__.format(params) # type: ignore diff --git a/lib/python3.10/site-packages/sympy/physics/vector/vector.py b/lib/python3.10/site-packages/sympy/physics/vector/vector.py new file mode 100644 index 0000000000000000000000000000000000000000..d27f709353b909c1eb4584495e76b91b1a18af66 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/physics/vector/vector.py @@ -0,0 +1,806 @@ +from sympy import (S, sympify, expand, sqrt, Add, zeros, acos, + ImmutableMatrix as Matrix, simplify) +from sympy.simplify.trigsimp import trigsimp +from sympy.printing.defaults import Printable +from sympy.utilities.misc import filldedent +from sympy.core.evalf import EvalfMixin + +from mpmath.libmp.libmpf import prec_to_dps + + +__all__ = ['Vector'] + + +class Vector(Printable, EvalfMixin): + """The class used to define vectors. + + It along with ReferenceFrame are the building blocks of describing a + classical mechanics system in PyDy and sympy.physics.vector. + + Attributes + ========== + + simp : Boolean + Let certain methods use trigsimp on their outputs + + """ + + simp = False + is_number = False + + def __init__(self, inlist): + """This is the constructor for the Vector class. You should not be + calling this, it should only be used by other functions. You should be + treating Vectors like you would with if you were doing the math by + hand, and getting the first 3 from the standard basis vectors from a + ReferenceFrame. + + The only exception is to create a zero vector: + zv = Vector(0) + + """ + + self.args = [] + if inlist == 0: + inlist = [] + if isinstance(inlist, dict): + d = inlist + else: + d = {} + for inp in inlist: + if inp[1] in d: + d[inp[1]] += inp[0] + else: + d[inp[1]] = inp[0] + + for k, v in d.items(): + if v != Matrix([0, 0, 0]): + self.args.append((v, k)) + + @property + def func(self): + """Returns the class Vector. """ + return Vector + + def __hash__(self): + return hash(tuple(self.args)) + + def __add__(self, other): + """The add operator for Vector. """ + if other == 0: + return self + other = _check_vector(other) + return Vector(self.args + other.args) + + def dot(self, other): + """Dot product of two vectors. + + Returns a scalar, the dot product of the two Vectors + + Parameters + ========== + + other : Vector + The Vector which we are dotting with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dot + >>> from sympy import symbols + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> dot(N.x, N.x) + 1 + >>> dot(N.x, N.y) + 0 + >>> A = N.orientnew('A', 'Axis', [q1, N.x]) + >>> dot(N.y, A.y) + cos(q1) + + """ + + from sympy.physics.vector.dyadic import Dyadic, _check_dyadic + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Vector(0) + for v in other.args: + ol += v[0] * v[2] * (v[1].dot(self)) + return ol + other = _check_vector(other) + out = S.Zero + for v1 in self.args: + for v2 in other.args: + out += ((v2[0].T) * (v2[1].dcm(v1[1])) * (v1[0]))[0] + if Vector.simp: + return trigsimp(out, recursive=True) + else: + return out + + def __truediv__(self, other): + """This uses mul and inputs self and 1 divided by other. """ + return self.__mul__(S.One / other) + + def __eq__(self, other): + """Tests for equality. + + It is very import to note that this is only as good as the SymPy + equality test; False does not always mean they are not equivalent + Vectors. + If other is 0, and self is empty, returns True. + If other is 0 and self is not empty, returns False. + If none of the above, only accepts other as a Vector. + + """ + + if other == 0: + other = Vector(0) + try: + other = _check_vector(other) + except TypeError: + return False + if (self.args == []) and (other.args == []): + return True + elif (self.args == []) or (other.args == []): + return False + + frame = self.args[0][1] + for v in frame: + if expand((self - other).dot(v)) != 0: + return False + return True + + def __mul__(self, other): + """Multiplies the Vector by a sympifyable expression. + + Parameters + ========== + + other : Sympifyable + The scalar to multiply this Vector with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> b = Symbol('b') + >>> V = 10 * b * N.x + >>> print(V) + 10*b*N.x + + """ + + newlist = list(self.args) + other = sympify(other) + for i, v in enumerate(newlist): + newlist[i] = (other * newlist[i][0], newlist[i][1]) + return Vector(newlist) + + def __neg__(self): + return self * -1 + + def outer(self, other): + """Outer product between two Vectors. + + A rank increasing operation, which returns a Dyadic from two Vectors + + Parameters + ========== + + other : Vector + The Vector to take the outer product with + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, outer + >>> N = ReferenceFrame('N') + >>> outer(N.x, N.x) + (N.x|N.x) + + """ + + from sympy.physics.vector.dyadic import Dyadic + other = _check_vector(other) + ol = Dyadic(0) + for v in self.args: + for v2 in other.args: + # it looks this way because if we are in the same frame and + # use the enumerate function on the same frame in a nested + # fashion, then bad things happen + ol += Dyadic([(v[0][0] * v2[0][0], v[1].x, v2[1].x)]) + ol += Dyadic([(v[0][0] * v2[0][1], v[1].x, v2[1].y)]) + ol += Dyadic([(v[0][0] * v2[0][2], v[1].x, v2[1].z)]) + ol += Dyadic([(v[0][1] * v2[0][0], v[1].y, v2[1].x)]) + ol += Dyadic([(v[0][1] * v2[0][1], v[1].y, v2[1].y)]) + ol += Dyadic([(v[0][1] * v2[0][2], v[1].y, v2[1].z)]) + ol += Dyadic([(v[0][2] * v2[0][0], v[1].z, v2[1].x)]) + ol += Dyadic([(v[0][2] * v2[0][1], v[1].z, v2[1].y)]) + ol += Dyadic([(v[0][2] * v2[0][2], v[1].z, v2[1].z)]) + return ol + + def _latex(self, printer): + """Latex Printing method. """ + + ar = self.args # just to shorten things + if len(ar) == 0: + return str(0) + ol = [] # output list, to be concatenated to a string + for i, v in enumerate(ar): + for j in 0, 1, 2: + # if the coef of the basis vector is 1, we skip the 1 + if ar[i][0][j] == 1: + ol.append(' + ' + ar[i][1].latex_vecs[j]) + # if the coef of the basis vector is -1, we skip the 1 + elif ar[i][0][j] == -1: + ol.append(' - ' + ar[i][1].latex_vecs[j]) + elif ar[i][0][j] != 0: + # If the coefficient of the basis vector is not 1 or -1; + # also, we might wrap it in parentheses, for readability. + arg_str = printer._print(ar[i][0][j]) + if isinstance(ar[i][0][j], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + ar[i][1].latex_vecs[j]) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def _pretty(self, printer): + """Pretty Printing method. """ + from sympy.printing.pretty.stringpict import prettyForm + + terms = [] + + def juxtapose(a, b): + pa = printer._print(a) + pb = printer._print(b) + if a.is_Add: + pa = prettyForm(*pa.parens()) + return printer._print_seq([pa, pb], delimiter=' ') + + for M, N in self.args: + for i in range(3): + if M[i] == 0: + continue + elif M[i] == 1: + terms.append(prettyForm(N.pretty_vecs[i])) + elif M[i] == -1: + terms.append(prettyForm("-1") * prettyForm(N.pretty_vecs[i])) + else: + terms.append(juxtapose(M[i], N.pretty_vecs[i])) + + if terms: + pretty_result = prettyForm.__add__(*terms) + else: + pretty_result = prettyForm("0") + + return pretty_result + + def __rsub__(self, other): + return (-1 * self) + other + + def _sympystr(self, printer, order=True): + """Printing method. """ + if not order or len(self.args) == 1: + ar = list(self.args) + elif len(self.args) == 0: + return printer._print(0) + else: + d = {v[1]: v[0] for v in self.args} + keys = sorted(d.keys(), key=lambda x: x.index) + ar = [] + for key in keys: + ar.append((d[key], key)) + ol = [] # output list, to be concatenated to a string + for i, v in enumerate(ar): + for j in 0, 1, 2: + # if the coef of the basis vector is 1, we skip the 1 + if ar[i][0][j] == 1: + ol.append(' + ' + ar[i][1].str_vecs[j]) + # if the coef of the basis vector is -1, we skip the 1 + elif ar[i][0][j] == -1: + ol.append(' - ' + ar[i][1].str_vecs[j]) + elif ar[i][0][j] != 0: + # If the coefficient of the basis vector is not 1 or -1; + # also, we might wrap it in parentheses, for readability. + arg_str = printer._print(ar[i][0][j]) + if isinstance(ar[i][0][j], Add): + arg_str = "(%s)" % arg_str + if arg_str[0] == '-': + arg_str = arg_str[1:] + str_start = ' - ' + else: + str_start = ' + ' + ol.append(str_start + arg_str + '*' + ar[i][1].str_vecs[j]) + outstr = ''.join(ol) + if outstr.startswith(' + '): + outstr = outstr[3:] + elif outstr.startswith(' '): + outstr = outstr[1:] + return outstr + + def __sub__(self, other): + """The subtraction operator. """ + return self.__add__(other * -1) + + def cross(self, other): + """The cross product operator for two Vectors. + + Returns a Vector, expressed in the same ReferenceFrames as self. + + Parameters + ========== + + other : Vector + The Vector which we are crossing with + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame, cross + >>> q1 = symbols('q1') + >>> N = ReferenceFrame('N') + >>> cross(N.x, N.y) + N.z + >>> A = ReferenceFrame('A') + >>> A.orient_axis(N, q1, N.x) + >>> cross(A.x, N.y) + N.z + >>> cross(N.y, A.x) + - sin(q1)*A.y - cos(q1)*A.z + + """ + + from sympy.physics.vector.dyadic import Dyadic, _check_dyadic + if isinstance(other, Dyadic): + other = _check_dyadic(other) + ol = Dyadic(0) + for i, v in enumerate(other.args): + ol += v[0] * ((self.cross(v[1])).outer(v[2])) + return ol + other = _check_vector(other) + if other.args == []: + return Vector(0) + + def _det(mat): + """This is needed as a little method for to find the determinant + of a list in python; needs to work for a 3x3 list. + SymPy's Matrix will not take in Vector, so need a custom function. + You should not be calling this. + + """ + + return (mat[0][0] * (mat[1][1] * mat[2][2] - mat[1][2] * mat[2][1]) + + mat[0][1] * (mat[1][2] * mat[2][0] - mat[1][0] * + mat[2][2]) + mat[0][2] * (mat[1][0] * mat[2][1] - + mat[1][1] * mat[2][0])) + + outlist = [] + ar = other.args # For brevity + for i, v in enumerate(ar): + tempx = v[1].x + tempy = v[1].y + tempz = v[1].z + tempm = ([[tempx, tempy, tempz], + [self.dot(tempx), self.dot(tempy), self.dot(tempz)], + [Vector([ar[i]]).dot(tempx), Vector([ar[i]]).dot(tempy), + Vector([ar[i]]).dot(tempz)]]) + outlist += _det(tempm).args + return Vector(outlist) + + __radd__ = __add__ + __rmul__ = __mul__ + + def separate(self): + """ + The constituents of this vector in different reference frames, + as per its definition. + + Returns a dict mapping each ReferenceFrame to the corresponding + constituent Vector. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> R1 = ReferenceFrame('R1') + >>> R2 = ReferenceFrame('R2') + >>> v = R1.x + R2.x + >>> v.separate() == {R1: R1.x, R2: R2.x} + True + + """ + + components = {} + for x in self.args: + components[x[1]] = Vector([x]) + return components + + def __and__(self, other): + return self.dot(other) + __and__.__doc__ = dot.__doc__ + __rand__ = __and__ + + def __xor__(self, other): + return self.cross(other) + __xor__.__doc__ = cross.__doc__ + + def __or__(self, other): + return self.outer(other) + __or__.__doc__ = outer.__doc__ + + def diff(self, var, frame, var_in_dcm=True): + """Returns the partial derivative of the vector with respect to a + variable in the provided reference frame. + + Parameters + ========== + var : Symbol + What the partial derivative is taken with respect to. + frame : ReferenceFrame + The reference frame that the partial derivative is taken in. + var_in_dcm : boolean + If true, the differentiation algorithm assumes that the variable + may be present in any of the direction cosine matrices that relate + the frame to the frames of any component of the vector. But if it + is known that the variable is not present in the direction cosine + matrices, false can be set to skip full reexpression in the desired + frame. + + Examples + ======== + + >>> from sympy import Symbol + >>> from sympy.physics.vector import dynamicsymbols, ReferenceFrame + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> t = Symbol('t') + >>> q1 = dynamicsymbols('q1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.y]) + >>> A.x.diff(t, N) + - sin(q1)*q1'*N.x - cos(q1)*q1'*N.z + >>> A.x.diff(t, N).express(A).simplify() + - q1'*A.z + >>> B = ReferenceFrame('B') + >>> u1, u2 = dynamicsymbols('u1, u2') + >>> v = u1 * A.x + u2 * B.y + >>> v.diff(u2, N, var_in_dcm=False) + B.y + + """ + + from sympy.physics.vector.frame import _check_frame + + _check_frame(frame) + var = sympify(var) + + inlist = [] + + for vector_component in self.args: + measure_number = vector_component[0] + component_frame = vector_component[1] + if component_frame == frame: + inlist += [(measure_number.diff(var), frame)] + else: + # If the direction cosine matrix relating the component frame + # with the derivative frame does not contain the variable. + if not var_in_dcm or (frame.dcm(component_frame).diff(var) == + zeros(3, 3)): + inlist += [(measure_number.diff(var), component_frame)] + else: # else express in the frame + reexp_vec_comp = Vector([vector_component]).express(frame) + deriv = reexp_vec_comp.args[0][0].diff(var) + inlist += Vector([(deriv, frame)]).args + + return Vector(inlist) + + def express(self, otherframe, variables=False): + """ + Returns a Vector equivalent to this one, expressed in otherframe. + Uses the global express method. + + Parameters + ========== + + otherframe : ReferenceFrame + The frame for this Vector to be described in + + variables : boolean + If True, the coordinate symbols(if present) in this Vector + are re-expressed in terms otherframe + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame, dynamicsymbols + >>> from sympy.physics.vector import init_vprinting + >>> init_vprinting(pretty_print=False) + >>> q1 = dynamicsymbols('q1') + >>> N = ReferenceFrame('N') + >>> A = N.orientnew('A', 'Axis', [q1, N.y]) + >>> A.x.express(N) + cos(q1)*N.x - sin(q1)*N.z + + """ + from sympy.physics.vector import express + return express(self, otherframe, variables=variables) + + def to_matrix(self, reference_frame): + """Returns the matrix form of the vector with respect to the given + frame. + + Parameters + ---------- + reference_frame : ReferenceFrame + The reference frame that the rows of the matrix correspond to. + + Returns + ------- + matrix : ImmutableMatrix, shape(3,1) + The matrix that gives the 1D vector. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.physics.vector import ReferenceFrame + >>> a, b, c = symbols('a, b, c') + >>> N = ReferenceFrame('N') + >>> vector = a * N.x + b * N.y + c * N.z + >>> vector.to_matrix(N) + Matrix([ + [a], + [b], + [c]]) + >>> beta = symbols('beta') + >>> A = N.orientnew('A', 'Axis', (beta, N.x)) + >>> vector.to_matrix(A) + Matrix([ + [ a], + [ b*cos(beta) + c*sin(beta)], + [-b*sin(beta) + c*cos(beta)]]) + + """ + + return Matrix([self.dot(unit_vec) for unit_vec in + reference_frame]).reshape(3, 1) + + def doit(self, **hints): + """Calls .doit() on each term in the Vector""" + d = {} + for v in self.args: + d[v[1]] = v[0].applyfunc(lambda x: x.doit(**hints)) + return Vector(d) + + def dt(self, otherframe): + """ + Returns a Vector which is the time derivative of + the self Vector, taken in frame otherframe. + + Calls the global time_derivative method + + Parameters + ========== + + otherframe : ReferenceFrame + The frame to calculate the time derivative in + + """ + from sympy.physics.vector import time_derivative + return time_derivative(self, otherframe) + + def simplify(self): + """Returns a simplified Vector.""" + d = {} + for v in self.args: + d[v[1]] = simplify(v[0]) + return Vector(d) + + def subs(self, *args, **kwargs): + """Substitution on the Vector. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> from sympy import Symbol + >>> N = ReferenceFrame('N') + >>> s = Symbol('s') + >>> a = N.x * s + >>> a.subs({s: 2}) + 2*N.x + + """ + + d = {} + for v in self.args: + d[v[1]] = v[0].subs(*args, **kwargs) + return Vector(d) + + def magnitude(self): + """Returns the magnitude (Euclidean norm) of self. + + Warnings + ======== + + Python ignores the leading negative sign so that might + give wrong results. + ``-A.x.magnitude()`` would be treated as ``-(A.x.magnitude())``, + instead of ``(-A.x).magnitude()``. + + """ + return sqrt(self.dot(self)) + + def normalize(self): + """Returns a Vector of magnitude 1, codirectional with self.""" + return Vector(self.args + []) / self.magnitude() + + def applyfunc(self, f): + """Apply a function to each component of a vector.""" + if not callable(f): + raise TypeError("`f` must be callable.") + + d = {} + for v in self.args: + d[v[1]] = v[0].applyfunc(f) + return Vector(d) + + def angle_between(self, vec): + """ + Returns the smallest angle between Vector 'vec' and self. + + Parameter + ========= + + vec : Vector + The Vector between which angle is needed. + + Examples + ======== + + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame("A") + >>> v1 = A.x + >>> v2 = A.y + >>> v1.angle_between(v2) + pi/2 + + >>> v3 = A.x + A.y + A.z + >>> v1.angle_between(v3) + acos(sqrt(3)/3) + + Warnings + ======== + + Python ignores the leading negative sign so that might give wrong + results. ``-A.x.angle_between()`` would be treated as + ``-(A.x.angle_between())``, instead of ``(-A.x).angle_between()``. + + """ + + vec1 = self.normalize() + vec2 = vec.normalize() + angle = acos(vec1.dot(vec2)) + return angle + + def free_symbols(self, reference_frame): + """Returns the free symbols in the measure numbers of the vector + expressed in the given reference frame. + + Parameters + ========== + reference_frame : ReferenceFrame + The frame with respect to which the free symbols of the given + vector is to be determined. + + Returns + ======= + set of Symbol + set of symbols present in the measure numbers of + ``reference_frame``. + + """ + + return self.to_matrix(reference_frame).free_symbols + + def free_dynamicsymbols(self, reference_frame): + """Returns the free dynamic symbols (functions of time ``t``) in the + measure numbers of the vector expressed in the given reference frame. + + Parameters + ========== + reference_frame : ReferenceFrame + The frame with respect to which the free dynamic symbols of the + given vector is to be determined. + + Returns + ======= + set + Set of functions of time ``t``, e.g. + ``Function('f')(me.dynamicsymbols._t)``. + + """ + # TODO : Circular dependency if imported at top. Should move + # find_dynamicsymbols into physics.vector.functions. + from sympy.physics.mechanics.functions import find_dynamicsymbols + + return find_dynamicsymbols(self, reference_frame=reference_frame) + + def _eval_evalf(self, prec): + if not self.args: + return self + new_args = [] + dps = prec_to_dps(prec) + for mat, frame in self.args: + new_args.append([mat.evalf(n=dps), frame]) + return Vector(new_args) + + def xreplace(self, rule): + """Replace occurrences of objects within the measure numbers of the + vector. + + Parameters + ========== + + rule : dict-like + Expresses a replacement rule. + + Returns + ======= + + Vector + Result of the replacement. + + Examples + ======== + + >>> from sympy import symbols, pi + >>> from sympy.physics.vector import ReferenceFrame + >>> A = ReferenceFrame('A') + >>> x, y, z = symbols('x y z') + >>> ((1 + x*y) * A.x).xreplace({x: pi}) + (pi*y + 1)*A.x + >>> ((1 + x*y) * A.x).xreplace({x: pi, y: 2}) + (1 + 2*pi)*A.x + + Replacements occur only if an entire node in the expression tree is + matched: + + >>> ((x*y + z) * A.x).xreplace({x*y: pi}) + (z + pi)*A.x + >>> ((x*y*z) * A.x).xreplace({x*y: pi}) + x*y*z*A.x + + """ + + new_args = [] + for mat, frame in self.args: + mat = mat.xreplace(rule) + new_args.append([mat, frame]) + return Vector(new_args) + + +class VectorTypeError(TypeError): + + def __init__(self, other, want): + msg = filldedent("Expected an instance of %s, but received object " + "'%s' of %s." % (type(want), other, type(other))) + super().__init__(msg) + + +def _check_vector(other): + if not isinstance(other, Vector): + raise TypeError('A Vector must be supplied') + return other diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4fa4f5e3c0086e9fb09b9ce7a0d0841a95294ec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/experimental_lambdify.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/experimental_lambdify.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acc490faaebf0b4d7b291ee6d1e3fa5bc4ff3464 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/experimental_lambdify.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/plot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/plot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8ee32a3296eefcf3adbcd5eda8499e71d300542 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/plot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/plot_implicit.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/plot_implicit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1509465d441f8edf268cfeaa0dd85e9ec03778b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/plot_implicit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/plotgrid.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/plotgrid.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4423dbcb66d5b2dfb2b77564d23a597e2caa331 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/plotgrid.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/series.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/series.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1731cbf9ac8de7337425dc6b5d4c2f1130572f6c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/series.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/textplot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/textplot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..00f3afa6710c9d01116eb159f5a251d7015aafc2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/textplot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/__pycache__/utils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24fc037a4e689876dc18b20644b789df6f843eb8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/__pycache__/utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/__init__.py b/lib/python3.10/site-packages/sympy/plotting/backends/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/backends/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7c81314f6b8022c9addb3c038f654f16f2452492 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/backends/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/__pycache__/base_backend.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/backends/__pycache__/base_backend.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48444bcd5efd487b38c4b700c8959ca29ae94aca Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/backends/__pycache__/base_backend.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/base_backend.py b/lib/python3.10/site-packages/sympy/plotting/backends/base_backend.py new file mode 100644 index 0000000000000000000000000000000000000000..fbc18e9e973adef5a1107f89adf7977115bea98b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/backends/base_backend.py @@ -0,0 +1,419 @@ +from sympy.plotting.series import BaseSeries, GenericDataSeries +from sympy.utilities.exceptions import sympy_deprecation_warning +from sympy.utilities.iterables import is_sequence + + +__doctest_requires__ = { + ('Plot.append', 'Plot.extend'): ['matplotlib'], +} + + +# Global variable +# Set to False when running tests / doctests so that the plots don't show. +_show = True + +def unset_show(): + """ + Disable show(). For use in the tests. + """ + global _show + _show = False + + +def _deprecation_msg_m_a_r_f(attr): + sympy_deprecation_warning( + f"The `{attr}` property is deprecated. The `{attr}` keyword " + "argument should be passed to a plotting function, which generates " + "the appropriate data series. If needed, index the plot object to " + "retrieve a specific data series.", + deprecated_since_version="1.13", + active_deprecations_target="deprecated-markers-annotations-fill-rectangles", + stacklevel=4) + + +def _create_generic_data_series(**kwargs): + keywords = ["annotations", "markers", "fill", "rectangles"] + series = [] + for kw in keywords: + dictionaries = kwargs.pop(kw, []) + if dictionaries is None: + dictionaries = [] + if isinstance(dictionaries, dict): + dictionaries = [dictionaries] + for d in dictionaries: + args = d.pop("args", []) + series.append(GenericDataSeries(kw, *args, **d)) + return series + + +class Plot: + """Base class for all backends. A backend represents the plotting library, + which implements the necessary functionalities in order to use SymPy + plotting functions. + + For interactive work the function :func:`plot()` is better suited. + + This class permits the plotting of SymPy expressions using numerous + backends (:external:mod:`matplotlib`, textplot, the old pyglet module for SymPy, Google + charts api, etc). + + The figure can contain an arbitrary number of plots of SymPy expressions, + lists of coordinates of points, etc. Plot has a private attribute _series that + contains all data series to be plotted (expressions for lines or surfaces, + lists of points, etc (all subclasses of BaseSeries)). Those data series are + instances of classes not imported by ``from sympy import *``. + + The customization of the figure is on two levels. Global options that + concern the figure as a whole (e.g. title, xlabel, scale, etc) and + per-data series options (e.g. name) and aesthetics (e.g. color, point shape, + line type, etc.). + + The difference between options and aesthetics is that an aesthetic can be + a function of the coordinates (or parameters in a parametric plot). The + supported values for an aesthetic are: + + - None (the backend uses default values) + - a constant + - a function of one variable (the first coordinate or parameter) + - a function of two variables (the first and second coordinate or parameters) + - a function of three variables (only in nonparametric 3D plots) + + Their implementation depends on the backend so they may not work in some + backends. + + If the plot is parametric and the arity of the aesthetic function permits + it the aesthetic is calculated over parameters and not over coordinates. + If the arity does not permit calculation over parameters the calculation is + done over coordinates. + + Only cartesian coordinates are supported for the moment, but you can use + the parametric plots to plot in polar, spherical and cylindrical + coordinates. + + The arguments for the constructor Plot must be subclasses of BaseSeries. + + Any global option can be specified as a keyword argument. + + The global options for a figure are: + + - title : str + - xlabel : str or Symbol + - ylabel : str or Symbol + - zlabel : str or Symbol + - legend : bool + - xscale : {'linear', 'log'} + - yscale : {'linear', 'log'} + - axis : bool + - axis_center : tuple of two floats or {'center', 'auto'} + - xlim : tuple of two floats + - ylim : tuple of two floats + - aspect_ratio : tuple of two floats or {'auto'} + - autoscale : bool + - margin : float in [0, 1] + - backend : {'default', 'matplotlib', 'text'} or a subclass of BaseBackend + - size : optional tuple of two floats, (width, height); default: None + + The per data series options and aesthetics are: + There are none in the base series. See below for options for subclasses. + + Some data series support additional aesthetics or options: + + :class:`~.LineOver1DRangeSeries`, :class:`~.Parametric2DLineSeries`, and + :class:`~.Parametric3DLineSeries` support the following: + + Aesthetics: + + - line_color : string, or float, or function, optional + Specifies the color for the plot, which depends on the backend being + used. + + For example, if ``MatplotlibBackend`` is being used, then + Matplotlib string colors are acceptable (``"red"``, ``"r"``, + ``"cyan"``, ``"c"``, ...). + Alternatively, we can use a float number, 0 < color < 1, wrapped in a + string (for example, ``line_color="0.5"``) to specify grayscale colors. + Alternatively, We can specify a function returning a single + float value: this will be used to apply a color-loop (for example, + ``line_color=lambda x: math.cos(x)``). + + Note that by setting line_color, it would be applied simultaneously + to all the series. + + Options: + + - label : str + - steps : bool + - integers_only : bool + + :class:`~.SurfaceOver2DRangeSeries` and :class:`~.ParametricSurfaceSeries` + support the following: + + Aesthetics: + + - surface_color : function which returns a float. + + Notes + ===== + + How the plotting module works: + + 1. Whenever a plotting function is called, the provided expressions are + processed and a list of instances of the + :class:`~sympy.plotting.series.BaseSeries` class is created, containing + the necessary information to plot the expressions + (e.g. the expression, ranges, series name, ...). Eventually, these + objects will generate the numerical data to be plotted. + 2. A subclass of :class:`~.Plot` class is instantiaed (referred to as + backend, from now on), which stores the list of series and the main + attributes of the plot (e.g. axis labels, title, ...). + The backend implements the logic to generate the actual figure with + some plotting library. + 3. When the ``show`` command is executed, series are processed one by one + to generate numerical data and add it to the figure. The backend is also + going to set the axis labels, title, ..., according to the values stored + in the Plot instance. + + The backend should check if it supports the data series that it is given + (e.g. :class:`TextBackend` supports only + :class:`~sympy.plotting.series.LineOver1DRangeSeries`). + + It is the backend responsibility to know how to use the class of data series + that it's given. Note that the current implementation of the ``*Series`` + classes is "matplotlib-centric": the numerical data returned by the + ``get_points`` and ``get_meshes`` methods is meant to be used directly by + Matplotlib. Therefore, the new backend will have to pre-process the + numerical data to make it compatible with the chosen plotting library. + Keep in mind that future SymPy versions may improve the ``*Series`` classes + in order to return numerical data "non-matplotlib-centric", hence if you code + a new backend you have the responsibility to check if its working on each + SymPy release. + + Please explore the :class:`MatplotlibBackend` source code to understand + how a backend should be coded. + + In order to be used by SymPy plotting functions, a backend must implement + the following methods: + + * show(self): used to loop over the data series, generate the numerical + data, plot it and set the axis labels, title, ... + * save(self, path): used to save the current plot to the specified file + path. + * close(self): used to close the current plot backend (note: some plotting + library does not support this functionality. In that case, just raise a + warning). + """ + + def __init__(self, *args, + title=None, xlabel=None, ylabel=None, zlabel=None, aspect_ratio='auto', + xlim=None, ylim=None, axis_center='auto', axis=True, + xscale='linear', yscale='linear', legend=False, autoscale=True, + margin=0, annotations=None, markers=None, rectangles=None, + fill=None, backend='default', size=None, **kwargs): + + # Options for the graph as a whole. + # The possible values for each option are described in the docstring of + # Plot. They are based purely on convention, no checking is done. + self.title = title + self.xlabel = xlabel + self.ylabel = ylabel + self.zlabel = zlabel + self.aspect_ratio = aspect_ratio + self.axis_center = axis_center + self.axis = axis + self.xscale = xscale + self.yscale = yscale + self.legend = legend + self.autoscale = autoscale + self.margin = margin + self._annotations = annotations + self._markers = markers + self._rectangles = rectangles + self._fill = fill + + # Contains the data objects to be plotted. The backend should be smart + # enough to iterate over this list. + self._series = [] + self._series.extend(args) + self._series.extend(_create_generic_data_series( + annotations=annotations, markers=markers, rectangles=rectangles, + fill=fill)) + + is_real = \ + lambda lim: all(getattr(i, 'is_real', True) for i in lim) + is_finite = \ + lambda lim: all(getattr(i, 'is_finite', True) for i in lim) + + # reduce code repetition + def check_and_set(t_name, t): + if t: + if not is_real(t): + raise ValueError( + "All numbers from {}={} must be real".format(t_name, t)) + if not is_finite(t): + raise ValueError( + "All numbers from {}={} must be finite".format(t_name, t)) + setattr(self, t_name, (float(t[0]), float(t[1]))) + + self.xlim = None + check_and_set("xlim", xlim) + self.ylim = None + check_and_set("ylim", ylim) + self.size = None + check_and_set("size", size) + + @property + def _backend(self): + return self + + @property + def backend(self): + return type(self) + + def __str__(self): + series_strs = [('[%d]: ' % i) + str(s) + for i, s in enumerate(self._series)] + return 'Plot object containing:\n' + '\n'.join(series_strs) + + def __getitem__(self, index): + return self._series[index] + + def __setitem__(self, index, *args): + if len(args) == 1 and isinstance(args[0], BaseSeries): + self._series[index] = args + + def __delitem__(self, index): + del self._series[index] + + def append(self, arg): + """Adds an element from a plot's series to an existing plot. + + Examples + ======== + + Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the + second plot's first series object to the first, use the + ``append`` method, like so: + + .. plot:: + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + >>> p1 = plot(x*x, show=False) + >>> p2 = plot(x, show=False) + >>> p1.append(p2[0]) + >>> p1 + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + [1]: cartesian line: x for x over (-10.0, 10.0) + >>> p1.show() + + See Also + ======== + + extend + + """ + if isinstance(arg, BaseSeries): + self._series.append(arg) + else: + raise TypeError('Must specify element of plot to append.') + + def extend(self, arg): + """Adds all series from another plot. + + Examples + ======== + + Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the + second plot to the first, use the ``extend`` method, like so: + + .. plot:: + :format: doctest + :include-source: True + + >>> from sympy import symbols + >>> from sympy.plotting import plot + >>> x = symbols('x') + >>> p1 = plot(x**2, show=False) + >>> p2 = plot(x, -x, show=False) + >>> p1.extend(p2) + >>> p1 + Plot object containing: + [0]: cartesian line: x**2 for x over (-10.0, 10.0) + [1]: cartesian line: x for x over (-10.0, 10.0) + [2]: cartesian line: -x for x over (-10.0, 10.0) + >>> p1.show() + + """ + if isinstance(arg, Plot): + self._series.extend(arg._series) + elif is_sequence(arg): + self._series.extend(arg) + else: + raise TypeError('Expecting Plot or sequence of BaseSeries') + + def show(self): + raise NotImplementedError + + def save(self, path): + raise NotImplementedError + + def close(self): + raise NotImplementedError + + # deprecations + + @property + def markers(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("markers") + return self._markers + + @markers.setter + def markers(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("markers") + self._series.extend(_create_generic_data_series(markers=v)) + self._markers = v + + @property + def annotations(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("annotations") + return self._annotations + + @annotations.setter + def annotations(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("annotations") + self._series.extend(_create_generic_data_series(annotations=v)) + self._annotations = v + + @property + def rectangles(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("rectangles") + return self._rectangles + + @rectangles.setter + def rectangles(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("rectangles") + self._series.extend(_create_generic_data_series(rectangles=v)) + self._rectangles = v + + @property + def fill(self): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("fill") + return self._fill + + @fill.setter + def fill(self, v): + """.. deprecated:: 1.13""" + _deprecation_msg_m_a_r_f("fill") + self._series.extend(_create_generic_data_series(fill=v)) + self._fill = v diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8623940dadb9272730fdeccc1668374781c2e5cf --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__init__.py @@ -0,0 +1,5 @@ +from sympy.plotting.backends.matplotlibbackend.matplotlib import ( + MatplotlibBackend, _matplotlib_list +) + +__all__ = ["MatplotlibBackend", "_matplotlib_list"] diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e500154bfd24c09bfd05dac73029412567a983fa Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/matplotlib.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/matplotlib.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92677193fde6cc061564757b2003f4b8ff08325e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/__pycache__/matplotlib.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py new file mode 100644 index 0000000000000000000000000000000000000000..f598a10a7cd17d40e18d1438e8c6bb174071d0a6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/backends/matplotlibbackend/matplotlib.py @@ -0,0 +1,318 @@ +from collections.abc import Callable +from sympy.core.basic import Basic +from sympy.external import import_module +import sympy.plotting.backends.base_backend as base_backend +from sympy.printing.latex import latex + + +# N.B. +# When changing the minimum module version for matplotlib, please change +# the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py` + + +def _str_or_latex(label): + if isinstance(label, Basic): + return latex(label, mode='inline') + return str(label) + + +def _matplotlib_list(interval_list): + """ + Returns lists for matplotlib ``fill`` command from a list of bounding + rectangular intervals + """ + xlist = [] + ylist = [] + if len(interval_list): + for intervals in interval_list: + intervalx = intervals[0] + intervaly = intervals[1] + xlist.extend([intervalx.start, intervalx.start, + intervalx.end, intervalx.end, None]) + ylist.extend([intervaly.start, intervaly.end, + intervaly.end, intervaly.start, None]) + else: + #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill`` + xlist.extend((None, None, None, None)) + ylist.extend((None, None, None, None)) + return xlist, ylist + + +# Don't have to check for the success of importing matplotlib in each case; +# we will only be using this backend if we can successfully import matploblib +class MatplotlibBackend(base_backend.Plot): + """ This class implements the functionalities to use Matplotlib with SymPy + plotting functions. + """ + + def __init__(self, *series, **kwargs): + super().__init__(*series, **kwargs) + self.matplotlib = import_module('matplotlib', + import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']}, + min_module_version='1.1.0', catch=(RuntimeError,)) + self.plt = self.matplotlib.pyplot + self.cm = self.matplotlib.cm + self.LineCollection = self.matplotlib.collections.LineCollection + self.aspect = kwargs.get('aspect_ratio', 'auto') + if self.aspect != 'auto': + self.aspect = float(self.aspect[1]) / self.aspect[0] + # PlotGrid can provide its figure and axes to be populated with + # the data from the series. + self._plotgrid_fig = kwargs.pop("fig", None) + self._plotgrid_ax = kwargs.pop("ax", None) + + def _create_figure(self): + def set_spines(ax): + ax.spines['left'].set_position('zero') + ax.spines['right'].set_color('none') + ax.spines['bottom'].set_position('zero') + ax.spines['top'].set_color('none') + ax.xaxis.set_ticks_position('bottom') + ax.yaxis.set_ticks_position('left') + + if self._plotgrid_fig is not None: + self.fig = self._plotgrid_fig + self.ax = self._plotgrid_ax + if not any(s.is_3D for s in self._series): + set_spines(self.ax) + else: + self.fig = self.plt.figure(figsize=self.size) + if any(s.is_3D for s in self._series): + self.ax = self.fig.add_subplot(1, 1, 1, projection="3d") + else: + self.ax = self.fig.add_subplot(1, 1, 1) + set_spines(self.ax) + + @staticmethod + def get_segments(x, y, z=None): + """ Convert two list of coordinates to a list of segments to be used + with Matplotlib's :external:class:`~matplotlib.collections.LineCollection`. + + Parameters + ========== + x : list + List of x-coordinates + + y : list + List of y-coordinates + + z : list + List of z-coordinates for a 3D line. + """ + np = import_module('numpy') + if z is not None: + dim = 3 + points = (x, y, z) + else: + dim = 2 + points = (x, y) + points = np.ma.array(points).T.reshape(-1, 1, dim) + return np.ma.concatenate([points[:-1], points[1:]], axis=1) + + def _process_series(self, series, ax): + np = import_module('numpy') + mpl_toolkits = import_module( + 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']}) + + # XXX Workaround for matplotlib issue + # https://github.com/matplotlib/matplotlib/issues/17130 + xlims, ylims, zlims = [], [], [] + + for s in series: + # Create the collections + if s.is_2Dline: + if s.is_parametric: + x, y, param = s.get_data() + else: + x, y = s.get_data() + if (isinstance(s.line_color, (int, float)) or + callable(s.line_color)): + segments = self.get_segments(x, y) + collection = self.LineCollection(segments) + collection.set_array(s.get_color_array()) + ax.add_collection(collection) + else: + lbl = _str_or_latex(s.label) + line, = ax.plot(x, y, label=lbl, color=s.line_color) + elif s.is_contour: + ax.contour(*s.get_data()) + elif s.is_3Dline: + x, y, z, param = s.get_data() + if (isinstance(s.line_color, (int, float)) or + callable(s.line_color)): + art3d = mpl_toolkits.mplot3d.art3d + segments = self.get_segments(x, y, z) + collection = art3d.Line3DCollection(segments) + collection.set_array(s.get_color_array()) + ax.add_collection(collection) + else: + lbl = _str_or_latex(s.label) + ax.plot(x, y, z, label=lbl, color=s.line_color) + + xlims.append(s._xlim) + ylims.append(s._ylim) + zlims.append(s._zlim) + elif s.is_3Dsurface: + if s.is_parametric: + x, y, z, u, v = s.get_data() + else: + x, y, z = s.get_data() + collection = ax.plot_surface(x, y, z, + cmap=getattr(self.cm, 'viridis', self.cm.jet), + rstride=1, cstride=1, linewidth=0.1) + if isinstance(s.surface_color, (float, int, Callable)): + color_array = s.get_color_array() + color_array = color_array.reshape(color_array.size) + collection.set_array(color_array) + else: + collection.set_color(s.surface_color) + + xlims.append(s._xlim) + ylims.append(s._ylim) + zlims.append(s._zlim) + elif s.is_implicit: + points = s.get_data() + if len(points) == 2: + # interval math plotting + x, y = _matplotlib_list(points[0]) + ax.fill(x, y, facecolor=s.line_color, edgecolor='None') + else: + # use contourf or contour depending on whether it is + # an inequality or equality. + # XXX: ``contour`` plots multiple lines. Should be fixed. + ListedColormap = self.matplotlib.colors.ListedColormap + colormap = ListedColormap(["white", s.line_color]) + xarray, yarray, zarray, plot_type = points + if plot_type == 'contour': + ax.contour(xarray, yarray, zarray, cmap=colormap) + else: + ax.contourf(xarray, yarray, zarray, cmap=colormap) + elif s.is_generic: + if s.type == "markers": + # s.rendering_kw["color"] = s.line_color + ax.plot(*s.args, **s.rendering_kw) + elif s.type == "annotations": + ax.annotate(*s.args, **s.rendering_kw) + elif s.type == "fill": + # s.rendering_kw["color"] = s.line_color + ax.fill_between(*s.args, **s.rendering_kw) + elif s.type == "rectangles": + # s.rendering_kw["color"] = s.line_color + ax.add_patch( + self.matplotlib.patches.Rectangle( + *s.args, **s.rendering_kw)) + else: + raise NotImplementedError( + '{} is not supported in the SymPy plotting module ' + 'with matplotlib backend. Please report this issue.' + .format(ax)) + + Axes3D = mpl_toolkits.mplot3d.Axes3D + if not isinstance(ax, Axes3D): + ax.autoscale_view( + scalex=ax.get_autoscalex_on(), + scaley=ax.get_autoscaley_on()) + else: + # XXX Workaround for matplotlib issue + # https://github.com/matplotlib/matplotlib/issues/17130 + if xlims: + xlims = np.array(xlims) + xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1])) + ax.set_xlim(xlim) + else: + ax.set_xlim([0, 1]) + + if ylims: + ylims = np.array(ylims) + ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1])) + ax.set_ylim(ylim) + else: + ax.set_ylim([0, 1]) + + if zlims: + zlims = np.array(zlims) + zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1])) + ax.set_zlim(zlim) + else: + ax.set_zlim([0, 1]) + + # Set global options. + # TODO The 3D stuff + # XXX The order of those is important. + if self.xscale and not isinstance(ax, Axes3D): + ax.set_xscale(self.xscale) + if self.yscale and not isinstance(ax, Axes3D): + ax.set_yscale(self.yscale) + if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check + ax.set_autoscale_on(self.autoscale) + if self.axis_center: + val = self.axis_center + if isinstance(ax, Axes3D): + pass + elif val == 'center': + ax.spines['left'].set_position('center') + ax.spines['bottom'].set_position('center') + elif val == 'auto': + xl, xh = ax.get_xlim() + yl, yh = ax.get_ylim() + pos_left = ('data', 0) if xl*xh <= 0 else 'center' + pos_bottom = ('data', 0) if yl*yh <= 0 else 'center' + ax.spines['left'].set_position(pos_left) + ax.spines['bottom'].set_position(pos_bottom) + else: + ax.spines['left'].set_position(('data', val[0])) + ax.spines['bottom'].set_position(('data', val[1])) + if not self.axis: + ax.set_axis_off() + if self.legend: + if ax.legend(): + ax.legend_.set_visible(self.legend) + if self.margin: + ax.set_xmargin(self.margin) + ax.set_ymargin(self.margin) + if self.title: + ax.set_title(self.title) + if self.xlabel: + xlbl = _str_or_latex(self.xlabel) + ax.set_xlabel(xlbl, position=(1, 0)) + if self.ylabel: + ylbl = _str_or_latex(self.ylabel) + ax.set_ylabel(ylbl, position=(0, 1)) + if isinstance(ax, Axes3D) and self.zlabel: + zlbl = _str_or_latex(self.zlabel) + ax.set_zlabel(zlbl, position=(0, 1)) + + # xlim and ylim should always be set at last so that plot limits + # doesn't get altered during the process. + if self.xlim: + ax.set_xlim(self.xlim) + if self.ylim: + ax.set_ylim(self.ylim) + self.ax.set_aspect(self.aspect) + + + def process_series(self): + """ + Iterates over every ``Plot`` object and further calls + _process_series() + """ + self._create_figure() + self._process_series(self._series, self.ax) + + def show(self): + self.process_series() + #TODO after fixing https://github.com/ipython/ipython/issues/1255 + # you can uncomment the next line and remove the pyplot.show() call + #self.fig.show() + if base_backend._show: + self.fig.tight_layout() + self.plt.show() + else: + self.close() + + def save(self, path): + self.process_series() + self.fig.savefig(path) + + def close(self): + self.plt.close(self.fig) diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__init__.py b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ca4685e4b7790653a97b712c27b240ade5bb481a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__init__.py @@ -0,0 +1,3 @@ +from sympy.plotting.backends.textbackend.text import TextBackend + +__all__ = ["TextBackend"] diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99e533c52bbec107409e5053c7c036d92fde00c0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__pycache__/text.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__pycache__/text.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6c6d17c468d5c270f1cac0511c45a7cf9f1cefb Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/__pycache__/text.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/text.py b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/text.py new file mode 100644 index 0000000000000000000000000000000000000000..0917ec78b3463a929c373c98fdd279d84ce4c9e5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/backends/textbackend/text.py @@ -0,0 +1,24 @@ +import sympy.plotting.backends.base_backend as base_backend +from sympy.plotting.series import LineOver1DRangeSeries +from sympy.plotting.textplot import textplot + + +class TextBackend(base_backend.Plot): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def show(self): + if not base_backend._show: + return + if len(self._series) != 1: + raise ValueError( + 'The TextBackend supports only one graph per Plot.') + elif not isinstance(self._series[0], LineOver1DRangeSeries): + raise ValueError( + 'The TextBackend supports only expressions over a 1D range') + else: + ser = self._series[0] + textplot(ser.expr, ser.start, ser.end) + + def close(self): + pass diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/__init__.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fb9a6a57f94e931f0c5f5b3dda7b0b6fd31841f4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__init__.py @@ -0,0 +1,12 @@ +from .interval_arithmetic import interval +from .lib_interval import (Abs, exp, log, log10, sin, cos, tan, sqrt, + imin, imax, sinh, cosh, tanh, acosh, asinh, atanh, + asin, acos, atan, ceil, floor, And, Or) + +__all__ = [ + 'interval', + + 'Abs', 'exp', 'log', 'log10', 'sin', 'cos', 'tan', 'sqrt', 'imin', 'imax', + 'sinh', 'cosh', 'tanh', 'acosh', 'asinh', 'atanh', 'asin', 'acos', 'atan', + 'ceil', 'floor', 'And', 'Or', +] diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4044c39dbab953f75c514dd220aaf7e5d5ca5f60 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/interval_arithmetic.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/interval_arithmetic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2296b58d23a82aa1211917690dc728447aa2fe03 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/interval_arithmetic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/interval_membership.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/interval_membership.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dcd7a830763277a649c8b48ee8c9e05f143f274d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/interval_membership.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/lib_interval.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/lib_interval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fe0fe6d76188b39f0dbf8febb741af7054e2c440 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/__pycache__/lib_interval.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py new file mode 100644 index 0000000000000000000000000000000000000000..fc5c0e2ef118c7cf4f80de53a3590de11130410e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/intervalmath/interval_arithmetic.py @@ -0,0 +1,413 @@ +""" +Interval Arithmetic for plotting. +This module does not implement interval arithmetic accurately and +hence cannot be used for purposes other than plotting. If you want +to use interval arithmetic, use mpmath's interval arithmetic. + +The module implements interval arithmetic using numpy and +python floating points. The rounding up and down is not handled +and hence this is not an accurate implementation of interval +arithmetic. + +The module uses numpy for speed which cannot be achieved with mpmath. +""" + +# Q: Why use numpy? Why not simply use mpmath's interval arithmetic? +# A: mpmath's interval arithmetic simulates a floating point unit +# and hence is slow, while numpy evaluations are orders of magnitude +# faster. + +# Q: Why create a separate class for intervals? Why not use SymPy's +# Interval Sets? +# A: The functionalities that will be required for plotting is quite +# different from what Interval Sets implement. + +# Q: Why is rounding up and down according to IEEE754 not handled? +# A: It is not possible to do it in both numpy and python. An external +# library has to used, which defeats the whole purpose i.e., speed. Also +# rounding is handled for very few functions in those libraries. + +# Q Will my plots be affected? +# A It will not affect most of the plots. The interval arithmetic +# module based suffers the same problems as that of floating point +# arithmetic. + +from sympy.core.numbers import int_valued +from sympy.core.logic import fuzzy_and +from sympy.simplify.simplify import nsimplify + +from .interval_membership import intervalMembership + + +class interval: + """ Represents an interval containing floating points as start and + end of the interval + The is_valid variable tracks whether the interval obtained as the + result of the function is in the domain and is continuous. + - True: Represents the interval result of a function is continuous and + in the domain of the function. + - False: The interval argument of the function was not in the domain of + the function, hence the is_valid of the result interval is False + - None: The function was not continuous over the interval or + the function's argument interval is partly in the domain of the + function + + A comparison between an interval and a real number, or a + comparison between two intervals may return ``intervalMembership`` + of two 3-valued logic values. + """ + + def __init__(self, *args, is_valid=True, **kwargs): + self.is_valid = is_valid + if len(args) == 1: + if isinstance(args[0], interval): + self.start, self.end = args[0].start, args[0].end + else: + self.start = float(args[0]) + self.end = float(args[0]) + elif len(args) == 2: + if args[0] < args[1]: + self.start = float(args[0]) + self.end = float(args[1]) + else: + self.start = float(args[1]) + self.end = float(args[0]) + + else: + raise ValueError("interval takes a maximum of two float values " + "as arguments") + + @property + def mid(self): + return (self.start + self.end) / 2.0 + + @property + def width(self): + return self.end - self.start + + def __repr__(self): + return "interval(%f, %f)" % (self.start, self.end) + + def __str__(self): + return "[%f, %f]" % (self.start, self.end) + + def __lt__(self, other): + if isinstance(other, (int, float)): + if self.end < other: + return intervalMembership(True, self.is_valid) + elif self.start > other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + + elif isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.end < other. start: + return intervalMembership(True, valid) + if self.start > other.end: + return intervalMembership(False, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __gt__(self, other): + if isinstance(other, (int, float)): + if self.start > other: + return intervalMembership(True, self.is_valid) + elif self.end < other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + elif isinstance(other, interval): + return other.__lt__(self) + else: + return NotImplemented + + def __eq__(self, other): + if isinstance(other, (int, float)): + if self.start == other and self.end == other: + return intervalMembership(True, self.is_valid) + if other in self: + return intervalMembership(None, self.is_valid) + else: + return intervalMembership(False, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.start == other.start and self.end == other.end: + return intervalMembership(True, valid) + elif self.__lt__(other)[0] is not None: + return intervalMembership(False, valid) + else: + return intervalMembership(None, valid) + else: + return NotImplemented + + def __ne__(self, other): + if isinstance(other, (int, float)): + if self.start == other and self.end == other: + return intervalMembership(False, self.is_valid) + if other in self: + return intervalMembership(None, self.is_valid) + else: + return intervalMembership(True, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.start == other.start and self.end == other.end: + return intervalMembership(False, valid) + if not self.__lt__(other)[0] is None: + return intervalMembership(True, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __le__(self, other): + if isinstance(other, (int, float)): + if self.end <= other: + return intervalMembership(True, self.is_valid) + if self.start > other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + + if isinstance(other, interval): + valid = fuzzy_and([self.is_valid, other.is_valid]) + if self.end <= other.start: + return intervalMembership(True, valid) + if self.start > other.end: + return intervalMembership(False, valid) + return intervalMembership(None, valid) + else: + return NotImplemented + + def __ge__(self, other): + if isinstance(other, (int, float)): + if self.start >= other: + return intervalMembership(True, self.is_valid) + elif self.end < other: + return intervalMembership(False, self.is_valid) + else: + return intervalMembership(None, self.is_valid) + elif isinstance(other, interval): + return other.__le__(self) + + def __add__(self, other): + if isinstance(other, (int, float)): + if self.is_valid: + return interval(self.start + other, self.end + other) + else: + start = self.start + other + end = self.end + other + return interval(start, end, is_valid=self.is_valid) + + elif isinstance(other, interval): + start = self.start + other.start + end = self.end + other.end + valid = fuzzy_and([self.is_valid, other.is_valid]) + return interval(start, end, is_valid=valid) + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other): + if isinstance(other, (int, float)): + start = self.start - other + end = self.end - other + return interval(start, end, is_valid=self.is_valid) + + elif isinstance(other, interval): + start = self.start - other.end + end = self.end - other.start + valid = fuzzy_and([self.is_valid, other.is_valid]) + return interval(start, end, is_valid=valid) + else: + return NotImplemented + + def __rsub__(self, other): + if isinstance(other, (int, float)): + start = other - self.end + end = other - self.start + return interval(start, end, is_valid=self.is_valid) + elif isinstance(other, interval): + return other.__sub__(self) + else: + return NotImplemented + + def __neg__(self): + if self.is_valid: + return interval(-self.end, -self.start) + else: + return interval(-self.end, -self.start, is_valid=self.is_valid) + + def __mul__(self, other): + if isinstance(other, interval): + if self.is_valid is False or other.is_valid is False: + return interval(-float('inf'), float('inf'), is_valid=False) + elif self.is_valid is None or other.is_valid is None: + return interval(-float('inf'), float('inf'), is_valid=None) + else: + inters = [] + inters.append(self.start * other.start) + inters.append(self.end * other.start) + inters.append(self.start * other.end) + inters.append(self.end * other.end) + start = min(inters) + end = max(inters) + return interval(start, end) + elif isinstance(other, (int, float)): + return interval(self.start*other, self.end*other, is_valid=self.is_valid) + else: + return NotImplemented + + __rmul__ = __mul__ + + def __contains__(self, other): + if isinstance(other, (int, float)): + return self.start <= other and self.end >= other + else: + return self.start <= other.start and other.end <= self.end + + def __rtruediv__(self, other): + if isinstance(other, (int, float)): + other = interval(other) + return other.__truediv__(self) + elif isinstance(other, interval): + return other.__truediv__(self) + else: + return NotImplemented + + def __truediv__(self, other): + # Both None and False are handled + if not self.is_valid: + # Don't divide as the value is not valid + return interval(-float('inf'), float('inf'), is_valid=self.is_valid) + if isinstance(other, (int, float)): + if other == 0: + # Divide by zero encountered. valid nowhere + return interval(-float('inf'), float('inf'), is_valid=False) + else: + return interval(self.start / other, self.end / other) + + elif isinstance(other, interval): + if other.is_valid is False or self.is_valid is False: + return interval(-float('inf'), float('inf'), is_valid=False) + elif other.is_valid is None or self.is_valid is None: + return interval(-float('inf'), float('inf'), is_valid=None) + else: + # denominator contains both signs, i.e. being divided by zero + # return the whole real line with is_valid = None + if 0 in other: + return interval(-float('inf'), float('inf'), is_valid=None) + + # denominator negative + this = self + if other.end < 0: + this = -this + other = -other + + # denominator positive + inters = [] + inters.append(this.start / other.start) + inters.append(this.end / other.start) + inters.append(this.start / other.end) + inters.append(this.end / other.end) + start = max(inters) + end = min(inters) + return interval(start, end) + else: + return NotImplemented + + def __pow__(self, other): + # Implements only power to an integer. + from .lib_interval import exp, log + if not self.is_valid: + return self + if isinstance(other, interval): + return exp(other * log(self)) + elif isinstance(other, (float, int)): + if other < 0: + return 1 / self.__pow__(abs(other)) + else: + if int_valued(other): + return _pow_int(self, other) + else: + return _pow_float(self, other) + else: + return NotImplemented + + def __rpow__(self, other): + if isinstance(other, (float, int)): + if not self.is_valid: + #Don't do anything + return self + elif other < 0: + if self.width > 0: + return interval(-float('inf'), float('inf'), is_valid=False) + else: + power_rational = nsimplify(self.start) + num, denom = power_rational.as_numer_denom() + if denom % 2 == 0: + return interval(-float('inf'), float('inf'), + is_valid=False) + else: + start = -abs(other)**self.start + end = start + return interval(start, end) + else: + return interval(other**self.start, other**self.end) + elif isinstance(other, interval): + return other.__pow__(self) + else: + return NotImplemented + + def __hash__(self): + return hash((self.is_valid, self.start, self.end)) + + +def _pow_float(inter, power): + """Evaluates an interval raised to a floating point.""" + power_rational = nsimplify(power) + num, denom = power_rational.as_numer_denom() + if num % 2 == 0: + start = abs(inter.start)**power + end = abs(inter.end)**power + if start < 0: + ret = interval(0, max(start, end)) + else: + ret = interval(start, end) + return ret + elif denom % 2 == 0: + if inter.end < 0: + return interval(-float('inf'), float('inf'), is_valid=False) + elif inter.start < 0: + return interval(0, inter.end**power, is_valid=None) + else: + return interval(inter.start**power, inter.end**power) + else: + if inter.start < 0: + start = -abs(inter.start)**power + else: + start = inter.start**power + + if inter.end < 0: + end = -abs(inter.end)**power + else: + end = inter.end**power + + return interval(start, end, is_valid=inter.is_valid) + + +def _pow_int(inter, power): + """Evaluates an interval raised to an integer power""" + power = int(power) + if power & 1: + return interval(inter.start**power, inter.end**power) + else: + if inter.start < 0 and inter.end > 0: + start = 0 + end = max(inter.start**power, inter.end**power) + return interval(start, end) + else: + return interval(inter.start**power, inter.end**power) diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/interval_membership.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/interval_membership.py new file mode 100644 index 0000000000000000000000000000000000000000..c4887c2d96f0d006b95a8e207a4f4a75940aec23 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/intervalmath/interval_membership.py @@ -0,0 +1,78 @@ +from sympy.core.logic import fuzzy_and, fuzzy_or, fuzzy_not, fuzzy_xor + + +class intervalMembership: + """Represents a boolean expression returned by the comparison of + the interval object. + + Parameters + ========== + + (a, b) : (bool, bool) + The first value determines the comparison as follows: + - True: If the comparison is True throughout the intervals. + - False: If the comparison is False throughout the intervals. + - None: If the comparison is True for some part of the intervals. + + The second value is determined as follows: + - True: If both the intervals in comparison are valid. + - False: If at least one of the intervals is False, else + - None + """ + def __init__(self, a, b): + self._wrapped = (a, b) + + def __getitem__(self, i): + try: + return self._wrapped[i] + except IndexError: + raise IndexError( + "{} must be a valid indexing for the 2-tuple." + .format(i)) + + def __len__(self): + return 2 + + def __iter__(self): + return iter(self._wrapped) + + def __str__(self): + return "intervalMembership({}, {})".format(*self) + __repr__ = __str__ + + def __and__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_and([a1, a2]), fuzzy_and([b1, b2])) + + def __or__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_or([a1, a2]), fuzzy_and([b1, b2])) + + def __invert__(self): + a, b = self + return intervalMembership(fuzzy_not(a), b) + + def __xor__(self, other): + if not isinstance(other, intervalMembership): + raise ValueError( + "The comparison is not supported for {}.".format(other)) + + a1, b1 = self + a2, b2 = other + return intervalMembership(fuzzy_xor([a1, a2]), fuzzy_and([b1, b2])) + + def __eq__(self, other): + return self._wrapped == other + + def __ne__(self, other): + return self._wrapped != other diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/lib_interval.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/lib_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..7549a05820d747ce057892f8df1fbcbc61cc3f43 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/intervalmath/lib_interval.py @@ -0,0 +1,452 @@ +""" The module contains implemented functions for interval arithmetic.""" +from functools import reduce + +from sympy.plotting.intervalmath import interval +from sympy.external import import_module + + +def Abs(x): + if isinstance(x, (int, float)): + return interval(abs(x)) + elif isinstance(x, interval): + if x.start < 0 and x.end > 0: + return interval(0, max(abs(x.start), abs(x.end)), is_valid=x.is_valid) + else: + return interval(abs(x.start), abs(x.end)) + else: + raise NotImplementedError + +#Monotonic + + +def exp(x): + """evaluates the exponential of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.exp(x), np.exp(x)) + elif isinstance(x, interval): + return interval(np.exp(x.start), np.exp(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +#Monotonic +def log(x): + """evaluates the natural logarithm of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x <= 0: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.log(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-np.inf, np.inf, is_valid=x.is_valid) + elif x.end <= 0: + return interval(-np.inf, np.inf, is_valid=False) + elif x.start <= 0: + return interval(-np.inf, np.inf, is_valid=None) + + return interval(np.log(x.start), np.log(x.end)) + else: + raise NotImplementedError + + +#Monotonic +def log10(x): + """evaluates the logarithm to the base 10 of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x <= 0: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.log10(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-np.inf, np.inf, is_valid=x.is_valid) + elif x.end <= 0: + return interval(-np.inf, np.inf, is_valid=False) + elif x.start <= 0: + return interval(-np.inf, np.inf, is_valid=None) + return interval(np.log10(x.start), np.log10(x.end)) + else: + raise NotImplementedError + + +#Monotonic +def atan(x): + """evaluates the tan inverse of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.arctan(x)) + elif isinstance(x, interval): + start = np.arctan(x.start) + end = np.arctan(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +#periodic +def sin(x): + """evaluates the sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sin(x)) + elif isinstance(x, interval): + if not x.is_valid: + return interval(-1, 1, is_valid=x.is_valid) + na, __ = divmod(x.start, np.pi / 2.0) + nb, __ = divmod(x.end, np.pi / 2.0) + start = min(np.sin(x.start), np.sin(x.end)) + end = max(np.sin(x.start), np.sin(x.end)) + if nb - na > 4: + return interval(-1, 1, is_valid=x.is_valid) + elif na == nb: + return interval(start, end, is_valid=x.is_valid) + else: + if (na - 1) // 4 != (nb - 1) // 4: + #sin has max + end = 1 + if (na - 3) // 4 != (nb - 3) // 4: + #sin has min + start = -1 + return interval(start, end) + else: + raise NotImplementedError + + +#periodic +def cos(x): + """Evaluates the cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sin(x)) + elif isinstance(x, interval): + if not (np.isfinite(x.start) and np.isfinite(x.end)): + return interval(-1, 1, is_valid=x.is_valid) + na, __ = divmod(x.start, np.pi / 2.0) + nb, __ = divmod(x.end, np.pi / 2.0) + start = min(np.cos(x.start), np.cos(x.end)) + end = max(np.cos(x.start), np.cos(x.end)) + if nb - na > 4: + #differ more than 2*pi + return interval(-1, 1, is_valid=x.is_valid) + elif na == nb: + #in the same quadarant + return interval(start, end, is_valid=x.is_valid) + else: + if (na) // 4 != (nb) // 4: + #cos has max + end = 1 + if (na - 2) // 4 != (nb - 2) // 4: + #cos has min + start = -1 + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +def tan(x): + """Evaluates the tan of an interval""" + return sin(x) / cos(x) + + +#Monotonic +def sqrt(x): + """Evaluates the square root of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if x > 0: + return interval(np.sqrt(x)) + else: + return interval(-np.inf, np.inf, is_valid=False) + elif isinstance(x, interval): + #Outside the domain + if x.end < 0: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < 0: + return interval(-np.inf, np.inf, is_valid=None) + else: + return interval(np.sqrt(x.start), np.sqrt(x.end), + is_valid=x.is_valid) + else: + raise NotImplementedError + + +def imin(*args): + """Evaluates the minimum of a list of intervals""" + np = import_module('numpy') + if not all(isinstance(arg, (int, float, interval)) for arg in args): + return NotImplementedError + else: + new_args = [a for a in args if isinstance(a, (int, float)) + or a.is_valid] + if len(new_args) == 0: + if all(a.is_valid is False for a in args): + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(-np.inf, np.inf, is_valid=None) + start_array = [a if isinstance(a, (int, float)) else a.start + for a in new_args] + + end_array = [a if isinstance(a, (int, float)) else a.end + for a in new_args] + return interval(min(start_array), min(end_array)) + + +def imax(*args): + """Evaluates the maximum of a list of intervals""" + np = import_module('numpy') + if not all(isinstance(arg, (int, float, interval)) for arg in args): + return NotImplementedError + else: + new_args = [a for a in args if isinstance(a, (int, float)) + or a.is_valid] + if len(new_args) == 0: + if all(a.is_valid is False for a in args): + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(-np.inf, np.inf, is_valid=None) + start_array = [a if isinstance(a, (int, float)) else a.start + for a in new_args] + + end_array = [a if isinstance(a, (int, float)) else a.end + for a in new_args] + + return interval(max(start_array), max(end_array)) + + +#Monotonic +def sinh(x): + """Evaluates the hyperbolic sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.sinh(x), np.sinh(x)) + elif isinstance(x, interval): + return interval(np.sinh(x.start), np.sinh(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +def cosh(x): + """Evaluates the hyperbolic cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.cosh(x), np.cosh(x)) + elif isinstance(x, interval): + #both signs + if x.start < 0 and x.end > 0: + end = max(np.cosh(x.start), np.cosh(x.end)) + return interval(1, end, is_valid=x.is_valid) + else: + #Monotonic + start = np.cosh(x.start) + end = np.cosh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + raise NotImplementedError + + +#Monotonic +def tanh(x): + """Evaluates the hyperbolic tan of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.tanh(x), np.tanh(x)) + elif isinstance(x, interval): + return interval(np.tanh(x.start), np.tanh(x.end), is_valid=x.is_valid) + else: + raise NotImplementedError + + +def asin(x): + """Evaluates the inverse sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if abs(x) > 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arcsin(x), np.arcsin(x)) + elif isinstance(x, interval): + #Outside the domain + if x.is_valid is False or x.start > 1 or x.end < -1: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < -1 or x.end > 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arcsin(x.start) + end = np.arcsin(x.end) + return interval(start, end, is_valid=x.is_valid) + + +def acos(x): + """Evaluates the inverse cos of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + if abs(x) > 1: + #Outside the domain + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arccos(x), np.arccos(x)) + elif isinstance(x, interval): + #Outside the domain + if x.is_valid is False or x.start > 1 or x.end < -1: + return interval(-np.inf, np.inf, is_valid=False) + #Partially outside the domain + elif x.start < -1 or x.end > 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arccos(x.start) + end = np.arccos(x.end) + return interval(start, end, is_valid=x.is_valid) + + +def ceil(x): + """Evaluates the ceiling of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.ceil(x)) + elif isinstance(x, interval): + if x.is_valid is False: + return interval(-np.inf, np.inf, is_valid=False) + else: + start = np.ceil(x.start) + end = np.ceil(x.end) + #Continuous over the interval + if start == end: + return interval(start, end, is_valid=x.is_valid) + else: + #Not continuous over the interval + return interval(start, end, is_valid=None) + else: + return NotImplementedError + + +def floor(x): + """Evaluates the floor of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.floor(x)) + elif isinstance(x, interval): + if x.is_valid is False: + return interval(-np.inf, np.inf, is_valid=False) + else: + start = np.floor(x.start) + end = np.floor(x.end) + #continuous over the argument + if start == end: + return interval(start, end, is_valid=x.is_valid) + else: + #not continuous over the interval + return interval(start, end, is_valid=None) + else: + return NotImplementedError + + +def acosh(x): + """Evaluates the inverse hyperbolic cosine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if x < 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arccosh(x)) + elif isinstance(x, interval): + #Outside the domain + if x.end < 1: + return interval(-np.inf, np.inf, is_valid=False) + #Partly outside the domain + elif x.start < 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arccosh(x.start) + end = np.arccosh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +#Monotonic +def asinh(x): + """Evaluates the inverse hyperbolic sine of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + return interval(np.arcsinh(x)) + elif isinstance(x, interval): + start = np.arcsinh(x.start) + end = np.arcsinh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +def atanh(x): + """Evaluates the inverse hyperbolic tangent of an interval""" + np = import_module('numpy') + if isinstance(x, (int, float)): + #Outside the domain + if abs(x) >= 1: + return interval(-np.inf, np.inf, is_valid=False) + else: + return interval(np.arctanh(x)) + elif isinstance(x, interval): + #outside the domain + if x.is_valid is False or x.start >= 1 or x.end <= -1: + return interval(-np.inf, np.inf, is_valid=False) + #partly outside the domain + elif x.start <= -1 or x.end >= 1: + return interval(-np.inf, np.inf, is_valid=None) + else: + start = np.arctanh(x.start) + end = np.arctanh(x.end) + return interval(start, end, is_valid=x.is_valid) + else: + return NotImplementedError + + +#Three valued logic for interval plotting. + +def And(*args): + """Defines the three valued ``And`` behaviour for a 2-tuple of + three valued logic values""" + def reduce_and(cmp_intervala, cmp_intervalb): + if cmp_intervala[0] is False or cmp_intervalb[0] is False: + first = False + elif cmp_intervala[0] is None or cmp_intervalb[0] is None: + first = None + else: + first = True + if cmp_intervala[1] is False or cmp_intervalb[1] is False: + second = False + elif cmp_intervala[1] is None or cmp_intervalb[1] is None: + second = None + else: + second = True + return (first, second) + return reduce(reduce_and, args) + + +def Or(*args): + """Defines the three valued ``Or`` behaviour for a 2-tuple of + three valued logic values""" + def reduce_or(cmp_intervala, cmp_intervalb): + if cmp_intervala[0] is True or cmp_intervalb[0] is True: + first = True + elif cmp_intervala[0] is None or cmp_intervalb[0] is None: + first = None + else: + first = False + + if cmp_intervala[1] is True or cmp_intervalb[1] is True: + second = True + elif cmp_intervala[1] is None or cmp_intervalb[1] is None: + second = None + else: + second = False + return (first, second) + return reduce(reduce_or, args) diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__init__.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2117f7de25dc3f9a25833c3a7869f53d25289f6c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_functions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..102c8d49d4269f05945700912a5b0826c635fe7a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_membership.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_membership.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdae85a17765df27228ecdbc1f1943bcf4c772eb Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_interval_membership.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_intervalmath.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_intervalmath.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f5ec70cf05cc9ccf1a5662e427c02b02e954f62e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/__pycache__/test_intervalmath.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..861c3660df024d3fbec788a027708348e9929655 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_interval_functions.py @@ -0,0 +1,415 @@ +from sympy.external import import_module +from sympy.plotting.intervalmath import ( + Abs, acos, acosh, And, asin, asinh, atan, atanh, ceil, cos, cosh, + exp, floor, imax, imin, interval, log, log10, Or, sin, sinh, sqrt, + tan, tanh, +) + +np = import_module('numpy') +if not np: + disabled = True + + +#requires Numpy. Hence included in interval_functions + + +def test_interval_pow(): + a = 2**interval(1, 2) == interval(2, 4) + assert a == (True, True) + a = interval(1, 2)**interval(1, 2) == interval(1, 4) + assert a == (True, True) + a = interval(-1, 1)**interval(0.5, 2) + assert a.is_valid is None + a = interval(-2, -1) ** interval(1, 2) + assert a.is_valid is False + a = interval(-2, -1) ** (1.0 / 2) + assert a.is_valid is False + a = interval(-1, 1)**(1.0 / 2) + assert a.is_valid is None + a = interval(-1, 1)**(1.0 / 3) == interval(-1, 1) + assert a == (True, True) + a = interval(-1, 1)**2 == interval(0, 1) + assert a == (True, True) + a = interval(-1, 1) ** (1.0 / 29) == interval(-1, 1) + assert a == (True, True) + a = -2**interval(1, 1) == interval(-2, -2) + assert a == (True, True) + + a = interval(1, 2, is_valid=False)**2 + assert a.is_valid is False + + a = (-3)**interval(1, 2) + assert a.is_valid is False + a = (-4)**interval(0.5, 0.5) + assert a.is_valid is False + assert ((-3)**interval(1, 1) == interval(-3, -3)) == (True, True) + + a = interval(8, 64)**(2.0 / 3) + assert abs(a.start - 4) < 1e-10 # eps + assert abs(a.end - 16) < 1e-10 + a = interval(-8, 64)**(2.0 / 3) + assert abs(a.start - 4) < 1e-10 # eps + assert abs(a.end - 16) < 1e-10 + + +def test_exp(): + a = exp(interval(-np.inf, 0)) + assert a.start == np.exp(-np.inf) + assert a.end == np.exp(0) + a = exp(interval(1, 2)) + assert a.start == np.exp(1) + assert a.end == np.exp(2) + a = exp(1) + assert a.start == np.exp(1) + assert a.end == np.exp(1) + + +def test_log(): + a = log(interval(1, 2)) + assert a.start == 0 + assert a.end == np.log(2) + a = log(interval(-1, 1)) + assert a.is_valid is None + a = log(interval(-3, -1)) + assert a.is_valid is False + a = log(-3) + assert a.is_valid is False + a = log(2) + assert a.start == np.log(2) + assert a.end == np.log(2) + + +def test_log10(): + a = log10(interval(1, 2)) + assert a.start == 0 + assert a.end == np.log10(2) + a = log10(interval(-1, 1)) + assert a.is_valid is None + a = log10(interval(-3, -1)) + assert a.is_valid is False + a = log10(-3) + assert a.is_valid is False + a = log10(2) + assert a.start == np.log10(2) + assert a.end == np.log10(2) + + +def test_atan(): + a = atan(interval(0, 1)) + assert a.start == np.arctan(0) + assert a.end == np.arctan(1) + a = atan(1) + assert a.start == np.arctan(1) + assert a.end == np.arctan(1) + + +def test_sin(): + a = sin(interval(0, np.pi / 4)) + assert a.start == np.sin(0) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(-np.pi / 4, np.pi / 4)) + assert a.start == np.sin(-np.pi / 4) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(np.pi / 4, 3 * np.pi / 4)) + assert a.start == np.sin(np.pi / 4) + assert a.end == 1 + + a = sin(interval(7 * np.pi / 6, 7 * np.pi / 4)) + assert a.start == -1 + assert a.end == np.sin(7 * np.pi / 6) + + a = sin(interval(0, 3 * np.pi)) + assert a.start == -1 + assert a.end == 1 + + a = sin(interval(np.pi / 3, 7 * np.pi / 4)) + assert a.start == -1 + assert a.end == 1 + + a = sin(np.pi / 4) + assert a.start == np.sin(np.pi / 4) + assert a.end == np.sin(np.pi / 4) + + a = sin(interval(1, 2, is_valid=False)) + assert a.is_valid is False + + +def test_cos(): + a = cos(interval(0, np.pi / 4)) + assert a.start == np.cos(np.pi / 4) + assert a.end == 1 + + a = cos(interval(-np.pi / 4, np.pi / 4)) + assert a.start == np.cos(-np.pi / 4) + assert a.end == 1 + + a = cos(interval(np.pi / 4, 3 * np.pi / 4)) + assert a.start == np.cos(3 * np.pi / 4) + assert a.end == np.cos(np.pi / 4) + + a = cos(interval(3 * np.pi / 4, 5 * np.pi / 4)) + assert a.start == -1 + assert a.end == np.cos(3 * np.pi / 4) + + a = cos(interval(0, 3 * np.pi)) + assert a.start == -1 + assert a.end == 1 + + a = cos(interval(- np.pi / 3, 5 * np.pi / 4)) + assert a.start == -1 + assert a.end == 1 + + a = cos(interval(1, 2, is_valid=False)) + assert a.is_valid is False + + +def test_tan(): + a = tan(interval(0, np.pi / 4)) + assert a.start == 0 + # must match lib_interval definition of tan: + assert a.end == np.sin(np.pi / 4)/np.cos(np.pi / 4) + + a = tan(interval(np.pi / 4, 3 * np.pi / 4)) + #discontinuity + assert a.is_valid is None + + +def test_sqrt(): + a = sqrt(interval(1, 4)) + assert a.start == 1 + assert a.end == 2 + + a = sqrt(interval(0.01, 1)) + assert a.start == np.sqrt(0.01) + assert a.end == 1 + + a = sqrt(interval(-1, 1)) + assert a.is_valid is None + + a = sqrt(interval(-3, -1)) + assert a.is_valid is False + + a = sqrt(4) + assert (a == interval(2, 2)) == (True, True) + + a = sqrt(-3) + assert a.is_valid is False + + +def test_imin(): + a = imin(interval(1, 3), interval(2, 5), interval(-1, 3)) + assert a.start == -1 + assert a.end == 3 + + a = imin(-2, interval(1, 4)) + assert a.start == -2 + assert a.end == -2 + + a = imin(5, interval(3, 4), interval(-2, 2, is_valid=False)) + assert a.start == 3 + assert a.end == 4 + + +def test_imax(): + a = imax(interval(-2, 2), interval(2, 7), interval(-3, 9)) + assert a.start == 2 + assert a.end == 9 + + a = imax(8, interval(1, 4)) + assert a.start == 8 + assert a.end == 8 + + a = imax(interval(1, 2), interval(3, 4), interval(-2, 2, is_valid=False)) + assert a.start == 3 + assert a.end == 4 + + +def test_sinh(): + a = sinh(interval(-1, 1)) + assert a.start == np.sinh(-1) + assert a.end == np.sinh(1) + + a = sinh(1) + assert a.start == np.sinh(1) + assert a.end == np.sinh(1) + + +def test_cosh(): + a = cosh(interval(1, 2)) + assert a.start == np.cosh(1) + assert a.end == np.cosh(2) + a = cosh(interval(-2, -1)) + assert a.start == np.cosh(-1) + assert a.end == np.cosh(-2) + + a = cosh(interval(-2, 1)) + assert a.start == 1 + assert a.end == np.cosh(-2) + + a = cosh(1) + assert a.start == np.cosh(1) + assert a.end == np.cosh(1) + + +def test_tanh(): + a = tanh(interval(-3, 3)) + assert a.start == np.tanh(-3) + assert a.end == np.tanh(3) + + a = tanh(3) + assert a.start == np.tanh(3) + assert a.end == np.tanh(3) + + +def test_asin(): + a = asin(interval(-0.5, 0.5)) + assert a.start == np.arcsin(-0.5) + assert a.end == np.arcsin(0.5) + + a = asin(interval(-1.5, 1.5)) + assert a.is_valid is None + a = asin(interval(-2, -1.5)) + assert a.is_valid is False + + a = asin(interval(0, 2)) + assert a.is_valid is None + + a = asin(interval(2, 5)) + assert a.is_valid is False + + a = asin(0.5) + assert a.start == np.arcsin(0.5) + assert a.end == np.arcsin(0.5) + + a = asin(1.5) + assert a.is_valid is False + + +def test_acos(): + a = acos(interval(-0.5, 0.5)) + assert a.start == np.arccos(0.5) + assert a.end == np.arccos(-0.5) + + a = acos(interval(-1.5, 1.5)) + assert a.is_valid is None + a = acos(interval(-2, -1.5)) + assert a.is_valid is False + + a = acos(interval(0, 2)) + assert a.is_valid is None + + a = acos(interval(2, 5)) + assert a.is_valid is False + + a = acos(0.5) + assert a.start == np.arccos(0.5) + assert a.end == np.arccos(0.5) + + a = acos(1.5) + assert a.is_valid is False + + +def test_ceil(): + a = ceil(interval(0.2, 0.5)) + assert a.start == 1 + assert a.end == 1 + + a = ceil(interval(0.5, 1.5)) + assert a.start == 1 + assert a.end == 2 + assert a.is_valid is None + + a = ceil(interval(-5, 5)) + assert a.is_valid is None + + a = ceil(5.4) + assert a.start == 6 + assert a.end == 6 + + +def test_floor(): + a = floor(interval(0.2, 0.5)) + assert a.start == 0 + assert a.end == 0 + + a = floor(interval(0.5, 1.5)) + assert a.start == 0 + assert a.end == 1 + assert a.is_valid is None + + a = floor(interval(-5, 5)) + assert a.is_valid is None + + a = floor(5.4) + assert a.start == 5 + assert a.end == 5 + + +def test_asinh(): + a = asinh(interval(1, 2)) + assert a.start == np.arcsinh(1) + assert a.end == np.arcsinh(2) + + a = asinh(0.5) + assert a.start == np.arcsinh(0.5) + assert a.end == np.arcsinh(0.5) + + +def test_acosh(): + a = acosh(interval(3, 5)) + assert a.start == np.arccosh(3) + assert a.end == np.arccosh(5) + + a = acosh(interval(0, 3)) + assert a.is_valid is None + a = acosh(interval(-3, 0.5)) + assert a.is_valid is False + + a = acosh(0.5) + assert a.is_valid is False + + a = acosh(2) + assert a.start == np.arccosh(2) + assert a.end == np.arccosh(2) + + +def test_atanh(): + a = atanh(interval(-0.5, 0.5)) + assert a.start == np.arctanh(-0.5) + assert a.end == np.arctanh(0.5) + + a = atanh(interval(0, 3)) + assert a.is_valid is None + + a = atanh(interval(-3, -2)) + assert a.is_valid is False + + a = atanh(0.5) + assert a.start == np.arctanh(0.5) + assert a.end == np.arctanh(0.5) + + a = atanh(1.5) + assert a.is_valid is False + + +def test_Abs(): + assert (Abs(interval(-0.5, 0.5)) == interval(0, 0.5)) == (True, True) + assert (Abs(interval(-3, -2)) == interval(2, 3)) == (True, True) + assert (Abs(-3) == interval(3, 3)) == (True, True) + + +def test_And(): + args = [(True, True), (True, False), (True, None)] + assert And(*args) == (True, False) + + args = [(False, True), (None, None), (True, True)] + assert And(*args) == (False, None) + + +def test_Or(): + args = [(True, True), (True, False), (False, None)] + assert Or(*args) == (True, True) + args = [(None, None), (False, None), (False, False)] + assert Or(*args) == (None, None) diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py new file mode 100644 index 0000000000000000000000000000000000000000..7b7f23680d60a64a6257a84c2476e31a8b5dfce8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_interval_membership.py @@ -0,0 +1,150 @@ +from sympy.core.symbol import Symbol +from sympy.plotting.intervalmath import interval +from sympy.plotting.intervalmath.interval_membership import intervalMembership +from sympy.plotting.experimental_lambdify import experimental_lambdify +from sympy.testing.pytest import raises + + +def test_creation(): + assert intervalMembership(True, True) + raises(TypeError, lambda: intervalMembership(True)) + raises(TypeError, lambda: intervalMembership(True, True, True)) + + +def test_getitem(): + a = intervalMembership(True, False) + assert a[0] is True + assert a[1] is False + raises(IndexError, lambda: a[2]) + + +def test_str(): + a = intervalMembership(True, False) + assert str(a) == 'intervalMembership(True, False)' + assert repr(a) == 'intervalMembership(True, False)' + + +def test_equivalence(): + a = intervalMembership(True, True) + b = intervalMembership(True, False) + assert (a == b) is False + assert (a != b) is True + + a = intervalMembership(True, False) + b = intervalMembership(True, False) + assert (a == b) is True + assert (a != b) is False + + +def test_not(): + x = Symbol('x') + + r1 = x > -1 + r2 = x <= -1 + + i = interval + + f1 = experimental_lambdify((x,), r1) + f2 = experimental_lambdify((x,), r2) + + tt = i(-0.1, 0.1, is_valid=True) + tn = i(-0.1, 0.1, is_valid=None) + tf = i(-0.1, 0.1, is_valid=False) + + assert f1(tt) == ~f2(tt) + assert f1(tn) == ~f2(tn) + assert f1(tf) == ~f2(tf) + + nt = i(0.9, 1.1, is_valid=True) + nn = i(0.9, 1.1, is_valid=None) + nf = i(0.9, 1.1, is_valid=False) + + assert f1(nt) == ~f2(nt) + assert f1(nn) == ~f2(nn) + assert f1(nf) == ~f2(nf) + + ft = i(1.9, 2.1, is_valid=True) + fn = i(1.9, 2.1, is_valid=None) + ff = i(1.9, 2.1, is_valid=False) + + assert f1(ft) == ~f2(ft) + assert f1(fn) == ~f2(fn) + assert f1(ff) == ~f2(ff) + + +def test_boolean(): + # There can be 9*9 test cases in full mapping of the cartesian product. + # But we only consider 3*3 cases for simplicity. + s = [ + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(True, True) + ] + + # Reduced tests for 'And' + a1 = [ + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(None, None), + intervalMembership(False, False), + intervalMembership(None, None), + intervalMembership(True, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] & s[j] == next(a1_iter) + + # Reduced tests for 'Or' + a1 = [ + intervalMembership(False, False), + intervalMembership(None, False), + intervalMembership(True, False), + intervalMembership(None, False), + intervalMembership(None, None), + intervalMembership(True, None), + intervalMembership(True, False), + intervalMembership(True, None), + intervalMembership(True, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] | s[j] == next(a1_iter) + + # Reduced tests for 'Xor' + a1 = [ + intervalMembership(False, False), + intervalMembership(None, False), + intervalMembership(True, False), + intervalMembership(None, False), + intervalMembership(None, None), + intervalMembership(None, None), + intervalMembership(True, False), + intervalMembership(None, None), + intervalMembership(False, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + for j in range(len(s)): + assert s[i] ^ s[j] == next(a1_iter) + + # Reduced tests for 'Not' + a1 = [ + intervalMembership(True, False), + intervalMembership(None, None), + intervalMembership(False, True) + ] + a1_iter = iter(a1) + for i in range(len(s)): + assert ~s[i] == next(a1_iter) + + +def test_boolean_errors(): + a = intervalMembership(True, True) + raises(ValueError, lambda: a & 1) + raises(ValueError, lambda: a | 1) + raises(ValueError, lambda: a ^ 1) diff --git a/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py new file mode 100644 index 0000000000000000000000000000000000000000..e30f217a44b4ea795270c0e2c66b6813b05e63ea --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/intervalmath/tests/test_intervalmath.py @@ -0,0 +1,213 @@ +from sympy.plotting.intervalmath import interval +from sympy.testing.pytest import raises + + +def test_interval(): + assert (interval(1, 1) == interval(1, 1, is_valid=True)) == (True, True) + assert (interval(1, 1) == interval(1, 1, is_valid=False)) == (True, False) + assert (interval(1, 1) == interval(1, 1, is_valid=None)) == (True, None) + assert (interval(1, 1.5) == interval(1, 2)) == (None, True) + assert (interval(0, 1) == interval(2, 3)) == (False, True) + assert (interval(0, 1) == interval(1, 2)) == (None, True) + assert (interval(1, 2) != interval(1, 2)) == (False, True) + assert (interval(1, 3) != interval(2, 3)) == (None, True) + assert (interval(1, 3) != interval(-5, -3)) == (True, True) + assert ( + interval(1, 3, is_valid=False) != interval(-5, -3)) == (True, False) + assert (interval(1, 3, is_valid=None) != interval(-5, 3)) == (None, None) + assert (interval(4, 4) != 4) == (False, True) + assert (interval(1, 1) == 1) == (True, True) + assert (interval(1, 3, is_valid=False) == interval(1, 3)) == (True, False) + assert (interval(1, 3, is_valid=None) == interval(1, 3)) == (True, None) + inter = interval(-5, 5) + assert (interval(inter) == interval(-5, 5)) == (True, True) + assert inter.width == 10 + assert 0 in inter + assert -5 in inter + assert 5 in inter + assert interval(0, 3) in inter + assert interval(-6, 2) not in inter + assert -5.05 not in inter + assert 5.3 not in inter + interb = interval(-float('inf'), float('inf')) + assert 0 in inter + assert inter in interb + assert interval(0, float('inf')) in interb + assert interval(-float('inf'), 5) in interb + assert interval(-1e50, 1e50) in interb + assert ( + -interval(-1, -2, is_valid=False) == interval(1, 2)) == (True, False) + raises(ValueError, lambda: interval(1, 2, 3)) + + +def test_interval_add(): + assert (interval(1, 2) + interval(2, 3) == interval(3, 5)) == (True, True) + assert (1 + interval(1, 2) == interval(2, 3)) == (True, True) + assert (interval(1, 2) + 1 == interval(2, 3)) == (True, True) + compare = (1 + interval(0, float('inf')) == interval(1, float('inf'))) + assert compare == (True, True) + a = 1 + interval(2, 5, is_valid=False) + assert a.is_valid is False + a = 1 + interval(2, 5, is_valid=None) + assert a.is_valid is None + a = interval(2, 5, is_valid=False) + interval(3, 5, is_valid=None) + assert a.is_valid is False + a = interval(3, 5) + interval(-1, 1, is_valid=None) + assert a.is_valid is None + a = interval(2, 5, is_valid=False) + 1 + assert a.is_valid is False + + +def test_interval_sub(): + assert (interval(1, 2) - interval(1, 5) == interval(-4, 1)) == (True, True) + assert (interval(1, 2) - 1 == interval(0, 1)) == (True, True) + assert (1 - interval(1, 2) == interval(-1, 0)) == (True, True) + a = 1 - interval(1, 2, is_valid=False) + assert a.is_valid is False + a = interval(1, 4, is_valid=None) - 1 + assert a.is_valid is None + a = interval(1, 3, is_valid=False) - interval(1, 3) + assert a.is_valid is False + a = interval(1, 3, is_valid=None) - interval(1, 3) + assert a.is_valid is None + + +def test_interval_inequality(): + assert (interval(1, 2) < interval(3, 4)) == (True, True) + assert (interval(1, 2) < interval(2, 4)) == (None, True) + assert (interval(1, 2) < interval(-2, 0)) == (False, True) + assert (interval(1, 2) <= interval(2, 4)) == (True, True) + assert (interval(1, 2) <= interval(1.5, 6)) == (None, True) + assert (interval(2, 3) <= interval(1, 2)) == (None, True) + assert (interval(2, 3) <= interval(1, 1.5)) == (False, True) + assert ( + interval(1, 2, is_valid=False) <= interval(-2, 0)) == (False, False) + assert (interval(1, 2, is_valid=None) <= interval(-2, 0)) == (False, None) + assert (interval(1, 2) <= 1.5) == (None, True) + assert (interval(1, 2) <= 3) == (True, True) + assert (interval(1, 2) <= 0) == (False, True) + assert (interval(5, 8) > interval(2, 3)) == (True, True) + assert (interval(2, 5) > interval(1, 3)) == (None, True) + assert (interval(2, 3) > interval(3.1, 5)) == (False, True) + + assert (interval(-1, 1) == 0) == (None, True) + assert (interval(-1, 1) == 2) == (False, True) + assert (interval(-1, 1) != 0) == (None, True) + assert (interval(-1, 1) != 2) == (True, True) + + assert (interval(3, 5) > 2) == (True, True) + assert (interval(3, 5) < 2) == (False, True) + assert (interval(1, 5) < 2) == (None, True) + assert (interval(1, 5) > 2) == (None, True) + assert (interval(0, 1) > 2) == (False, True) + assert (interval(1, 2) >= interval(0, 1)) == (True, True) + assert (interval(1, 2) >= interval(0, 1.5)) == (None, True) + assert (interval(1, 2) >= interval(3, 4)) == (False, True) + assert (interval(1, 2) >= 0) == (True, True) + assert (interval(1, 2) >= 1.2) == (None, True) + assert (interval(1, 2) >= 3) == (False, True) + assert (2 > interval(0, 1)) == (True, True) + a = interval(-1, 1, is_valid=False) < interval(2, 5, is_valid=None) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=False) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) < interval(2, 5, is_valid=None) + assert a == (True, None) + a = interval(-1, 1, is_valid=False) > interval(-5, -2, is_valid=None) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=False) + assert a == (True, False) + a = interval(-1, 1, is_valid=None) > interval(-5, -2, is_valid=None) + assert a == (True, None) + + +def test_interval_mul(): + assert ( + interval(1, 5) * interval(2, 10) == interval(2, 50)) == (True, True) + a = interval(-1, 1) * interval(2, 10) == interval(-10, 10) + assert a == (True, True) + + a = interval(-1, 1) * interval(-5, 3) == interval(-5, 5) + assert a == (True, True) + + assert (interval(1, 3) * 2 == interval(2, 6)) == (True, True) + assert (3 * interval(-1, 2) == interval(-3, 6)) == (True, True) + + a = 3 * interval(1, 2, is_valid=False) + assert a.is_valid is False + + a = 3 * interval(1, 2, is_valid=None) + assert a.is_valid is None + + a = interval(1, 5, is_valid=False) * interval(1, 2, is_valid=None) + assert a.is_valid is False + + +def test_interval_div(): + div = interval(1, 2, is_valid=False) / 3 + assert div == interval(-float('inf'), float('inf'), is_valid=False) + + div = interval(1, 2, is_valid=None) / 3 + assert div == interval(-float('inf'), float('inf'), is_valid=None) + + div = 3 / interval(1, 2, is_valid=None) + assert div == interval(-float('inf'), float('inf'), is_valid=None) + a = interval(1, 2) / 0 + assert a.is_valid is False + a = interval(0.5, 1) / interval(-1, 0) + assert a.is_valid is None + a = interval(0, 1) / interval(0, 1) + assert a.is_valid is None + + a = interval(-1, 1) / interval(-1, 1) + assert a.is_valid is None + + a = interval(-1, 2) / interval(0.5, 1) == interval(-2.0, 4.0) + assert a == (True, True) + a = interval(0, 1) / interval(0.5, 1) == interval(0.0, 2.0) + assert a == (True, True) + a = interval(-1, 0) / interval(0.5, 1) == interval(-2.0, 0.0) + assert a == (True, True) + a = interval(-0.5, -0.25) / interval(0.5, 1) == interval(-1.0, -0.25) + assert a == (True, True) + a = interval(0.5, 1) / interval(0.5, 1) == interval(0.5, 2.0) + assert a == (True, True) + a = interval(0.5, 4) / interval(0.5, 1) == interval(0.5, 8.0) + assert a == (True, True) + a = interval(-1, -0.5) / interval(0.5, 1) == interval(-2.0, -0.5) + assert a == (True, True) + a = interval(-4, -0.5) / interval(0.5, 1) == interval(-8.0, -0.5) + assert a == (True, True) + a = interval(-1, 2) / interval(-2, -0.5) == interval(-4.0, 2.0) + assert a == (True, True) + a = interval(0, 1) / interval(-2, -0.5) == interval(-2.0, 0.0) + assert a == (True, True) + a = interval(-1, 0) / interval(-2, -0.5) == interval(0.0, 2.0) + assert a == (True, True) + a = interval(-0.5, -0.25) / interval(-2, -0.5) == interval(0.125, 1.0) + assert a == (True, True) + a = interval(0.5, 1) / interval(-2, -0.5) == interval(-2.0, -0.25) + assert a == (True, True) + a = interval(0.5, 4) / interval(-2, -0.5) == interval(-8.0, -0.25) + assert a == (True, True) + a = interval(-1, -0.5) / interval(-2, -0.5) == interval(0.25, 2.0) + assert a == (True, True) + a = interval(-4, -0.5) / interval(-2, -0.5) == interval(0.25, 8.0) + assert a == (True, True) + a = interval(-5, 5, is_valid=False) / 2 + assert a.is_valid is False + +def test_hashable(): + ''' + test that interval objects are hashable. + this is required in order to be able to put them into the cache, which + appears to be necessary for plotting in py3k. For details, see: + + https://github.com/sympy/sympy/pull/2101 + https://github.com/sympy/sympy/issues/6533 + ''' + hash(interval(1, 1)) + hash(interval(1, 1, is_valid=True)) + hash(interval(-4, -0.5)) + hash(interval(-2, -0.5)) + hash(interval(0.25, 8.0)) diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__init__.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cd86a505d8c4b8026bd91cde27d441e00223a8bc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__init__.py @@ -0,0 +1,138 @@ +"""Plotting module that can plot 2D and 3D functions +""" + +from sympy.utilities.decorator import doctest_depends_on + +@doctest_depends_on(modules=('pyglet',)) +def PygletPlot(*args, **kwargs): + """ + + Plot Examples + ============= + + See examples/advanced/pyglet_plotting.py for many more examples. + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x, y, z + + >>> Plot(x*y**3-y*x**3) + [0]: -x**3*y + x*y**3, 'mode=cartesian' + + >>> p = Plot() + >>> p[1] = x*y + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + + >>> p = Plot() + >>> p[1] = x**2+y**2 + >>> p[2] = -x**2-y**2 + + + Variable Intervals + ================== + + The basic format is [var, min, max, steps], but the + syntax is flexible and arguments left out are taken + from the defaults for the current coordinate mode: + + >>> Plot(x**2) # implies [x,-5,5,100] + [0]: x**2, 'mode=cartesian' + + >>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100] + [0]: x**2 - y**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13,100]) + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(1*x, [], [x], mode='cylindrical') + ... # [unbound_theta,0,2*Pi,40], [x,-1,1,20] + [0]: x, 'mode=cartesian' + + + Coordinate Modes + ================ + + Plot supports several curvilinear coordinate modes, and + they independent for each plotted function. You can specify + a coordinate mode explicitly with the 'mode' named argument, + but it can be automatically determined for Cartesian or + parametric plots, and therefore must only be specified for + polar, cylindrical, and spherical modes. + + Specifically, Plot(function arguments) and Plot[n] = + (function arguments) will interpret your arguments as a + Cartesian plot if you provide one function and a parametric + plot if you provide two or three functions. Similarly, the + arguments will be interpreted as a curve if one variable is + used, and a surface if two are used. + + Supported mode names by number of variables: + + 1: parametric, cartesian, polar + 2: parametric, cartesian, cylindrical = polar, spherical + + >>> Plot(1, mode='spherical') + + + Calculator-like Interface + ========================= + + >>> p = Plot(visible=False) + >>> f = x**2 + >>> p[1] = f + >>> p[2] = f.diff(x) + >>> p[3] = f.diff(x).diff(x) + >>> p + [1]: x**2, 'mode=cartesian' + [2]: 2*x, 'mode=cartesian' + [3]: 2, 'mode=cartesian' + >>> p.show() + >>> p.clear() + >>> p + + >>> p[1] = x**2+y**2 + >>> p[1].style = 'solid' + >>> p[2] = -x**2-y**2 + >>> p[2].style = 'wireframe' + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + >>> p[1].style = 'both' + >>> p[2].style = 'both' + >>> p.close() + + + Plot Window Keyboard Controls + ============================= + + Screen Rotation: + X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2 + Z axis Q,E, Numpad 7,9 + + Model Rotation: + Z axis Z,C, Numpad 1,3 + + Zoom: R,F, PgUp,PgDn, Numpad +,- + + Reset Camera: X, Numpad 5 + + Camera Presets: + XY F1 + XZ F2 + YZ F3 + Perspective F4 + + Sensitivity Modifier: SHIFT + + Axes Toggle: + Visible F5 + Colors F6 + + Close Window: ESCAPE + + ============================= + """ + + from sympy.plotting.pygletplot.plot import PygletPlot + return PygletPlot(*args, **kwargs) diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b1293a10277bef193a7a23512eab61687f35124 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/color_scheme.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/color_scheme.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ba1144324fee64c1ee045198621ea2b2a45cd924 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/color_scheme.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/managed_window.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/managed_window.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68fce397cfbb82c5565bd37488d31279cadefa16 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/managed_window.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc84ca00736ffb5ba9bbe649d96cbc61fa635908 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_axes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_axes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..93a66d3b25960c99f6886a675a4f58ab2fe7ce2c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_axes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_camera.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_camera.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..43aa9fad638a6e5e2fa189478164001e27a236fa Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_camera.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_controller.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_controller.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c2601788bd76de16895e3366b2d1c6e9c53da088 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_controller.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_curve.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_curve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b5f98fc9622bb315f46b4af53df5a667743f685e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_curve.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_interval.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_interval.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e229266a3b8f53853999f38b3019d7aa6689a54 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_interval.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e7415cfa72b1e6b4b1067ce9d3b1215dda28bd34 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode_base.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode_base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7742e04343dcf56464186d51ea8dda543a5d12c4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_mode_base.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_modes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_modes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7380bd23806f49e4fea959f943aaed96a560deb7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_modes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_object.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_object.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e485692d9a9eb78d4478fbb2bb3bacb3ebe594db Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_object.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_rotation.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_rotation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5393c57df823cf80d089d5fce5c6c41405656bc7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_rotation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_surface.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_surface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5aecf21fcad8130a060f21260f08624f3d28a6d1 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_surface.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_window.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_window.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7764fa1c11d987d1e9975a8cd6d184c1a6f07b88 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/plot_window.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/util.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb50bf577f526e24a4f2227c6582eb9f0210ce3e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/__pycache__/util.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/color_scheme.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/color_scheme.py new file mode 100644 index 0000000000000000000000000000000000000000..613e777a7f45f54349c47d272aa6d1c157bcd117 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/color_scheme.py @@ -0,0 +1,336 @@ +from sympy.core.basic import Basic +from sympy.core.symbol import (Symbol, symbols) +from sympy.utilities.lambdify import lambdify +from .util import interpolate, rinterpolate, create_bounds, update_bounds +from sympy.utilities.iterables import sift + + +class ColorGradient: + colors = [0.4, 0.4, 0.4], [0.9, 0.9, 0.9] + intervals = 0.0, 1.0 + + def __init__(self, *args): + if len(args) == 2: + self.colors = list(args) + self.intervals = [0.0, 1.0] + elif len(args) > 0: + if len(args) % 2 != 0: + raise ValueError("len(args) should be even") + self.colors = [args[i] for i in range(1, len(args), 2)] + self.intervals = [args[i] for i in range(0, len(args), 2)] + assert len(self.colors) == len(self.intervals) + + def copy(self): + c = ColorGradient() + c.colors = [e[::] for e in self.colors] + c.intervals = self.intervals[::] + return c + + def _find_interval(self, v): + m = len(self.intervals) + i = 0 + while i < m - 1 and self.intervals[i] <= v: + i += 1 + return i + + def _interpolate_axis(self, axis, v): + i = self._find_interval(v) + v = rinterpolate(self.intervals[i - 1], self.intervals[i], v) + return interpolate(self.colors[i - 1][axis], self.colors[i][axis], v) + + def __call__(self, r, g, b): + c = self._interpolate_axis + return c(0, r), c(1, g), c(2, b) + +default_color_schemes = {} # defined at the bottom of this file + + +class ColorScheme: + + def __init__(self, *args, **kwargs): + self.args = args + self.f, self.gradient = None, ColorGradient() + + if len(args) == 1 and not isinstance(args[0], Basic) and callable(args[0]): + self.f = args[0] + elif len(args) == 1 and isinstance(args[0], str): + if args[0] in default_color_schemes: + cs = default_color_schemes[args[0]] + self.f, self.gradient = cs.f, cs.gradient.copy() + else: + self.f = lambdify('x,y,z,u,v', args[0]) + else: + self.f, self.gradient = self._interpret_args(args) + self._test_color_function() + if not isinstance(self.gradient, ColorGradient): + raise ValueError("Color gradient not properly initialized. " + "(Not a ColorGradient instance.)") + + def _interpret_args(self, args): + f, gradient = None, self.gradient + atoms, lists = self._sort_args(args) + s = self._pop_symbol_list(lists) + s = self._fill_in_vars(s) + + # prepare the error message for lambdification failure + f_str = ', '.join(str(fa) for fa in atoms) + s_str = (str(sa) for sa in s) + s_str = ', '.join(sa for sa in s_str if sa.find('unbound') < 0) + f_error = ValueError("Could not interpret arguments " + "%s as functions of %s." % (f_str, s_str)) + + # try to lambdify args + if len(atoms) == 1: + fv = atoms[0] + try: + f = lambdify(s, [fv, fv, fv]) + except TypeError: + raise f_error + + elif len(atoms) == 3: + fr, fg, fb = atoms + try: + f = lambdify(s, [fr, fg, fb]) + except TypeError: + raise f_error + + else: + raise ValueError("A ColorScheme must provide 1 or 3 " + "functions in x, y, z, u, and/or v.") + + # try to intrepret any given color information + if len(lists) == 0: + gargs = [] + + elif len(lists) == 1: + gargs = lists[0] + + elif len(lists) == 2: + try: + (r1, g1, b1), (r2, g2, b2) = lists + except TypeError: + raise ValueError("If two color arguments are given, " + "they must be given in the format " + "(r1, g1, b1), (r2, g2, b2).") + gargs = lists + + elif len(lists) == 3: + try: + (r1, r2), (g1, g2), (b1, b2) = lists + except Exception: + raise ValueError("If three color arguments are given, " + "they must be given in the format " + "(r1, r2), (g1, g2), (b1, b2). To create " + "a multi-step gradient, use the syntax " + "[0, colorStart, step1, color1, ..., 1, " + "colorEnd].") + gargs = [[r1, g1, b1], [r2, g2, b2]] + + else: + raise ValueError("Don't know what to do with collection " + "arguments %s." % (', '.join(str(l) for l in lists))) + + if gargs: + try: + gradient = ColorGradient(*gargs) + except Exception as ex: + raise ValueError(("Could not initialize a gradient " + "with arguments %s. Inner " + "exception: %s") % (gargs, str(ex))) + + return f, gradient + + def _pop_symbol_list(self, lists): + symbol_lists = [] + for l in lists: + mark = True + for s in l: + if s is not None and not isinstance(s, Symbol): + mark = False + break + if mark: + lists.remove(l) + symbol_lists.append(l) + if len(symbol_lists) == 1: + return symbol_lists[0] + elif len(symbol_lists) == 0: + return [] + else: + raise ValueError("Only one list of Symbols " + "can be given for a color scheme.") + + def _fill_in_vars(self, args): + defaults = symbols('x,y,z,u,v') + v_error = ValueError("Could not find what to plot.") + if len(args) == 0: + return defaults + if not isinstance(args, (tuple, list)): + raise v_error + if len(args) == 0: + return defaults + for s in args: + if s is not None and not isinstance(s, Symbol): + raise v_error + # when vars are given explicitly, any vars + # not given are marked 'unbound' as to not + # be accidentally used in an expression + vars = [Symbol('unbound%i' % (i)) for i in range(1, 6)] + # interpret as t + if len(args) == 1: + vars[3] = args[0] + # interpret as u,v + elif len(args) == 2: + if args[0] is not None: + vars[3] = args[0] + if args[1] is not None: + vars[4] = args[1] + # interpret as x,y,z + elif len(args) >= 3: + # allow some of x,y,z to be + # left unbound if not given + if args[0] is not None: + vars[0] = args[0] + if args[1] is not None: + vars[1] = args[1] + if args[2] is not None: + vars[2] = args[2] + # interpret the rest as t + if len(args) >= 4: + vars[3] = args[3] + # ...or u,v + if len(args) >= 5: + vars[4] = args[4] + return vars + + def _sort_args(self, args): + lists, atoms = sift(args, + lambda a: isinstance(a, (tuple, list)), binary=True) + return atoms, lists + + def _test_color_function(self): + if not callable(self.f): + raise ValueError("Color function is not callable.") + try: + result = self.f(0, 0, 0, 0, 0) + if len(result) != 3: + raise ValueError("length should be equal to 3") + except TypeError: + raise ValueError("Color function needs to accept x,y,z,u,v, " + "as arguments even if it doesn't use all of them.") + except AssertionError: + raise ValueError("Color function needs to return 3-tuple r,g,b.") + except Exception: + pass # color function probably not valid at 0,0,0,0,0 + + def __call__(self, x, y, z, u, v): + try: + return self.f(x, y, z, u, v) + except Exception: + return None + + def apply_to_curve(self, verts, u_set, set_len=None, inc_pos=None): + """ + Apply this color scheme to a + set of vertices over a single + independent variable u. + """ + bounds = create_bounds() + cverts = [] + if callable(set_len): + set_len(len(u_set)*2) + # calculate f() = r,g,b for each vert + # and find the min and max for r,g,b + for _u in range(len(u_set)): + if verts[_u] is None: + cverts.append(None) + else: + x, y, z = verts[_u] + u, v = u_set[_u], None + c = self(x, y, z, u, v) + if c is not None: + c = list(c) + update_bounds(bounds, c) + cverts.append(c) + if callable(inc_pos): + inc_pos() + # scale and apply gradient + for _u in range(len(u_set)): + if cverts[_u] is not None: + for _c in range(3): + # scale from [f_min, f_max] to [0,1] + cverts[_u][_c] = rinterpolate(bounds[_c][0], bounds[_c][1], + cverts[_u][_c]) + # apply gradient + cverts[_u] = self.gradient(*cverts[_u]) + if callable(inc_pos): + inc_pos() + return cverts + + def apply_to_surface(self, verts, u_set, v_set, set_len=None, inc_pos=None): + """ + Apply this color scheme to a + set of vertices over two + independent variables u and v. + """ + bounds = create_bounds() + cverts = [] + if callable(set_len): + set_len(len(u_set)*len(v_set)*2) + # calculate f() = r,g,b for each vert + # and find the min and max for r,g,b + for _u in range(len(u_set)): + column = [] + for _v in range(len(v_set)): + if verts[_u][_v] is None: + column.append(None) + else: + x, y, z = verts[_u][_v] + u, v = u_set[_u], v_set[_v] + c = self(x, y, z, u, v) + if c is not None: + c = list(c) + update_bounds(bounds, c) + column.append(c) + if callable(inc_pos): + inc_pos() + cverts.append(column) + # scale and apply gradient + for _u in range(len(u_set)): + for _v in range(len(v_set)): + if cverts[_u][_v] is not None: + # scale from [f_min, f_max] to [0,1] + for _c in range(3): + cverts[_u][_v][_c] = rinterpolate(bounds[_c][0], + bounds[_c][1], cverts[_u][_v][_c]) + # apply gradient + cverts[_u][_v] = self.gradient(*cverts[_u][_v]) + if callable(inc_pos): + inc_pos() + return cverts + + def str_base(self): + return ", ".join(str(a) for a in self.args) + + def __repr__(self): + return "%s" % (self.str_base()) + + +x, y, z, t, u, v = symbols('x,y,z,t,u,v') + +default_color_schemes['rainbow'] = ColorScheme(z, y, x) +default_color_schemes['zfade'] = ColorScheme(z, (0.4, 0.4, 0.97), + (0.97, 0.4, 0.4), (None, None, z)) +default_color_schemes['zfade3'] = ColorScheme(z, (None, None, z), + [0.00, (0.2, 0.2, 1.0), + 0.35, (0.2, 0.8, 0.4), + 0.50, (0.3, 0.9, 0.3), + 0.65, (0.4, 0.8, 0.2), + 1.00, (1.0, 0.2, 0.2)]) + +default_color_schemes['zfade4'] = ColorScheme(z, (None, None, z), + [0.0, (0.3, 0.3, 1.0), + 0.30, (0.3, 1.0, 0.3), + 0.55, (0.95, 1.0, 0.2), + 0.65, (1.0, 0.95, 0.2), + 0.85, (1.0, 0.7, 0.2), + 1.0, (1.0, 0.3, 0.2)]) diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/managed_window.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/managed_window.py new file mode 100644 index 0000000000000000000000000000000000000000..81fa2541b4dd9e13534aabfd2a11bf88c479daf8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/managed_window.py @@ -0,0 +1,106 @@ +from pyglet.window import Window +from pyglet.clock import Clock + +from threading import Thread, Lock + +gl_lock = Lock() + + +class ManagedWindow(Window): + """ + A pyglet window with an event loop which executes automatically + in a separate thread. Behavior is added by creating a subclass + which overrides setup, update, and/or draw. + """ + fps_limit = 30 + default_win_args = {"width": 600, + "height": 500, + "vsync": False, + "resizable": True} + + def __init__(self, **win_args): + """ + It is best not to override this function in the child + class, unless you need to take additional arguments. + Do any OpenGL initialization calls in setup(). + """ + + # check if this is run from the doctester + if win_args.get('runfromdoctester', False): + return + + self.win_args = dict(self.default_win_args, **win_args) + self.Thread = Thread(target=self.__event_loop__) + self.Thread.start() + + def __event_loop__(self, **win_args): + """ + The event loop thread function. Do not override or call + directly (it is called by __init__). + """ + gl_lock.acquire() + try: + try: + super().__init__(**self.win_args) + self.switch_to() + self.setup() + except Exception as e: + print("Window initialization failed: %s" % (str(e))) + self.has_exit = True + finally: + gl_lock.release() + + clock = Clock() + clock.fps_limit = self.fps_limit + while not self.has_exit: + dt = clock.tick() + gl_lock.acquire() + try: + try: + self.switch_to() + self.dispatch_events() + self.clear() + self.update(dt) + self.draw() + self.flip() + except Exception as e: + print("Uncaught exception in event loop: %s" % str(e)) + self.has_exit = True + finally: + gl_lock.release() + super().close() + + def close(self): + """ + Closes the window. + """ + self.has_exit = True + + def setup(self): + """ + Called once before the event loop begins. + Override this method in a child class. This + is the best place to put things like OpenGL + initialization calls. + """ + pass + + def update(self, dt): + """ + Called before draw during each iteration of + the event loop. dt is the elapsed time in + seconds since the last update. OpenGL rendering + calls are best put in draw() rather than here. + """ + pass + + def draw(self): + """ + Called after update during each iteration of + the event loop. Put OpenGL rendering calls + here. + """ + pass + +if __name__ == '__main__': + ManagedWindow() diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..8c3dd3c8d4ce6c660cc07f93a55029eef98e55a2 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot.py @@ -0,0 +1,464 @@ +from threading import RLock + +# it is sufficient to import "pyglet" here once +try: + import pyglet.gl as pgl +except ImportError: + raise ImportError("pyglet is required for plotting.\n " + "visit https://pyglet.org/") + +from sympy.core.numbers import Integer +from sympy.external.gmpy import SYMPY_INTS +from sympy.geometry.entity import GeometryEntity +from sympy.plotting.pygletplot.plot_axes import PlotAxes +from sympy.plotting.pygletplot.plot_mode import PlotMode +from sympy.plotting.pygletplot.plot_object import PlotObject +from sympy.plotting.pygletplot.plot_window import PlotWindow +from sympy.plotting.pygletplot.util import parse_option_string +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import is_sequence + +from time import sleep +from os import getcwd, listdir + +import ctypes + +@doctest_depends_on(modules=('pyglet',)) +class PygletPlot: + """ + Plot Examples + ============= + + See examples/advanced/pyglet_plotting.py for many more examples. + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x, y, z + + >>> Plot(x*y**3-y*x**3) + [0]: -x**3*y + x*y**3, 'mode=cartesian' + + >>> p = Plot() + >>> p[1] = x*y + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + + >>> p = Plot() + >>> p[1] = x**2+y**2 + >>> p[2] = -x**2-y**2 + + + Variable Intervals + ================== + + The basic format is [var, min, max, steps], but the + syntax is flexible and arguments left out are taken + from the defaults for the current coordinate mode: + + >>> Plot(x**2) # implies [x,-5,5,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [], []) # [x,-1,1,40], [y,-1,1,40] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2-y**2, [100], [100]) # [x,-1,1,100], [y,-1,1,100] + [0]: x**2 - y**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13,100]) + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [-13,13]) # [x,-13,13,100] + [0]: x**2, 'mode=cartesian' + >>> Plot(x**2, [x,-13,13]) # [x,-13,13,10] + [0]: x**2, 'mode=cartesian' + >>> Plot(1*x, [], [x], mode='cylindrical') + ... # [unbound_theta,0,2*Pi,40], [x,-1,1,20] + [0]: x, 'mode=cartesian' + + + Coordinate Modes + ================ + + Plot supports several curvilinear coordinate modes, and + they independent for each plotted function. You can specify + a coordinate mode explicitly with the 'mode' named argument, + but it can be automatically determined for Cartesian or + parametric plots, and therefore must only be specified for + polar, cylindrical, and spherical modes. + + Specifically, Plot(function arguments) and Plot[n] = + (function arguments) will interpret your arguments as a + Cartesian plot if you provide one function and a parametric + plot if you provide two or three functions. Similarly, the + arguments will be interpreted as a curve if one variable is + used, and a surface if two are used. + + Supported mode names by number of variables: + + 1: parametric, cartesian, polar + 2: parametric, cartesian, cylindrical = polar, spherical + + >>> Plot(1, mode='spherical') + + + Calculator-like Interface + ========================= + + >>> p = Plot(visible=False) + >>> f = x**2 + >>> p[1] = f + >>> p[2] = f.diff(x) + >>> p[3] = f.diff(x).diff(x) + >>> p + [1]: x**2, 'mode=cartesian' + [2]: 2*x, 'mode=cartesian' + [3]: 2, 'mode=cartesian' + >>> p.show() + >>> p.clear() + >>> p + + >>> p[1] = x**2+y**2 + >>> p[1].style = 'solid' + >>> p[2] = -x**2-y**2 + >>> p[2].style = 'wireframe' + >>> p[1].color = z, (0.4,0.4,0.9), (0.9,0.4,0.4) + >>> p[1].style = 'both' + >>> p[2].style = 'both' + >>> p.close() + + + Plot Window Keyboard Controls + ============================= + + Screen Rotation: + X,Y axis Arrow Keys, A,S,D,W, Numpad 4,6,8,2 + Z axis Q,E, Numpad 7,9 + + Model Rotation: + Z axis Z,C, Numpad 1,3 + + Zoom: R,F, PgUp,PgDn, Numpad +,- + + Reset Camera: X, Numpad 5 + + Camera Presets: + XY F1 + XZ F2 + YZ F3 + Perspective F4 + + Sensitivity Modifier: SHIFT + + Axes Toggle: + Visible F5 + Colors F6 + + Close Window: ESCAPE + + ============================= + + """ + + @doctest_depends_on(modules=('pyglet',)) + def __init__(self, *fargs, **win_args): + """ + Positional Arguments + ==================== + + Any given positional arguments are used to + initialize a plot function at index 1. In + other words... + + >>> from sympy.plotting.pygletplot import PygletPlot as Plot + >>> from sympy.abc import x + >>> p = Plot(x**2, visible=False) + + ...is equivalent to... + + >>> p = Plot(visible=False) + >>> p[1] = x**2 + + Note that in earlier versions of the plotting + module, you were able to specify multiple + functions in the initializer. This functionality + has been dropped in favor of better automatic + plot plot_mode detection. + + + Named Arguments + =============== + + axes + An option string of the form + "key1=value1; key2 = value2" which + can use the following options: + + style = ordinate + none OR frame OR box OR ordinate + + stride = 0.25 + val OR (val_x, val_y, val_z) + + overlay = True (draw on top of plot) + True OR False + + colored = False (False uses Black, + True uses colors + R,G,B = X,Y,Z) + True OR False + + label_axes = False (display axis names + at endpoints) + True OR False + + visible = True (show immediately + True OR False + + + The following named arguments are passed as + arguments to window initialization: + + antialiasing = True + True OR False + + ortho = False + True OR False + + invert_mouse_zoom = False + True OR False + + """ + # Register the plot modes + from . import plot_modes # noqa + + self._win_args = win_args + self._window = None + + self._render_lock = RLock() + + self._functions = {} + self._pobjects = [] + self._screenshot = ScreenShot(self) + + axe_options = parse_option_string(win_args.pop('axes', '')) + self.axes = PlotAxes(**axe_options) + self._pobjects.append(self.axes) + + self[0] = fargs + if win_args.get('visible', True): + self.show() + + ## Window Interfaces + + def show(self): + """ + Creates and displays a plot window, or activates it + (gives it focus) if it has already been created. + """ + if self._window and not self._window.has_exit: + self._window.activate() + else: + self._win_args['visible'] = True + self.axes.reset_resources() + + #if hasattr(self, '_doctest_depends_on'): + # self._win_args['runfromdoctester'] = True + + self._window = PlotWindow(self, **self._win_args) + + def close(self): + """ + Closes the plot window. + """ + if self._window: + self._window.close() + + def saveimage(self, outfile=None, format='', size=(600, 500)): + """ + Saves a screen capture of the plot window to an + image file. + + If outfile is given, it can either be a path + or a file object. Otherwise a png image will + be saved to the current working directory. + If the format is omitted, it is determined from + the filename extension. + """ + self._screenshot.save(outfile, format, size) + + ## Function List Interfaces + + def clear(self): + """ + Clears the function list of this plot. + """ + self._render_lock.acquire() + self._functions = {} + self.adjust_all_bounds() + self._render_lock.release() + + def __getitem__(self, i): + """ + Returns the function at position i in the + function list. + """ + return self._functions[i] + + def __setitem__(self, i, args): + """ + Parses and adds a PlotMode to the function + list. + """ + if not (isinstance(i, (SYMPY_INTS, Integer)) and i >= 0): + raise ValueError("Function index must " + "be an integer >= 0.") + + if isinstance(args, PlotObject): + f = args + else: + if (not is_sequence(args)) or isinstance(args, GeometryEntity): + args = [args] + if len(args) == 0: + return # no arguments given + kwargs = {"bounds_callback": self.adjust_all_bounds} + f = PlotMode(*args, **kwargs) + + if f: + self._render_lock.acquire() + self._functions[i] = f + self._render_lock.release() + else: + raise ValueError("Failed to parse '%s'." + % ', '.join(str(a) for a in args)) + + def __delitem__(self, i): + """ + Removes the function in the function list at + position i. + """ + self._render_lock.acquire() + del self._functions[i] + self.adjust_all_bounds() + self._render_lock.release() + + def firstavailableindex(self): + """ + Returns the first unused index in the function list. + """ + i = 0 + self._render_lock.acquire() + while i in self._functions: + i += 1 + self._render_lock.release() + return i + + def append(self, *args): + """ + Parses and adds a PlotMode to the function + list at the first available index. + """ + self.__setitem__(self.firstavailableindex(), args) + + def __len__(self): + """ + Returns the number of functions in the function list. + """ + return len(self._functions) + + def __iter__(self): + """ + Allows iteration of the function list. + """ + return self._functions.itervalues() + + def __repr__(self): + return str(self) + + def __str__(self): + """ + Returns a string containing a new-line separated + list of the functions in the function list. + """ + s = "" + if len(self._functions) == 0: + s += "" + else: + self._render_lock.acquire() + s += "\n".join(["%s[%i]: %s" % ("", i, str(self._functions[i])) + for i in self._functions]) + self._render_lock.release() + return s + + def adjust_all_bounds(self): + self._render_lock.acquire() + self.axes.reset_bounding_box() + for f in self._functions: + self.axes.adjust_bounds(self._functions[f].bounds) + self._render_lock.release() + + def wait_for_calculations(self): + sleep(0) + self._render_lock.acquire() + for f in self._functions: + a = self._functions[f]._get_calculating_verts + b = self._functions[f]._get_calculating_cverts + while a() or b(): + sleep(0) + self._render_lock.release() + +class ScreenShot: + def __init__(self, plot): + self._plot = plot + self.screenshot_requested = False + self.outfile = None + self.format = '' + self.invisibleMode = False + self.flag = 0 + + def __bool__(self): + return self.screenshot_requested + + def _execute_saving(self): + if self.flag < 3: + self.flag += 1 + return + + size_x, size_y = self._plot._window.get_size() + size = size_x*size_y*4*ctypes.sizeof(ctypes.c_ubyte) + image = ctypes.create_string_buffer(size) + pgl.glReadPixels(0, 0, size_x, size_y, pgl.GL_RGBA, pgl.GL_UNSIGNED_BYTE, image) + from PIL import Image + im = Image.frombuffer('RGBA', (size_x, size_y), + image.raw, 'raw', 'RGBA', 0, 1) + im.transpose(Image.FLIP_TOP_BOTTOM).save(self.outfile, self.format) + + self.flag = 0 + self.screenshot_requested = False + if self.invisibleMode: + self._plot._window.close() + + def save(self, outfile=None, format='', size=(600, 500)): + self.outfile = outfile + self.format = format + self.size = size + self.screenshot_requested = True + + if not self._plot._window or self._plot._window.has_exit: + self._plot._win_args['visible'] = False + + self._plot._win_args['width'] = size[0] + self._plot._win_args['height'] = size[1] + + self._plot.axes.reset_resources() + self._plot._window = PlotWindow(self._plot, **self._plot._win_args) + self.invisibleMode = True + + if self.outfile is None: + self.outfile = self._create_unique_path() + print(self.outfile) + + def _create_unique_path(self): + cwd = getcwd() + l = listdir(cwd) + path = '' + i = 0 + while True: + if not 'plot_%s.png' % i in l: + path = cwd + '/plot_%s.png' % i + break + i += 1 + return path diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_axes.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_axes.py new file mode 100644 index 0000000000000000000000000000000000000000..ae26fb0b2fa64e7f7318c51ce3fe5afaa276b48e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_axes.py @@ -0,0 +1,251 @@ +import pyglet.gl as pgl +from pyglet import font + +from sympy.core import S +from sympy.plotting.pygletplot.plot_object import PlotObject +from sympy.plotting.pygletplot.util import billboard_matrix, dot_product, \ + get_direction_vectors, strided_range, vec_mag, vec_sub +from sympy.utilities.iterables import is_sequence + + +class PlotAxes(PlotObject): + + def __init__(self, *args, + style='', none=None, frame=None, box=None, ordinate=None, + stride=0.25, + visible='', overlay='', colored='', label_axes='', label_ticks='', + tick_length=0.1, + font_face='Arial', font_size=28, + **kwargs): + # initialize style parameter + style = style.lower() + + # allow alias kwargs to override style kwarg + if none is not None: + style = 'none' + if frame is not None: + style = 'frame' + if box is not None: + style = 'box' + if ordinate is not None: + style = 'ordinate' + + if style in ['', 'ordinate']: + self._render_object = PlotAxesOrdinate(self) + elif style in ['frame', 'box']: + self._render_object = PlotAxesFrame(self) + elif style in ['none']: + self._render_object = None + else: + raise ValueError(("Unrecognized axes style %s.") % (style)) + + # initialize stride parameter + try: + stride = eval(stride) + except TypeError: + pass + if is_sequence(stride): + if len(stride) != 3: + raise ValueError("length should be equal to 3") + self._stride = stride + else: + self._stride = [stride, stride, stride] + self._tick_length = float(tick_length) + + # setup bounding box and ticks + self._origin = [0, 0, 0] + self.reset_bounding_box() + + def flexible_boolean(input, default): + if input in [True, False]: + return input + if input in ('f', 'F', 'false', 'False'): + return False + if input in ('t', 'T', 'true', 'True'): + return True + return default + + # initialize remaining parameters + self.visible = flexible_boolean(kwargs, True) + self._overlay = flexible_boolean(overlay, True) + self._colored = flexible_boolean(colored, False) + self._label_axes = flexible_boolean(label_axes, False) + self._label_ticks = flexible_boolean(label_ticks, True) + + # setup label font + self.font_face = font_face + self.font_size = font_size + + # this is also used to reinit the + # font on window close/reopen + self.reset_resources() + + def reset_resources(self): + self.label_font = None + + def reset_bounding_box(self): + self._bounding_box = [[None, None], [None, None], [None, None]] + self._axis_ticks = [[], [], []] + + def draw(self): + if self._render_object: + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT | pgl.GL_DEPTH_BUFFER_BIT) + if self._overlay: + pgl.glDisable(pgl.GL_DEPTH_TEST) + self._render_object.draw() + pgl.glPopAttrib() + + def adjust_bounds(self, child_bounds): + b = self._bounding_box + c = child_bounds + for i in range(3): + if abs(c[i][0]) is S.Infinity or abs(c[i][1]) is S.Infinity: + continue + b[i][0] = c[i][0] if b[i][0] is None else min([b[i][0], c[i][0]]) + b[i][1] = c[i][1] if b[i][1] is None else max([b[i][1], c[i][1]]) + self._bounding_box = b + self._recalculate_axis_ticks(i) + + def _recalculate_axis_ticks(self, axis): + b = self._bounding_box + if b[axis][0] is None or b[axis][1] is None: + self._axis_ticks[axis] = [] + else: + self._axis_ticks[axis] = strided_range(b[axis][0], b[axis][1], + self._stride[axis]) + + def toggle_visible(self): + self.visible = not self.visible + + def toggle_colors(self): + self._colored = not self._colored + + +class PlotAxesBase(PlotObject): + + def __init__(self, parent_axes): + self._p = parent_axes + + def draw(self): + color = [([0.2, 0.1, 0.3], [0.2, 0.1, 0.3], [0.2, 0.1, 0.3]), + ([0.9, 0.3, 0.5], [0.5, 1.0, 0.5], [0.3, 0.3, 0.9])][self._p._colored] + self.draw_background(color) + self.draw_axis(2, color[2]) + self.draw_axis(1, color[1]) + self.draw_axis(0, color[0]) + + def draw_background(self, color): + pass # optional + + def draw_axis(self, axis, color): + raise NotImplementedError() + + def draw_text(self, text, position, color, scale=1.0): + if len(color) == 3: + color = (color[0], color[1], color[2], 1.0) + + if self._p.label_font is None: + self._p.label_font = font.load(self._p.font_face, + self._p.font_size, + bold=True, italic=False) + + label = font.Text(self._p.label_font, text, + color=color, + valign=font.Text.BASELINE, + halign=font.Text.CENTER) + + pgl.glPushMatrix() + pgl.glTranslatef(*position) + billboard_matrix() + scale_factor = 0.005 * scale + pgl.glScalef(scale_factor, scale_factor, scale_factor) + pgl.glColor4f(0, 0, 0, 0) + label.draw() + pgl.glPopMatrix() + + def draw_line(self, v, color): + o = self._p._origin + pgl.glBegin(pgl.GL_LINES) + pgl.glColor3f(*color) + pgl.glVertex3f(v[0][0] + o[0], v[0][1] + o[1], v[0][2] + o[2]) + pgl.glVertex3f(v[1][0] + o[0], v[1][1] + o[1], v[1][2] + o[2]) + pgl.glEnd() + + +class PlotAxesOrdinate(PlotAxesBase): + + def __init__(self, parent_axes): + super().__init__(parent_axes) + + def draw_axis(self, axis, color): + ticks = self._p._axis_ticks[axis] + radius = self._p._tick_length / 2.0 + if len(ticks) < 2: + return + + # calculate the vector for this axis + axis_lines = [[0, 0, 0], [0, 0, 0]] + axis_lines[0][axis], axis_lines[1][axis] = ticks[0], ticks[-1] + axis_vector = vec_sub(axis_lines[1], axis_lines[0]) + + # calculate angle to the z direction vector + pos_z = get_direction_vectors()[2] + d = abs(dot_product(axis_vector, pos_z)) + d = d / vec_mag(axis_vector) + + # don't draw labels if we're looking down the axis + labels_visible = abs(d - 1.0) > 0.02 + + # draw the ticks and labels + for tick in ticks: + self.draw_tick_line(axis, color, radius, tick, labels_visible) + + # draw the axis line and labels + self.draw_axis_line(axis, color, ticks[0], ticks[-1], labels_visible) + + def draw_axis_line(self, axis, color, a_min, a_max, labels_visible): + axis_line = [[0, 0, 0], [0, 0, 0]] + axis_line[0][axis], axis_line[1][axis] = a_min, a_max + self.draw_line(axis_line, color) + if labels_visible: + self.draw_axis_line_labels(axis, color, axis_line) + + def draw_axis_line_labels(self, axis, color, axis_line): + if not self._p._label_axes: + return + axis_labels = [axis_line[0][::], axis_line[1][::]] + axis_labels[0][axis] -= 0.3 + axis_labels[1][axis] += 0.3 + a_str = ['X', 'Y', 'Z'][axis] + self.draw_text("-" + a_str, axis_labels[0], color) + self.draw_text("+" + a_str, axis_labels[1], color) + + def draw_tick_line(self, axis, color, radius, tick, labels_visible): + tick_axis = {0: 1, 1: 0, 2: 1}[axis] + tick_line = [[0, 0, 0], [0, 0, 0]] + tick_line[0][axis] = tick_line[1][axis] = tick + tick_line[0][tick_axis], tick_line[1][tick_axis] = -radius, radius + self.draw_line(tick_line, color) + if labels_visible: + self.draw_tick_line_label(axis, color, radius, tick) + + def draw_tick_line_label(self, axis, color, radius, tick): + if not self._p._label_axes: + return + tick_label_vector = [0, 0, 0] + tick_label_vector[axis] = tick + tick_label_vector[{0: 1, 1: 0, 2: 1}[axis]] = [-1, 1, 1][ + axis] * radius * 3.5 + self.draw_text(str(tick), tick_label_vector, color, scale=0.5) + + +class PlotAxesFrame(PlotAxesBase): + + def __init__(self, parent_axes): + super().__init__(parent_axes) + + def draw_background(self, color): + pass + + def draw_axis(self, axis, color): + raise NotImplementedError() diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_camera.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_camera.py new file mode 100644 index 0000000000000000000000000000000000000000..43598debac252ffd22beb8690fef30745259c634 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_camera.py @@ -0,0 +1,124 @@ +import pyglet.gl as pgl +from sympy.plotting.pygletplot.plot_rotation import get_spherical_rotatation +from sympy.plotting.pygletplot.util import get_model_matrix, model_to_screen, \ + screen_to_model, vec_subs + + +class PlotCamera: + + min_dist = 0.05 + max_dist = 500.0 + + min_ortho_dist = 100.0 + max_ortho_dist = 10000.0 + + _default_dist = 6.0 + _default_ortho_dist = 600.0 + + rot_presets = { + 'xy': (0, 0, 0), + 'xz': (-90, 0, 0), + 'yz': (0, 90, 0), + 'perspective': (-45, 0, -45) + } + + def __init__(self, window, ortho=False): + self.window = window + self.axes = self.window.plot.axes + self.ortho = ortho + self.reset() + + def init_rot_matrix(self): + pgl.glPushMatrix() + pgl.glLoadIdentity() + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def set_rot_preset(self, preset_name): + self.init_rot_matrix() + if preset_name not in self.rot_presets: + raise ValueError( + "%s is not a valid rotation preset." % preset_name) + r = self.rot_presets[preset_name] + self.euler_rotate(r[0], 1, 0, 0) + self.euler_rotate(r[1], 0, 1, 0) + self.euler_rotate(r[2], 0, 0, 1) + + def reset(self): + self._dist = 0.0 + self._x, self._y = 0.0, 0.0 + self._rot = None + if self.ortho: + self._dist = self._default_ortho_dist + else: + self._dist = self._default_dist + self.init_rot_matrix() + + def mult_rot_matrix(self, rot): + pgl.glPushMatrix() + pgl.glLoadMatrixf(rot) + pgl.glMultMatrixf(self._rot) + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def setup_projection(self): + pgl.glMatrixMode(pgl.GL_PROJECTION) + pgl.glLoadIdentity() + if self.ortho: + # yep, this is pseudo ortho (don't tell anyone) + pgl.gluPerspective( + 0.3, float(self.window.width)/float(self.window.height), + self.min_ortho_dist - 0.01, self.max_ortho_dist + 0.01) + else: + pgl.gluPerspective( + 30.0, float(self.window.width)/float(self.window.height), + self.min_dist - 0.01, self.max_dist + 0.01) + pgl.glMatrixMode(pgl.GL_MODELVIEW) + + def _get_scale(self): + return 1.0, 1.0, 1.0 + + def apply_transformation(self): + pgl.glLoadIdentity() + pgl.glTranslatef(self._x, self._y, -self._dist) + if self._rot is not None: + pgl.glMultMatrixf(self._rot) + pgl.glScalef(*self._get_scale()) + + def spherical_rotate(self, p1, p2, sensitivity=1.0): + mat = get_spherical_rotatation(p1, p2, self.window.width, + self.window.height, sensitivity) + if mat is not None: + self.mult_rot_matrix(mat) + + def euler_rotate(self, angle, x, y, z): + pgl.glPushMatrix() + pgl.glLoadMatrixf(self._rot) + pgl.glRotatef(angle, x, y, z) + self._rot = get_model_matrix() + pgl.glPopMatrix() + + def zoom_relative(self, clicks, sensitivity): + + if self.ortho: + dist_d = clicks * sensitivity * 50.0 + min_dist = self.min_ortho_dist + max_dist = self.max_ortho_dist + else: + dist_d = clicks * sensitivity + min_dist = self.min_dist + max_dist = self.max_dist + + new_dist = (self._dist - dist_d) + if (clicks < 0 and new_dist < max_dist) or new_dist > min_dist: + self._dist = new_dist + + def mouse_translate(self, x, y, dx, dy): + pgl.glPushMatrix() + pgl.glLoadIdentity() + pgl.glTranslatef(0, 0, -self._dist) + z = model_to_screen(0, 0, 0)[2] + d = vec_subs(screen_to_model(x, y, z), screen_to_model(x - dx, y - dy, z)) + pgl.glPopMatrix() + self._x += d[0] + self._y += d[1] diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_controller.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_controller.py new file mode 100644 index 0000000000000000000000000000000000000000..aa7e01e6fd17fddf07b733442208a0a4c9d87d5b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_controller.py @@ -0,0 +1,218 @@ +from pyglet.window import key +from pyglet.window.mouse import LEFT, RIGHT, MIDDLE +from sympy.plotting.pygletplot.util import get_direction_vectors, get_basis_vectors + + +class PlotController: + + normal_mouse_sensitivity = 4.0 + modified_mouse_sensitivity = 1.0 + + normal_key_sensitivity = 160.0 + modified_key_sensitivity = 40.0 + + keymap = { + key.LEFT: 'left', + key.A: 'left', + key.NUM_4: 'left', + + key.RIGHT: 'right', + key.D: 'right', + key.NUM_6: 'right', + + key.UP: 'up', + key.W: 'up', + key.NUM_8: 'up', + + key.DOWN: 'down', + key.S: 'down', + key.NUM_2: 'down', + + key.Z: 'rotate_z_neg', + key.NUM_1: 'rotate_z_neg', + + key.C: 'rotate_z_pos', + key.NUM_3: 'rotate_z_pos', + + key.Q: 'spin_left', + key.NUM_7: 'spin_left', + key.E: 'spin_right', + key.NUM_9: 'spin_right', + + key.X: 'reset_camera', + key.NUM_5: 'reset_camera', + + key.NUM_ADD: 'zoom_in', + key.PAGEUP: 'zoom_in', + key.R: 'zoom_in', + + key.NUM_SUBTRACT: 'zoom_out', + key.PAGEDOWN: 'zoom_out', + key.F: 'zoom_out', + + key.RSHIFT: 'modify_sensitivity', + key.LSHIFT: 'modify_sensitivity', + + key.F1: 'rot_preset_xy', + key.F2: 'rot_preset_xz', + key.F3: 'rot_preset_yz', + key.F4: 'rot_preset_perspective', + + key.F5: 'toggle_axes', + key.F6: 'toggle_axe_colors', + + key.F8: 'save_image' + } + + def __init__(self, window, *, invert_mouse_zoom=False, **kwargs): + self.invert_mouse_zoom = invert_mouse_zoom + self.window = window + self.camera = window.camera + self.action = { + # Rotation around the view Y (up) vector + 'left': False, + 'right': False, + # Rotation around the view X vector + 'up': False, + 'down': False, + # Rotation around the view Z vector + 'spin_left': False, + 'spin_right': False, + # Rotation around the model Z vector + 'rotate_z_neg': False, + 'rotate_z_pos': False, + # Reset to the default rotation + 'reset_camera': False, + # Performs camera z-translation + 'zoom_in': False, + 'zoom_out': False, + # Use alternative sensitivity (speed) + 'modify_sensitivity': False, + # Rotation presets + 'rot_preset_xy': False, + 'rot_preset_xz': False, + 'rot_preset_yz': False, + 'rot_preset_perspective': False, + # axes + 'toggle_axes': False, + 'toggle_axe_colors': False, + # screenshot + 'save_image': False + } + + def update(self, dt): + z = 0 + if self.action['zoom_out']: + z -= 1 + if self.action['zoom_in']: + z += 1 + if z != 0: + self.camera.zoom_relative(z/10.0, self.get_key_sensitivity()/10.0) + + dx, dy, dz = 0, 0, 0 + if self.action['left']: + dx -= 1 + if self.action['right']: + dx += 1 + if self.action['up']: + dy -= 1 + if self.action['down']: + dy += 1 + if self.action['spin_left']: + dz += 1 + if self.action['spin_right']: + dz -= 1 + + if not self.is_2D(): + if dx != 0: + self.camera.euler_rotate(dx*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[1])) + if dy != 0: + self.camera.euler_rotate(dy*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[0])) + if dz != 0: + self.camera.euler_rotate(dz*dt*self.get_key_sensitivity(), + *(get_direction_vectors()[2])) + else: + self.camera.mouse_translate(0, 0, dx*dt*self.get_key_sensitivity(), + -dy*dt*self.get_key_sensitivity()) + + rz = 0 + if self.action['rotate_z_neg'] and not self.is_2D(): + rz -= 1 + if self.action['rotate_z_pos'] and not self.is_2D(): + rz += 1 + + if rz != 0: + self.camera.euler_rotate(rz*dt*self.get_key_sensitivity(), + *(get_basis_vectors()[2])) + + if self.action['reset_camera']: + self.camera.reset() + + if self.action['rot_preset_xy']: + self.camera.set_rot_preset('xy') + if self.action['rot_preset_xz']: + self.camera.set_rot_preset('xz') + if self.action['rot_preset_yz']: + self.camera.set_rot_preset('yz') + if self.action['rot_preset_perspective']: + self.camera.set_rot_preset('perspective') + + if self.action['toggle_axes']: + self.action['toggle_axes'] = False + self.camera.axes.toggle_visible() + + if self.action['toggle_axe_colors']: + self.action['toggle_axe_colors'] = False + self.camera.axes.toggle_colors() + + if self.action['save_image']: + self.action['save_image'] = False + self.window.plot.saveimage() + + return True + + def get_mouse_sensitivity(self): + if self.action['modify_sensitivity']: + return self.modified_mouse_sensitivity + else: + return self.normal_mouse_sensitivity + + def get_key_sensitivity(self): + if self.action['modify_sensitivity']: + return self.modified_key_sensitivity + else: + return self.normal_key_sensitivity + + def on_key_press(self, symbol, modifiers): + if symbol in self.keymap: + self.action[self.keymap[symbol]] = True + + def on_key_release(self, symbol, modifiers): + if symbol in self.keymap: + self.action[self.keymap[symbol]] = False + + def on_mouse_drag(self, x, y, dx, dy, buttons, modifiers): + if buttons & LEFT: + if self.is_2D(): + self.camera.mouse_translate(x, y, dx, dy) + else: + self.camera.spherical_rotate((x - dx, y - dy), (x, y), + self.get_mouse_sensitivity()) + if buttons & MIDDLE: + self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy, + self.get_mouse_sensitivity()/20.0) + if buttons & RIGHT: + self.camera.mouse_translate(x, y, dx, dy) + + def on_mouse_scroll(self, x, y, dx, dy): + self.camera.zoom_relative([1, -1][self.invert_mouse_zoom]*dy, + self.get_mouse_sensitivity()) + + def is_2D(self): + functions = self.window.plot._functions + for i in functions: + if len(functions[i].i_vars) > 1 or len(functions[i].d_vars) > 2: + return False + return True diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_curve.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_curve.py new file mode 100644 index 0000000000000000000000000000000000000000..6b97dac843f58c76694d424f0b0b7e3499ba5202 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_curve.py @@ -0,0 +1,82 @@ +import pyglet.gl as pgl +from sympy.core import S +from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase + + +class PlotCurve(PlotModeBase): + + style_override = 'wireframe' + + def _on_calculate_verts(self): + self.t_interval = self.intervals[0] + self.t_set = list(self.t_interval.frange()) + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + evaluate = self._get_evaluator() + + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = float(self.t_interval.v_len) + + self.verts = [] + b = self.bounds + for t in self.t_set: + try: + _e = evaluate(t) # calculate vertex + except (NameError, ZeroDivisionError): + _e = None + if _e is not None: # update bounding box + for axis in range(3): + b[axis][0] = min([b[axis][0], _e[axis]]) + b[axis][1] = max([b[axis][1], _e[axis]]) + self.verts.append(_e) + self._calculating_verts_pos += 1.0 + + for axis in range(3): + b[axis][2] = b[axis][1] - b[axis][0] + if b[axis][2] == 0.0: + b[axis][2] = 1.0 + + self.push_wireframe(self.draw_verts(False)) + + def _on_calculate_cverts(self): + if not self.verts or not self.color: + return + + def set_work_len(n): + self._calculating_cverts_len = float(n) + + def inc_work_pos(): + self._calculating_cverts_pos += 1.0 + set_work_len(1) + self._calculating_cverts_pos = 0 + self.cverts = self.color.apply_to_curve(self.verts, + self.t_set, + set_len=set_work_len, + inc_pos=inc_work_pos) + self.push_wireframe(self.draw_verts(True)) + + def calculate_one_cvert(self, t): + vert = self.verts[t] + return self.color(vert[0], vert[1], vert[2], + self.t_set[t], None) + + def draw_verts(self, use_cverts): + def f(): + pgl.glBegin(pgl.GL_LINE_STRIP) + for t in range(len(self.t_set)): + p = self.verts[t] + if p is None: + pgl.glEnd() + pgl.glBegin(pgl.GL_LINE_STRIP) + continue + if use_cverts: + c = self.cverts[t] + if c is None: + c = (0, 0, 0) + pgl.glColor3f(*c) + else: + pgl.glColor3f(*self.default_wireframe_color) + pgl.glVertex3f(*p) + pgl.glEnd() + return f diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_interval.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_interval.py new file mode 100644 index 0000000000000000000000000000000000000000..085ab096915bbc4a3761b71736b4dd14f1ff779f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_interval.py @@ -0,0 +1,181 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.core.numbers import Integer + + +class PlotInterval: + """ + """ + _v, _v_min, _v_max, _v_steps = None, None, None, None + + def require_all_args(f): + def check(self, *args, **kwargs): + for g in [self._v, self._v_min, self._v_max, self._v_steps]: + if g is None: + raise ValueError("PlotInterval is incomplete.") + return f(self, *args, **kwargs) + return check + + def __init__(self, *args): + if len(args) == 1: + if isinstance(args[0], PlotInterval): + self.fill_from(args[0]) + return + elif isinstance(args[0], str): + try: + args = eval(args[0]) + except TypeError: + s_eval_error = "Could not interpret string %s." + raise ValueError(s_eval_error % (args[0])) + elif isinstance(args[0], (tuple, list)): + args = args[0] + else: + raise ValueError("Not an interval.") + if not isinstance(args, (tuple, list)) or len(args) > 4: + f_error = "PlotInterval must be a tuple or list of length 4 or less." + raise ValueError(f_error) + + args = list(args) + if len(args) > 0 and (args[0] is None or isinstance(args[0], Symbol)): + self.v = args.pop(0) + if len(args) in [2, 3]: + self.v_min = args.pop(0) + self.v_max = args.pop(0) + if len(args) == 1: + self.v_steps = args.pop(0) + elif len(args) == 1: + self.v_steps = args.pop(0) + + def get_v(self): + return self._v + + def set_v(self, v): + if v is None: + self._v = None + return + if not isinstance(v, Symbol): + raise ValueError("v must be a SymPy Symbol.") + self._v = v + + def get_v_min(self): + return self._v_min + + def set_v_min(self, v_min): + if v_min is None: + self._v_min = None + return + try: + self._v_min = sympify(v_min) + float(self._v_min.evalf()) + except TypeError: + raise ValueError("v_min could not be interpreted as a number.") + + def get_v_max(self): + return self._v_max + + def set_v_max(self, v_max): + if v_max is None: + self._v_max = None + return + try: + self._v_max = sympify(v_max) + float(self._v_max.evalf()) + except TypeError: + raise ValueError("v_max could not be interpreted as a number.") + + def get_v_steps(self): + return self._v_steps + + def set_v_steps(self, v_steps): + if v_steps is None: + self._v_steps = None + return + if isinstance(v_steps, int): + v_steps = Integer(v_steps) + elif not isinstance(v_steps, Integer): + raise ValueError("v_steps must be an int or SymPy Integer.") + if v_steps <= S.Zero: + raise ValueError("v_steps must be positive.") + self._v_steps = v_steps + + @require_all_args + def get_v_len(self): + return self.v_steps + 1 + + v = property(get_v, set_v) + v_min = property(get_v_min, set_v_min) + v_max = property(get_v_max, set_v_max) + v_steps = property(get_v_steps, set_v_steps) + v_len = property(get_v_len) + + def fill_from(self, b): + if b.v is not None: + self.v = b.v + if b.v_min is not None: + self.v_min = b.v_min + if b.v_max is not None: + self.v_max = b.v_max + if b.v_steps is not None: + self.v_steps = b.v_steps + + @staticmethod + def try_parse(*args): + """ + Returns a PlotInterval if args can be interpreted + as such, otherwise None. + """ + if len(args) == 1 and isinstance(args[0], PlotInterval): + return args[0] + try: + return PlotInterval(*args) + except ValueError: + return None + + def _str_base(self): + return ",".join([str(self.v), str(self.v_min), + str(self.v_max), str(self.v_steps)]) + + def __repr__(self): + """ + A string representing the interval in class constructor form. + """ + return "PlotInterval(%s)" % (self._str_base()) + + def __str__(self): + """ + A string representing the interval in list form. + """ + return "[%s]" % (self._str_base()) + + @require_all_args + def assert_complete(self): + pass + + @require_all_args + def vrange(self): + """ + Yields v_steps+1 SymPy numbers ranging from + v_min to v_max. + """ + d = (self.v_max - self.v_min) / self.v_steps + for i in range(self.v_steps + 1): + a = self.v_min + (d * Integer(i)) + yield a + + @require_all_args + def vrange2(self): + """ + Yields v_steps pairs of SymPy numbers ranging from + (v_min, v_min + step) to (v_max - step, v_max). + """ + d = (self.v_max - self.v_min) / self.v_steps + a = self.v_min + (d * S.Zero) + for i in range(self.v_steps): + b = self.v_min + (d * Integer(i + 1)) + yield a, b + a = b + + def frange(self): + for i in self.vrange(): + yield float(i.evalf()) diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_mode.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_mode.py new file mode 100644 index 0000000000000000000000000000000000000000..f4ee00db9177b98b3259438949836fe5b69416c2 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_mode.py @@ -0,0 +1,400 @@ +from .plot_interval import PlotInterval +from .plot_object import PlotObject +from .util import parse_option_string +from sympy.core.symbol import Symbol +from sympy.core.sympify import sympify +from sympy.geometry.entity import GeometryEntity +from sympy.utilities.iterables import is_sequence + + +class PlotMode(PlotObject): + """ + Grandparent class for plotting + modes. Serves as interface for + registration, lookup, and init + of modes. + + To create a new plot mode, + inherit from PlotModeBase + or one of its children, such + as PlotSurface or PlotCurve. + """ + + ## Class-level attributes + ## used to register and lookup + ## plot modes. See PlotModeBase + ## for descriptions and usage. + + i_vars, d_vars = '', '' + intervals = [] + aliases = [] + is_default = False + + ## Draw is the only method here which + ## is meant to be overridden in child + ## classes, and PlotModeBase provides + ## a base implementation. + def draw(self): + raise NotImplementedError() + + ## Everything else in this file has to + ## do with registration and retrieval + ## of plot modes. This is where I've + ## hidden much of the ugliness of automatic + ## plot mode divination... + + ## Plot mode registry data structures + _mode_alias_list = [] + _mode_map = { + 1: {1: {}, 2: {}}, + 2: {1: {}, 2: {}}, + 3: {1: {}, 2: {}}, + } # [d][i][alias_str]: class + _mode_default_map = { + 1: {}, + 2: {}, + 3: {}, + } # [d][i]: class + _i_var_max, _d_var_max = 2, 3 + + def __new__(cls, *args, **kwargs): + """ + This is the function which interprets + arguments given to Plot.__init__ and + Plot.__setattr__. Returns an initialized + instance of the appropriate child class. + """ + + newargs, newkwargs = PlotMode._extract_options(args, kwargs) + mode_arg = newkwargs.get('mode', '') + + # Interpret the arguments + d_vars, intervals = PlotMode._interpret_args(newargs) + i_vars = PlotMode._find_i_vars(d_vars, intervals) + i, d = max([len(i_vars), len(intervals)]), len(d_vars) + + # Find the appropriate mode + subcls = PlotMode._get_mode(mode_arg, i, d) + + # Create the object + o = object.__new__(subcls) + + # Do some setup for the mode instance + o.d_vars = d_vars + o._fill_i_vars(i_vars) + o._fill_intervals(intervals) + o.options = newkwargs + + return o + + @staticmethod + def _get_mode(mode_arg, i_var_count, d_var_count): + """ + Tries to return an appropriate mode class. + Intended to be called only by __new__. + + mode_arg + Can be a string or a class. If it is a + PlotMode subclass, it is simply returned. + If it is a string, it can an alias for + a mode or an empty string. In the latter + case, we try to find a default mode for + the i_var_count and d_var_count. + + i_var_count + The number of independent variables + needed to evaluate the d_vars. + + d_var_count + The number of dependent variables; + usually the number of functions to + be evaluated in plotting. + + For example, a Cartesian function y = f(x) has + one i_var (x) and one d_var (y). A parametric + form x,y,z = f(u,v), f(u,v), f(u,v) has two + two i_vars (u,v) and three d_vars (x,y,z). + """ + # if the mode_arg is simply a PlotMode class, + # check that the mode supports the numbers + # of independent and dependent vars, then + # return it + try: + m = None + if issubclass(mode_arg, PlotMode): + m = mode_arg + except TypeError: + pass + if m: + if not m._was_initialized: + raise ValueError(("To use unregistered plot mode %s " + "you must first call %s._init_mode().") + % (m.__name__, m.__name__)) + if d_var_count != m.d_var_count: + raise ValueError(("%s can only plot functions " + "with %i dependent variables.") + % (m.__name__, + m.d_var_count)) + if i_var_count > m.i_var_count: + raise ValueError(("%s cannot plot functions " + "with more than %i independent " + "variables.") + % (m.__name__, + m.i_var_count)) + return m + # If it is a string, there are two possibilities. + if isinstance(mode_arg, str): + i, d = i_var_count, d_var_count + if i > PlotMode._i_var_max: + raise ValueError(var_count_error(True, True)) + if d > PlotMode._d_var_max: + raise ValueError(var_count_error(False, True)) + # If the string is '', try to find a suitable + # default mode + if not mode_arg: + return PlotMode._get_default_mode(i, d) + # Otherwise, interpret the string as a mode + # alias (e.g. 'cartesian', 'parametric', etc) + else: + return PlotMode._get_aliased_mode(mode_arg, i, d) + else: + raise ValueError("PlotMode argument must be " + "a class or a string") + + @staticmethod + def _get_default_mode(i, d, i_vars=-1): + if i_vars == -1: + i_vars = i + try: + return PlotMode._mode_default_map[d][i] + except KeyError: + # Keep looking for modes in higher i var counts + # which support the given d var count until we + # reach the max i_var count. + if i < PlotMode._i_var_max: + return PlotMode._get_default_mode(i + 1, d, i_vars) + else: + raise ValueError(("Couldn't find a default mode " + "for %i independent and %i " + "dependent variables.") % (i_vars, d)) + + @staticmethod + def _get_aliased_mode(alias, i, d, i_vars=-1): + if i_vars == -1: + i_vars = i + if alias not in PlotMode._mode_alias_list: + raise ValueError(("Couldn't find a mode called" + " %s. Known modes: %s.") + % (alias, ", ".join(PlotMode._mode_alias_list))) + try: + return PlotMode._mode_map[d][i][alias] + except TypeError: + # Keep looking for modes in higher i var counts + # which support the given d var count and alias + # until we reach the max i_var count. + if i < PlotMode._i_var_max: + return PlotMode._get_aliased_mode(alias, i + 1, d, i_vars) + else: + raise ValueError(("Couldn't find a %s mode " + "for %i independent and %i " + "dependent variables.") + % (alias, i_vars, d)) + + @classmethod + def _register(cls): + """ + Called once for each user-usable plot mode. + For Cartesian2D, it is invoked after the + class definition: Cartesian2D._register() + """ + name = cls.__name__ + cls._init_mode() + + try: + i, d = cls.i_var_count, cls.d_var_count + # Add the mode to _mode_map under all + # given aliases + for a in cls.aliases: + if a not in PlotMode._mode_alias_list: + # Also track valid aliases, so + # we can quickly know when given + # an invalid one in _get_mode. + PlotMode._mode_alias_list.append(a) + PlotMode._mode_map[d][i][a] = cls + if cls.is_default: + # If this mode was marked as the + # default for this d,i combination, + # also set that. + PlotMode._mode_default_map[d][i] = cls + + except Exception as e: + raise RuntimeError(("Failed to register " + "plot mode %s. Reason: %s") + % (name, (str(e)))) + + @classmethod + def _init_mode(cls): + """ + Initializes the plot mode based on + the 'mode-specific parameters' above. + Only intended to be called by + PlotMode._register(). To use a mode without + registering it, you can directly call + ModeSubclass._init_mode(). + """ + def symbols_list(symbol_str): + return [Symbol(s) for s in symbol_str] + + # Convert the vars strs into + # lists of symbols. + cls.i_vars = symbols_list(cls.i_vars) + cls.d_vars = symbols_list(cls.d_vars) + + # Var count is used often, calculate + # it once here + cls.i_var_count = len(cls.i_vars) + cls.d_var_count = len(cls.d_vars) + + if cls.i_var_count > PlotMode._i_var_max: + raise ValueError(var_count_error(True, False)) + if cls.d_var_count > PlotMode._d_var_max: + raise ValueError(var_count_error(False, False)) + + # Try to use first alias as primary_alias + if len(cls.aliases) > 0: + cls.primary_alias = cls.aliases[0] + else: + cls.primary_alias = cls.__name__ + + di = cls.intervals + if len(di) != cls.i_var_count: + raise ValueError("Plot mode must provide a " + "default interval for each i_var.") + for i in range(cls.i_var_count): + # default intervals must be given [min,max,steps] + # (no var, but they must be in the same order as i_vars) + if len(di[i]) != 3: + raise ValueError("length should be equal to 3") + + # Initialize an incomplete interval, + # to later be filled with a var when + # the mode is instantiated. + di[i] = PlotInterval(None, *di[i]) + + # To prevent people from using modes + # without these required fields set up. + cls._was_initialized = True + + _was_initialized = False + + ## Initializer Helper Methods + + @staticmethod + def _find_i_vars(functions, intervals): + i_vars = [] + + # First, collect i_vars in the + # order they are given in any + # intervals. + for i in intervals: + if i.v is None: + continue + elif i.v in i_vars: + raise ValueError(("Multiple intervals given " + "for %s.") % (str(i.v))) + i_vars.append(i.v) + + # Then, find any remaining + # i_vars in given functions + # (aka d_vars) + for f in functions: + for a in f.free_symbols: + if a not in i_vars: + i_vars.append(a) + + return i_vars + + def _fill_i_vars(self, i_vars): + # copy default i_vars + self.i_vars = [Symbol(str(i)) for i in self.i_vars] + # replace with given i_vars + for i in range(len(i_vars)): + self.i_vars[i] = i_vars[i] + + def _fill_intervals(self, intervals): + # copy default intervals + self.intervals = [PlotInterval(i) for i in self.intervals] + # track i_vars used so far + v_used = [] + # fill copy of default + # intervals with given info + for i in range(len(intervals)): + self.intervals[i].fill_from(intervals[i]) + if self.intervals[i].v is not None: + v_used.append(self.intervals[i].v) + # Find any orphan intervals and + # assign them i_vars + for i in range(len(self.intervals)): + if self.intervals[i].v is None: + u = [v for v in self.i_vars if v not in v_used] + if len(u) == 0: + raise ValueError("length should not be equal to 0") + self.intervals[i].v = u[0] + v_used.append(u[0]) + + @staticmethod + def _interpret_args(args): + interval_wrong_order = "PlotInterval %s was given before any function(s)." + interpret_error = "Could not interpret %s as a function or interval." + + functions, intervals = [], [] + if isinstance(args[0], GeometryEntity): + for coords in list(args[0].arbitrary_point()): + functions.append(coords) + intervals.append(PlotInterval.try_parse(args[0].plot_interval())) + else: + for a in args: + i = PlotInterval.try_parse(a) + if i is not None: + if len(functions) == 0: + raise ValueError(interval_wrong_order % (str(i))) + else: + intervals.append(i) + else: + if is_sequence(a, include=str): + raise ValueError(interpret_error % (str(a))) + try: + f = sympify(a) + functions.append(f) + except TypeError: + raise ValueError(interpret_error % str(a)) + + return functions, intervals + + @staticmethod + def _extract_options(args, kwargs): + newkwargs, newargs = {}, [] + for a in args: + if isinstance(a, str): + newkwargs = dict(newkwargs, **parse_option_string(a)) + else: + newargs.append(a) + newkwargs = dict(newkwargs, **kwargs) + return newargs, newkwargs + + +def var_count_error(is_independent, is_plotting): + """ + Used to format an error message which differs + slightly in 4 places. + """ + if is_plotting: + v = "Plotting" + else: + v = "Registering plot modes" + if is_independent: + n, s = PlotMode._i_var_max, "independent" + else: + n, s = PlotMode._d_var_max, "dependent" + return ("%s with more than %i %s variables " + "is not supported.") % (v, n, s) diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_mode_base.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_mode_base.py new file mode 100644 index 0000000000000000000000000000000000000000..2c6503650afda122e271bdecb2365c8fa20f2376 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_mode_base.py @@ -0,0 +1,378 @@ +import pyglet.gl as pgl +from sympy.core import S +from sympy.plotting.pygletplot.color_scheme import ColorScheme +from sympy.plotting.pygletplot.plot_mode import PlotMode +from sympy.utilities.iterables import is_sequence +from time import sleep +from threading import Thread, Event, RLock +import warnings + + +class PlotModeBase(PlotMode): + """ + Intended parent class for plotting + modes. Provides base functionality + in conjunction with its parent, + PlotMode. + """ + + ## + ## Class-Level Attributes + ## + + """ + The following attributes are meant + to be set at the class level, and serve + as parameters to the plot mode registry + (in PlotMode). See plot_modes.py for + concrete examples. + """ + + """ + i_vars + 'x' for Cartesian2D + 'xy' for Cartesian3D + etc. + + d_vars + 'y' for Cartesian2D + 'r' for Polar + etc. + """ + i_vars, d_vars = '', '' + + """ + intervals + Default intervals for each i_var, and in the + same order. Specified [min, max, steps]. + No variable can be given (it is bound later). + """ + intervals = [] + + """ + aliases + A list of strings which can be used to + access this mode. + 'cartesian' for Cartesian2D and Cartesian3D + 'polar' for Polar + 'cylindrical', 'polar' for Cylindrical + + Note that _init_mode chooses the first alias + in the list as the mode's primary_alias, which + will be displayed to the end user in certain + contexts. + """ + aliases = [] + + """ + is_default + Whether to set this mode as the default + for arguments passed to PlotMode() containing + the same number of d_vars as this mode and + at most the same number of i_vars. + """ + is_default = False + + """ + All of the above attributes are defined in PlotMode. + The following ones are specific to PlotModeBase. + """ + + """ + A list of the render styles. Do not modify. + """ + styles = {'wireframe': 1, 'solid': 2, 'both': 3} + + """ + style_override + Always use this style if not blank. + """ + style_override = '' + + """ + default_wireframe_color + default_solid_color + Can be used when color is None or being calculated. + Used by PlotCurve and PlotSurface, but not anywhere + in PlotModeBase. + """ + + default_wireframe_color = (0.85, 0.85, 0.85) + default_solid_color = (0.6, 0.6, 0.9) + default_rot_preset = 'xy' + + ## + ## Instance-Level Attributes + ## + + ## 'Abstract' member functions + def _get_evaluator(self): + if self.use_lambda_eval: + try: + e = self._get_lambda_evaluator() + return e + except Exception: + warnings.warn("\nWarning: creating lambda evaluator failed. " + "Falling back on SymPy subs evaluator.") + return self._get_sympy_evaluator() + + def _get_sympy_evaluator(self): + raise NotImplementedError() + + def _get_lambda_evaluator(self): + raise NotImplementedError() + + def _on_calculate_verts(self): + raise NotImplementedError() + + def _on_calculate_cverts(self): + raise NotImplementedError() + + ## Base member functions + def __init__(self, *args, bounds_callback=None, **kwargs): + self.verts = [] + self.cverts = [] + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + self.cbounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + + self._draw_lock = RLock() + + self._calculating_verts = Event() + self._calculating_cverts = Event() + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = 0.0 + self._calculating_cverts_pos = 0.0 + self._calculating_cverts_len = 0.0 + + self._max_render_stack_size = 3 + self._draw_wireframe = [-1] + self._draw_solid = [-1] + + self._style = None + self._color = None + + self.predraw = [] + self.postdraw = [] + + self.use_lambda_eval = self.options.pop('use_sympy_eval', None) is None + self.style = self.options.pop('style', '') + self.color = self.options.pop('color', 'rainbow') + self.bounds_callback = bounds_callback + + self._on_calculate() + + def synchronized(f): + def w(self, *args, **kwargs): + self._draw_lock.acquire() + try: + r = f(self, *args, **kwargs) + return r + finally: + self._draw_lock.release() + return w + + @synchronized + def push_wireframe(self, function): + """ + Push a function which performs gl commands + used to build a display list. (The list is + built outside of the function) + """ + assert callable(function) + self._draw_wireframe.append(function) + if len(self._draw_wireframe) > self._max_render_stack_size: + del self._draw_wireframe[1] # leave marker element + + @synchronized + def push_solid(self, function): + """ + Push a function which performs gl commands + used to build a display list. (The list is + built outside of the function) + """ + assert callable(function) + self._draw_solid.append(function) + if len(self._draw_solid) > self._max_render_stack_size: + del self._draw_solid[1] # leave marker element + + def _create_display_list(self, function): + dl = pgl.glGenLists(1) + pgl.glNewList(dl, pgl.GL_COMPILE) + function() + pgl.glEndList() + return dl + + def _render_stack_top(self, render_stack): + top = render_stack[-1] + if top == -1: + return -1 # nothing to display + elif callable(top): + dl = self._create_display_list(top) + render_stack[-1] = (dl, top) + return dl # display newly added list + elif len(top) == 2: + if pgl.GL_TRUE == pgl.glIsList(top[0]): + return top[0] # display stored list + dl = self._create_display_list(top[1]) + render_stack[-1] = (dl, top[1]) + return dl # display regenerated list + + def _draw_solid_display_list(self, dl): + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT) + pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_FILL) + pgl.glCallList(dl) + pgl.glPopAttrib() + + def _draw_wireframe_display_list(self, dl): + pgl.glPushAttrib(pgl.GL_ENABLE_BIT | pgl.GL_POLYGON_BIT) + pgl.glPolygonMode(pgl.GL_FRONT_AND_BACK, pgl.GL_LINE) + pgl.glEnable(pgl.GL_POLYGON_OFFSET_LINE) + pgl.glPolygonOffset(-0.005, -50.0) + pgl.glCallList(dl) + pgl.glPopAttrib() + + @synchronized + def draw(self): + for f in self.predraw: + if callable(f): + f() + if self.style_override: + style = self.styles[self.style_override] + else: + style = self.styles[self._style] + # Draw solid component if style includes solid + if style & 2: + dl = self._render_stack_top(self._draw_solid) + if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl): + self._draw_solid_display_list(dl) + # Draw wireframe component if style includes wireframe + if style & 1: + dl = self._render_stack_top(self._draw_wireframe) + if dl > 0 and pgl.GL_TRUE == pgl.glIsList(dl): + self._draw_wireframe_display_list(dl) + for f in self.postdraw: + if callable(f): + f() + + def _on_change_color(self, color): + Thread(target=self._calculate_cverts).start() + + def _on_calculate(self): + Thread(target=self._calculate_all).start() + + def _calculate_all(self): + self._calculate_verts() + self._calculate_cverts() + + def _calculate_verts(self): + if self._calculating_verts.is_set(): + return + self._calculating_verts.set() + try: + self._on_calculate_verts() + finally: + self._calculating_verts.clear() + if callable(self.bounds_callback): + self.bounds_callback() + + def _calculate_cverts(self): + if self._calculating_verts.is_set(): + return + while self._calculating_cverts.is_set(): + sleep(0) # wait for previous calculation + self._calculating_cverts.set() + try: + self._on_calculate_cverts() + finally: + self._calculating_cverts.clear() + + def _get_calculating_verts(self): + return self._calculating_verts.is_set() + + def _get_calculating_verts_pos(self): + return self._calculating_verts_pos + + def _get_calculating_verts_len(self): + return self._calculating_verts_len + + def _get_calculating_cverts(self): + return self._calculating_cverts.is_set() + + def _get_calculating_cverts_pos(self): + return self._calculating_cverts_pos + + def _get_calculating_cverts_len(self): + return self._calculating_cverts_len + + ## Property handlers + def _get_style(self): + return self._style + + @synchronized + def _set_style(self, v): + if v is None: + return + if v == '': + step_max = 0 + for i in self.intervals: + if i.v_steps is None: + continue + step_max = max([step_max, int(i.v_steps)]) + v = ['both', 'solid'][step_max > 40] + if v not in self.styles: + raise ValueError("v should be there in self.styles") + if v == self._style: + return + self._style = v + + def _get_color(self): + return self._color + + @synchronized + def _set_color(self, v): + try: + if v is not None: + if is_sequence(v): + v = ColorScheme(*v) + else: + v = ColorScheme(v) + if repr(v) == repr(self._color): + return + self._on_change_color(v) + self._color = v + except Exception as e: + raise RuntimeError("Color change failed. " + "Reason: %s" % (str(e))) + + style = property(_get_style, _set_style) + color = property(_get_color, _set_color) + + calculating_verts = property(_get_calculating_verts) + calculating_verts_pos = property(_get_calculating_verts_pos) + calculating_verts_len = property(_get_calculating_verts_len) + + calculating_cverts = property(_get_calculating_cverts) + calculating_cverts_pos = property(_get_calculating_cverts_pos) + calculating_cverts_len = property(_get_calculating_cverts_len) + + ## String representations + + def __str__(self): + f = ", ".join(str(d) for d in self.d_vars) + o = "'mode=%s'" % (self.primary_alias) + return ", ".join([f, o]) + + def __repr__(self): + f = ", ".join(str(d) for d in self.d_vars) + i = ", ".join(str(i) for i in self.intervals) + d = [('mode', self.primary_alias), + ('color', str(self.color)), + ('style', str(self.style))] + + o = "'%s'" % ("; ".join("%s=%s" % (k, v) + for k, v in d if v != 'None')) + return ", ".join([f, i, o]) diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_modes.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_modes.py new file mode 100644 index 0000000000000000000000000000000000000000..e78e0b4ce291b071f684fa3ffc02f456dffe0023 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_modes.py @@ -0,0 +1,209 @@ +from sympy.utilities.lambdify import lambdify +from sympy.core.numbers import pi +from sympy.functions import sin, cos +from sympy.plotting.pygletplot.plot_curve import PlotCurve +from sympy.plotting.pygletplot.plot_surface import PlotSurface + +from math import sin as p_sin +from math import cos as p_cos + + +def float_vec3(f): + def inner(*args): + v = f(*args) + return float(v[0]), float(v[1]), float(v[2]) + return inner + + +class Cartesian2D(PlotCurve): + i_vars, d_vars = 'x', 'y' + intervals = [[-5, 5, 100]] + aliases = ['cartesian'] + is_default = True + + def _get_sympy_evaluator(self): + fy = self.d_vars[0] + x = self.t_interval.v + + @float_vec3 + def e(_x): + return (_x, fy.subs(x, _x), 0.0) + return e + + def _get_lambda_evaluator(self): + fy = self.d_vars[0] + x = self.t_interval.v + return lambdify([x], [x, fy, 0.0]) + + +class Cartesian3D(PlotSurface): + i_vars, d_vars = 'xy', 'z' + intervals = [[-1, 1, 40], [-1, 1, 40]] + aliases = ['cartesian', 'monge'] + is_default = True + + def _get_sympy_evaluator(self): + fz = self.d_vars[0] + x = self.u_interval.v + y = self.v_interval.v + + @float_vec3 + def e(_x, _y): + return (_x, _y, fz.subs(x, _x).subs(y, _y)) + return e + + def _get_lambda_evaluator(self): + fz = self.d_vars[0] + x = self.u_interval.v + y = self.v_interval.v + return lambdify([x, y], [x, y, fz]) + + +class ParametricCurve2D(PlotCurve): + i_vars, d_vars = 't', 'xy' + intervals = [[0, 2*pi, 100]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy = self.d_vars + t = self.t_interval.v + + @float_vec3 + def e(_t): + return (fx.subs(t, _t), fy.subs(t, _t), 0.0) + return e + + def _get_lambda_evaluator(self): + fx, fy = self.d_vars + t = self.t_interval.v + return lambdify([t], [fx, fy, 0.0]) + + +class ParametricCurve3D(PlotCurve): + i_vars, d_vars = 't', 'xyz' + intervals = [[0, 2*pi, 100]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy, fz = self.d_vars + t = self.t_interval.v + + @float_vec3 + def e(_t): + return (fx.subs(t, _t), fy.subs(t, _t), fz.subs(t, _t)) + return e + + def _get_lambda_evaluator(self): + fx, fy, fz = self.d_vars + t = self.t_interval.v + return lambdify([t], [fx, fy, fz]) + + +class ParametricSurface(PlotSurface): + i_vars, d_vars = 'uv', 'xyz' + intervals = [[-1, 1, 40], [-1, 1, 40]] + aliases = ['parametric'] + is_default = True + + def _get_sympy_evaluator(self): + fx, fy, fz = self.d_vars + u = self.u_interval.v + v = self.v_interval.v + + @float_vec3 + def e(_u, _v): + return (fx.subs(u, _u).subs(v, _v), + fy.subs(u, _u).subs(v, _v), + fz.subs(u, _u).subs(v, _v)) + return e + + def _get_lambda_evaluator(self): + fx, fy, fz = self.d_vars + u = self.u_interval.v + v = self.v_interval.v + return lambdify([u, v], [fx, fy, fz]) + + +class Polar(PlotCurve): + i_vars, d_vars = 't', 'r' + intervals = [[0, 2*pi, 100]] + aliases = ['polar'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.t_interval.v + + def e(_t): + _r = float(fr.subs(t, _t)) + return (_r*p_cos(_t), _r*p_sin(_t), 0.0) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.t_interval.v + fx, fy = fr*cos(t), fr*sin(t) + return lambdify([t], [fx, fy, 0.0]) + + +class Cylindrical(PlotSurface): + i_vars, d_vars = 'th', 'r' + intervals = [[0, 2*pi, 40], [-1, 1, 20]] + aliases = ['cylindrical', 'polar'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + h = self.v_interval.v + + def e(_t, _h): + _r = float(fr.subs(t, _t).subs(h, _h)) + return (_r*p_cos(_t), _r*p_sin(_t), _h) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + h = self.v_interval.v + fx, fy = fr*cos(t), fr*sin(t) + return lambdify([t, h], [fx, fy, h]) + + +class Spherical(PlotSurface): + i_vars, d_vars = 'tp', 'r' + intervals = [[0, 2*pi, 40], [0, pi, 20]] + aliases = ['spherical'] + is_default = False + + def _get_sympy_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + p = self.v_interval.v + + def e(_t, _p): + _r = float(fr.subs(t, _t).subs(p, _p)) + return (_r*p_cos(_t)*p_sin(_p), + _r*p_sin(_t)*p_sin(_p), + _r*p_cos(_p)) + return e + + def _get_lambda_evaluator(self): + fr = self.d_vars[0] + t = self.u_interval.v + p = self.v_interval.v + fx = fr * cos(t) * sin(p) + fy = fr * sin(t) * sin(p) + fz = fr * cos(p) + return lambdify([t, p], [fx, fy, fz]) + +Cartesian2D._register() +Cartesian3D._register() +ParametricCurve2D._register() +ParametricCurve3D._register() +ParametricSurface._register() +Polar._register() +Cylindrical._register() +Spherical._register() diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_object.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_object.py new file mode 100644 index 0000000000000000000000000000000000000000..e51040fb8b1a52c49d849b96692f6c0dba329d75 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_object.py @@ -0,0 +1,17 @@ +class PlotObject: + """ + Base class for objects which can be displayed in + a Plot. + """ + visible = True + + def _draw(self): + if self.visible: + self.draw() + + def draw(self): + """ + OpenGL rendering code for the plot object. + Override in base class. + """ + pass diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_rotation.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_rotation.py new file mode 100644 index 0000000000000000000000000000000000000000..7f568964997634121946a761b55ef0f916ac633f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_rotation.py @@ -0,0 +1,68 @@ +try: + from ctypes import c_float +except ImportError: + pass + +import pyglet.gl as pgl +from math import sqrt as _sqrt, acos as _acos + + +def cross(a, b): + return (a[1] * b[2] - a[2] * b[1], + a[2] * b[0] - a[0] * b[2], + a[0] * b[1] - a[1] * b[0]) + + +def dot(a, b): + return a[0] * b[0] + a[1] * b[1] + a[2] * b[2] + + +def mag(a): + return _sqrt(a[0]**2 + a[1]**2 + a[2]**2) + + +def norm(a): + m = mag(a) + return (a[0] / m, a[1] / m, a[2] / m) + + +def get_sphere_mapping(x, y, width, height): + x = min([max([x, 0]), width]) + y = min([max([y, 0]), height]) + + sr = _sqrt((width/2)**2 + (height/2)**2) + sx = ((x - width / 2) / sr) + sy = ((y - height / 2) / sr) + + sz = 1.0 - sx**2 - sy**2 + + if sz > 0.0: + sz = _sqrt(sz) + return (sx, sy, sz) + else: + sz = 0 + return norm((sx, sy, sz)) + +rad2deg = 180.0 / 3.141592 + + +def get_spherical_rotatation(p1, p2, width, height, theta_multiplier): + v1 = get_sphere_mapping(p1[0], p1[1], width, height) + v2 = get_sphere_mapping(p2[0], p2[1], width, height) + + d = min(max([dot(v1, v2), -1]), 1) + + if abs(d - 1.0) < 0.000001: + return None + + raxis = norm( cross(v1, v2) ) + rtheta = theta_multiplier * rad2deg * _acos(d) + + pgl.glPushMatrix() + pgl.glLoadIdentity() + pgl.glRotatef(rtheta, *raxis) + mat = (c_float*16)() + pgl.glGetFloatv(pgl.GL_MODELVIEW_MATRIX, mat) + pgl.glPopMatrix() + + return mat diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_surface.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_surface.py new file mode 100644 index 0000000000000000000000000000000000000000..ed421eebb441d193f4d9b763f56e146c11e5a42c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_surface.py @@ -0,0 +1,102 @@ +import pyglet.gl as pgl + +from sympy.core import S +from sympy.plotting.pygletplot.plot_mode_base import PlotModeBase + + +class PlotSurface(PlotModeBase): + + default_rot_preset = 'perspective' + + def _on_calculate_verts(self): + self.u_interval = self.intervals[0] + self.u_set = list(self.u_interval.frange()) + self.v_interval = self.intervals[1] + self.v_set = list(self.v_interval.frange()) + self.bounds = [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + evaluate = self._get_evaluator() + + self._calculating_verts_pos = 0.0 + self._calculating_verts_len = float( + self.u_interval.v_len*self.v_interval.v_len) + + verts = [] + b = self.bounds + for u in self.u_set: + column = [] + for v in self.v_set: + try: + _e = evaluate(u, v) # calculate vertex + except ZeroDivisionError: + _e = None + if _e is not None: # update bounding box + for axis in range(3): + b[axis][0] = min([b[axis][0], _e[axis]]) + b[axis][1] = max([b[axis][1], _e[axis]]) + column.append(_e) + self._calculating_verts_pos += 1.0 + + verts.append(column) + for axis in range(3): + b[axis][2] = b[axis][1] - b[axis][0] + if b[axis][2] == 0.0: + b[axis][2] = 1.0 + + self.verts = verts + self.push_wireframe(self.draw_verts(False, False)) + self.push_solid(self.draw_verts(False, True)) + + def _on_calculate_cverts(self): + if not self.verts or not self.color: + return + + def set_work_len(n): + self._calculating_cverts_len = float(n) + + def inc_work_pos(): + self._calculating_cverts_pos += 1.0 + set_work_len(1) + self._calculating_cverts_pos = 0 + self.cverts = self.color.apply_to_surface(self.verts, + self.u_set, + self.v_set, + set_len=set_work_len, + inc_pos=inc_work_pos) + self.push_solid(self.draw_verts(True, True)) + + def calculate_one_cvert(self, u, v): + vert = self.verts[u][v] + return self.color(vert[0], vert[1], vert[2], + self.u_set[u], self.v_set[v]) + + def draw_verts(self, use_cverts, use_solid_color): + def f(): + for u in range(1, len(self.u_set)): + pgl.glBegin(pgl.GL_QUAD_STRIP) + for v in range(len(self.v_set)): + pa = self.verts[u - 1][v] + pb = self.verts[u][v] + if pa is None or pb is None: + pgl.glEnd() + pgl.glBegin(pgl.GL_QUAD_STRIP) + continue + if use_cverts: + ca = self.cverts[u - 1][v] + cb = self.cverts[u][v] + if ca is None: + ca = (0, 0, 0) + if cb is None: + cb = (0, 0, 0) + else: + if use_solid_color: + ca = cb = self.default_solid_color + else: + ca = cb = self.default_wireframe_color + pgl.glColor3f(*ca) + pgl.glVertex3f(*pa) + pgl.glColor3f(*cb) + pgl.glVertex3f(*pb) + pgl.glEnd() + return f diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_window.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_window.py new file mode 100644 index 0000000000000000000000000000000000000000..d9df4cc453acb05d7c2d871e9e8efeb36905de5d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/plot_window.py @@ -0,0 +1,144 @@ +from time import perf_counter + + +import pyglet.gl as pgl + +from sympy.plotting.pygletplot.managed_window import ManagedWindow +from sympy.plotting.pygletplot.plot_camera import PlotCamera +from sympy.plotting.pygletplot.plot_controller import PlotController + + +class PlotWindow(ManagedWindow): + + def __init__(self, plot, antialiasing=True, ortho=False, + invert_mouse_zoom=False, linewidth=1.5, caption="SymPy Plot", + **kwargs): + """ + Named Arguments + =============== + + antialiasing = True + True OR False + ortho = False + True OR False + invert_mouse_zoom = False + True OR False + """ + self.plot = plot + + self.camera = None + self._calculating = False + + self.antialiasing = antialiasing + self.ortho = ortho + self.invert_mouse_zoom = invert_mouse_zoom + self.linewidth = linewidth + self.title = caption + self.last_caption_update = 0 + self.caption_update_interval = 0.2 + self.drawing_first_object = True + + super().__init__(**kwargs) + + def setup(self): + self.camera = PlotCamera(self, ortho=self.ortho) + self.controller = PlotController(self, + invert_mouse_zoom=self.invert_mouse_zoom) + self.push_handlers(self.controller) + + pgl.glClearColor(1.0, 1.0, 1.0, 0.0) + pgl.glClearDepth(1.0) + + pgl.glDepthFunc(pgl.GL_LESS) + pgl.glEnable(pgl.GL_DEPTH_TEST) + + pgl.glEnable(pgl.GL_LINE_SMOOTH) + pgl.glShadeModel(pgl.GL_SMOOTH) + pgl.glLineWidth(self.linewidth) + + pgl.glEnable(pgl.GL_BLEND) + pgl.glBlendFunc(pgl.GL_SRC_ALPHA, pgl.GL_ONE_MINUS_SRC_ALPHA) + + if self.antialiasing: + pgl.glHint(pgl.GL_LINE_SMOOTH_HINT, pgl.GL_NICEST) + pgl.glHint(pgl.GL_POLYGON_SMOOTH_HINT, pgl.GL_NICEST) + + self.camera.setup_projection() + + def on_resize(self, w, h): + super().on_resize(w, h) + if self.camera is not None: + self.camera.setup_projection() + + def update(self, dt): + self.controller.update(dt) + + def draw(self): + self.plot._render_lock.acquire() + self.camera.apply_transformation() + + calc_verts_pos, calc_verts_len = 0, 0 + calc_cverts_pos, calc_cverts_len = 0, 0 + + should_update_caption = (perf_counter() - self.last_caption_update > + self.caption_update_interval) + + if len(self.plot._functions.values()) == 0: + self.drawing_first_object = True + + iterfunctions = iter(self.plot._functions.values()) + + for r in iterfunctions: + if self.drawing_first_object: + self.camera.set_rot_preset(r.default_rot_preset) + self.drawing_first_object = False + + pgl.glPushMatrix() + r._draw() + pgl.glPopMatrix() + + # might as well do this while we are + # iterating and have the lock rather + # than locking and iterating twice + # per frame: + + if should_update_caption: + try: + if r.calculating_verts: + calc_verts_pos += r.calculating_verts_pos + calc_verts_len += r.calculating_verts_len + if r.calculating_cverts: + calc_cverts_pos += r.calculating_cverts_pos + calc_cverts_len += r.calculating_cverts_len + except ValueError: + pass + + for r in self.plot._pobjects: + pgl.glPushMatrix() + r._draw() + pgl.glPopMatrix() + + if should_update_caption: + self.update_caption(calc_verts_pos, calc_verts_len, + calc_cverts_pos, calc_cverts_len) + self.last_caption_update = perf_counter() + + if self.plot._screenshot: + self.plot._screenshot._execute_saving() + + self.plot._render_lock.release() + + def update_caption(self, calc_verts_pos, calc_verts_len, + calc_cverts_pos, calc_cverts_len): + caption = self.title + if calc_verts_len or calc_cverts_len: + caption += " (calculating" + if calc_verts_len > 0: + p = (calc_verts_pos / calc_verts_len) * 100 + caption += " vertices %i%%" % (p) + if calc_cverts_len > 0: + p = (calc_cverts_pos / calc_cverts_len) * 100 + caption += " colors %i%%" % (p) + caption += ")" + if self.caption != caption: + self.set_caption(caption) diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__init__.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfebd3bc8eb863849d75649d76cb87aac01da206 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__pycache__/test_plotting.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__pycache__/test_plotting.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..df1c39434169f5862f8cf53f4907397b7f955e45 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/__pycache__/test_plotting.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py new file mode 100644 index 0000000000000000000000000000000000000000..ddc4aaf3621a8c9056ce0d81c89ca6a0a681bbdb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/tests/test_plotting.py @@ -0,0 +1,88 @@ +from sympy.external.importtools import import_module + +disabled = False + +# if pyglet.gl fails to import, e.g. opengl is missing, we disable the tests +pyglet_gl = import_module("pyglet.gl", catch=(OSError,)) +pyglet_window = import_module("pyglet.window", catch=(OSError,)) +if not pyglet_gl or not pyglet_window: + disabled = True + + +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.trigonometric import (cos, sin) +x, y, z = symbols('x, y, z') + + +def test_plot_2d(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(x, [x, -5, 5, 4], visible=False) + p.wait_for_calculations() + + +def test_plot_2d_discontinuous(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -1, 1, 2], visible=False) + p.wait_for_calculations() + + +def test_plot_3d(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(x*y, [x, -5, 5, 5], [y, -5, 5, 5], visible=False) + p.wait_for_calculations() + + +def test_plot_3d_discontinuous(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -3, 3, 6], [y, -1, 1, 1], visible=False) + p.wait_for_calculations() + + +def test_plot_2d_polar(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(1/x, [x, -1, 1, 4], 'mode=polar', visible=False) + p.wait_for_calculations() + + +def test_plot_3d_cylinder(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot( + 1/y, [x, 0, 6.282, 4], [y, -1, 1, 4], 'mode=polar;style=solid', + visible=False) + p.wait_for_calculations() + + +def test_plot_3d_spherical(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot( + 1, [x, 0, 6.282, 4], [y, 0, 3.141, + 4], 'mode=spherical;style=wireframe', + visible=False) + p.wait_for_calculations() + + +def test_plot_2d_parametric(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(sin(x), cos(x), [x, 0, 6.282, 4], visible=False) + p.wait_for_calculations() + + +def test_plot_3d_parametric(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(sin(x), cos(x), x/5.0, [x, 0, 6.282, 4], visible=False) + p.wait_for_calculations() + + +def _test_plot_log(): + from sympy.plotting.pygletplot import PygletPlot + p = PygletPlot(log(x), [x, 0, 6.282, 4], 'mode=polar', visible=False) + p.wait_for_calculations() + + +def test_plot_integral(): + # Make sure it doesn't treat x as an independent variable + from sympy.plotting.pygletplot import PygletPlot + from sympy.integrals.integrals import Integral + p = PygletPlot(Integral(z*x, (x, 1, z), (z, 1, y)), visible=False) + p.wait_for_calculations() diff --git a/lib/python3.10/site-packages/sympy/plotting/pygletplot/util.py b/lib/python3.10/site-packages/sympy/plotting/pygletplot/util.py new file mode 100644 index 0000000000000000000000000000000000000000..43b882ca18274dcdb273cf35680016453db3c698 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/pygletplot/util.py @@ -0,0 +1,188 @@ +try: + from ctypes import c_float, c_int, c_double +except ImportError: + pass + +import pyglet.gl as pgl +from sympy.core import S + + +def get_model_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv): + """ + Returns the current modelview matrix. + """ + m = (array_type*16)() + glGetMethod(pgl.GL_MODELVIEW_MATRIX, m) + return m + + +def get_projection_matrix(array_type=c_float, glGetMethod=pgl.glGetFloatv): + """ + Returns the current modelview matrix. + """ + m = (array_type*16)() + glGetMethod(pgl.GL_PROJECTION_MATRIX, m) + return m + + +def get_viewport(): + """ + Returns the current viewport. + """ + m = (c_int*4)() + pgl.glGetIntegerv(pgl.GL_VIEWPORT, m) + return m + + +def get_direction_vectors(): + m = get_model_matrix() + return ((m[0], m[4], m[8]), + (m[1], m[5], m[9]), + (m[2], m[6], m[10])) + + +def get_view_direction_vectors(): + m = get_model_matrix() + return ((m[0], m[1], m[2]), + (m[4], m[5], m[6]), + (m[8], m[9], m[10])) + + +def get_basis_vectors(): + return ((1, 0, 0), (0, 1, 0), (0, 0, 1)) + + +def screen_to_model(x, y, z): + m = get_model_matrix(c_double, pgl.glGetDoublev) + p = get_projection_matrix(c_double, pgl.glGetDoublev) + w = get_viewport() + mx, my, mz = c_double(), c_double(), c_double() + pgl.gluUnProject(x, y, z, m, p, w, mx, my, mz) + return float(mx.value), float(my.value), float(mz.value) + + +def model_to_screen(x, y, z): + m = get_model_matrix(c_double, pgl.glGetDoublev) + p = get_projection_matrix(c_double, pgl.glGetDoublev) + w = get_viewport() + mx, my, mz = c_double(), c_double(), c_double() + pgl.gluProject(x, y, z, m, p, w, mx, my, mz) + return float(mx.value), float(my.value), float(mz.value) + + +def vec_subs(a, b): + return tuple(a[i] - b[i] for i in range(len(a))) + + +def billboard_matrix(): + """ + Removes rotational components of + current matrix so that primitives + are always drawn facing the viewer. + + |1|0|0|x| + |0|1|0|x| + |0|0|1|x| (x means left unchanged) + |x|x|x|x| + """ + m = get_model_matrix() + # XXX: for i in range(11): m[i] = i ? + m[0] = 1 + m[1] = 0 + m[2] = 0 + m[4] = 0 + m[5] = 1 + m[6] = 0 + m[8] = 0 + m[9] = 0 + m[10] = 1 + pgl.glLoadMatrixf(m) + + +def create_bounds(): + return [[S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0], + [S.Infinity, S.NegativeInfinity, 0]] + + +def update_bounds(b, v): + if v is None: + return + for axis in range(3): + b[axis][0] = min([b[axis][0], v[axis]]) + b[axis][1] = max([b[axis][1], v[axis]]) + + +def interpolate(a_min, a_max, a_ratio): + return a_min + a_ratio * (a_max - a_min) + + +def rinterpolate(a_min, a_max, a_value): + a_range = a_max - a_min + if a_max == a_min: + a_range = 1.0 + return (a_value - a_min) / float(a_range) + + +def interpolate_color(color1, color2, ratio): + return tuple(interpolate(color1[i], color2[i], ratio) for i in range(3)) + + +def scale_value(v, v_min, v_len): + return (v - v_min) / v_len + + +def scale_value_list(flist): + v_min, v_max = min(flist), max(flist) + v_len = v_max - v_min + return [scale_value(f, v_min, v_len) for f in flist] + + +def strided_range(r_min, r_max, stride, max_steps=50): + o_min, o_max = r_min, r_max + if abs(r_min - r_max) < 0.001: + return [] + try: + range(int(r_min - r_max)) + except (TypeError, OverflowError): + return [] + if r_min > r_max: + raise ValueError("r_min cannot be greater than r_max") + r_min_s = (r_min % stride) + r_max_s = stride - (r_max % stride) + if abs(r_max_s - stride) < 0.001: + r_max_s = 0.0 + r_min -= r_min_s + r_max += r_max_s + r_steps = int((r_max - r_min)/stride) + if max_steps and r_steps > max_steps: + return strided_range(o_min, o_max, stride*2) + return [r_min] + [r_min + e*stride for e in range(1, r_steps + 1)] + [r_max] + + +def parse_option_string(s): + if not isinstance(s, str): + return None + options = {} + for token in s.split(';'): + pieces = token.split('=') + if len(pieces) == 1: + option, value = pieces[0], "" + elif len(pieces) == 2: + option, value = pieces + else: + raise ValueError("Plot option string '%s' is malformed." % (s)) + options[option.strip()] = value.strip() + return options + + +def dot_product(v1, v2): + return sum(v1[i]*v2[i] for i in range(3)) + + +def vec_sub(v1, v2): + return tuple(v1[i] - v2[i] for i in range(3)) + + +def vec_mag(v): + return sum(v[i]**2 for i in range(3))**(0.5) diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__init__.py b/lib/python3.10/site-packages/sympy/plotting/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0945116f86f47a2de8841fb148019f4d60fff525 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_experimental_lambdify.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_experimental_lambdify.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce95e3420e5414af9502491bfe80eeb0187fceb1 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_experimental_lambdify.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_plot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_plot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5118f257905639ae855270d4d1778811fe0fb473 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_plot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_plot_implicit.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_plot_implicit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..67ff2ba879247291c876b551d627db8b91afb27b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_plot_implicit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_series.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_series.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2700a241bae10c57fca0bb26d890f469a0e3e855 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_series.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_textplot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_textplot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a40cc9b0b14d40a621614b50739be93b64236e8f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_textplot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_utils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f830570ca4c70069fbaded90a1bff520e56a3386 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/__pycache__/test_utils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_experimental_lambdify.py b/lib/python3.10/site-packages/sympy/plotting/tests/test_experimental_lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..95839d668762be7be94d0de5092594306ceeadbd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/tests/test_experimental_lambdify.py @@ -0,0 +1,77 @@ +from sympy.core.symbol import symbols, Symbol +from sympy.functions import Max +from sympy.plotting.experimental_lambdify import experimental_lambdify +from sympy.plotting.intervalmath.interval_arithmetic import \ + interval, intervalMembership + + +# Tests for exception handling in experimental_lambdify +def test_experimental_lambify(): + x = Symbol('x') + f = experimental_lambdify([x], Max(x, 5)) + # XXX should f be tested? If f(2) is attempted, an + # error is raised because a complex produced during wrapping of the arg + # is being compared with an int. + assert Max(2, 5) == 5 + assert Max(5, 7) == 7 + + x = Symbol('x-3') + f = experimental_lambdify([x], x + 1) + assert f(1) == 2 + + +def test_composite_boolean_region(): + x, y = symbols('x y') + + r1 = (x - 1)**2 + y**2 < 2 + r2 = (x + 1)**2 + y**2 < 2 + + f = experimental_lambdify((x, y), r1 & r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), r1 | r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), r1 & ~r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), ~r1 & r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(True, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(False, True) + + f = experimental_lambdify((x, y), ~r1 & ~r2) + a = (interval(-0.1, 0.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-1.1, -0.9), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(0.9, 1.1), interval(-0.1, 0.1)) + assert f(*a) == intervalMembership(False, True) + a = (interval(-0.1, 0.1), interval(1.9, 2.1)) + assert f(*a) == intervalMembership(True, True) diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_plot.py b/lib/python3.10/site-packages/sympy/plotting/tests/test_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..bf09e825e7444cfdaf42e8c419dc50170168365b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/tests/test_plot.py @@ -0,0 +1,1343 @@ +import os +from tempfile import TemporaryDirectory +import pytest +from sympy.concrete.summations import Sum +from sympy.core.numbers import (I, oo, pi) +from sympy.core.relational import Ne +from sympy.core.symbol import Symbol, symbols +from sympy.functions.elementary.exponential import (LambertW, exp, exp_polar, log) +from sympy.functions.elementary.miscellaneous import (real_root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.elementary.miscellaneous import Min +from sympy.functions.special.hyper import meijerg +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import And +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.external import import_module +from sympy.plotting.plot import ( + Plot, plot, plot_parametric, plot3d_parametric_line, plot3d, + plot3d_parametric_surface) +from sympy.plotting.plot import ( + unset_show, plot_contour, PlotGrid, MatplotlibBackend, TextBackend) +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + ParametricSurfaceSeries, SurfaceOver2DRangeSeries) +from sympy.testing.pytest import skip, warns, raises, warns_deprecated_sympy +from sympy.utilities import lambdify as lambdify_ +from sympy.utilities.exceptions import ignore_warnings + +unset_show() + + +matplotlib = import_module( + 'matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + + +class DummyBackendNotOk(Plot): + """ Used to verify if users can create their own backends. + This backend is meant to raise NotImplementedError for methods `show`, + `save`, `close`. + """ + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + +class DummyBackendOk(Plot): + """ Used to verify if users can create their own backends. + This backend is meant to pass all tests. + """ + def __new__(cls, *args, **kwargs): + return object.__new__(cls) + + def show(self): + pass + + def save(self): + pass + + def close(self): + pass + +def test_basic_plotting_backend(): + x = Symbol('x') + plot(x, (x, 0, 3), backend='text') + plot(x**2 + 1, (x, 0, 3), backend='text') + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_1(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + ### + # Examples from the 'introduction' notebook + ### + p = plot(x, legend=True, label='f1', adaptive=adaptive, n=10) + p = plot(x*sin(x), x*cos(x), label='f2', adaptive=adaptive, n=10) + p.extend(p) + p[0].line_color = lambda a: a + p[1].line_color = 'b' + p.title = 'Big title' + p.xlabel = 'the x axis' + p[1].label = 'straight line' + p.legend = True + p.aspect_ratio = (1, 1) + p.xlim = (-15, 20) + filename = 'test_basic_options_and_colors.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p.extend(plot(x + 1, adaptive=adaptive, n=10)) + p.append(plot(x + 3, x**2, adaptive=adaptive, n=10)[1]) + filename = 'test_plot_extend_append.png' + p.save(os.path.join(tmpdir, filename)) + + p[2] = plot(x**2, (x, -2, 3), adaptive=adaptive, n=10) + filename = 'test_plot_setitem.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(sin(x), (x, -2*pi, 4*pi), adaptive=adaptive, n=10) + filename = 'test_line_explicit.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(sin(x), adaptive=adaptive, n=10) + filename = 'test_line_default_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot((x**2, (x, -5, 5)), (x**3, (x, -3, 3)), adaptive=adaptive, n=10) + filename = 'test_line_multiple_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + raises(ValueError, lambda: plot(x, y)) + + #Piecewise plots + p = plot(Piecewise((1, x > 0), (0, True)), (x, -1, 1), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(Piecewise((x, x < 1), (x**2, True)), (x, -3, 3), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise_2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # test issue 7471 + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot(3, adaptive=adaptive, n=10) + p1.extend(p2) + filename = 'test_horizontal_line.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # test issue 10925 + f = Piecewise((-1, x < -1), (x, And(-1 <= x, x < 0)), \ + (x**2, And(0 <= x, x < 1)), (x**3, x >= 1)) + p = plot(f, (x, -3, 3), adaptive=adaptive, n=10) + filename = 'test_plot_piecewise_3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_2(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + #parametric 2d plots. + #Single plot with default range. + p = plot_parametric(sin(x), cos(x), adaptive=adaptive, n=10) + filename = 'test_parametric.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Single plot with range. + p = plot_parametric( + sin(x), cos(x), (x, -5, 5), legend=True, label='parametric_plot', + adaptive=adaptive, n=10) + filename = 'test_parametric_range.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Multiple plots with same range. + p = plot_parametric((sin(x), cos(x)), (x, sin(x)), + adaptive=adaptive, n=10) + filename = 'test_parametric_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #Multiple plots with different ranges. + p = plot_parametric( + (sin(x), cos(x), (x, -3, 3)), (x, sin(x), (x, -5, 5)), + adaptive=adaptive, n=10) + filename = 'test_parametric_multiple_ranges.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #depth of recursion specified. + p = plot_parametric(x, sin(x), depth=13, + adaptive=adaptive, n=10) + filename = 'test_recursion_depth.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #No adaptive sampling. + p = plot_parametric(cos(x), sin(x), adaptive=False, n=500) + filename = 'test_adaptive.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + #3d parametric plots + p = plot3d_parametric_line( + sin(x), cos(x), x, legend=True, label='3d_parametric_plot', + adaptive=adaptive, n=10) + filename = 'test_3d_line.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line( + (sin(x), cos(x), x, (x, -5, 5)), (cos(x), sin(x), x, (x, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_3d_line_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line(sin(x), cos(x), x, n=30, + adaptive=adaptive) + filename = 'test_3d_line_points.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # 3d surface single plot. + p = plot3d(x * y, adaptive=adaptive, n=10) + filename = 'test_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple 3D plots with same range. + p = plot3d(-x * y, x * y, (x, -5, 5), adaptive=adaptive, n=10) + filename = 'test_surface_multiple.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple 3D plots with different ranges. + p = plot3d( + (x * y, (x, -3, 3), (y, -3, 3)), (-x * y, (x, -3, 3), (y, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_surface_multiple_ranges.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Single Parametric 3D plot + p = plot3d_parametric_surface(sin(x + y), cos(x - y), x - y, + adaptive=adaptive, n=10) + filename = 'test_parametric_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Parametric 3D plots. + p = plot3d_parametric_surface( + (x*sin(z), x*cos(z), z, (x, -5, 5), (z, -5, 5)), + (sin(x + y), cos(x - y), x - y, (x, -5, 5), (y, -5, 5)), + adaptive=adaptive, n=10) + filename = 'test_parametric_surface.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Single Contour plot. + p = plot_contour(sin(x)*sin(y), (x, -5, 5), (y, -5, 5), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Contour plots with same range. + p = plot_contour(x**2 + y**2, x**3 + y**3, (x, -5, 5), (y, -5, 5), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # Multiple Contour plots with different range. + p = plot_contour( + (x**2 + y**2, (x, -5, 5), (y, -5, 5)), + (x**3 + y**3, (x, -3, 3), (y, -3, 3)), + adaptive=adaptive, n=10) + filename = 'test_contour_plot.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_3(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + ### + # Examples from the 'colors' notebook + ### + + p = plot(sin(x), adaptive=adaptive, n=10) + p[0].line_color = lambda a: a + filename = 'test_colors_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: b + filename = 'test_colors_line_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(x*sin(x), x*cos(x), (x, 0, 10), adaptive=adaptive, n=10) + p[0].line_color = lambda a: a + filename = 'test_colors_param_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: a + filename = 'test_colors_param_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + + p[0].line_color = lambda a, b: b + filename = 'test_colors_param_line_arity2b.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_line( + sin(x) + 0.1*sin(x)*cos(7*x), + cos(x) + 0.1*cos(x)*cos(7*x), + 0.1*sin(7*x), + (x, 0, 2*pi), adaptive=adaptive, n=10) + p[0].line_color = lambdify_(x, sin(4*x)) + filename = 'test_colors_3d_line_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].line_color = lambda a, b: b + filename = 'test_colors_3d_line_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].line_color = lambda a, b, c: c + filename = 'test_colors_3d_line_arity3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d(sin(x)*y, (x, 0, 6*pi), (y, -5, 5), adaptive=adaptive, n=10) + p[0].surface_color = lambda a: a + filename = 'test_colors_surface_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b: b + filename = 'test_colors_surface_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b, c: c + filename = 'test_colors_surface_arity3a.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambdify_((x, y, z), sqrt((x - 3*pi)**2 + y**2)) + filename = 'test_colors_surface_arity3b.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot3d_parametric_surface(x * cos(4 * y), x * sin(4 * y), y, + (x, -1, 1), (y, -1, 1), adaptive=adaptive, n=10) + p[0].surface_color = lambda a: a + filename = 'test_colors_param_surf_arity1.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambda a, b: a*b + filename = 'test_colors_param_surf_arity2.png' + p.save(os.path.join(tmpdir, filename)) + p[0].surface_color = lambdify_((x, y, z), sqrt(x**2 + y**2 + z**2)) + filename = 'test_colors_param_surf_arity3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True]) +def test_plot_and_save_4(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + ### + # Examples from the 'advanced' notebook + ### + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + i = Integral(log((sin(x)**2 + 1)*sqrt(x**2 + 1)), (x, 0, y)) + p = plot(i, (y, 1, 5), adaptive=adaptive, n=10, force_real_eval=True) + filename = 'test_advanced_integral.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_5(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + s = Sum(1/x**y, (x, 1, oo)) + p = plot(s, (y, 2, 10), adaptive=adaptive, n=10) + filename = 'test_advanced_inf_sum.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p = plot(Sum(1/x, (x, 1, y)), (y, 2, 10), show=False, + adaptive=adaptive, n=10) + p[0].only_integers = True + p[0].steps = True + filename = 'test_advanced_fin_sum.png' + + # XXX: This should be fixed in experimental_lambdify or by using + # ordinary lambdify so that it doesn't warn. The error results from + # passing an array of values as the integration limit. + # + # UserWarning: The evaluation of the expression is problematic. We are + # trying a failback method that may still work. Please report this as a + # bug. + with ignore_warnings(UserWarning): + p.save(os.path.join(tmpdir, filename)) + + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_and_save_6(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + filename = 'test.png' + ### + # Test expressions that can not be translated to np and generate complex + # results. + ### + p = plot(sin(x) + I*cos(x)) + p.save(os.path.join(tmpdir, filename)) + + with ignore_warnings(RuntimeWarning): + p = plot(sqrt(sqrt(-x))) + p.save(os.path.join(tmpdir, filename)) + + p = plot(LambertW(x)) + p.save(os.path.join(tmpdir, filename)) + p = plot(sqrt(LambertW(x))) + p.save(os.path.join(tmpdir, filename)) + + #Characteristic function of a StudentT distribution with nu=10 + x1 = 5 * x**2 * exp_polar(-I*pi)/2 + m1 = meijerg(((1 / 2,), ()), ((5, 0, 1 / 2), ()), x1) + x2 = 5*x**2 * exp_polar(I*pi)/2 + m2 = meijerg(((1/2,), ()), ((5, 0, 1/2), ()), x2) + expr = (m1 + m2) / (48 * pi) + with warns( + UserWarning, + match="The evaluation with NumPy/SciPy failed", + test_stacklevel=False, + ): + p = plot(expr, (x, 1e-6, 1e-2), adaptive=adaptive, n=10) + p.save(os.path.join(tmpdir, filename)) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plotgrid_and_save(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + + with TemporaryDirectory(prefix='sympy_') as tmpdir: + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot_parametric((sin(x), cos(x)), (x, sin(x)), show=False, + adaptive=adaptive, n=10) + p3 = plot_parametric( + cos(x), sin(x), adaptive=adaptive, n=10, show=False) + p4 = plot3d_parametric_line(sin(x), cos(x), x, show=False, + adaptive=adaptive, n=10) + # symmetric grid + p = PlotGrid(2, 2, p1, p2, p3, p4) + filename = 'test_grid1.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + # grid size greater than the number of subplots + p = PlotGrid(3, 4, p1, p2, p3, p4) + filename = 'test_grid2.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + p5 = plot(cos(x),(x, -pi, pi), show=False, adaptive=adaptive, n=10) + p5[0].line_color = lambda a: a + p6 = plot(Piecewise((1, x > 0), (0, True)), (x, -1, 1), show=False, + adaptive=adaptive, n=10) + p7 = plot_contour( + (x**2 + y**2, (x, -5, 5), (y, -5, 5)), + (x**3 + y**3, (x, -3, 3), (y, -3, 3)), show=False, + adaptive=adaptive, n=10) + # unsymmetric grid (subplots in one line) + p = PlotGrid(1, 3, p5, p6, p7) + filename = 'test_grid3.png' + p.save(os.path.join(tmpdir, filename)) + p._backend.close() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_append_issue_7140(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p1 = plot(x, adaptive=adaptive, n=10) + p2 = plot(x**2, adaptive=adaptive, n=10) + plot(x + 2, adaptive=adaptive, n=10) + + # append a series + p2.append(p1[0]) + assert len(p2._series) == 2 + + with raises(TypeError): + p1.append(p2) + + with raises(TypeError): + p1.append(p2._series) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_15265(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + eqn = sin(x) + + p = plot(eqn, xlim=(-S.Pi, S.Pi), ylim=(-1, 1), adaptive=adaptive, n=10) + p._backend.close() + + p = plot(eqn, xlim=(-1, 1), ylim=(-S.Pi, S.Pi), adaptive=adaptive, n=10) + p._backend.close() + + p = plot(eqn, xlim=(-1, 1), adaptive=adaptive, n=10, + ylim=(sympify('-3.14'), sympify('3.14'))) + p._backend.close() + + p = plot(eqn, adaptive=adaptive, n=10, + xlim=(sympify('-3.14'), sympify('3.14')), ylim=(-1, 1)) + p._backend.close() + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-S.ImaginaryUnit, 1), ylim=(-1, 1))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-1, 1), ylim=(-1, S.ImaginaryUnit))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(S.NegativeInfinity, 1), ylim=(-1, 1))) + + raises(ValueError, + lambda: plot(eqn, adaptive=adaptive, n=10, + xlim=(-1, 1), ylim=(-1, S.Infinity))) + + +def test_empty_Plot(): + if not matplotlib: + skip("Matplotlib not the default backend") + + # No exception showing an empty plot + plot() + # Plot is only a base class: doesn't implement any logic for showing + # images + p = Plot() + raises(NotImplementedError, lambda: p.show()) + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_17405(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + f = x**0.3 - 10*x**3 + x**2 + p = plot(f, (x, -10, 10), adaptive=adaptive, n=30, show=False) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + + # RuntimeWarning: invalid value encountered in double_scalars + with ignore_warnings(RuntimeWarning): + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_logplot_PR_16796(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(x, (x, .001, 100), adaptive=adaptive, n=30, + xscale='log', show=False) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + assert p[0].end == 100.0 + assert p[0].start == .001 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_16572(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(LambertW(x), show=False, adaptive=adaptive, n=30) + # Random number of segments, probably more than 50, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_11865(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + k = Symbol('k', integer=True) + f = Piecewise((-I*exp(I*pi*k)/k + I*exp(-I*pi*k)/k, Ne(k, 0)), (2*pi, True)) + p = plot(f, show=False, adaptive=adaptive, n=30) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + # and that there are no exceptions. + assert len(p[0].get_data()[0]) >= 30 + + +def test_issue_11461(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(real_root((log(x/(x-2))), 3), show=False, adaptive=True) + with warns( + RuntimeWarning, + match="invalid value encountered in", + test_stacklevel=False, + ): + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + # and that there are no exceptions. + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_11764(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot_parametric(cos(x), sin(x), (x, 0, 2 * pi), + aspect_ratio=(1,1), show=False, adaptive=adaptive, n=30) + assert p.aspect_ratio == (1, 1) + # Random number of segments, probably more than 100, but we want to see + # that there are segments generated, as opposed to when the bug was present + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_issue_13516(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + pm = plot(sin(x), backend="matplotlib", show=False, adaptive=adaptive, n=30) + assert pm.backend == MatplotlibBackend + assert len(pm[0].get_data()[0]) >= 30 + + pt = plot(sin(x), backend="text", show=False, adaptive=adaptive, n=30) + assert pt.backend == TextBackend + assert len(pt[0].get_data()[0]) >= 30 + + pd = plot(sin(x), backend="default", show=False, adaptive=adaptive, n=30) + assert pd.backend == MatplotlibBackend + assert len(pd[0].get_data()[0]) >= 30 + + p = plot(sin(x), show=False, adaptive=adaptive, n=30) + assert p.backend == MatplotlibBackend + assert len(p[0].get_data()[0]) >= 30 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_limits(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(x, x**2, (x, -10, 10), adaptive=adaptive, n=10) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 10) < 2 + assert abs(xmax - 10) < 2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 10) < 10 + assert abs(ymax - 100) < 10 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot3d_parametric_line_limits(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + v1 = (2*cos(x), 2*sin(x), 2*x, (x, -5, 5)) + v2 = (sin(x), cos(x), x, (x, -5, 5)) + p = plot3d_parametric_line(v1, v2, adaptive=adaptive, n=60) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 2) < 1e-2 + assert abs(xmax - 2) < 1e-2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 2) < 1e-2 + assert abs(ymax - 2) < 1e-2 + zmin, zmax = backend.ax.get_zlim() + assert abs(zmin + 10) < 1e-2 + assert abs(zmax - 10) < 1e-2 + + p = plot3d_parametric_line(v2, v1, adaptive=adaptive, n=60) + backend = p._backend + + xmin, xmax = backend.ax.get_xlim() + assert abs(xmin + 2) < 1e-2 + assert abs(xmax - 2) < 1e-2 + ymin, ymax = backend.ax.get_ylim() + assert abs(ymin + 2) < 1e-2 + assert abs(ymax - 2) < 1e-2 + zmin, zmax = backend.ax.get_zlim() + assert abs(zmin + 10) < 1e-2 + assert abs(zmax - 10) < 1e-2 + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_plot_size(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + p1 = plot(sin(x), backend="matplotlib", size=(8, 4), + adaptive=adaptive, n=10) + s1 = p1._backend.fig.get_size_inches() + assert (s1[0] == 8) and (s1[1] == 4) + p2 = plot(sin(x), backend="matplotlib", size=(5, 10), + adaptive=adaptive, n=10) + s2 = p2._backend.fig.get_size_inches() + assert (s2[0] == 5) and (s2[1] == 10) + p3 = PlotGrid(2, 1, p1, p2, size=(6, 2), + adaptive=adaptive, n=10) + s3 = p3._backend.fig.get_size_inches() + assert (s3[0] == 6) and (s3[1] == 2) + + with raises(ValueError): + plot(sin(x), backend="matplotlib", size=(-1, 3)) + + +def test_issue_20113(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + + # verify the capability to use custom backends + plot(sin(x), backend=Plot, show=False) + p2 = plot(sin(x), backend=MatplotlibBackend, show=False) + assert p2.backend == MatplotlibBackend + assert len(p2[0].get_data()[0]) >= 30 + p3 = plot(sin(x), backend=DummyBackendOk, show=False) + assert p3.backend == DummyBackendOk + assert len(p3[0].get_data()[0]) >= 30 + + # test for an improper coded backend + p4 = plot(sin(x), backend=DummyBackendNotOk, show=False) + assert p4.backend == DummyBackendNotOk + assert len(p4[0].get_data()[0]) >= 30 + with raises(NotImplementedError): + p4.show() + with raises(NotImplementedError): + p4.save("test/path") + with raises(NotImplementedError): + p4._backend.close() + + +def test_custom_coloring(): + x = Symbol('x') + y = Symbol('y') + plot(cos(x), line_color=lambda a: a) + plot(cos(x), line_color=1) + plot(cos(x), line_color="r") + plot_parametric(cos(x), sin(x), line_color=lambda a: a) + plot_parametric(cos(x), sin(x), line_color=1) + plot_parametric(cos(x), sin(x), line_color="r") + plot3d_parametric_line(cos(x), sin(x), x, line_color=lambda a: a) + plot3d_parametric_line(cos(x), sin(x), x, line_color=1) + plot3d_parametric_line(cos(x), sin(x), x, line_color="r") + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color=lambda a, b: a**2 + b**2) + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color=1) + plot3d_parametric_surface(cos(x + y), sin(x - y), x - y, + (x, -5, 5), (y, -5, 5), + surface_color="r") + plot3d(x*y, (x, -5, 5), (y, -5, 5), + surface_color=lambda a, b: a**2 + b**2) + plot3d(x*y, (x, -5, 5), (y, -5, 5), surface_color=1) + plot3d(x*y, (x, -5, 5), (y, -5, 5), surface_color="r") + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_deprecated_get_segments(adaptive): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + f = sin(x) + p = plot(f, (x, -10, 10), show=False, adaptive=adaptive, n=10) + with warns_deprecated_sympy(): + p[0].get_segments() + + +@pytest.mark.parametrize("adaptive", [True, False]) +def test_generic_data_series(adaptive): + # verify that no errors are raised when generic data series are used + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol("x") + p = plot(x, + markers=[{"args":[[0, 1], [0, 1]], "marker": "*", "linestyle": "none"}], + annotations=[{"text": "test", "xy": (0, 0)}], + fill={"x": [0, 1, 2, 3], "y1": [0, 1, 2, 3]}, + rectangles=[{"xy": (0, 0), "width": 5, "height": 1}], + adaptive=adaptive, n=10) + assert len(p._backend.ax.collections) == 1 + assert len(p._backend.ax.patches) == 1 + assert len(p._backend.ax.lines) == 2 + assert len(p._backend.ax.texts) == 1 + + +def test_deprecated_markers_annotations_rectangles_fill(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + p = plot(sin(x), (x, -10, 10), show=False) + with warns_deprecated_sympy(): + p.markers = [{"args":[[0, 1], [0, 1]], "marker": "*", "linestyle": "none"}] + assert len(p._series) == 2 + with warns_deprecated_sympy(): + p.annotations = [{"text": "test", "xy": (0, 0)}] + assert len(p._series) == 3 + with warns_deprecated_sympy(): + p.fill = {"x": [0, 1, 2, 3], "y1": [0, 1, 2, 3]} + assert len(p._series) == 4 + with warns_deprecated_sympy(): + p.rectangles = [{"xy": (0, 0), "width": 5, "height": 1}] + assert len(p._series) == 5 + + +def test_back_compatibility(): + if not matplotlib: + skip("Matplotlib not the default backend") + + x = Symbol('x') + y = Symbol('y') + p = plot(sin(x), adaptive=False, n=5) + assert len(p[0].get_points()) == 2 + assert len(p[0].get_data()) == 2 + p = plot_parametric(cos(x), sin(x), (x, 0, 2), adaptive=False, n=5) + assert len(p[0].get_points()) == 2 + assert len(p[0].get_data()) == 3 + p = plot3d_parametric_line(cos(x), sin(x), x, (x, 0, 2), + adaptive=False, n=5) + assert len(p[0].get_points()) == 3 + assert len(p[0].get_data()) == 4 + p = plot3d(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 3 + p = plot_contour(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 3 + p = plot3d_parametric_surface(x * cos(y), x * sin(y), x * cos(4 * y) / 2, + (x, 0, pi), (y, 0, 2*pi), n=5) + assert len(p[0].get_meshes()) == 3 + assert len(p[0].get_data()) == 5 + + +def test_plot_arguments(): + ### Test arguments for plot() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single expressions + p = plot(x + 1) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + + # single expressions custom label + p = plot(x + 1, "label") + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "label" + assert p[0].rendering_kw == {} + + # single expressions with range + p = plot(x + 1, (x, -2, 2)) + assert p[0].ranges == [(x, -2, 2)] + + # single expressions with range, label and rendering-kw dictionary + p = plot(x + 1, (x, -2, 2), "test", {"color": "r"}) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"color": "r"} + + # multiple expressions + p = plot(x + 1, x**2) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + assert isinstance(p[1], LineOver1DRangeSeries) + assert p[1].expr == x**2 + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x**2" + assert p[1].rendering_kw == {} + + # multiple expressions over the same range + p = plot(x + 1, x**2, (x, 0, 5)) + assert p[0].ranges == [(x, 0, 5)] + assert p[1].ranges == [(x, 0, 5)] + + # multiple expressions over the same range with the same rendering kws + p = plot(x + 1, x**2, (x, 0, 5), {"color": "r"}) + assert p[0].ranges == [(x, 0, 5)] + assert p[1].ranges == [(x, 0, 5)] + assert p[0].rendering_kw == {"color": "r"} + assert p[1].rendering_kw == {"color": "r"} + + # multiple expressions with different ranges, labels and rendering kws + p = plot( + (x + 1, (x, 0, 5)), + (x**2, (x, -2, 2), "test", {"color": "r"})) + assert isinstance(p[0], LineOver1DRangeSeries) + assert p[0].expr == x + 1 + assert p[0].ranges == [(x, 0, 5)] + assert p[0].get_label(False) == "x + 1" + assert p[0].rendering_kw == {} + assert isinstance(p[1], LineOver1DRangeSeries) + assert p[1].expr == x**2 + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"color": "r"} + + # single argument: lambda function + f = lambda t: t + p = plot(lambda t: t) + assert isinstance(p[0], LineOver1DRangeSeries) + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range and label + p = plot(f, ("t", -5, 6), "test") + assert p[0].ranges[0][1:] == (-5, 6) + assert p[0].get_label(False) == "test" + + +def test_plot_parametric_arguments(): + ### Test arguments for plot_parametric() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot_parametric(x + 1, x) + assert isinstance(p[0], Parametric2DLineSeries) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + + # single parametric expression with custom range, label and rendering kws + p = plot_parametric(x + 1, x, (x, -2, 2), "test", + {"cmap": "Reds"}) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot_parametric((x + 1, x), (x, -2, 2), "test") + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # multiple parametric expressions same symbol + p = plot_parametric((x + 1, x), (x ** 2, x + 1)) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {} + + # multiple parametric expressions different symbols + p = plot_parametric((x + 1, x), (y ** 2, y + 1, "test")) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (y ** 2, y + 1) + assert p[1].ranges == [(y, -10, 10)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} + + # multiple parametric expressions same range + p = plot_parametric((x + 1, x), (x ** 2, x + 1), (x, -2, 2)) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {} + + # multiple parametric expressions, custom ranges and labels + p = plot_parametric( + (x + 1, x, (x, -2, 2), "test1"), + (x ** 2, x + 1, (x, -3, 3), "test2", {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test1" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, x + 1) + assert p[1].ranges == [(x, -3, 3)] + assert p[1].get_label(False) == "test2" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single argument: lambda function + fx = lambda t: t + fy = lambda t: 2 * t + p = plot_parametric(fx, fy) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert "Dummy" in p[0].get_label(False) + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range + label + p = plot_parametric(fx, fy, ("t", 0, 2), "test") + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + +def test_plot3d_parametric_line_arguments(): + ### Test arguments for plot3d_parametric_line() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot3d_parametric_line(x + 1, x, sin(x)) + assert isinstance(p[0], Parametric3DLineSeries) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + + # single parametric expression with custom range, label and rendering kws + p = plot3d_parametric_line(x + 1, x, sin(x), (x, -2, 2), + "test", {"cmap": "Reds"}) + assert isinstance(p[0], Parametric3DLineSeries) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot3d_parametric_line((x + 1, x, sin(x)), (x, -2, 2), "test") + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -2, 2)] + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # multiple parametric expression same symbol + p = plot3d_parametric_line( + (x + 1, x, sin(x)), (x ** 2, 1, cos(x), {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, 1, cos(x)) + assert p[1].ranges == [(x, -10, 10)] + assert p[1].get_label(False) == "x" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # multiple parametric expression different symbols + p = plot3d_parametric_line((x + 1, x, sin(x)), (y ** 2, 1, cos(y))) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (y ** 2, 1, cos(y)) + assert p[1].ranges == [(y, -10, 10)] + assert p[1].get_label(False) == "y" + assert p[1].rendering_kw == {} + + # multiple parametric expression, custom ranges and labels + p = plot3d_parametric_line( + (x + 1, x, sin(x)), + (x ** 2, 1, cos(x), (x, -2, 2), "test", {"cmap": "Reds"})) + assert p[0].expr == (x + 1, x, sin(x)) + assert p[0].ranges == [(x, -10, 10)] + assert p[0].get_label(False) == "x" + assert p[0].rendering_kw == {} + assert p[1].expr == (x ** 2, 1, cos(x)) + assert p[1].ranges == [(x, -2, 2)] + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single argument: lambda function + fx = lambda t: t + fy = lambda t: 2 * t + fz = lambda t: 3 * t + p = plot3d_parametric_line(fx, fy, fz) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert "Dummy" in p[0].get_label(False) + assert p[0].rendering_kw == {} + + # single argument: lambda function + custom range + label + p = plot3d_parametric_line(fx, fy, fz, ("t", 0, 2), "test") + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + +def test_plot3d_plot_contour_arguments(): + ### Test arguments for plot3d() and plot_contour() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single expression + p = plot3d(x + y) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + + # single expression, custom range, label and rendering kws + p = plot3d(x + y, (x, -2, 2), "test", {"cmap": "Reds"}) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -10, 10) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + p = plot3d(x + y, (x, -2, 2), (y, -4, 4), "test") + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + + # multiple expressions + p = plot3d(x + y, x * y) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[1].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[1].get_label(False) == "x*y" + assert p[1].rendering_kw == {} + + # multiple expressions, same custom ranges + p = plot3d(x + y, x * y, (x, -2, 2), (y, -4, 4)) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -2, 2) + assert p[1].ranges[1] == (y, -4, 4) + assert p[1].get_label(False) == "x*y" + assert p[1].rendering_kw == {} + + # multiple expressions, custom ranges, labels and rendering kws + p = plot3d( + (x + y, (x, -2, 2), (y, -4, 4)), + (x * y, (x, -3, 3), (y, -6, 6), "test", {"cmap": "Reds"})) + assert p[0].expr == x + y + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "x + y" + assert p[0].rendering_kw == {} + assert p[1].expr == x * y + assert p[1].ranges[0] == (x, -3, 3) + assert p[1].ranges[1] == (y, -6, 6) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # single expression: lambda function + f = lambda x, y: x + y + p = plot3d(f) + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-10, 10) + assert p[0].ranges[1][1:] == (-10, 10) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # single expression: lambda function + custom ranges + label + p = plot3d(f, ("a", -5, 3), ("b", -2, 1), "test") + assert callable(p[0].expr) + assert p[0].ranges[0][1:] == (-5, 3) + assert p[0].ranges[1][1:] == (-2, 1) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + + # test issue 25818 + # single expression, custom range, min/max functions + p = plot3d(Min(x, y), (x, 0, 10), (y, 0, 10)) + assert isinstance(p[0], SurfaceOver2DRangeSeries) + assert p[0].expr == Min(x, y) + assert p[0].ranges[0] == (x, 0, 10) + assert p[0].ranges[1] == (y, 0, 10) + assert p[0].get_label(False) == "Min(x, y)" + assert p[0].rendering_kw == {} + + +def test_plot3d_parametric_surface_arguments(): + ### Test arguments for plot3d_parametric_surface() + if not matplotlib: + skip("Matplotlib not the default backend") + + x, y = symbols("x, y") + + # single parametric expression + p = plot3d_parametric_surface(x + y, cos(x + y), sin(x + y)) + assert isinstance(p[0], ParametricSurfaceSeries) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "(x + y, cos(x + y), sin(x + y))" + assert p[0].rendering_kw == {} + + # single parametric expression, custom ranges, labels and rendering kws + p = plot3d_parametric_surface(x + y, cos(x + y), sin(x + y), + (x, -2, 2), (y, -4, 4), "test", {"cmap": "Reds"}) + assert isinstance(p[0], ParametricSurfaceSeries) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -4, 4) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {"cmap": "Reds"} + + # multiple parametric expressions + p = plot3d_parametric_surface( + (x + y, cos(x + y), sin(x + y)), + (x - y, cos(x - y), sin(x - y), "test")) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[0].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[0].get_label(False) == "(x + y, cos(x + y), sin(x + y))" + assert p[0].rendering_kw == {} + assert p[1].expr == (x - y, cos(x - y), sin(x - y)) + assert p[1].ranges[0] == (x, -10, 10) or (y, -10, 10) + assert p[1].ranges[1] == (x, -10, 10) or (y, -10, 10) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} + + # multiple parametric expressions, custom ranges and labels + p = plot3d_parametric_surface( + (x + y, cos(x + y), sin(x + y), (x, -2, 2), "test"), + (x - y, cos(x - y), sin(x - y), (x, -3, 3), (y, -4, 4), + "test2", {"cmap": "Reds"})) + assert p[0].expr == (x + y, cos(x + y), sin(x + y)) + assert p[0].ranges[0] == (x, -2, 2) + assert p[0].ranges[1] == (y, -10, 10) + assert p[0].get_label(False) == "test" + assert p[0].rendering_kw == {} + assert p[1].expr == (x - y, cos(x - y), sin(x - y)) + assert p[1].ranges[0] == (x, -3, 3) + assert p[1].ranges[1] == (y, -4, 4) + assert p[1].get_label(False) == "test2" + assert p[1].rendering_kw == {"cmap": "Reds"} + + # lambda functions instead of symbolic expressions for a single 3D + # parametric surface + p = plot3d_parametric_surface( + lambda u, v: u, lambda u, v: v, lambda u, v: u + v, + ("u", 0, 2), ("v", -3, 4)) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (-0, 2) + assert p[0].ranges[1][1:] == (-3, 4) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + + # lambda functions instead of symbolic expressions for multiple 3D + # parametric surfaces + p = plot3d_parametric_surface( + (lambda u, v: u, lambda u, v: v, lambda u, v: u + v, + ("u", 0, 2), ("v", -3, 4)), + (lambda u, v: v, lambda u, v: u, lambda u, v: u - v, + ("u", -2, 3), ("v", -4, 5), "test")) + assert all(callable(t) for t in p[0].expr) + assert p[0].ranges[0][1:] == (0, 2) + assert p[0].ranges[1][1:] == (-3, 4) + assert p[0].get_label(False) == "" + assert p[0].rendering_kw == {} + assert all(callable(t) for t in p[1].expr) + assert p[1].ranges[0][1:] == (-2, 3) + assert p[1].ranges[1][1:] == (-4, 5) + assert p[1].get_label(False) == "test" + assert p[1].rendering_kw == {} diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_plot_implicit.py b/lib/python3.10/site-packages/sympy/plotting/tests/test_plot_implicit.py new file mode 100644 index 0000000000000000000000000000000000000000..73c7b186c83f0b64d5f6f4cc5cd9f6a08efef43a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/tests/test_plot_implicit.py @@ -0,0 +1,146 @@ +from sympy.core.numbers import (I, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import re +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.logic.boolalg import (And, Or) +from sympy.plotting.plot_implicit import plot_implicit +from sympy.plotting.plot import unset_show +from tempfile import NamedTemporaryFile, mkdtemp +from sympy.testing.pytest import skip, warns, XFAIL +from sympy.external import import_module +from sympy.testing.tmpfiles import TmpFileManager + +import os + +#Set plots not to show +unset_show() + +def tmp_file(dir=None, name=''): + return NamedTemporaryFile( + suffix='.png', dir=dir, delete=False).name + +def plot_and_save(expr, *args, name='', dir=None, **kwargs): + p = plot_implicit(expr, *args, **kwargs) + p.save(tmp_file(dir=dir, name=name)) + # Close the plot to avoid a warning from matplotlib + p._backend.close() + +def plot_implicit_tests(name): + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + x = Symbol('x') + y = Symbol('y') + #implicit plot tests + plot_and_save(Eq(y, cos(x)), (x, -5, 5), (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), (x, -5, 5), + (y, -4, 4), name=name, dir=temp_dir) + plot_and_save(y > 1 / x, (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y < 1 / tan(x), (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y >= 2 * sin(x) * cos(x), (x, -5, 5), + (y, -2, 2), name=name, dir=temp_dir) + plot_and_save(y <= x**2, (x, -3, 3), + (y, -1, 5), name=name, dir=temp_dir) + + #Test all input args for plot_implicit + plot_and_save(Eq(y**2, x**3 - x), dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), adaptive=False, dir=temp_dir) + plot_and_save(Eq(y**2, x**3 - x), adaptive=False, n=500, dir=temp_dir) + plot_and_save(y > x, (x, -5, 5), dir=temp_dir) + plot_and_save(And(y > exp(x), y > x + 2), dir=temp_dir) + plot_and_save(Or(y > x, y > -x), dir=temp_dir) + plot_and_save(x**2 - 1, (x, -5, 5), dir=temp_dir) + plot_and_save(x**2 - 1, dir=temp_dir) + plot_and_save(y > x, depth=-5, dir=temp_dir) + plot_and_save(y > x, depth=5, dir=temp_dir) + plot_and_save(y > cos(x), adaptive=False, dir=temp_dir) + plot_and_save(y < cos(x), adaptive=False, dir=temp_dir) + plot_and_save(And(y > cos(x), Or(y > x, Eq(y, x))), dir=temp_dir) + plot_and_save(y - cos(pi / x), dir=temp_dir) + + plot_and_save(x**2 - 1, title='An implicit plot', dir=temp_dir) + +@XFAIL +def test_no_adaptive_meshing(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + try: + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + x = Symbol('x') + y = Symbol('y') + # Test plots which cannot be rendered using the adaptive algorithm + + # This works, but it triggers a deprecation warning from sympify(). The + # code needs to be updated to detect if interval math is supported without + # relying on random AttributeErrors. + with warns(UserWarning, match="Adaptive meshing could not be applied"): + plot_and_save(Eq(y, re(cos(x) + I*sin(x))), name='test', dir=temp_dir) + finally: + TmpFileManager.cleanup() + else: + skip("Matplotlib not the default backend") +def test_line_color(): + x, y = symbols('x, y') + p = plot_implicit(x**2 + y**2 - 1, line_color="green", show=False) + assert p._series[0].line_color == "green" + p = plot_implicit(x**2 + y**2 - 1, line_color='r', show=False) + assert p._series[0].line_color == "r" + +def test_matplotlib(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if matplotlib: + try: + plot_implicit_tests('test') + test_line_color() + finally: + TmpFileManager.cleanup() + else: + skip("Matplotlib not the default backend") + + +def test_region_and(): + matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,)) + if not matplotlib: + skip("Matplotlib not the default backend") + + from matplotlib.testing.compare import compare_images + test_directory = os.path.dirname(os.path.abspath(__file__)) + + try: + temp_dir = mkdtemp() + TmpFileManager.tmp_folder(temp_dir) + + x, y = symbols('x y') + + r1 = (x - 1)**2 + y**2 < 2 + r2 = (x + 1)**2 + y**2 < 2 + + test_filename = tmp_file(dir=temp_dir, name="test_region_and") + cmp_filename = os.path.join(test_directory, "test_region_and.png") + p = plot_implicit(r1 & r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_or") + cmp_filename = os.path.join(test_directory, "test_region_or.png") + p = plot_implicit(r1 | r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_not") + cmp_filename = os.path.join(test_directory, "test_region_not.png") + p = plot_implicit(~r1, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + + test_filename = tmp_file(dir=temp_dir, name="test_region_xor") + cmp_filename = os.path.join(test_directory, "test_region_xor.png") + p = plot_implicit(r1 ^ r2, x, y) + p.save(test_filename) + compare_images(cmp_filename, test_filename, 0.005) + finally: + TmpFileManager.cleanup() diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_region_and.png b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_and.png new file mode 100644 index 0000000000000000000000000000000000000000..61dda4c2054e5e4bd5018cb84af86a832e81886a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_and.png differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_region_not.png b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_not.png new file mode 100644 index 0000000000000000000000000000000000000000..29d3d47b5a95346cb7c44655c12a2a63e6c7a857 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_not.png differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_region_or.png b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_or.png new file mode 100644 index 0000000000000000000000000000000000000000..8a6329dd8dd368c37e431a7741e0869ec84f8f68 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_or.png differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_region_xor.png b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_xor.png new file mode 100644 index 0000000000000000000000000000000000000000..1a48862909d3ad09a5f4d306bf6c8f96117d080c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/plotting/tests/test_region_xor.png differ diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_series.py b/lib/python3.10/site-packages/sympy/plotting/tests/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..e23aa719153d20dc9b9e911e5cee097f0ea56211 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/tests/test_series.py @@ -0,0 +1,1771 @@ +from sympy import ( + latex, exp, symbols, I, pi, sin, cos, tan, log, sqrt, + re, im, arg, frac, Sum, S, Abs, lambdify, + Function, dsolve, Eq, floor, Tuple +) +from sympy.external import import_module +from sympy.plotting.series import ( + LineOver1DRangeSeries, Parametric2DLineSeries, Parametric3DLineSeries, + SurfaceOver2DRangeSeries, ContourSeries, ParametricSurfaceSeries, + ImplicitSeries, _set_discretization_points, List2DSeries +) +from sympy.testing.pytest import raises, warns, XFAIL, skip, ignore_warnings + +np = import_module('numpy') + + +def test_adaptive(): + # verify that adaptive-related keywords produces the expected results + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s1 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True, + depth=2) + x1, _ = s1.get_data() + s2 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True, + depth=5) + x2, _ = s2.get_data() + s3 = LineOver1DRangeSeries(sin(x), (x, -10, 10), "", adaptive=True) + x3, _ = s3.get_data() + assert len(x1) < len(x2) < len(x3) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True, depth=2) + x1, _, _, = s1.get_data() + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True, depth=5) + x2, _, _ = s2.get_data() + s3 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=True) + x3, _, _ = s3.get_data() + assert len(x1) < len(x2) < len(x3) + + +def test_detect_poles(): + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + s1 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=False) + xx1, yy1 = s1.get_data() + s2 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=True, eps=0.01) + xx2, yy2 = s2.get_data() + # eps is too small: doesn't detect any poles + s3 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles=True, eps=1e-06) + xx3, yy3 = s3.get_data() + s4 = LineOver1DRangeSeries(tan(x), (x, -pi, pi), + adaptive=False, n=1000, detect_poles="symbolic") + xx4, yy4 = s4.get_data() + + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) and np.allclose(xx1, xx4) + assert not np.any(np.isnan(yy1)) + assert not np.any(np.isnan(yy3)) + assert np.any(np.isnan(yy2)) + assert np.any(np.isnan(yy4)) + assert len(s2.poles_locations) == len(s3.poles_locations) == 0 + assert len(s4.poles_locations) == 2 + assert np.allclose(np.abs(s4.poles_locations), np.pi / 2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + s1 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles=False) + s2 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles=True, eps=0.05) + s3 = LineOver1DRangeSeries(frac(x), (x, -10, 10), + adaptive=False, n=1000, detect_poles="symbolic") + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + xx3, yy3 = s3.get_data() + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) + assert not np.any(np.isnan(yy1)) + assert np.any(np.isnan(yy2)) and np.any(np.isnan(yy2)) + assert not np.allclose(yy1, yy2, equal_nan=True) + # The poles below are actually step discontinuities. + assert len(s3.poles_locations) == 21 + + s1 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=False) + xx1, yy1 = s1.get_data() + s2 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=True, eps=0.01) + xx2, yy2 = s2.get_data() + # eps is too small: doesn't detect any poles + s3 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles=True, eps=1e-06) + xx3, yy3 = s3.get_data() + s4 = LineOver1DRangeSeries(tan(u * x), (x, -pi, pi), params={u: 1}, + adaptive=False, n=1000, detect_poles="symbolic") + xx4, yy4 = s4.get_data() + + assert np.allclose(xx1, xx2) and np.allclose(xx1, xx3) and np.allclose(xx1, xx4) + assert not np.any(np.isnan(yy1)) + assert not np.any(np.isnan(yy3)) + assert np.any(np.isnan(yy2)) + assert np.any(np.isnan(yy4)) + assert len(s2.poles_locations) == len(s3.poles_locations) == 0 + assert len(s4.poles_locations) == 2 + assert np.allclose(np.abs(s4.poles_locations), np.pi / 2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + u, v = symbols("u, v", real=True) + n = S(1) / 3 + f = (u + I * v)**n + r, i = re(f), im(f) + s1 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), (v, -2, 2), + adaptive=False, n=1000, detect_poles=False) + s2 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), (v, -2, 2), + adaptive=False, n=1000, detect_poles=True) + with ignore_warnings(RuntimeWarning): + xx1, yy1, pp1 = s1.get_data() + assert not np.isnan(yy1).any() + xx2, yy2, pp2 = s2.get_data() + assert np.isnan(yy2).any() + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some of", + test_stacklevel=False, + ): + f = (x * u + x * I * v)**n + r, i = re(f), im(f) + s1 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), + (v, -2, 2), params={x: 1}, + adaptive=False, n1=1000, detect_poles=False) + s2 = Parametric2DLineSeries(r.subs(u, -2), i.subs(u, -2), + (v, -2, 2), params={x: 1}, + adaptive=False, n1=1000, detect_poles=True) + with ignore_warnings(RuntimeWarning): + xx1, yy1, pp1 = s1.get_data() + assert not np.isnan(yy1).any() + xx2, yy2, pp2 = s2.get_data() + assert np.isnan(yy2).any() + + +def test_number_discretization_points(): + # verify that the different ways to set the number of discretization + # points are consistent with each other. + if not np: + skip("numpy not installed.") + + x, y, z = symbols("x:z") + + for pt in [LineOver1DRangeSeries, Parametric2DLineSeries, + Parametric3DLineSeries]: + kw1 = _set_discretization_points({"n": 10}, pt) + kw2 = _set_discretization_points({"n": [10, 20, 30]}, pt) + kw3 = _set_discretization_points({"n1": 10}, pt) + assert all(("n1" in kw) and kw["n1"] == 10 for kw in [kw1, kw2, kw3]) + + for pt in [SurfaceOver2DRangeSeries, ContourSeries, ParametricSurfaceSeries, + ImplicitSeries]: + kw1 = _set_discretization_points({"n": 10}, pt) + kw2 = _set_discretization_points({"n": [10, 20, 30]}, pt) + kw3 = _set_discretization_points({"n1": 10, "n2": 20}, pt) + assert kw1["n1"] == kw1["n2"] == 10 + assert all((kw["n1"] == 10) and (kw["n2"] == 20) for kw in [kw2, kw3]) + + # verify that line-related series can deal with large float number of + # discretization points + LineOver1DRangeSeries(cos(x), (x, -5, 5), adaptive=False, n=1e04).get_data() + + +def test_list2dseries(): + if not np: + skip("numpy not installed.") + + xx = np.linspace(-3, 3, 10) + yy1 = np.cos(xx) + yy2 = np.linspace(-3, 3, 20) + + # same number of elements: everything is fine + s = List2DSeries(xx, yy1) + assert not s.is_parametric + # different number of elements: error + raises(ValueError, lambda: List2DSeries(xx, yy2)) + + # no color func: returns only x, y components and s in not parametric + s = List2DSeries(xx, yy1) + xxs, yys = s.get_data() + assert np.allclose(xx, xxs) + assert np.allclose(yy1, yys) + assert not s.is_parametric + + +def test_interactive_vs_noninteractive(): + # verify that if a *Series class receives a `params` dictionary, it sets + # is_interactive=True + x, y, z, u, v = symbols("x, y, z, u, v") + + s = LineOver1DRangeSeries(cos(x), (x, -5, 5)) + assert not s.is_interactive + s = LineOver1DRangeSeries(u * cos(x), (x, -5, 5), params={u: 1}) + assert s.is_interactive + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5)) + assert not s.is_interactive + s = Parametric2DLineSeries(u * cos(x), u * sin(x), (x, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5)) + assert not s.is_interactive + s = Parametric3DLineSeries(u * cos(x), u * sin(x), x, (x, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -5, 5), (y, -5, 5)) + assert not s.is_interactive + s = SurfaceOver2DRangeSeries(u * cos(x * y), (x, -5, 5), (y, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = ContourSeries(cos(x * y), (x, -5, 5), (y, -5, 5)) + assert not s.is_interactive + s = ContourSeries(u * cos(x * y), (x, -5, 5), (y, -5, 5), + params={u: 1}) + assert s.is_interactive + + s = ParametricSurfaceSeries(u * cos(v), v * sin(u), u + v, + (u, -5, 5), (v, -5, 5)) + assert not s.is_interactive + s = ParametricSurfaceSeries(u * cos(v * x), v * sin(u), u + v, + (u, -5, 5), (v, -5, 5), params={x: 1}) + assert s.is_interactive + + +def test_lin_log_scale(): + # Verify that data series create the correct spacing in the data. + if not np: + skip("numpy not installed.") + + x, y, z = symbols("x, y, z") + + s = LineOver1DRangeSeries(x, (x, 1, 10), adaptive=False, n=50, + xscale="linear") + xx, _ = s.get_data() + assert np.isclose(xx[1] - xx[0], xx[-1] - xx[-2]) + + s = LineOver1DRangeSeries(x, (x, 1, 10), adaptive=False, n=50, + xscale="log") + xx, _ = s.get_data() + assert not np.isclose(xx[1] - xx[0], xx[-1] - xx[-2]) + + s = Parametric2DLineSeries( + cos(x), sin(x), (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="linear") + _, _, param = s.get_data() + assert np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric2DLineSeries( + cos(x), sin(x), (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="log") + _, _, param = s.get_data() + assert not np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric3DLineSeries( + cos(x), sin(x), x, (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="linear") + _, _, _, param = s.get_data() + assert np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = Parametric3DLineSeries( + cos(x), sin(x), x, (x, pi / 2, 1.5 * pi), adaptive=False, n=50, + xscale="log") + _, _, _, param = s.get_data() + assert not np.isclose(param[1] - param[0], param[-1] - param[-2]) + + s = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, 1, 5), (y, 1, 5), n=10, + xscale="linear", yscale="linear") + xx, yy, _ = s.get_data() + assert np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, 1, 5), (y, 1, 5), n=10, + xscale="log", yscale="log") + xx, yy, _ = s.get_data() + assert not np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert not np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = ImplicitSeries( + cos(x ** 2 + y ** 2) > 0, (x, 1, 5), (y, 1, 5), + n1=10, n2=10, xscale="linear", yscale="linear", adaptive=False) + xx, yy, _, _ = s.get_data() + assert np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + s = ImplicitSeries( + cos(x ** 2 + y ** 2) > 0, (x, 1, 5), (y, 1, 5), + n=10, xscale="log", yscale="log", adaptive=False) + xx, yy, _, _ = s.get_data() + assert not np.isclose(xx[0, 1] - xx[0, 0], xx[0, -1] - xx[0, -2]) + assert not np.isclose(yy[1, 0] - yy[0, 0], yy[-1, 0] - yy[-2, 0]) + + +def test_rendering_kw(): + # verify that each series exposes the `rendering_kw` attribute + if not np: + skip("numpy not installed.") + + u, v, x, y, z = symbols("u, v, x:z") + + s = List2DSeries([1, 2, 3], [4, 5, 6]) + assert isinstance(s.rendering_kw, dict) + + s = LineOver1DRangeSeries(1, (x, -5, 5)) + assert isinstance(s.rendering_kw, dict) + + s = Parametric2DLineSeries(sin(x), cos(x), (x, 0, pi)) + assert isinstance(s.rendering_kw, dict) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2 * pi)) + assert isinstance(s.rendering_kw, dict) + + s = SurfaceOver2DRangeSeries(x + y, (x, -2, 2), (y, -3, 3)) + assert isinstance(s.rendering_kw, dict) + + s = ContourSeries(x + y, (x, -2, 2), (y, -3, 3)) + assert isinstance(s.rendering_kw, dict) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1)) + assert isinstance(s.rendering_kw, dict) + + +def test_data_shape(): + # Verify that the series produces the correct data shape when the input + # expression is a number. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + # scalar expression: it should return a numpy ones array + s = LineOver1DRangeSeries(1, (x, -5, 5)) + xx, yy = s.get_data() + assert len(xx) == len(yy) + assert np.all(yy == 1) + + s = LineOver1DRangeSeries(1, (x, -5, 5), adaptive=False, n=10) + xx, yy = s.get_data() + assert len(xx) == len(yy) == 10 + assert np.all(yy == 1) + + s = Parametric2DLineSeries(sin(x), 1, (x, 0, pi)) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric2DLineSeries(1, sin(x), (x, 0, pi)) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = Parametric2DLineSeries(sin(x), 1, (x, 0, pi), adaptive=False) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric2DLineSeries(1, sin(x), (x, 0, pi), adaptive=False) + xx, yy, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = Parametric3DLineSeries(cos(x), sin(x), 1, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(zz == 1) + + s = Parametric3DLineSeries(cos(x), 1, x, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(yy == 1) + + s = Parametric3DLineSeries(1, sin(x), x, (x, 0, 2 * pi)) + xx, yy, zz, param = s.get_data() + assert (len(xx) == len(yy)) and (len(xx) == len(zz)) and (len(xx) == len(param)) + assert np.all(xx == 1) + + s = SurfaceOver2DRangeSeries(1, (x, -2, 2), (y, -3, 3)) + xx, yy, zz = s.get_data() + assert (xx.shape == yy.shape) and (xx.shape == zz.shape) + assert np.all(zz == 1) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(xx == 1) + + s = ParametricSurfaceSeries(1, 1, y, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(yy == 1) + + s = ParametricSurfaceSeries(x, 1, 1, (x, 0, 1), (y, 0, 1)) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape + assert np.all(zz == 1) + + +def test_only_integers(): + if not np: + skip("numpy not installed.") + + x, y, u, v = symbols("x, y, u, v") + + s = LineOver1DRangeSeries(sin(x), (x, -5.5, 4.5), "", + adaptive=False, only_integers=True) + xx, _ = s.get_data() + assert len(xx) == 10 + assert xx[0] == -5 and xx[-1] == 4 + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -5.5, 5.5), + (y, -3.5, 3.5), "", + adaptive=False, only_integers=True) + xx, yy, _ = s.get_data() + assert xx.shape == yy.shape == (7, 11) + assert np.allclose(xx[:, 0] - (-5) * np.ones(7), 0) + assert np.allclose(xx[0, :] - np.linspace(-5, 5, 11), 0) + assert np.allclose(yy[:, 0] - np.linspace(-3, 3, 7), 0) + assert np.allclose(yy[0, :] - (-3) * np.ones(11), 0) + + r = 2 + sin(7 * u + 5 * v) + expr = ( + r * cos(u) * sin(v), + r * sin(u) * sin(v), + r * cos(v) + ) + s = ParametricSurfaceSeries(*expr, (u, 0, 2 * pi), (v, 0, pi), "", + adaptive=False, only_integers=True) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape == (4, 7) + + # only_integers also works with scalar expressions + s = LineOver1DRangeSeries(1, (x, -5.5, 4.5), "", + adaptive=False, only_integers=True) + xx, _ = s.get_data() + assert len(xx) == 10 + assert xx[0] == -5 and xx[-1] == 4 + + s = Parametric2DLineSeries(cos(x), 1, (x, 0, 2 * pi), "", + adaptive=False, only_integers=True) + _, _, p = s.get_data() + assert len(p) == 7 + assert p[0] == 0 and p[-1] == 6 + + s = SurfaceOver2DRangeSeries(1, (x, -5.5, 5.5), (y, -3.5, 3.5), "", + adaptive=False, only_integers=True) + xx, yy, _ = s.get_data() + assert xx.shape == yy.shape == (7, 11) + assert np.allclose(xx[:, 0] - (-5) * np.ones(7), 0) + assert np.allclose(xx[0, :] - np.linspace(-5, 5, 11), 0) + assert np.allclose(yy[:, 0] - np.linspace(-3, 3, 7), 0) + assert np.allclose(yy[0, :] - (-3) * np.ones(11), 0) + + r = 2 + sin(7 * u + 5 * v) + expr = ( + r * cos(u) * sin(v), + 1, + r * cos(v) + ) + s = ParametricSurfaceSeries(*expr, (u, 0, 2 * pi), (v, 0, pi), "", + adaptive=False, only_integers=True) + xx, yy, zz, uu, vv = s.get_data() + assert xx.shape == yy.shape == zz.shape == uu.shape == vv.shape == (4, 7) + + +def test_is_point_is_filled(): + # verify that `is_point` and `is_filled` are attributes and that they + # they receive the correct values + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + s = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = List2DSeries([0, 1, 2], [3, 4, 5], + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = List2DSeries([0, 1, 2], [3, 4, 5], + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + is_point=False, is_filled=True) + assert (not s.is_point) and s.is_filled + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + is_point=True, is_filled=False) + assert s.is_point and (not s.is_filled) + + +def test_is_filled_2d(): + # verify that the is_filled attribute is exposed by the following series + x, y = symbols("x, y") + + expr = cos(x**2 + y**2) + ranges = (x, -2, 2), (y, -2, 2) + + s = ContourSeries(expr, *ranges) + assert s.is_filled + s = ContourSeries(expr, *ranges, is_filled=True) + assert s.is_filled + s = ContourSeries(expr, *ranges, is_filled=False) + assert not s.is_filled + + +def test_steps(): + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + + def do_test(s1, s2): + if (not s1.is_parametric) and s1.is_2Dline: + xx1, _ = s1.get_data() + xx2, _ = s2.get_data() + elif s1.is_parametric and s1.is_2Dline: + xx1, _, _ = s1.get_data() + xx2, _, _ = s2.get_data() + elif (not s1.is_parametric) and s1.is_3Dline: + xx1, _, _ = s1.get_data() + xx2, _, _ = s2.get_data() + else: + xx1, _, _, _ = s1.get_data() + xx2, _, _, _ = s2.get_data() + assert len(xx1) != len(xx2) + + s1 = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + adaptive=False, n=40, steps=False) + s2 = LineOver1DRangeSeries(cos(x), (x, -5, 5), "", + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + s1 = List2DSeries([0, 1, 2], [3, 4, 5], steps=False) + s2 = List2DSeries([0, 1, 2], [3, 4, 5], steps=True) + do_test(s1, s2) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=40, steps=False) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + s1 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=40, steps=False) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=40, steps=True) + do_test(s1, s2) + + +def test_interactive_data(): + # verify that InteractiveSeries produces the same numerical data as their + # corresponding non-interactive series. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + s1 = LineOver1DRangeSeries(u * cos(x), (x, -5, 5), params={u: 1}, n=50) + s2 = LineOver1DRangeSeries(cos(x), (x, -5, 5), adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = Parametric2DLineSeries( + u * cos(x), u * sin(x), (x, -5, 5), params={u: 1}, n=50) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, -5, 5), + adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = Parametric3DLineSeries( + u * cos(x), u * sin(x), u * x, (x, -5, 5), + params={u: 1}, n=50) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, -5, 5), + adaptive=False, n=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = SurfaceOver2DRangeSeries( + u * cos(x ** 2 + y ** 2), (x, -3, 3), (y, -3, 3), + params={u: 1}, n1=50, n2=50,) + s2 = SurfaceOver2DRangeSeries( + cos(x ** 2 + y ** 2), (x, -3, 3), (y, -3, 3), + adaptive=False, n1=50, n2=50) + do_test(s1.get_data(), s2.get_data()) + + s1 = ParametricSurfaceSeries( + u * cos(x + y), sin(x + y), x - y, (x, -3, 3), (y, -3, 3), + params={u: 1}, n1=50, n2=50,) + s2 = ParametricSurfaceSeries( + cos(x + y), sin(x + y), x - y, (x, -3, 3), (y, -3, 3), + adaptive=False, n1=50, n2=50,) + do_test(s1.get_data(), s2.get_data()) + + # real part of a complex function evaluated over a real line with numpy + expr = re((z ** 2 + 1) / (z ** 2 - 1)) + s1 = LineOver1DRangeSeries(u * expr, (z, -3, 3), adaptive=False, n=50, + modules=None, params={u: 1}) + s2 = LineOver1DRangeSeries(expr, (z, -3, 3), adaptive=False, n=50, + modules=None) + do_test(s1.get_data(), s2.get_data()) + + # real part of a complex function evaluated over a real line with mpmath + expr = re((z ** 2 + 1) / (z ** 2 - 1)) + s1 = LineOver1DRangeSeries(u * expr, (z, -3, 3), n=50, modules="mpmath", + params={u: 1}) + s2 = LineOver1DRangeSeries(expr, (z, -3, 3), + adaptive=False, n=50, modules="mpmath") + do_test(s1.get_data(), s2.get_data()) + + +def test_list2dseries_interactive(): + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x, y, u") + + s = List2DSeries([1, 2, 3], [1, 2, 3]) + assert not s.is_interactive + + # symbolic expressions as coordinates, but no ``params`` + raises(ValueError, lambda: List2DSeries([cos(x)], [sin(x)])) + + # too few parameters + raises(ValueError, + lambda: List2DSeries([cos(x), y], [sin(x), 2], params={u: 1})) + + s = List2DSeries([cos(x)], [sin(x)], params={x: 1}) + assert s.is_interactive + + s = List2DSeries([x, 2, 3, 4], [4, 3, 2, x], params={x: 3}) + xx, yy = s.get_data() + assert np.allclose(xx, [3, 2, 3, 4]) + assert np.allclose(yy, [4, 3, 2, 3]) + assert not s.is_parametric + + # numeric lists + params is present -> interactive series and + # lists are converted to Tuple. + s = List2DSeries([1, 2, 3], [1, 2, 3], params={x: 1}) + assert s.is_interactive + assert isinstance(s.list_x, Tuple) + assert isinstance(s.list_y, Tuple) + + +def test_mpmath(): + # test that the argument of complex functions evaluated with mpmath + # might be different than the one computed with Numpy (different + # behaviour at branch cuts) + if not np: + skip("numpy not installed.") + + z, u = symbols("z, u") + + s1 = LineOver1DRangeSeries(im(sqrt(-z)), (z, 1e-03, 5), + adaptive=True, modules=None, force_real_eval=True) + s2 = LineOver1DRangeSeries(im(sqrt(-z)), (z, 1e-03, 5), + adaptive=True, modules="mpmath", force_real_eval=True) + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + assert np.all(yy1 < 0) + assert np.all(yy2 > 0) + + s1 = LineOver1DRangeSeries(im(sqrt(-z)), (z, -5, 5), + adaptive=False, n=20, modules=None, force_real_eval=True) + s2 = LineOver1DRangeSeries(im(sqrt(-z)), (z, -5, 5), + adaptive=False, n=20, modules="mpmath", force_real_eval=True) + xx1, yy1 = s1.get_data() + xx2, yy2 = s2.get_data() + assert np.allclose(xx1, xx2) + assert not np.allclose(yy1, yy2) + + +def test_str(): + u, x, y, z = symbols("u, x:z") + + s = LineOver1DRangeSeries(cos(x), (x, -4, 3)) + assert str(s) == "cartesian line: cos(x) for x over (-4.0, 3.0)" + + d = {"return": "real"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: re(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "imag"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: im(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "abs"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: abs(cos(x)) for x over (-4.0, 3.0)" + + d = {"return": "arg"} + s = LineOver1DRangeSeries(cos(x), (x, -4, 3), **d) + assert str(s) == "cartesian line: arg(cos(x)) for x over (-4.0, 3.0)" + + s = LineOver1DRangeSeries(cos(u * x), (x, -4, 3), params={u: 1}) + assert str(s) == "interactive cartesian line: cos(u*x) for x over (-4.0, 3.0) and parameters (u,)" + + s = LineOver1DRangeSeries(cos(u * x), (x, -u, 3*y), params={u: 1, y: 1}) + assert str(s) == "interactive cartesian line: cos(u*x) for x over (-u, 3*y) and parameters (u, y)" + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3)) + assert str(s) == "parametric cartesian line: (cos(x), sin(x)) for x over (-4.0, 3.0)" + + s = Parametric2DLineSeries(cos(u * x), sin(x), (x, -4, 3), params={u: 1}) + assert str(s) == "interactive parametric cartesian line: (cos(u*x), sin(x)) for x over (-4.0, 3.0) and parameters (u,)" + + s = Parametric2DLineSeries(cos(u * x), sin(x), (x, -u, 3*y), params={u: 1, y:1}) + assert str(s) == "interactive parametric cartesian line: (cos(u*x), sin(x)) for x over (-u, 3*y) and parameters (u, y)" + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3)) + assert str(s) == "3D parametric cartesian line: (cos(x), sin(x), x) for x over (-4.0, 3.0)" + + s = Parametric3DLineSeries(cos(u*x), sin(x), x, (x, -4, 3), params={u: 1}) + assert str(s) == "interactive 3D parametric cartesian line: (cos(u*x), sin(x), x) for x over (-4.0, 3.0) and parameters (u,)" + + s = Parametric3DLineSeries(cos(u*x), sin(x), x, (x, -u, 3*y), params={u: 1, y: 1}) + assert str(s) == "interactive 3D parametric cartesian line: (cos(u*x), sin(x), x) for x over (-u, 3*y) and parameters (u, y)" + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5)) + assert str(s) == "cartesian surface: cos(x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = SurfaceOver2DRangeSeries(cos(u * x * y), (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive cartesian surface: cos(u*x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = SurfaceOver2DRangeSeries(cos(u * x * y), (x, -4*u, 3), (y, -2, 5*u), params={u: 1}) + assert str(s) == "interactive cartesian surface: cos(u*x*y) for x over (-4*u, 3.0) and y over (-2.0, 5*u) and parameters (u,)" + + s = ContourSeries(cos(x * y), (x, -4, 3), (y, -2, 5)) + assert str(s) == "contour: cos(x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = ContourSeries(cos(u * x * y), (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive contour: cos(u*x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5)) + assert str(s) == "parametric cartesian surface: (cos(x*y), sin(x*y), x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0)" + + s = ParametricSurfaceSeries(cos(u * x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), params={u: 1}) + assert str(s) == "interactive parametric cartesian surface: (cos(u*x*y), sin(x*y), x*y) for x over (-4.0, 3.0) and y over (-2.0, 5.0) and parameters (u,)" + + s = ImplicitSeries(x < y, (x, -5, 4), (y, -3, 2)) + assert str(s) == "Implicit expression: x < y for x over (-5.0, 4.0) and y over (-3.0, 2.0)" + + +def test_use_cm(): + # verify that the `use_cm` attribute is implemented. + if not np: + skip("numpy not installed.") + + u, x, y, z = symbols("u, x:z") + + s = List2DSeries([1, 2, 3, 4], [5, 6, 7, 8], use_cm=True) + assert s.use_cm + s = List2DSeries([1, 2, 3, 4], [5, 6, 7, 8], use_cm=False) + assert not s.use_cm + + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3), use_cm=True) + assert s.use_cm + s = Parametric2DLineSeries(cos(x), sin(x), (x, -4, 3), use_cm=False) + assert not s.use_cm + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3), + use_cm=True) + assert s.use_cm + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, -4, 3), + use_cm=False) + assert not s.use_cm + + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5), + use_cm=True) + assert s.use_cm + s = SurfaceOver2DRangeSeries(cos(x * y), (x, -4, 3), (y, -2, 5), + use_cm=False) + assert not s.use_cm + + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), use_cm=True) + assert s.use_cm + s = ParametricSurfaceSeries(cos(x * y), sin(x * y), x * y, + (x, -4, 3), (y, -2, 5), use_cm=False) + assert not s.use_cm + + +def test_surface_use_cm(): + # verify that SurfaceOver2DRangeSeries and ParametricSurfaceSeries get + # the same value for use_cm + + x, y, u, v = symbols("x, y, u, v") + + # they read the same value from default settings + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2)) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi)) + assert s1.use_cm == s2.use_cm + + # they get the same value + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + use_cm=False) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi), use_cm=False) + assert s1.use_cm == s2.use_cm + + # they get the same value + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + use_cm=True) + s2 = ParametricSurfaceSeries(u * cos(v), u * sin(v), u, + (u, 0, 1), (v, 0 , 2*pi), use_cm=True) + assert s1.use_cm == s2.use_cm + + +def test_sums(): + # test that data series are able to deal with sums + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x, y, u") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + s = LineOver1DRangeSeries(Sum(1 / x ** y, (x, 1, 1000)), (y, 2, 10), + adaptive=False, only_integers=True) + xx, yy = s.get_data() + + s1 = LineOver1DRangeSeries(Sum(1 / x, (x, 1, y)), (y, 2, 10), + adaptive=False, only_integers=True) + xx1, yy1 = s1.get_data() + + s2 = LineOver1DRangeSeries(Sum(u / x, (x, 1, y)), (y, 2, 10), + params={u: 1}, only_integers=True) + xx2, yy2 = s2.get_data() + xx1 = xx1.astype(float) + xx2 = xx2.astype(float) + do_test([xx1, yy1], [xx2, yy2]) + + s = LineOver1DRangeSeries(Sum(1 / x, (x, 1, y)), (y, 2, 10), + adaptive=True) + with warns( + UserWarning, + match="The evaluation with NumPy/SciPy failed", + test_stacklevel=False, + ): + raises(TypeError, lambda: s.get_data()) + + +def test_apply_transforms(): + # verify that transformation functions get applied to the output + # of data series + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x:z, u, v") + + s1 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + tx=np.rad2deg) + s3 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + ty=np.rad2deg) + s4 = LineOver1DRangeSeries(cos(x), (x, -2*pi, 2*pi), adaptive=False, n=10, + tx=np.rad2deg, ty=np.rad2deg) + + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + x3, y3 = s3.get_data() + x4, y4 = s4.get_data() + assert np.isclose(x1[0], -2*np.pi) and np.isclose(x1[-1], 2*np.pi) + assert (y1.min() < -0.9) and (y1.max() > 0.9) + assert np.isclose(x2[0], -360) and np.isclose(x2[-1], 360) + assert (y2.min() < -0.9) and (y2.max() > 0.9) + assert np.isclose(x3[0], -2*np.pi) and np.isclose(x3[-1], 2*np.pi) + assert (y3.min() < -52) and (y3.max() > 52) + assert np.isclose(x4[0], -360) and np.isclose(x4[-1], 360) + assert (y4.min() < -52) and (y4.max() > 52) + + xx = np.linspace(-2*np.pi, 2*np.pi, 10) + yy = np.cos(xx) + s1 = List2DSeries(xx, yy) + s2 = List2DSeries(xx, yy, tx=np.rad2deg, ty=np.rad2deg) + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + assert np.isclose(x1[0], -2*np.pi) and np.isclose(x1[-1], 2*np.pi) + assert (y1.min() < -0.9) and (y1.max() > 0.9) + assert np.isclose(x2[0], -360) and np.isclose(x2[-1], 360) + assert (y2.min() < -52) and (y2.max() > 52) + + s1 = Parametric2DLineSeries( + sin(x), cos(x), (x, -pi, pi), adaptive=False, n=10) + s2 = Parametric2DLineSeries( + sin(x), cos(x), (x, -pi, pi), adaptive=False, n=10, + tx=np.rad2deg, ty=np.rad2deg, tp=np.rad2deg) + x1, y1, a1 = s1.get_data() + x2, y2, a2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, np.deg2rad(y2)) + assert np.allclose(a1, np.deg2rad(a2)) + + s1 = Parametric3DLineSeries( + sin(x), cos(x), x, (x, -pi, pi), adaptive=False, n=10) + s2 = Parametric3DLineSeries( + sin(x), cos(x), x, (x, -pi, pi), adaptive=False, n=10, tp=np.rad2deg) + x1, y1, z1, a1 = s1.get_data() + x2, y2, z2, a2 = s2.get_data() + assert np.allclose(x1, x2) + assert np.allclose(y1, y2) + assert np.allclose(z1, z2) + assert np.allclose(a1, np.deg2rad(a2)) + + s1 = SurfaceOver2DRangeSeries( + cos(x**2 + y**2), (x, -2*pi, 2*pi), (y, -2*pi, 2*pi), + adaptive=False, n1=10, n2=10) + s2 = SurfaceOver2DRangeSeries( + cos(x**2 + y**2), (x, -2*pi, 2*pi), (y, -2*pi, 2*pi), + adaptive=False, n1=10, n2=10, + tx=np.rad2deg, ty=lambda x: 2*x, tz=lambda x: 3*x) + x1, y1, z1 = s1.get_data() + x2, y2, z2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, y2 / 2) + assert np.allclose(z1, z2 / 3) + + s1 = ParametricSurfaceSeries( + u + v, u - v, u * v, (u, 0, 2*pi), (v, 0, pi), + adaptive=False, n1=10, n2=10) + s2 = ParametricSurfaceSeries( + u + v, u - v, u * v, (u, 0, 2*pi), (v, 0, pi), + adaptive=False, n1=10, n2=10, + tx=np.rad2deg, ty=lambda x: 2*x, tz=lambda x: 3*x) + x1, y1, z1, u1, v1 = s1.get_data() + x2, y2, z2, u2, v2 = s2.get_data() + assert np.allclose(x1, np.deg2rad(x2)) + assert np.allclose(y1, y2 / 2) + assert np.allclose(z1, z2 / 3) + assert np.allclose(u1, u2) + assert np.allclose(v1, v2) + + +def test_series_labels(): + # verify that series return the correct label, depending on the plot + # type and input arguments. If the user set custom label on a data series, + # it should returned un-modified. + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x, y, z, u, v") + wrapper = "$%s$" + + expr = cos(x) + s1 = LineOver1DRangeSeries(expr, (x, -2, 2), None) + s2 = LineOver1DRangeSeries(expr, (x, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + s1 = List2DSeries([0, 1, 2, 3], [0, 1, 2, 3], "test") + assert s1.get_label(False) == "test" + assert s1.get_label(True) == "test" + + expr = (cos(x), sin(x)) + s1 = Parametric2DLineSeries(*expr, (x, -2, 2), None, use_cm=True) + s2 = Parametric2DLineSeries(*expr, (x, -2, 2), "test", use_cm=True) + s3 = Parametric2DLineSeries(*expr, (x, -2, 2), None, use_cm=False) + s4 = Parametric2DLineSeries(*expr, (x, -2, 2), "test", use_cm=False) + assert s1.get_label(False) == "x" + assert s1.get_label(True) == wrapper % "x" + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + assert s3.get_label(False) == str(expr) + assert s3.get_label(True) == wrapper % latex(expr) + assert s4.get_label(False) == "test" + assert s4.get_label(True) == "test" + + expr = (cos(x), sin(x), x) + s1 = Parametric3DLineSeries(*expr, (x, -2, 2), None, use_cm=True) + s2 = Parametric3DLineSeries(*expr, (x, -2, 2), "test", use_cm=True) + s3 = Parametric3DLineSeries(*expr, (x, -2, 2), None, use_cm=False) + s4 = Parametric3DLineSeries(*expr, (x, -2, 2), "test", use_cm=False) + assert s1.get_label(False) == "x" + assert s1.get_label(True) == wrapper % "x" + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + assert s3.get_label(False) == str(expr) + assert s3.get_label(True) == wrapper % latex(expr) + assert s4.get_label(False) == "test" + assert s4.get_label(True) == "test" + + expr = cos(x**2 + y**2) + s1 = SurfaceOver2DRangeSeries(expr, (x, -2, 2), (y, -2, 2), None) + s2 = SurfaceOver2DRangeSeries(expr, (x, -2, 2), (y, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + expr = (cos(x - y), sin(x + y), x - y) + s1 = ParametricSurfaceSeries(*expr, (x, -2, 2), (y, -2, 2), None) + s2 = ParametricSurfaceSeries(*expr, (x, -2, 2), (y, -2, 2), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + expr = Eq(cos(x - y), 0) + s1 = ImplicitSeries(expr, (x, -10, 10), (y, -10, 10), None) + s2 = ImplicitSeries(expr, (x, -10, 10), (y, -10, 10), "test") + assert s1.get_label(False) == str(expr) + assert s1.get_label(True) == wrapper % latex(expr) + assert s2.get_label(False) == "test" + assert s2.get_label(True) == "test" + + +def test_is_polar_2d_parametric(): + # verify that Parametric2DLineSeries isable to apply polar discretization, + # which is used when polar_plot is executed with polar_axis=True + if not np: + skip("numpy not installed.") + + t, u = symbols("t u") + + # NOTE: a sufficiently big n must be provided, or else tests + # are going to fail + # No colormap + f = sin(4 * t) + s1 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=False, use_cm=False) + x1, y1, p1 = s1.get_data() + s2 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=True, use_cm=False) + th, r, p2 = s2.get_data() + assert (not np.allclose(x1, th)) and (not np.allclose(y1, r)) + assert np.allclose(p1, p2) + + # With colormap + s3 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=False, color_func=lambda t: 2*t) + x3, y3, p3 = s3.get_data() + s4 = Parametric2DLineSeries(f * cos(t), f * sin(t), (t, 0, 2*pi), + adaptive=False, n=10, is_polar=True, color_func=lambda t: 2*t) + th4, r4, p4 = s4.get_data() + assert np.allclose(p3, p4) and (not np.allclose(p1, p3)) + assert np.allclose(x3, x1) and np.allclose(y3, y1) + assert np.allclose(th4, th) and np.allclose(r4, r) + + +def test_is_polar_3d(): + # verify that SurfaceOver2DRangeSeries is able to apply + # polar discretization + if not np: + skip("numpy not installed.") + + x, y, t = symbols("x, y, t") + expr = (x**2 - 1)**2 + s1 = SurfaceOver2DRangeSeries(expr, (x, 0, 1.5), (y, 0, 2 * pi), + n=10, adaptive=False, is_polar=False) + s2 = SurfaceOver2DRangeSeries(expr, (x, 0, 1.5), (y, 0, 2 * pi), + n=10, adaptive=False, is_polar=True) + x1, y1, z1 = s1.get_data() + x2, y2, z2 = s2.get_data() + x22, y22 = x1 * np.cos(y1), x1 * np.sin(y1) + assert np.allclose(x2, x22) + assert np.allclose(y2, y22) + + +def test_color_func(): + # verify that eval_color_func produces the expected results in order to + # maintain back compatibility with the old sympy.plotting module + if not np: + skip("numpy not installed.") + + x, y, z, u, v = symbols("x, y, z, u, v") + + # color func: returns x, y, color and s is parametric + xx = np.linspace(-3, 3, 10) + yy1 = np.cos(xx) + s = List2DSeries(xx, yy1, color_func=lambda x, y: 2 * x, use_cm=True) + xxs, yys, col = s.get_data() + assert np.allclose(xx, xxs) + assert np.allclose(yy1, yys) + assert np.allclose(2 * xx, col) + assert s.is_parametric + + s = List2DSeries(xx, yy1, color_func=lambda x, y: 2 * x, use_cm=False) + assert len(s.get_data()) == 2 + assert not s.is_parametric + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: t) + xx, yy, col = s.get_data() + assert (not np.allclose(xx, col)) and (not np.allclose(yy, col)) + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y: x * y) + xx, yy, col = s.get_data() + assert np.allclose(col, xx * yy) + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, t: x * y * t) + xx, yy, col = s.get_data() + assert np.allclose(col, xx * yy * np.linspace(0, 2*np.pi, 10)) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: t) + xx, yy, zz, col = s.get_data() + assert (not np.allclose(xx, col)) and (not np.allclose(yy, col)) + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, xx * yy * zz) + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda x, y, z, t: x * y * z * t) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, xx * yy * zz * np.linspace(0, 2*np.pi, 10)) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x: x) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx, col) + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x, y: x * y) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx * yy, col) + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz = s.get_data() + col = s.eval_color_func(xx, yy, zz) + assert np.allclose(xx * yy * zz, col) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u:u) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(uu, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u, v: u * v) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(uu * vv, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda x, y, z: x * y * z) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(xx * yy * zz, col) + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda x, y, z, u, v: x * y * z * u * v) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(xx * yy * zz * uu * vv, col) + + # Interactive Series + s = List2DSeries([0, 1, 2, x], [x, 2, 3, 4], + color_func=lambda x, y: 2 * x, params={x: 1}, use_cm=True) + xx, yy, col = s.get_data() + assert np.allclose(xx, [0, 1, 2, 1]) + assert np.allclose(yy, [1, 2, 3, 4]) + assert np.allclose(2 * xx, col) + assert s.is_parametric and s.use_cm + + s = List2DSeries([0, 1, 2, x], [x, 2, 3, 4], + color_func=lambda x, y: 2 * x, params={x: 1}, use_cm=False) + assert len(s.get_data()) == 2 + assert not s.is_parametric + + +def test_color_func_scalar_val(): + # verify that eval_color_func returns a numpy array even when color_func + # evaluates to a scalar value + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: 1) + xx, yy, col = s.get_data() + assert np.allclose(col, np.ones(xx.shape)) + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 2*pi), + adaptive=False, n=10, color_func=lambda t: 1) + xx, yy, zz, col = s.get_data() + assert np.allclose(col, np.ones(xx.shape)) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + adaptive=False, n1=10, n2=10, color_func=lambda x: 1) + xx, yy, zz = s.get_data() + assert np.allclose(s.eval_color_func(xx), np.ones(xx.shape)) + + s = ParametricSurfaceSeries(1, x, y, (x, 0, 1), (y, 0, 1), adaptive=False, + n1=10, n2=10, color_func=lambda u: 1) + xx, yy, zz, uu, vv = s.get_data() + col = s.eval_color_func(xx, yy, zz, uu, vv) + assert np.allclose(col, np.ones(xx.shape)) + + +def test_color_func_expression(): + # verify that color_func is able to deal with instances of Expr: they will + # be lambdified with the same signature used for the main expression. + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + color_func=sin(x), adaptive=False, n=10, use_cm=True) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + color_func=lambda x: np.cos(x), adaptive=False, n=10, use_cm=True) + # the following statement should not raise errors + d1 = s1.get_data() + assert callable(s1.color_func) + d2 = s2.get_data() + assert not np.allclose(d1[-1], d2[-1]) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), + color_func=sin(x**2 + y**2), adaptive=False, n1=5, n2=5) + # the following statement should not raise errors + s.get_data() + assert callable(s.color_func) + + xx = [1, 2, 3, 4, 5] + yy = [1, 2, 3, 4, 5] + raises(TypeError, + lambda : List2DSeries(xx, yy, use_cm=True, color_func=sin(x))) + + +def test_line_surface_color(): + # verify the back-compatibility with the old sympy.plotting module. + # By setting line_color or surface_color to be a callable, it will set + # the color_func attribute. + + x, y, z = symbols("x, y, z") + + s = LineOver1DRangeSeries(sin(x), (x, -5, 5), adaptive=False, n=10, + line_color=lambda x: x) + assert (s.line_color is None) and callable(s.color_func) + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 2*pi), + adaptive=False, n=10, line_color=lambda t: t) + assert (s.line_color is None) and callable(s.color_func) + + s = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -2, 2), (y, -2, 2), + n1=10, n2=10, surface_color=lambda x: x) + assert (s.surface_color is None) and callable(s.color_func) + + +def test_complex_adaptive_false(): + # verify that series with adaptive=False is evaluated with discretized + # ranges of type complex. + if not np: + skip("numpy not installed.") + + x, y, u = symbols("x y u") + + def do_test(data1, data2): + assert len(data1) == len(data2) + for d1, d2 in zip(data1, data2): + assert np.allclose(d1, d2) + + expr1 = sqrt(x) * exp(-x**2) + expr2 = sqrt(u * x) * exp(-x**2) + s1 = LineOver1DRangeSeries(im(expr1), (x, -5, 5), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(im(expr2), (x, -5, 5), + adaptive=False, n=10, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + s1 = Parametric2DLineSeries(re(expr1), im(expr1), (x, -pi, pi), + adaptive=False, n=10) + s2 = Parametric2DLineSeries(re(expr2), im(expr2), (x, -pi, pi), + adaptive=False, n=10, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + s1 = SurfaceOver2DRangeSeries(im(expr1), (x, -5, 5), (y, -10, 10), + adaptive=False, n1=30, n2=3) + s2 = SurfaceOver2DRangeSeries(im(expr2), (x, -5, 5), (y, -10, 10), + adaptive=False, n1=30, n2=3, params={u: 1}) + data1 = s1.get_data() + data2 = s2.get_data() + do_test(data1, data2) + assert (not np.allclose(data1[1], 0)) and (not np.allclose(data2[1], 0)) + + +def test_expr_is_lambda_function(): + # verify that when a numpy function is provided, the series will be able + # to evaluate it. Also, label should be empty in order to prevent some + # backend from crashing. + if not np: + skip("numpy not installed.") + + f = lambda x: np.cos(x) + s1 = LineOver1DRangeSeries(f, ("x", -5, 5), adaptive=True, depth=3) + s1.get_data() + s2 = LineOver1DRangeSeries(f, ("x", -5, 5), adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + fx = lambda x: np.cos(x) + fy = lambda x: np.sin(x) + s1 = Parametric2DLineSeries(fx, fy, ("x", 0, 2*pi), + adaptive=True, adaptive_goal=0.1) + s1.get_data() + s2 = Parametric2DLineSeries(fx, fy, ("x", 0, 2*pi), + adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + fz = lambda x: x + s1 = Parametric3DLineSeries(fx, fy, fz, ("x", 0, 2*pi), + adaptive=True, adaptive_goal=0.1) + s1.get_data() + s2 = Parametric3DLineSeries(fx, fy, fz, ("x", 0, 2*pi), + adaptive=False, n=10) + s2.get_data() + assert s1.label == s2.label == "" + + f = lambda x, y: np.cos(x**2 + y**2) + s1 = SurfaceOver2DRangeSeries(f, ("a", -2, 2), ("b", -3, 3), + adaptive=False, n1=10, n2=10) + s1.get_data() + s2 = ContourSeries(f, ("a", -2, 2), ("b", -3, 3), + adaptive=False, n1=10, n2=10) + s2.get_data() + assert s1.label == s2.label == "" + + fx = lambda u, v: np.cos(u + v) + fy = lambda u, v: np.sin(u - v) + fz = lambda u, v: u * v + s1 = ParametricSurfaceSeries(fx, fy, fz, ("u", 0, pi), ("v", 0, 2*pi), + adaptive=False, n1=10, n2=10) + s1.get_data() + assert s1.label == "" + + raises(TypeError, lambda: List2DSeries(lambda t: t, lambda t: t)) + raises(TypeError, lambda : ImplicitSeries(lambda t: np.sin(t), + ("x", -5, 5), ("y", -6, 6))) + + +def test_show_in_legend_lines(): + # verify that lines series correctly set the show_in_legend attribute + x, u = symbols("x, u") + + s = LineOver1DRangeSeries(cos(x), (x, -2, 2), "test", show_in_legend=True) + assert s.show_in_legend + s = LineOver1DRangeSeries(cos(x), (x, -2, 2), "test", show_in_legend=False) + assert not s.show_in_legend + + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), "test", + show_in_legend=True) + assert s.show_in_legend + s = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), "test", + show_in_legend=False) + assert not s.show_in_legend + + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), "test", + show_in_legend=True) + assert s.show_in_legend + s = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), "test", + show_in_legend=False) + assert not s.show_in_legend + + +@XFAIL +def test_particular_case_1_with_adaptive_true(): + # Verify that symbolic expressions and numerical lambda functions are + # evaluated with the same algorithm. + if not np: + skip("numpy not installed.") + + # NOTE: xfail because sympy's adaptive algorithm is not deterministic + + def do_test(a, b): + with warns( + RuntimeWarning, + match="invalid value encountered in scalar power", + test_stacklevel=False, + ): + d1 = a.get_data() + d2 = b.get_data() + for t, v in zip(d1, d2): + assert np.allclose(t, v) + + n = symbols("n") + a = S(2) / 3 + epsilon = 0.01 + xn = (n**3 + n**2)**(S(1)/3) - (n**3 - n**2)**(S(1)/3) + expr = Abs(xn - a) - epsilon + math_func = lambdify([n], expr) + s1 = LineOver1DRangeSeries(expr, (n, -10, 10), "", + adaptive=True, depth=3) + s2 = LineOver1DRangeSeries(math_func, ("n", -10, 10), "", + adaptive=True, depth=3) + do_test(s1, s2) + + +def test_particular_case_1_with_adaptive_false(): + # Verify that symbolic expressions and numerical lambda functions are + # evaluated with the same algorithm. In particular, uniform evaluation + # is going to use np.vectorize, which correctly evaluates the following + # mathematical function. + if not np: + skip("numpy not installed.") + + def do_test(a, b): + d1 = a.get_data() + d2 = b.get_data() + for t, v in zip(d1, d2): + assert np.allclose(t, v) + + n = symbols("n") + a = S(2) / 3 + epsilon = 0.01 + xn = (n**3 + n**2)**(S(1)/3) - (n**3 - n**2)**(S(1)/3) + expr = Abs(xn - a) - epsilon + math_func = lambdify([n], expr) + + s3 = LineOver1DRangeSeries(expr, (n, -10, 10), "", + adaptive=False, n=10) + s4 = LineOver1DRangeSeries(math_func, ("n", -10, 10), "", + adaptive=False, n=10) + do_test(s3, s4) + + +def test_complex_params_number_eval(): + # The main expression contains terms like sqrt(xi - 1), with + # parameter (0 <= xi <= 1). + # There shouldn't be any NaN values on the output. + if not np: + skip("numpy not installed.") + + xi, wn, x0, v0, t = symbols("xi, omega_n, x0, v0, t") + x = Function("x")(t) + eq = x.diff(t, 2) + 2 * xi * wn * x.diff(t) + wn**2 * x + sol = dsolve(eq, x, ics={x.subs(t, 0): x0, x.diff(t).subs(t, 0): v0}) + params = { + wn: 0.5, + xi: 0.25, + x0: 0.45, + v0: 0.0 + } + s = LineOver1DRangeSeries(sol.rhs, (t, 0, 100), adaptive=False, n=5, + params=params) + x, y = s.get_data() + assert not np.isnan(x).any() + assert not np.isnan(y).any() + + + # Fourier Series of a sawtooth wave + # The main expression contains a Sum with a symbolic upper range. + # The lambdified code looks like: + # sum(blablabla for for n in range(1, m+1)) + # But range requires integer numbers, whereas per above example, the series + # casts parameters to complex. Verify that the series is able to detect + # upper bounds in summations and cast it to int in order to get successfull + # evaluation + x, T, n, m = symbols("x, T, n, m") + fs = S(1) / 2 - (1 / pi) * Sum(sin(2 * n * pi * x / T) / n, (n, 1, m)) + params = { + T: 4.5, + m: 5 + } + s = LineOver1DRangeSeries(fs, (x, 0, 10), adaptive=False, n=5, + params=params) + x, y = s.get_data() + assert not np.isnan(x).any() + assert not np.isnan(y).any() + + +def test_complex_range_line_plot_1(): + # verify that univariate functions are evaluated with a complex + # data range (with zero imaginary part). There shouln't be any + # NaN value in the output. + if not np: + skip("numpy not installed.") + + x, u = symbols("x, u") + expr1 = im(sqrt(x) * exp(-x**2)) + expr2 = im(sqrt(u * x) * exp(-x**2)) + s1 = LineOver1DRangeSeries(expr1, (x, -10, 10), adaptive=True, + adaptive_goal=0.1) + s2 = LineOver1DRangeSeries(expr1, (x, -10, 10), adaptive=False, n=30) + s3 = LineOver1DRangeSeries(expr2, (x, -10, 10), adaptive=False, n=30, + params={u: 1}) + + with ignore_warnings(RuntimeWarning): + data1 = s1.get_data() + data2 = s2.get_data() + data3 = s3.get_data() + + assert not np.isnan(data1[1]).any() + assert not np.isnan(data2[1]).any() + assert not np.isnan(data3[1]).any() + assert np.allclose(data2[0], data3[0]) and np.allclose(data2[1], data3[1]) + + +@XFAIL +def test_complex_range_line_plot_2(): + # verify that univariate functions are evaluated with a complex + # data range (with non-zero imaginary part). There shouln't be any + # NaN value in the output. + if not np: + skip("numpy not installed.") + + # NOTE: xfail because sympy's adaptive algorithm is unable to deal with + # complex number. + + x, u = symbols("x, u") + + # adaptive and uniform meshing should produce the same data. + # because of the adaptive nature, just compare the first and last points + # of both series. + s1 = LineOver1DRangeSeries(abs(sqrt(x)), (x, -5-2j, 5-2j), adaptive=True) + s2 = LineOver1DRangeSeries(abs(sqrt(x)), (x, -5-2j, 5-2j), adaptive=False, + n=10) + with warns( + RuntimeWarning, + match="invalid value encountered in sqrt", + test_stacklevel=False, + ): + d1 = s1.get_data() + d2 = s2.get_data() + xx1 = [d1[0][0], d1[0][-1]] + xx2 = [d2[0][0], d2[0][-1]] + yy1 = [d1[1][0], d1[1][-1]] + yy2 = [d2[1][0], d2[1][-1]] + assert np.allclose(xx1, xx2) + assert np.allclose(yy1, yy2) + + +def test_force_real_eval(): + # verify that force_real_eval=True produces inconsistent results when + # compared with evaluation of complex domain. + if not np: + skip("numpy not installed.") + + x = symbols("x") + + expr = im(sqrt(x) * exp(-x**2)) + s1 = LineOver1DRangeSeries(expr, (x, -10, 10), adaptive=False, n=10, + force_real_eval=False) + s2 = LineOver1DRangeSeries(expr, (x, -10, 10), adaptive=False, n=10, + force_real_eval=True) + d1 = s1.get_data() + with ignore_warnings(RuntimeWarning): + d2 = s2.get_data() + assert not np.allclose(d1[1], 0) + assert np.allclose(d2[1], 0) + + +def test_contour_series_show_clabels(): + # verify that a contour series has the abiliy to set the visibility of + # labels to contour lines + + x, y = symbols("x, y") + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2)) + assert s.show_clabels + + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2), clabels=True) + assert s.show_clabels + + s = ContourSeries(cos(x*y), (x, -2, 2), (y, -2, 2), clabels=False) + assert not s.show_clabels + + +def test_LineOver1DRangeSeries_complex_range(): + # verify that LineOver1DRangeSeries can accept a complex range + # if the imaginary part of the start and end values are the same + + x = symbols("x") + + LineOver1DRangeSeries(sqrt(x), (x, -10, 10)) + LineOver1DRangeSeries(sqrt(x), (x, -10-2j, 10-2j)) + raises(ValueError, + lambda : LineOver1DRangeSeries(sqrt(x), (x, -10-2j, 10+2j))) + + +def test_symbolic_plotting_ranges(): + # verify that data series can use symbolic plotting ranges + if not np: + skip("numpy not installed.") + + x, y, z, a, b = symbols("x, y, z, a, b") + + def do_test(s1, s2, new_params): + d1 = s1.get_data() + d2 = s2.get_data() + for u, v in zip(d1, d2): + assert np.allclose(u, v) + s2.params = new_params + d2 = s2.get_data() + for u, v in zip(d1, d2): + assert not np.allclose(u, v) + + s1 = LineOver1DRangeSeries(sin(x), (x, 0, 1), adaptive=False, n=10) + s2 = LineOver1DRangeSeries(sin(x), (x, a, b), params={a: 0, b: 1}, + adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : LineOver1DRangeSeries(sin(x), (x, a, b), params={a: 1}, n=10)) + + s1 = Parametric2DLineSeries(cos(x), sin(x), (x, 0, 1), adaptive=False, n=10) + s2 = Parametric2DLineSeries(cos(x), sin(x), (x, a, b), params={a: 0, b: 1}, + adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : Parametric2DLineSeries(cos(x), sin(x), (x, a, b), + params={a: 0}, adaptive=False, n=10)) + + s1 = Parametric3DLineSeries(cos(x), sin(x), x, (x, 0, 1), + adaptive=False, n=10) + s2 = Parametric3DLineSeries(cos(x), sin(x), x, (x, a, b), + params={a: 0, b: 1}, adaptive=False, n=10) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : Parametric3DLineSeries(cos(x), sin(x), x, (x, a, b), + params={a: 0}, adaptive=False, n=10)) + + s1 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi, pi), (y, -pi, pi), + adaptive=False, n1=5, n2=5) + s2 = SurfaceOver2DRangeSeries(cos(x**2 + y**2), (x, -pi * a, pi * a), + (y, -pi * b, pi * b), params={a: 1, b: 1}, + adaptive=False, n1=5, n2=5) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : SurfaceOver2DRangeSeries(cos(x**2 + y**2), + (x, -pi * a, pi * a), (y, -pi * b, pi * b), params={a: 1}, + adaptive=False, n1=5, n2=5)) + # one range symbol is included into another range's minimum or maximum val + raises(ValueError, + lambda : SurfaceOver2DRangeSeries(cos(x**2 + y**2), + (x, -pi * a + y, pi * a), (y, -pi * b, pi * b), params={a: 1}, + adaptive=False, n1=5, n2=5)) + + s1 = ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2, 2), (y, -2, 2), n1=5, n2=5) + s2 = ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2 * a, 2), (y, -2, 2 * b), + params={a: 1, b: 1}, n1=5, n2=5) + do_test(s1, s2, {a: 0.5, b: 1.5}) + + # missing a parameter + raises(ValueError, + lambda : ParametricSurfaceSeries( + cos(x - y), sin(x + y), x - y, (x, -2 * a, 2), (y, -2, 2 * b), + params={a: 1}, n1=5, n2=5)) + + +def test_exclude_points(): + # verify that exclude works as expected + if not np: + skip("numpy not installed.") + + x = symbols("x") + + expr = (floor(x) + S.Half) / (1 - (x - S.Half)**2) + + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some", + test_stacklevel=False, + ): + s = LineOver1DRangeSeries(expr, (x, -3.5, 3.5), adaptive=False, n=100, + exclude=list(range(-3, 4))) + xx, yy = s.get_data() + assert not np.isnan(xx).any() + assert np.count_nonzero(np.isnan(yy)) == 7 + assert len(xx) > 100 + + e1 = log(floor(x)) * cos(x) + e2 = log(floor(x)) * sin(x) + with warns( + UserWarning, + match="NumPy is unable to evaluate with complex numbers some", + test_stacklevel=False, + ): + s = Parametric2DLineSeries(e1, e2, (x, 1, 12), adaptive=False, n=100, + exclude=list(range(1, 13))) + xx, yy, pp = s.get_data() + assert not np.isnan(pp).any() + assert np.count_nonzero(np.isnan(xx)) == 11 + assert np.count_nonzero(np.isnan(yy)) == 11 + assert len(xx) > 100 + + +def test_unwrap(): + # verify that unwrap works as expected + if not np: + skip("numpy not installed.") + + x, y = symbols("x, y") + expr = 1 / (x**3 + 2*x**2 + x) + expr = arg(expr.subs(x, I*y*2*pi)) + s1 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap=False) + s2 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap=True) + s3 = LineOver1DRangeSeries(expr, (y, 1e-05, 1e05), xscale="log", + adaptive=False, n=10, unwrap={"period": 4}) + x1, y1 = s1.get_data() + x2, y2 = s2.get_data() + x3, y3 = s3.get_data() + assert np.allclose(x1, x2) + # there must not be nan values in the results of these evaluations + assert all(not np.isnan(t).any() for t in [y1, y2, y3]) + assert not np.allclose(y1, y2) + assert not np.allclose(y1, y3) + assert not np.allclose(y2, y3) diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_textplot.py b/lib/python3.10/site-packages/sympy/plotting/tests/test_textplot.py new file mode 100644 index 0000000000000000000000000000000000000000..928085c627e5230f2ac4a8ce0bbac5354ab35d51 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/tests/test_textplot.py @@ -0,0 +1,203 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.plotting.textplot import textplot_str + +from sympy.utilities.exceptions import ignore_warnings + + +def test_axes_alignment(): + x = Symbol('x') + lines = [ + ' 1 | ..', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' 0 |--------------------------...--------------------------', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert lines == list(textplot_str(x, -1, 1)) + + lines = [ + ' 1 | ..', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' 0 |--------------------------...--------------------------', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' | ... ', + ' | ... ', + ' | .... ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert lines == list(textplot_str(x, -1, 1, H=17)) + + +def test_singularity(): + x = Symbol('x') + lines = [ + ' 54 | . ', + ' | ', + ' | ', + ' | ', + ' | ',' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' 27.5 |--.----------------------------------------------------', + ' | ', + ' | ', + ' | ', + ' | . ', + ' | \\ ', + ' | \\ ', + ' | .. ', + ' | ... ', + ' | ............. ', + ' 1 |_______________________________________________________', + ' 0 0.5 1' + ] + assert lines == list(textplot_str(1/x, 0, 1)) + + lines = [ + ' 0 | ......', + ' | ........ ', + ' | ........ ', + ' | ...... ', + ' | ..... ', + ' | .... ', + ' | ... ', + ' | .. ', + ' | ... ', + ' | / ', + ' -2 |-------..----------------------------------------------', + ' | / ', + ' | / ', + ' | / ', + ' | . ', + ' | ', + ' | . ', + ' | ', + ' | ', + ' | ', + ' -4 |_______________________________________________________', + ' 0 0.5 1' + ] + # RuntimeWarning: divide by zero encountered in log + with ignore_warnings(RuntimeWarning): + assert lines == list(textplot_str(log(x), 0, 1)) + + +def test_sinc(): + x = Symbol('x') + lines = [ + ' 1 | . . ', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | ', + ' | . . ', + ' | ', + ' 0.4 |-------------------------------------------------------', + ' | . . ', + ' | ', + ' | . . ', + ' | ', + ' | ..... ..... ', + ' | .. \\ . . / .. ', + ' | / \\ / \\ ', + ' |/ \\ . . / \\', + ' | \\ / \\ / ', + ' -0.2 |_______________________________________________________', + ' -10 0 10' + ] + # RuntimeWarning: invalid value encountered in double_scalars + with ignore_warnings(RuntimeWarning): + assert lines == list(textplot_str(sin(x)/x, -10, 10)) + + +def test_imaginary(): + x = Symbol('x') + lines = [ + ' 1 | ..', + ' | .. ', + ' | ... ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | .. ', + ' | / ', + ' 0.5 |----------------------------------/--------------------', + ' | .. ', + ' | / ', + ' | . ', + ' | ', + ' | . ', + ' | . ', + ' | ', + ' | ', + ' | ', + ' 0 |_______________________________________________________', + ' -1 0 1' + ] + # RuntimeWarning: invalid value encountered in sqrt + with ignore_warnings(RuntimeWarning): + assert list(textplot_str(sqrt(x), -1, 1)) == lines + + lines = [ + ' 1 | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' 0 |-------------------------------------------------------', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' | ', + ' -1 |_______________________________________________________', + ' -1 0 1' + ] + assert list(textplot_str(S.ImaginaryUnit, -1, 1)) == lines diff --git a/lib/python3.10/site-packages/sympy/plotting/tests/test_utils.py b/lib/python3.10/site-packages/sympy/plotting/tests/test_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4206a8b001319552c2e2be1aeb46057e6f708912 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/plotting/tests/test_utils.py @@ -0,0 +1,110 @@ +from pytest import raises +from sympy import ( + symbols, Expr, Tuple, Integer, cos, solveset, FiniteSet, ImageSet) +from sympy.plotting.utils import ( + _create_ranges, _plot_sympify, extract_solution) +from sympy.physics.mechanics import ReferenceFrame, Vector as MechVector +from sympy.vector import CoordSys3D, Vector + + +def test_plot_sympify(): + x, y = symbols("x, y") + + # argument is already sympified + args = x + y + r = _plot_sympify(args) + assert r == args + + # one argument needs to be sympified + args = (x + y, 1) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], Expr) + assert isinstance(r[1], Integer) + + # string and dict should not be sympified + args = (x + y, (x, 0, 1), "str", 1, {1: 1, 2: 2.0}) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 5 + assert isinstance(r[0], Expr) + assert isinstance(r[1], Tuple) + assert isinstance(r[2], str) + assert isinstance(r[3], Integer) + assert isinstance(r[4], dict) and isinstance(r[4][1], int) and isinstance(r[4][2], float) + + # nested arguments containing strings + args = ((x + y, (y, 0, 1), "a"), (x + 1, (x, 0, 1), "$f_{1}$")) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], Tuple) + assert isinstance(r[0][1], Tuple) + assert isinstance(r[0][1][1], Integer) + assert isinstance(r[0][2], str) + assert isinstance(r[1], Tuple) + assert isinstance(r[1][1], Tuple) + assert isinstance(r[1][1][1], Integer) + assert isinstance(r[1][2], str) + + # vectors from sympy.physics.vectors module are not sympified + # vectors from sympy.vectors are sympified + # in both cases, no error should be raised + R = ReferenceFrame("R") + v1 = 2 * R.x + R.y + C = CoordSys3D("C") + v2 = 2 * C.i + C.j + args = (v1, v2) + r = _plot_sympify(args) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(v1, MechVector) + assert isinstance(v2, Vector) + + +def test_create_ranges(): + x, y = symbols("x, y") + + # user don't provide any range -> return a default range + r = _create_ranges({x}, [], 1) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 1 + assert isinstance(r[0], (Tuple, tuple)) + assert r[0] == (x, -10, 10) + + r = _create_ranges({x, y}, [], 2) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], (Tuple, tuple)) + assert isinstance(r[1], (Tuple, tuple)) + assert r[0] == (x, -10, 10) or (y, -10, 10) + assert r[1] == (y, -10, 10) or (x, -10, 10) + assert r[0] != r[1] + + # not enough ranges provided by the user -> create default ranges + r = _create_ranges( + {x, y}, + [ + (x, 0, 1), + ], + 2, + ) + assert isinstance(r, (list, tuple, Tuple)) and len(r) == 2 + assert isinstance(r[0], (Tuple, tuple)) + assert isinstance(r[1], (Tuple, tuple)) + assert r[0] == (x, 0, 1) or (y, -10, 10) + assert r[1] == (y, -10, 10) or (x, 0, 1) + assert r[0] != r[1] + + # too many free symbols + raises(ValueError, lambda: _create_ranges({x, y}, [], 1)) + raises(ValueError, lambda: _create_ranges({x, y}, [(x, 0, 5), (y, 0, 1)], 1)) + + +def test_extract_solution(): + x = symbols("x") + + sol = solveset(cos(10 * x)) + assert sol.has(ImageSet) + res = extract_solution(sol) + assert len(res) == 20 + assert isinstance(res, FiniteSet) + + res = extract_solution(sol, 20) + assert len(res) == 40 + assert isinstance(res, FiniteSet) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__init__.py b/lib/python3.10/site-packages/sympy/polys/domains/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c6839b4494afd0ee0c0ecd9ddee65d1afbdc6b53 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/__init__.py @@ -0,0 +1,57 @@ +"""Implementation of mathematical domains. """ + +__all__ = [ + 'Domain', 'FiniteField', 'IntegerRing', 'RationalField', 'RealField', + 'ComplexField', 'AlgebraicField', 'PolynomialRing', 'FractionField', + 'ExpressionDomain', 'PythonRational', + + 'GF', 'FF', 'ZZ', 'QQ', 'ZZ_I', 'QQ_I', 'RR', 'CC', 'EX', 'EXRAW', +] + +from .domain import Domain +from .finitefield import FiniteField, FF, GF +from .integerring import IntegerRing, ZZ +from .rationalfield import RationalField, QQ +from .algebraicfield import AlgebraicField +from .gaussiandomains import ZZ_I, QQ_I +from .realfield import RealField, RR +from .complexfield import ComplexField, CC +from .polynomialring import PolynomialRing +from .fractionfield import FractionField +from .expressiondomain import ExpressionDomain, EX +from .expressionrawdomain import EXRAW +from .pythonrational import PythonRational + + +# This is imported purely for backwards compatibility because some parts of +# the codebase used to import this from here and it's possible that downstream +# does as well: +from sympy.external.gmpy import GROUND_TYPES # noqa: F401 + +# +# The rest of these are obsolete and provided only for backwards +# compatibility: +# + +from .pythonfinitefield import PythonFiniteField +from .gmpyfinitefield import GMPYFiniteField +from .pythonintegerring import PythonIntegerRing +from .gmpyintegerring import GMPYIntegerRing +from .pythonrationalfield import PythonRationalField +from .gmpyrationalfield import GMPYRationalField + +FF_python = PythonFiniteField +FF_gmpy = GMPYFiniteField + +ZZ_python = PythonIntegerRing +ZZ_gmpy = GMPYIntegerRing + +QQ_python = PythonRationalField +QQ_gmpy = GMPYRationalField + +__all__.extend(( + 'PythonFiniteField', 'GMPYFiniteField', 'PythonIntegerRing', + 'GMPYIntegerRing', 'PythonRational', 'GMPYRationalField', + + 'FF_python', 'FF_gmpy', 'ZZ_python', 'ZZ_gmpy', 'QQ_python', 'QQ_gmpy', +)) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5e5cef67c89226d1d68579552009799b08516f5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/algebraicfield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/algebraicfield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e839d26e2db20e657aa67d5c8e9675b4f576c19 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/algebraicfield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/characteristiczero.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/characteristiczero.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..38a8076ed64a97156512690e46eb0d9657c81fec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/characteristiczero.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/complexfield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/complexfield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d5c68d016fa1cb9ebd13af28ecab9866f9c1a4d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/complexfield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/gmpyintegerring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/gmpyintegerring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a865f318644ae6e9330cea7803e80bc17e67c29d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/gmpyintegerring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/gmpyrationalfield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/gmpyrationalfield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff17c98f97511ed756598262e68973bdfeed438e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/gmpyrationalfield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/groundtypes.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/groundtypes.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10213b4da5d41fa5c05b983f5c0215abe7f19a2d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/groundtypes.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/integerring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/integerring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..31552772dbc09cdc0fd2d9207d1c6b250b9f8bbe Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/integerring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/modularinteger.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/modularinteger.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..450b4abda8ab7fe0b50684ce9d220465b9d4df3f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/modularinteger.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/mpelements.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/mpelements.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aff2247cf2405c3d98cef663b0df53c22e67fae1 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/mpelements.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/old_fractionfield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/old_fractionfield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ca566b1872ca9d22a87f8e1a367584c0966a6a9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/old_fractionfield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/old_polynomialring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/old_polynomialring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb0d15c7a8fd9f5781567fc24cf8bc35a9be9298 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/old_polynomialring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/polynomialring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/polynomialring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d8b487444e0b0c8658cb1c158aa5f14191443ac8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/polynomialring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonfinitefield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonfinitefield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6295bcc63b9d558d32b0244d70bc2290d6a5f89c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonfinitefield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonintegerring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonintegerring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2e17f445bb3818c5c8278c2f1b4f54e3e2beb350 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonintegerring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonrational.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonrational.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed8a5eed049a21c4a9c898a28aa419e709f8f80 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonrational.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonrationalfield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonrationalfield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b903d30149c64789e0b83d4888e2a4010e9b2cf5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/pythonrationalfield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/quotientring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/quotientring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9a20546378aa183354177d8ae4f7e72c8304aa20 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/quotientring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/rationalfield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/rationalfield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb6b42ea0a34559d1514c766ad5f07a923a798ff Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/rationalfield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/realfield.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/realfield.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8782c3a279f44e16e0bf418190490e36dfef190a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/realfield.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/ring.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/ring.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d274c1dc668489fd22077dacc10c4237b864818 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/ring.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/simpledomain.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/simpledomain.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac69da17514438661bcc9e864f9f4899985fb5ec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/domains/__pycache__/simpledomain.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/domains/characteristiczero.py b/lib/python3.10/site-packages/sympy/polys/domains/characteristiczero.py new file mode 100644 index 0000000000000000000000000000000000000000..755a354bea9594b9e8f73256c448b3debae037b2 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/characteristiczero.py @@ -0,0 +1,15 @@ +"""Implementation of :class:`CharacteristicZero` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.utilities import public + +@public +class CharacteristicZero(Domain): + """Domain that has infinite number of elements. """ + + has_CharacteristicZero = True + + def characteristic(self): + """Return the characteristic of this domain. """ + return 0 diff --git a/lib/python3.10/site-packages/sympy/polys/domains/complexfield.py b/lib/python3.10/site-packages/sympy/polys/domains/complexfield.py new file mode 100644 index 0000000000000000000000000000000000000000..4642b20249bee3db36123d0c15af064496673d50 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/complexfield.py @@ -0,0 +1,185 @@ +"""Implementation of :class:`ComplexField` class. """ + + +from sympy.external.gmpy import SYMPY_INTS +from sympy.core.numbers import Float, I +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.gaussiandomains import QQ_I +from sympy.polys.domains.mpelements import MPContext +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import DomainError, CoercionFailed +from sympy.utilities import public + +@public +class ComplexField(Field, CharacteristicZero, SimpleDomain): + """Complex numbers up to the given precision. """ + + rep = 'CC' + + is_ComplexField = is_CC = True + + is_Exact = False + is_Numerical = True + + has_assoc_Ring = False + has_assoc_Field = True + + _default_precision = 53 + + @property + def has_default_precision(self): + return self.precision == self._default_precision + + @property + def precision(self): + return self._context.prec + + @property + def dps(self): + return self._context.dps + + @property + def tolerance(self): + return self._context.tolerance + + def __init__(self, prec=_default_precision, dps=None, tol=None): + context = MPContext(prec, dps, tol, False) + context._parent = self + self._context = context + + self._dtype = context.mpc + self.zero = self.dtype(0) + self.one = self.dtype(1) + + @property + def tp(self): + # XXX: Domain treats tp as an alis of dtype. Here we need to two + # separate things: dtype is a callable to make/convert instances. + # We use tp with isinstance to check if an object is an instance + # of the domain already. + return self._dtype + + def dtype(self, x, y=0): + # XXX: This is needed because mpmath does not recognise fmpz. + # It might be better to add conversion routines to mpmath and if that + # happens then this can be removed. + if isinstance(x, SYMPY_INTS): + x = int(x) + if isinstance(y, SYMPY_INTS): + y = int(y) + return self._dtype(x, y) + + def __eq__(self, other): + return (isinstance(other, ComplexField) + and self.precision == other.precision + and self.tolerance == other.tolerance) + + def __hash__(self): + return hash((self.__class__.__name__, self._dtype, self.precision, self.tolerance)) + + def to_sympy(self, element): + """Convert ``element`` to SymPy number. """ + return Float(element.real, self.dps) + I*Float(element.imag, self.dps) + + def from_sympy(self, expr): + """Convert SymPy's number to ``dtype``. """ + number = expr.evalf(n=self.dps) + real, imag = number.as_real_imag() + + if real.is_Number and imag.is_Number: + return self.dtype(real, imag) + else: + raise CoercionFailed("expected complex number, got %s" % expr) + + def from_ZZ(self, element, base): + return self.dtype(element) + + def from_ZZ_gmpy(self, element, base): + return self.dtype(int(element)) + + def from_ZZ_python(self, element, base): + return self.dtype(element) + + def from_QQ(self, element, base): + return self.dtype(int(element.numerator)) / int(element.denominator) + + def from_QQ_python(self, element, base): + return self.dtype(element.numerator) / element.denominator + + def from_QQ_gmpy(self, element, base): + return self.dtype(int(element.numerator)) / int(element.denominator) + + def from_GaussianIntegerRing(self, element, base): + return self.dtype(int(element.x), int(element.y)) + + def from_GaussianRationalField(self, element, base): + x = element.x + y = element.y + return (self.dtype(int(x.numerator)) / int(x.denominator) + + self.dtype(0, int(y.numerator)) / int(y.denominator)) + + def from_AlgebraicField(self, element, base): + return self.from_sympy(base.to_sympy(element).evalf(self.dps)) + + def from_RealField(self, element, base): + return self.dtype(element) + + def from_ComplexField(self, element, base): + if self == base: + return element + else: + return self.dtype(element) + + def get_ring(self): + """Returns a ring associated with ``self``. """ + raise DomainError("there is no ring associated with %s" % self) + + def get_exact(self): + """Returns an exact domain associated with ``self``. """ + return QQ_I + + def is_negative(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def is_positive(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def is_nonnegative(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def is_nonpositive(self, element): + """Returns ``False`` for any ``ComplexElement``. """ + return False + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + return self.one + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + return a*b + + def almosteq(self, a, b, tolerance=None): + """Check if ``a`` and ``b`` are almost equal. """ + return self._context.almosteq(a, b, tolerance) + + def is_square(self, a): + """Returns ``True``. Every complex number has a complex square root.""" + return True + + def exsqrt(self, a): + r"""Returns the principal complex square root of ``a``. + + Explanation + =========== + The argument of the principal square root is always within + $(-\frac{\pi}{2}, \frac{\pi}{2}]$. The square root may be + slightly inaccurate due to floating point rounding error. + """ + return a ** 0.5 + +CC = ComplexField() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/compositedomain.py b/lib/python3.10/site-packages/sympy/polys/domains/compositedomain.py new file mode 100644 index 0000000000000000000000000000000000000000..a8f63ba7bb86b1d69493b77bfa8c7f33652adbbf --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/compositedomain.py @@ -0,0 +1,52 @@ +"""Implementation of :class:`CompositeDomain` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.polys.polyerrors import GeneratorsError + +from sympy.utilities import public + +@public +class CompositeDomain(Domain): + """Base class for composite domains, e.g. ZZ[x], ZZ(X). """ + + is_Composite = True + + gens, ngens, symbols, domain = [None]*4 + + def inject(self, *symbols): + """Inject generators into this domain. """ + if not (set(self.symbols) & set(symbols)): + return self.__class__(self.domain, self.symbols + symbols, self.order) + else: + raise GeneratorsError("common generators in %s and %s" % (self.symbols, symbols)) + + def drop(self, *symbols): + """Drop generators from this domain. """ + symset = set(symbols) + newsyms = tuple(s for s in self.symbols if s not in symset) + domain = self.domain.drop(*symbols) + if not newsyms: + return domain + else: + return self.__class__(domain, newsyms, self.order) + + def set_domain(self, domain): + """Set the ground domain of this domain. """ + return self.__class__(domain, self.symbols, self.order) + + @property + def is_Exact(self): + """Returns ``True`` if this domain is exact. """ + return self.domain.is_Exact + + def get_exact(self): + """Returns an exact version of this domain. """ + return self.set_domain(self.domain.get_exact()) + + @property + def has_CharacteristicZero(self): + return self.domain.has_CharacteristicZero + + def characteristic(self): + return self.domain.characteristic() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/domain.py b/lib/python3.10/site-packages/sympy/polys/domains/domain.py new file mode 100644 index 0000000000000000000000000000000000000000..a1136d35b13c0fe1d318b957a42a9ef82cdd15dd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/domain.py @@ -0,0 +1,1372 @@ +"""Implementation of :class:`Domain` class. """ + +from __future__ import annotations +from typing import Any + +from sympy.core.numbers import AlgebraicNumber +from sympy.core import Basic, sympify +from sympy.core.sorting import ordered +from sympy.external.gmpy import GROUND_TYPES +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.orderings import lex +from sympy.polys.polyerrors import UnificationFailed, CoercionFailed, DomainError +from sympy.polys.polyutils import _unify_gens, _not_a_coeff +from sympy.utilities import public +from sympy.utilities.iterables import is_sequence + + +@public +class Domain: + """Superclass for all domains in the polys domains system. + + See :ref:`polys-domainsintro` for an introductory explanation of the + domains system. + + The :py:class:`~.Domain` class is an abstract base class for all of the + concrete domain types. There are many different :py:class:`~.Domain` + subclasses each of which has an associated ``dtype`` which is a class + representing the elements of the domain. The coefficients of a + :py:class:`~.Poly` are elements of a domain which must be a subclass of + :py:class:`~.Domain`. + + Examples + ======== + + The most common example domains are the integers :ref:`ZZ` and the + rationals :ref:`QQ`. + + >>> from sympy import Poly, symbols, Domain + >>> x, y = symbols('x, y') + >>> p = Poly(x**2 + y) + >>> p + Poly(x**2 + y, x, y, domain='ZZ') + >>> p.domain + ZZ + >>> isinstance(p.domain, Domain) + True + >>> Poly(x**2 + y/2) + Poly(x**2 + 1/2*y, x, y, domain='QQ') + + The domains can be used directly in which case the domain object e.g. + (:ref:`ZZ` or :ref:`QQ`) can be used as a constructor for elements of + ``dtype``. + + >>> from sympy import ZZ, QQ + >>> ZZ(2) + 2 + >>> ZZ.dtype # doctest: +SKIP + + >>> type(ZZ(2)) # doctest: +SKIP + + >>> QQ(1, 2) + 1/2 + >>> type(QQ(1, 2)) # doctest: +SKIP + + + The corresponding domain elements can be used with the arithmetic + operations ``+,-,*,**`` and depending on the domain some combination of + ``/,//,%`` might be usable. For example in :ref:`ZZ` both ``//`` (floor + division) and ``%`` (modulo division) can be used but ``/`` (true + division) cannot. Since :ref:`QQ` is a :py:class:`~.Field` its elements + can be used with ``/`` but ``//`` and ``%`` should not be used. Some + domains have a :py:meth:`~.Domain.gcd` method. + + >>> ZZ(2) + ZZ(3) + 5 + >>> ZZ(5) // ZZ(2) + 2 + >>> ZZ(5) % ZZ(2) + 1 + >>> QQ(1, 2) / QQ(2, 3) + 3/4 + >>> ZZ.gcd(ZZ(4), ZZ(2)) + 2 + >>> QQ.gcd(QQ(2,7), QQ(5,3)) + 1/21 + >>> ZZ.is_Field + False + >>> QQ.is_Field + True + + There are also many other domains including: + + 1. :ref:`GF(p)` for finite fields of prime order. + 2. :ref:`RR` for real (floating point) numbers. + 3. :ref:`CC` for complex (floating point) numbers. + 4. :ref:`QQ(a)` for algebraic number fields. + 5. :ref:`K[x]` for polynomial rings. + 6. :ref:`K(x)` for rational function fields. + 7. :ref:`EX` for arbitrary expressions. + + Each domain is represented by a domain object and also an implementation + class (``dtype``) for the elements of the domain. For example the + :ref:`K[x]` domains are represented by a domain object which is an + instance of :py:class:`~.PolynomialRing` and the elements are always + instances of :py:class:`~.PolyElement`. The implementation class + represents particular types of mathematical expressions in a way that is + more efficient than a normal SymPy expression which is of type + :py:class:`~.Expr`. The domain methods :py:meth:`~.Domain.from_sympy` and + :py:meth:`~.Domain.to_sympy` are used to convert from :py:class:`~.Expr` + to a domain element and vice versa. + + >>> from sympy import Symbol, ZZ, Expr + >>> x = Symbol('x') + >>> K = ZZ[x] # polynomial ring domain + >>> K + ZZ[x] + >>> type(K) # class of the domain + + >>> K.dtype # class of the elements + + >>> p_expr = x**2 + 1 # Expr + >>> p_expr + x**2 + 1 + >>> type(p_expr) + + >>> isinstance(p_expr, Expr) + True + >>> p_domain = K.from_sympy(p_expr) + >>> p_domain # domain element + x**2 + 1 + >>> type(p_domain) + + >>> K.to_sympy(p_domain) == p_expr + True + + The :py:meth:`~.Domain.convert_from` method is used to convert domain + elements from one domain to another. + + >>> from sympy import ZZ, QQ + >>> ez = ZZ(2) + >>> eq = QQ.convert_from(ez, ZZ) + >>> type(ez) # doctest: +SKIP + + >>> type(eq) # doctest: +SKIP + + + Elements from different domains should not be mixed in arithmetic or other + operations: they should be converted to a common domain first. The domain + method :py:meth:`~.Domain.unify` is used to find a domain that can + represent all the elements of two given domains. + + >>> from sympy import ZZ, QQ, symbols + >>> x, y = symbols('x, y') + >>> ZZ.unify(QQ) + QQ + >>> ZZ[x].unify(QQ) + QQ[x] + >>> ZZ[x].unify(QQ[y]) + QQ[x,y] + + If a domain is a :py:class:`~.Ring` then is might have an associated + :py:class:`~.Field` and vice versa. The :py:meth:`~.Domain.get_field` and + :py:meth:`~.Domain.get_ring` methods will find or create the associated + domain. + + >>> from sympy import ZZ, QQ, Symbol + >>> x = Symbol('x') + >>> ZZ.has_assoc_Field + True + >>> ZZ.get_field() + QQ + >>> QQ.has_assoc_Ring + True + >>> QQ.get_ring() + ZZ + >>> K = QQ[x] + >>> K + QQ[x] + >>> K.get_field() + QQ(x) + + See also + ======== + + DomainElement: abstract base class for domain elements + construct_domain: construct a minimal domain for some expressions + + """ + + dtype: type | None = None + """The type (class) of the elements of this :py:class:`~.Domain`: + + >>> from sympy import ZZ, QQ, Symbol + >>> ZZ.dtype + + >>> z = ZZ(2) + >>> z + 2 + >>> type(z) + + >>> type(z) == ZZ.dtype + True + + Every domain has an associated **dtype** ("datatype") which is the + class of the associated domain elements. + + See also + ======== + + of_type + """ + + zero: Any = None + """The zero element of the :py:class:`~.Domain`: + + >>> from sympy import QQ + >>> QQ.zero + 0 + >>> QQ.of_type(QQ.zero) + True + + See also + ======== + + of_type + one + """ + + one: Any = None + """The one element of the :py:class:`~.Domain`: + + >>> from sympy import QQ + >>> QQ.one + 1 + >>> QQ.of_type(QQ.one) + True + + See also + ======== + + of_type + zero + """ + + is_Ring = False + """Boolean flag indicating if the domain is a :py:class:`~.Ring`. + + >>> from sympy import ZZ + >>> ZZ.is_Ring + True + + Basically every :py:class:`~.Domain` represents a ring so this flag is + not that useful. + + See also + ======== + + is_PID + is_Field + get_ring + has_assoc_Ring + """ + + is_Field = False + """Boolean flag indicating if the domain is a :py:class:`~.Field`. + + >>> from sympy import ZZ, QQ + >>> ZZ.is_Field + False + >>> QQ.is_Field + True + + See also + ======== + + is_PID + is_Ring + get_field + has_assoc_Field + """ + + has_assoc_Ring = False + """Boolean flag indicating if the domain has an associated + :py:class:`~.Ring`. + + >>> from sympy import QQ + >>> QQ.has_assoc_Ring + True + >>> QQ.get_ring() + ZZ + + See also + ======== + + is_Field + get_ring + """ + + has_assoc_Field = False + """Boolean flag indicating if the domain has an associated + :py:class:`~.Field`. + + >>> from sympy import ZZ + >>> ZZ.has_assoc_Field + True + >>> ZZ.get_field() + QQ + + See also + ======== + + is_Field + get_field + """ + + is_FiniteField = is_FF = False + is_IntegerRing = is_ZZ = False + is_RationalField = is_QQ = False + is_GaussianRing = is_ZZ_I = False + is_GaussianField = is_QQ_I = False + is_RealField = is_RR = False + is_ComplexField = is_CC = False + is_AlgebraicField = is_Algebraic = False + is_PolynomialRing = is_Poly = False + is_FractionField = is_Frac = False + is_SymbolicDomain = is_EX = False + is_SymbolicRawDomain = is_EXRAW = False + is_FiniteExtension = False + + is_Exact = True + is_Numerical = False + + is_Simple = False + is_Composite = False + + is_PID = False + """Boolean flag indicating if the domain is a `principal ideal domain`_. + + >>> from sympy import ZZ + >>> ZZ.has_assoc_Field + True + >>> ZZ.get_field() + QQ + + .. _principal ideal domain: https://en.wikipedia.org/wiki/Principal_ideal_domain + + See also + ======== + + is_Field + get_field + """ + + has_CharacteristicZero = False + + rep: str | None = None + alias: str | None = None + + def __init__(self): + raise NotImplementedError + + def __str__(self): + return self.rep + + def __repr__(self): + return str(self) + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype)) + + def new(self, *args): + return self.dtype(*args) + + @property + def tp(self): + """Alias for :py:attr:`~.Domain.dtype`""" + return self.dtype + + def __call__(self, *args): + """Construct an element of ``self`` domain from ``args``. """ + return self.new(*args) + + def normal(self, *args): + return self.dtype(*args) + + def convert_from(self, element, base): + """Convert ``element`` to ``self.dtype`` given the base domain. """ + if base.alias is not None: + method = "from_" + base.alias + else: + method = "from_" + base.__class__.__name__ + + _convert = getattr(self, method) + + if _convert is not None: + result = _convert(element, base) + + if result is not None: + return result + + raise CoercionFailed("Cannot convert %s of type %s from %s to %s" % (element, type(element), base, self)) + + def convert(self, element, base=None): + """Convert ``element`` to ``self.dtype``. """ + + if base is not None: + if _not_a_coeff(element): + raise CoercionFailed('%s is not in any domain' % element) + return self.convert_from(element, base) + + if self.of_type(element): + return element + + if _not_a_coeff(element): + raise CoercionFailed('%s is not in any domain' % element) + + from sympy.polys.domains import ZZ, QQ, RealField, ComplexField + + if ZZ.of_type(element): + return self.convert_from(element, ZZ) + + if isinstance(element, int): + return self.convert_from(ZZ(element), ZZ) + + if GROUND_TYPES != 'python': + if isinstance(element, ZZ.tp): + return self.convert_from(element, ZZ) + if isinstance(element, QQ.tp): + return self.convert_from(element, QQ) + + if isinstance(element, float): + parent = RealField(tol=False) + return self.convert_from(parent(element), parent) + + if isinstance(element, complex): + parent = ComplexField(tol=False) + return self.convert_from(parent(element), parent) + + if isinstance(element, DomainElement): + return self.convert_from(element, element.parent()) + + # TODO: implement this in from_ methods + if self.is_Numerical and getattr(element, 'is_ground', False): + return self.convert(element.LC()) + + if isinstance(element, Basic): + try: + return self.from_sympy(element) + except (TypeError, ValueError): + pass + else: # TODO: remove this branch + if not is_sequence(element): + try: + element = sympify(element, strict=True) + if isinstance(element, Basic): + return self.from_sympy(element) + except (TypeError, ValueError): + pass + + raise CoercionFailed("Cannot convert %s of type %s to %s" % (element, type(element), self)) + + def of_type(self, element): + """Check if ``a`` is of type ``dtype``. """ + return isinstance(element, self.tp) # XXX: this isn't correct, e.g. PolyElement + + def __contains__(self, a): + """Check if ``a`` belongs to this domain. """ + try: + if _not_a_coeff(a): + raise CoercionFailed + self.convert(a) # this might raise, too + except CoercionFailed: + return False + + return True + + def to_sympy(self, a): + """Convert domain element *a* to a SymPy expression (Expr). + + Explanation + =========== + + Convert a :py:class:`~.Domain` element *a* to :py:class:`~.Expr`. Most + public SymPy functions work with objects of type :py:class:`~.Expr`. + The elements of a :py:class:`~.Domain` have a different internal + representation. It is not possible to mix domain elements with + :py:class:`~.Expr` so each domain has :py:meth:`~.Domain.to_sympy` and + :py:meth:`~.Domain.from_sympy` methods to convert its domain elements + to and from :py:class:`~.Expr`. + + Parameters + ========== + + a: domain element + An element of this :py:class:`~.Domain`. + + Returns + ======= + + expr: Expr + A normal SymPy expression of type :py:class:`~.Expr`. + + Examples + ======== + + Construct an element of the :ref:`QQ` domain and then convert it to + :py:class:`~.Expr`. + + >>> from sympy import QQ, Expr + >>> q_domain = QQ(2) + >>> q_domain + 2 + >>> q_expr = QQ.to_sympy(q_domain) + >>> q_expr + 2 + + Although the printed forms look similar these objects are not of the + same type. + + >>> isinstance(q_domain, Expr) + False + >>> isinstance(q_expr, Expr) + True + + Construct an element of :ref:`K[x]` and convert to + :py:class:`~.Expr`. + + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> K = QQ[x] + >>> x_domain = K.gens[0] # generator x as a domain element + >>> p_domain = x_domain**2/3 + 1 + >>> p_domain + 1/3*x**2 + 1 + >>> p_expr = K.to_sympy(p_domain) + >>> p_expr + x**2/3 + 1 + + The :py:meth:`~.Domain.from_sympy` method is used for the opposite + conversion from a normal SymPy expression to a domain element. + + >>> p_domain == p_expr + False + >>> K.from_sympy(p_expr) == p_domain + True + >>> K.to_sympy(p_domain) == p_expr + True + >>> K.from_sympy(K.to_sympy(p_domain)) == p_domain + True + >>> K.to_sympy(K.from_sympy(p_expr)) == p_expr + True + + The :py:meth:`~.Domain.from_sympy` method makes it easier to construct + domain elements interactively. + + >>> from sympy import Symbol + >>> x = Symbol('x') + >>> K = QQ[x] + >>> K.from_sympy(x**2/3 + 1) + 1/3*x**2 + 1 + + See also + ======== + + from_sympy + convert_from + """ + raise NotImplementedError + + def from_sympy(self, a): + """Convert a SymPy expression to an element of this domain. + + Explanation + =========== + + See :py:meth:`~.Domain.to_sympy` for explanation and examples. + + Parameters + ========== + + expr: Expr + A normal SymPy expression of type :py:class:`~.Expr`. + + Returns + ======= + + a: domain element + An element of this :py:class:`~.Domain`. + + See also + ======== + + to_sympy + convert_from + """ + raise NotImplementedError + + def sum(self, args): + return sum(args, start=self.zero) + + def from_FF(K1, a, K0): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return None + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return None + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return None + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return None + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to ``dtype``. """ + return None + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return None + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return None + + def from_RealField(K1, a, K0): + """Convert a real element object to ``dtype``. """ + return None + + def from_ComplexField(K1, a, K0): + """Convert a complex element to ``dtype``. """ + return None + + def from_AlgebraicField(K1, a, K0): + """Convert an algebraic number to ``dtype``. """ + return None + + def from_PolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + if a.is_ground: + return K1.convert(a.LC, K0.dom) + + def from_FractionField(K1, a, K0): + """Convert a rational function to ``dtype``. """ + return None + + def from_MonogenicFiniteExtension(K1, a, K0): + """Convert an ``ExtensionElement`` to ``dtype``. """ + return K1.convert_from(a.rep, K0.ring) + + def from_ExpressionDomain(K1, a, K0): + """Convert a ``EX`` object to ``dtype``. """ + return K1.from_sympy(a.ex) + + def from_ExpressionRawDomain(K1, a, K0): + """Convert a ``EX`` object to ``dtype``. """ + return K1.from_sympy(a) + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + if a.degree() <= 0: + return K1.convert(a.LC(), K0.dom) + + def from_GeneralizedPolynomialRing(K1, a, K0): + return K1.from_FractionField(a, K0) + + def unify_with_symbols(K0, K1, symbols): + if (K0.is_Composite and (set(K0.symbols) & set(symbols))) or (K1.is_Composite and (set(K1.symbols) & set(symbols))): + raise UnificationFailed("Cannot unify %s with %s, given %s generators" % (K0, K1, tuple(symbols))) + + return K0.unify(K1) + + def unify_composite(K0, K1): + """Unify two domains where at least one is composite.""" + K0_ground = K0.dom if K0.is_Composite else K0 + K1_ground = K1.dom if K1.is_Composite else K1 + + K0_symbols = K0.symbols if K0.is_Composite else () + K1_symbols = K1.symbols if K1.is_Composite else () + + domain = K0_ground.unify(K1_ground) + symbols = _unify_gens(K0_symbols, K1_symbols) + order = K0.order if K0.is_Composite else K1.order + + # E.g. ZZ[x].unify(QQ.frac_field(x)) -> ZZ.frac_field(x) + if ((K0.is_FractionField and K1.is_PolynomialRing or + K1.is_FractionField and K0.is_PolynomialRing) and + (not K0_ground.is_Field or not K1_ground.is_Field) and domain.is_Field + and domain.has_assoc_Ring): + domain = domain.get_ring() + + if K0.is_Composite and (not K1.is_Composite or K0.is_FractionField or K1.is_PolynomialRing): + cls = K0.__class__ + else: + cls = K1.__class__ + + # Here cls might be PolynomialRing, FractionField, GlobalPolynomialRing + # (dense/old Polynomialring) or dense/old FractionField. + + from sympy.polys.domains.old_polynomialring import GlobalPolynomialRing + if cls == GlobalPolynomialRing: + return cls(domain, symbols) + + return cls(domain, symbols, order) + + def unify(K0, K1, symbols=None): + """ + Construct a minimal domain that contains elements of ``K0`` and ``K1``. + + Known domains (from smallest to largest): + + - ``GF(p)`` + - ``ZZ`` + - ``QQ`` + - ``RR(prec, tol)`` + - ``CC(prec, tol)`` + - ``ALG(a, b, c)`` + - ``K[x, y, z]`` + - ``K(x, y, z)`` + - ``EX`` + + """ + if symbols is not None: + return K0.unify_with_symbols(K1, symbols) + + if K0 == K1: + return K0 + + if not (K0.has_CharacteristicZero and K1.has_CharacteristicZero): + # Reject unification of domains with different characteristics. + if K0.characteristic() != K1.characteristic(): + raise UnificationFailed("Cannot unify %s with %s" % (K0, K1)) + + # We do not get here if K0 == K1. The two domains have the same + # characteristic but are unequal so at least one is composite and + # we are unifying something like GF(3).unify(GF(3)[x]). + return K0.unify_composite(K1) + + # From here we know both domains have characteristic zero and it can be + # acceptable to fall back on EX. + + if K0.is_EXRAW: + return K0 + if K1.is_EXRAW: + return K1 + + if K0.is_EX: + return K0 + if K1.is_EX: + return K1 + + if K0.is_FiniteExtension or K1.is_FiniteExtension: + if K1.is_FiniteExtension: + K0, K1 = K1, K0 + if K1.is_FiniteExtension: + # Unifying two extensions. + # Try to ensure that K0.unify(K1) == K1.unify(K0) + if list(ordered([K0.modulus, K1.modulus]))[1] == K0.modulus: + K0, K1 = K1, K0 + return K1.set_domain(K0) + else: + # Drop the generator from other and unify with the base domain + K1 = K1.drop(K0.symbol) + K1 = K0.domain.unify(K1) + return K0.set_domain(K1) + + if K0.is_Composite or K1.is_Composite: + return K0.unify_composite(K1) + + def mkinexact(cls, K0, K1): + prec = max(K0.precision, K1.precision) + tol = max(K0.tolerance, K1.tolerance) + return cls(prec=prec, tol=tol) + + if K1.is_ComplexField: + K0, K1 = K1, K0 + if K0.is_ComplexField: + if K1.is_ComplexField or K1.is_RealField: + return mkinexact(K0.__class__, K0, K1) + else: + return K0 + + if K1.is_RealField: + K0, K1 = K1, K0 + if K0.is_RealField: + if K1.is_RealField: + return mkinexact(K0.__class__, K0, K1) + elif K1.is_GaussianRing or K1.is_GaussianField: + from sympy.polys.domains.complexfield import ComplexField + return ComplexField(prec=K0.precision, tol=K0.tolerance) + else: + return K0 + + if K1.is_AlgebraicField: + K0, K1 = K1, K0 + if K0.is_AlgebraicField: + if K1.is_GaussianRing: + K1 = K1.get_field() + if K1.is_GaussianField: + K1 = K1.as_AlgebraicField() + if K1.is_AlgebraicField: + return K0.__class__(K0.dom.unify(K1.dom), *_unify_gens(K0.orig_ext, K1.orig_ext)) + else: + return K0 + + if K0.is_GaussianField: + return K0 + if K1.is_GaussianField: + return K1 + + if K0.is_GaussianRing: + if K1.is_RationalField: + K0 = K0.get_field() + return K0 + if K1.is_GaussianRing: + if K0.is_RationalField: + K1 = K1.get_field() + return K1 + + if K0.is_RationalField: + return K0 + if K1.is_RationalField: + return K1 + + if K0.is_IntegerRing: + return K0 + if K1.is_IntegerRing: + return K1 + + from sympy.polys.domains import EX + return EX + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + # XXX: Remove this. + return isinstance(other, Domain) and self.dtype == other.dtype + + def __ne__(self, other): + """Returns ``False`` if two domains are equivalent. """ + return not self == other + + def map(self, seq): + """Rersively apply ``self`` to all elements of ``seq``. """ + result = [] + + for elt in seq: + if isinstance(elt, list): + result.append(self.map(elt)) + else: + result.append(self(elt)) + + return result + + def get_ring(self): + """Returns a ring associated with ``self``. """ + raise DomainError('there is no ring associated with %s' % self) + + def get_field(self): + """Returns a field associated with ``self``. """ + raise DomainError('there is no field associated with %s' % self) + + def get_exact(self): + """Returns an exact domain associated with ``self``. """ + return self + + def __getitem__(self, symbols): + """The mathematical way to make a polynomial ring. """ + if hasattr(symbols, '__iter__'): + return self.poly_ring(*symbols) + else: + return self.poly_ring(symbols) + + def poly_ring(self, *symbols, order=lex): + """Returns a polynomial ring, i.e. `K[X]`. """ + from sympy.polys.domains.polynomialring import PolynomialRing + return PolynomialRing(self, symbols, order) + + def frac_field(self, *symbols, order=lex): + """Returns a fraction field, i.e. `K(X)`. """ + from sympy.polys.domains.fractionfield import FractionField + return FractionField(self, symbols, order) + + def old_poly_ring(self, *symbols, **kwargs): + """Returns a polynomial ring, i.e. `K[X]`. """ + from sympy.polys.domains.old_polynomialring import PolynomialRing + return PolynomialRing(self, *symbols, **kwargs) + + def old_frac_field(self, *symbols, **kwargs): + """Returns a fraction field, i.e. `K(X)`. """ + from sympy.polys.domains.old_fractionfield import FractionField + return FractionField(self, *symbols, **kwargs) + + def algebraic_field(self, *extension, alias=None): + r"""Returns an algebraic field, i.e. `K(\alpha, \ldots)`. """ + raise DomainError("Cannot create algebraic field over %s" % self) + + def alg_field_from_poly(self, poly, alias=None, root_index=-1): + r""" + Convenience method to construct an algebraic extension on a root of a + polynomial, chosen by root index. + + Parameters + ========== + + poly : :py:class:`~.Poly` + The polynomial whose root generates the extension. + alias : str, optional (default=None) + Symbol name for the generator of the extension. + E.g. "alpha" or "theta". + root_index : int, optional (default=-1) + Specifies which root of the polynomial is desired. The ordering is + as defined by the :py:class:`~.ComplexRootOf` class. The default of + ``-1`` selects the most natural choice in the common cases of + quadratic and cyclotomic fields (the square root on the positive + real or imaginary axis, resp. $\mathrm{e}^{2\pi i/n}$). + + Examples + ======== + + >>> from sympy import QQ, Poly + >>> from sympy.abc import x + >>> f = Poly(x**2 - 2) + >>> K = QQ.alg_field_from_poly(f) + >>> K.ext.minpoly == f + True + >>> g = Poly(8*x**3 - 6*x - 1) + >>> L = QQ.alg_field_from_poly(g, "alpha") + >>> L.ext.minpoly == g + True + >>> L.to_sympy(L([1, 1, 1])) + alpha**2 + alpha + 1 + + """ + from sympy.polys.rootoftools import CRootOf + root = CRootOf(poly, root_index) + alpha = AlgebraicNumber(root, alias=alias) + return self.algebraic_field(alpha, alias=alias) + + def cyclotomic_field(self, n, ss=False, alias="zeta", gen=None, root_index=-1): + r""" + Convenience method to construct a cyclotomic field. + + Parameters + ========== + + n : int + Construct the nth cyclotomic field. + ss : boolean, optional (default=False) + If True, append *n* as a subscript on the alias string. + alias : str, optional (default="zeta") + Symbol name for the generator. + gen : :py:class:`~.Symbol`, optional (default=None) + Desired variable for the cyclotomic polynomial that defines the + field. If ``None``, a dummy variable will be used. + root_index : int, optional (default=-1) + Specifies which root of the polynomial is desired. The ordering is + as defined by the :py:class:`~.ComplexRootOf` class. The default of + ``-1`` selects the root $\mathrm{e}^{2\pi i/n}$. + + Examples + ======== + + >>> from sympy import QQ, latex + >>> K = QQ.cyclotomic_field(5) + >>> K.to_sympy(K([-1, 1])) + 1 - zeta + >>> L = QQ.cyclotomic_field(7, True) + >>> a = L.to_sympy(L([-1, 1])) + >>> print(a) + 1 - zeta7 + >>> print(latex(a)) + 1 - \zeta_{7} + + """ + from sympy.polys.specialpolys import cyclotomic_poly + if ss: + alias += str(n) + return self.alg_field_from_poly(cyclotomic_poly(n, gen), alias=alias, + root_index=root_index) + + def inject(self, *symbols): + """Inject generators into this domain. """ + raise NotImplementedError + + def drop(self, *symbols): + """Drop generators from this domain. """ + if self.is_Simple: + return self + raise NotImplementedError # pragma: no cover + + def is_zero(self, a): + """Returns True if ``a`` is zero. """ + return not a + + def is_one(self, a): + """Returns True if ``a`` is one. """ + return a == self.one + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return a > 0 + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return a < 0 + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return a <= 0 + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return a >= 0 + + def canonical_unit(self, a): + if self.is_negative(a): + return -self.one + else: + return self.one + + def abs(self, a): + """Absolute value of ``a``, implies ``__abs__``. """ + return abs(a) + + def neg(self, a): + """Returns ``a`` negated, implies ``__neg__``. """ + return -a + + def pos(self, a): + """Returns ``a`` positive, implies ``__pos__``. """ + return +a + + def add(self, a, b): + """Sum of ``a`` and ``b``, implies ``__add__``. """ + return a + b + + def sub(self, a, b): + """Difference of ``a`` and ``b``, implies ``__sub__``. """ + return a - b + + def mul(self, a, b): + """Product of ``a`` and ``b``, implies ``__mul__``. """ + return a * b + + def pow(self, a, b): + """Raise ``a`` to power ``b``, implies ``__pow__``. """ + return a ** b + + def exquo(self, a, b): + """Exact quotient of *a* and *b*. Analogue of ``a / b``. + + Explanation + =========== + + This is essentially the same as ``a / b`` except that an error will be + raised if the division is inexact (if there is any remainder) and the + result will always be a domain element. When working in a + :py:class:`~.Domain` that is not a :py:class:`~.Field` (e.g. :ref:`ZZ` + or :ref:`K[x]`) ``exquo`` should be used instead of ``/``. + + The key invariant is that if ``q = K.exquo(a, b)`` (and ``exquo`` does + not raise an exception) then ``a == b*q``. + + Examples + ======== + + We can use ``K.exquo`` instead of ``/`` for exact division. + + >>> from sympy import ZZ + >>> ZZ.exquo(ZZ(4), ZZ(2)) + 2 + >>> ZZ.exquo(ZZ(5), ZZ(2)) + Traceback (most recent call last): + ... + ExactQuotientFailed: 2 does not divide 5 in ZZ + + Over a :py:class:`~.Field` such as :ref:`QQ`, division (with nonzero + divisor) is always exact so in that case ``/`` can be used instead of + :py:meth:`~.Domain.exquo`. + + >>> from sympy import QQ + >>> QQ.exquo(QQ(5), QQ(2)) + 5/2 + >>> QQ(5) / QQ(2) + 5/2 + + Parameters + ========== + + a: domain element + The dividend + b: domain element + The divisor + + Returns + ======= + + q: domain element + The exact quotient + + Raises + ====== + + ExactQuotientFailed: if exact division is not possible. + ZeroDivisionError: when the divisor is zero. + + See also + ======== + + quo: Analogue of ``a // b`` + rem: Analogue of ``a % b`` + div: Analogue of ``divmod(a, b)`` + + Notes + ===== + + Since the default :py:attr:`~.Domain.dtype` for :ref:`ZZ` is ``int`` + (or ``mpz``) division as ``a / b`` should not be used as it would give + a ``float`` which is not a domain element. + + >>> ZZ(4) / ZZ(2) # doctest: +SKIP + 2.0 + >>> ZZ(5) / ZZ(2) # doctest: +SKIP + 2.5 + + On the other hand with `SYMPY_GROUND_TYPES=flint` elements of :ref:`ZZ` + are ``flint.fmpz`` and division would raise an exception: + + >>> ZZ(4) / ZZ(2) # doctest: +SKIP + Traceback (most recent call last): + ... + TypeError: unsupported operand type(s) for /: 'fmpz' and 'fmpz' + + Using ``/`` with :ref:`ZZ` will lead to incorrect results so + :py:meth:`~.Domain.exquo` should be used instead. + + """ + raise NotImplementedError + + def quo(self, a, b): + """Quotient of *a* and *b*. Analogue of ``a // b``. + + ``K.quo(a, b)`` is equivalent to ``K.div(a, b)[0]``. See + :py:meth:`~.Domain.div` for more explanation. + + See also + ======== + + rem: Analogue of ``a % b`` + div: Analogue of ``divmod(a, b)`` + exquo: Analogue of ``a / b`` + """ + raise NotImplementedError + + def rem(self, a, b): + """Modulo division of *a* and *b*. Analogue of ``a % b``. + + ``K.rem(a, b)`` is equivalent to ``K.div(a, b)[1]``. See + :py:meth:`~.Domain.div` for more explanation. + + See also + ======== + + quo: Analogue of ``a // b`` + div: Analogue of ``divmod(a, b)`` + exquo: Analogue of ``a / b`` + """ + raise NotImplementedError + + def div(self, a, b): + """Quotient and remainder for *a* and *b*. Analogue of ``divmod(a, b)`` + + Explanation + =========== + + This is essentially the same as ``divmod(a, b)`` except that is more + consistent when working over some :py:class:`~.Field` domains such as + :ref:`QQ`. When working over an arbitrary :py:class:`~.Domain` the + :py:meth:`~.Domain.div` method should be used instead of ``divmod``. + + The key invariant is that if ``q, r = K.div(a, b)`` then + ``a == b*q + r``. + + The result of ``K.div(a, b)`` is the same as the tuple + ``(K.quo(a, b), K.rem(a, b))`` except that if both quotient and + remainder are needed then it is more efficient to use + :py:meth:`~.Domain.div`. + + Examples + ======== + + We can use ``K.div`` instead of ``divmod`` for floor division and + remainder. + + >>> from sympy import ZZ, QQ + >>> ZZ.div(ZZ(5), ZZ(2)) + (2, 1) + + If ``K`` is a :py:class:`~.Field` then the division is always exact + with a remainder of :py:attr:`~.Domain.zero`. + + >>> QQ.div(QQ(5), QQ(2)) + (5/2, 0) + + Parameters + ========== + + a: domain element + The dividend + b: domain element + The divisor + + Returns + ======= + + (q, r): tuple of domain elements + The quotient and remainder + + Raises + ====== + + ZeroDivisionError: when the divisor is zero. + + See also + ======== + + quo: Analogue of ``a // b`` + rem: Analogue of ``a % b`` + exquo: Analogue of ``a / b`` + + Notes + ===== + + If ``gmpy`` is installed then the ``gmpy.mpq`` type will be used as + the :py:attr:`~.Domain.dtype` for :ref:`QQ`. The ``gmpy.mpq`` type + defines ``divmod`` in a way that is undesirable so + :py:meth:`~.Domain.div` should be used instead of ``divmod``. + + >>> a = QQ(1) + >>> b = QQ(3, 2) + >>> a # doctest: +SKIP + mpq(1,1) + >>> b # doctest: +SKIP + mpq(3,2) + >>> divmod(a, b) # doctest: +SKIP + (mpz(0), mpq(1,1)) + >>> QQ.div(a, b) # doctest: +SKIP + (mpq(2,3), mpq(0,1)) + + Using ``//`` or ``%`` with :ref:`QQ` will lead to incorrect results so + :py:meth:`~.Domain.div` should be used instead. + + """ + raise NotImplementedError + + def invert(self, a, b): + """Returns inversion of ``a mod b``, implies something. """ + raise NotImplementedError + + def revert(self, a): + """Returns ``a**(-1)`` if possible. """ + raise NotImplementedError + + def numer(self, a): + """Returns numerator of ``a``. """ + raise NotImplementedError + + def denom(self, a): + """Returns denominator of ``a``. """ + raise NotImplementedError + + def half_gcdex(self, a, b): + """Half extended GCD of ``a`` and ``b``. """ + s, t, h = self.gcdex(a, b) + return s, h + + def gcdex(self, a, b): + """Extended GCD of ``a`` and ``b``. """ + raise NotImplementedError + + def cofactors(self, a, b): + """Returns GCD and cofactors of ``a`` and ``b``. """ + gcd = self.gcd(a, b) + cfa = self.quo(a, gcd) + cfb = self.quo(b, gcd) + return gcd, cfa, cfb + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + raise NotImplementedError + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + raise NotImplementedError + + def log(self, a, b): + """Returns b-base logarithm of ``a``. """ + raise NotImplementedError + + def sqrt(self, a): + """Returns a (possibly inexact) square root of ``a``. + + Explanation + =========== + There is no universal definition of "inexact square root" for all + domains. It is not recommended to implement this method for domains + other then :ref:`ZZ`. + + See also + ======== + exsqrt + """ + raise NotImplementedError + + def is_square(self, a): + """Returns whether ``a`` is a square in the domain. + + Explanation + =========== + Returns ``True`` if there is an element ``b`` in the domain such that + ``b * b == a``, otherwise returns ``False``. For inexact domains like + :ref:`RR` and :ref:`CC`, a tiny difference in this equality can be + tolerated. + + See also + ======== + exsqrt + """ + raise NotImplementedError + + def exsqrt(self, a): + """Principal square root of a within the domain if ``a`` is square. + + Explanation + =========== + The implementation of this method should return an element ``b`` in the + domain such that ``b * b == a``, or ``None`` if there is no such ``b``. + For inexact domains like :ref:`RR` and :ref:`CC`, a tiny difference in + this equality can be tolerated. The choice of a "principal" square root + should follow a consistent rule whenever possible. + + See also + ======== + sqrt, is_square + """ + raise NotImplementedError + + def evalf(self, a, prec=None, **options): + """Returns numerical approximation of ``a``. """ + return self.to_sympy(a).evalf(prec, **options) + + n = evalf + + def real(self, a): + return a + + def imag(self, a): + return self.zero + + def almosteq(self, a, b, tolerance=None): + """Check if ``a`` and ``b`` are almost equal. """ + return a == b + + def characteristic(self): + """Return the characteristic of this domain. """ + raise NotImplementedError('characteristic()') + + +__all__ = ['Domain'] diff --git a/lib/python3.10/site-packages/sympy/polys/domains/domainelement.py b/lib/python3.10/site-packages/sympy/polys/domains/domainelement.py new file mode 100644 index 0000000000000000000000000000000000000000..b1033e86a7edcbffa633efd65ca7ced48f3b1f1a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/domainelement.py @@ -0,0 +1,38 @@ +"""Trait for implementing domain elements. """ + + +from sympy.utilities import public + +@public +class DomainElement: + """ + Represents an element of a domain. + + Mix in this trait into a class whose instances should be recognized as + elements of a domain. Method ``parent()`` gives that domain. + """ + + __slots__ = () + + def parent(self): + """Get the domain associated with ``self`` + + Examples + ======== + + >>> from sympy import ZZ, symbols + >>> x, y = symbols('x, y') + >>> K = ZZ[x,y] + >>> p = K(x)**2 + K(y)**2 + >>> p + x**2 + y**2 + >>> p.parent() + ZZ[x,y] + + Notes + ===== + + This is used by :py:meth:`~.Domain.convert` to identify the domain + associated with a domain element. + """ + raise NotImplementedError("abstract method") diff --git a/lib/python3.10/site-packages/sympy/polys/domains/expressiondomain.py b/lib/python3.10/site-packages/sympy/polys/domains/expressiondomain.py new file mode 100644 index 0000000000000000000000000000000000000000..26cd5aa5bf34985f885093be227df6aa9b35d36c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/expressiondomain.py @@ -0,0 +1,278 @@ +"""Implementation of :class:`ExpressionDomain` class. """ + + +from sympy.core import sympify, SympifyError +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyutils import PicklableWithSlots +from sympy.utilities import public + +eflags = {"deep": False, "mul": True, "power_exp": False, "power_base": False, + "basic": False, "multinomial": False, "log": False} + +@public +class ExpressionDomain(Field, CharacteristicZero, SimpleDomain): + """A class for arbitrary expressions. """ + + is_SymbolicDomain = is_EX = True + + class Expression(DomainElement, PicklableWithSlots): + """An arbitrary expression. """ + + __slots__ = ('ex',) + + def __init__(self, ex): + if not isinstance(ex, self.__class__): + self.ex = sympify(ex) + else: + self.ex = ex.ex + + def __repr__(f): + return 'EX(%s)' % repr(f.ex) + + def __str__(f): + return 'EX(%s)' % str(f.ex) + + def __hash__(self): + return hash((self.__class__.__name__, self.ex)) + + def parent(self): + return EX + + def as_expr(f): + return f.ex + + def numer(f): + return f.__class__(f.ex.as_numer_denom()[0]) + + def denom(f): + return f.__class__(f.ex.as_numer_denom()[1]) + + def simplify(f, ex): + return f.__class__(ex.cancel().expand(**eflags)) + + def __abs__(f): + return f.__class__(abs(f.ex)) + + def __neg__(f): + return f.__class__(-f.ex) + + def _to_ex(f, g): + try: + return f.__class__(g) + except SympifyError: + return None + + def __lt__(f, g): + return f.ex.sort_key() < g.ex.sort_key() + + def __add__(f, g): + g = f._to_ex(g) + + if g is None: + return NotImplemented + elif g == EX.zero: + return f + elif f == EX.zero: + return g + else: + return f.simplify(f.ex + g.ex) + + def __radd__(f, g): + return f.simplify(f.__class__(g).ex + f.ex) + + def __sub__(f, g): + g = f._to_ex(g) + + if g is None: + return NotImplemented + elif g == EX.zero: + return f + elif f == EX.zero: + return -g + else: + return f.simplify(f.ex - g.ex) + + def __rsub__(f, g): + return f.simplify(f.__class__(g).ex - f.ex) + + def __mul__(f, g): + g = f._to_ex(g) + + if g is None: + return NotImplemented + + if EX.zero in (f, g): + return EX.zero + elif f.ex.is_Number and g.ex.is_Number: + return f.__class__(f.ex*g.ex) + + return f.simplify(f.ex*g.ex) + + def __rmul__(f, g): + return f.simplify(f.__class__(g).ex*f.ex) + + def __pow__(f, n): + n = f._to_ex(n) + + if n is not None: + return f.simplify(f.ex**n.ex) + else: + return NotImplemented + + def __truediv__(f, g): + g = f._to_ex(g) + + if g is not None: + return f.simplify(f.ex/g.ex) + else: + return NotImplemented + + def __rtruediv__(f, g): + return f.simplify(f.__class__(g).ex/f.ex) + + def __eq__(f, g): + return f.ex == f.__class__(g).ex + + def __ne__(f, g): + return not f == g + + def __bool__(f): + return not f.ex.is_zero + + def gcd(f, g): + from sympy.polys import gcd + return f.__class__(gcd(f.ex, f.__class__(g).ex)) + + def lcm(f, g): + from sympy.polys import lcm + return f.__class__(lcm(f.ex, f.__class__(g).ex)) + + dtype = Expression + + zero = Expression(0) + one = Expression(1) + + rep = 'EX' + + has_assoc_Ring = False + has_assoc_Field = True + + def __init__(self): + pass + + def __eq__(self, other): + if isinstance(other, ExpressionDomain): + return True + else: + return NotImplemented + + def __hash__(self): + return hash("EX") + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return a.as_expr() + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + return self.dtype(a) + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ``GaussianRational`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianRational`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_AlgebraicField(K1, a, K0): + """Convert an ``ANP`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ComplexField(K1, a, K0): + """Convert a mpmath ``mpc`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_PolynomialRing(K1, a, K0): + """Convert a ``DMP`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_FractionField(K1, a, K0): + """Convert a ``DMF`` object to ``dtype``. """ + return K1(K0.to_sympy(a)) + + def from_ExpressionDomain(K1, a, K0): + """Convert a ``EX`` object to ``dtype``. """ + return a + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self # XXX: EX is not a ring but we don't have much choice here. + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return a.ex.as_coeff_mul()[0].is_positive + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return a.ex.could_extract_minus_sign() + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return a.ex.as_coeff_mul()[0].is_nonpositive + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return a.ex.as_coeff_mul()[0].is_nonnegative + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numer() + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denom() + + def gcd(self, a, b): + return self(1) + + def lcm(self, a, b): + return a.lcm(b) + + +EX = ExpressionDomain() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/expressionrawdomain.py b/lib/python3.10/site-packages/sympy/polys/domains/expressionrawdomain.py new file mode 100644 index 0000000000000000000000000000000000000000..9811ca26c965197a13f56ab8266ad744e4571560 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/expressionrawdomain.py @@ -0,0 +1,57 @@ +"""Implementation of :class:`ExpressionRawDomain` class. """ + + +from sympy.core import Expr, S, sympify, Add +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + + +@public +class ExpressionRawDomain(Field, CharacteristicZero, SimpleDomain): + """A class for arbitrary expressions but without automatic simplification. """ + + is_SymbolicRawDomain = is_EXRAW = True + + dtype = Expr + + zero = S.Zero + one = S.One + + rep = 'EXRAW' + + has_assoc_Ring = False + has_assoc_Field = True + + def __init__(self): + pass + + @classmethod + def new(self, a): + return sympify(a) + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return a + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + if not isinstance(a, Expr): + raise CoercionFailed(f"Expecting an Expr instance but found: {type(a).__name__}") + return a + + def convert_from(self, a, K): + """Convert a domain element from another domain to EXRAW""" + return K.to_sympy(a) + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def sum(self, items): + return Add(*items) + + +EXRAW = ExpressionRawDomain() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/field.py b/lib/python3.10/site-packages/sympy/polys/domains/field.py new file mode 100644 index 0000000000000000000000000000000000000000..33c1314dee45d0155e116118c912c961cb61281f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/field.py @@ -0,0 +1,104 @@ +"""Implementation of :class:`Field` class. """ + + +from sympy.polys.domains.ring import Ring +from sympy.polys.polyerrors import NotReversible, DomainError +from sympy.utilities import public + +@public +class Field(Ring): + """Represents a field domain. """ + + is_Field = True + is_PID = True + + def get_ring(self): + """Returns a ring associated with ``self``. """ + raise DomainError('there is no ring associated with %s' % self) + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return a / b + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return a / b + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies nothing. """ + return self.zero + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__truediv__``. """ + return a / b, self.zero + + def gcd(self, a, b): + """ + Returns GCD of ``a`` and ``b``. + + This definition of GCD over fields allows to clear denominators + in `primitive()`. + + Examples + ======== + + >>> from sympy.polys.domains import QQ + >>> from sympy import S, gcd, primitive + >>> from sympy.abc import x + + >>> QQ.gcd(QQ(2, 3), QQ(4, 9)) + 2/9 + >>> gcd(S(2)/3, S(4)/9) + 2/9 + >>> primitive(2*x/3 + S(4)/9) + (2/9, 3*x + 2) + + """ + try: + ring = self.get_ring() + except DomainError: + return self.one + + p = ring.gcd(self.numer(a), self.numer(b)) + q = ring.lcm(self.denom(a), self.denom(b)) + + return self.convert(p, ring)/q + + def lcm(self, a, b): + """ + Returns LCM of ``a`` and ``b``. + + >>> from sympy.polys.domains import QQ + >>> from sympy import S, lcm + + >>> QQ.lcm(QQ(2, 3), QQ(4, 9)) + 4/3 + >>> lcm(S(2)/3, S(4)/9) + 4/3 + + """ + + try: + ring = self.get_ring() + except DomainError: + return a*b + + p = ring.lcm(self.numer(a), self.numer(b)) + q = ring.gcd(self.denom(a), self.denom(b)) + + return self.convert(p, ring)/q + + def revert(self, a): + """Returns ``a**(-1)`` if possible. """ + if a: + return 1/a + else: + raise NotReversible('zero is not reversible') + + def is_unit(self, a): + """Return true if ``a`` is a invertible""" + return bool(a) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/finitefield.py b/lib/python3.10/site-packages/sympy/polys/domains/finitefield.py new file mode 100644 index 0000000000000000000000000000000000000000..92ecbaeb52dd7f49ebf81cb993a71a2cad817f52 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/finitefield.py @@ -0,0 +1,328 @@ +"""Implementation of :class:`FiniteField` class. """ + +import operator + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on + +from sympy.core.numbers import int_valued +from sympy.polys.domains.field import Field + +from sympy.polys.domains.modularinteger import ModularIntegerFactory +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.galoistools import gf_zassenhaus, gf_irred_p_rabin +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public +from sympy.polys.domains.groundtypes import SymPyInteger + + +if GROUND_TYPES == 'flint': + __doctest_skip__ = ['FiniteField'] + + +if GROUND_TYPES == 'flint': + import flint + # Don't use python-flint < 0.5.0 because nmod was missing some features in + # previous versions of python-flint and fmpz_mod was not yet added. + _major, _minor, *_ = flint.__version__.split('.') + if (int(_major), int(_minor)) < (0, 5): + flint = None +else: + flint = None + + +def _modular_int_factory(mod, dom, symmetric, self): + + # Use flint if available + if flint is not None: + + nmod = flint.nmod + fmpz_mod_ctx = flint.fmpz_mod_ctx + index = operator.index + + try: + mod = dom.convert(mod) + except CoercionFailed: + raise ValueError('modulus must be an integer, got %s' % mod) + + # mod might be e.g. Integer + try: + fmpz_mod_ctx(mod) + except TypeError: + mod = index(mod) + + # flint's nmod is only for moduli up to 2^64-1 (on a 64-bit machine) + try: + nmod(0, mod) + except OverflowError: + # Use fmpz_mod + fctx = fmpz_mod_ctx(mod) + + def ctx(x): + try: + return fctx(x) + except TypeError: + # x might be Integer + return fctx(index(x)) + else: + # Use nmod + def ctx(x): + try: + return nmod(x, mod) + except TypeError: + return nmod(index(x), mod) + + return ctx + + # Use the Python implementation + return ModularIntegerFactory(mod, dom, symmetric, self) + + +@public +@doctest_depends_on(modules=['python', 'gmpy']) +class FiniteField(Field, SimpleDomain): + r"""Finite field of prime order :ref:`GF(p)` + + A :ref:`GF(p)` domain represents a `finite field`_ `\mathbb{F}_p` of prime + order as :py:class:`~.Domain` in the domain system (see + :ref:`polys-domainsintro`). + + A :py:class:`~.Poly` created from an expression with integer + coefficients will have the domain :ref:`ZZ`. However, if the ``modulus=p`` + option is given then the domain will be a finite field instead. + + >>> from sympy import Poly, Symbol + >>> x = Symbol('x') + >>> p = Poly(x**2 + 1) + >>> p + Poly(x**2 + 1, x, domain='ZZ') + >>> p.domain + ZZ + >>> p2 = Poly(x**2 + 1, modulus=2) + >>> p2 + Poly(x**2 + 1, x, modulus=2) + >>> p2.domain + GF(2) + + It is possible to factorise a polynomial over :ref:`GF(p)` using the + modulus argument to :py:func:`~.factor` or by specifying the domain + explicitly. The domain can also be given as a string. + + >>> from sympy import factor, GF + >>> factor(x**2 + 1) + x**2 + 1 + >>> factor(x**2 + 1, modulus=2) + (x + 1)**2 + >>> factor(x**2 + 1, domain=GF(2)) + (x + 1)**2 + >>> factor(x**2 + 1, domain='GF(2)') + (x + 1)**2 + + It is also possible to use :ref:`GF(p)` with the :py:func:`~.cancel` + and :py:func:`~.gcd` functions. + + >>> from sympy import cancel, gcd + >>> cancel((x**2 + 1)/(x + 1)) + (x**2 + 1)/(x + 1) + >>> cancel((x**2 + 1)/(x + 1), domain=GF(2)) + x + 1 + >>> gcd(x**2 + 1, x + 1) + 1 + >>> gcd(x**2 + 1, x + 1, domain=GF(2)) + x + 1 + + When using the domain directly :ref:`GF(p)` can be used as a constructor + to create instances which then support the operations ``+,-,*,**,/`` + + >>> from sympy import GF + >>> K = GF(5) + >>> K + GF(5) + >>> x = K(3) + >>> y = K(2) + >>> x + 3 mod 5 + >>> y + 2 mod 5 + >>> x * y + 1 mod 5 + >>> x / y + 4 mod 5 + + Notes + ===== + + It is also possible to create a :ref:`GF(p)` domain of **non-prime** + order but the resulting ring is **not** a field: it is just the ring of + the integers modulo ``n``. + + >>> K = GF(9) + >>> z = K(3) + >>> z + 3 mod 9 + >>> z**2 + 0 mod 9 + + It would be good to have a proper implementation of prime power fields + (``GF(p**n)``) but these are not yet implemented in SymPY. + + .. _finite field: https://en.wikipedia.org/wiki/Finite_field + """ + + rep = 'FF' + alias = 'FF' + + is_FiniteField = is_FF = True + is_Numerical = True + + has_assoc_Ring = False + has_assoc_Field = True + + dom = None + mod = None + + def __init__(self, mod, symmetric=True): + from sympy.polys.domains import ZZ + dom = ZZ + + if mod <= 0: + raise ValueError('modulus must be a positive integer, got %s' % mod) + + self.dtype = _modular_int_factory(mod, dom, symmetric, self) + self.zero = self.dtype(0) + self.one = self.dtype(1) + self.dom = dom + self.mod = mod + self.sym = symmetric + self._tp = type(self.zero) + + @property + def tp(self): + return self._tp + + def __str__(self): + return 'GF(%s)' % self.mod + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.mod, self.dom)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, FiniteField) and \ + self.mod == other.mod and self.dom == other.dom + + def characteristic(self): + """Return the characteristic of this domain. """ + return self.mod + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(self.to_int(a)) + + def from_sympy(self, a): + """Convert SymPy's Integer to SymPy's ``Integer``. """ + if a.is_Integer: + return self.dtype(self.dom.dtype(int(a))) + elif int_valued(a): + return self.dtype(self.dom.dtype(int(a))) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def to_int(self, a): + """Convert ``val`` to a Python ``int`` object. """ + aval = int(a) + if self.sym and aval > self.mod // 2: + aval -= self.mod + return aval + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return bool(a) + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return True + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return False + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return not a + + def from_FF(K1, a, K0=None): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ(int(a), K0.dom)) + + def from_FF_python(K1, a, K0=None): + """Convert ``ModularInteger(int)`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_python(int(a), K0.dom)) + + def from_ZZ(K1, a, K0=None): + """Convert Python's ``int`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_python(a, K0)) + + def from_ZZ_python(K1, a, K0=None): + """Convert Python's ``int`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_python(a, K0)) + + def from_QQ(K1, a, K0=None): + """Convert Python's ``Fraction`` to ``dtype``. """ + if a.denominator == 1: + return K1.from_ZZ_python(a.numerator) + + def from_QQ_python(K1, a, K0=None): + """Convert Python's ``Fraction`` to ``dtype``. """ + if a.denominator == 1: + return K1.from_ZZ_python(a.numerator) + + def from_FF_gmpy(K1, a, K0=None): + """Convert ``ModularInteger(mpz)`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_gmpy(a.val, K0.dom)) + + def from_ZZ_gmpy(K1, a, K0=None): + """Convert GMPY's ``mpz`` to ``dtype``. """ + return K1.dtype(K1.dom.from_ZZ_gmpy(a, K0)) + + def from_QQ_gmpy(K1, a, K0=None): + """Convert GMPY's ``mpq`` to ``dtype``. """ + if a.denominator == 1: + return K1.from_ZZ_gmpy(a.numerator) + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to ``dtype``. """ + p, q = K0.to_rational(a) + + if q == 1: + return K1.dtype(K1.dom.dtype(p)) + + def is_square(self, a): + """Returns True if ``a`` is a quadratic residue modulo p. """ + # a is not a square <=> x**2-a is irreducible + poly = [int(x) for x in [self.one, self.zero, -a]] + return not gf_irred_p_rabin(poly, self.mod, self.dom) + + def exsqrt(self, a): + """Square root modulo p of ``a`` if it is a quadratic residue. + + Explanation + =========== + Always returns the square root that is no larger than ``p // 2``. + """ + # x**2-a is not square-free if a=0 or the field is characteristic 2 + if self.mod == 2 or a == 0: + return a + # Otherwise, use square-free factorization routine to factorize x**2-a + poly = [int(x) for x in [self.one, self.zero, -a]] + for factor in gf_zassenhaus(poly, self.mod, self.dom): + if len(factor) == 2 and factor[1] <= self.mod // 2: + return self.dtype(factor[1]) + return None + + +FF = GF = FiniteField diff --git a/lib/python3.10/site-packages/sympy/polys/domains/fractionfield.py b/lib/python3.10/site-packages/sympy/polys/domains/fractionfield.py new file mode 100644 index 0000000000000000000000000000000000000000..47bc25436b8e30f6a02506dc237bcf1791de487c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/fractionfield.py @@ -0,0 +1,177 @@ +"""Implementation of :class:`FractionField` class. """ + + +from sympy.polys.domains.compositedomain import CompositeDomain +from sympy.polys.domains.field import Field +from sympy.polys.polyerrors import CoercionFailed, GeneratorsError +from sympy.utilities import public + +@public +class FractionField(Field, CompositeDomain): + """A class for representing multivariate rational function fields. """ + + is_FractionField = is_Frac = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self, domain_or_field, symbols=None, order=None): + from sympy.polys.fields import FracField + + if isinstance(domain_or_field, FracField) and symbols is None and order is None: + field = domain_or_field + else: + field = FracField(symbols, domain_or_field, order) + + self.field = field + self.dtype = field.dtype + + self.gens = field.gens + self.ngens = field.ngens + self.symbols = field.symbols + self.domain = field.domain + + # TODO: remove this + self.dom = self.domain + + def new(self, element): + return self.field.field_new(element) + + @property + def zero(self): + return self.field.zero + + @property + def one(self): + return self.field.one + + @property + def order(self): + return self.field.order + + def __str__(self): + return str(self.domain) + '(' + ','.join(map(str, self.symbols)) + ')' + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype.field, self.domain, self.symbols)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, FractionField) and \ + (self.dtype.field, self.domain, self.symbols) ==\ + (other.dtype.field, other.domain, other.symbols) + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return a.as_expr() + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + return self.field.from_expr(a) + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + dom = K1.domain + conv = dom.convert_from + if dom.is_ZZ: + return K1(conv(K0.numer(a), K0)) / K1(conv(K0.denom(a), K0)) + else: + return K1(conv(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianRational`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ``GaussianInteger`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_ComplexField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K1.domain.convert(a, K0)) + + def from_AlgebraicField(K1, a, K0): + """Convert an algebraic number to ``dtype``. """ + if K1.domain != K0: + a = K1.domain.convert_from(a, K0) + if a is not None: + return K1.new(a) + + def from_PolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + if a.is_ground: + return K1.convert_from(a.coeff(1), K0.domain) + try: + return K1.new(a.set_ring(K1.field.ring)) + except (CoercionFailed, GeneratorsError): + # XXX: We get here if K1=ZZ(x,y) and K0=QQ[x,y] + # and the poly a in K0 has non-integer coefficients. + # It seems that K1.new can handle this but K1.new doesn't work + # when K0.domain is an algebraic field... + try: + return K1.new(a) + except (CoercionFailed, GeneratorsError): + return None + + def from_FractionField(K1, a, K0): + """Convert a rational function to ``dtype``. """ + try: + return a.set_field(K1.field) + except (CoercionFailed, GeneratorsError): + return None + + def get_ring(self): + """Returns a field associated with ``self``. """ + return self.field.to_ring().to_domain() + + def is_positive(self, a): + """Returns True if ``LC(a)`` is positive. """ + return self.domain.is_positive(a.numer.LC) + + def is_negative(self, a): + """Returns True if ``LC(a)`` is negative. """ + return self.domain.is_negative(a.numer.LC) + + def is_nonpositive(self, a): + """Returns True if ``LC(a)`` is non-positive. """ + return self.domain.is_nonpositive(a.numer.LC) + + def is_nonnegative(self, a): + """Returns True if ``LC(a)`` is non-negative. """ + return self.domain.is_nonnegative(a.numer.LC) + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numer + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denom + + def factorial(self, a): + """Returns factorial of ``a``. """ + return self.dtype(self.domain.factorial(a)) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/gaussiandomains.py b/lib/python3.10/site-packages/sympy/polys/domains/gaussiandomains.py new file mode 100644 index 0000000000000000000000000000000000000000..bf3df50d5de65da0ac22b6f00d364d44e0cc28c7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/gaussiandomains.py @@ -0,0 +1,686 @@ +"""Domains of Gaussian type.""" + +from sympy.core.numbers import I +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.domain import Domain +from sympy.polys.domains.domainelement import DomainElement +from sympy.polys.domains.field import Field +from sympy.polys.domains.ring import Ring + + +class GaussianElement(DomainElement): + """Base class for elements of Gaussian type domains.""" + base: Domain + _parent: Domain + + __slots__ = ('x', 'y') + + def __new__(cls, x, y=0): + conv = cls.base.convert + return cls.new(conv(x), conv(y)) + + @classmethod + def new(cls, x, y): + """Create a new GaussianElement of the same domain.""" + obj = super().__new__(cls) + obj.x = x + obj.y = y + return obj + + def parent(self): + """The domain that this is an element of (ZZ_I or QQ_I)""" + return self._parent + + def __hash__(self): + return hash((self.x, self.y)) + + def __eq__(self, other): + if isinstance(other, self.__class__): + return self.x == other.x and self.y == other.y + else: + return NotImplemented + + def __lt__(self, other): + if not isinstance(other, GaussianElement): + return NotImplemented + return [self.y, self.x] < [other.y, other.x] + + def __pos__(self): + return self + + def __neg__(self): + return self.new(-self.x, -self.y) + + def __repr__(self): + return "%s(%s, %s)" % (self._parent.rep, self.x, self.y) + + def __str__(self): + return str(self._parent.to_sympy(self)) + + @classmethod + def _get_xy(cls, other): + if not isinstance(other, cls): + try: + other = cls._parent.convert(other) + except CoercionFailed: + return None, None + return other.x, other.y + + def __add__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(self.x + x, self.y + y) + else: + return NotImplemented + + __radd__ = __add__ + + def __sub__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(self.x - x, self.y - y) + else: + return NotImplemented + + def __rsub__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(x - self.x, y - self.y) + else: + return NotImplemented + + def __mul__(self, other): + x, y = self._get_xy(other) + if x is not None: + return self.new(self.x*x - self.y*y, self.x*y + self.y*x) + else: + return NotImplemented + + __rmul__ = __mul__ + + def __pow__(self, exp): + if exp == 0: + return self.new(1, 0) + if exp < 0: + self, exp = 1/self, -exp + if exp == 1: + return self + pow2 = self + prod = self if exp % 2 else self._parent.one + exp //= 2 + while exp: + pow2 *= pow2 + if exp % 2: + prod *= pow2 + exp //= 2 + return prod + + def __bool__(self): + return bool(self.x) or bool(self.y) + + def quadrant(self): + """Return quadrant index 0-3. + + 0 is included in quadrant 0. + """ + if self.y > 0: + return 0 if self.x > 0 else 1 + elif self.y < 0: + return 2 if self.x < 0 else 3 + else: + return 0 if self.x >= 0 else 2 + + def __rdivmod__(self, other): + try: + other = self._parent.convert(other) + except CoercionFailed: + return NotImplemented + else: + return other.__divmod__(self) + + def __rtruediv__(self, other): + try: + other = QQ_I.convert(other) + except CoercionFailed: + return NotImplemented + else: + return other.__truediv__(self) + + def __floordiv__(self, other): + qr = self.__divmod__(other) + return qr if qr is NotImplemented else qr[0] + + def __rfloordiv__(self, other): + qr = self.__rdivmod__(other) + return qr if qr is NotImplemented else qr[0] + + def __mod__(self, other): + qr = self.__divmod__(other) + return qr if qr is NotImplemented else qr[1] + + def __rmod__(self, other): + qr = self.__rdivmod__(other) + return qr if qr is NotImplemented else qr[1] + + +class GaussianInteger(GaussianElement): + """Gaussian integer: domain element for :ref:`ZZ_I` + + >>> from sympy import ZZ_I + >>> z = ZZ_I(2, 3) + >>> z + (2 + 3*I) + >>> type(z) + + """ + base = ZZ + + def __truediv__(self, other): + """Return a Gaussian rational.""" + return QQ_I.convert(self)/other + + def __divmod__(self, other): + if not other: + raise ZeroDivisionError('divmod({}, 0)'.format(self)) + x, y = self._get_xy(other) + if x is None: + return NotImplemented + + # multiply self and other by x - I*y + # self/other == (a + I*b)/c + a, b = self.x*x + self.y*y, -self.x*y + self.y*x + c = x*x + y*y + + # find integers qx and qy such that + # |a - qx*c| <= c/2 and |b - qy*c| <= c/2 + qx = (2*a + c) // (2*c) # -c <= 2*a - qx*2*c < c + qy = (2*b + c) // (2*c) + + q = GaussianInteger(qx, qy) + # |self/other - q| < 1 since + # |a/c - qx|**2 + |b/c - qy|**2 <= 1/4 + 1/4 < 1 + + return q, self - q*other # |r| < |other| + + +class GaussianRational(GaussianElement): + """Gaussian rational: domain element for :ref:`QQ_I` + + >>> from sympy import QQ_I, QQ + >>> z = QQ_I(QQ(2, 3), QQ(4, 5)) + >>> z + (2/3 + 4/5*I) + >>> type(z) + + """ + base = QQ + + def __truediv__(self, other): + """Return a Gaussian rational.""" + if not other: + raise ZeroDivisionError('{} / 0'.format(self)) + x, y = self._get_xy(other) + if x is None: + return NotImplemented + c = x*x + y*y + + return GaussianRational((self.x*x + self.y*y)/c, + (-self.x*y + self.y*x)/c) + + def __divmod__(self, other): + try: + other = self._parent.convert(other) + except CoercionFailed: + return NotImplemented + if not other: + raise ZeroDivisionError('{} % 0'.format(self)) + else: + return self/other, QQ_I.zero + + +class GaussianDomain(): + """Base class for Gaussian domains.""" + dom = None # type: Domain + + is_Numerical = True + is_Exact = True + + has_assoc_Ring = True + has_assoc_Field = True + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + conv = self.dom.to_sympy + return conv(a.x) + I*conv(a.y) + + def from_sympy(self, a): + """Convert a SymPy object to ``self.dtype``.""" + r, b = a.as_coeff_Add() + x = self.dom.from_sympy(r) # may raise CoercionFailed + if not b: + return self.new(x, 0) + r, b = b.as_coeff_Mul() + y = self.dom.from_sympy(r) + if b is I: + return self.new(x, y) + else: + raise CoercionFailed("{} is not Gaussian".format(a)) + + def inject(self, *gens): + """Inject generators into this domain. """ + return self.poly_ring(*gens) + + def canonical_unit(self, d): + unit = self.units[-d.quadrant()] # - for inverse power + return unit + + def is_negative(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def is_positive(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def is_nonnegative(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def is_nonpositive(self, element): + """Returns ``False`` for any ``GaussianElement``. """ + return False + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY mpz to ``self.dtype``.""" + return K1(a) + + def from_ZZ(K1, a, K0): + """Convert a ZZ_python element to ``self.dtype``.""" + return K1(a) + + def from_ZZ_python(K1, a, K0): + """Convert a ZZ_python element to ``self.dtype``.""" + return K1(a) + + def from_QQ(K1, a, K0): + """Convert a GMPY mpq to ``self.dtype``.""" + return K1(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY mpq to ``self.dtype``.""" + return K1(a) + + def from_QQ_python(K1, a, K0): + """Convert a QQ_python element to ``self.dtype``.""" + return K1(a) + + def from_AlgebraicField(K1, a, K0): + """Convert an element from ZZ or QQ to ``self.dtype``.""" + if K0.ext.args[0] == I: + return K1.from_sympy(K0.to_sympy(a)) + + +class GaussianIntegerRing(GaussianDomain, Ring): + r"""Ring of Gaussian integers ``ZZ_I`` + + The :ref:`ZZ_I` domain represents the `Gaussian integers`_ `\mathbb{Z}[i]` + as a :py:class:`~.Domain` in the domain system (see + :ref:`polys-domainsintro`). + + By default a :py:class:`~.Poly` created from an expression with + coefficients that are combinations of integers and ``I`` (`\sqrt{-1}`) + will have the domain :ref:`ZZ_I`. + + >>> from sympy import Poly, Symbol, I + >>> x = Symbol('x') + >>> p = Poly(x**2 + I) + >>> p + Poly(x**2 + I, x, domain='ZZ_I') + >>> p.domain + ZZ_I + + The :ref:`ZZ_I` domain can be used to factorise polynomials that are + reducible over the Gaussian integers. + + >>> from sympy import factor + >>> factor(x**2 + 1) + x**2 + 1 + >>> factor(x**2 + 1, domain='ZZ_I') + (x - I)*(x + I) + + The corresponding `field of fractions`_ is the domain of the Gaussian + rationals :ref:`QQ_I`. Conversely :ref:`ZZ_I` is the `ring of integers`_ + of :ref:`QQ_I`. + + >>> from sympy import ZZ_I, QQ_I + >>> ZZ_I.get_field() + QQ_I + >>> QQ_I.get_ring() + ZZ_I + + When using the domain directly :ref:`ZZ_I` can be used as a constructor. + + >>> ZZ_I(3, 4) + (3 + 4*I) + >>> ZZ_I(5) + (5 + 0*I) + + The domain elements of :ref:`ZZ_I` are instances of + :py:class:`~.GaussianInteger` which support the rings operations + ``+,-,*,**``. + + >>> z1 = ZZ_I(5, 1) + >>> z2 = ZZ_I(2, 3) + >>> z1 + (5 + 1*I) + >>> z2 + (2 + 3*I) + >>> z1 + z2 + (7 + 4*I) + >>> z1 * z2 + (7 + 17*I) + >>> z1 ** 2 + (24 + 10*I) + + Both floor (``//``) and modulo (``%``) division work with + :py:class:`~.GaussianInteger` (see the :py:meth:`~.Domain.div` method). + + >>> z3, z4 = ZZ_I(5), ZZ_I(1, 3) + >>> z3 // z4 # floor division + (1 + -1*I) + >>> z3 % z4 # modulo division (remainder) + (1 + -2*I) + >>> (z3//z4)*z4 + z3%z4 == z3 + True + + True division (``/``) in :ref:`ZZ_I` gives an element of :ref:`QQ_I`. The + :py:meth:`~.Domain.exquo` method can be used to divide in :ref:`ZZ_I` when + exact division is possible. + + >>> z1 / z2 + (1 + -1*I) + >>> ZZ_I.exquo(z1, z2) + (1 + -1*I) + >>> z3 / z4 + (1/2 + -3/2*I) + >>> ZZ_I.exquo(z3, z4) + Traceback (most recent call last): + ... + ExactQuotientFailed: (1 + 3*I) does not divide (5 + 0*I) in ZZ_I + + The :py:meth:`~.Domain.gcd` method can be used to compute the `gcd`_ of any + two elements. + + >>> ZZ_I.gcd(ZZ_I(10), ZZ_I(2)) + (2 + 0*I) + >>> ZZ_I.gcd(ZZ_I(5), ZZ_I(2, 1)) + (2 + 1*I) + + .. _Gaussian integers: https://en.wikipedia.org/wiki/Gaussian_integer + .. _gcd: https://en.wikipedia.org/wiki/Greatest_common_divisor + + """ + dom = ZZ + dtype = GaussianInteger + zero = dtype(ZZ(0), ZZ(0)) + one = dtype(ZZ(1), ZZ(0)) + imag_unit = dtype(ZZ(0), ZZ(1)) + units = (one, imag_unit, -one, -imag_unit) # powers of i + + rep = 'ZZ_I' + + is_GaussianRing = True + is_ZZ_I = True + + def __init__(self): # override Domain.__init__ + """For constructing ZZ_I.""" + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, GaussianIntegerRing): + return True + else: + return NotImplemented + + def __hash__(self): + """Compute hash code of ``self``. """ + return hash('ZZ_I') + + @property + def has_CharacteristicZero(self): + return True + + def characteristic(self): + return 0 + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self + + def get_field(self): + """Returns a field associated with ``self``. """ + return QQ_I + + def normalize(self, d, *args): + """Return first quadrant element associated with ``d``. + + Also multiply the other arguments by the same power of i. + """ + unit = self.canonical_unit(d) + d *= unit + args = tuple(a*unit for a in args) + return (d,) + args if args else d + + def gcd(self, a, b): + """Greatest common divisor of a and b over ZZ_I.""" + while b: + a, b = b, a % b + return self.normalize(a) + + def lcm(self, a, b): + """Least common multiple of a and b over ZZ_I.""" + return (a * b) // self.gcd(a, b) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ZZ_I element to ZZ_I.""" + return a + + def from_GaussianRationalField(K1, a, K0): + """Convert a QQ_I element to ZZ_I.""" + return K1.new(ZZ.convert(a.x), ZZ.convert(a.y)) + +ZZ_I = GaussianInteger._parent = GaussianIntegerRing() + + +class GaussianRationalField(GaussianDomain, Field): + r"""Field of Gaussian rationals ``QQ_I`` + + The :ref:`QQ_I` domain represents the `Gaussian rationals`_ `\mathbb{Q}(i)` + as a :py:class:`~.Domain` in the domain system (see + :ref:`polys-domainsintro`). + + By default a :py:class:`~.Poly` created from an expression with + coefficients that are combinations of rationals and ``I`` (`\sqrt{-1}`) + will have the domain :ref:`QQ_I`. + + >>> from sympy import Poly, Symbol, I + >>> x = Symbol('x') + >>> p = Poly(x**2 + I/2) + >>> p + Poly(x**2 + I/2, x, domain='QQ_I') + >>> p.domain + QQ_I + + The polys option ``gaussian=True`` can be used to specify that the domain + should be :ref:`QQ_I` even if the coefficients do not contain ``I`` or are + all integers. + + >>> Poly(x**2) + Poly(x**2, x, domain='ZZ') + >>> Poly(x**2 + I) + Poly(x**2 + I, x, domain='ZZ_I') + >>> Poly(x**2/2) + Poly(1/2*x**2, x, domain='QQ') + >>> Poly(x**2, gaussian=True) + Poly(x**2, x, domain='QQ_I') + >>> Poly(x**2 + I, gaussian=True) + Poly(x**2 + I, x, domain='QQ_I') + >>> Poly(x**2/2, gaussian=True) + Poly(1/2*x**2, x, domain='QQ_I') + + The :ref:`QQ_I` domain can be used to factorise polynomials that are + reducible over the Gaussian rationals. + + >>> from sympy import factor, QQ_I + >>> factor(x**2/4 + 1) + (x**2 + 4)/4 + >>> factor(x**2/4 + 1, domain='QQ_I') + (x - 2*I)*(x + 2*I)/4 + >>> factor(x**2/4 + 1, domain=QQ_I) + (x - 2*I)*(x + 2*I)/4 + + It is also possible to specify the :ref:`QQ_I` domain explicitly with + polys functions like :py:func:`~.apart`. + + >>> from sympy import apart + >>> apart(1/(1 + x**2)) + 1/(x**2 + 1) + >>> apart(1/(1 + x**2), domain=QQ_I) + I/(2*(x + I)) - I/(2*(x - I)) + + The corresponding `ring of integers`_ is the domain of the Gaussian + integers :ref:`ZZ_I`. Conversely :ref:`QQ_I` is the `field of fractions`_ + of :ref:`ZZ_I`. + + >>> from sympy import ZZ_I, QQ_I, QQ + >>> ZZ_I.get_field() + QQ_I + >>> QQ_I.get_ring() + ZZ_I + + When using the domain directly :ref:`QQ_I` can be used as a constructor. + + >>> QQ_I(3, 4) + (3 + 4*I) + >>> QQ_I(5) + (5 + 0*I) + >>> QQ_I(QQ(2, 3), QQ(4, 5)) + (2/3 + 4/5*I) + + The domain elements of :ref:`QQ_I` are instances of + :py:class:`~.GaussianRational` which support the field operations + ``+,-,*,**,/``. + + >>> z1 = QQ_I(5, 1) + >>> z2 = QQ_I(2, QQ(1, 2)) + >>> z1 + (5 + 1*I) + >>> z2 + (2 + 1/2*I) + >>> z1 + z2 + (7 + 3/2*I) + >>> z1 * z2 + (19/2 + 9/2*I) + >>> z2 ** 2 + (15/4 + 2*I) + + True division (``/``) in :ref:`QQ_I` gives an element of :ref:`QQ_I` and + is always exact. + + >>> z1 / z2 + (42/17 + -2/17*I) + >>> QQ_I.exquo(z1, z2) + (42/17 + -2/17*I) + >>> z1 == (z1/z2)*z2 + True + + Both floor (``//``) and modulo (``%``) division can be used with + :py:class:`~.GaussianRational` (see :py:meth:`~.Domain.div`) + but division is always exact so there is no remainder. + + >>> z1 // z2 + (42/17 + -2/17*I) + >>> z1 % z2 + (0 + 0*I) + >>> QQ_I.div(z1, z2) + ((42/17 + -2/17*I), (0 + 0*I)) + >>> (z1//z2)*z2 + z1%z2 == z1 + True + + .. _Gaussian rationals: https://en.wikipedia.org/wiki/Gaussian_rational + """ + dom = QQ + dtype = GaussianRational + zero = dtype(QQ(0), QQ(0)) + one = dtype(QQ(1), QQ(0)) + imag_unit = dtype(QQ(0), QQ(1)) + units = (one, imag_unit, -one, -imag_unit) # powers of i + + rep = 'QQ_I' + + is_GaussianField = True + is_QQ_I = True + + def __init__(self): # override Domain.__init__ + """For constructing QQ_I.""" + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, GaussianRationalField): + return True + else: + return NotImplemented + + def __hash__(self): + """Compute hash code of ``self``. """ + return hash('QQ_I') + + @property + def has_CharacteristicZero(self): + return True + + def characteristic(self): + return 0 + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return ZZ_I + + def get_field(self): + """Returns a field associated with ``self``. """ + return self + + def as_AlgebraicField(self): + """Get equivalent domain as an ``AlgebraicField``. """ + return AlgebraicField(self.dom, I) + + def numer(self, a): + """Get the numerator of ``a``.""" + ZZ_I = self.get_ring() + return ZZ_I.convert(a * self.denom(a)) + + def denom(self, a): + """Get the denominator of ``a``.""" + ZZ = self.dom.get_ring() + QQ = self.dom + ZZ_I = self.get_ring() + denom_ZZ = ZZ.lcm(QQ.denom(a.x), QQ.denom(a.y)) + return ZZ_I(denom_ZZ, ZZ.zero) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a ZZ_I element to QQ_I.""" + return K1.new(a.x, a.y) + + def from_GaussianRationalField(K1, a, K0): + """Convert a QQ_I element to QQ_I.""" + return a + + def from_ComplexField(K1, a, K0): + """Convert a ComplexField element to QQ_I.""" + return K1.new(QQ.convert(a.real), QQ.convert(a.imag)) + + +QQ_I = GaussianRational._parent = GaussianRationalField() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/gmpyfinitefield.py b/lib/python3.10/site-packages/sympy/polys/domains/gmpyfinitefield.py new file mode 100644 index 0000000000000000000000000000000000000000..2e8315a29eca8160102d66b83d953caf998b0fd7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/gmpyfinitefield.py @@ -0,0 +1,16 @@ +"""Implementation of :class:`GMPYFiniteField` class. """ + + +from sympy.polys.domains.finitefield import FiniteField +from sympy.polys.domains.gmpyintegerring import GMPYIntegerRing + +from sympy.utilities import public + +@public +class GMPYFiniteField(FiniteField): + """Finite field based on GMPY integers. """ + + alias = 'FF_gmpy' + + def __init__(self, mod, symmetric=True): + super().__init__(mod, GMPYIntegerRing(), symmetric) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/gmpyintegerring.py b/lib/python3.10/site-packages/sympy/polys/domains/gmpyintegerring.py new file mode 100644 index 0000000000000000000000000000000000000000..f132bbe5aff7a4164a09b9b90f00ae5f140cbd03 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/gmpyintegerring.py @@ -0,0 +1,105 @@ +"""Implementation of :class:`GMPYIntegerRing` class. """ + + +from sympy.polys.domains.groundtypes import ( + GMPYInteger, SymPyInteger, + factorial as gmpy_factorial, + gmpy_gcdex, gmpy_gcd, gmpy_lcm, sqrt as gmpy_sqrt, +) +from sympy.core.numbers import int_valued +from sympy.polys.domains.integerring import IntegerRing +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class GMPYIntegerRing(IntegerRing): + """Integer ring based on GMPY's ``mpz`` type. + + This will be the implementation of :ref:`ZZ` if ``gmpy`` or ``gmpy2`` is + installed. Elements will be of type ``gmpy.mpz``. + """ + + dtype = GMPYInteger + zero = dtype(0) + one = dtype(1) + tp = type(one) + alias = 'ZZ_gmpy' + + def __init__(self): + """Allow instantiation of this domain. """ + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(int(a)) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Integer: + return GMPYInteger(a.p) + elif int_valued(a): + return GMPYInteger(int(a)) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to GMPY's ``mpz``. """ + return K0.to_int(a) + + def from_ZZ_python(K1, a, K0): + """Convert Python's ``int`` to GMPY's ``mpz``. """ + return GMPYInteger(a) + + def from_QQ(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return GMPYInteger(a.numerator) + + def from_QQ_python(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return GMPYInteger(a.numerator) + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to GMPY's ``mpz``. """ + return K0.to_int(a) + + def from_ZZ_gmpy(K1, a, K0): + """Convert GMPY's ``mpz`` to GMPY's ``mpz``. """ + return a + + def from_QQ_gmpy(K1, a, K0): + """Convert GMPY ``mpq`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return a.numerator + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to GMPY's ``mpz``. """ + p, q = K0.to_rational(a) + + if q == 1: + return GMPYInteger(p) + + def from_GaussianIntegerRing(K1, a, K0): + if a.y == 0: + return a.x + + def gcdex(self, a, b): + """Compute extended GCD of ``a`` and ``b``. """ + h, s, t = gmpy_gcdex(a, b) + return s, t, h + + def gcd(self, a, b): + """Compute GCD of ``a`` and ``b``. """ + return gmpy_gcd(a, b) + + def lcm(self, a, b): + """Compute LCM of ``a`` and ``b``. """ + return gmpy_lcm(a, b) + + def sqrt(self, a): + """Compute square root of ``a``. """ + return gmpy_sqrt(a) + + def factorial(self, a): + """Compute factorial of ``a``. """ + return gmpy_factorial(a) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/gmpyrationalfield.py b/lib/python3.10/site-packages/sympy/polys/domains/gmpyrationalfield.py new file mode 100644 index 0000000000000000000000000000000000000000..10bae5b2b7b476f96ba06f637c549ee4afff4c6d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/gmpyrationalfield.py @@ -0,0 +1,100 @@ +"""Implementation of :class:`GMPYRationalField` class. """ + + +from sympy.polys.domains.groundtypes import ( + GMPYRational, SymPyRational, + gmpy_numer, gmpy_denom, factorial as gmpy_factorial, +) +from sympy.polys.domains.rationalfield import RationalField +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class GMPYRationalField(RationalField): + """Rational field based on GMPY's ``mpq`` type. + + This will be the implementation of :ref:`QQ` if ``gmpy`` or ``gmpy2`` is + installed. Elements will be of type ``gmpy.mpq``. + """ + + dtype = GMPYRational + zero = dtype(0) + one = dtype(1) + tp = type(one) + alias = 'QQ_gmpy' + + def __init__(self): + pass + + def get_ring(self): + """Returns ring associated with ``self``. """ + from sympy.polys.domains import GMPYIntegerRing + return GMPYIntegerRing() + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyRational(int(gmpy_numer(a)), + int(gmpy_denom(a))) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Rational: + return GMPYRational(a.p, a.q) + elif a.is_Float: + from sympy.polys.domains import RR + return GMPYRational(*map(int, RR.to_rational(a))) + else: + raise CoercionFailed("expected ``Rational`` object, got %s" % a) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return GMPYRational(a) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return GMPYRational(a.numerator, a.denominator) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return GMPYRational(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return a + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianElement`` object to ``dtype``. """ + if a.y == 0: + return GMPYRational(a.x) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return GMPYRational(*map(int, K0.to_rational(a))) + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return GMPYRational(a) / GMPYRational(b) + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return GMPYRational(a) / GMPYRational(b) + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies nothing. """ + return self.zero + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__truediv__``. """ + return GMPYRational(a) / GMPYRational(b), self.zero + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numerator + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denominator + + def factorial(self, a): + """Returns factorial of ``a``. """ + return GMPYRational(gmpy_factorial(int(a))) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/groundtypes.py b/lib/python3.10/site-packages/sympy/polys/domains/groundtypes.py new file mode 100644 index 0000000000000000000000000000000000000000..1d50cf912a998767c4a52c5a2f3aab825e072aec --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/groundtypes.py @@ -0,0 +1,99 @@ +"""Ground types for various mathematical domains in SymPy. """ + +import builtins +from sympy.external.gmpy import GROUND_TYPES, factorial, sqrt, is_square, sqrtrem + +PythonInteger = builtins.int +PythonReal = builtins.float +PythonComplex = builtins.complex + +from .pythonrational import PythonRational + +from sympy.core.intfunc import ( + igcdex as python_gcdex, + igcd2 as python_gcd, + ilcm as python_lcm, +) + +from sympy.core.numbers import (Float as SymPyReal, Integer as SymPyInteger, Rational as SymPyRational) + + +class _GMPYInteger: + def __init__(self, obj): + pass + +class _GMPYRational: + def __init__(self, obj): + pass + + +if GROUND_TYPES == 'gmpy': + + from gmpy2 import ( + mpz as GMPYInteger, + mpq as GMPYRational, + numer as gmpy_numer, + denom as gmpy_denom, + gcdext as gmpy_gcdex, + gcd as gmpy_gcd, + lcm as gmpy_lcm, + qdiv as gmpy_qdiv, + ) + gcdex = gmpy_gcdex + gcd = gmpy_gcd + lcm = gmpy_lcm + +elif GROUND_TYPES == 'flint': + + from flint import fmpz as _fmpz + + GMPYInteger = _GMPYInteger + GMPYRational = _GMPYRational + gmpy_numer = None + gmpy_denom = None + gmpy_gcdex = None + gmpy_gcd = None + gmpy_lcm = None + gmpy_qdiv = None + + def gcd(a, b): + return a.gcd(b) + + def gcdex(a, b): + x, y, g = python_gcdex(a, b) + return _fmpz(x), _fmpz(y), _fmpz(g) + + def lcm(a, b): + return a.lcm(b) + +else: + GMPYInteger = _GMPYInteger + GMPYRational = _GMPYRational + gmpy_numer = None + gmpy_denom = None + gmpy_gcdex = None + gmpy_gcd = None + gmpy_lcm = None + gmpy_qdiv = None + gcdex = python_gcdex + gcd = python_gcd + lcm = python_lcm + + +__all__ = [ + 'PythonInteger', 'PythonReal', 'PythonComplex', + + 'PythonRational', + + 'python_gcdex', 'python_gcd', 'python_lcm', + + 'SymPyReal', 'SymPyInteger', 'SymPyRational', + + 'GMPYInteger', 'GMPYRational', 'gmpy_numer', + 'gmpy_denom', 'gmpy_gcdex', 'gmpy_gcd', 'gmpy_lcm', + 'gmpy_qdiv', + + 'factorial', 'sqrt', 'is_square', 'sqrtrem', + + 'GMPYInteger', 'GMPYRational', +] diff --git a/lib/python3.10/site-packages/sympy/polys/domains/integerring.py b/lib/python3.10/site-packages/sympy/polys/domains/integerring.py new file mode 100644 index 0000000000000000000000000000000000000000..65eaa9631cfdf138997a4ebdb362c4233fb098fb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/integerring.py @@ -0,0 +1,276 @@ +"""Implementation of :class:`IntegerRing` class. """ + +from sympy.external.gmpy import MPZ, GROUND_TYPES + +from sympy.core.numbers import int_valued +from sympy.polys.domains.groundtypes import ( + SymPyInteger, + factorial, + gcdex, gcd, lcm, sqrt, is_square, sqrtrem, +) + +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.ring import Ring +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +import math + +@public +class IntegerRing(Ring, CharacteristicZero, SimpleDomain): + r"""The domain ``ZZ`` representing the integers `\mathbb{Z}`. + + The :py:class:`IntegerRing` class represents the ring of integers as a + :py:class:`~.Domain` in the domain system. :py:class:`IntegerRing` is a + super class of :py:class:`PythonIntegerRing` and + :py:class:`GMPYIntegerRing` one of which will be the implementation for + :ref:`ZZ` depending on whether or not ``gmpy`` or ``gmpy2`` is installed. + + See also + ======== + + Domain + """ + + rep = 'ZZ' + alias = 'ZZ' + dtype = MPZ + zero = dtype(0) + one = dtype(1) + tp = type(one) + + + is_IntegerRing = is_ZZ = True + is_Numerical = True + is_PID = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self): + """Allow instantiation of this domain. """ + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, IntegerRing): + return True + else: + return NotImplemented + + def __hash__(self): + """Compute a hash value for this domain. """ + return hash('ZZ') + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(int(a)) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Integer: + return MPZ(a.p) + elif int_valued(a): + return MPZ(int(a)) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def get_field(self): + r"""Return the associated field of fractions :ref:`QQ` + + Returns + ======= + + :ref:`QQ`: + The associated field of fractions :ref:`QQ`, a + :py:class:`~.Domain` representing the rational numbers + `\mathbb{Q}`. + + Examples + ======== + + >>> from sympy import ZZ + >>> ZZ.get_field() + QQ + """ + from sympy.polys.domains import QQ + return QQ + + def algebraic_field(self, *extension, alias=None): + r"""Returns an algebraic field, i.e. `\mathbb{Q}(\alpha, \ldots)`. + + Parameters + ========== + + *extension : One or more :py:class:`~.Expr`. + Generators of the extension. These should be expressions that are + algebraic over `\mathbb{Q}`. + + alias : str, :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the alias symbol for the + primitive element of the returned :py:class:`~.AlgebraicField`. + + Returns + ======= + + :py:class:`~.AlgebraicField` + A :py:class:`~.Domain` representing the algebraic field extension. + + Examples + ======== + + >>> from sympy import ZZ, sqrt + >>> ZZ.algebraic_field(sqrt(2)) + QQ + """ + return self.get_field().algebraic_field(*extension, alias=alias) + + def from_AlgebraicField(K1, a, K0): + """Convert a :py:class:`~.ANP` object to :ref:`ZZ`. + + See :py:meth:`~.Domain.convert`. + """ + if a.is_ground: + return K1.convert(a.LC(), K0.dom) + + def log(self, a, b): + r"""Logarithm of *a* to the base *b*. + + Parameters + ========== + + a: number + b: number + + Returns + ======= + + $\\lfloor\log(a, b)\\rfloor$: + Floor of the logarithm of *a* to the base *b* + + Examples + ======== + + >>> from sympy import ZZ + >>> ZZ.log(ZZ(8), ZZ(2)) + 3 + >>> ZZ.log(ZZ(9), ZZ(2)) + 3 + + Notes + ===== + + This function uses ``math.log`` which is based on ``float`` so it will + fail for large integer arguments. + """ + return self.dtype(int(math.log(int(a), b))) + + def from_FF(K1, a, K0): + """Convert ``ModularInteger(int)`` to GMPY's ``mpz``. """ + return MPZ(K0.to_int(a)) + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to GMPY's ``mpz``. """ + return MPZ(K0.to_int(a)) + + def from_ZZ(K1, a, K0): + """Convert Python's ``int`` to GMPY's ``mpz``. """ + return MPZ(a) + + def from_ZZ_python(K1, a, K0): + """Convert Python's ``int`` to GMPY's ``mpz``. """ + return MPZ(a) + + def from_QQ(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return MPZ(a.numerator) + + def from_QQ_python(K1, a, K0): + """Convert Python's ``Fraction`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return MPZ(a.numerator) + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to GMPY's ``mpz``. """ + return MPZ(K0.to_int(a)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert GMPY's ``mpz`` to GMPY's ``mpz``. """ + return a + + def from_QQ_gmpy(K1, a, K0): + """Convert GMPY ``mpq`` to GMPY's ``mpz``. """ + if a.denominator == 1: + return a.numerator + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to GMPY's ``mpz``. """ + p, q = K0.to_rational(a) + + if q == 1: + # XXX: If MPZ is flint.fmpz and p is a gmpy2.mpz, then we need + # to convert via int because fmpz and mpz do not know about each + # other. + return MPZ(int(p)) + + def from_GaussianIntegerRing(K1, a, K0): + if a.y == 0: + return a.x + + def from_EX(K1, a, K0): + """Convert ``Expression`` to GMPY's ``mpz``. """ + if a.is_Integer: + return K1.from_sympy(a) + + def gcdex(self, a, b): + """Compute extended GCD of ``a`` and ``b``. """ + h, s, t = gcdex(a, b) + # XXX: This conditional logic should be handled somewhere else. + if GROUND_TYPES == 'gmpy': + return s, t, h + else: + return h, s, t + + def gcd(self, a, b): + """Compute GCD of ``a`` and ``b``. """ + return gcd(a, b) + + def lcm(self, a, b): + """Compute LCM of ``a`` and ``b``. """ + return lcm(a, b) + + def sqrt(self, a): + """Compute square root of ``a``. """ + return sqrt(a) + + def is_square(self, a): + """Return ``True`` if ``a`` is a square. + + Explanation + =========== + An integer is a square if and only if there exists an integer + ``b`` such that ``b * b == a``. + """ + return is_square(a) + + def exsqrt(self, a): + """Non-negative square root of ``a`` if ``a`` is a square. + + See also + ======== + is_square + """ + if a < 0: + return None + root, rem = sqrtrem(a) + if rem != 0: + return None + return root + + def factorial(self, a): + """Compute factorial of ``a``. """ + return factorial(a) + + +ZZ = IntegerRing() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/modularinteger.py b/lib/python3.10/site-packages/sympy/polys/domains/modularinteger.py new file mode 100644 index 0000000000000000000000000000000000000000..39a0237563c69a77e4736466d1ebcaa7ca39485f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/modularinteger.py @@ -0,0 +1,237 @@ +"""Implementation of :class:`ModularInteger` class. """ + +from __future__ import annotations +from typing import Any + +import operator + +from sympy.polys.polyutils import PicklableWithSlots +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.domains.domainelement import DomainElement + +from sympy.utilities import public +from sympy.utilities.exceptions import sympy_deprecation_warning + +@public +class ModularInteger(PicklableWithSlots, DomainElement): + """A class representing a modular integer. """ + + mod, dom, sym, _parent = None, None, None, None + + __slots__ = ('val',) + + def parent(self): + return self._parent + + def __init__(self, val): + if isinstance(val, self.__class__): + self.val = val.val % self.mod + else: + self.val = self.dom.convert(val) % self.mod + + def modulus(self): + return self.mod + + def __hash__(self): + return hash((self.val, self.mod)) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, self.val) + + def __str__(self): + return "%s mod %s" % (self.val, self.mod) + + def __int__(self): + return int(self.val) + + def to_int(self): + + sympy_deprecation_warning( + """ModularInteger.to_int() is deprecated. + + Use int(a) or K = GF(p) and K.to_int(a) instead of a.to_int(). + """, + deprecated_since_version="1.13", + active_deprecations_target="modularinteger-to-int", + ) + + if self.sym: + if self.val <= self.mod // 2: + return self.val + else: + return self.val - self.mod + else: + return self.val + + def __pos__(self): + return self + + def __neg__(self): + return self.__class__(-self.val) + + @classmethod + def _get_val(cls, other): + if isinstance(other, cls): + return other.val + else: + try: + return cls.dom.convert(other) + except CoercionFailed: + return None + + def __add__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val + val) + else: + return NotImplemented + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val - val) + else: + return NotImplemented + + def __rsub__(self, other): + return (-self).__add__(other) + + def __mul__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val * val) + else: + return NotImplemented + + def __rmul__(self, other): + return self.__mul__(other) + + def __truediv__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val * self._invert(val)) + else: + return NotImplemented + + def __rtruediv__(self, other): + return self.invert().__mul__(other) + + def __mod__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(self.val % val) + else: + return NotImplemented + + def __rmod__(self, other): + val = self._get_val(other) + + if val is not None: + return self.__class__(val % self.val) + else: + return NotImplemented + + def __pow__(self, exp): + if not exp: + return self.__class__(self.dom.one) + + if exp < 0: + val, exp = self.invert().val, -exp + else: + val = self.val + + return self.__class__(pow(val, int(exp), self.mod)) + + def _compare(self, other, op): + val = self._get_val(other) + + if val is None: + return NotImplemented + + return op(self.val, val % self.mod) + + def _compare_deprecated(self, other, op): + val = self._get_val(other) + + if val is None: + return NotImplemented + + sympy_deprecation_warning( + """Ordered comparisons with modular integers are deprecated. + + Use e.g. int(a) < int(b) instead of a < b. + """, + deprecated_since_version="1.13", + active_deprecations_target="modularinteger-compare", + stacklevel=4, + ) + + return op(self.val, val % self.mod) + + def __eq__(self, other): + return self._compare(other, operator.eq) + + def __ne__(self, other): + return self._compare(other, operator.ne) + + def __lt__(self, other): + return self._compare_deprecated(other, operator.lt) + + def __le__(self, other): + return self._compare_deprecated(other, operator.le) + + def __gt__(self, other): + return self._compare_deprecated(other, operator.gt) + + def __ge__(self, other): + return self._compare_deprecated(other, operator.ge) + + def __bool__(self): + return bool(self.val) + + @classmethod + def _invert(cls, value): + return cls.dom.invert(value, cls.mod) + + def invert(self): + return self.__class__(self._invert(self.val)) + +_modular_integer_cache: dict[tuple[Any, Any, Any], type[ModularInteger]] = {} + +def ModularIntegerFactory(_mod, _dom, _sym, parent): + """Create custom class for specific integer modulus.""" + try: + _mod = _dom.convert(_mod) + except CoercionFailed: + ok = False + else: + ok = True + + if not ok or _mod < 1: + raise ValueError("modulus must be a positive integer, got %s" % _mod) + + key = _mod, _dom, _sym + + try: + cls = _modular_integer_cache[key] + except KeyError: + class cls(ModularInteger): + mod, dom, sym = _mod, _dom, _sym + _parent = parent + + if _sym: + cls.__name__ = "SymmetricModularIntegerMod%s" % _mod + else: + cls.__name__ = "ModularIntegerMod%s" % _mod + + _modular_integer_cache[key] = cls + + return cls diff --git a/lib/python3.10/site-packages/sympy/polys/domains/mpelements.py b/lib/python3.10/site-packages/sympy/polys/domains/mpelements.py new file mode 100644 index 0000000000000000000000000000000000000000..3652c268d714093027a194c30c7ecd5bc680601b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/mpelements.py @@ -0,0 +1,177 @@ +"""Real and complex elements. """ + + +from sympy.external.gmpy import MPQ +from sympy.polys.domains.domainelement import DomainElement +from sympy.utilities import public + +from mpmath.ctx_mp_python import PythonMPContext, _mpf, _mpc, _constant +from mpmath.libmp import (MPZ_ONE, fzero, fone, finf, fninf, fnan, + round_nearest, mpf_mul, repr_dps, int_types, + from_int, from_float, from_str, to_rational) + + +@public +class RealElement(_mpf, DomainElement): + """An element of a real domain. """ + + __slots__ = ('__mpf__',) + + def _set_mpf(self, val): + self.__mpf__ = val + + _mpf_ = property(lambda self: self.__mpf__, _set_mpf) + + def parent(self): + return self.context._parent + +@public +class ComplexElement(_mpc, DomainElement): + """An element of a complex domain. """ + + __slots__ = ('__mpc__',) + + def _set_mpc(self, val): + self.__mpc__ = val + + _mpc_ = property(lambda self: self.__mpc__, _set_mpc) + + def parent(self): + return self.context._parent + +new = object.__new__ + +@public +class MPContext(PythonMPContext): + + def __init__(ctx, prec=53, dps=None, tol=None, real=False): + ctx._prec_rounding = [prec, round_nearest] + + if dps is None: + ctx._set_prec(prec) + else: + ctx._set_dps(dps) + + ctx.mpf = RealElement + ctx.mpc = ComplexElement + ctx.mpf._ctxdata = [ctx.mpf, new, ctx._prec_rounding] + ctx.mpc._ctxdata = [ctx.mpc, new, ctx._prec_rounding] + + if real: + ctx.mpf.context = ctx + else: + ctx.mpc.context = ctx + + ctx.constant = _constant + ctx.constant._ctxdata = [ctx.mpf, new, ctx._prec_rounding] + ctx.constant.context = ctx + + ctx.types = [ctx.mpf, ctx.mpc, ctx.constant] + ctx.trap_complex = True + ctx.pretty = True + + if tol is None: + ctx.tol = ctx._make_tol() + elif tol is False: + ctx.tol = fzero + else: + ctx.tol = ctx._convert_tol(tol) + + ctx.tolerance = ctx.make_mpf(ctx.tol) + + if not ctx.tolerance: + ctx.max_denom = 1000000 + else: + ctx.max_denom = int(1/ctx.tolerance) + + ctx.zero = ctx.make_mpf(fzero) + ctx.one = ctx.make_mpf(fone) + ctx.j = ctx.make_mpc((fzero, fone)) + ctx.inf = ctx.make_mpf(finf) + ctx.ninf = ctx.make_mpf(fninf) + ctx.nan = ctx.make_mpf(fnan) + + def _make_tol(ctx): + hundred = (0, 25, 2, 5) + eps = (0, MPZ_ONE, 1-ctx.prec, 1) + return mpf_mul(hundred, eps) + + def make_tol(ctx): + return ctx.make_mpf(ctx._make_tol()) + + def _convert_tol(ctx, tol): + if isinstance(tol, int_types): + return from_int(tol) + if isinstance(tol, float): + return from_float(tol) + if hasattr(tol, "_mpf_"): + return tol._mpf_ + prec, rounding = ctx._prec_rounding + if isinstance(tol, str): + return from_str(tol, prec, rounding) + raise ValueError("expected a real number, got %s" % tol) + + def _convert_fallback(ctx, x, strings): + raise TypeError("cannot create mpf from " + repr(x)) + + @property + def _repr_digits(ctx): + return repr_dps(ctx._prec) + + @property + def _str_digits(ctx): + return ctx._dps + + def to_rational(ctx, s, limit=True): + p, q = to_rational(s._mpf_) + + # Needed for GROUND_TYPES=flint if gmpy2 is installed because mpmath's + # to_rational() function returns a gmpy2.mpz instance and if MPQ is + # flint.fmpq then MPQ(p, q) will fail. + p = int(p) + + if not limit or q <= ctx.max_denom: + return p, q + + p0, q0, p1, q1 = 0, 1, 1, 0 + n, d = p, q + + while True: + a = n//d + q2 = q0 + a*q1 + if q2 > ctx.max_denom: + break + p0, q0, p1, q1 = p1, q1, p0 + a*p1, q2 + n, d = d, n - a*d + + k = (ctx.max_denom - q0)//q1 + + number = MPQ(p, q) + bound1 = MPQ(p0 + k*p1, q0 + k*q1) + bound2 = MPQ(p1, q1) + + if not bound2 or not bound1: + return p, q + elif abs(bound2 - number) <= abs(bound1 - number): + return bound2.numerator, bound2.denominator + else: + return bound1.numerator, bound1.denominator + + def almosteq(ctx, s, t, rel_eps=None, abs_eps=None): + t = ctx.convert(t) + if abs_eps is None and rel_eps is None: + rel_eps = abs_eps = ctx.tolerance or ctx.make_tol() + if abs_eps is None: + abs_eps = ctx.convert(rel_eps) + elif rel_eps is None: + rel_eps = ctx.convert(abs_eps) + diff = abs(s-t) + if diff <= abs_eps: + return True + abss = abs(s) + abst = abs(t) + if abss < abst: + err = diff/abst + else: + err = diff/abss + return err <= rel_eps diff --git a/lib/python3.10/site-packages/sympy/polys/domains/old_fractionfield.py b/lib/python3.10/site-packages/sympy/polys/domains/old_fractionfield.py new file mode 100644 index 0000000000000000000000000000000000000000..25d849c39e45259728479ab0305d4956053ae743 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/old_fractionfield.py @@ -0,0 +1,188 @@ +"""Implementation of :class:`FractionField` class. """ + + +from sympy.polys.domains.field import Field +from sympy.polys.domains.compositedomain import CompositeDomain +from sympy.polys.polyclasses import DMF +from sympy.polys.polyerrors import GeneratorsNeeded +from sympy.polys.polyutils import dict_from_basic, basic_from_dict, _dict_reorder +from sympy.utilities import public + +@public +class FractionField(Field, CompositeDomain): + """A class for representing rational function fields. """ + + dtype = DMF + is_FractionField = is_Frac = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self, dom, *gens): + if not gens: + raise GeneratorsNeeded("generators not specified") + + lev = len(gens) - 1 + self.ngens = len(gens) + + self.zero = self.dtype.zero(lev, dom) + self.one = self.dtype.one(lev, dom) + + self.domain = self.dom = dom + self.symbols = self.gens = gens + + def set_domain(self, dom): + """Make a new fraction field with given domain. """ + return self.__class__(dom, *self.gens) + + def new(self, element): + return self.dtype(element, self.dom, len(self.gens) - 1) + + def __str__(self): + return str(self.dom) + '(' + ','.join(map(str, self.gens)) + ')' + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.dom, self.gens)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, FractionField) and \ + self.dtype == other.dtype and self.dom == other.dom and self.gens == other.gens + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return (basic_from_dict(a.numer().to_sympy_dict(), *self.gens) / + basic_from_dict(a.denom().to_sympy_dict(), *self.gens)) + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + p, q = a.as_numer_denom() + + num, _ = dict_from_basic(p, gens=self.gens) + den, _ = dict_from_basic(q, gens=self.gens) + + for k, v in num.items(): + num[k] = self.dom.from_sympy(v) + + for k, v in den.items(): + den[k] = self.dom.from_sympy(v) + + return self((num, den)).cancel() + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1(K1.dom.convert(a, K0)) + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert a ``DMF`` object to ``dtype``. """ + if K1.gens == K0.gens: + if K1.dom == K0.dom: + return K1(a.to_list()) + else: + return K1(a.convert(K1.dom).to_list()) + else: + monoms, coeffs = _dict_reorder(a.to_dict(), K0.gens, K1.gens) + + if K1.dom != K0.dom: + coeffs = [ K1.dom.convert(c, K0.dom) for c in coeffs ] + + return K1(dict(zip(monoms, coeffs))) + + def from_FractionField(K1, a, K0): + """ + Convert a fraction field element to another fraction field. + + Examples + ======== + + >>> from sympy.polys.polyclasses import DMF + >>> from sympy.polys.domains import ZZ, QQ + >>> from sympy.abc import x + + >>> f = DMF(([ZZ(1), ZZ(2)], [ZZ(1), ZZ(1)]), ZZ) + + >>> QQx = QQ.old_frac_field(x) + >>> ZZx = ZZ.old_frac_field(x) + + >>> QQx.from_FractionField(f, ZZx) + DMF([1, 2], [1, 1], QQ) + + """ + if K1.gens == K0.gens: + if K1.dom == K0.dom: + return a + else: + return K1((a.numer().convert(K1.dom).to_list(), + a.denom().convert(K1.dom).to_list())) + elif set(K0.gens).issubset(K1.gens): + nmonoms, ncoeffs = _dict_reorder( + a.numer().to_dict(), K0.gens, K1.gens) + dmonoms, dcoeffs = _dict_reorder( + a.denom().to_dict(), K0.gens, K1.gens) + + if K1.dom != K0.dom: + ncoeffs = [ K1.dom.convert(c, K0.dom) for c in ncoeffs ] + dcoeffs = [ K1.dom.convert(c, K0.dom) for c in dcoeffs ] + + return K1((dict(zip(nmonoms, ncoeffs)), dict(zip(dmonoms, dcoeffs)))) + + def get_ring(self): + """Returns a ring associated with ``self``. """ + from sympy.polys.domains import PolynomialRing + return PolynomialRing(self.dom, *self.gens) + + def poly_ring(self, *gens): + """Returns a polynomial ring, i.e. `K[X]`. """ + raise NotImplementedError('nested domains not allowed') + + def frac_field(self, *gens): + """Returns a fraction field, i.e. `K(X)`. """ + raise NotImplementedError('nested domains not allowed') + + def is_positive(self, a): + """Returns True if ``a`` is positive. """ + return self.dom.is_positive(a.numer().LC()) + + def is_negative(self, a): + """Returns True if ``a`` is negative. """ + return self.dom.is_negative(a.numer().LC()) + + def is_nonpositive(self, a): + """Returns True if ``a`` is non-positive. """ + return self.dom.is_nonpositive(a.numer().LC()) + + def is_nonnegative(self, a): + """Returns True if ``a`` is non-negative. """ + return self.dom.is_nonnegative(a.numer().LC()) + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numer() + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denom() + + def factorial(self, a): + """Returns factorial of ``a``. """ + return self.dtype(self.dom.factorial(a)) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/old_polynomialring.py b/lib/python3.10/site-packages/sympy/polys/domains/old_polynomialring.py new file mode 100644 index 0000000000000000000000000000000000000000..c29a4529aac3c64b29d8c670ac45b6c100294ced --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/old_polynomialring.py @@ -0,0 +1,490 @@ +"""Implementation of :class:`PolynomialRing` class. """ + + +from sympy.polys.agca.modules import FreeModulePolyRing +from sympy.polys.domains.compositedomain import CompositeDomain +from sympy.polys.domains.old_fractionfield import FractionField +from sympy.polys.domains.ring import Ring +from sympy.polys.orderings import monomial_key, build_product_order +from sympy.polys.polyclasses import DMP, DMF +from sympy.polys.polyerrors import (GeneratorsNeeded, PolynomialError, + CoercionFailed, ExactQuotientFailed, NotReversible) +from sympy.polys.polyutils import dict_from_basic, basic_from_dict, _dict_reorder +from sympy.utilities import public +from sympy.utilities.iterables import iterable + + +@public +class PolynomialRingBase(Ring, CompositeDomain): + """ + Base class for generalized polynomial rings. + + This base class should be used for uniform access to generalized polynomial + rings. Subclasses only supply information about the element storage etc. + + Do not instantiate. + """ + + has_assoc_Ring = True + has_assoc_Field = True + + default_order = "grevlex" + + def __init__(self, dom, *gens, **opts): + if not gens: + raise GeneratorsNeeded("generators not specified") + + lev = len(gens) - 1 + self.ngens = len(gens) + + self.zero = self.dtype.zero(lev, dom) + self.one = self.dtype.one(lev, dom) + + self.domain = self.dom = dom + self.symbols = self.gens = gens + # NOTE 'order' may not be set if inject was called through CompositeDomain + self.order = opts.get('order', monomial_key(self.default_order)) + + def set_domain(self, dom): + """Return a new polynomial ring with given domain. """ + return self.__class__(dom, *self.gens, order=self.order) + + def new(self, element): + return self.dtype(element, self.dom, len(self.gens) - 1) + + def _ground_new(self, element): + return self.one.ground_new(element) + + def _from_dict(self, element): + return DMP.from_dict(element, len(self.gens) - 1, self.dom) + + def __str__(self): + s_order = str(self.order) + orderstr = ( + " order=" + s_order) if s_order != self.default_order else "" + return str(self.dom) + '[' + ','.join(map(str, self.gens)) + orderstr + ']' + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.dom, + self.gens, self.order)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, PolynomialRingBase) and \ + self.dtype == other.dtype and self.dom == other.dom and \ + self.gens == other.gens and self.order == other.order + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return K1._ground_new(K1.dom.convert(a, K0)) + + def from_AlgebraicField(K1, a, K0): + """Convert a ``ANP`` object to ``dtype``. """ + if K1.dom == K0: + return K1._ground_new(a) + + def from_PolynomialRing(K1, a, K0): + """Convert a ``PolyElement`` object to ``dtype``. """ + if K1.gens == K0.symbols: + if K1.dom == K0.dom: + return K1(dict(a)) # set the correct ring + else: + convert_dom = lambda c: K1.dom.convert_from(c, K0.dom) + return K1._from_dict({m: convert_dom(c) for m, c in a.items()}) + else: + monoms, coeffs = _dict_reorder(a.to_dict(), K0.symbols, K1.gens) + + if K1.dom != K0.dom: + coeffs = [ K1.dom.convert(c, K0.dom) for c in coeffs ] + + return K1._from_dict(dict(zip(monoms, coeffs))) + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert a ``DMP`` object to ``dtype``. """ + if K1.gens == K0.gens: + if K1.dom != K0.dom: + a = a.convert(K1.dom) + return K1(a.to_list()) + else: + monoms, coeffs = _dict_reorder(a.to_dict(), K0.gens, K1.gens) + + if K1.dom != K0.dom: + coeffs = [ K1.dom.convert(c, K0.dom) for c in coeffs ] + + return K1(dict(zip(monoms, coeffs))) + + def get_field(self): + """Returns a field associated with ``self``. """ + return FractionField(self.dom, *self.gens) + + def poly_ring(self, *gens): + """Returns a polynomial ring, i.e. ``K[X]``. """ + raise NotImplementedError('nested domains not allowed') + + def frac_field(self, *gens): + """Returns a fraction field, i.e. ``K(X)``. """ + raise NotImplementedError('nested domains not allowed') + + def revert(self, a): + try: + return self.exquo(self.one, a) + except (ExactQuotientFailed, ZeroDivisionError): + raise NotReversible('%s is not a unit' % a) + + def gcdex(self, a, b): + """Extended GCD of ``a`` and ``b``. """ + return a.gcdex(b) + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + return a.gcd(b) + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + return a.lcm(b) + + def factorial(self, a): + """Returns factorial of ``a``. """ + return self.dtype(self.dom.factorial(a)) + + def _vector_to_sdm(self, v, order): + """ + For internal use by the modules class. + + Convert an iterable of elements of this ring into a sparse distributed + module element. + """ + raise NotImplementedError + + def _sdm_to_dics(self, s, n): + """Helper for _sdm_to_vector.""" + from sympy.polys.distributedmodules import sdm_to_dict + dic = sdm_to_dict(s) + res = [{} for _ in range(n)] + for k, v in dic.items(): + res[k[0]][k[1:]] = v + return res + + def _sdm_to_vector(self, s, n): + """ + For internal use by the modules class. + + Convert a sparse distributed module into a list of length ``n``. + + Examples + ======== + + >>> from sympy import QQ, ilex + >>> from sympy.abc import x, y + >>> R = QQ.old_poly_ring(x, y, order=ilex) + >>> L = [((1, 1, 1), QQ(1)), ((0, 1, 0), QQ(1)), ((0, 0, 1), QQ(2))] + >>> R._sdm_to_vector(L, 2) + [DMF([[1], [2, 0]], [[1]], QQ), DMF([[1, 0], []], [[1]], QQ)] + """ + dics = self._sdm_to_dics(s, n) + # NOTE this works for global and local rings! + return [self(x) for x in dics] + + def free_module(self, rank): + """ + Generate a free module of rank ``rank`` over ``self``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2) + QQ[x]**2 + """ + return FreeModulePolyRing(self, rank) + + +def _vector_to_sdm_helper(v, order): + """Helper method for common code in Global and Local poly rings.""" + from sympy.polys.distributedmodules import sdm_from_dict + d = {} + for i, e in enumerate(v): + for key, value in e.to_dict().items(): + d[(i,) + key] = value + return sdm_from_dict(d, order) + + +@public +class GlobalPolynomialRing(PolynomialRingBase): + """A true polynomial ring, with objects DMP. """ + + is_PolynomialRing = is_Poly = True + dtype = DMP + + def new(self, element): + if isinstance(element, dict): + return DMP.from_dict(element, len(self.gens) - 1, self.dom) + elif element in self.dom: + return self._ground_new(self.dom.convert(element)) + else: + return self.dtype(element, self.dom, len(self.gens) - 1) + + def from_FractionField(K1, a, K0): + """ + Convert a ``DMF`` object to ``DMP``. + + Examples + ======== + + >>> from sympy.polys.polyclasses import DMP, DMF + >>> from sympy.polys.domains import ZZ + >>> from sympy.abc import x + + >>> f = DMF(([ZZ(1), ZZ(1)], [ZZ(1)]), ZZ) + >>> K = ZZ.old_frac_field(x) + + >>> F = ZZ.old_poly_ring(x).from_FractionField(f, K) + + >>> F == DMP([ZZ(1), ZZ(1)], ZZ) + True + >>> type(F) # doctest: +SKIP + + + """ + if a.denom().is_one: + return K1.from_GlobalPolynomialRing(a.numer(), K0) + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return basic_from_dict(a.to_sympy_dict(), *self.gens) + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + try: + rep, _ = dict_from_basic(a, gens=self.gens) + except PolynomialError: + raise CoercionFailed("Cannot convert %s to type %s" % (a, self)) + + for k, v in rep.items(): + rep[k] = self.dom.from_sympy(v) + + return DMP.from_dict(rep, self.ngens - 1, self.dom) + + def is_positive(self, a): + """Returns True if ``LC(a)`` is positive. """ + return self.dom.is_positive(a.LC()) + + def is_negative(self, a): + """Returns True if ``LC(a)`` is negative. """ + return self.dom.is_negative(a.LC()) + + def is_nonpositive(self, a): + """Returns True if ``LC(a)`` is non-positive. """ + return self.dom.is_nonpositive(a.LC()) + + def is_nonnegative(self, a): + """Returns True if ``LC(a)`` is non-negative. """ + return self.dom.is_nonnegative(a.LC()) + + def _vector_to_sdm(self, v, order): + """ + Examples + ======== + + >>> from sympy import lex, QQ + >>> from sympy.abc import x, y + >>> R = QQ.old_poly_ring(x, y) + >>> f = R.convert(x + 2*y) + >>> g = R.convert(x * y) + >>> R._vector_to_sdm([f, g], lex) + [((1, 1, 1), 1), ((0, 1, 0), 1), ((0, 0, 1), 2)] + """ + return _vector_to_sdm_helper(v, order) + + +class GeneralizedPolynomialRing(PolynomialRingBase): + """A generalized polynomial ring, with objects DMF. """ + + dtype = DMF + + def new(self, a): + """Construct an element of ``self`` domain from ``a``. """ + res = self.dtype(a, self.dom, len(self.gens) - 1) + + # make sure res is actually in our ring + if res.denom().terms(order=self.order)[0][0] != (0,)*len(self.gens): + from sympy.printing.str import sstr + raise CoercionFailed("denominator %s not allowed in %s" + % (sstr(res), self)) + return res + + def __contains__(self, a): + try: + a = self.convert(a) + except CoercionFailed: + return False + return a.denom().terms(order=self.order)[0][0] == (0,)*len(self.gens) + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return (basic_from_dict(a.numer().to_sympy_dict(), *self.gens) / + basic_from_dict(a.denom().to_sympy_dict(), *self.gens)) + + def from_sympy(self, a): + """Convert SymPy's expression to ``dtype``. """ + p, q = a.as_numer_denom() + + num, _ = dict_from_basic(p, gens=self.gens) + den, _ = dict_from_basic(q, gens=self.gens) + + for k, v in num.items(): + num[k] = self.dom.from_sympy(v) + + for k, v in den.items(): + den[k] = self.dom.from_sympy(v) + + return self((num, den)).cancel() + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``. """ + # Elements are DMF that will always divide (except 0). The result is + # not guaranteed to be in this ring, so we have to check that. + r = a / b + + try: + r = self.new((r.num, r.den)) + except CoercionFailed: + raise ExactQuotientFailed(a, b, self) + + return r + + def from_FractionField(K1, a, K0): + dmf = K1.get_field().from_FractionField(a, K0) + return K1((dmf.num, dmf.den)) + + def _vector_to_sdm(self, v, order): + """ + Turn an iterable into a sparse distributed module. + + Note that the vector is multiplied by a unit first to make all entries + polynomials. + + Examples + ======== + + >>> from sympy import ilex, QQ + >>> from sympy.abc import x, y + >>> R = QQ.old_poly_ring(x, y, order=ilex) + >>> f = R.convert((x + 2*y) / (1 + x)) + >>> g = R.convert(x * y) + >>> R._vector_to_sdm([f, g], ilex) + [((0, 0, 1), 2), ((0, 1, 0), 1), ((1, 1, 1), 1), ((1, + 2, 1), 1)] + """ + # NOTE this is quite inefficient... + u = self.one.numer() + for x in v: + u *= x.denom() + return _vector_to_sdm_helper([x.numer()*u/x.denom() for x in v], order) + + +@public +def PolynomialRing(dom, *gens, **opts): + r""" + Create a generalized multivariate polynomial ring. + + A generalized polynomial ring is defined by a ground field `K`, a set + of generators (typically `x_1, \ldots, x_n`) and a monomial order `<`. + The monomial order can be global, local or mixed. In any case it induces + a total ordering on the monomials, and there exists for every (non-zero) + polynomial `f \in K[x_1, \ldots, x_n]` a well-defined "leading monomial" + `LM(f) = LM(f, >)`. One can then define a multiplicative subset + `S = S_> = \{f \in K[x_1, \ldots, x_n] | LM(f) = 1\}`. The generalized + polynomial ring corresponding to the monomial order is + `R = S^{-1}K[x_1, \ldots, x_n]`. + + If `>` is a so-called global order, that is `1` is the smallest monomial, + then we just have `S = K` and `R = K[x_1, \ldots, x_n]`. + + Examples + ======== + + A few examples may make this clearer. + + >>> from sympy.abc import x, y + >>> from sympy import QQ + + Our first ring uses global lexicographic order. + + >>> R1 = QQ.old_poly_ring(x, y, order=(("lex", x, y),)) + + The second ring uses local lexicographic order. Note that when using a + single (non-product) order, you can just specify the name and omit the + variables: + + >>> R2 = QQ.old_poly_ring(x, y, order="ilex") + + The third and fourth rings use a mixed orders: + + >>> o1 = (("ilex", x), ("lex", y)) + >>> o2 = (("lex", x), ("ilex", y)) + >>> R3 = QQ.old_poly_ring(x, y, order=o1) + >>> R4 = QQ.old_poly_ring(x, y, order=o2) + + We will investigate what elements of `K(x, y)` are contained in the various + rings. + + >>> L = [x, 1/x, y/(1 + x), 1/(1 + y), 1/(1 + x*y)] + >>> test = lambda R: [f in R for f in L] + + The first ring is just `K[x, y]`: + + >>> test(R1) + [True, False, False, False, False] + + The second ring is R1 localised at the maximal ideal (x, y): + + >>> test(R2) + [True, False, True, True, True] + + The third ring is R1 localised at the prime ideal (x): + + >>> test(R3) + [True, False, True, False, True] + + Finally the fourth ring is R1 localised at `S = K[x, y] \setminus yK[y]`: + + >>> test(R4) + [True, False, False, True, False] + """ + + order = opts.get("order", GeneralizedPolynomialRing.default_order) + if iterable(order): + order = build_product_order(order, gens) + order = monomial_key(order) + opts['order'] = order + + if order.is_global: + return GlobalPolynomialRing(dom, *gens, **opts) + else: + return GeneralizedPolynomialRing(dom, *gens, **opts) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/polynomialring.py b/lib/python3.10/site-packages/sympy/polys/domains/polynomialring.py new file mode 100644 index 0000000000000000000000000000000000000000..bad73208f866c33c7ffcbffab2b7e9eed97c94ec --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/polynomialring.py @@ -0,0 +1,199 @@ +"""Implementation of :class:`PolynomialRing` class. """ + + +from sympy.polys.domains.ring import Ring +from sympy.polys.domains.compositedomain import CompositeDomain + +from sympy.polys.polyerrors import CoercionFailed, GeneratorsError +from sympy.utilities import public + +@public +class PolynomialRing(Ring, CompositeDomain): + """A class for representing multivariate polynomial rings. """ + + is_PolynomialRing = is_Poly = True + + has_assoc_Ring = True + has_assoc_Field = True + + def __init__(self, domain_or_ring, symbols=None, order=None): + from sympy.polys.rings import PolyRing + + if isinstance(domain_or_ring, PolyRing) and symbols is None and order is None: + ring = domain_or_ring + else: + ring = PolyRing(symbols, domain_or_ring, order) + + self.ring = ring + self.dtype = ring.dtype + + self.gens = ring.gens + self.ngens = ring.ngens + self.symbols = ring.symbols + self.domain = ring.domain + + + if symbols: + if ring.domain.is_Field and ring.domain.is_Exact and len(symbols)==1: + self.is_PID = True + + # TODO: remove this + self.dom = self.domain + + def new(self, element): + return self.ring.ring_new(element) + + @property + def zero(self): + return self.ring.zero + + @property + def one(self): + return self.ring.one + + @property + def order(self): + return self.ring.order + + def __str__(self): + return str(self.domain) + '[' + ','.join(map(str, self.symbols)) + ']' + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype.ring, self.domain, self.symbols)) + + def __eq__(self, other): + """Returns `True` if two domains are equivalent. """ + return isinstance(other, PolynomialRing) and \ + (self.dtype.ring, self.domain, self.symbols) == \ + (other.dtype.ring, other.domain, other.symbols) + + def is_unit(self, a): + """Returns ``True`` if ``a`` is a unit of ``self``""" + if not a.is_ground: + return False + K = self.domain + return K.is_unit(K.convert_from(a, self)) + + def canonical_unit(self, a): + u = self.domain.canonical_unit(a.LC) + return self.ring.ground_new(u) + + def to_sympy(self, a): + """Convert `a` to a SymPy object. """ + return a.as_expr() + + def from_sympy(self, a): + """Convert SymPy's expression to `dtype`. """ + return self.ring.from_expr(a) + + def from_ZZ(K1, a, K0): + """Convert a Python `int` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_python(K1, a, K0): + """Convert a Python `int` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ(K1, a, K0): + """Convert a Python `Fraction` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ_python(K1, a, K0): + """Convert a Python `Fraction` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY `mpz` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY `mpq` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianIntegerRing(K1, a, K0): + """Convert a `GaussianInteger` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_GaussianRationalField(K1, a, K0): + """Convert a `GaussianRational` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_RealField(K1, a, K0): + """Convert a mpmath `mpf` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_ComplexField(K1, a, K0): + """Convert a mpmath `mpf` object to `dtype`. """ + return K1(K1.domain.convert(a, K0)) + + def from_AlgebraicField(K1, a, K0): + """Convert an algebraic number to ``dtype``. """ + if K1.domain != K0: + a = K1.domain.convert_from(a, K0) + if a is not None: + return K1.new(a) + + def from_PolynomialRing(K1, a, K0): + """Convert a polynomial to ``dtype``. """ + try: + return a.set_ring(K1.ring) + except (CoercionFailed, GeneratorsError): + return None + + def from_FractionField(K1, a, K0): + """Convert a rational function to ``dtype``. """ + if K1.domain == K0: + return K1.ring.from_list([a]) + + q, r = K0.numer(a).div(K0.denom(a)) + + if r.is_zero: + return K1.from_PolynomialRing(q, K0.field.ring.to_domain()) + else: + return None + + def from_GlobalPolynomialRing(K1, a, K0): + """Convert from old poly ring to ``dtype``. """ + if K1.symbols == K0.gens: + ad = a.to_dict() + if K1.domain != K0.domain: + ad = {m: K1.domain.convert(c) for m, c in ad.items()} + return K1(ad) + elif a.is_ground and K0.domain == K1: + return K1.convert_from(a.to_list()[0], K0.domain) + + def get_field(self): + """Returns a field associated with `self`. """ + return self.ring.to_field().to_domain() + + def is_positive(self, a): + """Returns True if `LC(a)` is positive. """ + return self.domain.is_positive(a.LC) + + def is_negative(self, a): + """Returns True if `LC(a)` is negative. """ + return self.domain.is_negative(a.LC) + + def is_nonpositive(self, a): + """Returns True if `LC(a)` is non-positive. """ + return self.domain.is_nonpositive(a.LC) + + def is_nonnegative(self, a): + """Returns True if `LC(a)` is non-negative. """ + return self.domain.is_nonnegative(a.LC) + + def gcdex(self, a, b): + """Extended GCD of `a` and `b`. """ + return a.gcdex(b) + + def gcd(self, a, b): + """Returns GCD of `a` and `b`. """ + return a.gcd(b) + + def lcm(self, a, b): + """Returns LCM of `a` and `b`. """ + return a.lcm(b) + + def factorial(self, a): + """Returns factorial of `a`. """ + return self.dtype(self.domain.factorial(a)) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/pythonfinitefield.py b/lib/python3.10/site-packages/sympy/polys/domains/pythonfinitefield.py new file mode 100644 index 0000000000000000000000000000000000000000..44baa4f6d1b43317283041206eaa43e06a5cc8db --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/pythonfinitefield.py @@ -0,0 +1,16 @@ +"""Implementation of :class:`PythonFiniteField` class. """ + + +from sympy.polys.domains.finitefield import FiniteField +from sympy.polys.domains.pythonintegerring import PythonIntegerRing + +from sympy.utilities import public + +@public +class PythonFiniteField(FiniteField): + """Finite field based on Python's integers. """ + + alias = 'FF_python' + + def __init__(self, mod, symmetric=True): + super().__init__(mod, PythonIntegerRing(), symmetric) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/pythonintegerring.py b/lib/python3.10/site-packages/sympy/polys/domains/pythonintegerring.py new file mode 100644 index 0000000000000000000000000000000000000000..81ee9637a4ebcfaf3c5f11d12c18265305984c25 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/pythonintegerring.py @@ -0,0 +1,98 @@ +"""Implementation of :class:`PythonIntegerRing` class. """ + + +from sympy.core.numbers import int_valued +from sympy.polys.domains.groundtypes import ( + PythonInteger, SymPyInteger, sqrt as python_sqrt, + factorial as python_factorial, python_gcdex, python_gcd, python_lcm, +) +from sympy.polys.domains.integerring import IntegerRing +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class PythonIntegerRing(IntegerRing): + """Integer ring based on Python's ``int`` type. + + This will be used as :ref:`ZZ` if ``gmpy`` and ``gmpy2`` are not + installed. Elements are instances of the standard Python ``int`` type. + """ + + dtype = PythonInteger + zero = dtype(0) + one = dtype(1) + alias = 'ZZ_python' + + def __init__(self): + """Allow instantiation of this domain. """ + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyInteger(a) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Integer: + return PythonInteger(a.p) + elif int_valued(a): + return PythonInteger(int(a)) + else: + raise CoercionFailed("expected an integer, got %s" % a) + + def from_FF_python(K1, a, K0): + """Convert ``ModularInteger(int)`` to Python's ``int``. """ + return K0.to_int(a) + + def from_ZZ_python(K1, a, K0): + """Convert Python's ``int`` to Python's ``int``. """ + return a + + def from_QQ(K1, a, K0): + """Convert Python's ``Fraction`` to Python's ``int``. """ + if a.denominator == 1: + return a.numerator + + def from_QQ_python(K1, a, K0): + """Convert Python's ``Fraction`` to Python's ``int``. """ + if a.denominator == 1: + return a.numerator + + def from_FF_gmpy(K1, a, K0): + """Convert ``ModularInteger(mpz)`` to Python's ``int``. """ + return PythonInteger(K0.to_int(a)) + + def from_ZZ_gmpy(K1, a, K0): + """Convert GMPY's ``mpz`` to Python's ``int``. """ + return PythonInteger(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert GMPY's ``mpq`` to Python's ``int``. """ + if a.denom() == 1: + return PythonInteger(a.numer()) + + def from_RealField(K1, a, K0): + """Convert mpmath's ``mpf`` to Python's ``int``. """ + p, q = K0.to_rational(a) + + if q == 1: + return PythonInteger(p) + + def gcdex(self, a, b): + """Compute extended GCD of ``a`` and ``b``. """ + return python_gcdex(a, b) + + def gcd(self, a, b): + """Compute GCD of ``a`` and ``b``. """ + return python_gcd(a, b) + + def lcm(self, a, b): + """Compute LCM of ``a`` and ``b``. """ + return python_lcm(a, b) + + def sqrt(self, a): + """Compute square root of ``a``. """ + return python_sqrt(a) + + def factorial(self, a): + """Compute factorial of ``a``. """ + return python_factorial(a) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/pythonrational.py b/lib/python3.10/site-packages/sympy/polys/domains/pythonrational.py new file mode 100644 index 0000000000000000000000000000000000000000..87b56d6c929c3ce3ce153dce7b3c210821d706a0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/pythonrational.py @@ -0,0 +1,22 @@ +""" +Rational number type based on Python integers. + +The PythonRational class from here has been moved to +sympy.external.pythonmpq + +This module is just left here for backwards compatibility. +""" + + +from sympy.core.numbers import Rational +from sympy.core.sympify import _sympy_converter +from sympy.utilities import public +from sympy.external.pythonmpq import PythonMPQ + + +PythonRational = public(PythonMPQ) + + +def sympify_pythonrational(arg): + return Rational(arg.numerator, arg.denominator) +_sympy_converter[PythonRational] = sympify_pythonrational diff --git a/lib/python3.10/site-packages/sympy/polys/domains/pythonrationalfield.py b/lib/python3.10/site-packages/sympy/polys/domains/pythonrationalfield.py new file mode 100644 index 0000000000000000000000000000000000000000..51afaef636f000855d51a69fb93eb416ae1e5347 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/pythonrationalfield.py @@ -0,0 +1,73 @@ +"""Implementation of :class:`PythonRationalField` class. """ + + +from sympy.polys.domains.groundtypes import PythonInteger, PythonRational, SymPyRational +from sympy.polys.domains.rationalfield import RationalField +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class PythonRationalField(RationalField): + """Rational field based on :ref:`MPQ`. + + This will be used as :ref:`QQ` if ``gmpy`` and ``gmpy2`` are not + installed. Elements are instances of :ref:`MPQ`. + """ + + dtype = PythonRational + zero = dtype(0) + one = dtype(1) + alias = 'QQ_python' + + def __init__(self): + pass + + def get_ring(self): + """Returns ring associated with ``self``. """ + from sympy.polys.domains import PythonIntegerRing + return PythonIntegerRing() + + def to_sympy(self, a): + """Convert `a` to a SymPy object. """ + return SymPyRational(a.numerator, a.denominator) + + def from_sympy(self, a): + """Convert SymPy's Rational to `dtype`. """ + if a.is_Rational: + return PythonRational(a.p, a.q) + elif a.is_Float: + from sympy.polys.domains import RR + p, q = RR.to_rational(a) + return PythonRational(int(p), int(q)) + else: + raise CoercionFailed("expected `Rational` object, got %s" % a) + + def from_ZZ_python(K1, a, K0): + """Convert a Python `int` object to `dtype`. """ + return PythonRational(a) + + def from_QQ_python(K1, a, K0): + """Convert a Python `Fraction` object to `dtype`. """ + return a + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY `mpz` object to `dtype`. """ + return PythonRational(PythonInteger(a)) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY `mpq` object to `dtype`. """ + return PythonRational(PythonInteger(a.numer()), + PythonInteger(a.denom())) + + def from_RealField(K1, a, K0): + """Convert a mpmath `mpf` object to `dtype`. """ + p, q = K0.to_rational(a) + return PythonRational(int(p), int(q)) + + def numer(self, a): + """Returns numerator of `a`. """ + return a.numerator + + def denom(self, a): + """Returns denominator of `a`. """ + return a.denominator diff --git a/lib/python3.10/site-packages/sympy/polys/domains/quotientring.py b/lib/python3.10/site-packages/sympy/polys/domains/quotientring.py new file mode 100644 index 0000000000000000000000000000000000000000..7e8abf6b210a5627c9c139e41248637c9b88931f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/quotientring.py @@ -0,0 +1,202 @@ +"""Implementation of :class:`QuotientRing` class.""" + + +from sympy.polys.agca.modules import FreeModuleQuotientRing +from sympy.polys.domains.ring import Ring +from sympy.polys.polyerrors import NotReversible, CoercionFailed +from sympy.utilities import public + +# TODO +# - successive quotients (when quotient ideals are implemented) +# - poly rings over quotients? +# - division by non-units in integral domains? + +@public +class QuotientRingElement: + """ + Class representing elements of (commutative) quotient rings. + + Attributes: + + - ring - containing ring + - data - element of ring.ring (i.e. base ring) representing self + """ + + def __init__(self, ring, data): + self.ring = ring + self.data = data + + def __str__(self): + from sympy.printing.str import sstr + data = self.ring.ring.to_sympy(self.data) + return sstr(data) + " + " + str(self.ring.base_ideal) + + __repr__ = __str__ + + def __bool__(self): + return not self.ring.is_zero(self) + + def __add__(self, om): + if not isinstance(om, self.__class__) or om.ring != self.ring: + try: + om = self.ring.convert(om) + except (NotImplementedError, CoercionFailed): + return NotImplemented + return self.ring(self.data + om.data) + + __radd__ = __add__ + + def __neg__(self): + return self.ring(self.data*self.ring.ring.convert(-1)) + + def __sub__(self, om): + return self.__add__(-om) + + def __rsub__(self, om): + return (-self).__add__(om) + + def __mul__(self, o): + if not isinstance(o, self.__class__): + try: + o = self.ring.convert(o) + except (NotImplementedError, CoercionFailed): + return NotImplemented + return self.ring(self.data*o.data) + + __rmul__ = __mul__ + + def __rtruediv__(self, o): + return self.ring.revert(self)*o + + def __truediv__(self, o): + if not isinstance(o, self.__class__): + try: + o = self.ring.convert(o) + except (NotImplementedError, CoercionFailed): + return NotImplemented + return self.ring.revert(o)*self + + def __pow__(self, oth): + if oth < 0: + return self.ring.revert(self) ** -oth + return self.ring(self.data ** oth) + + def __eq__(self, om): + if not isinstance(om, self.__class__) or om.ring != self.ring: + return False + return self.ring.is_zero(self - om) + + def __ne__(self, om): + return not self == om + + +class QuotientRing(Ring): + """ + Class representing (commutative) quotient rings. + + You should not usually instantiate this by hand, instead use the constructor + from the base ring in the construction. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> I = QQ.old_poly_ring(x).ideal(x**3 + 1) + >>> QQ.old_poly_ring(x).quotient_ring(I) + QQ[x]/ + + Shorter versions are possible: + + >>> QQ.old_poly_ring(x)/I + QQ[x]/ + + >>> QQ.old_poly_ring(x)/[x**3 + 1] + QQ[x]/ + + Attributes: + + - ring - the base ring + - base_ideal - the ideal used to form the quotient + """ + + has_assoc_Ring = True + has_assoc_Field = False + dtype = QuotientRingElement + + def __init__(self, ring, ideal): + if not ideal.ring == ring: + raise ValueError('Ideal must belong to %s, got %s' % (ring, ideal)) + self.ring = ring + self.base_ideal = ideal + self.zero = self(self.ring.zero) + self.one = self(self.ring.one) + + def __str__(self): + return str(self.ring) + "/" + str(self.base_ideal) + + def __hash__(self): + return hash((self.__class__.__name__, self.dtype, self.ring, self.base_ideal)) + + def new(self, a): + """Construct an element of ``self`` domain from ``a``. """ + if not isinstance(a, self.ring.dtype): + a = self.ring(a) + # TODO optionally disable reduction? + return self.dtype(self, self.base_ideal.reduce_element(a)) + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + return isinstance(other, QuotientRing) and \ + self.ring == other.ring and self.base_ideal == other.base_ideal + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return K1(K1.ring.convert(a, K0)) + + from_ZZ_python = from_ZZ + from_QQ_python = from_ZZ_python + from_ZZ_gmpy = from_ZZ_python + from_QQ_gmpy = from_ZZ_python + from_RealField = from_ZZ_python + from_GlobalPolynomialRing = from_ZZ_python + from_FractionField = from_ZZ_python + + def from_sympy(self, a): + return self(self.ring.from_sympy(a)) + + def to_sympy(self, a): + return self.ring.to_sympy(a.data) + + def from_QuotientRing(self, a, K0): + if K0 == self: + return a + + def poly_ring(self, *gens): + """Returns a polynomial ring, i.e. ``K[X]``. """ + raise NotImplementedError('nested domains not allowed') + + def frac_field(self, *gens): + """Returns a fraction field, i.e. ``K(X)``. """ + raise NotImplementedError('nested domains not allowed') + + def revert(self, a): + """ + Compute a**(-1), if possible. + """ + I = self.ring.ideal(a.data) + self.base_ideal + try: + return self(I.in_terms_of_generators(1)[0]) + except ValueError: # 1 not in I + raise NotReversible('%s not a unit in %r' % (a, self)) + + def is_zero(self, a): + return self.base_ideal.contains(a.data) + + def free_module(self, rank): + """ + Generate a free module of rank ``rank`` over ``self``. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> (QQ.old_poly_ring(x)/[x**2 + 1]).free_module(2) + (QQ[x]/)**2 + """ + return FreeModuleQuotientRing(self, rank) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/rationalfield.py b/lib/python3.10/site-packages/sympy/polys/domains/rationalfield.py new file mode 100644 index 0000000000000000000000000000000000000000..6da570332de8a6d39a21bb3d57447670c7a98441 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/rationalfield.py @@ -0,0 +1,200 @@ +"""Implementation of :class:`RationalField` class. """ + + +from sympy.external.gmpy import MPQ + +from sympy.polys.domains.groundtypes import SymPyRational, is_square, sqrtrem + +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class RationalField(Field, CharacteristicZero, SimpleDomain): + r"""Abstract base class for the domain :ref:`QQ`. + + The :py:class:`RationalField` class represents the field of rational + numbers $\mathbb{Q}$ as a :py:class:`~.Domain` in the domain system. + :py:class:`RationalField` is a superclass of + :py:class:`PythonRationalField` and :py:class:`GMPYRationalField` one of + which will be the implementation for :ref:`QQ` depending on whether either + of ``gmpy`` or ``gmpy2`` is installed or not. + + See also + ======== + + Domain + """ + + rep = 'QQ' + alias = 'QQ' + + is_RationalField = is_QQ = True + is_Numerical = True + + has_assoc_Ring = True + has_assoc_Field = True + + dtype = MPQ + zero = dtype(0) + one = dtype(1) + tp = type(one) + + def __init__(self): + pass + + def __eq__(self, other): + """Returns ``True`` if two domains are equivalent. """ + if isinstance(other, RationalField): + return True + else: + return NotImplemented + + def __hash__(self): + """Returns hash code of ``self``. """ + return hash('QQ') + + def get_ring(self): + """Returns ring associated with ``self``. """ + from sympy.polys.domains import ZZ + return ZZ + + def to_sympy(self, a): + """Convert ``a`` to a SymPy object. """ + return SymPyRational(int(a.numerator), int(a.denominator)) + + def from_sympy(self, a): + """Convert SymPy's Integer to ``dtype``. """ + if a.is_Rational: + return MPQ(a.p, a.q) + elif a.is_Float: + from sympy.polys.domains import RR + return MPQ(*map(int, RR.to_rational(a))) + else: + raise CoercionFailed("expected `Rational` object, got %s" % a) + + def algebraic_field(self, *extension, alias=None): + r"""Returns an algebraic field, i.e. `\mathbb{Q}(\alpha, \ldots)`. + + Parameters + ========== + + *extension : One or more :py:class:`~.Expr` + Generators of the extension. These should be expressions that are + algebraic over `\mathbb{Q}`. + + alias : str, :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the alias symbol for the + primitive element of the returned :py:class:`~.AlgebraicField`. + + Returns + ======= + + :py:class:`~.AlgebraicField` + A :py:class:`~.Domain` representing the algebraic field extension. + + Examples + ======== + + >>> from sympy import QQ, sqrt + >>> QQ.algebraic_field(sqrt(2)) + QQ + """ + from sympy.polys.domains import AlgebraicField + return AlgebraicField(self, *extension, alias=alias) + + def from_AlgebraicField(K1, a, K0): + """Convert a :py:class:`~.ANP` object to :ref:`QQ`. + + See :py:meth:`~.Domain.convert` + """ + if a.is_ground: + return K1.convert(a.LC(), K0.dom) + + def from_ZZ(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return MPQ(a) + + def from_ZZ_python(K1, a, K0): + """Convert a Python ``int`` object to ``dtype``. """ + return MPQ(a) + + def from_QQ(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return MPQ(a.numerator, a.denominator) + + def from_QQ_python(K1, a, K0): + """Convert a Python ``Fraction`` object to ``dtype``. """ + return MPQ(a.numerator, a.denominator) + + def from_ZZ_gmpy(K1, a, K0): + """Convert a GMPY ``mpz`` object to ``dtype``. """ + return MPQ(a) + + def from_QQ_gmpy(K1, a, K0): + """Convert a GMPY ``mpq`` object to ``dtype``. """ + return a + + def from_GaussianRationalField(K1, a, K0): + """Convert a ``GaussianElement`` object to ``dtype``. """ + if a.y == 0: + return MPQ(a.x) + + def from_RealField(K1, a, K0): + """Convert a mpmath ``mpf`` object to ``dtype``. """ + return MPQ(*map(int, K0.to_rational(a))) + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return MPQ(a) / MPQ(b) + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__truediv__``. """ + return MPQ(a) / MPQ(b) + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies nothing. """ + return self.zero + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__truediv__``. """ + return MPQ(a) / MPQ(b), self.zero + + def numer(self, a): + """Returns numerator of ``a``. """ + return a.numerator + + def denom(self, a): + """Returns denominator of ``a``. """ + return a.denominator + + def is_square(self, a): + """Return ``True`` if ``a`` is a square. + + Explanation + =========== + A rational number is a square if and only if there exists + a rational number ``b`` such that ``b * b == a``. + """ + return is_square(a.numerator) and is_square(a.denominator) + + def exsqrt(self, a): + """Non-negative square root of ``a`` if ``a`` is a square. + + See also + ======== + is_square + """ + if a.numerator < 0: # denominator is always positive + return None + p_sqrt, p_rem = sqrtrem(a.numerator) + if p_rem != 0: + return None + q_sqrt, q_rem = sqrtrem(a.denominator) + if q_rem != 0: + return None + return MPQ(p_sqrt, q_sqrt) + +QQ = RationalField() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/realfield.py b/lib/python3.10/site-packages/sympy/polys/domains/realfield.py new file mode 100644 index 0000000000000000000000000000000000000000..754335f9ed0ee1fed660ea67dd54cb9ed25cc799 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/realfield.py @@ -0,0 +1,167 @@ +"""Implementation of :class:`RealField` class. """ + + +from sympy.external.gmpy import SYMPY_INTS +from sympy.core.numbers import Float +from sympy.polys.domains.field import Field +from sympy.polys.domains.simpledomain import SimpleDomain +from sympy.polys.domains.characteristiczero import CharacteristicZero +from sympy.polys.domains.mpelements import MPContext +from sympy.polys.polyerrors import CoercionFailed +from sympy.utilities import public + +@public +class RealField(Field, CharacteristicZero, SimpleDomain): + """Real numbers up to the given precision. """ + + rep = 'RR' + + is_RealField = is_RR = True + + is_Exact = False + is_Numerical = True + is_PID = False + + has_assoc_Ring = False + has_assoc_Field = True + + _default_precision = 53 + + @property + def has_default_precision(self): + return self.precision == self._default_precision + + @property + def precision(self): + return self._context.prec + + @property + def dps(self): + return self._context.dps + + @property + def tolerance(self): + return self._context.tolerance + + def __init__(self, prec=_default_precision, dps=None, tol=None): + context = MPContext(prec, dps, tol, True) + context._parent = self + self._context = context + + self._dtype = context.mpf + self.zero = self.dtype(0) + self.one = self.dtype(1) + + @property + def tp(self): + # XXX: Domain treats tp as an alis of dtype. Here we need to two + # separate things: dtype is a callable to make/convert instances. + # We use tp with isinstance to check if an object is an instance + # of the domain already. + return self._dtype + + def dtype(self, arg): + # XXX: This is needed because mpmath does not recognise fmpz. + # It might be better to add conversion routines to mpmath and if that + # happens then this can be removed. + if isinstance(arg, SYMPY_INTS): + arg = int(arg) + return self._dtype(arg) + + def __eq__(self, other): + return (isinstance(other, RealField) + and self.precision == other.precision + and self.tolerance == other.tolerance) + + def __hash__(self): + return hash((self.__class__.__name__, self._dtype, self.precision, self.tolerance)) + + def to_sympy(self, element): + """Convert ``element`` to SymPy number. """ + return Float(element, self.dps) + + def from_sympy(self, expr): + """Convert SymPy's number to ``dtype``. """ + number = expr.evalf(n=self.dps) + + if number.is_Number: + return self.dtype(number) + else: + raise CoercionFailed("expected real number, got %s" % expr) + + def from_ZZ(self, element, base): + return self.dtype(element) + + def from_ZZ_python(self, element, base): + return self.dtype(element) + + def from_ZZ_gmpy(self, element, base): + return self.dtype(int(element)) + + # XXX: We need to convert the denominators to int here because mpmath does + # not recognise mpz. Ideally mpmath would handle this and if it changed to + # do so then the calls to int here could be removed. + + def from_QQ(self, element, base): + return self.dtype(element.numerator) / int(element.denominator) + + def from_QQ_python(self, element, base): + return self.dtype(element.numerator) / int(element.denominator) + + def from_QQ_gmpy(self, element, base): + return self.dtype(int(element.numerator)) / int(element.denominator) + + def from_AlgebraicField(self, element, base): + return self.from_sympy(base.to_sympy(element).evalf(self.dps)) + + def from_RealField(self, element, base): + if self == base: + return element + else: + return self.dtype(element) + + def from_ComplexField(self, element, base): + if not element.imag: + return self.dtype(element.real) + + def to_rational(self, element, limit=True): + """Convert a real number to rational number. """ + return self._context.to_rational(element, limit) + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self + + def get_exact(self): + """Returns an exact domain associated with ``self``. """ + from sympy.polys.domains import QQ + return QQ + + def gcd(self, a, b): + """Returns GCD of ``a`` and ``b``. """ + return self.one + + def lcm(self, a, b): + """Returns LCM of ``a`` and ``b``. """ + return a*b + + def almosteq(self, a, b, tolerance=None): + """Check if ``a`` and ``b`` are almost equal. """ + return self._context.almosteq(a, b, tolerance) + + def is_square(self, a): + """Returns ``True`` if ``a >= 0`` and ``False`` otherwise. """ + return a >= 0 + + def exsqrt(self, a): + """Non-negative square root for ``a >= 0`` and ``None`` otherwise. + + Explanation + =========== + The square root may be slightly inaccurate due to floating point + rounding error. + """ + return a ** 0.5 if a >= 0 else None + + +RR = RealField() diff --git a/lib/python3.10/site-packages/sympy/polys/domains/ring.py b/lib/python3.10/site-packages/sympy/polys/domains/ring.py new file mode 100644 index 0000000000000000000000000000000000000000..c69e6944d8f51e4b319609368a476e6e847ae126 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/ring.py @@ -0,0 +1,118 @@ +"""Implementation of :class:`Ring` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.polys.polyerrors import ExactQuotientFailed, NotInvertible, NotReversible + +from sympy.utilities import public + +@public +class Ring(Domain): + """Represents a ring domain. """ + + is_Ring = True + + def get_ring(self): + """Returns a ring associated with ``self``. """ + return self + + def exquo(self, a, b): + """Exact quotient of ``a`` and ``b``, implies ``__floordiv__``. """ + if a % b: + raise ExactQuotientFailed(a, b, self) + else: + return a // b + + def quo(self, a, b): + """Quotient of ``a`` and ``b``, implies ``__floordiv__``. """ + return a // b + + def rem(self, a, b): + """Remainder of ``a`` and ``b``, implies ``__mod__``. """ + return a % b + + def div(self, a, b): + """Division of ``a`` and ``b``, implies ``__divmod__``. """ + return divmod(a, b) + + def invert(self, a, b): + """Returns inversion of ``a mod b``. """ + s, t, h = self.gcdex(a, b) + + if self.is_one(h): + return s % b + else: + raise NotInvertible("zero divisor") + + def revert(self, a): + """Returns ``a**(-1)`` if possible. """ + if self.is_one(a) or self.is_one(-a): + return a + else: + raise NotReversible('only units are reversible in a ring') + + def is_unit(self, a): + try: + self.revert(a) + return True + except NotReversible: + return False + + def numer(self, a): + """Returns numerator of ``a``. """ + return a + + def denom(self, a): + """Returns denominator of `a`. """ + return self.one + + def free_module(self, rank): + """ + Generate a free module of rank ``rank`` over self. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).free_module(2) + QQ[x]**2 + """ + raise NotImplementedError + + def ideal(self, *gens): + """ + Generate an ideal of ``self``. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).ideal(x**2) + + """ + from sympy.polys.agca.ideals import ModuleImplementedIdeal + return ModuleImplementedIdeal(self, self.free_module(1).submodule( + *[[x] for x in gens])) + + def quotient_ring(self, e): + """ + Form a quotient ring of ``self``. + + Here ``e`` can be an ideal or an iterable. + + >>> from sympy.abc import x + >>> from sympy import QQ + >>> QQ.old_poly_ring(x).quotient_ring(QQ.old_poly_ring(x).ideal(x**2)) + QQ[x]/ + >>> QQ.old_poly_ring(x).quotient_ring([x**2]) + QQ[x]/ + + The division operator has been overloaded for this: + + >>> QQ.old_poly_ring(x)/[x**2] + QQ[x]/ + """ + from sympy.polys.agca.ideals import Ideal + from sympy.polys.domains.quotientring import QuotientRing + if not isinstance(e, Ideal): + e = self.ideal(*e) + return QuotientRing(self, e) + + def __truediv__(self, e): + return self.quotient_ring(e) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/simpledomain.py b/lib/python3.10/site-packages/sympy/polys/domains/simpledomain.py new file mode 100644 index 0000000000000000000000000000000000000000..88cf634555d8bd9229d7fc511af3cf96fececbb8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/simpledomain.py @@ -0,0 +1,15 @@ +"""Implementation of :class:`SimpleDomain` class. """ + + +from sympy.polys.domains.domain import Domain +from sympy.utilities import public + +@public +class SimpleDomain(Domain): + """Base class for simple domains, e.g. ZZ, QQ. """ + + is_Simple = True + + def inject(self, *gens): + """Inject generators into this domain. """ + return self.poly_ring(*gens) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/tests/__init__.py b/lib/python3.10/site-packages/sympy/polys/domains/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/polys/domains/tests/test_domains.py b/lib/python3.10/site-packages/sympy/polys/domains/tests/test_domains.py new file mode 100644 index 0000000000000000000000000000000000000000..13fc7940fc280b17ac5593934446ee8778f7dafe --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/tests/test_domains.py @@ -0,0 +1,1416 @@ +"""Tests for classes defining properties of ground domains, e.g. ZZ, QQ, ZZ[x] ... """ + +from sympy.external.gmpy import GROUND_TYPES + +from sympy.core.numbers import (AlgebraicNumber, E, Float, I, Integer, + Rational, oo, pi, _illegal) +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.polys.polytools import Poly +from sympy.abc import x, y, z + +from sympy.polys.domains import (ZZ, QQ, RR, CC, FF, GF, EX, EXRAW, ZZ_gmpy, + ZZ_python, QQ_gmpy, QQ_python) +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.gaussiandomains import ZZ_I, QQ_I +from sympy.polys.domains.polynomialring import PolynomialRing +from sympy.polys.domains.realfield import RealField + +from sympy.polys.numberfields.subfield import field_isomorphism +from sympy.polys.rings import ring +from sympy.polys.specialpolys import cyclotomic_poly +from sympy.polys.fields import field + +from sympy.polys.agca.extensions import FiniteExtension + +from sympy.polys.polyerrors import ( + UnificationFailed, + GeneratorsError, + CoercionFailed, + NotInvertible, + DomainError) + +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from itertools import product + +ALG = QQ.algebraic_field(sqrt(2), sqrt(3)) + +def unify(K0, K1): + return K0.unify(K1) + +def test_Domain_unify(): + F3 = GF(3) + F5 = GF(5) + + assert unify(F3, F3) == F3 + raises(UnificationFailed, lambda: unify(F3, ZZ)) + raises(UnificationFailed, lambda: unify(F3, QQ)) + raises(UnificationFailed, lambda: unify(F3, ZZ_I)) + raises(UnificationFailed, lambda: unify(F3, QQ_I)) + raises(UnificationFailed, lambda: unify(F3, ALG)) + raises(UnificationFailed, lambda: unify(F3, RR)) + raises(UnificationFailed, lambda: unify(F3, CC)) + raises(UnificationFailed, lambda: unify(F3, ZZ[x])) + raises(UnificationFailed, lambda: unify(F3, ZZ.frac_field(x))) + raises(UnificationFailed, lambda: unify(F3, EX)) + + assert unify(F5, F5) == F5 + raises(UnificationFailed, lambda: unify(F5, F3)) + raises(UnificationFailed, lambda: unify(F5, F3[x])) + raises(UnificationFailed, lambda: unify(F5, F3.frac_field(x))) + + raises(UnificationFailed, lambda: unify(ZZ, F3)) + assert unify(ZZ, ZZ) == ZZ + assert unify(ZZ, QQ) == QQ + assert unify(ZZ, ALG) == ALG + assert unify(ZZ, RR) == RR + assert unify(ZZ, CC) == CC + assert unify(ZZ, ZZ[x]) == ZZ[x] + assert unify(ZZ, ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ, EX) == EX + + raises(UnificationFailed, lambda: unify(QQ, F3)) + assert unify(QQ, ZZ) == QQ + assert unify(QQ, QQ) == QQ + assert unify(QQ, ALG) == ALG + assert unify(QQ, RR) == RR + assert unify(QQ, CC) == CC + assert unify(QQ, ZZ[x]) == QQ[x] + assert unify(QQ, ZZ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ, EX) == EX + + raises(UnificationFailed, lambda: unify(ZZ_I, F3)) + assert unify(ZZ_I, ZZ) == ZZ_I + assert unify(ZZ_I, ZZ_I) == ZZ_I + assert unify(ZZ_I, QQ) == QQ_I + assert unify(ZZ_I, ALG) == QQ.algebraic_field(I, sqrt(2), sqrt(3)) + assert unify(ZZ_I, RR) == CC + assert unify(ZZ_I, CC) == CC + assert unify(ZZ_I, ZZ[x]) == ZZ_I[x] + assert unify(ZZ_I, ZZ_I[x]) == ZZ_I[x] + assert unify(ZZ_I, ZZ.frac_field(x)) == ZZ_I.frac_field(x) + assert unify(ZZ_I, ZZ_I.frac_field(x)) == ZZ_I.frac_field(x) + assert unify(ZZ_I, EX) == EX + + raises(UnificationFailed, lambda: unify(QQ_I, F3)) + assert unify(QQ_I, ZZ) == QQ_I + assert unify(QQ_I, ZZ_I) == QQ_I + assert unify(QQ_I, QQ) == QQ_I + assert unify(QQ_I, ALG) == QQ.algebraic_field(I, sqrt(2), sqrt(3)) + assert unify(QQ_I, RR) == CC + assert unify(QQ_I, CC) == CC + assert unify(QQ_I, ZZ[x]) == QQ_I[x] + assert unify(QQ_I, ZZ_I[x]) == QQ_I[x] + assert unify(QQ_I, QQ[x]) == QQ_I[x] + assert unify(QQ_I, QQ_I[x]) == QQ_I[x] + assert unify(QQ_I, ZZ.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, ZZ_I.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, QQ.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, QQ_I.frac_field(x)) == QQ_I.frac_field(x) + assert unify(QQ_I, EX) == EX + + raises(UnificationFailed, lambda: unify(RR, F3)) + assert unify(RR, ZZ) == RR + assert unify(RR, QQ) == RR + assert unify(RR, ALG) == RR + assert unify(RR, RR) == RR + assert unify(RR, CC) == CC + assert unify(RR, ZZ[x]) == RR[x] + assert unify(RR, ZZ.frac_field(x)) == RR.frac_field(x) + assert unify(RR, EX) == EX + assert RR[x].unify(ZZ.frac_field(y)) == RR.frac_field(x, y) + + raises(UnificationFailed, lambda: unify(CC, F3)) + assert unify(CC, ZZ) == CC + assert unify(CC, QQ) == CC + assert unify(CC, ALG) == CC + assert unify(CC, RR) == CC + assert unify(CC, CC) == CC + assert unify(CC, ZZ[x]) == CC[x] + assert unify(CC, ZZ.frac_field(x)) == CC.frac_field(x) + assert unify(CC, EX) == EX + + raises(UnificationFailed, lambda: unify(ZZ[x], F3)) + assert unify(ZZ[x], ZZ) == ZZ[x] + assert unify(ZZ[x], QQ) == QQ[x] + assert unify(ZZ[x], ALG) == ALG[x] + assert unify(ZZ[x], RR) == RR[x] + assert unify(ZZ[x], CC) == CC[x] + assert unify(ZZ[x], ZZ[x]) == ZZ[x] + assert unify(ZZ[x], ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ[x], EX) == EX + + raises(UnificationFailed, lambda: unify(ZZ.frac_field(x), F3)) + assert unify(ZZ.frac_field(x), ZZ) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ) == QQ.frac_field(x) + assert unify(ZZ.frac_field(x), ALG) == ALG.frac_field(x) + assert unify(ZZ.frac_field(x), RR) == RR.frac_field(x) + assert unify(ZZ.frac_field(x), CC) == CC.frac_field(x) + assert unify(ZZ.frac_field(x), ZZ[x]) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), EX) == EX + + raises(UnificationFailed, lambda: unify(EX, F3)) + assert unify(EX, ZZ) == EX + assert unify(EX, QQ) == EX + assert unify(EX, ALG) == EX + assert unify(EX, RR) == EX + assert unify(EX, CC) == EX + assert unify(EX, ZZ[x]) == EX + assert unify(EX, ZZ.frac_field(x)) == EX + assert unify(EX, EX) == EX + +def test_Domain_unify_composite(): + assert unify(ZZ.poly_ring(x), ZZ) == ZZ.poly_ring(x) + assert unify(ZZ.poly_ring(x), QQ) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), ZZ) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), QQ) == QQ.poly_ring(x) + + assert unify(ZZ, ZZ.poly_ring(x)) == ZZ.poly_ring(x) + assert unify(QQ, ZZ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(ZZ, QQ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(QQ, QQ.poly_ring(x)) == QQ.poly_ring(x) + + assert unify(ZZ.poly_ring(x, y), ZZ) == ZZ.poly_ring(x, y) + assert unify(ZZ.poly_ring(x, y), QQ) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), ZZ) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), QQ) == QQ.poly_ring(x, y) + + assert unify(ZZ, ZZ.poly_ring(x, y)) == ZZ.poly_ring(x, y) + assert unify(QQ, ZZ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(ZZ, QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(QQ, QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + + assert unify(ZZ.frac_field(x), ZZ) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), ZZ) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), QQ) == QQ.frac_field(x) + + assert unify(ZZ, ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(QQ, ZZ.frac_field(x)) == QQ.frac_field(x) + assert unify(ZZ, QQ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ, QQ.frac_field(x)) == QQ.frac_field(x) + + assert unify(ZZ.frac_field(x, y), ZZ) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x, y), QQ) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), ZZ) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), QQ) == QQ.frac_field(x, y) + + assert unify(ZZ, ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ, ZZ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(ZZ, QQ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(QQ, QQ.frac_field(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.poly_ring(x), ZZ.poly_ring(x)) == ZZ.poly_ring(x) + assert unify(ZZ.poly_ring(x), QQ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), ZZ.poly_ring(x)) == QQ.poly_ring(x) + assert unify(QQ.poly_ring(x), QQ.poly_ring(x)) == QQ.poly_ring(x) + + assert unify(ZZ.poly_ring(x, y), ZZ.poly_ring(x)) == ZZ.poly_ring(x, y) + assert unify(ZZ.poly_ring(x, y), QQ.poly_ring(x)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), ZZ.poly_ring(x)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x, y), QQ.poly_ring(x)) == QQ.poly_ring(x, y) + + assert unify(ZZ.poly_ring(x), ZZ.poly_ring(x, y)) == ZZ.poly_ring(x, y) + assert unify(ZZ.poly_ring(x), QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x), ZZ.poly_ring(x, y)) == QQ.poly_ring(x, y) + assert unify(QQ.poly_ring(x), QQ.poly_ring(x, y)) == QQ.poly_ring(x, y) + + assert unify(ZZ.poly_ring(x, y), ZZ.poly_ring(x, z)) == ZZ.poly_ring(x, y, z) + assert unify(ZZ.poly_ring(x, y), QQ.poly_ring(x, z)) == QQ.poly_ring(x, y, z) + assert unify(QQ.poly_ring(x, y), ZZ.poly_ring(x, z)) == QQ.poly_ring(x, y, z) + assert unify(QQ.poly_ring(x, y), QQ.poly_ring(x, z)) == QQ.poly_ring(x, y, z) + + assert unify(ZZ.frac_field(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), ZZ.frac_field(x)) == QQ.frac_field(x) + assert unify(QQ.frac_field(x), QQ.frac_field(x)) == QQ.frac_field(x) + + assert unify(ZZ.frac_field(x, y), ZZ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x, y), QQ.frac_field(x)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), ZZ.frac_field(x)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), QQ.frac_field(x)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x), ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x), QQ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x), ZZ.frac_field(x, y)) == QQ.frac_field(x, y) + assert unify(QQ.frac_field(x), QQ.frac_field(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x, y), ZZ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(ZZ.frac_field(x, y), QQ.frac_field(x, z)) == QQ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), ZZ.frac_field(x, z)) == QQ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), QQ.frac_field(x, z)) == QQ.frac_field(x, y, z) + + assert unify(ZZ.poly_ring(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(ZZ.poly_ring(x), QQ.frac_field(x)) == ZZ.frac_field(x) + assert unify(QQ.poly_ring(x), ZZ.frac_field(x)) == ZZ.frac_field(x) + assert unify(QQ.poly_ring(x), QQ.frac_field(x)) == QQ.frac_field(x) + + assert unify(ZZ.poly_ring(x, y), ZZ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(ZZ.poly_ring(x, y), QQ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x, y), ZZ.frac_field(x)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x, y), QQ.frac_field(x)) == QQ.frac_field(x, y) + + assert unify(ZZ.poly_ring(x), ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(ZZ.poly_ring(x), QQ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x), ZZ.frac_field(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.poly_ring(x), QQ.frac_field(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.poly_ring(x, y), ZZ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(ZZ.poly_ring(x, y), QQ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.poly_ring(x, y), ZZ.frac_field(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.poly_ring(x, y), QQ.frac_field(x, z)) == QQ.frac_field(x, y, z) + + assert unify(ZZ.frac_field(x), ZZ.poly_ring(x)) == ZZ.frac_field(x) + assert unify(ZZ.frac_field(x), QQ.poly_ring(x)) == ZZ.frac_field(x) + assert unify(QQ.frac_field(x), ZZ.poly_ring(x)) == ZZ.frac_field(x) + assert unify(QQ.frac_field(x), QQ.poly_ring(x)) == QQ.frac_field(x) + + assert unify(ZZ.frac_field(x, y), ZZ.poly_ring(x)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x, y), QQ.poly_ring(x)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), ZZ.poly_ring(x)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x, y), QQ.poly_ring(x)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x), ZZ.poly_ring(x, y)) == ZZ.frac_field(x, y) + assert unify(ZZ.frac_field(x), QQ.poly_ring(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x), ZZ.poly_ring(x, y)) == ZZ.frac_field(x, y) + assert unify(QQ.frac_field(x), QQ.poly_ring(x, y)) == QQ.frac_field(x, y) + + assert unify(ZZ.frac_field(x, y), ZZ.poly_ring(x, z)) == ZZ.frac_field(x, y, z) + assert unify(ZZ.frac_field(x, y), QQ.poly_ring(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), ZZ.poly_ring(x, z)) == ZZ.frac_field(x, y, z) + assert unify(QQ.frac_field(x, y), QQ.poly_ring(x, z)) == QQ.frac_field(x, y, z) + +def test_Domain_unify_algebraic(): + sqrt5 = QQ.algebraic_field(sqrt(5)) + sqrt7 = QQ.algebraic_field(sqrt(7)) + sqrt57 = QQ.algebraic_field(sqrt(5), sqrt(7)) + + assert sqrt5.unify(sqrt7) == sqrt57 + + assert sqrt5.unify(sqrt5[x, y]) == sqrt5[x, y] + assert sqrt5[x, y].unify(sqrt5) == sqrt5[x, y] + + assert sqrt5.unify(sqrt5.frac_field(x, y)) == sqrt5.frac_field(x, y) + assert sqrt5.frac_field(x, y).unify(sqrt5) == sqrt5.frac_field(x, y) + + assert sqrt5.unify(sqrt7[x, y]) == sqrt57[x, y] + assert sqrt5[x, y].unify(sqrt7) == sqrt57[x, y] + + assert sqrt5.unify(sqrt7.frac_field(x, y)) == sqrt57.frac_field(x, y) + assert sqrt5.frac_field(x, y).unify(sqrt7) == sqrt57.frac_field(x, y) + +def test_Domain_unify_FiniteExtension(): + KxZZ = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ)) + KxQQ = FiniteExtension(Poly(x**2 - 2, x, domain=QQ)) + KxZZy = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ[y])) + KxQQy = FiniteExtension(Poly(x**2 - 2, x, domain=QQ[y])) + + assert KxZZ.unify(KxZZ) == KxZZ + assert KxQQ.unify(KxQQ) == KxQQ + assert KxZZy.unify(KxZZy) == KxZZy + assert KxQQy.unify(KxQQy) == KxQQy + + assert KxZZ.unify(ZZ) == KxZZ + assert KxZZ.unify(QQ) == KxQQ + assert KxQQ.unify(ZZ) == KxQQ + assert KxQQ.unify(QQ) == KxQQ + + assert KxZZ.unify(ZZ[y]) == KxZZy + assert KxZZ.unify(QQ[y]) == KxQQy + assert KxQQ.unify(ZZ[y]) == KxQQy + assert KxQQ.unify(QQ[y]) == KxQQy + + assert KxZZy.unify(ZZ) == KxZZy + assert KxZZy.unify(QQ) == KxQQy + assert KxQQy.unify(ZZ) == KxQQy + assert KxQQy.unify(QQ) == KxQQy + + assert KxZZy.unify(ZZ[y]) == KxZZy + assert KxZZy.unify(QQ[y]) == KxQQy + assert KxQQy.unify(ZZ[y]) == KxQQy + assert KxQQy.unify(QQ[y]) == KxQQy + + K = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ[y])) + assert K.unify(ZZ) == K + assert K.unify(ZZ[x]) == K + assert K.unify(ZZ[y]) == K + assert K.unify(ZZ[x, y]) == K + + Kz = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ[y, z])) + assert K.unify(ZZ[z]) == Kz + assert K.unify(ZZ[x, z]) == Kz + assert K.unify(ZZ[y, z]) == Kz + assert K.unify(ZZ[x, y, z]) == Kz + + Kx = FiniteExtension(Poly(x**2 - 2, x, domain=ZZ)) + Ky = FiniteExtension(Poly(y**2 - 2, y, domain=ZZ)) + Kxy = FiniteExtension(Poly(y**2 - 2, y, domain=Kx)) + assert Kx.unify(Kx) == Kx + assert Ky.unify(Ky) == Ky + assert Kx.unify(Ky) == Kxy + assert Ky.unify(Kx) == Kxy + +def test_Domain_unify_with_symbols(): + raises(UnificationFailed, lambda: ZZ[x, y].unify_with_symbols(ZZ, (y, z))) + raises(UnificationFailed, lambda: ZZ.unify_with_symbols(ZZ[x, y], (y, z))) + +def test_Domain__contains__(): + assert (0 in EX) is True + assert (0 in ZZ) is True + assert (0 in QQ) is True + assert (0 in RR) is True + assert (0 in CC) is True + assert (0 in ALG) is True + assert (0 in ZZ[x, y]) is True + assert (0 in QQ[x, y]) is True + assert (0 in RR[x, y]) is True + + assert (-7 in EX) is True + assert (-7 in ZZ) is True + assert (-7 in QQ) is True + assert (-7 in RR) is True + assert (-7 in CC) is True + assert (-7 in ALG) is True + assert (-7 in ZZ[x, y]) is True + assert (-7 in QQ[x, y]) is True + assert (-7 in RR[x, y]) is True + + assert (17 in EX) is True + assert (17 in ZZ) is True + assert (17 in QQ) is True + assert (17 in RR) is True + assert (17 in CC) is True + assert (17 in ALG) is True + assert (17 in ZZ[x, y]) is True + assert (17 in QQ[x, y]) is True + assert (17 in RR[x, y]) is True + + assert (Rational(-1, 7) in EX) is True + assert (Rational(-1, 7) in ZZ) is False + assert (Rational(-1, 7) in QQ) is True + assert (Rational(-1, 7) in RR) is True + assert (Rational(-1, 7) in CC) is True + assert (Rational(-1, 7) in ALG) is True + assert (Rational(-1, 7) in ZZ[x, y]) is False + assert (Rational(-1, 7) in QQ[x, y]) is True + assert (Rational(-1, 7) in RR[x, y]) is True + + assert (Rational(3, 5) in EX) is True + assert (Rational(3, 5) in ZZ) is False + assert (Rational(3, 5) in QQ) is True + assert (Rational(3, 5) in RR) is True + assert (Rational(3, 5) in CC) is True + assert (Rational(3, 5) in ALG) is True + assert (Rational(3, 5) in ZZ[x, y]) is False + assert (Rational(3, 5) in QQ[x, y]) is True + assert (Rational(3, 5) in RR[x, y]) is True + + assert (3.0 in EX) is True + assert (3.0 in ZZ) is True + assert (3.0 in QQ) is True + assert (3.0 in RR) is True + assert (3.0 in CC) is True + assert (3.0 in ALG) is True + assert (3.0 in ZZ[x, y]) is True + assert (3.0 in QQ[x, y]) is True + assert (3.0 in RR[x, y]) is True + + assert (3.14 in EX) is True + assert (3.14 in ZZ) is False + assert (3.14 in QQ) is True + assert (3.14 in RR) is True + assert (3.14 in CC) is True + assert (3.14 in ALG) is True + assert (3.14 in ZZ[x, y]) is False + assert (3.14 in QQ[x, y]) is True + assert (3.14 in RR[x, y]) is True + + assert (oo in ALG) is False + assert (oo in ZZ[x, y]) is False + assert (oo in QQ[x, y]) is False + + assert (-oo in ZZ) is False + assert (-oo in QQ) is False + assert (-oo in ALG) is False + assert (-oo in ZZ[x, y]) is False + assert (-oo in QQ[x, y]) is False + + assert (sqrt(7) in EX) is True + assert (sqrt(7) in ZZ) is False + assert (sqrt(7) in QQ) is False + assert (sqrt(7) in RR) is True + assert (sqrt(7) in CC) is True + assert (sqrt(7) in ALG) is False + assert (sqrt(7) in ZZ[x, y]) is False + assert (sqrt(7) in QQ[x, y]) is False + assert (sqrt(7) in RR[x, y]) is True + + assert (2*sqrt(3) + 1 in EX) is True + assert (2*sqrt(3) + 1 in ZZ) is False + assert (2*sqrt(3) + 1 in QQ) is False + assert (2*sqrt(3) + 1 in RR) is True + assert (2*sqrt(3) + 1 in CC) is True + assert (2*sqrt(3) + 1 in ALG) is True + assert (2*sqrt(3) + 1 in ZZ[x, y]) is False + assert (2*sqrt(3) + 1 in QQ[x, y]) is False + assert (2*sqrt(3) + 1 in RR[x, y]) is True + + assert (sin(1) in EX) is True + assert (sin(1) in ZZ) is False + assert (sin(1) in QQ) is False + assert (sin(1) in RR) is True + assert (sin(1) in CC) is True + assert (sin(1) in ALG) is False + assert (sin(1) in ZZ[x, y]) is False + assert (sin(1) in QQ[x, y]) is False + assert (sin(1) in RR[x, y]) is True + + assert (x**2 + 1 in EX) is True + assert (x**2 + 1 in ZZ) is False + assert (x**2 + 1 in QQ) is False + assert (x**2 + 1 in RR) is False + assert (x**2 + 1 in CC) is False + assert (x**2 + 1 in ALG) is False + assert (x**2 + 1 in ZZ[x]) is True + assert (x**2 + 1 in QQ[x]) is True + assert (x**2 + 1 in RR[x]) is True + assert (x**2 + 1 in ZZ[x, y]) is True + assert (x**2 + 1 in QQ[x, y]) is True + assert (x**2 + 1 in RR[x, y]) is True + + assert (x**2 + y**2 in EX) is True + assert (x**2 + y**2 in ZZ) is False + assert (x**2 + y**2 in QQ) is False + assert (x**2 + y**2 in RR) is False + assert (x**2 + y**2 in CC) is False + assert (x**2 + y**2 in ALG) is False + assert (x**2 + y**2 in ZZ[x]) is False + assert (x**2 + y**2 in QQ[x]) is False + assert (x**2 + y**2 in RR[x]) is False + assert (x**2 + y**2 in ZZ[x, y]) is True + assert (x**2 + y**2 in QQ[x, y]) is True + assert (x**2 + y**2 in RR[x, y]) is True + + assert (Rational(3, 2)*x/(y + 1) - z in QQ[x, y, z]) is False + + +def test_issue_14433(): + assert (Rational(2, 3)*x in QQ.frac_field(1/x)) is True + assert (1/x in QQ.frac_field(x)) is True + assert ((x**2 + y**2) in QQ.frac_field(1/x, 1/y)) is True + assert ((x + y) in QQ.frac_field(1/x, y)) is True + assert ((x - y) in QQ.frac_field(x, 1/y)) is True + + +def test_Domain_get_ring(): + assert ZZ.has_assoc_Ring is True + assert QQ.has_assoc_Ring is True + assert ZZ[x].has_assoc_Ring is True + assert QQ[x].has_assoc_Ring is True + assert ZZ[x, y].has_assoc_Ring is True + assert QQ[x, y].has_assoc_Ring is True + assert ZZ.frac_field(x).has_assoc_Ring is True + assert QQ.frac_field(x).has_assoc_Ring is True + assert ZZ.frac_field(x, y).has_assoc_Ring is True + assert QQ.frac_field(x, y).has_assoc_Ring is True + + assert EX.has_assoc_Ring is False + assert RR.has_assoc_Ring is False + assert ALG.has_assoc_Ring is False + + assert ZZ.get_ring() == ZZ + assert QQ.get_ring() == ZZ + assert ZZ[x].get_ring() == ZZ[x] + assert QQ[x].get_ring() == QQ[x] + assert ZZ[x, y].get_ring() == ZZ[x, y] + assert QQ[x, y].get_ring() == QQ[x, y] + assert ZZ.frac_field(x).get_ring() == ZZ[x] + assert QQ.frac_field(x).get_ring() == QQ[x] + assert ZZ.frac_field(x, y).get_ring() == ZZ[x, y] + assert QQ.frac_field(x, y).get_ring() == QQ[x, y] + + assert EX.get_ring() == EX + + assert RR.get_ring() == RR + # XXX: This should also be like RR + raises(DomainError, lambda: ALG.get_ring()) + + +def test_Domain_get_field(): + assert EX.has_assoc_Field is True + assert ZZ.has_assoc_Field is True + assert QQ.has_assoc_Field is True + assert RR.has_assoc_Field is True + assert ALG.has_assoc_Field is True + assert ZZ[x].has_assoc_Field is True + assert QQ[x].has_assoc_Field is True + assert ZZ[x, y].has_assoc_Field is True + assert QQ[x, y].has_assoc_Field is True + + assert EX.get_field() == EX + assert ZZ.get_field() == QQ + assert QQ.get_field() == QQ + assert RR.get_field() == RR + assert ALG.get_field() == ALG + assert ZZ[x].get_field() == ZZ.frac_field(x) + assert QQ[x].get_field() == QQ.frac_field(x) + assert ZZ[x, y].get_field() == ZZ.frac_field(x, y) + assert QQ[x, y].get_field() == QQ.frac_field(x, y) + + +def test_Domain_set_domain(): + doms = [GF(5), ZZ, QQ, ALG, RR, CC, EX, ZZ[z], QQ[z], RR[z], CC[z], EX[z]] + for D1 in doms: + for D2 in doms: + assert D1[x].set_domain(D2) == D2[x] + assert D1[x, y].set_domain(D2) == D2[x, y] + assert D1.frac_field(x).set_domain(D2) == D2.frac_field(x) + assert D1.frac_field(x, y).set_domain(D2) == D2.frac_field(x, y) + assert D1.old_poly_ring(x).set_domain(D2) == D2.old_poly_ring(x) + assert D1.old_poly_ring(x, y).set_domain(D2) == D2.old_poly_ring(x, y) + assert D1.old_frac_field(x).set_domain(D2) == D2.old_frac_field(x) + assert D1.old_frac_field(x, y).set_domain(D2) == D2.old_frac_field(x, y) + + +def test_Domain_is_Exact(): + exact = [GF(5), ZZ, QQ, ALG, EX] + inexact = [RR, CC] + for D in exact + inexact: + for R in D, D[x], D.frac_field(x), D.old_poly_ring(x), D.old_frac_field(x): + if D in exact: + assert R.is_Exact is True + else: + assert R.is_Exact is False + + +def test_Domain_get_exact(): + assert EX.get_exact() == EX + assert ZZ.get_exact() == ZZ + assert QQ.get_exact() == QQ + assert RR.get_exact() == QQ + assert CC.get_exact() == QQ_I + assert ALG.get_exact() == ALG + assert ZZ[x].get_exact() == ZZ[x] + assert QQ[x].get_exact() == QQ[x] + assert RR[x].get_exact() == QQ[x] + assert CC[x].get_exact() == QQ_I[x] + assert ZZ[x, y].get_exact() == ZZ[x, y] + assert QQ[x, y].get_exact() == QQ[x, y] + assert RR[x, y].get_exact() == QQ[x, y] + assert CC[x, y].get_exact() == QQ_I[x, y] + assert ZZ.frac_field(x).get_exact() == ZZ.frac_field(x) + assert QQ.frac_field(x).get_exact() == QQ.frac_field(x) + assert RR.frac_field(x).get_exact() == QQ.frac_field(x) + assert CC.frac_field(x).get_exact() == QQ_I.frac_field(x) + assert ZZ.frac_field(x, y).get_exact() == ZZ.frac_field(x, y) + assert QQ.frac_field(x, y).get_exact() == QQ.frac_field(x, y) + assert RR.frac_field(x, y).get_exact() == QQ.frac_field(x, y) + assert CC.frac_field(x, y).get_exact() == QQ_I.frac_field(x, y) + assert ZZ.old_poly_ring(x).get_exact() == ZZ.old_poly_ring(x) + assert QQ.old_poly_ring(x).get_exact() == QQ.old_poly_ring(x) + assert RR.old_poly_ring(x).get_exact() == QQ.old_poly_ring(x) + assert CC.old_poly_ring(x).get_exact() == QQ_I.old_poly_ring(x) + assert ZZ.old_poly_ring(x, y).get_exact() == ZZ.old_poly_ring(x, y) + assert QQ.old_poly_ring(x, y).get_exact() == QQ.old_poly_ring(x, y) + assert RR.old_poly_ring(x, y).get_exact() == QQ.old_poly_ring(x, y) + assert CC.old_poly_ring(x, y).get_exact() == QQ_I.old_poly_ring(x, y) + assert ZZ.old_frac_field(x).get_exact() == ZZ.old_frac_field(x) + assert QQ.old_frac_field(x).get_exact() == QQ.old_frac_field(x) + assert RR.old_frac_field(x).get_exact() == QQ.old_frac_field(x) + assert CC.old_frac_field(x).get_exact() == QQ_I.old_frac_field(x) + assert ZZ.old_frac_field(x, y).get_exact() == ZZ.old_frac_field(x, y) + assert QQ.old_frac_field(x, y).get_exact() == QQ.old_frac_field(x, y) + assert RR.old_frac_field(x, y).get_exact() == QQ.old_frac_field(x, y) + assert CC.old_frac_field(x, y).get_exact() == QQ_I.old_frac_field(x, y) + + +def test_Domain_characteristic(): + for F, c in [(FF(3), 3), (FF(5), 5), (FF(7), 7)]: + for R in F, F[x], F.frac_field(x), F.old_poly_ring(x), F.old_frac_field(x): + assert R.has_CharacteristicZero is False + assert R.characteristic() == c + for D in ZZ, QQ, ZZ_I, QQ_I, ALG: + for R in D, D[x], D.frac_field(x), D.old_poly_ring(x), D.old_frac_field(x): + assert R.has_CharacteristicZero is True + assert R.characteristic() == 0 + + +def test_Domain_is_unit(): + nums = [-2, -1, 0, 1, 2] + invring = [False, True, False, True, False] + invfield = [True, True, False, True, True] + ZZx, QQx, QQxf = ZZ[x], QQ[x], QQ.frac_field(x) + assert [ZZ.is_unit(ZZ(n)) for n in nums] == invring + assert [QQ.is_unit(QQ(n)) for n in nums] == invfield + assert [ZZx.is_unit(ZZx(n)) for n in nums] == invring + assert [QQx.is_unit(QQx(n)) for n in nums] == invfield + assert [QQxf.is_unit(QQxf(n)) for n in nums] == invfield + assert ZZx.is_unit(ZZx(x)) is False + assert QQx.is_unit(QQx(x)) is False + assert QQxf.is_unit(QQxf(x)) is True + + +def test_Domain_convert(): + + def check_element(e1, e2, K1, K2, K3): + assert type(e1) is type(e2), '%s, %s: %s %s -> %s' % (e1, e2, K1, K2, K3) + assert e1 == e2, '%s, %s: %s %s -> %s' % (e1, e2, K1, K2, K3) + + def check_domains(K1, K2): + K3 = K1.unify(K2) + check_element(K3.convert_from(K1.one, K1), K3.one, K1, K2, K3) + check_element(K3.convert_from(K2.one, K2), K3.one, K1, K2, K3) + check_element(K3.convert_from(K1.zero, K1), K3.zero, K1, K2, K3) + check_element(K3.convert_from(K2.zero, K2), K3.zero, K1, K2, K3) + + def composite_domains(K): + domains = [ + K, + K[y], K[z], K[y, z], + K.frac_field(y), K.frac_field(z), K.frac_field(y, z), + # XXX: These should be tested and made to work... + # K.old_poly_ring(y), K.old_frac_field(y), + ] + return domains + + QQ2 = QQ.algebraic_field(sqrt(2)) + QQ3 = QQ.algebraic_field(sqrt(3)) + doms = [ZZ, QQ, QQ2, QQ3, QQ_I, ZZ_I, RR, CC] + + for i, K1 in enumerate(doms): + for K2 in doms[i:]: + for K3 in composite_domains(K1): + for K4 in composite_domains(K2): + check_domains(K3, K4) + + assert QQ.convert(10e-52) == QQ(1684996666696915, 1684996666696914987166688442938726917102321526408785780068975640576) + + R, xr = ring("x", ZZ) + assert ZZ.convert(xr - xr) == 0 + assert ZZ.convert(xr - xr, R.to_domain()) == 0 + + assert CC.convert(ZZ_I(1, 2)) == CC(1, 2) + assert CC.convert(QQ_I(1, 2)) == CC(1, 2) + + assert QQ.convert_from(RR(0.5), RR) == QQ(1, 2) + assert RR.convert_from(QQ(1, 2), QQ) == RR(0.5) + assert QQ_I.convert_from(CC(0.5, 0.75), CC) == QQ_I(QQ(1, 2), QQ(3, 4)) + assert CC.convert_from(QQ_I(QQ(1, 2), QQ(3, 4)), QQ_I) == CC(0.5, 0.75) + + K1 = QQ.frac_field(x) + K2 = ZZ.frac_field(x) + K3 = QQ[x] + K4 = ZZ[x] + Ks = [K1, K2, K3, K4] + for Ka, Kb in product(Ks, Ks): + assert Ka.convert_from(Kb.from_sympy(x), Kb) == Ka.from_sympy(x) + + assert K2.convert_from(QQ(1, 2), QQ) == K2(QQ(1, 2)) + + +def test_EX_convert(): + + elements = [ + (ZZ, ZZ(3)), + (QQ, QQ(1,2)), + (ZZ_I, ZZ_I(1,2)), + (QQ_I, QQ_I(1,2)), + (RR, RR(3)), + (CC, CC(1,2)), + (EX, EX(3)), + (EXRAW, EXRAW(3)), + (ALG, ALG.from_sympy(sqrt(2))), + ] + + for R, e in elements: + for EE in EX, EXRAW: + elem = EE.from_sympy(R.to_sympy(e)) + assert EE.convert_from(e, R) == elem + assert R.convert_from(elem, EE) == e + + +def test_GlobalPolynomialRing_convert(): + K1 = QQ.old_poly_ring(x) + K2 = QQ[x] + assert K1.convert(x) == K1.convert(K2.convert(x), K2) + assert K2.convert(x) == K2.convert(K1.convert(x), K1) + + K1 = QQ.old_poly_ring(x, y) + K2 = QQ[x] + assert K1.convert(x) == K1.convert(K2.convert(x), K2) + #assert K2.convert(x) == K2.convert(K1.convert(x), K1) + + K1 = ZZ.old_poly_ring(x, y) + K2 = QQ[x] + assert K1.convert(x) == K1.convert(K2.convert(x), K2) + #assert K2.convert(x) == K2.convert(K1.convert(x), K1) + + +def test_PolynomialRing__init(): + R, = ring("", ZZ) + assert ZZ.poly_ring() == R.to_domain() + + +def test_FractionField__init(): + F, = field("", ZZ) + assert ZZ.frac_field() == F.to_domain() + + +def test_FractionField_convert(): + K = QQ.frac_field(x) + assert K.convert(QQ(2, 3), QQ) == K.from_sympy(Rational(2, 3)) + K = QQ.frac_field(x) + assert K.convert(ZZ(2), ZZ) == K.from_sympy(Integer(2)) + + +def test_inject(): + assert ZZ.inject(x, y, z) == ZZ[x, y, z] + assert ZZ[x].inject(y, z) == ZZ[x, y, z] + assert ZZ.frac_field(x).inject(y, z) == ZZ.frac_field(x, y, z) + raises(GeneratorsError, lambda: ZZ[x].inject(x)) + + +def test_drop(): + assert ZZ.drop(x) == ZZ + assert ZZ[x].drop(x) == ZZ + assert ZZ[x, y].drop(x) == ZZ[y] + assert ZZ.frac_field(x).drop(x) == ZZ + assert ZZ.frac_field(x, y).drop(x) == ZZ.frac_field(y) + assert ZZ[x][y].drop(y) == ZZ[x] + assert ZZ[x][y].drop(x) == ZZ[y] + assert ZZ.frac_field(x)[y].drop(x) == ZZ[y] + assert ZZ.frac_field(x)[y].drop(y) == ZZ.frac_field(x) + Ky = FiniteExtension(Poly(x**2-1, x, domain=ZZ[y])) + K = FiniteExtension(Poly(x**2-1, x, domain=ZZ)) + assert Ky.drop(y) == K + raises(GeneratorsError, lambda: Ky.drop(x)) + + +def test_Domain_map(): + seq = ZZ.map([1, 2, 3, 4]) + + assert all(ZZ.of_type(elt) for elt in seq) + + seq = ZZ.map([[1, 2, 3, 4]]) + + assert all(ZZ.of_type(elt) for elt in seq[0]) and len(seq) == 1 + + +def test_Domain___eq__(): + assert (ZZ[x, y] == ZZ[x, y]) is True + assert (QQ[x, y] == QQ[x, y]) is True + + assert (ZZ[x, y] == QQ[x, y]) is False + assert (QQ[x, y] == ZZ[x, y]) is False + + assert (ZZ.frac_field(x, y) == ZZ.frac_field(x, y)) is True + assert (QQ.frac_field(x, y) == QQ.frac_field(x, y)) is True + + assert (ZZ.frac_field(x, y) == QQ.frac_field(x, y)) is False + assert (QQ.frac_field(x, y) == ZZ.frac_field(x, y)) is False + + assert RealField()[x] == RR[x] + + +def test_Domain__algebraic_field(): + alg = ZZ.algebraic_field(sqrt(2)) + assert alg.ext.minpoly == Poly(x**2 - 2) + assert alg.dom == QQ + + alg = QQ.algebraic_field(sqrt(2)) + assert alg.ext.minpoly == Poly(x**2 - 2) + assert alg.dom == QQ + + alg = alg.algebraic_field(sqrt(3)) + assert alg.ext.minpoly == Poly(x**4 - 10*x**2 + 1) + assert alg.dom == QQ + + +def test_Domain_alg_field_from_poly(): + f = Poly(x**2 - 2) + g = Poly(x**2 - 3) + h = Poly(x**4 - 10*x**2 + 1) + + alg = ZZ.alg_field_from_poly(f) + assert alg.ext.minpoly == f + assert alg.dom == QQ + + alg = QQ.alg_field_from_poly(f) + assert alg.ext.minpoly == f + assert alg.dom == QQ + + alg = alg.alg_field_from_poly(g) + assert alg.ext.minpoly == h + assert alg.dom == QQ + + +def test_Domain_cyclotomic_field(): + K = ZZ.cyclotomic_field(12) + assert K.ext.minpoly == Poly(cyclotomic_poly(12)) + assert K.dom == QQ + + F = QQ.cyclotomic_field(3) + assert F.ext.minpoly == Poly(cyclotomic_poly(3)) + assert F.dom == QQ + + E = F.cyclotomic_field(4) + assert field_isomorphism(E.ext, K.ext) is not None + assert E.dom == QQ + + +def test_PolynomialRing_from_FractionField(): + F, x,y = field("x,y", ZZ) + R, X,Y = ring("x,y", ZZ) + + f = (x**2 + y**2)/(x + 1) + g = (x**2 + y**2)/4 + h = x**2 + y**2 + + assert R.to_domain().from_FractionField(f, F.to_domain()) is None + assert R.to_domain().from_FractionField(g, F.to_domain()) == X**2/4 + Y**2/4 + assert R.to_domain().from_FractionField(h, F.to_domain()) == X**2 + Y**2 + + F, x,y = field("x,y", QQ) + R, X,Y = ring("x,y", QQ) + + f = (x**2 + y**2)/(x + 1) + g = (x**2 + y**2)/4 + h = x**2 + y**2 + + assert R.to_domain().from_FractionField(f, F.to_domain()) is None + assert R.to_domain().from_FractionField(g, F.to_domain()) == X**2/4 + Y**2/4 + assert R.to_domain().from_FractionField(h, F.to_domain()) == X**2 + Y**2 + + +def test_FractionField_from_PolynomialRing(): + R, x,y = ring("x,y", QQ) + F, X,Y = field("x,y", ZZ) + + f = 3*x**2 + 5*y**2 + g = x**2/3 + y**2/5 + + assert F.to_domain().from_PolynomialRing(f, R.to_domain()) == 3*X**2 + 5*Y**2 + assert F.to_domain().from_PolynomialRing(g, R.to_domain()) == (5*X**2 + 3*Y**2)/15 + + +def test_FF_of_type(): + # XXX: of_type is not very useful here because in the case of ground types + # = flint all elements are of type nmod. + assert FF(3).of_type(FF(3)(1)) is True + assert FF(5).of_type(FF(5)(3)) is True + + +def test___eq__(): + assert not QQ[x] == ZZ[x] + assert not QQ.frac_field(x) == ZZ.frac_field(x) + + +def test_RealField_from_sympy(): + assert RR.convert(S.Zero) == RR.dtype(0) + assert RR.convert(S(0.0)) == RR.dtype(0.0) + assert RR.convert(S.One) == RR.dtype(1) + assert RR.convert(S(1.0)) == RR.dtype(1.0) + assert RR.convert(sin(1)) == RR.dtype(sin(1).evalf()) + + +def test_not_in_any_domain(): + check = list(_illegal) + [x] + [ + float(i) for i in _illegal[:3]] + for dom in (ZZ, QQ, RR, CC, EX): + for i in check: + if i == x and dom == EX: + continue + assert i not in dom, (i, dom) + raises(CoercionFailed, lambda: dom.convert(i)) + + +def test_ModularInteger(): + F3 = FF(3) + + a = F3(0) + assert F3.of_type(a) and a == 0 + a = F3(1) + assert F3.of_type(a) and a == 1 + a = F3(2) + assert F3.of_type(a) and a == 2 + a = F3(3) + assert F3.of_type(a) and a == 0 + a = F3(4) + assert F3.of_type(a) and a == 1 + + a = F3(F3(0)) + assert F3.of_type(a) and a == 0 + a = F3(F3(1)) + assert F3.of_type(a) and a == 1 + a = F3(F3(2)) + assert F3.of_type(a) and a == 2 + a = F3(F3(3)) + assert F3.of_type(a) and a == 0 + a = F3(F3(4)) + assert F3.of_type(a) and a == 1 + + a = -F3(1) + assert F3.of_type(a) and a == 2 + a = -F3(2) + assert F3.of_type(a) and a == 1 + + a = 2 + F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2) + 2 + assert F3.of_type(a) and a == 1 + a = F3(2) + F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2) + F3(2) + assert F3.of_type(a) and a == 1 + + a = 3 - F3(2) + assert F3.of_type(a) and a == 1 + a = F3(3) - 2 + assert F3.of_type(a) and a == 1 + a = F3(3) - F3(2) + assert F3.of_type(a) and a == 1 + a = F3(3) - F3(2) + assert F3.of_type(a) and a == 1 + + a = 2*F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)*2 + assert F3.of_type(a) and a == 1 + a = F3(2)*F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)*F3(2) + assert F3.of_type(a) and a == 1 + + a = 2/F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)/2 + assert F3.of_type(a) and a == 1 + a = F3(2)/F3(2) + assert F3.of_type(a) and a == 1 + a = F3(2)/F3(2) + assert F3.of_type(a) and a == 1 + + a = F3(2)**0 + assert F3.of_type(a) and a == 1 + a = F3(2)**1 + assert F3.of_type(a) and a == 2 + a = F3(2)**2 + assert F3.of_type(a) and a == 1 + + F7 = FF(7) + + a = F7(3)**100000000000 + assert F7.of_type(a) and a == 4 + a = F7(3)**-100000000000 + assert F7.of_type(a) and a == 2 + + assert bool(F3(3)) is False + assert bool(F3(4)) is True + + F5 = FF(5) + + a = F5(1)**(-1) + assert F5.of_type(a) and a == 1 + a = F5(2)**(-1) + assert F5.of_type(a) and a == 3 + a = F5(3)**(-1) + assert F5.of_type(a) and a == 2 + a = F5(4)**(-1) + assert F5.of_type(a) and a == 4 + + if GROUND_TYPES != 'flint': + # XXX: This gives a core dump with python-flint... + raises(NotInvertible, lambda: F5(0)**(-1)) + raises(NotInvertible, lambda: F5(5)**(-1)) + + raises(ValueError, lambda: FF(0)) + raises(ValueError, lambda: FF(2.1)) + + for n1 in range(5): + for n2 in range(5): + if GROUND_TYPES != 'flint': + with warns_deprecated_sympy(): + assert (F5(n1) < F5(n2)) is (n1 < n2) + with warns_deprecated_sympy(): + assert (F5(n1) <= F5(n2)) is (n1 <= n2) + with warns_deprecated_sympy(): + assert (F5(n1) > F5(n2)) is (n1 > n2) + with warns_deprecated_sympy(): + assert (F5(n1) >= F5(n2)) is (n1 >= n2) + else: + raises(TypeError, lambda: F5(n1) < F5(n2)) + raises(TypeError, lambda: F5(n1) <= F5(n2)) + raises(TypeError, lambda: F5(n1) > F5(n2)) + raises(TypeError, lambda: F5(n1) >= F5(n2)) + + # https://github.com/sympy/sympy/issues/26789 + assert GF(Integer(5)) == F5 + assert F5(Integer(3)) == F5(3) + + +def test_QQ_int(): + assert int(QQ(2**2000, 3**1250)) == 455431 + assert int(QQ(2**100, 3)) == 422550200076076467165567735125 + + +def test_RR_double(): + assert RR(3.14) > 1e-50 + assert RR(1e-13) > 1e-50 + assert RR(1e-14) > 1e-50 + assert RR(1e-15) > 1e-50 + assert RR(1e-20) > 1e-50 + assert RR(1e-40) > 1e-50 + + +def test_RR_Float(): + f1 = Float("1.01") + f2 = Float("1.0000000000000000000001") + assert f1._prec == 53 + assert f2._prec == 80 + assert RR(f1)-1 > 1e-50 + assert RR(f2)-1 < 1e-50 # RR's precision is lower than f2's + + RR2 = RealField(prec=f2._prec) + assert RR2(f1)-1 > 1e-50 + assert RR2(f2)-1 > 1e-50 # RR's precision is equal to f2's + + +def test_CC_double(): + assert CC(3.14).real > 1e-50 + assert CC(1e-13).real > 1e-50 + assert CC(1e-14).real > 1e-50 + assert CC(1e-15).real > 1e-50 + assert CC(1e-20).real > 1e-50 + assert CC(1e-40).real > 1e-50 + + assert CC(3.14j).imag > 1e-50 + assert CC(1e-13j).imag > 1e-50 + assert CC(1e-14j).imag > 1e-50 + assert CC(1e-15j).imag > 1e-50 + assert CC(1e-20j).imag > 1e-50 + assert CC(1e-40j).imag > 1e-50 + + +def test_gaussian_domains(): + I = S.ImaginaryUnit + a, b, c, d = [ZZ_I.convert(x) for x in (5, 2 + I, 3 - I, 5 - 5*I)] + assert ZZ_I.gcd(a, b) == b + assert ZZ_I.gcd(a, c) == b + assert ZZ_I.lcm(a, b) == a + assert ZZ_I.lcm(a, c) == d + assert ZZ_I(3, 4) != QQ_I(3, 4) # XXX is this right or should QQ->ZZ if possible? + assert ZZ_I(3, 0) != 3 # and should this go to Integer? + assert QQ_I(S(3)/4, 0) != S(3)/4 # and this to Rational? + assert ZZ_I(0, 0).quadrant() == 0 + assert ZZ_I(-1, 0).quadrant() == 2 + + assert QQ_I.convert(QQ(3, 2)) == QQ_I(QQ(3, 2), QQ(0)) + assert QQ_I.convert(QQ(3, 2), QQ) == QQ_I(QQ(3, 2), QQ(0)) + + for G in (QQ_I, ZZ_I): + + q = G(3, 4) + assert str(q) == '3 + 4*I' + assert q.parent() == G + assert q._get_xy(pi) == (None, None) + assert q._get_xy(2) == (2, 0) + assert q._get_xy(2*I) == (0, 2) + + assert hash(q) == hash((3, 4)) + assert G(1, 2) == G(1, 2) + assert G(1, 2) != G(1, 3) + assert G(3, 0) == G(3) + + assert q + q == G(6, 8) + assert q - q == G(0, 0) + assert 3 - q == -q + 3 == G(0, -4) + assert 3 + q == q + 3 == G(6, 4) + assert q * q == G(-7, 24) + assert 3 * q == q * 3 == G(9, 12) + assert q ** 0 == G(1, 0) + assert q ** 1 == q + assert q ** 2 == q * q == G(-7, 24) + assert q ** 3 == q * q * q == G(-117, 44) + assert 1 / q == q ** -1 == QQ_I(S(3)/25, - S(4)/25) + assert q / 1 == QQ_I(3, 4) + assert q / 2 == QQ_I(S(3)/2, 2) + assert q/3 == QQ_I(1, S(4)/3) + assert 3/q == QQ_I(S(9)/25, -S(12)/25) + i, r = divmod(q, 2) + assert 2*i + r == q + i, r = divmod(2, q) + assert q*i + r == G(2, 0) + + raises(ZeroDivisionError, lambda: q % 0) + raises(ZeroDivisionError, lambda: q / 0) + raises(ZeroDivisionError, lambda: q // 0) + raises(ZeroDivisionError, lambda: divmod(q, 0)) + raises(ZeroDivisionError, lambda: divmod(q, 0)) + raises(TypeError, lambda: q + x) + raises(TypeError, lambda: q - x) + raises(TypeError, lambda: x + q) + raises(TypeError, lambda: x - q) + raises(TypeError, lambda: q * x) + raises(TypeError, lambda: x * q) + raises(TypeError, lambda: q / x) + raises(TypeError, lambda: x / q) + raises(TypeError, lambda: q // x) + raises(TypeError, lambda: x // q) + + assert G.from_sympy(S(2)) == G(2, 0) + assert G.to_sympy(G(2, 0)) == S(2) + raises(CoercionFailed, lambda: G.from_sympy(pi)) + + PR = G.inject(x) + assert isinstance(PR, PolynomialRing) + assert PR.domain == G + assert len(PR.gens) == 1 and PR.gens[0].as_expr() == x + + if G is QQ_I: + AF = G.as_AlgebraicField() + assert isinstance(AF, AlgebraicField) + assert AF.domain == QQ + assert AF.ext.args[0] == I + + for qi in [G(-1, 0), G(1, 0), G(0, -1), G(0, 1)]: + assert G.is_negative(qi) is False + assert G.is_positive(qi) is False + assert G.is_nonnegative(qi) is False + assert G.is_nonpositive(qi) is False + + domains = [ZZ, QQ, AlgebraicField(QQ, I)] + + # XXX: These domains are all obsolete because ZZ/QQ with MPZ/MPQ + # already use either gmpy, flint or python depending on the + # availability of these libraries. We can keep these tests for now but + # ideally we should remove these alternate domains entirely. + domains += [ZZ_python(), QQ_python()] + if GROUND_TYPES == 'gmpy': + domains += [ZZ_gmpy(), QQ_gmpy()] + + for K in domains: + assert G.convert(K(2)) == G(2, 0) + assert G.convert(K(2), K) == G(2, 0) + + for K in ZZ_I, QQ_I: + assert G.convert(K(1, 1)) == G(1, 1) + assert G.convert(K(1, 1), K) == G(1, 1) + + if G == ZZ_I: + assert repr(q) == 'ZZ_I(3, 4)' + assert q//3 == G(1, 1) + assert 12//q == G(1, -2) + assert 12 % q == G(1, 2) + assert q % 2 == G(-1, 0) + assert i == G(0, 0) + assert r == G(2, 0) + assert G.get_ring() == G + assert G.get_field() == QQ_I + else: + assert repr(q) == 'QQ_I(3, 4)' + assert G.get_ring() == ZZ_I + assert G.get_field() == G + assert q//3 == G(1, S(4)/3) + assert 12//q == G(S(36)/25, -S(48)/25) + assert 12 % q == G(0, 0) + assert q % 2 == G(0, 0) + assert i == G(S(6)/25, -S(8)/25), (G,i) + assert r == G(0, 0) + q2 = G(S(3)/2, S(5)/3) + assert G.numer(q2) == ZZ_I(9, 10) + assert G.denom(q2) == ZZ_I(6) + + +def test_EX_EXRAW(): + assert EXRAW.zero is S.Zero + assert EXRAW.one is S.One + + assert EX(1) == EX.Expression(1) + assert EX(1).ex is S.One + assert EXRAW(1) is S.One + + # EX has cancelling but EXRAW does not + assert 2*EX((x + y*x)/x) == EX(2 + 2*y) != 2*((x + y*x)/x) + assert 2*EXRAW((x + y*x)/x) == 2*((x + y*x)/x) != (1 + y) + + assert EXRAW.convert_from(EX(1), EX) is EXRAW.one + assert EX.convert_from(EXRAW(1), EXRAW) == EX.one + + assert EXRAW.from_sympy(S.One) is S.One + assert EXRAW.to_sympy(EXRAW.one) is S.One + raises(CoercionFailed, lambda: EXRAW.from_sympy([])) + + assert EXRAW.get_field() == EXRAW + + assert EXRAW.unify(EX) == EXRAW + assert EX.unify(EXRAW) == EXRAW + + +def test_EX_ordering(): + elements = [EX(1), EX(x), EX(3)] + assert sorted(elements) == [EX(1), EX(3), EX(x)] + + +def test_canonical_unit(): + + for K in [ZZ, QQ, RR]: # CC? + assert K.canonical_unit(K(2)) == K(1) + assert K.canonical_unit(K(-2)) == K(-1) + + for K in [ZZ_I, QQ_I]: + i = K.from_sympy(I) + assert K.canonical_unit(K(2)) == K(1) + assert K.canonical_unit(K(2)*i) == -i + assert K.canonical_unit(-K(2)) == K(-1) + assert K.canonical_unit(-K(2)*i) == i + + K = ZZ[x] + assert K.canonical_unit(K(x + 1)) == K(1) + assert K.canonical_unit(K(-x + 1)) == K(-1) + + K = ZZ_I[x] + assert K.canonical_unit(K.from_sympy(I*x)) == ZZ_I(0, -1) + + K = ZZ_I.frac_field(x, y) + i = K.from_sympy(I) + assert i / i == K.one + assert (K.one + i)/(i - K.one) == -i + + +def test_issue_18278(): + assert str(RR(2).parent()) == 'RR' + assert str(CC(2).parent()) == 'CC' + + +def test_Domain_is_negative(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_negative(a) == False + assert CC.is_negative(b) == False + + +def test_Domain_is_positive(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_positive(a) == False + assert CC.is_positive(b) == False + + +def test_Domain_is_nonnegative(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_nonnegative(a) == False + assert CC.is_nonnegative(b) == False + + +def test_Domain_is_nonpositive(): + I = S.ImaginaryUnit + a, b = [CC.convert(x) for x in (2 + I, 5)] + assert CC.is_nonpositive(a) == False + assert CC.is_nonpositive(b) == False + + +def test_exponential_domain(): + K = ZZ[E] + eK = K.from_sympy(E) + assert K.from_sympy(exp(3)) == eK ** 3 + assert K.convert(exp(3)) == eK ** 3 + + +def test_AlgebraicField_alias(): + # No default alias: + k = QQ.algebraic_field(sqrt(2)) + assert k.ext.alias is None + + # For a single extension, its alias is used: + alpha = AlgebraicNumber(sqrt(2), alias='alpha') + k = QQ.algebraic_field(alpha) + assert k.ext.alias.name == 'alpha' + + # Can override the alias of a single extension: + k = QQ.algebraic_field(alpha, alias='theta') + assert k.ext.alias.name == 'theta' + + # With multiple extensions, no default alias: + k = QQ.algebraic_field(sqrt(2), sqrt(3)) + assert k.ext.alias is None + + # With multiple extensions, no default alias, even if one of + # the extensions has one: + k = QQ.algebraic_field(alpha, sqrt(3)) + assert k.ext.alias is None + + # With multiple extensions, may set an alias: + k = QQ.algebraic_field(sqrt(2), sqrt(3), alias='theta') + assert k.ext.alias.name == 'theta' + + # Alias is passed to constructed field elements: + k = QQ.algebraic_field(alpha) + beta = k.to_alg_num(k([1, 2, 3])) + assert beta.alias is alpha.alias + + +def test_exsqrt(): + assert ZZ.is_square(ZZ(4)) is True + assert ZZ.exsqrt(ZZ(4)) == ZZ(2) + assert ZZ.is_square(ZZ(42)) is False + assert ZZ.exsqrt(ZZ(42)) is None + assert ZZ.is_square(ZZ(0)) is True + assert ZZ.exsqrt(ZZ(0)) == ZZ(0) + assert ZZ.is_square(ZZ(-1)) is False + assert ZZ.exsqrt(ZZ(-1)) is None + + assert QQ.is_square(QQ(9, 4)) is True + assert QQ.exsqrt(QQ(9, 4)) == QQ(3, 2) + assert QQ.is_square(QQ(18, 8)) is True + assert QQ.exsqrt(QQ(18, 8)) == QQ(3, 2) + assert QQ.is_square(QQ(-9, -4)) is True + assert QQ.exsqrt(QQ(-9, -4)) == QQ(3, 2) + assert QQ.is_square(QQ(11, 4)) is False + assert QQ.exsqrt(QQ(11, 4)) is None + assert QQ.is_square(QQ(9, 5)) is False + assert QQ.exsqrt(QQ(9, 5)) is None + assert QQ.is_square(QQ(4)) is True + assert QQ.exsqrt(QQ(4)) == QQ(2) + assert QQ.is_square(QQ(0)) is True + assert QQ.exsqrt(QQ(0)) == QQ(0) + assert QQ.is_square(QQ(-16, 9)) is False + assert QQ.exsqrt(QQ(-16, 9)) is None + + assert RR.is_square(RR(6.25)) is True + assert RR.exsqrt(RR(6.25)) == RR(2.5) + assert RR.is_square(RR(2)) is True + assert RR.almosteq(RR.exsqrt(RR(2)), RR(1.4142135623730951), tolerance=1e-15) + assert RR.is_square(RR(0)) is True + assert RR.exsqrt(RR(0)) == RR(0) + assert RR.is_square(RR(-1)) is False + assert RR.exsqrt(RR(-1)) is None + + assert CC.is_square(CC(2)) is True + assert CC.almosteq(CC.exsqrt(CC(2)), CC(1.4142135623730951), tolerance=1e-15) + assert CC.is_square(CC(0)) is True + assert CC.exsqrt(CC(0)) == CC(0) + assert CC.is_square(CC(-1)) is True + assert CC.exsqrt(CC(-1)) == CC(0, 1) + assert CC.is_square(CC(0, 2)) is True + assert CC.exsqrt(CC(0, 2)) == CC(1, 1) + assert CC.is_square(CC(-3, -4)) is True + assert CC.exsqrt(CC(-3, -4)) == CC(1, -2) + + F2 = FF(2) + assert F2.is_square(F2(1)) is True + assert F2.exsqrt(F2(1)) == F2(1) + assert F2.is_square(F2(0)) is True + assert F2.exsqrt(F2(0)) == F2(0) + + F7 = FF(7) + assert F7.is_square(F7(2)) is True + assert F7.exsqrt(F7(2)) == F7(3) + assert F7.is_square(F7(3)) is False + assert F7.exsqrt(F7(3)) is None + assert F7.is_square(F7(0)) is True + assert F7.exsqrt(F7(0)) == F7(0) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/tests/test_polynomialring.py b/lib/python3.10/site-packages/sympy/polys/domains/tests/test_polynomialring.py new file mode 100644 index 0000000000000000000000000000000000000000..6cb1fdf3f9f9250518289019b0bb108047e8cb6c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/tests/test_polynomialring.py @@ -0,0 +1,93 @@ +"""Tests for the PolynomialRing classes. """ + +from sympy.polys.domains import QQ, ZZ +from sympy.polys.polyerrors import ExactQuotientFailed, CoercionFailed, NotReversible + +from sympy.abc import x, y + +from sympy.testing.pytest import raises + + +def test_build_order(): + R = QQ.old_poly_ring(x, y, order=(("lex", x), ("ilex", y))) + assert R.order((1, 5)) == ((1,), (-5,)) + + +def test_globalring(): + Qxy = QQ.old_frac_field(x, y) + R = QQ.old_poly_ring(x, y) + X = R.convert(x) + Y = R.convert(y) + + assert x in R + assert 1/x not in R + assert 1/(1 + x) not in R + assert Y in R + assert X * (Y**2 + 1) == R.convert(x * (y**2 + 1)) + assert X + 1 == R.convert(x + 1) + raises(ExactQuotientFailed, lambda: X/Y) + raises(TypeError, lambda: x/Y) + raises(TypeError, lambda: X/y) + assert X**2 / X == X + + assert R.from_GlobalPolynomialRing(ZZ.old_poly_ring(x, y).convert(x), ZZ.old_poly_ring(x, y)) == X + assert R.from_FractionField(Qxy.convert(x), Qxy) == X + assert R.from_FractionField(Qxy.convert(x/y), Qxy) is None + + assert R._sdm_to_vector(R._vector_to_sdm([X, Y], R.order), 2) == [X, Y] + + +def test_localring(): + Qxy = QQ.old_frac_field(x, y) + R = QQ.old_poly_ring(x, y, order="ilex") + X = R.convert(x) + Y = R.convert(y) + + assert x in R + assert 1/x not in R + assert 1/(1 + x) in R + assert Y in R + assert X*(Y**2 + 1)/(1 + X) == R.convert(x*(y**2 + 1)/(1 + x)) + raises(TypeError, lambda: x/Y) + raises(TypeError, lambda: X/y) + assert X + 1 == R.convert(x + 1) + assert X**2 / X == X + + assert R.from_GlobalPolynomialRing(ZZ.old_poly_ring(x, y).convert(x), ZZ.old_poly_ring(x, y)) == X + assert R.from_FractionField(Qxy.convert(x), Qxy) == X + raises(CoercionFailed, lambda: R.from_FractionField(Qxy.convert(x/y), Qxy)) + raises(ExactQuotientFailed, lambda: R.exquo(X, Y)) + raises(NotReversible, lambda: R.revert(X)) + + assert R._sdm_to_vector( + R._vector_to_sdm([X/(X + 1), Y/(1 + X*Y)], R.order), 2) == \ + [X*(1 + X*Y), Y*(1 + X)] + + +def test_conversion(): + L = QQ.old_poly_ring(x, y, order="ilex") + G = QQ.old_poly_ring(x, y) + + assert L.convert(x) == L.convert(G.convert(x), G) + assert G.convert(x) == G.convert(L.convert(x), L) + raises(CoercionFailed, lambda: G.convert(L.convert(1/(1 + x)), L)) + + +def test_units(): + R = QQ.old_poly_ring(x) + assert R.is_unit(R.convert(1)) + assert R.is_unit(R.convert(2)) + assert not R.is_unit(R.convert(x)) + assert not R.is_unit(R.convert(1 + x)) + + R = QQ.old_poly_ring(x, order='ilex') + assert R.is_unit(R.convert(1)) + assert R.is_unit(R.convert(2)) + assert not R.is_unit(R.convert(x)) + assert R.is_unit(R.convert(1 + x)) + + R = ZZ.old_poly_ring(x) + assert R.is_unit(R.convert(1)) + assert not R.is_unit(R.convert(2)) + assert not R.is_unit(R.convert(x)) + assert not R.is_unit(R.convert(1 + x)) diff --git a/lib/python3.10/site-packages/sympy/polys/domains/tests/test_quotientring.py b/lib/python3.10/site-packages/sympy/polys/domains/tests/test_quotientring.py new file mode 100644 index 0000000000000000000000000000000000000000..aff167bdd72dc4400785efefef7b3e9057fd0727 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/domains/tests/test_quotientring.py @@ -0,0 +1,52 @@ +"""Tests for quotient rings.""" + +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.abc import x, y + +from sympy.polys.polyerrors import NotReversible + +from sympy.testing.pytest import raises + + +def test_QuotientRingElement(): + R = QQ.old_poly_ring(x)/[x**10] + X = R.convert(x) + + assert X*(X + 1) == R.convert(x**2 + x) + assert X*x == R.convert(x**2) + assert x*X == R.convert(x**2) + assert X + x == R.convert(2*x) + assert x + X == 2*X + assert X**2 == R.convert(x**2) + assert 1/(1 - X) == R.convert(sum(x**i for i in range(10))) + assert X**10 == R.zero + assert X != x + + raises(NotReversible, lambda: 1/X) + + +def test_QuotientRing(): + I = QQ.old_poly_ring(x).ideal(x**2 + 1) + R = QQ.old_poly_ring(x)/I + + assert R == QQ.old_poly_ring(x)/[x**2 + 1] + assert R == QQ.old_poly_ring(x)/QQ.old_poly_ring(x).ideal(x**2 + 1) + assert R != QQ.old_poly_ring(x) + + assert R.convert(1)/x == -x + I + assert -1 + I == x**2 + I + assert R.convert(ZZ(1), ZZ) == 1 + I + assert R.convert(R.convert(x), R) == R.convert(x) + + X = R.convert(x) + Y = QQ.old_poly_ring(x).convert(x) + assert -1 + I == X**2 + I + assert -1 + I == Y**2 + I + assert R.to_sympy(X) == x + + raises(ValueError, lambda: QQ.old_poly_ring(x)/QQ.old_poly_ring(x, y).ideal(x)) + + R = QQ.old_poly_ring(x, order="ilex") + I = R.ideal(x) + assert R.convert(1) + I == (R/I).convert(1) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__init__.py b/lib/python3.10/site-packages/sympy/polys/matrices/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e4ebc3d71ba3dac9ccc695d046d6b3d2ad940fa1 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/__init__.py @@ -0,0 +1,15 @@ +""" + +sympy.polys.matrices package. + +The main export from this package is the DomainMatrix class which is a +lower-level implementation of matrices based on the polys Domains. This +implementation is typically a lot faster than SymPy's standard Matrix class +but is a work in progress and is still experimental. + +""" +from .domainmatrix import DomainMatrix, DM + +__all__ = [ + 'DomainMatrix', 'DM', +] diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..627c8b9542b55bea20b04aefd4b7b13107cfde2e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/_dfm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/_dfm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..12c87e001f56fc7dd26eb3dd579047c7d8804ab4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/_dfm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/_typing.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/_typing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86fd69520b2dba710a3649a0b3af2369482dccfb Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/_typing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/ddm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/ddm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..71cb090d1ce4e1a24dbe5ce81a20a75ecf76b79a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/ddm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/dense.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/dense.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caae7435d730cc3f000ff44eed3fad442e9e22c7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/dense.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/dfm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/dfm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0beb11f1e707b346ad02ae719796fa3b9c0a07ee Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/dfm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/domainscalar.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/domainscalar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60912ab3cdd66aa0160525ab413bf0efa4c991b9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/domainscalar.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/eigen.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/eigen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97b894219a33270826bb3b66ac0cb59397bc9e43 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/eigen.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/exceptions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..456499f9994206831e39f2d135b7df77fe9d9cd6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/exceptions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/linsolve.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/linsolve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..146c6d1a513051fcfd7008e990c76337d3ddc177 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/linsolve.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/lll.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/lll.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a19c6ff3e24760c11e4ea22c7fcd76612262673 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/lll.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/normalforms.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/normalforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1bffa98f4dd95d3a5b260a22cbd29481e2abe9fc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/normalforms.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/rref.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/rref.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e8aefda3c70f275723f95bb11ac20374dd07f3c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/rref.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/sdm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/sdm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4599ae4a439f32df7c0118d7c203f9bdc238c798 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/__pycache__/sdm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/_dfm.py b/lib/python3.10/site-packages/sympy/polys/matrices/_dfm.py new file mode 100644 index 0000000000000000000000000000000000000000..d84fe136e6db146a2570739aaa968ae94473f665 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/_dfm.py @@ -0,0 +1,898 @@ +# +# sympy.polys.matrices.dfm +# +# This modules defines the DFM class which is a wrapper for dense flint +# matrices as found in python-flint. +# +# As of python-flint 0.4.1 matrices over the following domains can be supported +# by python-flint: +# +# ZZ: flint.fmpz_mat +# QQ: flint.fmpq_mat +# GF(p): flint.nmod_mat (p prime and p < ~2**62) +# +# The underlying flint library has many more domains, but these are not yet +# supported by python-flint. +# +# The DFM class is a wrapper for the flint matrices and provides a common +# interface for all supported domains that is interchangeable with the DDM +# and SDM classes so that DomainMatrix can be used with any as its internal +# matrix representation. +# + +# TODO: +# +# Implement the following methods that are provided by python-flint: +# +# - hnf (Hermite normal form) +# - snf (Smith normal form) +# - minpoly +# - is_hnf +# - is_snf +# - rank +# +# The other types DDM and SDM do not have these methods and the algorithms +# for hnf, snf and rank are already implemented. Algorithms for minpoly, +# is_hnf and is_snf would need to be added. +# +# Add more methods to python-flint to expose more of Flint's functionality +# and also to make some of the above methods simpler or more efficient e.g. +# slicing, fancy indexing etc. + +from sympy.external.gmpy import GROUND_TYPES +from sympy.external.importtools import import_module +from sympy.utilities.decorator import doctest_depends_on + +from sympy.polys.domains import ZZ, QQ + +from .exceptions import ( + DMBadInputError, + DMDomainError, + DMNonSquareMatrixError, + DMNonInvertibleMatrixError, + DMRankError, + DMShapeError, + DMValueError, +) + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['*'] + + +flint = import_module('flint') + + +__all__ = ['DFM'] + + +@doctest_depends_on(ground_types=['flint']) +class DFM: + """ + Dense FLINT matrix. This class is a wrapper for matrices from python-flint. + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.matrices.dfm import DFM + >>> dfm = DFM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.rep + [1, 2] + [3, 4] + >>> type(dfm.rep) # doctest: +SKIP + + + Usually, the DFM class is not instantiated directly, but is created as the + internal representation of :class:`~.DomainMatrix`. When + `SYMPY_GROUND_TYPES` is set to `flint` and `python-flint` is installed, the + :class:`DFM` class is used automatically as the internal representation of + :class:`~.DomainMatrix` in dense format if the domain is supported by + python-flint. + + >>> from sympy.polys.matrices.domainmatrix import DM + >>> dM = DM([[1, 2], [3, 4]], ZZ) + >>> dM.rep + [[1, 2], [3, 4]] + + A :class:`~.DomainMatrix` can be converted to :class:`DFM` by calling the + :meth:`to_dfm` method: + + >>> dM.to_dfm() + [[1, 2], [3, 4]] + + """ + + fmt = 'dense' + is_DFM = True + is_DDM = False + + def __new__(cls, rowslist, shape, domain): + """Construct from a nested list.""" + flint_mat = cls._get_flint_func(domain) + + if 0 not in shape: + try: + rep = flint_mat(rowslist) + except (ValueError, TypeError): + raise DMBadInputError(f"Input should be a list of list of {domain}") + else: + rep = flint_mat(*shape) + + return cls._new(rep, shape, domain) + + @classmethod + def _new(cls, rep, shape, domain): + """Internal constructor from a flint matrix.""" + cls._check(rep, shape, domain) + obj = object.__new__(cls) + obj.rep = rep + obj.shape = obj.rows, obj.cols = shape + obj.domain = domain + return obj + + def _new_rep(self, rep): + """Create a new DFM with the same shape and domain but a new rep.""" + return self._new(rep, self.shape, self.domain) + + @classmethod + def _check(cls, rep, shape, domain): + repshape = (rep.nrows(), rep.ncols()) + if repshape != shape: + raise DMBadInputError("Shape of rep does not match shape of DFM") + if domain == ZZ and not isinstance(rep, flint.fmpz_mat): + raise RuntimeError("Rep is not a flint.fmpz_mat") + elif domain == QQ and not isinstance(rep, flint.fmpq_mat): + raise RuntimeError("Rep is not a flint.fmpq_mat") + elif domain not in (ZZ, QQ): + raise NotImplementedError("Only ZZ and QQ are supported by DFM") + + @classmethod + def _supports_domain(cls, domain): + """Return True if the given domain is supported by DFM.""" + return domain in (ZZ, QQ) + + @classmethod + def _get_flint_func(cls, domain): + """Return the flint matrix class for the given domain.""" + if domain == ZZ: + return flint.fmpz_mat + elif domain == QQ: + return flint.fmpq_mat + else: + raise NotImplementedError("Only ZZ and QQ are supported by DFM") + + @property + def _func(self): + """Callable to create a flint matrix of the same domain.""" + return self._get_flint_func(self.domain) + + def __str__(self): + """Return ``str(self)``.""" + return str(self.to_ddm()) + + def __repr__(self): + """Return ``repr(self)``.""" + return f'DFM{repr(self.to_ddm())[3:]}' + + def __eq__(self, other): + """Return ``self == other``.""" + if not isinstance(other, DFM): + return NotImplemented + # Compare domains first because we do *not* want matrices with + # different domains to be equal but e.g. a flint fmpz_mat and fmpq_mat + # with the same entries will compare equal. + return self.domain == other.domain and self.rep == other.rep + + @classmethod + def from_list(cls, rowslist, shape, domain): + """Construct from a nested list.""" + return cls(rowslist, shape, domain) + + def to_list(self): + """Convert to a nested list.""" + return self.rep.tolist() + + def copy(self): + """Return a copy of self.""" + return self._new_rep(self._func(self.rep)) + + def to_ddm(self): + """Convert to a DDM.""" + return DDM.from_list(self.to_list(), self.shape, self.domain) + + def to_sdm(self): + """Convert to a SDM.""" + return SDM.from_list(self.to_list(), self.shape, self.domain) + + def to_dfm(self): + """Return self.""" + return self + + def to_dfm_or_ddm(self): + """ + Convert to a :class:`DFM`. + + This :class:`DFM` method exists to parallel the :class:`~.DDM` and + :class:`~.SDM` methods. For :class:`DFM` it will always return self. + + See Also + ======== + + to_ddm + to_sdm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm_or_ddm + """ + return self + + @classmethod + def from_ddm(cls, ddm): + """Convert from a DDM.""" + return cls.from_list(ddm.to_list(), ddm.shape, ddm.domain) + + @classmethod + def from_list_flat(cls, elements, shape, domain): + """Inverse of :meth:`to_list_flat`.""" + func = cls._get_flint_func(domain) + try: + rep = func(*shape, elements) + except ValueError: + raise DMBadInputError(f"Incorrect number of elements for shape {shape}") + except TypeError: + raise DMBadInputError(f"Input should be a list of {domain}") + return cls(rep, shape, domain) + + def to_list_flat(self): + """Convert to a flat list.""" + return self.rep.entries() + + def to_flat_nz(self): + """Convert to a flat list of non-zeros.""" + return self.to_ddm().to_flat_nz() + + @classmethod + def from_flat_nz(cls, elements, data, domain): + """Inverse of :meth:`to_flat_nz`.""" + return DDM.from_flat_nz(elements, data, domain).to_dfm() + + def to_dod(self): + """Convert to a DOD.""" + return self.to_ddm().to_dod() + + @classmethod + def from_dod(cls, dod, shape, domain): + """Inverse of :meth:`to_dod`.""" + return DDM.from_dod(dod, shape, domain).to_dfm() + + def to_dok(self): + """Convert to a DOK.""" + return self.to_ddm().to_dok() + + @classmethod + def from_dok(cls, dok, shape, domain): + """Inverse of :math:`to_dod`.""" + return DDM.from_dok(dok, shape, domain).to_dfm() + + def iter_values(self): + """Iterater over the non-zero values of the matrix.""" + m, n = self.shape + rep = self.rep + for i in range(m): + for j in range(n): + repij = rep[i, j] + if repij: + yield rep[i, j] + + def iter_items(self): + """Iterate over indices and values of nonzero elements of the matrix.""" + m, n = self.shape + rep = self.rep + for i in range(m): + for j in range(n): + repij = rep[i, j] + if repij: + yield ((i, j), repij) + + def convert_to(self, domain): + """Convert to a new domain.""" + if domain == self.domain: + return self.copy() + elif domain == QQ and self.domain == ZZ: + return self._new(flint.fmpq_mat(self.rep), self.shape, domain) + elif domain == ZZ and self.domain == QQ: + # XXX: python-flint has no fmpz_mat.from_fmpq_mat + return self.to_ddm().convert_to(domain).to_dfm() + else: + # It is the callers responsibility to convert to DDM before calling + # this method if the domain is not supported by DFM. + raise NotImplementedError("Only ZZ and QQ are supported by DFM") + + def getitem(self, i, j): + """Get the ``(i, j)``-th entry.""" + # XXX: flint matrices do not support negative indices + # XXX: They also raise ValueError instead of IndexError + m, n = self.shape + if i < 0: + i += m + if j < 0: + j += n + try: + return self.rep[i, j] + except ValueError: + raise IndexError(f"Invalid indices ({i}, {j}) for Matrix of shape {self.shape}") + + def setitem(self, i, j, value): + """Set the ``(i, j)``-th entry.""" + # XXX: flint matrices do not support negative indices + # XXX: They also raise ValueError instead of IndexError + m, n = self.shape + if i < 0: + i += m + if j < 0: + j += n + try: + self.rep[i, j] = value + except ValueError: + raise IndexError(f"Invalid indices ({i}, {j}) for Matrix of shape {self.shape}") + + def _extract(self, i_indices, j_indices): + """Extract a submatrix with no checking.""" + # Indices must be positive and in range. + M = self.rep + lol = [[M[i, j] for j in j_indices] for i in i_indices] + shape = (len(i_indices), len(j_indices)) + return self.from_list(lol, shape, self.domain) + + def extract(self, rowslist, colslist): + """Extract a submatrix.""" + # XXX: flint matrices do not support fancy indexing or negative indices + # + # Check and convert negative indices before calling _extract. + m, n = self.shape + + new_rows = [] + new_cols = [] + + for i in rowslist: + if i < 0: + i_pos = i + m + else: + i_pos = i + if not 0 <= i_pos < m: + raise IndexError(f"Invalid row index {i} for Matrix of shape {self.shape}") + new_rows.append(i_pos) + + for j in colslist: + if j < 0: + j_pos = j + n + else: + j_pos = j + if not 0 <= j_pos < n: + raise IndexError(f"Invalid column index {j} for Matrix of shape {self.shape}") + new_cols.append(j_pos) + + return self._extract(new_rows, new_cols) + + def extract_slice(self, rowslice, colslice): + """Slice a DFM.""" + # XXX: flint matrices do not support slicing + m, n = self.shape + i_indices = range(m)[rowslice] + j_indices = range(n)[colslice] + return self._extract(i_indices, j_indices) + + def neg(self): + """Negate a DFM matrix.""" + return self._new_rep(-self.rep) + + def add(self, other): + """Add two DFM matrices.""" + return self._new_rep(self.rep + other.rep) + + def sub(self, other): + """Subtract two DFM matrices.""" + return self._new_rep(self.rep - other.rep) + + def mul(self, other): + """Multiply a DFM matrix from the right by a scalar.""" + return self._new_rep(self.rep * other) + + def rmul(self, other): + """Multiply a DFM matrix from the left by a scalar.""" + return self._new_rep(other * self.rep) + + def mul_elementwise(self, other): + """Elementwise multiplication of two DFM matrices.""" + # XXX: flint matrices do not support elementwise multiplication + return self.to_ddm().mul_elementwise(other.to_ddm()).to_dfm() + + def matmul(self, other): + """Multiply two DFM matrices.""" + shape = (self.rows, other.cols) + return self._new(self.rep * other.rep, shape, self.domain) + + # XXX: For the most part DomainMatrix does not expect DDM, SDM, or DFM to + # have arithmetic operators defined. The only exception is negation. + # Perhaps that should be removed. + + def __neg__(self): + """Negate a DFM matrix.""" + return self.neg() + + @classmethod + def zeros(cls, shape, domain): + """Return a zero DFM matrix.""" + func = cls._get_flint_func(domain) + return cls._new(func(*shape), shape, domain) + + # XXX: flint matrices do not have anything like ones or eye + # In the methods below we convert to DDM and then back to DFM which is + # probably about as efficient as implementing these methods directly. + + @classmethod + def ones(cls, shape, domain): + """Return a one DFM matrix.""" + # XXX: flint matrices do not have anything like ones + return DDM.ones(shape, domain).to_dfm() + + @classmethod + def eye(cls, n, domain): + """Return the identity matrix of size n.""" + # XXX: flint matrices do not have anything like eye + return DDM.eye(n, domain).to_dfm() + + @classmethod + def diag(cls, elements, domain): + """Return a diagonal matrix.""" + return DDM.diag(elements, domain).to_dfm() + + def applyfunc(self, func, domain): + """Apply a function to each entry of a DFM matrix.""" + return self.to_ddm().applyfunc(func, domain).to_dfm() + + def transpose(self): + """Transpose a DFM matrix.""" + return self._new(self.rep.transpose(), (self.cols, self.rows), self.domain) + + def hstack(self, *others): + """Horizontally stack matrices.""" + return self.to_ddm().hstack(*[o.to_ddm() for o in others]).to_dfm() + + def vstack(self, *others): + """Vertically stack matrices.""" + return self.to_ddm().vstack(*[o.to_ddm() for o in others]).to_dfm() + + def diagonal(self): + """Return the diagonal of a DFM matrix.""" + M = self.rep + m, n = self.shape + return [M[i, i] for i in range(min(m, n))] + + def is_upper(self): + """Return ``True`` if the matrix is upper triangular.""" + M = self.rep + for i in range(self.rows): + for j in range(i): + if M[i, j]: + return False + return True + + def is_lower(self): + """Return ``True`` if the matrix is lower triangular.""" + M = self.rep + for i in range(self.rows): + for j in range(i + 1, self.cols): + if M[i, j]: + return False + return True + + def is_diagonal(self): + """Return ``True`` if the matrix is diagonal.""" + return self.is_upper() and self.is_lower() + + def is_zero_matrix(self): + """Return ``True`` if the matrix is the zero matrix.""" + M = self.rep + for i in range(self.rows): + for j in range(self.cols): + if M[i, j]: + return False + return True + + def nnz(self): + """Return the number of non-zero elements in the matrix.""" + return self.to_ddm().nnz() + + def scc(self): + """Return the strongly connected components of the matrix.""" + return self.to_ddm().scc() + + @doctest_depends_on(ground_types='flint') + def det(self): + """ + Compute the determinant of the matrix using FLINT. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm() + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.det() + -2 + + Notes + ===== + + Calls the ``.det()`` method of the underlying FLINT matrix. + + For :ref:`ZZ` or :ref:`QQ` this calls ``fmpz_mat_det`` or + ``fmpq_mat_det`` respectively. + + At the time of writing the implementation of ``fmpz_mat_det`` uses one + of several algorithms depending on the size of the matrix and bit size + of the entries. The algorithms used are: + + - Cofactor for very small (up to 4x4) matrices. + - Bareiss for small (up to 25x25) matrices. + - Modular algorithms for larger matrices (up to 60x60) or for larger + matrices with large bit sizes. + - Modular "accelerated" for larger matrices (60x60 upwards) if the bit + size is smaller than the dimensions of the matrix. + + The implementation of ``fmpq_mat_det`` clears denominators from each + row (not the whole matrix) and then calls ``fmpz_mat_det`` and divides + by the product of the denominators. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.det + Higher level interface to compute the determinant of a matrix. + """ + # XXX: At least the first three algorithms described above should also + # be implemented in the pure Python DDM and SDM classes which at the + # time of writng just use Bareiss for all matrices and domains. + # Probably in Python the thresholds would be different though. + return self.rep.det() + + @doctest_depends_on(ground_types='flint') + def charpoly(self): + """ + Compute the characteristic polynomial of the matrix using FLINT. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm() # need ground types = 'flint' + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.charpoly() + [1, -5, -2] + + Notes + ===== + + Calls the ``.charpoly()`` method of the underlying FLINT matrix. + + For :ref:`ZZ` or :ref:`QQ` this calls ``fmpz_mat_charpoly`` or + ``fmpq_mat_charpoly`` respectively. + + At the time of writing the implementation of ``fmpq_mat_charpoly`` + clears a denominator from the whole matrix and then calls + ``fmpz_mat_charpoly``. The coefficients of the characteristic + polynomial are then multiplied by powers of the denominator. + + The ``fmpz_mat_charpoly`` method uses a modular algorithm with CRT + reconstruction. The modular algorithm uses ``nmod_mat_charpoly`` which + uses Berkowitz for small matrices and non-prime moduli or otherwise + the Danilevsky method. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.charpoly + Higher level interface to compute the characteristic polynomial of + a matrix. + """ + # FLINT polynomial coefficients are in reverse order compared to SymPy. + return self.rep.charpoly().coeffs()[::-1] + + @doctest_depends_on(ground_types='flint') + def inv(self): + """ + Compute the inverse of a matrix using FLINT. + + Examples + ======== + + >>> from sympy import Matrix, QQ + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm().convert_to(QQ) + >>> dfm + [[1, 2], [3, 4]] + >>> dfm.inv() + [[-2, 1], [3/2, -1/2]] + >>> dfm.matmul(dfm.inv()) + [[1, 0], [0, 1]] + + Notes + ===== + + Calls the ``.inv()`` method of the underlying FLINT matrix. + + For now this will raise an error if the domain is :ref:`ZZ` but will + use the FLINT method for :ref:`QQ`. + + The FLINT methods for :ref:`ZZ` and :ref:`QQ` are ``fmpz_mat_inv`` and + ``fmpq_mat_inv`` respectively. The ``fmpz_mat_inv`` method computes an + inverse with denominator. This is implemented by calling + ``fmpz_mat_solve`` (see notes in :meth:`lu_solve` about the algorithm). + + The ``fmpq_mat_inv`` method clears denominators from each row and then + multiplies those into the rhs identity matrix before calling + ``fmpz_mat_solve``. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.inv + Higher level method for computing the inverse of a matrix. + """ + # TODO: Implement similar algorithms for DDM and SDM. + # + # XXX: The flint fmpz_mat and fmpq_mat inv methods both return fmpq_mat + # by default. The fmpz_mat method has an optional argument to return + # fmpz_mat instead for unimodular matrices. + # + # The convention in DomainMatrix is to raise an error if the matrix is + # not over a field regardless of whether the matrix is invertible over + # its domain or over any associated field. Maybe DomainMatrix.inv + # should be changed to always return a matrix over an associated field + # except with a unimodular argument for returning an inverse over a + # ring if possible. + # + # For now we follow the existing DomainMatrix convention... + K = self.domain + m, n = self.shape + + if m != n: + raise DMNonSquareMatrixError("cannot invert a non-square matrix") + + if K == ZZ: + raise DMDomainError("field expected, got %s" % K) + elif K == QQ: + try: + return self._new_rep(self.rep.inv()) + except ZeroDivisionError: + raise DMNonInvertibleMatrixError("matrix is not invertible") + else: + # If more domains are added for DFM then we will need to consider + # what happens here. + raise NotImplementedError("DFM.inv() is not implemented for %s" % K) + + def lu(self): + """Return the LU decomposition of the matrix.""" + L, U, swaps = self.to_ddm().lu() + return L.to_dfm(), U.to_dfm(), swaps + + # XXX: The lu_solve function should be renamed to solve. Whether or not it + # uses an LU decomposition is an implementation detail. A method called + # lu_solve would make sense for a situation in which an LU decomposition is + # reused several times to solve iwth different rhs but that would imply a + # different call signature. + # + # The underlying python-flint method has an algorithm= argument so we could + # use that and have e.g. solve_lu and solve_modular or perhaps also a + # method= argument to choose between the two. Flint itself has more + # possible algorithms to choose from than are exposed by python-flint. + + @doctest_depends_on(ground_types='flint') + def lu_solve(self, rhs): + """ + Solve a matrix equation using FLINT. + + Examples + ======== + + >>> from sympy import Matrix, QQ + >>> M = Matrix([[1, 2], [3, 4]]) + >>> dfm = M.to_DM().to_dfm().convert_to(QQ) + >>> dfm + [[1, 2], [3, 4]] + >>> rhs = Matrix([1, 2]).to_DM().to_dfm().convert_to(QQ) + >>> dfm.lu_solve(rhs) + [[0], [1/2]] + + Notes + ===== + + Calls the ``.solve()`` method of the underlying FLINT matrix. + + For now this will raise an error if the domain is :ref:`ZZ` but will + use the FLINT method for :ref:`QQ`. + + The FLINT methods for :ref:`ZZ` and :ref:`QQ` are ``fmpz_mat_solve`` + and ``fmpq_mat_solve`` respectively. The ``fmpq_mat_solve`` method + uses one of two algorithms: + + - For small matrices (<25 rows) it clears denominators between the + matrix and rhs and uses ``fmpz_mat_solve``. + - For larger matrices it uses ``fmpq_mat_solve_dixon`` which is a + modular approach with CRT reconstruction over :ref:`QQ`. + + The ``fmpz_mat_solve`` method uses one of four algorithms: + + - For very small (<= 3x3) matrices it uses a Cramer's rule. + - For small (<= 15x15) matrices it uses a fraction-free LU solve. + - Otherwise it uses either Dixon or another multimodular approach. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.lu_solve + Higher level interface to solve a matrix equation. + """ + if not self.domain == rhs.domain: + raise DMDomainError("Domains must match: %s != %s" % (self.domain, rhs.domain)) + + # XXX: As for inv we should consider whether to return a matrix over + # over an associated field or attempt to find a solution in the ring. + # For now we follow the existing DomainMatrix convention... + if not self.domain.is_Field: + raise DMDomainError("Field expected, got %s" % self.domain) + + m, n = self.shape + j, k = rhs.shape + if m != j: + raise DMShapeError("Matrix size mismatch: %s * %s vs %s * %s" % (m, n, j, k)) + sol_shape = (n, k) + + # XXX: The Flint solve method only handles square matrices. Probably + # Flint has functions that could be used to solve non-square systems + # but they are not exposed in python-flint yet. Alternatively we could + # put something here using the features that are available like rref. + if m != n: + return self.to_ddm().lu_solve(rhs.to_ddm()).to_dfm() + + try: + sol = self.rep.solve(rhs.rep) + except ZeroDivisionError: + raise DMNonInvertibleMatrixError("Matrix det == 0; not invertible.") + + return self._new(sol, sol_shape, self.domain) + + def nullspace(self): + """Return a basis for the nullspace of the matrix.""" + # Code to compute nullspace using flint: + # + # V, nullity = self.rep.nullspace() + # V_dfm = self._new_rep(V)._extract(range(self.rows), range(nullity)) + # + # XXX: That gives the nullspace but does not give us nonpivots. So we + # use the slower DDM method anyway. It would be better to change the + # signature of the nullspace method to not return nonpivots. + # + # XXX: Also python-flint exposes a nullspace method for fmpz_mat but + # not for fmpq_mat. This is the reverse of the situation for DDM etc + # which only allow nullspace over a field. The nullspace method for + # DDM, SDM etc should be changed to allow nullspace over ZZ as well. + # The DomainMatrix nullspace method does allow the domain to be a ring + # but does not directly call the lower-level nullspace methods and uses + # rref_den instead. Nullspace methods should also be added to all + # matrix types in python-flint. + ddm, nonpivots = self.to_ddm().nullspace() + return ddm.to_dfm(), nonpivots + + def nullspace_from_rref(self, pivots=None): + """Return a basis for the nullspace of the matrix.""" + # XXX: Use the flint nullspace method!!! + sdm, nonpivots = self.to_sdm().nullspace_from_rref(pivots=pivots) + return sdm.to_dfm(), nonpivots + + def particular(self): + """Return a particular solution to the system.""" + return self.to_ddm().particular().to_dfm() + + def _lll(self, transform=False, delta=0.99, eta=0.51, rep='zbasis', gram='approx'): + """Call the fmpz_mat.lll() method but check rank to avoid segfaults.""" + + # XXX: There are tests that pass e.g. QQ(5,6) for delta. That fails + # with a TypeError in flint because if QQ is fmpq then conversion with + # float fails. We handle that here but there are two better fixes: + # + # - Make python-flint's fmpq convert with float(x) + # - Change the tests because delta should just be a float. + + def to_float(x): + if QQ.of_type(x): + return float(x.numerator) / float(x.denominator) + else: + return float(x) + + delta = to_float(delta) + eta = to_float(eta) + + if not 0.25 < delta < 1: + raise DMValueError("delta must be between 0.25 and 1") + + # XXX: The flint lll method segfaults if the matrix is not full rank. + m, n = self.shape + if self.rep.rank() != m: + raise DMRankError("Matrix must have full row rank for Flint LLL.") + + # Actually call the flint method. + return self.rep.lll(transform=transform, delta=delta, eta=eta, rep=rep, gram=gram) + + @doctest_depends_on(ground_types='flint') + def lll(self, delta=0.75): + """Compute LLL-reduced basis using FLINT. + + See :meth:`lll_transform` for more information. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6]]) + >>> M.to_DM().to_dfm().lll() + [[2, 1, 0], [-1, 1, 3]] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.lll + Higher level interface to compute LLL-reduced basis. + lll_transform + Compute LLL-reduced basis and transform matrix. + """ + if self.domain != ZZ: + raise DMDomainError("ZZ expected, got %s" % self.domain) + elif self.rows > self.cols: + raise DMShapeError("Matrix must not have more rows than columns.") + + rep = self._lll(delta=delta) + return self._new_rep(rep) + + @doctest_depends_on(ground_types='flint') + def lll_transform(self, delta=0.75): + """Compute LLL-reduced basis and transform using FLINT. + + Examples + ======== + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2, 3], [4, 5, 6]]).to_DM().to_dfm() + >>> M_lll, T = M.lll_transform() + >>> M_lll + [[2, 1, 0], [-1, 1, 3]] + >>> T + [[-2, 1], [3, -1]] + >>> T.matmul(M) == M_lll + True + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.lll + Higher level interface to compute LLL-reduced basis. + lll + Compute LLL-reduced basis without transform matrix. + """ + if self.domain != ZZ: + raise DMDomainError("ZZ expected, got %s" % self.domain) + elif self.rows > self.cols: + raise DMShapeError("Matrix must not have more rows than columns.") + + rep, T = self._lll(transform=True, delta=delta) + basis = self._new_rep(rep) + T_dfm = self._new(T, (self.rows, self.rows), self.domain) + return basis, T_dfm + + +# Avoid circular imports +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.ddm import SDM diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/_typing.py b/lib/python3.10/site-packages/sympy/polys/matrices/_typing.py new file mode 100644 index 0000000000000000000000000000000000000000..fc7c3b601fe85d591ddf853acbf33f5bba64b11c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/_typing.py @@ -0,0 +1,16 @@ +from typing import TypeVar, Protocol + + +T = TypeVar('T') + + +class RingElement(Protocol): + """A ring element. + + Must support ``+``, ``-``, ``*``, ``**`` and ``-``. + """ + def __add__(self: T, other: T, /) -> T: ... + def __sub__(self: T, other: T, /) -> T: ... + def __mul__(self: T, other: T, /) -> T: ... + def __pow__(self: T, other: int, /) -> T: ... + def __neg__(self: T, /) -> T: ... diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/ddm.py b/lib/python3.10/site-packages/sympy/polys/matrices/ddm.py new file mode 100644 index 0000000000000000000000000000000000000000..77d02b8ad5a0a5ced07d2144b4a11337a705cff8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/ddm.py @@ -0,0 +1,1029 @@ +""" + +Module for the DDM class. + +The DDM class is an internal representation used by DomainMatrix. The letters +DDM stand for Dense Domain Matrix. A DDM instance represents a matrix using +elements from a polynomial Domain (e.g. ZZ, QQ, ...) in a dense-matrix +representation. + +Basic usage: + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + >>> A.shape + (2, 2) + >>> A + [[0, 1], [-1, 0]] + >>> type(A) + + >>> A @ A + [[-1, 0], [0, -1]] + +The ddm_* functions are designed to operate on DDM as well as on an ordinary +list of lists: + + >>> from sympy.polys.matrices.dense import ddm_idet + >>> ddm_idet(A, QQ) + 1 + >>> ddm_idet([[0, 1], [-1, 0]], QQ) + 1 + >>> A + [[-1, 0], [0, -1]] + +Note that ddm_idet modifies the input matrix in-place. It is recommended to +use the DDM.det method as a friendlier interface to this instead which takes +care of copying the matrix: + + >>> B = DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + >>> B.det() + 1 + +Normally DDM would not be used directly and is just part of the internal +representation of DomainMatrix which adds further functionality including e.g. +unifying domains. + +The dense format used by DDM is a list of lists of elements e.g. the 2x2 +identity matrix is like [[1, 0], [0, 1]]. The DDM class itself is a subclass +of list and its list items are plain lists. Elements are accessed as e.g. +ddm[i][j] where ddm[i] gives the ith row and ddm[i][j] gets the element in the +jth column of that row. Subclassing list makes e.g. iteration and indexing +very efficient. We do not override __getitem__ because it would lose that +benefit. + +The core routines are implemented by the ddm_* functions defined in dense.py. +Those functions are intended to be able to operate on a raw list-of-lists +representation of matrices with most functions operating in-place. The DDM +class takes care of copying etc and also stores a Domain object associated +with its elements. This makes it possible to implement things like A + B with +domain checking and also shape checking so that the list of lists +representation is friendlier. + +""" +from itertools import chain + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on + +from .exceptions import ( + DMBadInputError, + DMDomainError, + DMNonSquareMatrixError, + DMShapeError, +) + +from sympy.polys.domains import QQ + +from .dense import ( + ddm_transpose, + ddm_iadd, + ddm_isub, + ddm_ineg, + ddm_imul, + ddm_irmul, + ddm_imatmul, + ddm_irref, + ddm_irref_den, + ddm_idet, + ddm_iinv, + ddm_ilu_split, + ddm_ilu_solve, + ddm_berk, + ) + +from .lll import ddm_lll, ddm_lll_transform + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['DDM.to_dfm', 'DDM.to_dfm_or_ddm'] + + +class DDM(list): + """Dense matrix based on polys domain elements + + This is a list subclass and is a wrapper for a list of lists that supports + basic matrix arithmetic +, -, *, **. + """ + + fmt = 'dense' + is_DFM = False + is_DDM = True + + def __init__(self, rowslist, shape, domain): + if not (isinstance(rowslist, list) and all(type(row) is list for row in rowslist)): + raise DMBadInputError("rowslist must be a list of lists") + m, n = shape + if len(rowslist) != m or any(len(row) != n for row in rowslist): + raise DMBadInputError("Inconsistent row-list/shape") + + super().__init__(rowslist) + self.shape = (m, n) + self.rows = m + self.cols = n + self.domain = domain + + def getitem(self, i, j): + return self[i][j] + + def setitem(self, i, j, value): + self[i][j] = value + + def extract_slice(self, slice1, slice2): + ddm = [row[slice2] for row in self[slice1]] + rows = len(ddm) + cols = len(ddm[0]) if ddm else len(range(self.shape[1])[slice2]) + return DDM(ddm, (rows, cols), self.domain) + + def extract(self, rows, cols): + ddm = [] + for i in rows: + rowi = self[i] + ddm.append([rowi[j] for j in cols]) + return DDM(ddm, (len(rows), len(cols)), self.domain) + + @classmethod + def from_list(cls, rowslist, shape, domain): + """ + Create a :class:`DDM` from a list of lists. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM.from_list([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + >>> A + [[0, 1], [-1, 0]] + >>> A == DDM([[ZZ(0), ZZ(1)], [ZZ(-1), ZZ(0)]], (2, 2), ZZ) + True + + See Also + ======== + + from_list_flat + """ + return cls(rowslist, shape, domain) + + @classmethod + def from_ddm(cls, other): + return other.copy() + + def to_list(self): + """ + Convert to a list of lists. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_list() + [[1, 2], [3, 4]] + + See Also + ======== + + to_list_flat + sympy.polys.matrices.domainmatrix.DomainMatrix.to_list + """ + return list(self) + + def to_list_flat(self): + """ + Convert to a flat list of elements. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_list_flat() + [1, 2, 3, 4] + >>> A == DDM.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.to_list_flat + """ + flat = [] + for row in self: + flat.extend(row) + return flat + + @classmethod + def from_list_flat(cls, flat, shape, domain): + """ + Create a :class:`DDM` from a flat list of elements. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> A = DDM.from_list_flat([1, 2, 3, 4], (2, 2), QQ) + >>> A + [[1, 2], [3, 4]] + >>> A == DDM.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + to_list_flat + sympy.polys.matrices.domainmatrix.DomainMatrix.from_list_flat + """ + assert type(flat) is list + rows, cols = shape + if not (len(flat) == rows*cols): + raise DMBadInputError("Inconsistent flat-list shape") + lol = [flat[i*cols:(i+1)*cols] for i in range(rows)] + return cls(lol, shape, domain) + + def flatiter(self): + return chain.from_iterable(self) + + def flat(self): + items = [] + for row in self: + items.extend(row) + return items + + def to_flat_nz(self): + """ + Convert to a flat list of nonzero elements and data. + + Explanation + =========== + + This is used to operate on a list of the elements of a matrix and then + reconstruct a matrix using :meth:`from_flat_nz`. Zero elements are + included in the list but that may change in the future. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> elements, data = A.to_flat_nz() + >>> elements + [1, 2, 3, 4] + >>> A == DDM.from_flat_nz(elements, data, A.domain) + True + + See Also + ======== + + from_flat_nz + sympy.polys.matrices.sdm.SDM.to_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.to_flat_nz + """ + return self.to_sdm().to_flat_nz() + + @classmethod + def from_flat_nz(cls, elements, data, domain): + """ + Reconstruct a :class:`DDM` after calling :meth:`to_flat_nz`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> elements, data = A.to_flat_nz() + >>> elements + [1, 2, 3, 4] + >>> A == DDM.from_flat_nz(elements, data, A.domain) + True + + See Also + ======== + + to_flat_nz + sympy.polys.matrices.sdm.SDM.from_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.from_flat_nz + """ + return SDM.from_flat_nz(elements, data, domain).to_ddm() + + def to_dod(self): + """ + Convert to a dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dod() + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + + See Also + ======== + + from_dod + sympy.polys.matrices.sdm.SDM.to_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dod + """ + dod = {} + for i, row in enumerate(self): + row = {j:e for j, e in enumerate(row) if e} + if row: + dod[i] = row + return dod + + @classmethod + def from_dod(cls, dod, shape, domain): + """ + Create a :class:`DDM` from a dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> dod = {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + >>> A = DDM.from_dod(dod, (2, 2), QQ) + >>> A + [[1, 2], [3, 4]] + + See Also + ======== + + to_dod + sympy.polys.matrices.sdm.SDM.from_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.from_dod + """ + rows, cols = shape + lol = [[domain.zero] * cols for _ in range(rows)] + for i, row in dod.items(): + for j, element in row.items(): + lol[i][j] = element + return DDM(lol, shape, domain) + + def to_dok(self): + """ + Convert :class:`DDM` to dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dok() + {(0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4} + + See Also + ======== + + from_dok + sympy.polys.matrices.sdm.SDM.to_dok + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dok + """ + dok = {} + for i, row in enumerate(self): + for j, element in enumerate(row): + if element: + dok[i, j] = element + return dok + + @classmethod + def from_dok(cls, dok, shape, domain): + """ + Create a :class:`DDM` from a dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> dok = {(0, 0): 1, (0, 1): 2, (1, 0): 3, (1, 1): 4} + >>> A = DDM.from_dok(dok, (2, 2), QQ) + >>> A + [[1, 2], [3, 4]] + + See Also + ======== + + to_dok + sympy.polys.matrices.sdm.SDM.from_dok + sympy.polys.matrices.domainmatrix.DomainMatrix.from_dok + """ + rows, cols = shape + lol = [[domain.zero] * cols for _ in range(rows)] + for (i, j), element in dok.items(): + lol[i][j] = element + return DDM(lol, shape, domain) + + def iter_values(self): + """ + Iterater over the non-zero values of the matrix. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[QQ(1), QQ(0)], [QQ(3), QQ(4)]], (2, 2), QQ) + >>> list(A.iter_values()) + [1, 3, 4] + + See Also + ======== + + iter_items + to_list_flat + sympy.polys.matrices.domainmatrix.DomainMatrix.iter_values + """ + for row in self: + yield from filter(None, row) + + def iter_items(self): + """ + Iterate over indices and values of nonzero elements of the matrix. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[QQ(1), QQ(0)], [QQ(3), QQ(4)]], (2, 2), QQ) + >>> list(A.iter_items()) + [((0, 0), 1), ((1, 0), 3), ((1, 1), 4)] + + See Also + ======== + + iter_values + to_dok + sympy.polys.matrices.domainmatrix.DomainMatrix.iter_items + """ + for i, row in enumerate(self): + for j, element in enumerate(row): + if element: + yield (i, j), element + + def to_ddm(self): + """ + Convert to a :class:`DDM`. + + This just returns ``self`` but exists to parallel the corresponding + method in other matrix types like :class:`~.SDM`. + + See Also + ======== + + to_sdm + to_dfm + to_dfm_or_ddm + sympy.polys.matrices.sdm.SDM.to_ddm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_ddm + """ + return self + + def to_sdm(self): + """ + Convert to a :class:`~.SDM`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_sdm() + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + >>> type(A.to_sdm()) + + + See Also + ======== + + SDM + sympy.polys.matrices.sdm.SDM.to_ddm + """ + return SDM.from_list(self, self.shape, self.domain) + + @doctest_depends_on(ground_types=['flint']) + def to_dfm(self): + """ + Convert to :class:`~.DDM` to :class:`~.DFM`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dfm() + [[1, 2], [3, 4]] + >>> type(A.to_dfm()) + + + See Also + ======== + + DFM + sympy.polys.matrices._dfm.DFM.to_ddm + """ + return DFM(list(self), self.shape, self.domain) + + @doctest_depends_on(ground_types=['flint']) + def to_dfm_or_ddm(self): + """ + Convert to :class:`~.DFM` if possible or otherwise return self. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy import QQ + >>> A = DDM([[1, 2], [3, 4]], (2, 2), QQ) + >>> A.to_dfm_or_ddm() + [[1, 2], [3, 4]] + >>> type(A.to_dfm_or_ddm()) + + + See Also + ======== + + to_dfm + to_ddm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm_or_ddm + """ + if DFM._supports_domain(self.domain): + return self.to_dfm() + return self + + def convert_to(self, K): + Kold = self.domain + if K == Kold: + return self.copy() + rows = [[K.convert_from(e, Kold) for e in row] for row in self] + return DDM(rows, self.shape, K) + + def __str__(self): + rowsstr = ['[%s]' % ', '.join(map(str, row)) for row in self] + return '[%s]' % ', '.join(rowsstr) + + def __repr__(self): + cls = type(self).__name__ + rows = list.__repr__(self) + return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain) + + def __eq__(self, other): + if not isinstance(other, DDM): + return False + return (super().__eq__(other) and self.domain == other.domain) + + def __ne__(self, other): + return not self.__eq__(other) + + @classmethod + def zeros(cls, shape, domain): + z = domain.zero + m, n = shape + rowslist = [[z] * n for _ in range(m)] + return DDM(rowslist, shape, domain) + + @classmethod + def ones(cls, shape, domain): + one = domain.one + m, n = shape + rowlist = [[one] * n for _ in range(m)] + return DDM(rowlist, shape, domain) + + @classmethod + def eye(cls, size, domain): + if isinstance(size, tuple): + m, n = size + elif isinstance(size, int): + m = n = size + one = domain.one + ddm = cls.zeros((m, n), domain) + for i in range(min(m, n)): + ddm[i][i] = one + return ddm + + def copy(self): + copyrows = [row[:] for row in self] + return DDM(copyrows, self.shape, self.domain) + + def transpose(self): + rows, cols = self.shape + if rows: + ddmT = ddm_transpose(self) + else: + ddmT = [[]] * cols + return DDM(ddmT, (cols, rows), self.domain) + + def __add__(a, b): + if not isinstance(b, DDM): + return NotImplemented + return a.add(b) + + def __sub__(a, b): + if not isinstance(b, DDM): + return NotImplemented + return a.sub(b) + + def __neg__(a): + return a.neg() + + def __mul__(a, b): + if b in a.domain: + return a.mul(b) + else: + return NotImplemented + + def __rmul__(a, b): + if b in a.domain: + return a.mul(b) + else: + return NotImplemented + + def __matmul__(a, b): + if isinstance(b, DDM): + return a.matmul(b) + else: + return NotImplemented + + @classmethod + def _check(cls, a, op, b, ashape, bshape): + if a.domain != b.domain: + msg = "Domain mismatch: %s %s %s" % (a.domain, op, b.domain) + raise DMDomainError(msg) + if ashape != bshape: + msg = "Shape mismatch: %s %s %s" % (a.shape, op, b.shape) + raise DMShapeError(msg) + + def add(a, b): + """a + b""" + a._check(a, '+', b, a.shape, b.shape) + c = a.copy() + ddm_iadd(c, b) + return c + + def sub(a, b): + """a - b""" + a._check(a, '-', b, a.shape, b.shape) + c = a.copy() + ddm_isub(c, b) + return c + + def neg(a): + """-a""" + b = a.copy() + ddm_ineg(b) + return b + + def mul(a, b): + c = a.copy() + ddm_imul(c, b) + return c + + def rmul(a, b): + c = a.copy() + ddm_irmul(c, b) + return c + + def matmul(a, b): + """a @ b (matrix product)""" + m, o = a.shape + o2, n = b.shape + a._check(a, '*', b, o, o2) + c = a.zeros((m, n), a.domain) + ddm_imatmul(c, a, b) + return c + + def mul_elementwise(a, b): + assert a.shape == b.shape + assert a.domain == b.domain + c = [[aij * bij for aij, bij in zip(ai, bi)] for ai, bi in zip(a, b)] + return DDM(c, a.shape, a.domain) + + def hstack(A, *B): + """Horizontally stacks :py:class:`~.DDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + + >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.hstack(B) + [[1, 2, 5, 6], [3, 4, 7, 8]] + + >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.hstack(B, C) + [[1, 2, 5, 6, 9, 10], [3, 4, 7, 8, 11, 12]] + """ + Anew = list(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkrows == rows + assert Bk.domain == domain + + cols += Bkcols + + for i, Bki in enumerate(Bk): + Anew[i].extend(Bki) + + return DDM(Anew, (rows, cols), A.domain) + + def vstack(A, *B): + """Vertically stacks :py:class:`~.DDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + + >>> A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DDM([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.vstack(B) + [[1, 2], [3, 4], [5, 6], [7, 8]] + + >>> C = DDM([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.vstack(B, C) + [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] + """ + Anew = list(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkcols == cols + assert Bk.domain == domain + + rows += Bkrows + + Anew.extend(Bk.copy()) + + return DDM(Anew, (rows, cols), A.domain) + + def applyfunc(self, func, domain): + elements = [list(map(func, row)) for row in self] + return DDM(elements, self.shape, domain) + + def nnz(a): + """Number of non-zero entries in :py:class:`~.DDM` matrix. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nnz + """ + return sum(sum(map(bool, row)) for row in a) + + def scc(a): + """Strongly connected components of a square matrix *a*. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + >>> A = DDM([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(1)]], (2, 2), ZZ) + >>> A.scc() + [[0], [1]] + + See also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.scc + + """ + return a.to_sdm().scc() + + @classmethod + def diag(cls, values, domain): + """Returns a square diagonal matrix with *values* on the diagonal. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import DDM + >>> DDM.diag([ZZ(1), ZZ(2), ZZ(3)], ZZ) + [[1, 0, 0], [0, 2, 0], [0, 0, 3]] + + See also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.diag + """ + return SDM.diag(values, domain).to_ddm() + + def rref(a): + """Reduced-row echelon form of a and list of pivots. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + Higher level interface to this function. + sympy.polys.matrices.dense.ddm_irref + The underlying algorithm. + """ + b = a.copy() + K = a.domain + partial_pivot = K.is_RealField or K.is_ComplexField + pivots = ddm_irref(b, _partial_pivot=partial_pivot) + return b, pivots + + def rref_den(a): + """Reduced-row echelon form of a with denominator and list of pivots + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + Higher level interface to this function. + sympy.polys.matrices.dense.ddm_irref_den + The underlying algorithm. + """ + b = a.copy() + K = a.domain + denom, pivots = ddm_irref_den(b, K) + return b, denom, pivots + + def nullspace(a): + """Returns a basis for the nullspace of a. + + The domain of the matrix must be a field. + + See Also + ======== + + rref + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + """ + rref, pivots = a.rref() + return rref.nullspace_from_rref(pivots) + + def nullspace_from_rref(a, pivots=None): + """Compute the nullspace of a matrix from its rref. + + The domain of the matrix can be any domain. + + Returns a tuple (basis, nonpivots). + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + The higher level interface to this function. + """ + m, n = a.shape + K = a.domain + + if pivots is None: + pivots = [] + last_pivot = -1 + for i in range(m): + ai = a[i] + for j in range(last_pivot+1, n): + if ai[j]: + last_pivot = j + pivots.append(j) + break + + if not pivots: + return (a.eye(n, K), list(range(n))) + + # After rref the pivots are all one but after rref_den they may not be. + pivot_val = a[0][pivots[0]] + + basis = [] + nonpivots = [] + for i in range(n): + if i in pivots: + continue + nonpivots.append(i) + vec = [pivot_val if i == j else K.zero for j in range(n)] + for ii, jj in enumerate(pivots): + vec[jj] -= a[ii][i] + basis.append(vec) + + basis_ddm = DDM(basis, (len(basis), n), K) + + return (basis_ddm, nonpivots) + + def particular(a): + return a.to_sdm().particular().to_ddm() + + def det(a): + """Determinant of a""" + m, n = a.shape + if m != n: + raise DMNonSquareMatrixError("Determinant of non-square matrix") + b = a.copy() + K = b.domain + deta = ddm_idet(b, K) + return deta + + def inv(a): + """Inverse of a""" + m, n = a.shape + if m != n: + raise DMNonSquareMatrixError("Determinant of non-square matrix") + ainv = a.copy() + K = a.domain + ddm_iinv(ainv, a, K) + return ainv + + def lu(a): + """L, U decomposition of a""" + m, n = a.shape + K = a.domain + + U = a.copy() + L = a.eye(m, K) + swaps = ddm_ilu_split(L, U, K) + + return L, U, swaps + + def lu_solve(a, b): + """x where a*x = b""" + m, n = a.shape + m2, o = b.shape + a._check(a, 'lu_solve', b, m, m2) + if not a.domain.is_Field: + raise DMDomainError("lu_solve requires a field") + + L, U, swaps = a.lu() + x = a.zeros((n, o), a.domain) + ddm_ilu_solve(x, L, U, swaps, b) + return x + + def charpoly(a): + """Coefficients of characteristic polynomial of a""" + K = a.domain + m, n = a.shape + if m != n: + raise DMNonSquareMatrixError("Charpoly of non-square matrix") + vec = ddm_berk(a, K) + coeffs = [vec[i][0] for i in range(n+1)] + return coeffs + + def is_zero_matrix(self): + """ + Says whether this matrix has all zero entries. + """ + zero = self.domain.zero + return all(Mij == zero for Mij in self.flatiter()) + + def is_upper(self): + """ + Says whether this matrix is upper-triangular. True can be returned + even if the matrix is not square. + """ + zero = self.domain.zero + return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[:i]) + + def is_lower(self): + """ + Says whether this matrix is lower-triangular. True can be returned + even if the matrix is not square. + """ + zero = self.domain.zero + return all(Mij == zero for i, Mi in enumerate(self) for Mij in Mi[i+1:]) + + def is_diagonal(self): + """ + Says whether this matrix is diagonal. True can be returned even if + the matrix is not square. + """ + return self.is_upper() and self.is_lower() + + def diagonal(self): + """ + Returns a list of the elements from the diagonal of the matrix. + """ + m, n = self.shape + return [self[i][i] for i in range(min(m, n))] + + def lll(A, delta=QQ(3, 4)): + return ddm_lll(A, delta=delta) + + def lll_transform(A, delta=QQ(3, 4)): + return ddm_lll_transform(A, delta=delta) + + +from .sdm import SDM +from .dfm import DFM diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/dense.py b/lib/python3.10/site-packages/sympy/polys/matrices/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..47ab2d6897c6d9f3781af23ccb68f96f15c7e859 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/dense.py @@ -0,0 +1,824 @@ +""" + +Module for the ddm_* routines for operating on a matrix in list of lists +matrix representation. + +These routines are used internally by the DDM class which also provides a +friendlier interface for them. The idea here is to implement core matrix +routines in a way that can be applied to any simple list representation +without the need to use any particular matrix class. For example we can +compute the RREF of a matrix like: + + >>> from sympy.polys.matrices.dense import ddm_irref + >>> M = [[1, 2, 3], [4, 5, 6]] + >>> pivots = ddm_irref(M) + >>> M + [[1.0, 0.0, -1.0], [0, 1.0, 2.0]] + +These are lower-level routines that work mostly in place.The routines at this +level should not need to know what the domain of the elements is but should +ideally document what operations they will use and what functions they need to +be provided with. + +The next-level up is the DDM class which uses these routines but wraps them up +with an interface that handles copying etc and keeps track of the Domain of +the elements of the matrix: + + >>> from sympy.polys.domains import QQ + >>> from sympy.polys.matrices.ddm import DDM + >>> M = DDM([[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]], (2, 3), QQ) + >>> M + [[1, 2, 3], [4, 5, 6]] + >>> Mrref, pivots = M.rref() + >>> Mrref + [[1, 0, -1], [0, 1, 2]] + +""" +from __future__ import annotations +from operator import mul +from .exceptions import ( + DMShapeError, + DMDomainError, + DMNonInvertibleMatrixError, + DMNonSquareMatrixError, +) +from typing import Sequence, TypeVar +from sympy.polys.matrices._typing import RingElement + + +#: Type variable for the elements of the matrix +T = TypeVar('T') + +#: Type variable for the elements of the matrix that are in a ring +R = TypeVar('R', bound=RingElement) + + +def ddm_transpose(matrix: Sequence[Sequence[T]]) -> list[list[T]]: + """matrix transpose""" + return list(map(list, zip(*matrix))) + + +def ddm_iadd(a: list[list[R]], b: Sequence[Sequence[R]]) -> None: + """a += b""" + for ai, bi in zip(a, b): + for j, bij in enumerate(bi): + ai[j] += bij + + +def ddm_isub(a: list[list[R]], b: Sequence[Sequence[R]]) -> None: + """a -= b""" + for ai, bi in zip(a, b): + for j, bij in enumerate(bi): + ai[j] -= bij + + +def ddm_ineg(a: list[list[R]]) -> None: + """a <-- -a""" + for ai in a: + for j, aij in enumerate(ai): + ai[j] = -aij + + +def ddm_imul(a: list[list[R]], b: R) -> None: + """a <-- a*b""" + for ai in a: + for j, aij in enumerate(ai): + ai[j] = aij * b + + +def ddm_irmul(a: list[list[R]], b: R) -> None: + """a <-- b*a""" + for ai in a: + for j, aij in enumerate(ai): + ai[j] = b * aij + + +def ddm_imatmul( + a: list[list[R]], b: Sequence[Sequence[R]], c: Sequence[Sequence[R]] +) -> None: + """a += b @ c""" + cT = list(zip(*c)) + + for bi, ai in zip(b, a): + for j, cTj in enumerate(cT): + ai[j] = sum(map(mul, bi, cTj), ai[j]) + + +def ddm_irref(a, _partial_pivot=False): + """In-place reduced row echelon form of a matrix. + + Compute the reduced row echelon form of $a$. Modifies $a$ in place and + returns a list of the pivot columns. + + Uses naive Gauss-Jordan elimination in the ground domain which must be a + field. + + This routine is only really suitable for use with simple field domains like + :ref:`GF(p)`, :ref:`QQ` and :ref:`QQ(a)` although even for :ref:`QQ` with + larger matrices it is possibly more efficient to use fraction free + approaches. + + This method is not suitable for use with rational function fields + (:ref:`K(x)`) because the elements will blowup leading to costly gcd + operations. In this case clearing denominators and using fraction free + approaches is likely to be more efficient. + + For inexact numeric domains like :ref:`RR` and :ref:`CC` pass + ``_partial_pivot=True`` to use partial pivoting to control rounding errors. + + Examples + ======== + + >>> from sympy.polys.matrices.dense import ddm_irref + >>> from sympy import QQ + >>> M = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]] + >>> pivots = ddm_irref(M) + >>> M + [[1, 0, -1], [0, 1, 2]] + >>> pivots + [0, 1] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + Higher level interface to this routine. + ddm_irref_den + The fraction free version of this routine. + sdm_irref + A sparse version of this routine. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Row_echelon_form#Reduced_row_echelon_form + """ + # We compute aij**-1 below and then use multiplication instead of division + # in the innermost loop. The domain here is a field so either operation is + # defined. There are significant performance differences for some domains + # though. In the case of e.g. QQ or QQ(x) inversion is free but + # multiplication and division have the same cost so it makes no difference. + # In cases like GF(p), QQ, RR or CC though multiplication is + # faster than division so reusing a precomputed inverse for many + # multiplications can be a lot faster. The biggest win is QQ when + # deg(minpoly(a)) is large. + # + # With domains like QQ(x) this can perform badly for other reasons. + # Typically the initial matrix has simple denominators and the + # fraction-free approach with exquo (ddm_irref_den) will preserve that + # property throughout. The method here causes denominator blowup leading to + # expensive gcd reductions in the intermediate expressions. With many + # generators like QQ(x,y,z,...) this is extremely bad. + # + # TODO: Use a nontrivial pivoting strategy to control intermediate + # expression growth. Rearranging rows and/or columns could defer the most + # complicated elements until the end. If the first pivot is a + # complicated/large element then the first round of reduction will + # immediately introduce expression blowup across the whole matrix. + + # a is (m x n) + m = len(a) + if not m: + return [] + n = len(a[0]) + + i = 0 + pivots = [] + + for j in range(n): + # Proper pivoting should be used for all domains for performance + # reasons but it is only strictly needed for RR and CC (and possibly + # other domains like RR(x)). This path is used by DDM.rref() if the + # domain is RR or CC. It uses partial (row) pivoting based on the + # absolute value of the pivot candidates. + if _partial_pivot: + ip = max(range(i, m), key=lambda ip: abs(a[ip][j])) + a[i], a[ip] = a[ip], a[i] + + # pivot + aij = a[i][j] + + # zero-pivot + if not aij: + for ip in range(i+1, m): + aij = a[ip][j] + # row-swap + if aij: + a[i], a[ip] = a[ip], a[i] + break + else: + # next column + continue + + # normalise row + ai = a[i] + aijinv = aij**-1 + for l in range(j, n): + ai[l] *= aijinv # ai[j] = one + + # eliminate above and below to the right + for k, ak in enumerate(a): + if k == i or not ak[j]: + continue + akj = ak[j] + ak[j] -= akj # ak[j] = zero + for l in range(j+1, n): + ak[l] -= akj * ai[l] + + # next row + pivots.append(j) + i += 1 + + # no more rows? + if i >= m: + break + + return pivots + + +def ddm_irref_den(a, K): + """a <-- rref(a); return (den, pivots) + + Compute the fraction-free reduced row echelon form (RREF) of $a$. Modifies + $a$ in place and returns a tuple containing the denominator of the RREF and + a list of the pivot columns. + + Explanation + =========== + + The algorithm used is the fraction-free version of Gauss-Jordan elimination + described as FFGJ in [1]_. Here it is modified to handle zero or missing + pivots and to avoid redundant arithmetic. + + The domain $K$ must support exact division (``K.exquo``) but does not need + to be a field. This method is suitable for most exact rings and fields like + :ref:`ZZ`, :ref:`QQ` and :ref:`QQ(a)`. In the case of :ref:`QQ` or + :ref:`K(x)` it might be more efficient to clear denominators and use + :ref:`ZZ` or :ref:`K[x]` instead. + + For inexact domains like :ref:`RR` and :ref:`CC` use ``ddm_irref`` instead. + + Examples + ======== + + >>> from sympy.polys.matrices.dense import ddm_irref_den + >>> from sympy import ZZ, Matrix + >>> M = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)]] + >>> den, pivots = ddm_irref_den(M, ZZ) + >>> M + [[-3, 0, 3], [0, -3, -6]] + >>> den + -3 + >>> pivots + [0, 1] + >>> Matrix(M).rref()[0] + Matrix([ + [1, 0, -1], + [0, 1, 2]]) + + See Also + ======== + + ddm_irref + A version of this routine that uses field division. + sdm_irref + A sparse version of :func:`ddm_irref`. + sdm_rref_den + A sparse version of :func:`ddm_irref_den`. + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + Higher level interface. + + References + ========== + + .. [1] Fraction-free algorithms for linear and polynomial equations. + George C. Nakos , Peter R. Turner , Robert M. Williams. + https://dl.acm.org/doi/10.1145/271130.271133 + """ + # + # A simpler presentation of this algorithm is given in [1]: + # + # Given an n x n matrix A and n x 1 matrix b: + # + # for i in range(n): + # if i != 0: + # d = a[i-1][i-1] + # for j in range(n): + # if j == i: + # continue + # b[j] = a[i][i]*b[j] - a[j][i]*b[i] + # for k in range(n): + # a[j][k] = a[i][i]*a[j][k] - a[j][i]*a[i][k] + # if i != 0: + # a[j][k] /= d + # + # Our version here is a bit more complicated because: + # + # 1. We use row-swaps to avoid zero pivots. + # 2. We allow for some columns to be missing pivots. + # 3. We avoid a lot of redundant arithmetic. + # + # TODO: Use a non-trivial pivoting strategy. Even just row swapping makes a + # big difference to performance if e.g. the upper-left entry of the matrix + # is a huge polynomial. + + # a is (m x n) + m = len(a) + if not m: + return K.one, [] + n = len(a[0]) + + d = None + pivots = [] + no_pivots = [] + + # i, j will be the row and column indices of the current pivot + i = 0 + for j in range(n): + # next pivot? + aij = a[i][j] + + # swap rows if zero + if not aij: + for ip in range(i+1, m): + aij = a[ip][j] + # row-swap + if aij: + a[i], a[ip] = a[ip], a[i] + break + else: + # go to next column + no_pivots.append(j) + continue + + # Now aij is the pivot and i,j are the row and column. We need to clear + # the column above and below but we also need to keep track of the + # denominator of the RREF which means also multiplying everything above + # and to the left by the current pivot aij and dividing by d (which we + # multiplied everything by in the previous iteration so this is an + # exact division). + # + # First handle the upper left corner which is usually already diagonal + # with all diagonal entries equal to the current denominator but there + # can be other non-zero entries in any column that has no pivot. + + # Update previous pivots in the matrix + if pivots: + pivot_val = aij * a[0][pivots[0]] + # Divide out the common factor + if d is not None: + pivot_val = K.exquo(pivot_val, d) + + # Could defer this until the end but it is pretty cheap and + # helps when debugging. + for ip, jp in enumerate(pivots): + a[ip][jp] = pivot_val + + # Update columns without pivots + for jnp in no_pivots: + for ip in range(i): + aijp = a[ip][jnp] + if aijp: + aijp *= aij + if d is not None: + aijp = K.exquo(aijp, d) + a[ip][jnp] = aijp + + # Eliminate above, below and to the right as in ordinary division free + # Gauss-Jordan elmination except also dividing out d from every entry. + + for jp, aj in enumerate(a): + + # Skip the current row + if jp == i: + continue + + # Eliminate to the right in all rows + for kp in range(j+1, n): + ajk = aij * aj[kp] - aj[j] * a[i][kp] + if d is not None: + ajk = K.exquo(ajk, d) + aj[kp] = ajk + + # Set to zero above and below the pivot + aj[j] = K.zero + + # next row + pivots.append(j) + i += 1 + + # no more rows left? + if i >= m: + break + + if not K.is_one(aij): + d = aij + else: + d = None + + if not pivots: + denom = K.one + else: + denom = a[0][pivots[0]] + + return denom, pivots + + +def ddm_idet(a, K): + """a <-- echelon(a); return det + + Explanation + =========== + + Compute the determinant of $a$ using the Bareiss fraction-free algorithm. + The matrix $a$ is modified in place. Its diagonal elements are the + determinants of the leading principal minors. The determinant of $a$ is + returned. + + The domain $K$ must support exact division (``K.exquo``). This method is + suitable for most exact rings and fields like :ref:`ZZ`, :ref:`QQ` and + :ref:`QQ(a)` but not for inexact domains like :ref:`RR` and :ref:`CC`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.ddm import ddm_idet + >>> a = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)], [ZZ(7), ZZ(8), ZZ(9)]] + >>> a + [[1, 2, 3], [4, 5, 6], [7, 8, 9]] + >>> ddm_idet(a, ZZ) + 0 + >>> a + [[1, 2, 3], [4, -3, -6], [7, -6, 0]] + >>> [a[i][i] for i in range(len(a))] + [1, -3, 0] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.det + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Bareiss_algorithm + .. [2] https://www.math.usm.edu/perry/Research/Thesis_DRL.pdf + """ + # Bareiss algorithm + # https://www.math.usm.edu/perry/Research/Thesis_DRL.pdf + + # a is (m x n) + m = len(a) + if not m: + return K.one + n = len(a[0]) + + exquo = K.exquo + # uf keeps track of the sign change from row swaps + uf = K.one + + for k in range(n-1): + if not a[k][k]: + for i in range(k+1, n): + if a[i][k]: + a[k], a[i] = a[i], a[k] + uf = -uf + break + else: + return K.zero + + akkm1 = a[k-1][k-1] if k else K.one + + for i in range(k+1, n): + for j in range(k+1, n): + a[i][j] = exquo(a[i][j]*a[k][k] - a[i][k]*a[k][j], akkm1) + + return uf * a[-1][-1] + + +def ddm_iinv(ainv, a, K): + """ainv <-- inv(a) + + Compute the inverse of a matrix $a$ over a field $K$ using Gauss-Jordan + elimination. The result is stored in $ainv$. + + Uses division in the ground domain which should be an exact field. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import ddm_iinv, ddm_imatmul + >>> from sympy import QQ + >>> a = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + >>> ainv = [[None, None], [None, None]] + >>> ddm_iinv(ainv, a, QQ) + >>> ainv + [[-2, 1], [3/2, -1/2]] + >>> result = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + >>> ddm_imatmul(result, a, ainv) + >>> result + [[1, 0], [0, 1]] + + See Also + ======== + + ddm_irref: the underlying routine. + """ + if not K.is_Field: + raise DMDomainError('Not a field') + + # a is (m x n) + m = len(a) + if not m: + return + n = len(a[0]) + if m != n: + raise DMNonSquareMatrixError + + eye = [[K.one if i==j else K.zero for j in range(n)] for i in range(n)] + Aaug = [row + eyerow for row, eyerow in zip(a, eye)] + pivots = ddm_irref(Aaug) + if pivots != list(range(n)): + raise DMNonInvertibleMatrixError('Matrix det == 0; not invertible.') + ainv[:] = [row[n:] for row in Aaug] + + +def ddm_ilu_split(L, U, K): + """L, U <-- LU(U) + + Compute the LU decomposition of a matrix $L$ in place and store the lower + and upper triangular matrices in $L$ and $U$, respectively. Returns a list + of row swaps that were performed. + + Uses division in the ground domain which should be an exact field. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import ddm_ilu_split + >>> from sympy import QQ + >>> L = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + >>> U = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + >>> swaps = ddm_ilu_split(L, U, QQ) + >>> swaps + [] + >>> L + [[0, 0], [3, 0]] + >>> U + [[1, 2], [0, -2]] + + See Also + ======== + + ddm_ilu + ddm_ilu_solve + """ + m = len(U) + if not m: + return [] + n = len(U[0]) + + swaps = ddm_ilu(U) + + zeros = [K.zero] * min(m, n) + for i in range(1, m): + j = min(i, n) + L[i][:j] = U[i][:j] + U[i][:j] = zeros[:j] + + return swaps + + +def ddm_ilu(a): + """a <-- LU(a) + + Computes the LU decomposition of a matrix in place. Returns a list of + row swaps that were performed. + + Uses division in the ground domain which should be an exact field. + + This is only suitable for domains like :ref:`GF(p)`, :ref:`QQ`, :ref:`QQ_I` + and :ref:`QQ(a)`. With a rational function field like :ref:`K(x)` it is + better to clear denominators and use division-free algorithms. Pivoting is + used to avoid exact zeros but not for floating point accuracy so :ref:`RR` + and :ref:`CC` are not suitable (use :func:`ddm_irref` instead). + + Examples + ======== + + >>> from sympy.polys.matrices.dense import ddm_ilu + >>> from sympy import QQ + >>> a = [[QQ(1, 2), QQ(1, 3)], [QQ(1, 4), QQ(1, 5)]] + >>> swaps = ddm_ilu(a) + >>> swaps + [] + >>> a + [[1/2, 1/3], [1/2, 1/30]] + + The same example using ``Matrix``: + + >>> from sympy import Matrix, S + >>> M = Matrix([[S(1)/2, S(1)/3], [S(1)/4, S(1)/5]]) + >>> L, U, swaps = M.LUdecomposition() + >>> L + Matrix([ + [ 1, 0], + [1/2, 1]]) + >>> U + Matrix([ + [1/2, 1/3], + [ 0, 1/30]]) + >>> swaps + [] + + See Also + ======== + + ddm_irref + ddm_ilu_solve + sympy.matrices.matrixbase.MatrixBase.LUdecomposition + """ + m = len(a) + if not m: + return [] + n = len(a[0]) + + swaps = [] + + for i in range(min(m, n)): + if not a[i][i]: + for ip in range(i+1, m): + if a[ip][i]: + swaps.append((i, ip)) + a[i], a[ip] = a[ip], a[i] + break + else: + # M = Matrix([[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]]) + continue + for j in range(i+1, m): + l_ji = a[j][i] / a[i][i] + a[j][i] = l_ji + for k in range(i+1, n): + a[j][k] -= l_ji * a[i][k] + + return swaps + + +def ddm_ilu_solve(x, L, U, swaps, b): + """x <-- solve(L*U*x = swaps(b)) + + Solve a linear system, $A*x = b$, given an LU factorization of $A$. + + Uses division in the ground domain which must be a field. + + Modifies $x$ in place. + + Examples + ======== + + Compute the LU decomposition of $A$ (in place): + + >>> from sympy import QQ + >>> from sympy.polys.matrices.dense import ddm_ilu, ddm_ilu_solve + >>> A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + >>> swaps = ddm_ilu(A) + >>> A + [[1, 2], [3, -2]] + >>> L = U = A + + Solve the linear system: + + >>> b = [[QQ(5)], [QQ(6)]] + >>> x = [[None], [None]] + >>> ddm_ilu_solve(x, L, U, swaps, b) + >>> x + [[-4], [9/2]] + + See Also + ======== + + ddm_ilu + Compute the LU decomposition of a matrix in place. + ddm_ilu_split + Compute the LU decomposition of a matrix and separate $L$ and $U$. + sympy.polys.matrices.domainmatrix.DomainMatrix.lu_solve + Higher level interface to this function. + """ + m = len(U) + if not m: + return + n = len(U[0]) + + m2 = len(b) + if not m2: + raise DMShapeError("Shape mismtch") + o = len(b[0]) + + if m != m2: + raise DMShapeError("Shape mismtch") + if m < n: + raise NotImplementedError("Underdetermined") + + if swaps: + b = [row[:] for row in b] + for i1, i2 in swaps: + b[i1], b[i2] = b[i2], b[i1] + + # solve Ly = b + y = [[None] * o for _ in range(m)] + for k in range(o): + for i in range(m): + rhs = b[i][k] + for j in range(i): + rhs -= L[i][j] * y[j][k] + y[i][k] = rhs + + if m > n: + for i in range(n, m): + for j in range(o): + if y[i][j]: + raise DMNonInvertibleMatrixError + + # Solve Ux = y + for k in range(o): + for i in reversed(range(n)): + if not U[i][i]: + raise DMNonInvertibleMatrixError + rhs = y[i][k] + for j in range(i+1, n): + rhs -= U[i][j] * x[j][k] + x[i][k] = rhs / U[i][i] + + +def ddm_berk(M, K): + """ + Berkowitz algorithm for computing the characteristic polynomial. + + Explanation + =========== + + The Berkowitz algorithm is a division-free algorithm for computing the + characteristic polynomial of a matrix over any commutative ring using only + arithmetic in the coefficient ring. + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.polys.matrices.dense import ddm_berk + >>> from sympy.polys.domains import ZZ + >>> M = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + >>> ddm_berk(M, ZZ) + [[1], [-5], [-2]] + >>> Matrix(M).charpoly() + PurePoly(lambda**2 - 5*lambda - 2, lambda, domain='ZZ') + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.charpoly + The high-level interface to this function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Samuelson%E2%80%93Berkowitz_algorithm + """ + m = len(M) + if not m: + return [[K.one]] + n = len(M[0]) + + if m != n: + raise DMShapeError("Not square") + + if n == 1: + return [[K.one], [-M[0][0]]] + + a = M[0][0] + R = [M[0][1:]] + C = [[row[0]] for row in M[1:]] + A = [row[1:] for row in M[1:]] + + q = ddm_berk(A, K) + + T = [[K.zero] * n for _ in range(n+1)] + for i in range(n): + T[i][i] = K.one + T[i+1][i] = -a + for i in range(2, n+1): + if i == 2: + AnC = C + else: + C = AnC + AnC = [[K.zero] for row in C] + ddm_imatmul(AnC, A, C) + RAnC = [[K.zero]] + ddm_imatmul(RAnC, R, AnC) + for j in range(0, n+1-i): + T[i+j][j] = -RAnC[0][0] + + qout = [[K.zero] for _ in range(n+1)] + ddm_imatmul(qout, T, q) + return qout diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/dfm.py b/lib/python3.10/site-packages/sympy/polys/matrices/dfm.py new file mode 100644 index 0000000000000000000000000000000000000000..22938b7004654121f74b020bd6649bee84909e1e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/dfm.py @@ -0,0 +1,35 @@ +""" +sympy.polys.matrices.dfm + +Provides the :class:`DFM` class if ``GROUND_TYPES=flint'``. Otherwise, ``DFM`` +is a placeholder class that raises NotImplementedError when instantiated. +""" + +from sympy.external.gmpy import GROUND_TYPES + +if GROUND_TYPES == "flint": # pragma: no cover + # When python-flint is installed we will try to use it for dense matrices + # if the domain is supported by python-flint. + from ._dfm import DFM + +else: # pragma: no cover + # Other code should be able to import this and it should just present as a + # version of DFM that does not support any domains. + class DFM_dummy: + """ + Placeholder class for DFM when python-flint is not installed. + """ + def __init__(*args, **kwargs): + raise NotImplementedError("DFM requires GROUND_TYPES=flint.") + + @classmethod + def _supports_domain(cls, domain): + return False + + @classmethod + def _get_flint_func(cls, domain): + raise NotImplementedError("DFM requires GROUND_TYPES=flint.") + + # mypy really struggles with this kind of conditional type assignment. + # Maybe there is a better way to annotate this rather than type: ignore. + DFM = DFM_dummy # type: ignore diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/domainmatrix.py b/lib/python3.10/site-packages/sympy/polys/matrices/domainmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..b91bef314d16364cdbfc18f5f7470d958468a23a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/domainmatrix.py @@ -0,0 +1,3850 @@ +""" + +Module for the DomainMatrix class. + +A DomainMatrix represents a matrix with elements that are in a particular +Domain. Each DomainMatrix internally wraps a DDM which is used for the +lower-level operations. The idea is that the DomainMatrix class provides the +convenience routines for converting between Expr and the poly domains as well +as unifying matrices with different domains. + +""" +from collections import Counter +from functools import reduce +from typing import Union as tUnion, Tuple as tTuple + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on + +from sympy.core.sympify import _sympify + +from ..domains import Domain + +from ..constructor import construct_domain + +from .exceptions import ( + DMFormatError, + DMBadInputError, + DMShapeError, + DMDomainError, + DMNotAField, + DMNonSquareMatrixError, + DMNonInvertibleMatrixError +) + +from .domainscalar import DomainScalar + +from sympy.polys.domains import ZZ, EXRAW, QQ + +from sympy.polys.densearith import dup_mul +from sympy.polys.densebasic import dup_convert +from sympy.polys.densetools import ( + dup_mul_ground, + dup_quo_ground, + dup_content, + dup_clear_denoms, + dup_primitive, + dup_transform, +) +from sympy.polys.factortools import dup_factor_list +from sympy.polys.polyutils import _sort_factors + +from .ddm import DDM + +from .sdm import SDM + +from .dfm import DFM + +from .rref import _dm_rref, _dm_rref_den + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['DomainMatrix.to_dfm', 'DomainMatrix.to_dfm_or_ddm'] +else: + __doctest_skip__ = ['DomainMatrix.from_list'] + + +def DM(rows, domain): + """Convenient alias for DomainMatrix.from_list + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> DM([[1, 2], [3, 4]], ZZ) + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + + See Also + ======== + + DomainMatrix.from_list + """ + return DomainMatrix.from_list(rows, domain) + + +class DomainMatrix: + r""" + Associate Matrix with :py:class:`~.Domain` + + Explanation + =========== + + DomainMatrix uses :py:class:`~.Domain` for its internal representation + which makes it faster than the SymPy Matrix class (currently) for many + common operations, but this advantage makes it not entirely compatible + with Matrix. DomainMatrix are analogous to numpy arrays with "dtype". + In the DomainMatrix, each element has a domain such as :ref:`ZZ` + or :ref:`QQ(a)`. + + + Examples + ======== + + Creating a DomainMatrix from the existing Matrix class: + + >>> from sympy import Matrix + >>> from sympy.polys.matrices import DomainMatrix + >>> Matrix1 = Matrix([ + ... [1, 2], + ... [3, 4]]) + >>> A = DomainMatrix.from_Matrix(Matrix1) + >>> A + DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + + Directly forming a DomainMatrix: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> A + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + + See Also + ======== + + DDM + SDM + Domain + Poly + + """ + rep: tUnion[SDM, DDM, DFM] + shape: tTuple[int, int] + domain: Domain + + def __new__(cls, rows, shape, domain, *, fmt=None): + """ + Creates a :py:class:`~.DomainMatrix`. + + Parameters + ========== + + rows : Represents elements of DomainMatrix as list of lists + shape : Represents dimension of DomainMatrix + domain : Represents :py:class:`~.Domain` of DomainMatrix + + Raises + ====== + + TypeError + If any of rows, shape and domain are not provided + + """ + if isinstance(rows, (DDM, SDM, DFM)): + raise TypeError("Use from_rep to initialise from SDM/DDM") + elif isinstance(rows, list): + rep = DDM(rows, shape, domain) + elif isinstance(rows, dict): + rep = SDM(rows, shape, domain) + else: + msg = "Input should be list-of-lists or dict-of-dicts" + raise TypeError(msg) + + if fmt is not None: + if fmt == 'sparse': + rep = rep.to_sdm() + elif fmt == 'dense': + rep = rep.to_ddm() + else: + raise ValueError("fmt should be 'sparse' or 'dense'") + + # Use python-flint for dense matrices if possible + if rep.fmt == 'dense' and DFM._supports_domain(domain): + rep = rep.to_dfm() + + return cls.from_rep(rep) + + def __reduce__(self): + rep = self.rep + if rep.fmt == 'dense': + arg = self.to_list() + elif rep.fmt == 'sparse': + arg = dict(rep) + else: + raise RuntimeError # pragma: no cover + args = (arg, rep.shape, rep.domain) + return (self.__class__, args) + + def __getitem__(self, key): + i, j = key + m, n = self.shape + if not (isinstance(i, slice) or isinstance(j, slice)): + return DomainScalar(self.rep.getitem(i, j), self.domain) + + if not isinstance(i, slice): + if not -m <= i < m: + raise IndexError("Row index out of range") + i = i % m + i = slice(i, i+1) + if not isinstance(j, slice): + if not -n <= j < n: + raise IndexError("Column index out of range") + j = j % n + j = slice(j, j+1) + + return self.from_rep(self.rep.extract_slice(i, j)) + + def getitem_sympy(self, i, j): + return self.domain.to_sympy(self.rep.getitem(i, j)) + + def extract(self, rowslist, colslist): + return self.from_rep(self.rep.extract(rowslist, colslist)) + + def __setitem__(self, key, value): + i, j = key + if not self.domain.of_type(value): + raise TypeError + if isinstance(i, int) and isinstance(j, int): + self.rep.setitem(i, j, value) + else: + raise NotImplementedError + + @classmethod + def from_rep(cls, rep): + """Create a new DomainMatrix efficiently from DDM/SDM. + + Examples + ======== + + Create a :py:class:`~.DomainMatrix` with an dense internal + representation as :py:class:`~.DDM`: + + >>> from sympy.polys.domains import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.ddm import DDM + >>> drep = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> dM = DomainMatrix.from_rep(drep) + >>> dM + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + + Create a :py:class:`~.DomainMatrix` with a sparse internal + representation as :py:class:`~.SDM`: + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import ZZ + >>> drep = SDM({0:{1:ZZ(1)},1:{0:ZZ(2)}}, (2, 2), ZZ) + >>> dM = DomainMatrix.from_rep(drep) + >>> dM + DomainMatrix({0: {1: 1}, 1: {0: 2}}, (2, 2), ZZ) + + Parameters + ========== + + rep: SDM or DDM + The internal sparse or dense representation of the matrix. + + Returns + ======= + + DomainMatrix + A :py:class:`~.DomainMatrix` wrapping *rep*. + + Notes + ===== + + This takes ownership of rep as its internal representation. If rep is + being mutated elsewhere then a copy should be provided to + ``from_rep``. Only minimal verification or checking is done on *rep* + as this is supposed to be an efficient internal routine. + + """ + if not (isinstance(rep, (DDM, SDM)) or (DFM is not None and isinstance(rep, DFM))): + raise TypeError("rep should be of type DDM or SDM") + self = super().__new__(cls) + self.rep = rep + self.shape = rep.shape + self.domain = rep.domain + return self + + @classmethod + @doctest_depends_on(ground_types=['python', 'gmpy']) + def from_list(cls, rows, domain): + r""" + Convert a list of lists into a DomainMatrix + + Parameters + ========== + + rows: list of lists + Each element of the inner lists should be either the single arg, + or tuple of args, that would be passed to the domain constructor + in order to form an element of the domain. See examples. + + Returns + ======= + + DomainMatrix containing elements defined in rows + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import FF, QQ, ZZ + >>> A = DomainMatrix.from_list([[1, 0, 1], [0, 0, 1]], ZZ) + >>> A + DomainMatrix([[1, 0, 1], [0, 0, 1]], (2, 3), ZZ) + >>> B = DomainMatrix.from_list([[1, 0, 1], [0, 0, 1]], FF(7)) + >>> B + DomainMatrix([[1 mod 7, 0 mod 7, 1 mod 7], [0 mod 7, 0 mod 7, 1 mod 7]], (2, 3), GF(7)) + >>> C = DomainMatrix.from_list([[(1, 2), (3, 1)], [(1, 4), (5, 1)]], QQ) + >>> C + DomainMatrix([[1/2, 3], [1/4, 5]], (2, 2), QQ) + + See Also + ======== + + from_list_sympy + + """ + nrows = len(rows) + ncols = 0 if not nrows else len(rows[0]) + conv = lambda e: domain(*e) if isinstance(e, tuple) else domain(e) + domain_rows = [[conv(e) for e in row] for row in rows] + return DomainMatrix(domain_rows, (nrows, ncols), domain) + + @classmethod + def from_list_sympy(cls, nrows, ncols, rows, **kwargs): + r""" + Convert a list of lists of Expr into a DomainMatrix using construct_domain + + Parameters + ========== + + nrows: number of rows + ncols: number of columns + rows: list of lists + + Returns + ======= + + DomainMatrix containing elements of rows + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.abc import x, y, z + >>> A = DomainMatrix.from_list_sympy(1, 3, [[x, y, z]]) + >>> A + DomainMatrix([[x, y, z]], (1, 3), ZZ[x,y,z]) + + See Also + ======== + + sympy.polys.constructor.construct_domain, from_dict_sympy + + """ + assert len(rows) == nrows + assert all(len(row) == ncols for row in rows) + + items_sympy = [_sympify(item) for row in rows for item in row] + + domain, items_domain = cls.get_domain(items_sympy, **kwargs) + + domain_rows = [[items_domain[ncols*r + c] for c in range(ncols)] for r in range(nrows)] + + return DomainMatrix(domain_rows, (nrows, ncols), domain) + + @classmethod + def from_dict_sympy(cls, nrows, ncols, elemsdict, **kwargs): + """ + + Parameters + ========== + + nrows: number of rows + ncols: number of cols + elemsdict: dict of dicts containing non-zero elements of the DomainMatrix + + Returns + ======= + + DomainMatrix containing elements of elemsdict + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.abc import x,y,z + >>> elemsdict = {0: {0:x}, 1:{1: y}, 2: {2: z}} + >>> A = DomainMatrix.from_dict_sympy(3, 3, elemsdict) + >>> A + DomainMatrix({0: {0: x}, 1: {1: y}, 2: {2: z}}, (3, 3), ZZ[x,y,z]) + + See Also + ======== + + from_list_sympy + + """ + if not all(0 <= r < nrows for r in elemsdict): + raise DMBadInputError("Row out of range") + if not all(0 <= c < ncols for row in elemsdict.values() for c in row): + raise DMBadInputError("Column out of range") + + items_sympy = [_sympify(item) for row in elemsdict.values() for item in row.values()] + domain, items_domain = cls.get_domain(items_sympy, **kwargs) + + idx = 0 + items_dict = {} + for i, row in elemsdict.items(): + items_dict[i] = {} + for j in row: + items_dict[i][j] = items_domain[idx] + idx += 1 + + return DomainMatrix(items_dict, (nrows, ncols), domain) + + @classmethod + def from_Matrix(cls, M, fmt='sparse',**kwargs): + r""" + Convert Matrix to DomainMatrix + + Parameters + ========== + + M: Matrix + + Returns + ======= + + Returns DomainMatrix with identical elements as M + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.polys.matrices import DomainMatrix + >>> M = Matrix([ + ... [1.0, 3.4], + ... [2.4, 1]]) + >>> A = DomainMatrix.from_Matrix(M) + >>> A + DomainMatrix({0: {0: 1.0, 1: 3.4}, 1: {0: 2.4, 1: 1.0}}, (2, 2), RR) + + We can keep internal representation as ddm using fmt='dense' + >>> from sympy import Matrix, QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix.from_Matrix(Matrix([[QQ(1, 2), QQ(3, 4)], [QQ(0, 1), QQ(0, 1)]]), fmt='dense') + >>> A.rep + [[1/2, 3/4], [0, 0]] + + See Also + ======== + + Matrix + + """ + if fmt == 'dense': + return cls.from_list_sympy(*M.shape, M.tolist(), **kwargs) + + return cls.from_dict_sympy(*M.shape, M.todod(), **kwargs) + + @classmethod + def get_domain(cls, items_sympy, **kwargs): + K, items_K = construct_domain(items_sympy, **kwargs) + return K, items_K + + def choose_domain(self, **opts): + """Convert to a domain found by :func:`~.construct_domain`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> M = DM([[1, 2], [3, 4]], ZZ) + >>> M + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + >>> M.choose_domain(field=True) + DomainMatrix([[1, 2], [3, 4]], (2, 2), QQ) + + >>> from sympy.abc import x + >>> M = DM([[1, x], [x**2, x**3]], ZZ[x]) + >>> M.choose_domain(field=True).domain + ZZ(x) + + Keyword arguments are passed to :func:`~.construct_domain`. + + See Also + ======== + + construct_domain + convert_to + """ + elements, data = self.to_sympy().to_flat_nz() + dom, elements_dom = construct_domain(elements, **opts) + return self.from_flat_nz(elements_dom, data, dom) + + def copy(self): + return self.from_rep(self.rep.copy()) + + def convert_to(self, K): + r""" + Change the domain of DomainMatrix to desired domain or field + + Parameters + ========== + + K : Represents the desired domain or field. + Alternatively, ``None`` may be passed, in which case this method + just returns a copy of this DomainMatrix. + + Returns + ======= + + DomainMatrix + DomainMatrix with the desired domain or field + + Examples + ======== + + >>> from sympy import ZZ, ZZ_I + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.convert_to(ZZ_I) + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ_I) + + """ + if K == self.domain: + return self.copy() + + rep = self.rep + + # The DFM, DDM and SDM types do not do any implicit conversions so we + # manage switching between DDM and DFM here. + if rep.is_DFM and not DFM._supports_domain(K): + rep_K = rep.to_ddm().convert_to(K) + elif rep.is_DDM and DFM._supports_domain(K): + rep_K = rep.convert_to(K).to_dfm() + else: + rep_K = rep.convert_to(K) + + return self.from_rep(rep_K) + + def to_sympy(self): + return self.convert_to(EXRAW) + + def to_field(self): + r""" + Returns a DomainMatrix with the appropriate field + + Returns + ======= + + DomainMatrix + DomainMatrix with the appropriate field + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.to_field() + DomainMatrix([[1, 2], [3, 4]], (2, 2), QQ) + + """ + K = self.domain.get_field() + return self.convert_to(K) + + def to_sparse(self): + """ + Return a sparse DomainMatrix representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> A.rep + [[1, 0], [0, 2]] + >>> B = A.to_sparse() + >>> B.rep + {0: {0: 1}, 1: {1: 2}} + """ + if self.rep.fmt == 'sparse': + return self + + return self.from_rep(self.rep.to_sdm()) + + def to_dense(self): + """ + Return a dense DomainMatrix representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix({0: {0: 1}, 1: {1: 2}}, (2, 2), QQ) + >>> A.rep + {0: {0: 1}, 1: {1: 2}} + >>> B = A.to_dense() + >>> B.rep + [[1, 0], [0, 2]] + + """ + rep = self.rep + + if rep.fmt == 'dense': + return self + + return self.from_rep(rep.to_dfm_or_ddm()) + + def to_ddm(self): + """ + Return a :class:`~.DDM` representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix({0: {0: 1}, 1: {1: 2}}, (2, 2), QQ) + >>> ddm = A.to_ddm() + >>> ddm + [[1, 0], [0, 2]] + >>> type(ddm) + + + See Also + ======== + + to_sdm + to_dense + sympy.polys.matrices.ddm.DDM.to_sdm + """ + return self.rep.to_ddm() + + def to_sdm(self): + """ + Return a :class:`~.SDM` representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> sdm = A.to_sdm() + >>> sdm + {0: {0: 1}, 1: {1: 2}} + >>> type(sdm) + + + See Also + ======== + + to_ddm + to_sparse + sympy.polys.matrices.sdm.SDM.to_ddm + """ + return self.rep.to_sdm() + + @doctest_depends_on(ground_types=['flint']) + def to_dfm(self): + """ + Return a :class:`~.DFM` representation of *self*. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> dfm = A.to_dfm() + >>> dfm + [[1, 0], [0, 2]] + >>> type(dfm) + + + See Also + ======== + + to_ddm + to_dense + DFM + """ + return self.rep.to_dfm() + + @doctest_depends_on(ground_types=['flint']) + def to_dfm_or_ddm(self): + """ + Return a :class:`~.DFM` or :class:`~.DDM` representation of *self*. + + Explanation + =========== + + The :class:`~.DFM` representation can only be used if the ground types + are ``flint`` and the ground domain is supported by ``python-flint``. + This method will return a :class:`~.DFM` representation if possible, + but will return a :class:`~.DDM` representation otherwise. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> A = DomainMatrix([[1, 0],[0, 2]], (2, 2), QQ) + >>> dfm = A.to_dfm_or_ddm() + >>> dfm + [[1, 0], [0, 2]] + >>> type(dfm) # Depends on the ground domain and ground types + + + See Also + ======== + + to_ddm: Always return a :class:`~.DDM` representation. + to_dfm: Returns a :class:`~.DFM` representation or raise an error. + to_dense: Convert internally to a :class:`~.DFM` or :class:`~.DDM` + DFM: The :class:`~.DFM` dense FLINT matrix representation. + DDM: The Python :class:`~.DDM` dense domain matrix representation. + """ + return self.rep.to_dfm_or_ddm() + + @classmethod + def _unify_domain(cls, *matrices): + """Convert matrices to a common domain""" + domains = {matrix.domain for matrix in matrices} + if len(domains) == 1: + return matrices + domain = reduce(lambda x, y: x.unify(y), domains) + return tuple(matrix.convert_to(domain) for matrix in matrices) + + @classmethod + def _unify_fmt(cls, *matrices, fmt=None): + """Convert matrices to the same format. + + If all matrices have the same format, then return unmodified. + Otherwise convert both to the preferred format given as *fmt* which + should be 'dense' or 'sparse'. + """ + formats = {matrix.rep.fmt for matrix in matrices} + if len(formats) == 1: + return matrices + if fmt == 'sparse': + return tuple(matrix.to_sparse() for matrix in matrices) + elif fmt == 'dense': + return tuple(matrix.to_dense() for matrix in matrices) + else: + raise ValueError("fmt should be 'sparse' or 'dense'") + + def unify(self, *others, fmt=None): + """ + Unifies the domains and the format of self and other + matrices. + + Parameters + ========== + + others : DomainMatrix + + fmt: string 'dense', 'sparse' or `None` (default) + The preferred format to convert to if self and other are not + already in the same format. If `None` or not specified then no + conversion if performed. + + Returns + ======= + + Tuple[DomainMatrix] + Matrices with unified domain and format + + Examples + ======== + + Unify the domain of DomainMatrix that have different domains: + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + >>> B = DomainMatrix([[QQ(1, 2), QQ(2)]], (1, 2), QQ) + >>> Aq, Bq = A.unify(B) + >>> Aq + DomainMatrix([[1, 2]], (1, 2), QQ) + >>> Bq + DomainMatrix([[1/2, 2]], (1, 2), QQ) + + Unify the format (dense or sparse): + + >>> A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + >>> B = DomainMatrix({0:{0: ZZ(1)}}, (2, 2), ZZ) + >>> B.rep + {0: {0: 1}} + + >>> A2, B2 = A.unify(B, fmt='dense') + >>> B2.rep + [[1, 0], [0, 0]] + + See Also + ======== + + convert_to, to_dense, to_sparse + + """ + matrices = (self,) + others + matrices = DomainMatrix._unify_domain(*matrices) + if fmt is not None: + matrices = DomainMatrix._unify_fmt(*matrices, fmt=fmt) + return matrices + + def to_Matrix(self): + r""" + Convert DomainMatrix to Matrix + + Returns + ======= + + Matrix + MutableDenseMatrix for the DomainMatrix + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.to_Matrix() + Matrix([ + [1, 2], + [3, 4]]) + + See Also + ======== + + from_Matrix + + """ + from sympy.matrices.dense import MutableDenseMatrix + + # XXX: If the internal representation of RepMatrix changes then this + # might need to be changed also. + if self.domain in (ZZ, QQ, EXRAW): + if self.rep.fmt == "sparse": + rep = self.copy() + else: + rep = self.to_sparse() + else: + rep = self.convert_to(EXRAW).to_sparse() + + return MutableDenseMatrix._fromrep(rep) + + def to_list(self): + """ + Convert :class:`DomainMatrix` to list of lists. + + See Also + ======== + + from_list + to_list_flat + to_flat_nz + to_dok + """ + return self.rep.to_list() + + def to_list_flat(self): + """ + Convert :class:`DomainMatrix` to flat list. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> A.to_list_flat() + [1, 2, 3, 4] + + See Also + ======== + + from_list_flat + to_list + to_flat_nz + to_dok + """ + return self.rep.to_list_flat() + + @classmethod + def from_list_flat(cls, elements, shape, domain): + """ + Create :class:`DomainMatrix` from flat list. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> element_list = [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + >>> A = DomainMatrix.from_list_flat(element_list, (2, 2), ZZ) + >>> A + DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ) + >>> A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + to_list_flat + """ + ddm = DDM.from_list_flat(elements, shape, domain) + return cls.from_rep(ddm.to_dfm_or_ddm()) + + def to_flat_nz(self): + """ + Convert :class:`DomainMatrix` to list of nonzero elements and data. + + Explanation + =========== + + Returns a tuple ``(elements, data)`` where ``elements`` is a list of + elements of the matrix with zeros possibly excluded. The matrix can be + reconstructed by passing these to :meth:`from_flat_nz`. The idea is to + be able to modify a flat list of the elements and then create a new + matrix of the same shape with the modified elements in the same + positions. + + The format of ``data`` differs depending on whether the underlying + representation is dense or sparse but either way it represents the + positions of the elements in the list in a way that + :meth:`from_flat_nz` can use to reconstruct the matrix. The + :meth:`from_flat_nz` method should be called on the same + :class:`DomainMatrix` that was used to call :meth:`to_flat_nz`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> elements, data = A.to_flat_nz() + >>> elements + [1, 2, 3, 4] + >>> A == A.from_flat_nz(elements, data, A.domain) + True + + Create a matrix with the elements doubled: + + >>> elements_doubled = [2*x for x in elements] + >>> A2 = A.from_flat_nz(elements_doubled, data, A.domain) + >>> A2 == 2*A + True + + See Also + ======== + + from_flat_nz + """ + return self.rep.to_flat_nz() + + def from_flat_nz(self, elements, data, domain): + """ + Reconstruct :class:`DomainMatrix` after calling :meth:`to_flat_nz`. + + See :meth:`to_flat_nz` for explanation. + + See Also + ======== + + to_flat_nz + """ + rep = self.rep.from_flat_nz(elements, data, domain) + return self.from_rep(rep) + + def to_dod(self): + """ + Convert :class:`DomainMatrix` to dictionary of dictionaries (dod) format. + + Explanation + =========== + + Returns a dictionary of dictionaries representing the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2), ZZ(0)], [ZZ(3), ZZ(0), ZZ(4)]], ZZ) + >>> A.to_dod() + {0: {0: 1, 1: 2}, 1: {0: 3, 2: 4}} + >>> A.to_sparse() == A.from_dod(A.to_dod(), A.shape, A.domain) + True + >>> A == A.from_dod_like(A.to_dod()) + True + + See Also + ======== + + from_dod + from_dod_like + to_dok + to_list + to_list_flat + to_flat_nz + sympy.matrices.matrixbase.MatrixBase.todod + """ + return self.rep.to_dod() + + @classmethod + def from_dod(cls, dod, shape, domain): + """ + Create sparse :class:`DomainMatrix` from dict of dict (dod) format. + + See :meth:`to_dod` for explanation. + + See Also + ======== + + to_dod + from_dod_like + """ + return cls.from_rep(SDM.from_dod(dod, shape, domain)) + + def from_dod_like(self, dod, domain=None): + """ + Create :class:`DomainMatrix` like ``self`` from dict of dict (dod) format. + + See :meth:`to_dod` for explanation. + + See Also + ======== + + to_dod + from_dod + """ + if domain is None: + domain = self.domain + return self.from_rep(self.rep.from_dod(dod, self.shape, domain)) + + def to_dok(self): + """ + Convert :class:`DomainMatrix` to dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(0)], + ... [ZZ(0), ZZ(4)]], (2, 2), ZZ) + >>> A.to_dok() + {(0, 0): 1, (1, 1): 4} + + The matrix can be reconstructed by calling :meth:`from_dok` although + the reconstructed matrix will always be in sparse format: + + >>> A.to_sparse() == A.from_dok(A.to_dok(), A.shape, A.domain) + True + + See Also + ======== + + from_dok + to_list + to_list_flat + to_flat_nz + """ + return self.rep.to_dok() + + @classmethod + def from_dok(cls, dok, shape, domain): + """ + Create :class:`DomainMatrix` from dictionary of keys (dok) format. + + See :meth:`to_dok` for explanation. + + See Also + ======== + + to_dok + """ + return cls.from_rep(SDM.from_dok(dok, shape, domain)) + + def iter_values(self): + """ + Iterate over nonzero elements of the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> list(A.iter_values()) + [1, 3, 4] + + See Also + ======== + + iter_items + to_list_flat + sympy.matrices.matrixbase.MatrixBase.iter_values + """ + return self.rep.iter_values() + + def iter_items(self): + """ + Iterate over indices and values of nonzero elements of the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> list(A.iter_items()) + [((0, 0), 1), ((1, 0), 3), ((1, 1), 4)] + + See Also + ======== + + iter_values + to_dok + sympy.matrices.matrixbase.MatrixBase.iter_items + """ + return self.rep.iter_items() + + def nnz(self): + """ + Number of nonzero elements in the matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[1, 0], [0, 4]], ZZ) + >>> A.nnz() + 2 + """ + return self.rep.nnz() + + def __repr__(self): + return 'DomainMatrix(%s, %r, %r)' % (str(self.rep), self.shape, self.domain) + + def transpose(self): + """Matrix transpose of ``self``""" + return self.from_rep(self.rep.transpose()) + + def flat(self): + rows, cols = self.shape + return [self[i,j].element for i in range(rows) for j in range(cols)] + + @property + def is_zero_matrix(self): + return self.rep.is_zero_matrix() + + @property + def is_upper(self): + """ + Says whether this matrix is upper-triangular. True can be returned + even if the matrix is not square. + """ + return self.rep.is_upper() + + @property + def is_lower(self): + """ + Says whether this matrix is lower-triangular. True can be returned + even if the matrix is not square. + """ + return self.rep.is_lower() + + @property + def is_diagonal(self): + """ + True if the matrix is diagonal. + + Can return true for non-square matrices. A matrix is diagonal if + ``M[i,j] == 0`` whenever ``i != j``. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> M = DM([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(1)]], ZZ) + >>> M.is_diagonal + True + + See Also + ======== + + is_upper + is_lower + is_square + diagonal + """ + return self.rep.is_diagonal() + + def diagonal(self): + """ + Get the diagonal entries of the matrix as a list. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> M = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> M.diagonal() + [1, 4] + + See Also + ======== + + is_diagonal + diag + """ + return self.rep.diagonal() + + @property + def is_square(self): + """ + True if the matrix is square. + """ + return self.shape[0] == self.shape[1] + + def rank(self): + rref, pivots = self.rref() + return len(pivots) + + def hstack(A, *B): + r"""Horizontally stack the given matrices. + + Parameters + ========== + + B: DomainMatrix + Matrices to stack horizontally. + + Returns + ======= + + DomainMatrix + DomainMatrix by stacking horizontally. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + + >>> A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.hstack(B) + DomainMatrix([[1, 2, 5, 6], [3, 4, 7, 8]], (2, 4), ZZ) + + >>> C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.hstack(B, C) + DomainMatrix([[1, 2, 5, 6, 9, 10], [3, 4, 7, 8, 11, 12]], (2, 6), ZZ) + + See Also + ======== + + unify + """ + A, *B = A.unify(*B, fmt=A.rep.fmt) + return DomainMatrix.from_rep(A.rep.hstack(*(Bk.rep for Bk in B))) + + def vstack(A, *B): + r"""Vertically stack the given matrices. + + Parameters + ========== + + B: DomainMatrix + Matrices to stack vertically. + + Returns + ======= + + DomainMatrix + DomainMatrix by stacking vertically. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + + >>> A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + >>> A.vstack(B) + DomainMatrix([[1, 2], [3, 4], [5, 6], [7, 8]], (4, 2), ZZ) + + >>> C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + >>> A.vstack(B, C) + DomainMatrix([[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]], (6, 2), ZZ) + + See Also + ======== + + unify + """ + A, *B = A.unify(*B, fmt='dense') + return DomainMatrix.from_rep(A.rep.vstack(*(Bk.rep for Bk in B))) + + def applyfunc(self, func, domain=None): + if domain is None: + domain = self.domain + return self.from_rep(self.rep.applyfunc(func, domain)) + + def __add__(A, B): + if not isinstance(B, DomainMatrix): + return NotImplemented + A, B = A.unify(B, fmt='dense') + return A.add(B) + + def __sub__(A, B): + if not isinstance(B, DomainMatrix): + return NotImplemented + A, B = A.unify(B, fmt='dense') + return A.sub(B) + + def __neg__(A): + return A.neg() + + def __mul__(A, B): + """A * B""" + if isinstance(B, DomainMatrix): + A, B = A.unify(B, fmt='dense') + return A.matmul(B) + elif B in A.domain: + return A.scalarmul(B) + elif isinstance(B, DomainScalar): + A, B = A.unify(B) + return A.scalarmul(B.element) + else: + return NotImplemented + + def __rmul__(A, B): + if B in A.domain: + return A.rscalarmul(B) + elif isinstance(B, DomainScalar): + A, B = A.unify(B) + return A.rscalarmul(B.element) + else: + return NotImplemented + + def __pow__(A, n): + """A ** n""" + if not isinstance(n, int): + return NotImplemented + return A.pow(n) + + def _check(a, op, b, ashape, bshape): + if a.domain != b.domain: + msg = "Domain mismatch: %s %s %s" % (a.domain, op, b.domain) + raise DMDomainError(msg) + if ashape != bshape: + msg = "Shape mismatch: %s %s %s" % (a.shape, op, b.shape) + raise DMShapeError(msg) + if a.rep.fmt != b.rep.fmt: + msg = "Format mismatch: %s %s %s" % (a.rep.fmt, op, b.rep.fmt) + raise DMFormatError(msg) + if type(a.rep) != type(b.rep): + msg = "Type mismatch: %s %s %s" % (type(a.rep), op, type(b.rep)) + raise DMFormatError(msg) + + def add(A, B): + r""" + Adds two DomainMatrix matrices of the same Domain + + Parameters + ========== + + A, B: DomainMatrix + matrices to add + + Returns + ======= + + DomainMatrix + DomainMatrix after Addition + + Raises + ====== + + DMShapeError + If the dimensions of the two DomainMatrix are not equal + + ValueError + If the domain of the two DomainMatrix are not same + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(4), ZZ(3)], + ... [ZZ(2), ZZ(1)]], (2, 2), ZZ) + + >>> A.add(B) + DomainMatrix([[5, 5], [5, 5]], (2, 2), ZZ) + + See Also + ======== + + sub, matmul + + """ + A._check('+', B, A.shape, B.shape) + return A.from_rep(A.rep.add(B.rep)) + + + def sub(A, B): + r""" + Subtracts two DomainMatrix matrices of the same Domain + + Parameters + ========== + + A, B: DomainMatrix + matrices to subtract + + Returns + ======= + + DomainMatrix + DomainMatrix after Subtraction + + Raises + ====== + + DMShapeError + If the dimensions of the two DomainMatrix are not equal + + ValueError + If the domain of the two DomainMatrix are not same + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(4), ZZ(3)], + ... [ZZ(2), ZZ(1)]], (2, 2), ZZ) + + >>> A.sub(B) + DomainMatrix([[-3, -1], [1, 3]], (2, 2), ZZ) + + See Also + ======== + + add, matmul + + """ + A._check('-', B, A.shape, B.shape) + return A.from_rep(A.rep.sub(B.rep)) + + def neg(A): + r""" + Returns the negative of DomainMatrix + + Parameters + ========== + + A : Represents a DomainMatrix + + Returns + ======= + + DomainMatrix + DomainMatrix after Negation + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.neg() + DomainMatrix([[-1, -2], [-3, -4]], (2, 2), ZZ) + + """ + return A.from_rep(A.rep.neg()) + + def mul(A, b): + r""" + Performs term by term multiplication for the second DomainMatrix + w.r.t first DomainMatrix. Returns a DomainMatrix whose rows are + list of DomainMatrix matrices created after term by term multiplication. + + Parameters + ========== + + A, B: DomainMatrix + matrices to multiply term-wise + + Returns + ======= + + DomainMatrix + DomainMatrix after term by term multiplication + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> b = ZZ(2) + + >>> A.mul(b) + DomainMatrix([[2, 4], [6, 8]], (2, 2), ZZ) + + See Also + ======== + + matmul + + """ + return A.from_rep(A.rep.mul(b)) + + def rmul(A, b): + return A.from_rep(A.rep.rmul(b)) + + def matmul(A, B): + r""" + Performs matrix multiplication of two DomainMatrix matrices + + Parameters + ========== + + A, B: DomainMatrix + to multiply + + Returns + ======= + + DomainMatrix + DomainMatrix after multiplication + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(1), ZZ(1)], + ... [ZZ(0), ZZ(1)]], (2, 2), ZZ) + + >>> A.matmul(B) + DomainMatrix([[1, 3], [3, 7]], (2, 2), ZZ) + + See Also + ======== + + mul, pow, add, sub + + """ + + A._check('*', B, A.shape[1], B.shape[0]) + return A.from_rep(A.rep.matmul(B.rep)) + + def _scalarmul(A, lamda, reverse): + if lamda == A.domain.zero: + return DomainMatrix.zeros(A.shape, A.domain) + elif lamda == A.domain.one: + return A.copy() + elif reverse: + return A.rmul(lamda) + else: + return A.mul(lamda) + + def scalarmul(A, lamda): + return A._scalarmul(lamda, reverse=False) + + def rscalarmul(A, lamda): + return A._scalarmul(lamda, reverse=True) + + def mul_elementwise(A, B): + assert A.domain == B.domain + return A.from_rep(A.rep.mul_elementwise(B.rep)) + + def __truediv__(A, lamda): + """ Method for Scalar Division""" + if isinstance(lamda, int) or ZZ.of_type(lamda): + lamda = DomainScalar(ZZ(lamda), ZZ) + elif A.domain.is_Field and lamda in A.domain: + K = A.domain + lamda = DomainScalar(K.convert(lamda), K) + + if not isinstance(lamda, DomainScalar): + return NotImplemented + + A, lamda = A.to_field().unify(lamda) + if lamda.element == lamda.domain.zero: + raise ZeroDivisionError + if lamda.element == lamda.domain.one: + return A + + return A.mul(1 / lamda.element) + + def pow(A, n): + r""" + Computes A**n + + Parameters + ========== + + A : DomainMatrix + + n : exponent for A + + Returns + ======= + + DomainMatrix + DomainMatrix on computing A**n + + Raises + ====== + + NotImplementedError + if n is negative. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(1)], + ... [ZZ(0), ZZ(1)]], (2, 2), ZZ) + + >>> A.pow(2) + DomainMatrix([[1, 2], [0, 1]], (2, 2), ZZ) + + See Also + ======== + + matmul + + """ + nrows, ncols = A.shape + if nrows != ncols: + raise DMNonSquareMatrixError('Power of a nonsquare matrix') + if n < 0: + raise NotImplementedError('Negative powers') + elif n == 0: + return A.eye(nrows, A.domain) + elif n == 1: + return A + elif n % 2 == 1: + return A * A**(n - 1) + else: + sqrtAn = A ** (n // 2) + return sqrtAn * sqrtAn + + def scc(self): + """Compute the strongly connected components of a DomainMatrix + + Explanation + =========== + + A square matrix can be considered as the adjacency matrix for a + directed graph where the row and column indices are the vertices. In + this graph if there is an edge from vertex ``i`` to vertex ``j`` if + ``M[i, j]`` is nonzero. This routine computes the strongly connected + components of that graph which are subsets of the rows and columns that + are connected by some nonzero element of the matrix. The strongly + connected components are useful because many operations such as the + determinant can be computed by working with the submatrices + corresponding to each component. + + Examples + ======== + + Find the strongly connected components of a matrix: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> M = DomainMatrix([[ZZ(1), ZZ(0), ZZ(2)], + ... [ZZ(0), ZZ(3), ZZ(0)], + ... [ZZ(4), ZZ(6), ZZ(5)]], (3, 3), ZZ) + >>> M.scc() + [[1], [0, 2]] + + Compute the determinant from the components: + + >>> MM = M.to_Matrix() + >>> MM + Matrix([ + [1, 0, 2], + [0, 3, 0], + [4, 6, 5]]) + >>> MM[[1], [1]] + Matrix([[3]]) + >>> MM[[0, 2], [0, 2]] + Matrix([ + [1, 2], + [4, 5]]) + >>> MM.det() + -9 + >>> MM[[1], [1]].det() * MM[[0, 2], [0, 2]].det() + -9 + + The components are given in reverse topological order and represent a + permutation of the rows and columns that will bring the matrix into + block lower-triangular form: + + >>> MM[[1, 0, 2], [1, 0, 2]] + Matrix([ + [3, 0, 0], + [0, 1, 2], + [6, 4, 5]]) + + Returns + ======= + + List of lists of integers + Each list represents a strongly connected component. + + See also + ======== + + sympy.matrices.matrixbase.MatrixBase.strongly_connected_components + sympy.utilities.iterables.strongly_connected_components + + """ + if not self.is_square: + raise DMNonSquareMatrixError('Matrix must be square for scc') + + return self.rep.scc() + + def clear_denoms(self, convert=False): + """ + Clear denominators, but keep the domain unchanged. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[(1,2), (1,3)], [(1,4), (1,5)]], QQ) + >>> den, Anum = A.clear_denoms() + >>> den.to_sympy() + 60 + >>> Anum.to_Matrix() + Matrix([ + [30, 20], + [15, 12]]) + >>> den * A == Anum + True + + The numerator matrix will be in the same domain as the original matrix + unless ``convert`` is set to ``True``: + + >>> A.clear_denoms()[1].domain + QQ + >>> A.clear_denoms(convert=True)[1].domain + ZZ + + The denominator is always in the associated ring: + + >>> A.clear_denoms()[0].domain + ZZ + >>> A.domain.get_ring() + ZZ + + See Also + ======== + + sympy.polys.polytools.Poly.clear_denoms + clear_denoms_rowwise + """ + elems0, data = self.to_flat_nz() + + K0 = self.domain + K1 = K0.get_ring() if K0.has_assoc_Ring else K0 + + den, elems1 = dup_clear_denoms(elems0, K0, K1, convert=convert) + + if convert: + Kden, Knum = K1, K1 + else: + Kden, Knum = K1, K0 + + den = DomainScalar(den, Kden) + num = self.from_flat_nz(elems1, data, Knum) + + return den, num + + def clear_denoms_rowwise(self, convert=False): + """ + Clear denominators from each row of the matrix. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[(1,2), (1,3), (1,4)], [(1,5), (1,6), (1,7)]], QQ) + >>> den, Anum = A.clear_denoms_rowwise() + >>> den.to_Matrix() + Matrix([ + [12, 0], + [ 0, 210]]) + >>> Anum.to_Matrix() + Matrix([ + [ 6, 4, 3], + [42, 35, 30]]) + + The denominator matrix is a diagonal matrix with the denominators of + each row on the diagonal. The invariants are: + + >>> den * A == Anum + True + >>> A == den.to_field().inv() * Anum + True + + The numerator matrix will be in the same domain as the original matrix + unless ``convert`` is set to ``True``: + + >>> A.clear_denoms_rowwise()[1].domain + QQ + >>> A.clear_denoms_rowwise(convert=True)[1].domain + ZZ + + The domain of the denominator matrix is the associated ring: + + >>> A.clear_denoms_rowwise()[0].domain + ZZ + + See Also + ======== + + sympy.polys.polytools.Poly.clear_denoms + clear_denoms + """ + dod = self.to_dod() + + K0 = self.domain + K1 = K0.get_ring() if K0.has_assoc_Ring else K0 + + diagonals = [K0.one] * self.shape[0] + dod_num = {} + for i, rowi in dod.items(): + indices, elems = zip(*rowi.items()) + den, elems_num = dup_clear_denoms(elems, K0, K1, convert=convert) + rowi_num = dict(zip(indices, elems_num)) + diagonals[i] = den + dod_num[i] = rowi_num + + if convert: + Kden, Knum = K1, K1 + else: + Kden, Knum = K1, K0 + + den = self.diag(diagonals, Kden) + num = self.from_dod_like(dod_num, Knum) + + return den, num + + def cancel_denom(self, denom): + """ + Cancel factors between a matrix and a denominator. + + Returns a matrix and denominator on lowest terms. + + Requires ``gcd`` in the ground domain. + + Methods like :meth:`solve_den`, :meth:`inv_den` and :meth:`rref_den` + return a matrix and denominator but not necessarily on lowest terms. + Reduction to lowest terms without fractions can be performed with + :meth:`cancel_denom`. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 2, 0], + ... [0, 2, 2], + ... [0, 0, 2]], ZZ) + >>> Minv, den = M.inv_den() + >>> Minv.to_Matrix() + Matrix([ + [1, -1, 1], + [0, 1, -1], + [0, 0, 1]]) + >>> den + 2 + >>> Minv_reduced, den_reduced = Minv.cancel_denom(den) + >>> Minv_reduced.to_Matrix() + Matrix([ + [1, -1, 1], + [0, 1, -1], + [0, 0, 1]]) + >>> den_reduced + 2 + >>> Minv_reduced.to_field() / den_reduced == Minv.to_field() / den + True + + The denominator is made canonical with respect to units (e.g. a + negative denominator is made positive): + + >>> M = DM([[2, 2, 0]], ZZ) + >>> den = ZZ(-4) + >>> M.cancel_denom(den) + (DomainMatrix([[-1, -1, 0]], (1, 3), ZZ), 2) + + Any factor common to _all_ elements will be cancelled but there can + still be factors in common between _some_ elements of the matrix and + the denominator. To cancel factors between each element and the + denominator, use :meth:`cancel_denom_elementwise` or otherwise convert + to a field and use division: + + >>> M = DM([[4, 6]], ZZ) + >>> den = ZZ(12) + >>> M.cancel_denom(den) + (DomainMatrix([[2, 3]], (1, 2), ZZ), 6) + >>> numers, denoms = M.cancel_denom_elementwise(den) + >>> numers + DomainMatrix([[1, 1]], (1, 2), ZZ) + >>> denoms + DomainMatrix([[3, 2]], (1, 2), ZZ) + >>> M.to_field() / den + DomainMatrix([[1/3, 1/2]], (1, 2), QQ) + + See Also + ======== + + solve_den + inv_den + rref_den + cancel_denom_elementwise + """ + M = self + K = self.domain + + if K.is_zero(denom): + raise ZeroDivisionError('denominator is zero') + elif K.is_one(denom): + return (M.copy(), denom) + + elements, data = M.to_flat_nz() + + # First canonicalize the denominator (e.g. multiply by -1). + if K.is_negative(denom): + u = -K.one + else: + u = K.canonical_unit(denom) + + # Often after e.g. solve_den the denominator will be much more + # complicated than the elements of the numerator. Hopefully it will be + # quicker to find the gcd of the numerator and if there is no content + # then we do not need to look at the denominator at all. + content = dup_content(elements, K) + common = K.gcd(content, denom) + + if not K.is_one(content): + + common = K.gcd(content, denom) + + if not K.is_one(common): + elements = dup_quo_ground(elements, common, K) + denom = K.quo(denom, common) + + if not K.is_one(u): + elements = dup_mul_ground(elements, u, K) + denom = u * denom + elif K.is_one(common): + return (M.copy(), denom) + + M_cancelled = M.from_flat_nz(elements, data, K) + + return M_cancelled, denom + + def cancel_denom_elementwise(self, denom): + """ + Cancel factors between the elements of a matrix and a denominator. + + Returns a matrix of numerators and matrix of denominators. + + Requires ``gcd`` in the ground domain. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 3], [4, 12]], ZZ) + >>> denom = ZZ(6) + >>> numers, denoms = M.cancel_denom_elementwise(denom) + >>> numers.to_Matrix() + Matrix([ + [1, 1], + [2, 2]]) + >>> denoms.to_Matrix() + Matrix([ + [3, 2], + [3, 1]]) + >>> M_frac = (M.to_field() / denom).to_Matrix() + >>> M_frac + Matrix([ + [1/3, 1/2], + [2/3, 2]]) + >>> denoms_inverted = denoms.to_Matrix().applyfunc(lambda e: 1/e) + >>> numers.to_Matrix().multiply_elementwise(denoms_inverted) == M_frac + True + + Use :meth:`cancel_denom` to cancel factors between the matrix and the + denominator while preserving the form of a matrix with a scalar + denominator. + + See Also + ======== + + cancel_denom + """ + K = self.domain + M = self + + if K.is_zero(denom): + raise ZeroDivisionError('denominator is zero') + elif K.is_one(denom): + M_numers = M.copy() + M_denoms = M.ones(M.shape, M.domain) + return (M_numers, M_denoms) + + elements, data = M.to_flat_nz() + + cofactors = [K.cofactors(numer, denom) for numer in elements] + gcds, numers, denoms = zip(*cofactors) + + M_numers = M.from_flat_nz(list(numers), data, K) + M_denoms = M.from_flat_nz(list(denoms), data, K) + + return (M_numers, M_denoms) + + def content(self): + """ + Return the gcd of the elements of the matrix. + + Requires ``gcd`` in the ground domain. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 4], [4, 12]], ZZ) + >>> M.content() + 2 + + See Also + ======== + + primitive + cancel_denom + """ + K = self.domain + elements, _ = self.to_flat_nz() + return dup_content(elements, K) + + def primitive(self): + """ + Factor out gcd of the elements of a matrix. + + Requires ``gcd`` in the ground domain. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[2, 4], [4, 12]], ZZ) + >>> content, M_primitive = M.primitive() + >>> content + 2 + >>> M_primitive + DomainMatrix([[1, 2], [2, 6]], (2, 2), ZZ) + >>> content * M_primitive == M + True + >>> M_primitive.content() == ZZ(1) + True + + See Also + ======== + + content + cancel_denom + """ + K = self.domain + elements, data = self.to_flat_nz() + content, prims = dup_primitive(elements, K) + M_primitive = self.from_flat_nz(prims, data, K) + return content, M_primitive + + def rref(self, *, method='auto'): + r""" + Returns reduced-row echelon form (RREF) and list of pivots. + + If the domain is not a field then it will be converted to a field. See + :meth:`rref_den` for the fraction-free version of this routine that + returns RREF with denominator instead. + + The domain must either be a field or have an associated fraction field + (see :meth:`to_field`). + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(2), QQ(-1), QQ(0)], + ... [QQ(-1), QQ(2), QQ(-1)], + ... [QQ(0), QQ(0), QQ(2)]], (3, 3), QQ) + + >>> rref_matrix, rref_pivots = A.rref() + >>> rref_matrix + DomainMatrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], (3, 3), QQ) + >>> rref_pivots + (0, 1, 2) + + Parameters + ========== + + method : str, optional (default: 'auto') + The method to use to compute the RREF. The default is ``'auto'``, + which will attempt to choose the fastest method. The other options + are: + + - ``A.rref(method='GJ')`` uses Gauss-Jordan elimination with + division. If the domain is not a field then it will be converted + to a field with :meth:`to_field` first and RREF will be computed + by inverting the pivot elements in each row. This is most + efficient for very sparse matrices or for matrices whose elements + have complex denominators. + + - ``A.rref(method='FF')`` uses fraction-free Gauss-Jordan + elimination. Elimination is performed using exact division + (``exquo``) to control the growth of the coefficients. In this + case the current domain is always used for elimination but if + the domain is not a field then it will be converted to a field + at the end and divided by the denominator. This is most efficient + for dense matrices or for matrices with simple denominators. + + - ``A.rref(method='CD')`` clears the denominators before using + fraction-free Gauss-Jordan elimination in the assoicated ring. + This is most efficient for dense matrices with very simple + denominators. + + - ``A.rref(method='GJ_dense')``, ``A.rref(method='FF_dense')``, and + ``A.rref(method='CD_dense')`` are the same as the above methods + except that the dense implementations of the algorithms are used. + By default ``A.rref(method='auto')`` will usually choose the + sparse implementations for RREF. + + Regardless of which algorithm is used the returned matrix will + always have the same format (sparse or dense) as the input and its + domain will always be the field of fractions of the input domain. + + Returns + ======= + + (DomainMatrix, list) + reduced-row echelon form and list of pivots for the DomainMatrix + + See Also + ======== + + rref_den + RREF with denominator + sympy.polys.matrices.sdm.sdm_irref + Sparse implementation of ``method='GJ'``. + sympy.polys.matrices.sdm.sdm_rref_den + Sparse implementation of ``method='FF'`` and ``method='CD'``. + sympy.polys.matrices.dense.ddm_irref + Dense implementation of ``method='GJ'``. + sympy.polys.matrices.dense.ddm_irref_den + Dense implementation of ``method='FF'`` and ``method='CD'``. + clear_denoms + Clear denominators from a matrix, used by ``method='CD'`` and + by ``method='GJ'`` when the original domain is not a field. + + """ + return _dm_rref(self, method=method) + + def rref_den(self, *, method='auto', keep_domain=True): + r""" + Returns reduced-row echelon form with denominator and list of pivots. + + Requires exact division in the ground domain (``exquo``). + + Examples + ======== + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(2), ZZ(-1), ZZ(0)], + ... [ZZ(-1), ZZ(2), ZZ(-1)], + ... [ZZ(0), ZZ(0), ZZ(2)]], (3, 3), ZZ) + + >>> A_rref, denom, pivots = A.rref_den() + >>> A_rref + DomainMatrix([[6, 0, 0], [0, 6, 0], [0, 0, 6]], (3, 3), ZZ) + >>> denom + 6 + >>> pivots + (0, 1, 2) + >>> A_rref.to_field() / denom + DomainMatrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], (3, 3), QQ) + >>> A_rref.to_field() / denom == A.convert_to(QQ).rref()[0] + True + + Parameters + ========== + + method : str, optional (default: 'auto') + The method to use to compute the RREF. The default is ``'auto'``, + which will attempt to choose the fastest method. The other options + are: + + - ``A.rref(method='FF')`` uses fraction-free Gauss-Jordan + elimination. Elimination is performed using exact division + (``exquo``) to control the growth of the coefficients. In this + case the current domain is always used for elimination and the + result is always returned as a matrix over the current domain. + This is most efficient for dense matrices or for matrices with + simple denominators. + + - ``A.rref(method='CD')`` clears denominators before using + fraction-free Gauss-Jordan elimination in the assoicated ring. + The result will be converted back to the original domain unless + ``keep_domain=False`` is passed in which case the result will be + over the ring used for elimination. This is most efficient for + dense matrices with very simple denominators. + + - ``A.rref(method='GJ')`` uses Gauss-Jordan elimination with + division. If the domain is not a field then it will be converted + to a field with :meth:`to_field` first and RREF will be computed + by inverting the pivot elements in each row. The result is + converted back to the original domain by clearing denominators + unless ``keep_domain=False`` is passed in which case the result + will be over the field used for elimination. This is most + efficient for very sparse matrices or for matrices whose elements + have complex denominators. + + - ``A.rref(method='GJ_dense')``, ``A.rref(method='FF_dense')``, and + ``A.rref(method='CD_dense')`` are the same as the above methods + except that the dense implementations of the algorithms are used. + By default ``A.rref(method='auto')`` will usually choose the + sparse implementations for RREF. + + Regardless of which algorithm is used the returned matrix will + always have the same format (sparse or dense) as the input and if + ``keep_domain=True`` its domain will always be the same as the + input. + + keep_domain : bool, optional + If True (the default), the domain of the returned matrix and + denominator are the same as the domain of the input matrix. If + False, the domain of the returned matrix might be changed to an + associated ring or field if the algorithm used a different domain. + This is useful for efficiency if the caller does not need the + result to be in the original domain e.g. it avoids clearing + denominators in the case of ``A.rref(method='GJ')``. + + Returns + ======= + + (DomainMatrix, scalar, list) + Reduced-row echelon form, denominator and list of pivot indices. + + See Also + ======== + + rref + RREF without denominator for field domains. + sympy.polys.matrices.sdm.sdm_irref + Sparse implementation of ``method='GJ'``. + sympy.polys.matrices.sdm.sdm_rref_den + Sparse implementation of ``method='FF'`` and ``method='CD'``. + sympy.polys.matrices.dense.ddm_irref + Dense implementation of ``method='GJ'``. + sympy.polys.matrices.dense.ddm_irref_den + Dense implementation of ``method='FF'`` and ``method='CD'``. + clear_denoms + Clear denominators from a matrix, used by ``method='CD'``. + + """ + return _dm_rref_den(self, method=method, keep_domain=keep_domain) + + def columnspace(self): + r""" + Returns the columnspace for the DomainMatrix + + Returns + ======= + + DomainMatrix + The columns of this matrix form a basis for the columnspace. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(-1)], + ... [QQ(2), QQ(-2)]], (2, 2), QQ) + >>> A.columnspace() + DomainMatrix([[1], [2]], (2, 1), QQ) + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + rref, pivots = self.rref() + rows, cols = self.shape + return self.extract(range(rows), pivots) + + def rowspace(self): + r""" + Returns the rowspace for the DomainMatrix + + Returns + ======= + + DomainMatrix + The rows of this matrix form a basis for the rowspace. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(-1)], + ... [QQ(2), QQ(-2)]], (2, 2), QQ) + >>> A.rowspace() + DomainMatrix([[1, -1]], (1, 2), QQ) + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + rref, pivots = self.rref() + rows, cols = self.shape + return self.extract(range(len(pivots)), range(cols)) + + def nullspace(self, divide_last=False): + r""" + Returns the nullspace for the DomainMatrix + + Returns + ======= + + DomainMatrix + The rows of this matrix form a basis for the nullspace. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([ + ... [QQ(2), QQ(-2)], + ... [QQ(4), QQ(-4)]], QQ) + >>> A.nullspace() + DomainMatrix([[1, 1]], (1, 2), QQ) + + The returned matrix is a basis for the nullspace: + + >>> A_null = A.nullspace().transpose() + >>> A * A_null + DomainMatrix([[0], [0]], (2, 1), QQ) + >>> rows, cols = A.shape + >>> nullity = rows - A.rank() + >>> A_null.shape == (cols, nullity) + True + + Nullspace can also be computed for non-field rings. If the ring is not + a field then division is not used. Setting ``divide_last`` to True will + raise an error in this case: + + >>> from sympy import ZZ + >>> B = DM([[6, -3], + ... [4, -2]], ZZ) + >>> B.nullspace() + DomainMatrix([[3, 6]], (1, 2), ZZ) + >>> B.nullspace(divide_last=True) + Traceback (most recent call last): + ... + DMNotAField: Cannot normalize vectors over a non-field + + Over a ring with ``gcd`` defined the nullspace can potentially be + reduced with :meth:`primitive`: + + >>> B.nullspace().primitive() + (3, DomainMatrix([[1, 2]], (1, 2), ZZ)) + + A matrix over a ring can often be normalized by converting it to a + field but it is often a bad idea to do so: + + >>> from sympy.abc import a, b, c + >>> from sympy import Matrix + >>> M = Matrix([[ a*b, b + c, c], + ... [ a - b, b*c, c**2], + ... [a*b + a - b, b*c + b + c, c**2 + c]]) + >>> M.to_DM().domain + ZZ[a,b,c] + >>> M.to_DM().nullspace().to_Matrix().transpose() + Matrix([ + [ c**3], + [ -a*b*c**2 + a*c - b*c], + [a*b**2*c - a*b - a*c + b**2 + b*c]]) + + The unnormalized form here is nicer than the normalized form that + spreads a large denominator throughout the matrix: + + >>> M.to_DM().to_field().nullspace(divide_last=True).to_Matrix().transpose() + Matrix([ + [ c**3/(a*b**2*c - a*b - a*c + b**2 + b*c)], + [(-a*b*c**2 + a*c - b*c)/(a*b**2*c - a*b - a*c + b**2 + b*c)], + [ 1]]) + + Parameters + ========== + + divide_last : bool, optional + If False (the default), the vectors are not normalized and the RREF + is computed using :meth:`rref_den` and the denominator is + discarded. If True, then each row is divided by its final element; + the domain must be a field in this case. + + See Also + ======== + + nullspace_from_rref + rref + rref_den + rowspace + """ + A = self + K = A.domain + + if divide_last and not K.is_Field: + raise DMNotAField("Cannot normalize vectors over a non-field") + + if divide_last: + A_rref, pivots = A.rref() + else: + A_rref, den, pivots = A.rref_den() + + # Ensure that the sign is canonical before discarding the + # denominator. Then M.nullspace().primitive() is canonical. + u = K.canonical_unit(den) + if u != K.one: + A_rref *= u + + A_null = A_rref.nullspace_from_rref(pivots) + + return A_null + + def nullspace_from_rref(self, pivots=None): + """ + Compute nullspace from rref and pivots. + + The domain of the matrix can be any domain. + + The matrix must be in reduced row echelon form already. Otherwise the + result will be incorrect. Use :meth:`rref` or :meth:`rref_den` first + to get the reduced row echelon form or use :meth:`nullspace` instead. + + See Also + ======== + + nullspace + rref + rref_den + sympy.polys.matrices.sdm.SDM.nullspace_from_rref + sympy.polys.matrices.ddm.DDM.nullspace_from_rref + """ + null_rep, nonpivots = self.rep.nullspace_from_rref(pivots) + return self.from_rep(null_rep) + + def inv(self): + r""" + Finds the inverse of the DomainMatrix if exists + + Returns + ======= + + DomainMatrix + DomainMatrix after inverse + + Raises + ====== + + ValueError + If the domain of DomainMatrix not a Field + + DMNonSquareMatrixError + If the DomainMatrix is not a not Square DomainMatrix + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(2), QQ(-1), QQ(0)], + ... [QQ(-1), QQ(2), QQ(-1)], + ... [QQ(0), QQ(0), QQ(2)]], (3, 3), QQ) + >>> A.inv() + DomainMatrix([[2/3, 1/3, 1/6], [1/3, 2/3, 1/3], [0, 0, 1/2]], (3, 3), QQ) + + See Also + ======== + + neg + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + m, n = self.shape + if m != n: + raise DMNonSquareMatrixError + inv = self.rep.inv() + return self.from_rep(inv) + + def det(self): + r""" + Returns the determinant of a square :class:`DomainMatrix`. + + Returns + ======= + + determinant: DomainElement + Determinant of the matrix. + + Raises + ====== + + ValueError + If the domain of DomainMatrix is not a Field + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.det() + -2 + + """ + m, n = self.shape + if m != n: + raise DMNonSquareMatrixError + return self.rep.det() + + def adj_det(self): + """ + Adjugate and determinant of a square :class:`DomainMatrix`. + + Returns + ======= + + (adjugate, determinant) : (DomainMatrix, DomainScalar) + The adjugate matrix and determinant of this matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], ZZ) + >>> adjA, detA = A.adj_det() + >>> adjA + DomainMatrix([[4, -2], [-3, 1]], (2, 2), ZZ) + >>> detA + -2 + + See Also + ======== + + adjugate + Returns only the adjugate matrix. + det + Returns only the determinant. + inv_den + Returns a matrix/denominator pair representing the inverse matrix + but perhaps differing from the adjugate and determinant by a common + factor. + """ + m, n = self.shape + I_m = self.eye((m, m), self.domain) + adjA, detA = self.solve_den_charpoly(I_m, check=False) + if self.rep.fmt == "dense": + adjA = adjA.to_dense() + return adjA, detA + + def adjugate(self): + """ + Adjugate of a square :class:`DomainMatrix`. + + The adjugate matrix is the transpose of the cofactor matrix and is + related to the inverse by:: + + adj(A) = det(A) * A.inv() + + Unlike the inverse matrix the adjugate matrix can be computed and + expressed without division or fractions in the ground domain. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> A.adjugate() + DomainMatrix([[4, -2], [-3, 1]], (2, 2), ZZ) + + Returns + ======= + + DomainMatrix + The adjugate matrix of this matrix with the same domain. + + See Also + ======== + + adj_det + """ + adjA, detA = self.adj_det() + return adjA + + def inv_den(self, method=None): + """ + Return the inverse as a :class:`DomainMatrix` with denominator. + + Returns + ======= + + (inv, den) : (:class:`DomainMatrix`, :class:`~.DomainElement`) + The inverse matrix and its denominator. + + This is more or less equivalent to :meth:`adj_det` except that ``inv`` + and ``den`` are not guaranteed to be the adjugate and inverse. The + ratio ``inv/den`` is equivalent to ``adj/det`` but some factors + might be cancelled between ``inv`` and ``den``. In simple cases this + might just be a minus sign so that ``(inv, den) == (-adj, -det)`` but + factors more complicated than ``-1`` can also be cancelled. + Cancellation is not guaranteed to be complete so ``inv`` and ``den`` + may not be on lowest terms. The denominator ``den`` will be zero if and + only if the determinant is zero. + + If the actual adjugate and determinant are needed, use :meth:`adj_det` + instead. If the intention is to compute the inverse matrix or solve a + system of equations then :meth:`inv_den` is more efficient. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(2), ZZ(-1), ZZ(0)], + ... [ZZ(-1), ZZ(2), ZZ(-1)], + ... [ZZ(0), ZZ(0), ZZ(2)]], (3, 3), ZZ) + >>> Ainv, den = A.inv_den() + >>> den + 6 + >>> Ainv + DomainMatrix([[4, 2, 1], [2, 4, 2], [0, 0, 3]], (3, 3), ZZ) + >>> A * Ainv == den * A.eye(A.shape, A.domain).to_dense() + True + + Parameters + ========== + + method : str, optional + The method to use to compute the inverse. Can be one of ``None``, + ``'rref'`` or ``'charpoly'``. If ``None`` then the method is + chosen automatically (see :meth:`solve_den` for details). + + See Also + ======== + + inv + det + adj_det + solve_den + """ + I = self.eye(self.shape, self.domain) + return self.solve_den(I, method=method) + + def solve_den(self, b, method=None): + """ + Solve matrix equation $Ax = b$ without fractions in the ground domain. + + Examples + ======== + + Solve a matrix equation over the integers: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> b = DM([[ZZ(5)], [ZZ(6)]], ZZ) + >>> xnum, xden = A.solve_den(b) + >>> xden + -2 + >>> xnum + DomainMatrix([[8], [-9]], (2, 1), ZZ) + >>> A * xnum == xden * b + True + + Solve a matrix equation over a polynomial ring: + + >>> from sympy import ZZ + >>> from sympy.abc import x, y, z, a, b + >>> R = ZZ[x, y, z, a, b] + >>> M = DM([[x*y, x*z], [y*z, x*z]], R) + >>> b = DM([[a], [b]], R) + >>> M.to_Matrix() + Matrix([ + [x*y, x*z], + [y*z, x*z]]) + >>> b.to_Matrix() + Matrix([ + [a], + [b]]) + >>> xnum, xden = M.solve_den(b) + >>> xden + x**2*y*z - x*y*z**2 + >>> xnum.to_Matrix() + Matrix([ + [ a*x*z - b*x*z], + [-a*y*z + b*x*y]]) + >>> M * xnum == xden * b + True + + The solution can be expressed over a fraction field which will cancel + gcds between the denominator and the elements of the numerator: + + >>> xsol = xnum.to_field() / xden + >>> xsol.to_Matrix() + Matrix([ + [ (a - b)/(x*y - y*z)], + [(-a*z + b*x)/(x**2*z - x*z**2)]]) + >>> (M * xsol).to_Matrix() == b.to_Matrix() + True + + When solving a large system of equations this cancellation step might + be a lot slower than :func:`solve_den` itself. The solution can also be + expressed as a ``Matrix`` without attempting any polynomial + cancellation between the numerator and denominator giving a less + simplified result more quickly: + + >>> xsol_uncancelled = xnum.to_Matrix() / xnum.domain.to_sympy(xden) + >>> xsol_uncancelled + Matrix([ + [ (a*x*z - b*x*z)/(x**2*y*z - x*y*z**2)], + [(-a*y*z + b*x*y)/(x**2*y*z - x*y*z**2)]]) + >>> from sympy import cancel + >>> cancel(xsol_uncancelled) == xsol.to_Matrix() + True + + Parameters + ========== + + self : :class:`DomainMatrix` + The ``m x n`` matrix $A$ in the equation $Ax = b$. Underdetermined + systems are not supported so ``m >= n``: $A$ should be square or + have more rows than columns. + b : :class:`DomainMatrix` + The ``n x m`` matrix $b$ for the rhs. + cp : list of :class:`~.DomainElement`, optional + The characteristic polynomial of the matrix $A$. If not given, it + will be computed using :meth:`charpoly`. + method: str, optional + The method to use for solving the system. Can be one of ``None``, + ``'charpoly'`` or ``'rref'``. If ``None`` (the default) then the + method will be chosen automatically. + + The ``charpoly`` method uses :meth:`solve_den_charpoly` and can + only be used if the matrix is square. This method is division free + and can be used with any domain. + + The ``rref`` method is fraction free but requires exact division + in the ground domain (``exquo``). This is also suitable for most + domains. This method can be used with overdetermined systems (more + equations than unknowns) but not underdetermined systems as a + unique solution is sought. + + Returns + ======= + + (xnum, xden) : (DomainMatrix, DomainElement) + The solution of the equation $Ax = b$ as a pair consisting of an + ``n x m`` matrix numerator ``xnum`` and a scalar denominator + ``xden``. + + The solution $x$ is given by ``x = xnum / xden``. The division free + invariant is ``A * xnum == xden * b``. If $A$ is square then the + denominator ``xden`` will be a divisor of the determinant $det(A)$. + + Raises + ====== + + DMNonInvertibleMatrixError + If the system $Ax = b$ does not have a unique solution. + + See Also + ======== + + solve_den_charpoly + solve_den_rref + inv_den + """ + m, n = self.shape + bm, bn = b.shape + + if m != bm: + raise DMShapeError("Matrix equation shape mismatch.") + + if method is None: + method = 'rref' + elif method == 'charpoly' and m != n: + raise DMNonSquareMatrixError("method='charpoly' requires a square matrix.") + + if method == 'charpoly': + xnum, xden = self.solve_den_charpoly(b) + elif method == 'rref': + xnum, xden = self.solve_den_rref(b) + else: + raise DMBadInputError("method should be 'rref' or 'charpoly'") + + return xnum, xden + + def solve_den_rref(self, b): + """ + Solve matrix equation $Ax = b$ using fraction-free RREF + + Solves the matrix equation $Ax = b$ for $x$ and returns the solution + as a numerator/denominator pair. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> b = DM([[ZZ(5)], [ZZ(6)]], ZZ) + >>> xnum, xden = A.solve_den_rref(b) + >>> xden + -2 + >>> xnum + DomainMatrix([[8], [-9]], (2, 1), ZZ) + >>> A * xnum == xden * b + True + + See Also + ======== + + solve_den + solve_den_charpoly + """ + A = self + m, n = A.shape + bm, bn = b.shape + + if m != bm: + raise DMShapeError("Matrix equation shape mismatch.") + + if m < n: + raise DMShapeError("Underdetermined matrix equation.") + + Aaug = A.hstack(b) + Aaug_rref, denom, pivots = Aaug.rref_den() + + # XXX: We check here if there are pivots after the last column. If + # there were than it possibly means that rref_den performed some + # unnecessary elimination. It would be better if rref methods had a + # parameter indicating how many columns should be used for elimination. + if len(pivots) != n or pivots and pivots[-1] >= n: + raise DMNonInvertibleMatrixError("Non-unique solution.") + + xnum = Aaug_rref[:n, n:] + xden = denom + + return xnum, xden + + def solve_den_charpoly(self, b, cp=None, check=True): + """ + Solve matrix equation $Ax = b$ using the characteristic polynomial. + + This method solves the square matrix equation $Ax = b$ for $x$ using + the characteristic polynomial without any division or fractions in the + ground domain. + + Examples + ======== + + Solve a matrix equation over the integers: + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DM + >>> A = DM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], ZZ) + >>> b = DM([[ZZ(5)], [ZZ(6)]], ZZ) + >>> xnum, detA = A.solve_den_charpoly(b) + >>> detA + -2 + >>> xnum + DomainMatrix([[8], [-9]], (2, 1), ZZ) + >>> A * xnum == detA * b + True + + Parameters + ========== + + self : DomainMatrix + The ``n x n`` matrix `A` in the equation `Ax = b`. Must be square + and invertible. + b : DomainMatrix + The ``n x m`` matrix `b` for the rhs. + cp : list, optional + The characteristic polynomial of the matrix `A` if known. If not + given, it will be computed using :meth:`charpoly`. + check : bool, optional + If ``True`` (the default) check that the determinant is not zero + and raise an error if it is. If ``False`` then if the determinant + is zero the return value will be equal to ``(A.adjugate()*b, 0)``. + + Returns + ======= + + (xnum, detA) : (DomainMatrix, DomainElement) + The solution of the equation `Ax = b` as a matrix numerator and + scalar denominator pair. The denominator is equal to the + determinant of `A` and the numerator is ``adj(A)*b``. + + The solution $x$ is given by ``x = xnum / detA``. The division free + invariant is ``A * xnum == detA * b``. + + If ``b`` is the identity matrix, then ``xnum`` is the adjugate matrix + and we have ``A * adj(A) == detA * I``. + + See Also + ======== + + solve_den + Main frontend for solving matrix equations with denominator. + solve_den_rref + Solve matrix equations using fraction-free RREF. + inv_den + Invert a matrix using the characteristic polynomial. + """ + A, b = self.unify(b) + m, n = self.shape + mb, nb = b.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if mb != m: + raise DMShapeError("Matrix and vector must have the same number of rows") + + f, detA = self.adj_poly_det(cp=cp) + + if check and not detA: + raise DMNonInvertibleMatrixError("Matrix is not invertible") + + # Compute adj(A)*b = det(A)*inv(A)*b using Horner's method without + # constructing inv(A) explicitly. + adjA_b = self.eval_poly_mul(f, b) + + return (adjA_b, detA) + + def adj_poly_det(self, cp=None): + """ + Return the polynomial $p$ such that $p(A) = adj(A)$ and also the + determinant of $A$. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], QQ) + >>> p, detA = A.adj_poly_det() + >>> p + [-1, 5] + >>> p_A = A.eval_poly(p) + >>> p_A + DomainMatrix([[4, -2], [-3, 1]], (2, 2), QQ) + >>> p[0]*A**1 + p[1]*A**0 == p_A + True + >>> p_A == A.adjugate() + True + >>> A * A.adjugate() == detA * A.eye(A.shape, A.domain).to_dense() + True + + See Also + ======== + + adjugate + eval_poly + adj_det + """ + + # Cayley-Hamilton says that a matrix satisfies its own minimal + # polynomial + # + # p[0]*A^n + p[1]*A^(n-1) + ... + p[n]*I = 0 + # + # with p[0]=1 and p[n]=(-1)^n*det(A) or + # + # det(A)*I = -(-1)^n*(p[0]*A^(n-1) + p[1]*A^(n-2) + ... + p[n-1]*A). + # + # Define a new polynomial f with f[i] = -(-1)^n*p[i] for i=0..n-1. Then + # + # det(A)*I = f[0]*A^n + f[1]*A^(n-1) + ... + f[n-1]*A. + # + # Multiplying on the right by inv(A) gives + # + # det(A)*inv(A) = f[0]*A^(n-1) + f[1]*A^(n-2) + ... + f[n-1]. + # + # So adj(A) = det(A)*inv(A) = f(A) + + A = self + m, n = self.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if cp is None: + cp = A.charpoly() + + if len(cp) % 2: + # n is even + detA = cp[-1] + f = [-cpi for cpi in cp[:-1]] + else: + # n is odd + detA = -cp[-1] + f = cp[:-1] + + return f, detA + + def eval_poly(self, p): + """ + Evaluate polynomial function of a matrix $p(A)$. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], QQ) + >>> p = [QQ(1), QQ(2), QQ(3)] + >>> p_A = A.eval_poly(p) + >>> p_A + DomainMatrix([[12, 14], [21, 33]], (2, 2), QQ) + >>> p_A == p[0]*A**2 + p[1]*A + p[2]*A**0 + True + + See Also + ======== + + eval_poly_mul + """ + A = self + m, n = A.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if not p: + return self.zeros(self.shape, self.domain) + elif len(p) == 1: + return p[0] * self.eye(self.shape, self.domain) + + # Evaluate p(A) using Horner's method: + # XXX: Use Paterson-Stockmeyer method? + I = A.eye(A.shape, A.domain) + p_A = p[0] * I + for pi in p[1:]: + p_A = A*p_A + pi*I + + return p_A + + def eval_poly_mul(self, p, B): + r""" + Evaluate polynomial matrix product $p(A) \times B$. + + Evaluate the polynomial matrix product $p(A) \times B$ using Horner's + method without creating the matrix $p(A)$ explicitly. If $B$ is a + column matrix then this method will only use matrix-vector multiplies + and no matrix-matrix multiplies are needed. + + If $B$ is square or wide or if $A$ can be represented in a simpler + domain than $B$ then it might be faster to evaluate $p(A)$ explicitly + (see :func:`eval_poly`) and then multiply with $B$. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DM + >>> A = DM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], QQ) + >>> b = DM([[QQ(5)], [QQ(6)]], QQ) + >>> p = [QQ(1), QQ(2), QQ(3)] + >>> p_A_b = A.eval_poly_mul(p, b) + >>> p_A_b + DomainMatrix([[144], [303]], (2, 1), QQ) + >>> p_A_b == p[0]*A**2*b + p[1]*A*b + p[2]*b + True + >>> A.eval_poly_mul(p, b) == A.eval_poly(p)*b + True + + See Also + ======== + + eval_poly + solve_den_charpoly + """ + A = self + m, n = A.shape + mb, nb = B.shape + + if m != n: + raise DMNonSquareMatrixError("Matrix must be square") + + if mb != n: + raise DMShapeError("Matrices are not aligned") + + if A.domain != B.domain: + raise DMDomainError("Matrices must have the same domain") + + # Given a polynomial p(x) = p[0]*x^n + p[1]*x^(n-1) + ... + p[n-1] + # and matrices A and B we want to find + # + # p(A)*B = p[0]*A^n*B + p[1]*A^(n-1)*B + ... + p[n-1]*B + # + # Factoring out A term by term we get + # + # p(A)*B = A*(...A*(A*(A*(p[0]*B) + p[1]*B) + p[2]*B) + ...) + p[n-1]*B + # + # where each pair of brackets represents one iteration of the loop + # below starting from the innermost p[0]*B. If B is a column matrix + # then products like A*(...) are matrix-vector multiplies and products + # like p[i]*B are scalar-vector multiplies so there are no + # matrix-matrix multiplies. + + if not p: + return B.zeros(B.shape, B.domain, fmt=B.rep.fmt) + + p_A_B = p[0]*B + + for p_i in p[1:]: + p_A_B = A*p_A_B + p_i*B + + return p_A_B + + def lu(self): + r""" + Returns Lower and Upper decomposition of the DomainMatrix + + Returns + ======= + + (L, U, exchange) + L, U are Lower and Upper decomposition of the DomainMatrix, + exchange is the list of indices of rows exchanged in the + decomposition. + + Raises + ====== + + ValueError + If the domain of DomainMatrix not a Field + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(-1)], + ... [QQ(2), QQ(-2)]], (2, 2), QQ) + >>> L, U, exchange = A.lu() + >>> L + DomainMatrix([[1, 0], [2, 1]], (2, 2), QQ) + >>> U + DomainMatrix([[1, -1], [0, 0]], (2, 2), QQ) + >>> exchange + [] + + See Also + ======== + + lu_solve + + """ + if not self.domain.is_Field: + raise DMNotAField('Not a field') + L, U, swaps = self.rep.lu() + return self.from_rep(L), self.from_rep(U), swaps + + def lu_solve(self, rhs): + r""" + Solver for DomainMatrix x in the A*x = B + + Parameters + ========== + + rhs : DomainMatrix B + + Returns + ======= + + DomainMatrix + x in A*x = B + + Raises + ====== + + DMShapeError + If the DomainMatrix A and rhs have different number of rows + + ValueError + If the domain of DomainMatrix A not a Field + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [QQ(1), QQ(2)], + ... [QQ(3), QQ(4)]], (2, 2), QQ) + >>> B = DomainMatrix([ + ... [QQ(1), QQ(1)], + ... [QQ(0), QQ(1)]], (2, 2), QQ) + + >>> A.lu_solve(B) + DomainMatrix([[-2, -1], [3/2, 1]], (2, 2), QQ) + + See Also + ======== + + lu + + """ + if self.shape[0] != rhs.shape[0]: + raise DMShapeError("Shape") + if not self.domain.is_Field: + raise DMNotAField('Not a field') + sol = self.rep.lu_solve(rhs.rep) + return self.from_rep(sol) + + def _solve(A, b): + # XXX: Not sure about this method or its signature. It is just created + # because it is needed by the holonomic module. + if A.shape[0] != b.shape[0]: + raise DMShapeError("Shape") + if A.domain != b.domain or not A.domain.is_Field: + raise DMNotAField('Not a field') + Aaug = A.hstack(b) + Arref, pivots = Aaug.rref() + particular = Arref.from_rep(Arref.rep.particular()) + nullspace_rep, nonpivots = Arref[:,:-1].rep.nullspace() + nullspace = Arref.from_rep(nullspace_rep) + return particular, nullspace + + def charpoly(self): + r""" + Characteristic polynomial of a square matrix. + + Computes the characteristic polynomial in a fully expanded form using + division free arithmetic. If a factorization of the characteristic + polynomial is needed then it is more efficient to call + :meth:`charpoly_factor_list` than calling :meth:`charpoly` and then + factorizing the result. + + Returns + ======= + + list: list of DomainElement + coefficients of the characteristic polynomial + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + >>> A.charpoly() + [1, -5, -2] + + See Also + ======== + + charpoly_factor_list + Compute the factorisation of the characteristic polynomial. + charpoly_factor_blocks + A partial factorisation of the characteristic polynomial that can + be computed more efficiently than either the full factorisation or + the fully expanded polynomial. + """ + M = self + K = M.domain + + factors = M.charpoly_factor_blocks() + + cp = [K.one] + + for f, mult in factors: + for _ in range(mult): + cp = dup_mul(cp, f, K) + + return cp + + def charpoly_factor_list(self): + """ + Full factorization of the characteristic polynomial. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[6, -1, 0, 0], + ... [9, 12, 0, 0], + ... [0, 0, 1, 2], + ... [0, 0, 5, 6]], ZZ) + + Compute the factorization of the characteristic polynomial: + + >>> M.charpoly_factor_list() + [([1, -9], 2), ([1, -7, -4], 1)] + + Use :meth:`charpoly` to get the unfactorized characteristic polynomial: + + >>> M.charpoly() + [1, -25, 203, -495, -324] + + The same calculations with ``Matrix``: + + >>> M.to_Matrix().charpoly().as_expr() + lambda**4 - 25*lambda**3 + 203*lambda**2 - 495*lambda - 324 + >>> M.to_Matrix().charpoly().as_expr().factor() + (lambda - 9)**2*(lambda**2 - 7*lambda - 4) + + Returns + ======= + + list: list of pairs (factor, multiplicity) + A full factorization of the characteristic polynomial. + + See Also + ======== + + charpoly + Expanded form of the characteristic polynomial. + charpoly_factor_blocks + A partial factorisation of the characteristic polynomial that can + be computed more efficiently. + """ + M = self + K = M.domain + + # It is more efficient to start from the partial factorization provided + # for free by M.charpoly_factor_blocks than the expanded M.charpoly. + factors = M.charpoly_factor_blocks() + + factors_irreducible = [] + + for factor_i, mult_i in factors: + + _, factors_list = dup_factor_list(factor_i, K) + + for factor_j, mult_j in factors_list: + factors_irreducible.append((factor_j, mult_i * mult_j)) + + return _collect_factors(factors_irreducible) + + def charpoly_factor_blocks(self): + """ + Partial factorisation of the characteristic polynomial. + + This factorisation arises from a block structure of the matrix (if any) + and so the factors are not guaranteed to be irreducible. The + :meth:`charpoly_factor_blocks` method is the most efficient way to get + a representation of the characteristic polynomial but the result is + neither fully expanded nor fully factored. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import ZZ + >>> M = DM([[6, -1, 0, 0], + ... [9, 12, 0, 0], + ... [0, 0, 1, 2], + ... [0, 0, 5, 6]], ZZ) + + This computes a partial factorization using only the block structure of + the matrix to reveal factors: + + >>> M.charpoly_factor_blocks() + [([1, -18, 81], 1), ([1, -7, -4], 1)] + + These factors correspond to the two diagonal blocks in the matrix: + + >>> DM([[6, -1], [9, 12]], ZZ).charpoly() + [1, -18, 81] + >>> DM([[1, 2], [5, 6]], ZZ).charpoly() + [1, -7, -4] + + Use :meth:`charpoly_factor_list` to get a complete factorization into + irreducibles: + + >>> M.charpoly_factor_list() + [([1, -9], 2), ([1, -7, -4], 1)] + + Use :meth:`charpoly` to get the expanded characteristic polynomial: + + >>> M.charpoly() + [1, -25, 203, -495, -324] + + Returns + ======= + + list: list of pairs (factor, multiplicity) + A partial factorization of the characteristic polynomial. + + See Also + ======== + + charpoly + Compute the fully expanded characteristic polynomial. + charpoly_factor_list + Compute a full factorization of the characteristic polynomial. + """ + M = self + + if not M.is_square: + raise DMNonSquareMatrixError("not square") + + # scc returns indices that permute the matrix into block triangular + # form and can extract the diagonal blocks. M.charpoly() is equal to + # the product of the diagonal block charpolys. + components = M.scc() + + block_factors = [] + + for indices in components: + block = M.extract(indices, indices) + block_factors.append((block.charpoly_base(), 1)) + + return _collect_factors(block_factors) + + def charpoly_base(self): + """ + Base case for :meth:`charpoly_factor_blocks` after block decomposition. + + This method is used internally by :meth:`charpoly_factor_blocks` as the + base case for computing the characteristic polynomial of a block. It is + more efficient to call :meth:`charpoly_factor_blocks`, :meth:`charpoly` + or :meth:`charpoly_factor_list` rather than call this method directly. + + This will use either the dense or the sparse implementation depending + on the sparsity of the matrix and will clear denominators if possible + before calling :meth:`charpoly_berk` to compute the characteristic + polynomial using the Berkowitz algorithm. + + See Also + ======== + + charpoly + charpoly_factor_list + charpoly_factor_blocks + charpoly_berk + """ + M = self + K = M.domain + + # It seems that the sparse implementation is always faster for random + # matrices with fewer than 50% non-zero entries. This does not seem to + # depend on domain, size, bit count etc. + density = self.nnz() / self.shape[0]**2 + if density < 0.5: + M = M.to_sparse() + else: + M = M.to_dense() + + # Clearing denominators is always more efficient if it can be done. + # Doing it here after block decomposition is good because each block + # might have a smaller denominator. However it might be better for + # charpoly and charpoly_factor_list to restore the denominators only at + # the very end so that they can call e.g. dup_factor_list before + # restoring the denominators. The methods would need to be changed to + # return (poly, denom) pairs to make that work though. + clear_denoms = K.is_Field and K.has_assoc_Ring + + if clear_denoms: + clear_denoms = True + d, M = M.clear_denoms(convert=True) + d = d.element + K_f = K + K_r = M.domain + + # Berkowitz algorithm over K_r. + cp = M.charpoly_berk() + + if clear_denoms: + # Restore the denominator in the charpoly over K_f. + # + # If M = N/d then p_M(x) = p_N(x*d)/d^n. + cp = dup_convert(cp, K_r, K_f) + p = [K_f.one, K_f.zero] + q = [K_f.one/d] + cp = dup_transform(cp, p, q, K_f) + + return cp + + def charpoly_berk(self): + """Compute the characteristic polynomial using the Berkowitz algorithm. + + This method directly calls the underlying implementation of the + Berkowitz algorithm (:meth:`sympy.polys.matrices.dense.ddm_berk` or + :meth:`sympy.polys.matrices.sdm.sdm_berk`). + + This is used by :meth:`charpoly` and other methods as the base case for + for computing the characteristic polynomial. However those methods will + apply other optimizations such as block decomposition, clearing + denominators and converting between dense and sparse representations + before calling this method. It is more efficient to call those methods + instead of this one but this method is provided for direct access to + the Berkowitz algorithm. + + Examples + ======== + + >>> from sympy.polys.matrices import DM + >>> from sympy import QQ + >>> M = DM([[6, -1, 0, 0], + ... [9, 12, 0, 0], + ... [0, 0, 1, 2], + ... [0, 0, 5, 6]], QQ) + >>> M.charpoly_berk() + [1, -25, 203, -495, -324] + + See Also + ======== + + charpoly + charpoly_base + charpoly_factor_list + charpoly_factor_blocks + sympy.polys.matrices.dense.ddm_berk + sympy.polys.matrices.sdm.sdm_berk + """ + return self.rep.charpoly() + + @classmethod + def eye(cls, shape, domain): + r""" + Return identity matrix of size n or shape (m, n). + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> DomainMatrix.eye(3, QQ) + DomainMatrix({0: {0: 1}, 1: {1: 1}, 2: {2: 1}}, (3, 3), QQ) + + """ + if isinstance(shape, int): + shape = (shape, shape) + return cls.from_rep(SDM.eye(shape, domain)) + + @classmethod + def diag(cls, diagonal, domain, shape=None): + r""" + Return diagonal matrix with entries from ``diagonal``. + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import ZZ + >>> DomainMatrix.diag([ZZ(5), ZZ(6)], ZZ) + DomainMatrix({0: {0: 5}, 1: {1: 6}}, (2, 2), ZZ) + + """ + if shape is None: + N = len(diagonal) + shape = (N, N) + return cls.from_rep(SDM.diag(diagonal, domain, shape)) + + @classmethod + def zeros(cls, shape, domain, *, fmt='sparse'): + """Returns a zero DomainMatrix of size shape, belonging to the specified domain + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> DomainMatrix.zeros((2, 3), QQ) + DomainMatrix({}, (2, 3), QQ) + + """ + return cls.from_rep(SDM.zeros(shape, domain)) + + @classmethod + def ones(cls, shape, domain): + """Returns a DomainMatrix of 1s, of size shape, belonging to the specified domain + + Examples + ======== + + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy import QQ + >>> DomainMatrix.ones((2,3), QQ) + DomainMatrix([[1, 1, 1], [1, 1, 1]], (2, 3), QQ) + + """ + return cls.from_rep(DDM.ones(shape, domain).to_dfm_or_ddm()) + + def __eq__(A, B): + r""" + Checks for two DomainMatrix matrices to be equal or not + + Parameters + ========== + + A, B: DomainMatrix + to check equality + + Returns + ======= + + Boolean + True for equal, else False + + Raises + ====== + + NotImplementedError + If B is not a DomainMatrix + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> A = DomainMatrix([ + ... [ZZ(1), ZZ(2)], + ... [ZZ(3), ZZ(4)]], (2, 2), ZZ) + >>> B = DomainMatrix([ + ... [ZZ(1), ZZ(1)], + ... [ZZ(0), ZZ(1)]], (2, 2), ZZ) + >>> A.__eq__(A) + True + >>> A.__eq__(B) + False + + """ + if not isinstance(A, type(B)): + return NotImplemented + return A.domain == B.domain and A.rep == B.rep + + def unify_eq(A, B): + if A.shape != B.shape: + return False + if A.domain != B.domain: + A, B = A.unify(B) + return A == B + + def lll(A, delta=QQ(3, 4)): + """ + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm. + See [1]_ and [2]_. + + Parameters + ========== + + delta : QQ, optional + The Lovász parameter. Must be in the interval (0.25, 1), with larger + values producing a more reduced basis. The default is 0.75 for + historical reasons. + + Returns + ======= + + The reduced basis as a DomainMatrix over ZZ. + + Throws + ====== + + DMValueError: if delta is not in the range (0.25, 1) + DMShapeError: if the matrix is not of shape (m, n) with m <= n + DMDomainError: if the matrix domain is not ZZ + DMRankError: if the matrix contains linearly dependent rows + + Examples + ======== + + >>> from sympy.polys.domains import ZZ, QQ + >>> from sympy.polys.matrices import DM + >>> x = DM([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]], ZZ) + >>> y = DM([[10, -3, -2, 8, -4], + ... [3, -9, 8, 1, -11], + ... [-3, 13, -9, -3, -9], + ... [-12, -7, -11, 9, -1]], ZZ) + >>> assert x.lll(delta=QQ(5, 6)) == y + + Notes + ===== + + The implementation is derived from the Maple code given in Figures 4.3 + and 4.4 of [3]_ (pp.68-69). It uses the efficient method of only calculating + state updates as they are required. + + See also + ======== + + lll_transform + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Lenstra%E2%80%93Lenstra%E2%80%93Lov%C3%A1sz_lattice_basis_reduction_algorithm + .. [2] https://web.archive.org/web/20221029115428/https://web.cs.elte.hu/~lovasz/scans/lll.pdf + .. [3] Murray R. Bremner, "Lattice Basis Reduction: An Introduction to the LLL Algorithm and Its Applications" + + """ + return DomainMatrix.from_rep(A.rep.lll(delta=delta)) + + def lll_transform(A, delta=QQ(3, 4)): + """ + Performs the Lenstra–Lenstra–Lovász (LLL) basis reduction algorithm + and returns the reduced basis and transformation matrix. + + Explanation + =========== + + Parameters, algorithm and basis are the same as for :meth:`lll` except that + the return value is a tuple `(B, T)` with `B` the reduced basis and + `T` a transformation matrix. The original basis `A` is transformed to + `B` with `T*A == B`. If only `B` is needed then :meth:`lll` should be + used as it is a little faster. + + Examples + ======== + + >>> from sympy.polys.domains import ZZ, QQ + >>> from sympy.polys.matrices import DM + >>> X = DM([[1, 0, 0, 0, -20160], + ... [0, 1, 0, 0, 33768], + ... [0, 0, 1, 0, 39578], + ... [0, 0, 0, 1, 47757]], ZZ) + >>> B, T = X.lll_transform(delta=QQ(5, 6)) + >>> T * X == B + True + + See also + ======== + + lll + + """ + reduced, transform = A.rep.lll_transform(delta=delta) + return DomainMatrix.from_rep(reduced), DomainMatrix.from_rep(transform) + + +def _collect_factors(factors_list): + """ + Collect repeating factors and sort. + + >>> from sympy.polys.matrices.domainmatrix import _collect_factors + >>> _collect_factors([([1, 2], 2), ([1, 4], 3), ([1, 2], 5)]) + [([1, 4], 3), ([1, 2], 7)] + """ + factors = Counter() + for factor, exponent in factors_list: + factors[tuple(factor)] += exponent + + factors_list = [(list(f), e) for f, e in factors.items()] + + return _sort_factors(factors_list) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/domainscalar.py b/lib/python3.10/site-packages/sympy/polys/matrices/domainscalar.py new file mode 100644 index 0000000000000000000000000000000000000000..df439a60a0ea0df5f6fac988c06da2a06a4fbac2 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/domainscalar.py @@ -0,0 +1,122 @@ +""" + +Module for the DomainScalar class. + +A DomainScalar represents an element which is in a particular +Domain. The idea is that the DomainScalar class provides the +convenience routines for unifying elements with different domains. + +It assists in Scalar Multiplication and getitem for DomainMatrix. + +""" +from ..constructor import construct_domain + +from sympy.polys.domains import Domain, ZZ + + +class DomainScalar: + r""" + docstring + """ + + def __new__(cls, element, domain): + if not isinstance(domain, Domain): + raise TypeError("domain should be of type Domain") + if not domain.of_type(element): + raise TypeError("element %s should be in domain %s" % (element, domain)) + return cls.new(element, domain) + + @classmethod + def new(cls, element, domain): + obj = super().__new__(cls) + obj.element = element + obj.domain = domain + return obj + + def __repr__(self): + return repr(self.element) + + @classmethod + def from_sympy(cls, expr): + [domain, [element]] = construct_domain([expr]) + return cls.new(element, domain) + + def to_sympy(self): + return self.domain.to_sympy(self.element) + + def to_domain(self, domain): + element = domain.convert_from(self.element, self.domain) + return self.new(element, domain) + + def convert_to(self, domain): + return self.to_domain(domain) + + def unify(self, other): + domain = self.domain.unify(other.domain) + return self.to_domain(domain), other.to_domain(domain) + + def __bool__(self): + return bool(self.element) + + def __add__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.element + other.element, self.domain) + + def __sub__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.element - other.element, self.domain) + + def __mul__(self, other): + if not isinstance(other, DomainScalar): + if isinstance(other, int): + other = DomainScalar(ZZ(other), ZZ) + else: + return NotImplemented + + self, other = self.unify(other) + return self.new(self.element * other.element, self.domain) + + def __floordiv__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.domain.quo(self.element, other.element), self.domain) + + def __mod__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + return self.new(self.domain.rem(self.element, other.element), self.domain) + + def __divmod__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + self, other = self.unify(other) + q, r = self.domain.div(self.element, other.element) + return (self.new(q, self.domain), self.new(r, self.domain)) + + def __pow__(self, n): + if not isinstance(n, int): + return NotImplemented + return self.new(self.element**n, self.domain) + + def __pos__(self): + return self.new(+self.element, self.domain) + + def __neg__(self): + return self.new(-self.element, self.domain) + + def __eq__(self, other): + if not isinstance(other, DomainScalar): + return NotImplemented + return self.element == other.element and self.domain == other.domain + + def is_zero(self): + return self.element == self.domain.zero + + def is_one(self): + return self.element == self.domain.one diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/eigen.py b/lib/python3.10/site-packages/sympy/polys/matrices/eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..17d673c6ea09002e1cfd5357f301c447a7af4341 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/eigen.py @@ -0,0 +1,90 @@ +""" + +Routines for computing eigenvectors with DomainMatrix. + +""" +from sympy.core.symbol import Dummy + +from ..agca.extensions import FiniteExtension +from ..factortools import dup_factor_list +from ..polyroots import roots +from ..polytools import Poly +from ..rootoftools import CRootOf + +from .domainmatrix import DomainMatrix + + +def dom_eigenvects(A, l=Dummy('lambda')): + charpoly = A.charpoly() + rows, cols = A.shape + domain = A.domain + _, factors = dup_factor_list(charpoly, domain) + + rational_eigenvects = [] + algebraic_eigenvects = [] + for base, exp in factors: + if len(base) == 2: + field = domain + eigenval = -base[1] / base[0] + + EE_items = [ + [eigenval if i == j else field.zero for j in range(cols)] + for i in range(rows)] + EE = DomainMatrix(EE_items, (rows, cols), field) + + basis = (A - EE).nullspace(divide_last=True) + rational_eigenvects.append((field, eigenval, exp, basis)) + else: + minpoly = Poly.from_list(base, l, domain=domain) + field = FiniteExtension(minpoly) + eigenval = field(l) + + AA_items = [ + [Poly.from_list([item], l, domain=domain).rep for item in row] + for row in A.rep.to_ddm()] + AA_items = [[field(item) for item in row] for row in AA_items] + AA = DomainMatrix(AA_items, (rows, cols), field) + EE_items = [ + [eigenval if i == j else field.zero for j in range(cols)] + for i in range(rows)] + EE = DomainMatrix(EE_items, (rows, cols), field) + + basis = (AA - EE).nullspace(divide_last=True) + algebraic_eigenvects.append((field, minpoly, exp, basis)) + + return rational_eigenvects, algebraic_eigenvects + + +def dom_eigenvects_to_sympy( + rational_eigenvects, algebraic_eigenvects, + Matrix, **kwargs +): + result = [] + + for field, eigenvalue, multiplicity, eigenvects in rational_eigenvects: + eigenvects = eigenvects.rep.to_ddm() + eigenvalue = field.to_sympy(eigenvalue) + new_eigenvects = [ + Matrix([field.to_sympy(x) for x in vect]) + for vect in eigenvects] + result.append((eigenvalue, multiplicity, new_eigenvects)) + + for field, minpoly, multiplicity, eigenvects in algebraic_eigenvects: + eigenvects = eigenvects.rep.to_ddm() + l = minpoly.gens[0] + + eigenvects = [[field.to_sympy(x) for x in vect] for vect in eigenvects] + + degree = minpoly.degree() + minpoly = minpoly.as_expr() + eigenvals = roots(minpoly, l, **kwargs) + if len(eigenvals) != degree: + eigenvals = [CRootOf(minpoly, l, idx) for idx in range(degree)] + + for eigenvalue in eigenvals: + new_eigenvects = [ + Matrix([x.subs(l, eigenvalue) for x in vect]) + for vect in eigenvects] + result.append((eigenvalue, multiplicity, new_eigenvects)) + + return result diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/exceptions.py b/lib/python3.10/site-packages/sympy/polys/matrices/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..b1e5a4195c66aceed2d5ac1994381d3dec6a64ba --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/exceptions.py @@ -0,0 +1,67 @@ +""" + +Module to define exceptions to be used in sympy.polys.matrices modules and +classes. + +Ideally all exceptions raised in these modules would be defined and documented +here and not e.g. imported from matrices. Also ideally generic exceptions like +ValueError/TypeError would not be raised anywhere. + +""" + + +class DMError(Exception): + """Base class for errors raised by DomainMatrix""" + pass + + +class DMBadInputError(DMError): + """list of lists is inconsistent with shape""" + pass + + +class DMDomainError(DMError): + """domains do not match""" + pass + + +class DMNotAField(DMDomainError): + """domain is not a field""" + pass + + +class DMFormatError(DMError): + """mixed dense/sparse not supported""" + pass + + +class DMNonInvertibleMatrixError(DMError): + """The matrix in not invertible""" + pass + + +class DMRankError(DMError): + """matrix does not have expected rank""" + pass + + +class DMShapeError(DMError): + """shapes are inconsistent""" + pass + + +class DMNonSquareMatrixError(DMShapeError): + """The matrix is not square""" + pass + + +class DMValueError(DMError): + """The value passed is invalid""" + pass + + +__all__ = [ + 'DMError', 'DMBadInputError', 'DMDomainError', 'DMFormatError', + 'DMRankError', 'DMShapeError', 'DMNotAField', + 'DMNonInvertibleMatrixError', 'DMNonSquareMatrixError', 'DMValueError' +] diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/linsolve.py b/lib/python3.10/site-packages/sympy/polys/matrices/linsolve.py new file mode 100644 index 0000000000000000000000000000000000000000..af74058d859b744cf8fe1059ddb7c775fece79c7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/linsolve.py @@ -0,0 +1,230 @@ +# +# sympy.polys.matrices.linsolve module +# +# This module defines the _linsolve function which is the internal workhorse +# used by linsolve. This computes the solution of a system of linear equations +# using the SDM sparse matrix implementation in sympy.polys.matrices.sdm. This +# is a replacement for solve_lin_sys in sympy.polys.solvers which is +# inefficient for large sparse systems due to the use of a PolyRing with many +# generators: +# +# https://github.com/sympy/sympy/issues/20857 +# +# The implementation of _linsolve here handles: +# +# - Extracting the coefficients from the Expr/Eq input equations. +# - Constructing a domain and converting the coefficients to +# that domain. +# - Using the SDM.rref, SDM.nullspace etc methods to generate the full +# solution working with arithmetic only in the domain of the coefficients. +# +# The routines here are particularly designed to be efficient for large sparse +# systems of linear equations although as well as dense systems. It is +# possible that for some small dense systems solve_lin_sys which uses the +# dense matrix implementation DDM will be more efficient. With smaller systems +# though the bulk of the time is spent just preprocessing the inputs and the +# relative time spent in rref is too small to be noticeable. +# + +from collections import defaultdict + +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.singleton import S + +from sympy.polys.constructor import construct_domain +from sympy.polys.solvers import PolyNonlinearError + +from .sdm import ( + SDM, + sdm_irref, + sdm_particular_from_rref, + sdm_nullspace_from_rref +) + +from sympy.utilities.misc import filldedent + + +def _linsolve(eqs, syms): + + """Solve a linear system of equations. + + Examples + ======== + + Solve a linear system with a unique solution: + + >>> from sympy import symbols, Eq + >>> from sympy.polys.matrices.linsolve import _linsolve + >>> x, y = symbols('x, y') + >>> eqs = [Eq(x + y, 1), Eq(x - y, 2)] + >>> _linsolve(eqs, [x, y]) + {x: 3/2, y: -1/2} + + In the case of underdetermined systems the solution will be expressed in + terms of the unknown symbols that are unconstrained: + + >>> _linsolve([Eq(x + y, 0)], [x, y]) + {x: -y, y: y} + + """ + # Number of unknowns (columns in the non-augmented matrix) + nsyms = len(syms) + + # Convert to sparse augmented matrix (len(eqs) x (nsyms+1)) + eqsdict, const = _linear_eq_to_dict(eqs, syms) + Aaug = sympy_dict_to_dm(eqsdict, const, syms) + K = Aaug.domain + + # sdm_irref has issues with float matrices. This uses the ddm_rref() + # function. When sdm_rref() can handle float matrices reasonably this + # should be removed... + if K.is_RealField or K.is_ComplexField: + Aaug = Aaug.to_ddm().rref()[0].to_sdm() + + # Compute reduced-row echelon form (RREF) + Arref, pivots, nzcols = sdm_irref(Aaug) + + # No solution: + if pivots and pivots[-1] == nsyms: + return None + + # Particular solution for non-homogeneous system: + P = sdm_particular_from_rref(Arref, nsyms+1, pivots) + + # Nullspace - general solution to homogeneous system + # Note: using nsyms not nsyms+1 to ignore last column + V, nonpivots = sdm_nullspace_from_rref(Arref, K.one, nsyms, pivots, nzcols) + + # Collect together terms from particular and nullspace: + sol = defaultdict(list) + for i, v in P.items(): + sol[syms[i]].append(K.to_sympy(v)) + for npi, Vi in zip(nonpivots, V): + sym = syms[npi] + for i, v in Vi.items(): + sol[syms[i]].append(sym * K.to_sympy(v)) + + # Use a single call to Add for each term: + sol = {s: Add(*terms) for s, terms in sol.items()} + + # Fill in the zeros: + zero = S.Zero + for s in set(syms) - set(sol): + sol[s] = zero + + # All done! + return sol + + +def sympy_dict_to_dm(eqs_coeffs, eqs_rhs, syms): + """Convert a system of dict equations to a sparse augmented matrix""" + elems = set(eqs_rhs).union(*(e.values() for e in eqs_coeffs)) + K, elems_K = construct_domain(elems, field=True, extension=True) + elem_map = dict(zip(elems, elems_K)) + neqs = len(eqs_coeffs) + nsyms = len(syms) + sym2index = dict(zip(syms, range(nsyms))) + eqsdict = [] + for eq, rhs in zip(eqs_coeffs, eqs_rhs): + eqdict = {sym2index[s]: elem_map[c] for s, c in eq.items()} + if rhs: + eqdict[nsyms] = -elem_map[rhs] + if eqdict: + eqsdict.append(eqdict) + sdm_aug = SDM(enumerate(eqsdict), (neqs, nsyms + 1), K) + return sdm_aug + + +def _linear_eq_to_dict(eqs, syms): + """Convert a system Expr/Eq equations into dict form, returning + the coefficient dictionaries and a list of syms-independent terms + from each expression in ``eqs```. + + Examples + ======== + + >>> from sympy.polys.matrices.linsolve import _linear_eq_to_dict + >>> from sympy.abc import x + >>> _linear_eq_to_dict([2*x + 3], {x}) + ([{x: 2}], [3]) + """ + coeffs = [] + ind = [] + symset = set(syms) + for e in eqs: + if e.is_Equality: + coeff, terms = _lin_eq2dict(e.lhs, symset) + cR, tR = _lin_eq2dict(e.rhs, symset) + # there were no nonlinear errors so now + # cancellation is allowed + coeff -= cR + for k, v in tR.items(): + if k in terms: + terms[k] -= v + else: + terms[k] = -v + # don't store coefficients of 0, however + terms = {k: v for k, v in terms.items() if v} + c, d = coeff, terms + else: + c, d = _lin_eq2dict(e, symset) + coeffs.append(d) + ind.append(c) + return coeffs, ind + + +def _lin_eq2dict(a, symset): + """return (c, d) where c is the sym-independent part of ``a`` and + ``d`` is an efficiently calculated dictionary mapping symbols to + their coefficients. A PolyNonlinearError is raised if non-linearity + is detected. + + The values in the dictionary will be non-zero. + + Examples + ======== + + >>> from sympy.polys.matrices.linsolve import _lin_eq2dict + >>> from sympy.abc import x, y + >>> _lin_eq2dict(x + 2*y + 3, {x, y}) + (3, {x: 1, y: 2}) + """ + if a in symset: + return S.Zero, {a: S.One} + elif a.is_Add: + terms_list = defaultdict(list) + coeff_list = [] + for ai in a.args: + ci, ti = _lin_eq2dict(ai, symset) + coeff_list.append(ci) + for mij, cij in ti.items(): + terms_list[mij].append(cij) + coeff = Add(*coeff_list) + terms = {sym: Add(*coeffs) for sym, coeffs in terms_list.items()} + return coeff, terms + elif a.is_Mul: + terms = terms_coeff = None + coeff_list = [] + for ai in a.args: + ci, ti = _lin_eq2dict(ai, symset) + if not ti: + coeff_list.append(ci) + elif terms is None: + terms = ti + terms_coeff = ci + else: + # since ti is not null and we already have + # a term, this is a cross term + raise PolyNonlinearError(filldedent(''' + nonlinear cross-term: %s''' % a)) + coeff = Mul._from_args(coeff_list) + if terms is None: + return coeff, {} + else: + terms = {sym: coeff * c for sym, c in terms.items()} + return coeff * terms_coeff, terms + elif not a.has_xfree(symset): + return a, {} + else: + raise PolyNonlinearError('nonlinear term: %s' % a) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/lll.py b/lib/python3.10/site-packages/sympy/polys/matrices/lll.py new file mode 100644 index 0000000000000000000000000000000000000000..f33f91d92c5e20f89f302991e494a6a5b9fa4b2e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/lll.py @@ -0,0 +1,94 @@ +from __future__ import annotations + +from math import floor as mfloor + +from sympy.polys.domains import ZZ, QQ +from sympy.polys.matrices.exceptions import DMRankError, DMShapeError, DMValueError, DMDomainError + + +def _ddm_lll(x, delta=QQ(3, 4), return_transform=False): + if QQ(1, 4) >= delta or delta >= QQ(1, 1): + raise DMValueError("delta must lie in range (0.25, 1)") + if x.shape[0] > x.shape[1]: + raise DMShapeError("input matrix must have shape (m, n) with m <= n") + if x.domain != ZZ: + raise DMDomainError("input matrix domain must be ZZ") + m = x.shape[0] + n = x.shape[1] + k = 1 + y = x.copy() + y_star = x.zeros((m, n), QQ) + mu = x.zeros((m, m), QQ) + g_star = [QQ(0, 1) for _ in range(m)] + half = QQ(1, 2) + T = x.eye(m, ZZ) if return_transform else None + linear_dependent_error = "input matrix contains linearly dependent rows" + + def closest_integer(x): + return ZZ(mfloor(x + half)) + + def lovasz_condition(k: int) -> bool: + return g_star[k] >= ((delta - mu[k][k - 1] ** 2) * g_star[k - 1]) + + def mu_small(k: int, j: int) -> bool: + return abs(mu[k][j]) <= half + + def dot_rows(x, y, rows: tuple[int, int]): + return sum(x[rows[0]][z] * y[rows[1]][z] for z in range(x.shape[1])) + + def reduce_row(T, mu, y, rows: tuple[int, int]): + r = closest_integer(mu[rows[0]][rows[1]]) + y[rows[0]] = [y[rows[0]][z] - r * y[rows[1]][z] for z in range(n)] + mu[rows[0]][:rows[1]] = [mu[rows[0]][z] - r * mu[rows[1]][z] for z in range(rows[1])] + mu[rows[0]][rows[1]] -= r + if return_transform: + T[rows[0]] = [T[rows[0]][z] - r * T[rows[1]][z] for z in range(m)] + + for i in range(m): + y_star[i] = [QQ.convert_from(z, ZZ) for z in y[i]] + for j in range(i): + row_dot = dot_rows(y, y_star, (i, j)) + try: + mu[i][j] = row_dot / g_star[j] + except ZeroDivisionError: + raise DMRankError(linear_dependent_error) + y_star[i] = [y_star[i][z] - mu[i][j] * y_star[j][z] for z in range(n)] + g_star[i] = dot_rows(y_star, y_star, (i, i)) + while k < m: + if not mu_small(k, k - 1): + reduce_row(T, mu, y, (k, k - 1)) + if lovasz_condition(k): + for l in range(k - 2, -1, -1): + if not mu_small(k, l): + reduce_row(T, mu, y, (k, l)) + k += 1 + else: + nu = mu[k][k - 1] + alpha = g_star[k] + nu ** 2 * g_star[k - 1] + try: + beta = g_star[k - 1] / alpha + except ZeroDivisionError: + raise DMRankError(linear_dependent_error) + mu[k][k - 1] = nu * beta + g_star[k] = g_star[k] * beta + g_star[k - 1] = alpha + y[k], y[k - 1] = y[k - 1], y[k] + mu[k][:k - 1], mu[k - 1][:k - 1] = mu[k - 1][:k - 1], mu[k][:k - 1] + for i in range(k + 1, m): + xi = mu[i][k] + mu[i][k] = mu[i][k - 1] - nu * xi + mu[i][k - 1] = mu[k][k - 1] * mu[i][k] + xi + if return_transform: + T[k], T[k - 1] = T[k - 1], T[k] + k = max(k - 1, 1) + assert all(lovasz_condition(i) for i in range(1, m)) + assert all(mu_small(i, j) for i in range(m) for j in range(i)) + return y, T + + +def ddm_lll(x, delta=QQ(3, 4)): + return _ddm_lll(x, delta=delta, return_transform=False)[0] + + +def ddm_lll_transform(x, delta=QQ(3, 4)): + return _ddm_lll(x, delta=delta, return_transform=True) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/normalforms.py b/lib/python3.10/site-packages/sympy/polys/matrices/normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..507d9bf53163d56217a9b290bc42510445c20888 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/normalforms.py @@ -0,0 +1,406 @@ +'''Functions returning normal forms of matrices''' + +from collections import defaultdict + +from .domainmatrix import DomainMatrix +from .exceptions import DMDomainError, DMShapeError +from sympy.ntheory.modular import symmetric_residue +from sympy.polys.domains import QQ, ZZ + + +# TODO (future work): +# There are faster algorithms for Smith and Hermite normal forms, which +# we should implement. See e.g. the Kannan-Bachem algorithm: +# + + +def smith_normal_form(m): + ''' + Return the Smith Normal Form of a matrix `m` over the ring `domain`. + This will only work if the ring is a principal ideal domain. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.normalforms import smith_normal_form + >>> m = DomainMatrix([[ZZ(12), ZZ(6), ZZ(4)], + ... [ZZ(3), ZZ(9), ZZ(6)], + ... [ZZ(2), ZZ(16), ZZ(14)]], (3, 3), ZZ) + >>> print(smith_normal_form(m).to_Matrix()) + Matrix([[1, 0, 0], [0, 10, 0], [0, 0, -30]]) + + ''' + invs = invariant_factors(m) + smf = DomainMatrix.diag(invs, m.domain, m.shape) + return smf + + +def add_columns(m, i, j, a, b, c, d): + # replace m[:, i] by a*m[:, i] + b*m[:, j] + # and m[:, j] by c*m[:, i] + d*m[:, j] + for k in range(len(m)): + e = m[k][i] + m[k][i] = a*e + b*m[k][j] + m[k][j] = c*e + d*m[k][j] + + +def invariant_factors(m): + ''' + Return the tuple of abelian invariants for a matrix `m` + (as in the Smith-Normal form) + + References + ========== + + [1] https://en.wikipedia.org/wiki/Smith_normal_form#Algorithm + [2] https://web.archive.org/web/20200331143852/https://sierra.nmsu.edu/morandi/notes/SmithNormalForm.pdf + + ''' + domain = m.domain + if not domain.is_PID: + msg = "The matrix entries must be over a principal ideal domain" + raise ValueError(msg) + + if 0 in m.shape: + return () + + rows, cols = shape = m.shape + m = list(m.to_dense().rep.to_ddm()) + + def add_rows(m, i, j, a, b, c, d): + # replace m[i, :] by a*m[i, :] + b*m[j, :] + # and m[j, :] by c*m[i, :] + d*m[j, :] + for k in range(cols): + e = m[i][k] + m[i][k] = a*e + b*m[j][k] + m[j][k] = c*e + d*m[j][k] + + def clear_column(m): + # make m[1:, 0] zero by row and column operations + if m[0][0] == 0: + return m # pragma: nocover + pivot = m[0][0] + for j in range(1, rows): + if m[j][0] == 0: + continue + d, r = domain.div(m[j][0], pivot) + if r == 0: + add_rows(m, 0, j, 1, 0, -d, 1) + else: + a, b, g = domain.gcdex(pivot, m[j][0]) + d_0 = domain.div(m[j][0], g)[0] + d_j = domain.div(pivot, g)[0] + add_rows(m, 0, j, a, b, d_0, -d_j) + pivot = g + return m + + def clear_row(m): + # make m[0, 1:] zero by row and column operations + if m[0][0] == 0: + return m # pragma: nocover + pivot = m[0][0] + for j in range(1, cols): + if m[0][j] == 0: + continue + d, r = domain.div(m[0][j], pivot) + if r == 0: + add_columns(m, 0, j, 1, 0, -d, 1) + else: + a, b, g = domain.gcdex(pivot, m[0][j]) + d_0 = domain.div(m[0][j], g)[0] + d_j = domain.div(pivot, g)[0] + add_columns(m, 0, j, a, b, d_0, -d_j) + pivot = g + return m + + # permute the rows and columns until m[0,0] is non-zero if possible + ind = [i for i in range(rows) if m[i][0] != 0] + if ind and ind[0] != 0: + m[0], m[ind[0]] = m[ind[0]], m[0] + else: + ind = [j for j in range(cols) if m[0][j] != 0] + if ind and ind[0] != 0: + for row in m: + row[0], row[ind[0]] = row[ind[0]], row[0] + + # make the first row and column except m[0,0] zero + while (any(m[0][i] != 0 for i in range(1,cols)) or + any(m[i][0] != 0 for i in range(1,rows))): + m = clear_column(m) + m = clear_row(m) + + if 1 in shape: + invs = () + else: + lower_right = DomainMatrix([r[1:] for r in m[1:]], (rows-1, cols-1), domain) + invs = invariant_factors(lower_right) + + if m[0][0]: + result = [m[0][0]] + result.extend(invs) + # in case m[0] doesn't divide the invariants of the rest of the matrix + for i in range(len(result)-1): + if result[i] and domain.div(result[i+1], result[i])[1] != 0: + g = domain.gcd(result[i+1], result[i]) + result[i+1] = domain.div(result[i], g)[0]*result[i+1] + result[i] = g + else: + break + else: + result = invs + (m[0][0],) + return tuple(result) + + +def _gcdex(a, b): + r""" + This supports the functions that compute Hermite Normal Form. + + Explanation + =========== + + Let x, y be the coefficients returned by the extended Euclidean + Algorithm, so that x*a + y*b = g. In the algorithms for computing HNF, + it is critical that x, y not only satisfy the condition of being small + in magnitude -- namely that |x| <= |b|/g, |y| <- |a|/g -- but also that + y == 0 when a | b. + + """ + x, y, g = ZZ.gcdex(a, b) + if a != 0 and b % a == 0: + y = 0 + x = -1 if a < 0 else 1 + return x, y, g + + +def _hermite_normal_form(A): + r""" + Compute the Hermite Normal Form of DomainMatrix *A* over :ref:`ZZ`. + + Parameters + ========== + + A : :py:class:`~.DomainMatrix` over domain :ref:`ZZ`. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 2.4.5.) + + """ + if not A.domain.is_ZZ: + raise DMDomainError('Matrix must be over domain ZZ.') + # We work one row at a time, starting from the bottom row, and working our + # way up. + m, n = A.shape + A = A.to_dense().rep.to_ddm().copy() + # Our goal is to put pivot entries in the rightmost columns. + # Invariant: Before processing each row, k should be the index of the + # leftmost column in which we have so far put a pivot. + k = n + for i in range(m - 1, -1, -1): + if k == 0: + # This case can arise when n < m and we've already found n pivots. + # We don't need to consider any more rows, because this is already + # the maximum possible number of pivots. + break + k -= 1 + # k now points to the column in which we want to put a pivot. + # We want zeros in all entries to the left of the pivot column. + for j in range(k - 1, -1, -1): + if A[i][j] != 0: + # Replace cols j, k by lin combs of these cols such that, in row i, + # col j has 0, while col k has the gcd of their row i entries. Note + # that this ensures a nonzero entry in col k. + u, v, d = _gcdex(A[i][k], A[i][j]) + r, s = A[i][k] // d, A[i][j] // d + add_columns(A, k, j, u, v, -s, r) + b = A[i][k] + # Do not want the pivot entry to be negative. + if b < 0: + add_columns(A, k, k, -1, 0, -1, 0) + b = -b + # The pivot entry will be 0 iff the row was 0 from the pivot col all the + # way to the left. In this case, we are still working on the same pivot + # col for the next row. Therefore: + if b == 0: + k += 1 + # If the pivot entry is nonzero, then we want to reduce all entries to its + # right in the sense of the division algorithm, i.e. make them all remainders + # w.r.t. the pivot as divisor. + else: + for j in range(k + 1, n): + q = A[i][j] // b + add_columns(A, j, k, 1, -q, 0, 1) + # Finally, the HNF consists of those columns of A in which we succeeded in making + # a nonzero pivot. + return DomainMatrix.from_rep(A.to_dfm_or_ddm())[:, k:] + + +def _hermite_normal_form_modulo_D(A, D): + r""" + Perform the mod *D* Hermite Normal Form reduction algorithm on + :py:class:`~.DomainMatrix` *A*. + + Explanation + =========== + + If *A* is an $m \times n$ matrix of rank $m$, having Hermite Normal Form + $W$, and if *D* is any positive integer known in advance to be a multiple + of $\det(W)$, then the HNF of *A* can be computed by an algorithm that + works mod *D* in order to prevent coefficient explosion. + + Parameters + ========== + + A : :py:class:`~.DomainMatrix` over :ref:`ZZ` + $m \times n$ matrix, having rank $m$. + D : :ref:`ZZ` + Positive integer, known to be a multiple of the determinant of the + HNF of *A*. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`, or + if *D* is given but is not in :ref:`ZZ`. + + DMShapeError + If the matrix has more rows than columns. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 2.4.8.) + + """ + if not A.domain.is_ZZ: + raise DMDomainError('Matrix must be over domain ZZ.') + if not ZZ.of_type(D) or D < 1: + raise DMDomainError('Modulus D must be positive element of domain ZZ.') + + def add_columns_mod_R(m, R, i, j, a, b, c, d): + # replace m[:, i] by (a*m[:, i] + b*m[:, j]) % R + # and m[:, j] by (c*m[:, i] + d*m[:, j]) % R + for k in range(len(m)): + e = m[k][i] + m[k][i] = symmetric_residue((a * e + b * m[k][j]) % R, R) + m[k][j] = symmetric_residue((c * e + d * m[k][j]) % R, R) + + W = defaultdict(dict) + + m, n = A.shape + if n < m: + raise DMShapeError('Matrix must have at least as many columns as rows.') + A = A.to_dense().rep.to_ddm().copy() + k = n + R = D + for i in range(m - 1, -1, -1): + k -= 1 + for j in range(k - 1, -1, -1): + if A[i][j] != 0: + u, v, d = _gcdex(A[i][k], A[i][j]) + r, s = A[i][k] // d, A[i][j] // d + add_columns_mod_R(A, R, k, j, u, v, -s, r) + b = A[i][k] + if b == 0: + A[i][k] = b = R + u, v, d = _gcdex(b, R) + for ii in range(m): + W[ii][i] = u*A[ii][k] % R + if W[i][i] == 0: + W[i][i] = R + for j in range(i + 1, m): + q = W[i][j] // W[i][i] + add_columns(W, j, i, 1, -q, 0, 1) + R //= d + return DomainMatrix(W, (m, m), ZZ).to_dense() + + +def hermite_normal_form(A, *, D=None, check_rank=False): + r""" + Compute the Hermite Normal Form of :py:class:`~.DomainMatrix` *A* over + :ref:`ZZ`. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.matrices.normalforms import hermite_normal_form + >>> m = DomainMatrix([[ZZ(12), ZZ(6), ZZ(4)], + ... [ZZ(3), ZZ(9), ZZ(6)], + ... [ZZ(2), ZZ(16), ZZ(14)]], (3, 3), ZZ) + >>> print(hermite_normal_form(m).to_Matrix()) + Matrix([[10, 0, 2], [0, 15, 3], [0, 0, 2]]) + + Parameters + ========== + + A : $m \times n$ ``DomainMatrix`` over :ref:`ZZ`. + + D : :ref:`ZZ`, optional + Let $W$ be the HNF of *A*. If known in advance, a positive integer *D* + being any multiple of $\det(W)$ may be provided. In this case, if *A* + also has rank $m$, then we may use an alternative algorithm that works + mod *D* in order to prevent coefficient explosion. + + check_rank : boolean, optional (default=False) + The basic assumption is that, if you pass a value for *D*, then + you already believe that *A* has rank $m$, so we do not waste time + checking it for you. If you do want this to be checked (and the + ordinary, non-modulo *D* algorithm to be used if the check fails), then + set *check_rank* to ``True``. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The HNF of matrix *A*. + + Raises + ====== + + DMDomainError + If the domain of the matrix is not :ref:`ZZ`, or + if *D* is given but is not in :ref:`ZZ`. + + DMShapeError + If the mod *D* algorithm is used but the matrix has more rows than + columns. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithms 2.4.5 and 2.4.8.) + + """ + if not A.domain.is_ZZ: + raise DMDomainError('Matrix must be over domain ZZ.') + if D is not None and (not check_rank or A.convert_to(QQ).rank() == A.shape[0]): + return _hermite_normal_form_modulo_D(A, D) + else: + return _hermite_normal_form(A) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/rref.py b/lib/python3.10/site-packages/sympy/polys/matrices/rref.py new file mode 100644 index 0000000000000000000000000000000000000000..c5a71b04971e8dc8ecac5cc2691f98ba68e35d45 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/rref.py @@ -0,0 +1,422 @@ +# Algorithms for computing the reduced row echelon form of a matrix. +# +# We need to choose carefully which algorithms to use depending on the domain, +# shape, and sparsity of the matrix as well as things like the bit count in the +# case of ZZ or QQ. This is important because the algorithms have different +# performance characteristics in the extremes of dense vs sparse. +# +# In all cases we use the sparse implementations but we need to choose between +# Gauss-Jordan elimination with division and fraction-free Gauss-Jordan +# elimination. For very sparse matrices over ZZ with low bit counts it is +# asymptotically faster to use Gauss-Jordan elimination with division. For +# dense matrices with high bit counts it is asymptotically faster to use +# fraction-free Gauss-Jordan. +# +# The most important thing is to get the extreme cases right because it can +# make a big difference. In between the extremes though we have to make a +# choice and here we use empirically determined thresholds based on timings +# with random sparse matrices. +# +# In the case of QQ we have to consider the denominators as well. If the +# denominators are small then it is faster to clear them and use fraction-free +# Gauss-Jordan over ZZ. If the denominators are large then it is faster to use +# Gauss-Jordan elimination with division over QQ. +# +# Timings for the various algorithms can be found at +# +# https://github.com/sympy/sympy/issues/25410 +# https://github.com/sympy/sympy/pull/25443 + +from sympy.polys.domains import ZZ + +from sympy.polys.matrices.sdm import SDM, sdm_irref, sdm_rref_den +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.dense import ddm_irref, ddm_irref_den + + +def _dm_rref(M, *, method='auto'): + """ + Compute the reduced row echelon form of a ``DomainMatrix``. + + This function is the implementation of :meth:`DomainMatrix.rref`. + + Chooses the best algorithm depending on the domain, shape, and sparsity of + the matrix as well as things like the bit count in the case of :ref:`ZZ` or + :ref:`QQ`. The result is returned over the field associated with the domain + of the Matrix. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + The ``DomainMatrix`` method that calls this function. + sympy.polys.matrices.rref._dm_rref_den + Alternative function for computing RREF with denominator. + """ + method, use_fmt = _dm_rref_choose_method(M, method, denominator=False) + + M, old_fmt = _dm_to_fmt(M, use_fmt) + + if method == 'GJ': + # Use Gauss-Jordan with division over the associated field. + Mf = _to_field(M) + M_rref, pivots = _dm_rref_GJ(Mf) + + elif method == 'FF': + # Use fraction-free GJ over the current domain. + M_rref_f, den, pivots = _dm_rref_den_FF(M) + M_rref = _to_field(M_rref_f) / den + + elif method == 'CD': + # Clear denominators and use fraction-free GJ in the associated ring. + _, Mr = M.clear_denoms_rowwise(convert=True) + M_rref_f, den, pivots = _dm_rref_den_FF(Mr) + M_rref = _to_field(M_rref_f) / den + + else: + raise ValueError(f"Unknown method for rref: {method}") + + M_rref, _ = _dm_to_fmt(M_rref, old_fmt) + + # Invariants: + # - M_rref is in the same format (sparse or dense) as the input matrix. + # - M_rref is in the associated field domain and any denominator was + # divided in (so is implicitly 1 now). + + return M_rref, pivots + + +def _dm_rref_den(M, *, keep_domain=True, method='auto'): + """ + Compute the reduced row echelon form of a ``DomainMatrix`` with denominator. + + This function is the implementation of :meth:`DomainMatrix.rref_den`. + + Chooses the best algorithm depending on the domain, shape, and sparsity of + the matrix as well as things like the bit count in the case of :ref:`ZZ` or + :ref:`QQ`. The result is returned over the same domain as the input matrix + unless ``keep_domain=False`` in which case the result might be over an + associated ring or field domain. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + The ``DomainMatrix`` method that calls this function. + sympy.polys.matrices.rref._dm_rref + Alternative function for computing RREF without denominator. + """ + method, use_fmt = _dm_rref_choose_method(M, method, denominator=True) + + M, old_fmt = _dm_to_fmt(M, use_fmt) + + if method == 'FF': + # Use fraction-free GJ over the current domain. + M_rref, den, pivots = _dm_rref_den_FF(M) + + elif method == 'GJ': + # Use Gauss-Jordan with division over the associated field. + M_rref_f, pivots = _dm_rref_GJ(_to_field(M)) + + # Convert back to the ring? + if keep_domain and M_rref_f.domain != M.domain: + _, M_rref = M_rref_f.clear_denoms(convert=True) + + if pivots: + den = M_rref[0, pivots[0]].element + else: + den = M_rref.domain.one + else: + # Possibly an associated field + M_rref = M_rref_f + den = M_rref.domain.one + + elif method == 'CD': + # Clear denominators and use fraction-free GJ in the associated ring. + _, Mr = M.clear_denoms_rowwise(convert=True) + + M_rref_r, den, pivots = _dm_rref_den_FF(Mr) + + if keep_domain and M_rref_r.domain != M.domain: + # Convert back to the field + M_rref = _to_field(M_rref_r) / den + den = M.domain.one + else: + # Possibly an associated ring + M_rref = M_rref_r + + if pivots: + den = M_rref[0, pivots[0]].element + else: + den = M_rref.domain.one + else: + raise ValueError(f"Unknown method for rref: {method}") + + M_rref, _ = _dm_to_fmt(M_rref, old_fmt) + + # Invariants: + # - M_rref is in the same format (sparse or dense) as the input matrix. + # - If keep_domain=True then M_rref and den are in the same domain as the + # input matrix + # - If keep_domain=False then M_rref might be in an associated ring or + # field domain but den is always in the same domain as M_rref. + + return M_rref, den, pivots + + +def _dm_to_fmt(M, fmt): + """Convert a matrix to the given format and return the old format.""" + old_fmt = M.rep.fmt + if old_fmt == fmt: + pass + elif fmt == 'dense': + M = M.to_dense() + elif fmt == 'sparse': + M = M.to_sparse() + else: + raise ValueError(f'Unknown format: {fmt}') # pragma: no cover + return M, old_fmt + + +# These are the four basic implementations that we want to choose between: + + +def _dm_rref_GJ(M): + """Compute RREF using Gauss-Jordan elimination with division.""" + if M.rep.fmt == 'sparse': + return _dm_rref_GJ_sparse(M) + else: + return _dm_rref_GJ_dense(M) + + +def _dm_rref_den_FF(M): + """Compute RREF using fraction-free Gauss-Jordan elimination.""" + if M.rep.fmt == 'sparse': + return _dm_rref_den_FF_sparse(M) + else: + return _dm_rref_den_FF_dense(M) + + +def _dm_rref_GJ_sparse(M): + """Compute RREF using sparse Gauss-Jordan elimination with division.""" + M_rref_d, pivots, _ = sdm_irref(M.rep) + M_rref_sdm = SDM(M_rref_d, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_sdm), pivots + + +def _dm_rref_GJ_dense(M): + """Compute RREF using dense Gauss-Jordan elimination with division.""" + partial_pivot = M.domain.is_RR or M.domain.is_CC + ddm = M.rep.to_ddm().copy() + pivots = ddm_irref(ddm, _partial_pivot=partial_pivot) + M_rref_ddm = DDM(ddm, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_ddm.to_dfm_or_ddm()), pivots + + +def _dm_rref_den_FF_sparse(M): + """Compute RREF using sparse fraction-free Gauss-Jordan elimination.""" + M_rref_d, den, pivots = sdm_rref_den(M.rep, M.domain) + M_rref_sdm = SDM(M_rref_d, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_sdm), den, pivots + + +def _dm_rref_den_FF_dense(M): + """Compute RREF using sparse fraction-free Gauss-Jordan elimination.""" + ddm = M.rep.to_ddm().copy() + den, pivots = ddm_irref_den(ddm, M.domain) + M_rref_ddm = DDM(ddm, M.shape, M.domain) + pivots = tuple(pivots) + return M.from_rep(M_rref_ddm.to_dfm_or_ddm()), den, pivots + + +def _dm_rref_choose_method(M, method, *, denominator=False): + """Choose the fastest method for computing RREF for M.""" + + if method != 'auto': + if method.endswith('_dense'): + method = method[:-len('_dense')] + use_fmt = 'dense' + else: + use_fmt = 'sparse' + + else: + # The sparse implementations are always faster + use_fmt = 'sparse' + + K = M.domain + + if K.is_ZZ: + method = _dm_rref_choose_method_ZZ(M, denominator=denominator) + elif K.is_QQ: + method = _dm_rref_choose_method_QQ(M, denominator=denominator) + elif K.is_RR or K.is_CC: + # TODO: Add partial pivot support to the sparse implementations. + method = 'GJ' + use_fmt = 'dense' + elif K.is_EX and M.rep.fmt == 'dense' and not denominator: + # Do not switch to the sparse implementation for EX because the + # domain does not have proper canonicalization and the sparse + # implementation gives equivalent but non-identical results over EX + # from performing arithmetic in a different order. Specifically + # test_issue_23718 ends up getting a more complicated expression + # when using the sparse implementation. Probably the best fix for + # this is something else but for now we stick with the dense + # implementation for EX if the matrix is already dense. + method = 'GJ' + use_fmt = 'dense' + else: + # This is definitely suboptimal. More work is needed to determine + # the best method for computing RREF over different domains. + if denominator: + method = 'FF' + else: + method = 'GJ' + + return method, use_fmt + + +def _dm_rref_choose_method_QQ(M, *, denominator=False): + """Choose the fastest method for computing RREF over QQ.""" + # The same sorts of considerations apply here as in the case of ZZ. Here + # though a new more significant consideration is what sort of denominators + # we have and what to do with them so we focus on that. + + # First compute the density. This is the average number of non-zero entries + # per row but only counting rows that have at least one non-zero entry + # since RREF can ignore fully zero rows. + density, _, ncols = _dm_row_density(M) + + # For sparse matrices use Gauss-Jordan elimination over QQ regardless. + if density < min(5, ncols/2): + return 'GJ' + + # Compare the bit-length of the lcm of the denominators to the bit length + # of the numerators. + # + # The threshold here is empirical: we prefer rref over QQ if clearing + # denominators would result in a numerator matrix having 5x the bit size of + # the current numerators. + numers, denoms = _dm_QQ_numers_denoms(M) + numer_bits = max([n.bit_length() for n in numers], default=1) + + denom_lcm = ZZ.one + for d in denoms: + denom_lcm = ZZ.lcm(denom_lcm, d) + if denom_lcm.bit_length() > 5*numer_bits: + return 'GJ' + + # If we get here then the matrix is dense and the lcm of the denominators + # is not too large compared to the numerators. For particularly small + # denominators it is fastest just to clear them and use fraction-free + # Gauss-Jordan over ZZ. With very small denominators this is a little + # faster than using rref_den over QQ but there is an intermediate regime + # where rref_den over QQ is significantly faster. The small denominator + # case is probably very common because small fractions like 1/2 or 1/3 are + # often seen in user inputs. + + if denom_lcm.bit_length() < 50: + return 'CD' + else: + return 'FF' + + +def _dm_rref_choose_method_ZZ(M, *, denominator=False): + """Choose the fastest method for computing RREF over ZZ.""" + # In the extreme of very sparse matrices and low bit counts it is faster to + # use Gauss-Jordan elimination over QQ rather than fraction-free + # Gauss-Jordan over ZZ. In the opposite extreme of dense matrices and high + # bit counts it is faster to use fraction-free Gauss-Jordan over ZZ. These + # two extreme cases need to be handled differently because they lead to + # different asymptotic complexities. In between these two extremes we need + # a threshold for deciding which method to use. This threshold is + # determined empirically by timing the two methods with random matrices. + + # The disadvantage of using empirical timings is that future optimisations + # might change the relative speeds so this can easily become out of date. + # The main thing is to get the asymptotic complexity right for the extreme + # cases though so the precise value of the threshold is hopefully not too + # important. + + # Empirically determined parameter. + PARAM = 10000 + + # First compute the density. This is the average number of non-zero entries + # per row but only counting rows that have at least one non-zero entry + # since RREF can ignore fully zero rows. + density, nrows_nz, ncols = _dm_row_density(M) + + # For small matrices use QQ if more than half the entries are zero. + if nrows_nz < 10: + if density < ncols/2: + return 'GJ' + else: + return 'FF' + + # These are just shortcuts for the formula below. + if density < 5: + return 'GJ' + elif density > 5 + PARAM/nrows_nz: + return 'FF' # pragma: no cover + + # Maximum bitsize of any entry. + elements = _dm_elements(M) + bits = max([e.bit_length() for e in elements], default=1) + + # Wideness parameter. This is 1 for square or tall matrices but >1 for wide + # matrices. + wideness = max(1, 2/3*ncols/nrows_nz) + + max_density = (5 + PARAM/(nrows_nz*bits**2)) * wideness + + if density < max_density: + return 'GJ' + else: + return 'FF' + + +def _dm_row_density(M): + """Density measure for sparse matrices. + + Defines the "density", ``d`` as the average number of non-zero entries per + row except ignoring rows that are fully zero. RREF can ignore fully zero + rows so they are excluded. By definition ``d >= 1`` except that we define + ``d = 0`` for the zero matrix. + + Returns ``(density, nrows_nz, ncols)`` where ``nrows_nz`` counts the number + of nonzero rows and ``ncols`` is the number of columns. + """ + # Uses the SDM dict-of-dicts representation. + ncols = M.shape[1] + rows_nz = M.rep.to_sdm().values() + if not rows_nz: + return 0, 0, ncols + else: + nrows_nz = len(rows_nz) + density = sum(map(len, rows_nz)) / nrows_nz + return density, nrows_nz, ncols + + +def _dm_elements(M): + """Return nonzero elements of a DomainMatrix.""" + elements, _ = M.to_flat_nz() + return elements + + +def _dm_QQ_numers_denoms(Mq): + """Returns the numerators and denominators of a DomainMatrix over QQ.""" + elements = _dm_elements(Mq) + numers = [e.numerator for e in elements] + denoms = [e.denominator for e in elements] + return numers, denoms + + +def _to_field(M): + """Convert a DomainMatrix to a field if possible.""" + K = M.domain + if K.has_assoc_Field: + return M.to_field() + else: + return M diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/sdm.py b/lib/python3.10/site-packages/sympy/polys/matrices/sdm.py new file mode 100644 index 0000000000000000000000000000000000000000..5afe39adfad5b28dad209742d34af9bf9cac6991 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/sdm.py @@ -0,0 +1,2164 @@ +""" + +Module for the SDM class. + +""" + +from operator import add, neg, pos, sub, mul +from collections import defaultdict + +from sympy.external.gmpy import GROUND_TYPES +from sympy.utilities.decorator import doctest_depends_on +from sympy.utilities.iterables import _strongly_connected_components + +from .exceptions import DMBadInputError, DMDomainError, DMShapeError + +from sympy.polys.domains import QQ + +from .ddm import DDM + + +if GROUND_TYPES != 'flint': + __doctest_skip__ = ['SDM.to_dfm', 'SDM.to_dfm_or_ddm'] + + +class SDM(dict): + r"""Sparse matrix based on polys domain elements + + This is a dict subclass and is a wrapper for a dict of dicts that supports + basic matrix arithmetic +, -, *, **. + + + In order to create a new :py:class:`~.SDM`, a dict + of dicts mapping non-zero elements to their + corresponding row and column in the matrix is needed. + + We also need to specify the shape and :py:class:`~.Domain` + of our :py:class:`~.SDM` object. + + We declare a 2x2 :py:class:`~.SDM` matrix belonging + to QQ domain as shown below. + The 2x2 Matrix in the example is + + .. math:: + A = \left[\begin{array}{ccc} + 0 & \frac{1}{2} \\ + 0 & 0 \end{array} \right] + + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1:QQ(1, 2)}} + >>> A = SDM(elemsdict, (2, 2), QQ) + >>> A + {0: {1: 1/2}} + + We can manipulate :py:class:`~.SDM` the same way + as a Matrix class + + >>> from sympy import ZZ + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ) + >>> A + B + {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}} + + Multiplication + + >>> A*B + {0: {1: 8}, 1: {0: 3}} + >>> A*ZZ(2) + {0: {1: 4}, 1: {0: 2}} + + """ + + fmt = 'sparse' + is_DFM = False + is_DDM = False + + def __init__(self, elemsdict, shape, domain): + super().__init__(elemsdict) + self.shape = self.rows, self.cols = m, n = shape + self.domain = domain + + if not all(0 <= r < m for r in self): + raise DMBadInputError("Row out of range") + if not all(0 <= c < n for row in self.values() for c in row): + raise DMBadInputError("Column out of range") + + def getitem(self, i, j): + try: + return self[i][j] + except KeyError: + m, n = self.shape + if -m <= i < m and -n <= j < n: + try: + return self[i % m][j % n] + except KeyError: + return self.domain.zero + else: + raise IndexError("index out of range") + + def setitem(self, i, j, value): + m, n = self.shape + if not (-m <= i < m and -n <= j < n): + raise IndexError("index out of range") + i, j = i % m, j % n + if value: + try: + self[i][j] = value + except KeyError: + self[i] = {j: value} + else: + rowi = self.get(i, None) + if rowi is not None: + try: + del rowi[j] + except KeyError: + pass + else: + if not rowi: + del self[i] + + def extract_slice(self, slice1, slice2): + m, n = self.shape + ri = range(m)[slice1] + ci = range(n)[slice2] + + sdm = {} + for i, row in self.items(): + if i in ri: + row = {ci.index(j): e for j, e in row.items() if j in ci} + if row: + sdm[ri.index(i)] = row + + return self.new(sdm, (len(ri), len(ci)), self.domain) + + def extract(self, rows, cols): + if not (self and rows and cols): + return self.zeros((len(rows), len(cols)), self.domain) + + m, n = self.shape + if not (-m <= min(rows) <= max(rows) < m): + raise IndexError('Row index out of range') + if not (-n <= min(cols) <= max(cols) < n): + raise IndexError('Column index out of range') + + # rows and cols can contain duplicates e.g. M[[1, 2, 2], [0, 1]] + # Build a map from row/col in self to list of rows/cols in output + rowmap = defaultdict(list) + colmap = defaultdict(list) + for i2, i1 in enumerate(rows): + rowmap[i1 % m].append(i2) + for j2, j1 in enumerate(cols): + colmap[j1 % n].append(j2) + + # Used to efficiently skip zero rows/cols + rowset = set(rowmap) + colset = set(colmap) + + sdm1 = self + sdm2 = {} + for i1 in rowset & sdm1.keys(): + row1 = sdm1[i1] + row2 = {} + for j1 in colset & row1.keys(): + row1_j1 = row1[j1] + for j2 in colmap[j1]: + row2[j2] = row1_j1 + if row2: + for i2 in rowmap[i1]: + sdm2[i2] = row2.copy() + + return self.new(sdm2, (len(rows), len(cols)), self.domain) + + def __str__(self): + rowsstr = [] + for i, row in self.items(): + elemsstr = ', '.join('%s: %s' % (j, elem) for j, elem in row.items()) + rowsstr.append('%s: {%s}' % (i, elemsstr)) + return '{%s}' % ', '.join(rowsstr) + + def __repr__(self): + cls = type(self).__name__ + rows = dict.__repr__(self) + return '%s(%s, %s, %s)' % (cls, rows, self.shape, self.domain) + + @classmethod + def new(cls, sdm, shape, domain): + """ + + Parameters + ========== + + sdm: A dict of dicts for non-zero elements in SDM + shape: tuple representing dimension of SDM + domain: Represents :py:class:`~.Domain` of SDM + + Returns + ======= + + An :py:class:`~.SDM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1: QQ(2)}} + >>> A = SDM.new(elemsdict, (2, 2), QQ) + >>> A + {0: {1: 2}} + + """ + return cls(sdm, shape, domain) + + def copy(A): + """ + Returns the copy of a :py:class:`~.SDM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1:QQ(2)}, 1:{}} + >>> A = SDM(elemsdict, (2, 2), QQ) + >>> B = A.copy() + >>> B + {0: {1: 2}, 1: {}} + + """ + Ac = {i: Ai.copy() for i, Ai in A.items()} + return A.new(Ac, A.shape, A.domain) + + @classmethod + def from_list(cls, ddm, shape, domain): + """ + Create :py:class:`~.SDM` object from a list of lists. + + Parameters + ========== + + ddm: + list of lists containing domain elements + shape: + Dimensions of :py:class:`~.SDM` matrix + domain: + Represents :py:class:`~.Domain` of :py:class:`~.SDM` object + + Returns + ======= + + :py:class:`~.SDM` containing elements of ddm + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> ddm = [[QQ(1, 2), QQ(0)], [QQ(0), QQ(3, 4)]] + >>> A = SDM.from_list(ddm, (2, 2), QQ) + >>> A + {0: {0: 1/2}, 1: {1: 3/4}} + + See Also + ======== + + to_list + from_list_flat + from_dok + from_ddm + """ + + m, n = shape + if not (len(ddm) == m and all(len(row) == n for row in ddm)): + raise DMBadInputError("Inconsistent row-list/shape") + getrow = lambda i: {j:ddm[i][j] for j in range(n) if ddm[i][j]} + irows = ((i, getrow(i)) for i in range(m)) + sdm = {i: row for i, row in irows if row} + return cls(sdm, shape, domain) + + @classmethod + def from_ddm(cls, ddm): + """ + Create :py:class:`~.SDM` from a :py:class:`~.DDM`. + + Examples + ======== + + >>> from sympy.polys.matrices.ddm import DDM + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> ddm = DDM( [[QQ(1, 2), 0], [0, QQ(3, 4)]], (2, 2), QQ) + >>> A = SDM.from_ddm(ddm) + >>> A + {0: {0: 1/2}, 1: {1: 3/4}} + >>> SDM.from_ddm(ddm).to_ddm() == ddm + True + + See Also + ======== + + to_ddm + from_list + from_list_flat + from_dok + """ + return cls.from_list(ddm, ddm.shape, ddm.domain) + + def to_list(M): + """ + Convert a :py:class:`~.SDM` object to a list of lists. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> elemsdict = {0:{1:QQ(2)}, 1:{}} + >>> A = SDM(elemsdict, (2, 2), QQ) + >>> A.to_list() + [[0, 2], [0, 0]] + + + """ + m, n = M.shape + zero = M.domain.zero + ddm = [[zero] * n for _ in range(m)] + for i, row in M.items(): + for j, e in row.items(): + ddm[i][j] = e + return ddm + + def to_list_flat(M): + """ + Convert :py:class:`~.SDM` to a flat list. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{0: QQ(3)}}, (2, 2), QQ) + >>> A.to_list_flat() + [0, 2, 3, 0] + >>> A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + from_list_flat + to_list + to_dok + to_ddm + """ + m, n = M.shape + zero = M.domain.zero + flat = [zero] * (m * n) + for i, row in M.items(): + for j, e in row.items(): + flat[i*n + j] = e + return flat + + @classmethod + def from_list_flat(cls, elements, shape, domain): + """ + Create :py:class:`~.SDM` from a flat list of elements. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM.from_list_flat([QQ(0), QQ(2), QQ(0), QQ(0)], (2, 2), QQ) + >>> A + {0: {1: 2}} + >>> A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + True + + See Also + ======== + + to_list_flat + from_list + from_dok + from_ddm + """ + m, n = shape + if len(elements) != m * n: + raise DMBadInputError("Inconsistent flat-list shape") + sdm = defaultdict(dict) + for inj, element in enumerate(elements): + if element: + i, j = divmod(inj, n) + sdm[i][j] = element + return cls(sdm, shape, domain) + + def to_flat_nz(M): + """ + Convert :class:`SDM` to a flat list of nonzero elements and data. + + Explanation + =========== + + This is used to operate on a list of the elements of a matrix and then + reconstruct a modified matrix with elements in the same positions using + :meth:`from_flat_nz`. Zero elements are omitted from the list. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{0: QQ(3)}}, (2, 2), QQ) + >>> elements, data = A.to_flat_nz() + >>> elements + [2, 3] + >>> A == A.from_flat_nz(elements, data, A.domain) + True + + See Also + ======== + + from_flat_nz + to_list_flat + sympy.polys.matrices.ddm.DDM.to_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.to_flat_nz + """ + dok = M.to_dok() + indices = tuple(dok) + elements = list(dok.values()) + data = (indices, M.shape) + return elements, data + + @classmethod + def from_flat_nz(cls, elements, data, domain): + """ + Reconstruct a :class:`~.SDM` after calling :meth:`to_flat_nz`. + + See :meth:`to_flat_nz` for explanation. + + See Also + ======== + + to_flat_nz + from_list_flat + sympy.polys.matrices.ddm.DDM.from_flat_nz + sympy.polys.matrices.domainmatrix.DomainMatrix.from_flat_nz + """ + indices, shape = data + dok = dict(zip(indices, elements)) + return cls.from_dok(dok, shape, domain) + + def to_dod(M): + """ + Convert to dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> A.to_dod() + {0: {1: 2}, 1: {0: 3}} + + See Also + ======== + + from_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dod + """ + return {i: row.copy() for i, row in M.items()} + + @classmethod + def from_dod(cls, dod, shape, domain): + """ + Create :py:class:`~.SDM` from dictionary of dictionaries (dod) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> dod = {0: {1: QQ(2)}, 1: {0: QQ(3)}} + >>> A = SDM.from_dod(dod, (2, 2), QQ) + >>> A + {0: {1: 2}, 1: {0: 3}} + >>> A == SDM.from_dod(A.to_dod(), A.shape, A.domain) + True + + See Also + ======== + + to_dod + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dod + """ + sdm = defaultdict(dict) + for i, row in dod.items(): + for j, e in row.items(): + if e: + sdm[i][j] = e + return cls(sdm, shape, domain) + + def to_dok(M): + """ + Convert to dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> A.to_dok() + {(0, 1): 2, (1, 0): 3} + + See Also + ======== + + from_dok + to_list + to_list_flat + to_ddm + """ + return {(i, j): e for i, row in M.items() for j, e in row.items()} + + @classmethod + def from_dok(cls, dok, shape, domain): + """ + Create :py:class:`~.SDM` from dictionary of keys (dok) format. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> dok = {(0, 1): QQ(2), (1, 0): QQ(3)} + >>> A = SDM.from_dok(dok, (2, 2), QQ) + >>> A + {0: {1: 2}, 1: {0: 3}} + >>> A == SDM.from_dok(A.to_dok(), A.shape, A.domain) + True + + See Also + ======== + + to_dok + from_list + from_list_flat + from_ddm + """ + sdm = defaultdict(dict) + for (i, j), e in dok.items(): + if e: + sdm[i][j] = e + return cls(sdm, shape, domain) + + def iter_values(M): + """ + Iterate over the nonzero values of a :py:class:`~.SDM` matrix. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> list(A.iter_values()) + [2, 3] + + """ + for row in M.values(): + yield from row.values() + + def iter_items(M): + """ + Iterate over indices and values of the nonzero elements. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0: {1: QQ(2)}, 1: {0: QQ(3)}}, (2, 2), QQ) + >>> list(A.iter_items()) + [((0, 1), 2), ((1, 0), 3)] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.iter_items + """ + for i, row in M.items(): + for j, e in row.items(): + yield (i, j), e + + def to_ddm(M): + """ + Convert a :py:class:`~.SDM` object to a :py:class:`~.DDM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.to_ddm() + [[0, 2], [0, 0]] + + """ + return DDM(M.to_list(), M.shape, M.domain) + + def to_sdm(M): + """ + Convert to :py:class:`~.SDM` format (returns self). + """ + return M + + @doctest_depends_on(ground_types=['flint']) + def to_dfm(M): + """ + Convert a :py:class:`~.SDM` object to a :py:class:`~.DFM` object + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.to_dfm() + [[0, 2], [0, 0]] + + See Also + ======== + + to_ddm + to_dfm_or_ddm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm + """ + return M.to_ddm().to_dfm() + + @doctest_depends_on(ground_types=['flint']) + def to_dfm_or_ddm(M): + """ + Convert to :py:class:`~.DFM` if possible, else :py:class:`~.DDM`. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.to_dfm_or_ddm() + [[0, 2], [0, 0]] + >>> type(A.to_dfm_or_ddm()) # depends on the ground types + + + See Also + ======== + + to_ddm + to_dfm + sympy.polys.matrices.domainmatrix.DomainMatrix.to_dfm_or_ddm + """ + return M.to_ddm().to_dfm_or_ddm() + + @classmethod + def zeros(cls, shape, domain): + r""" + + Returns a :py:class:`~.SDM` of size shape, + belonging to the specified domain + + In the example below we declare a matrix A where, + + .. math:: + A := \left[\begin{array}{ccc} + 0 & 0 & 0 \\ + 0 & 0 & 0 \end{array} \right] + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM.zeros((2, 3), QQ) + >>> A + {} + + """ + return cls({}, shape, domain) + + @classmethod + def ones(cls, shape, domain): + one = domain.one + m, n = shape + row = dict(zip(range(n), [one]*n)) + sdm = {i: row.copy() for i in range(m)} + return cls(sdm, shape, domain) + + @classmethod + def eye(cls, shape, domain): + """ + + Returns a identity :py:class:`~.SDM` matrix of dimensions + size x size, belonging to the specified domain + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> I = SDM.eye((2, 2), QQ) + >>> I + {0: {0: 1}, 1: {1: 1}} + + """ + if isinstance(shape, int): + rows, cols = shape, shape + else: + rows, cols = shape + one = domain.one + sdm = {i: {i: one} for i in range(min(rows, cols))} + return cls(sdm, (rows, cols), domain) + + @classmethod + def diag(cls, diagonal, domain, shape=None): + if shape is None: + shape = (len(diagonal), len(diagonal)) + sdm = {i: {i: v} for i, v in enumerate(diagonal) if v} + return cls(sdm, shape, domain) + + def transpose(M): + """ + + Returns the transpose of a :py:class:`~.SDM` matrix + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy import QQ + >>> A = SDM({0:{1:QQ(2)}, 1:{}}, (2, 2), QQ) + >>> A.transpose() + {1: {0: 2}} + + """ + MT = sdm_transpose(M) + return M.new(MT, M.shape[::-1], M.domain) + + def __add__(A, B): + if not isinstance(B, SDM): + return NotImplemented + elif A.shape != B.shape: + raise DMShapeError("Matrix size mismatch: %s + %s" % (A.shape, B.shape)) + return A.add(B) + + def __sub__(A, B): + if not isinstance(B, SDM): + return NotImplemented + elif A.shape != B.shape: + raise DMShapeError("Matrix size mismatch: %s - %s" % (A.shape, B.shape)) + return A.sub(B) + + def __neg__(A): + return A.neg() + + def __mul__(A, B): + """A * B""" + if isinstance(B, SDM): + return A.matmul(B) + elif B in A.domain: + return A.mul(B) + else: + return NotImplemented + + def __rmul__(a, b): + if b in a.domain: + return a.rmul(b) + else: + return NotImplemented + + def matmul(A, B): + """ + Performs matrix multiplication of two SDM matrices + + Parameters + ========== + + A, B: SDM to multiply + + Returns + ======= + + SDM + SDM after multiplication + + Raises + ====== + + DomainError + If domain of A does not match + with that of B + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0:ZZ(2), 1:ZZ(3)}, 1:{0:ZZ(4)}}, (2, 2), ZZ) + >>> A.matmul(B) + {0: {0: 8}, 1: {0: 2, 1: 3}} + + """ + if A.domain != B.domain: + raise DMDomainError + m, n = A.shape + n2, o = B.shape + if n != n2: + raise DMShapeError + C = sdm_matmul(A, B, A.domain, m, o) + return A.new(C, (m, o), A.domain) + + def mul(A, b): + """ + Multiplies each element of A with a scalar b + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.mul(ZZ(3)) + {0: {1: 6}, 1: {0: 3}} + + """ + Csdm = unop_dict(A, lambda aij: aij*b) + return A.new(Csdm, A.shape, A.domain) + + def rmul(A, b): + Csdm = unop_dict(A, lambda aij: b*aij) + return A.new(Csdm, A.shape, A.domain) + + def mul_elementwise(A, B): + if A.domain != B.domain: + raise DMDomainError + if A.shape != B.shape: + raise DMShapeError + zero = A.domain.zero + fzero = lambda e: zero + Csdm = binop_dict(A, B, mul, fzero, fzero) + return A.new(Csdm, A.shape, A.domain) + + def add(A, B): + """ + + Adds two :py:class:`~.SDM` matrices + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ) + >>> A.add(B) + {0: {0: 3, 1: 2}, 1: {0: 1, 1: 4}} + + """ + Csdm = binop_dict(A, B, add, pos, pos) + return A.new(Csdm, A.shape, A.domain) + + def sub(A, B): + """ + + Subtracts two :py:class:`~.SDM` matrices + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> B = SDM({0:{0: ZZ(3)}, 1:{1:ZZ(4)}}, (2, 2), ZZ) + >>> A.sub(B) + {0: {0: -3, 1: 2}, 1: {0: 1, 1: -4}} + + """ + Csdm = binop_dict(A, B, sub, pos, neg) + return A.new(Csdm, A.shape, A.domain) + + def neg(A): + """ + + Returns the negative of a :py:class:`~.SDM` matrix + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.neg() + {0: {1: -2}, 1: {0: -1}} + + """ + Csdm = unop_dict(A, neg) + return A.new(Csdm, A.shape, A.domain) + + def convert_to(A, K): + """ + Converts the :py:class:`~.Domain` of a :py:class:`~.SDM` matrix to K + + Examples + ======== + + >>> from sympy import ZZ, QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.convert_to(QQ) + {0: {1: 2}, 1: {0: 1}} + + """ + Kold = A.domain + if K == Kold: + return A.copy() + Ak = unop_dict(A, lambda e: K.convert_from(e, Kold)) + return A.new(Ak, A.shape, K) + + def nnz(A): + """Number of non-zero elements in the :py:class:`~.SDM` matrix. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{1: ZZ(2)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + >>> A.nnz() + 2 + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nnz + """ + return sum(map(len, A.values())) + + def scc(A): + """Strongly connected components of a square matrix *A*. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0: ZZ(2)}, 1:{1:ZZ(1)}}, (2, 2), ZZ) + >>> A.scc() + [[0], [1]] + + See also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.scc + """ + rows, cols = A.shape + assert rows == cols + V = range(rows) + Emap = {v: list(A.get(v, [])) for v in V} + return _strongly_connected_components(V, Emap) + + def rref(A): + """ + + Returns reduced-row echelon form and list of pivots for the :py:class:`~.SDM` + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(2), 1:QQ(4)}}, (2, 2), QQ) + >>> A.rref() + ({0: {0: 1, 1: 2}}, [0]) + + """ + B, pivots, _ = sdm_irref(A) + return A.new(B, A.shape, A.domain), pivots + + def rref_den(A): + """ + + Returns reduced-row echelon form (RREF) with denominator and pivots. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(2), 1:QQ(4)}}, (2, 2), QQ) + >>> A.rref_den() + ({0: {0: 1, 1: 2}}, 1, [0]) + + """ + K = A.domain + A_rref_sdm, denom, pivots = sdm_rref_den(A, K) + A_rref = A.new(A_rref_sdm, A.shape, A.domain) + return A_rref, denom, pivots + + def inv(A): + """ + + Returns inverse of a matrix A + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.inv() + {0: {0: -2, 1: 1}, 1: {0: 3/2, 1: -1/2}} + + """ + return A.to_dfm_or_ddm().inv().to_sdm() + + def det(A): + """ + Returns determinant of A + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.det() + -2 + + """ + # It would be better to have a sparse implementation of det for use + # with very sparse matrices. Extremely sparse matrices probably just + # have determinant zero and we could probably detect that very quickly. + # In the meantime, we convert to a dense matrix and use ddm_idet. + # + # If GROUND_TYPES=flint though then we will use Flint's implementation + # if possible (dfm). + return A.to_dfm_or_ddm().det() + + def lu(A): + """ + + Returns LU decomposition for a matrix A + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.lu() + ({0: {0: 1}, 1: {0: 3, 1: 1}}, {0: {0: 1, 1: 2}, 1: {1: -2}}, []) + + """ + L, U, swaps = A.to_ddm().lu() + return A.from_ddm(L), A.from_ddm(U), swaps + + def lu_solve(A, b): + """ + + Uses LU decomposition to solve Ax = b, + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> b = SDM({0:{0:QQ(1)}, 1:{0:QQ(2)}}, (2, 1), QQ) + >>> A.lu_solve(b) + {1: {0: 1/2}} + + """ + return A.from_ddm(A.to_ddm().lu_solve(b.to_ddm())) + + def nullspace(A): + """ + Nullspace of a :py:class:`~.SDM` matrix A. + + The domain of the matrix must be a field. + + It is better to use the :meth:`~.DomainMatrix.nullspace` method rather + than this method which is otherwise no longer used. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0: QQ(2), 1: QQ(4)}}, (2, 2), QQ) + >>> A.nullspace() + ({0: {0: -2, 1: 1}}, [1]) + + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + The preferred way to get the nullspace of a matrix. + + """ + ncols = A.shape[1] + one = A.domain.one + B, pivots, nzcols = sdm_irref(A) + K, nonpivots = sdm_nullspace_from_rref(B, one, ncols, pivots, nzcols) + K = dict(enumerate(K)) + shape = (len(K), ncols) + return A.new(K, shape, A.domain), nonpivots + + def nullspace_from_rref(A, pivots=None): + """ + Returns nullspace for a :py:class:`~.SDM` matrix ``A`` in RREF. + + The domain of the matrix can be any domain. + + The matrix must already be in reduced row echelon form (RREF). + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import SDM + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0: QQ(2), 1: QQ(4)}}, (2, 2), QQ) + >>> A_rref, pivots = A.rref() + >>> A_null, nonpivots = A_rref.nullspace_from_rref(pivots) + >>> A_null + {0: {0: -2, 1: 1}} + >>> pivots + [0] + >>> nonpivots + [1] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace + The higher-level function that would usually be called instead of + calling this one directly. + + sympy.polys.matrices.domainmatrix.DomainMatrix.nullspace_from_rref + The higher-level direct equivalent of this function. + + sympy.polys.matrices.ddm.DDM.nullspace_from_rref + The equivalent function for dense :py:class:`~.DDM` matrices. + + """ + m, n = A.shape + K = A.domain + + if pivots is None: + pivots = sorted(map(min, A.values())) + + if not pivots: + return A.eye((n, n), K), list(range(n)) + elif len(pivots) == n: + return A.zeros((0, n), K), [] + + # In fraction-free RREF the nonzero entry inserted for the pivots is + # not necessarily 1. + pivot_val = A[0][pivots[0]] + assert not K.is_zero(pivot_val) + + pivots_set = set(pivots) + + # Loop once over all nonzero entries making a map from column indices + # to the nonzero entries in that column along with the row index of the + # nonzero entry. This is basically the transpose of the matrix. + nonzero_cols = defaultdict(list) + for i, Ai in A.items(): + for j, Aij in Ai.items(): + nonzero_cols[j].append((i, Aij)) + + # Usually in SDM we want to avoid looping over the dimensions of the + # matrix because it is optimised to support extremely sparse matrices. + # Here in nullspace though every zero column becomes a nonzero column + # so we need to loop once over the columns at least (range(n)) rather + # than just the nonzero entries of the matrix. We can still avoid + # an inner loop over the rows though by using the nonzero_cols map. + basis = [] + nonpivots = [] + for j in range(n): + if j in pivots_set: + continue + nonpivots.append(j) + + vec = {j: pivot_val} + for ip, Aij in nonzero_cols[j]: + vec[pivots[ip]] = -Aij + + basis.append(vec) + + sdm = dict(enumerate(basis)) + A_null = A.new(sdm, (len(basis), n), K) + + return (A_null, nonpivots) + + def particular(A): + ncols = A.shape[1] + B, pivots, nzcols = sdm_irref(A) + P = sdm_particular_from_rref(B, ncols, pivots) + rep = {0:P} if P else {} + return A.new(rep, (1, ncols-1), A.domain) + + def hstack(A, *B): + """Horizontally stacks :py:class:`~.SDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + + >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ) + >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ) + >>> A.hstack(B) + {0: {0: 1, 1: 2, 2: 5, 3: 6}, 1: {0: 3, 1: 4, 2: 7, 3: 8}} + + >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ) + >>> A.hstack(B, C) + {0: {0: 1, 1: 2, 2: 5, 3: 6, 4: 9, 5: 10}, 1: {0: 3, 1: 4, 2: 7, 3: 8, 4: 11, 5: 12}} + """ + Anew = dict(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkrows == rows + assert Bk.domain == domain + + for i, Bki in Bk.items(): + Ai = Anew.get(i, None) + if Ai is None: + Anew[i] = Ai = {} + for j, Bkij in Bki.items(): + Ai[j + cols] = Bkij + cols += Bkcols + + return A.new(Anew, (rows, cols), A.domain) + + def vstack(A, *B): + """Vertically stacks :py:class:`~.SDM` matrices. + + Examples + ======== + + >>> from sympy import ZZ + >>> from sympy.polys.matrices.sdm import SDM + + >>> A = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ) + >>> B = SDM({0: {0: ZZ(5), 1: ZZ(6)}, 1: {0: ZZ(7), 1: ZZ(8)}}, (2, 2), ZZ) + >>> A.vstack(B) + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}} + + >>> C = SDM({0: {0: ZZ(9), 1: ZZ(10)}, 1: {0: ZZ(11), 1: ZZ(12)}}, (2, 2), ZZ) + >>> A.vstack(B, C) + {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}, 2: {0: 5, 1: 6}, 3: {0: 7, 1: 8}, 4: {0: 9, 1: 10}, 5: {0: 11, 1: 12}} + """ + Anew = dict(A.copy()) + rows, cols = A.shape + domain = A.domain + + for Bk in B: + Bkrows, Bkcols = Bk.shape + assert Bkcols == cols + assert Bk.domain == domain + + for i, Bki in Bk.items(): + Anew[i + rows] = Bki + rows += Bkrows + + return A.new(Anew, (rows, cols), A.domain) + + def applyfunc(self, func, domain): + sdm = {i: {j: func(e) for j, e in row.items()} for i, row in self.items()} + return self.new(sdm, self.shape, domain) + + def charpoly(A): + """ + Returns the coefficients of the characteristic polynomial + of the :py:class:`~.SDM` matrix. These elements will be domain elements. + The domain of the elements will be same as domain of the :py:class:`~.SDM`. + + Examples + ======== + + >>> from sympy import QQ, Symbol + >>> from sympy.polys.matrices.sdm import SDM + >>> from sympy.polys import Poly + >>> A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + >>> A.charpoly() + [1, -5, -2] + + We can create a polynomial using the + coefficients using :py:class:`~.Poly` + + >>> x = Symbol('x') + >>> p = Poly(A.charpoly(), x, domain=A.domain) + >>> p + Poly(x**2 - 5*x - 2, x, domain='QQ') + + """ + K = A.domain + n, _ = A.shape + pdict = sdm_berk(A, n, K) + plist = [K.zero] * (n + 1) + for i, pi in pdict.items(): + plist[i] = pi + return plist + + def is_zero_matrix(self): + """ + Says whether this matrix has all zero entries. + """ + return not self + + def is_upper(self): + """ + Says whether this matrix is upper-triangular. True can be returned + even if the matrix is not square. + """ + return all(i <= j for i, row in self.items() for j in row) + + def is_lower(self): + """ + Says whether this matrix is lower-triangular. True can be returned + even if the matrix is not square. + """ + return all(i >= j for i, row in self.items() for j in row) + + def is_diagonal(self): + """ + Says whether this matrix is diagonal. True can be returned + even if the matrix is not square. + """ + return all(i == j for i, row in self.items() for j in row) + + def diagonal(self): + """ + Returns the diagonal of the matrix as a list. + """ + m, n = self.shape + zero = self.domain.zero + return [row.get(i, zero) for i, row in self.items() if i < n] + + def lll(A, delta=QQ(3, 4)): + """ + Returns the LLL-reduced basis for the :py:class:`~.SDM` matrix. + """ + return A.to_dfm_or_ddm().lll(delta=delta).to_sdm() + + def lll_transform(A, delta=QQ(3, 4)): + """ + Returns the LLL-reduced basis and transformation matrix. + """ + reduced, transform = A.to_dfm_or_ddm().lll_transform(delta=delta) + return reduced.to_sdm(), transform.to_sdm() + + +def binop_dict(A, B, fab, fa, fb): + Anz, Bnz = set(A), set(B) + C = {} + + for i in Anz & Bnz: + Ai, Bi = A[i], B[i] + Ci = {} + Anzi, Bnzi = set(Ai), set(Bi) + for j in Anzi & Bnzi: + Cij = fab(Ai[j], Bi[j]) + if Cij: + Ci[j] = Cij + for j in Anzi - Bnzi: + Cij = fa(Ai[j]) + if Cij: + Ci[j] = Cij + for j in Bnzi - Anzi: + Cij = fb(Bi[j]) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + for i in Anz - Bnz: + Ai = A[i] + Ci = {} + for j, Aij in Ai.items(): + Cij = fa(Aij) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + for i in Bnz - Anz: + Bi = B[i] + Ci = {} + for j, Bij in Bi.items(): + Cij = fb(Bij) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + return C + + +def unop_dict(A, f): + B = {} + for i, Ai in A.items(): + Bi = {} + for j, Aij in Ai.items(): + Bij = f(Aij) + if Bij: + Bi[j] = Bij + if Bi: + B[i] = Bi + return B + + +def sdm_transpose(M): + MT = {} + for i, Mi in M.items(): + for j, Mij in Mi.items(): + try: + MT[j][i] = Mij + except KeyError: + MT[j] = {i: Mij} + return MT + + +def sdm_dotvec(A, B, K): + return K.sum(A[j] * B[j] for j in A.keys() & B.keys()) + + +def sdm_matvecmul(A, B, K): + C = {} + for i, Ai in A.items(): + Ci = sdm_dotvec(Ai, B, K) + if Ci: + C[i] = Ci + return C + + +def sdm_matmul(A, B, K, m, o): + # + # Should be fast if A and B are very sparse. + # Consider e.g. A = B = eye(1000). + # + # The idea here is that we compute C = A*B in terms of the rows of C and + # B since the dict of dicts representation naturally stores the matrix as + # rows. The ith row of C (Ci) is equal to the sum of Aik * Bk where Bk is + # the kth row of B. The algorithm below loops over each nonzero element + # Aik of A and if the corresponding row Bj is nonzero then we do + # Ci += Aik * Bk. + # To make this more efficient we don't need to loop over all elements Aik. + # Instead for each row Ai we compute the intersection of the nonzero + # columns in Ai with the nonzero rows in B. That gives the k such that + # Aik and Bk are both nonzero. In Python the intersection of two sets + # of int can be computed very efficiently. + # + if K.is_EXRAW: + return sdm_matmul_exraw(A, B, K, m, o) + + C = {} + B_knz = set(B) + for i, Ai in A.items(): + Ci = {} + Ai_knz = set(Ai) + for k in Ai_knz & B_knz: + Aik = Ai[k] + for j, Bkj in B[k].items(): + Cij = Ci.get(j, None) + if Cij is not None: + Cij = Cij + Aik * Bkj + if Cij: + Ci[j] = Cij + else: + Ci.pop(j) + else: + Cij = Aik * Bkj + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + return C + + +def sdm_matmul_exraw(A, B, K, m, o): + # + # Like sdm_matmul above except that: + # + # - Handles cases like 0*oo -> nan (sdm_matmul skips multipication by zero) + # - Uses K.sum (Add(*items)) for efficient addition of Expr + # + zero = K.zero + C = {} + B_knz = set(B) + for i, Ai in A.items(): + Ci_list = defaultdict(list) + Ai_knz = set(Ai) + + # Nonzero row/column pair + for k in Ai_knz & B_knz: + Aik = Ai[k] + if zero * Aik == zero: + # This is the main inner loop: + for j, Bkj in B[k].items(): + Ci_list[j].append(Aik * Bkj) + else: + for j in range(o): + Ci_list[j].append(Aik * B[k].get(j, zero)) + + # Zero row in B, check for infinities in A + for k in Ai_knz - B_knz: + zAik = zero * Ai[k] + if zAik != zero: + for j in range(o): + Ci_list[j].append(zAik) + + # Add terms using K.sum (Add(*terms)) for efficiency + Ci = {} + for j, Cij_list in Ci_list.items(): + Cij = K.sum(Cij_list) + if Cij: + Ci[j] = Cij + if Ci: + C[i] = Ci + + # Find all infinities in B + for k, Bk in B.items(): + for j, Bkj in Bk.items(): + if zero * Bkj != zero: + for i in range(m): + Aik = A.get(i, {}).get(k, zero) + # If Aik is not zero then this was handled above + if Aik == zero: + Ci = C.get(i, {}) + Cij = Ci.get(j, zero) + Aik * Bkj + if Cij != zero: + Ci[j] = Cij + else: # pragma: no cover + # Not sure how we could get here but let's raise an + # exception just in case. + raise RuntimeError + C[i] = Ci + + return C + + +def sdm_irref(A): + """RREF and pivots of a sparse matrix *A*. + + Compute the reduced row echelon form (RREF) of the matrix *A* and return a + list of the pivot columns. This routine does not work in place and leaves + the original matrix *A* unmodified. + + The domain of the matrix must be a field. + + Examples + ======== + + This routine works with a dict of dicts sparse representation of a matrix: + + >>> from sympy import QQ + >>> from sympy.polys.matrices.sdm import sdm_irref + >>> A = {0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}} + >>> Arref, pivots, _ = sdm_irref(A) + >>> Arref + {0: {0: 1}, 1: {1: 1}} + >>> pivots + [0, 1] + + The analogous calculation with :py:class:`~.MutableDenseMatrix` would be + + >>> from sympy import Matrix + >>> M = Matrix([[1, 2], [3, 4]]) + >>> Mrref, pivots = M.rref() + >>> Mrref + Matrix([ + [1, 0], + [0, 1]]) + >>> pivots + (0, 1) + + Notes + ===== + + The cost of this algorithm is determined purely by the nonzero elements of + the matrix. No part of the cost of any step in this algorithm depends on + the number of rows or columns in the matrix. No step depends even on the + number of nonzero rows apart from the primary loop over those rows. The + implementation is much faster than ddm_rref for sparse matrices. In fact + at the time of writing it is also (slightly) faster than the dense + implementation even if the input is a fully dense matrix so it seems to be + faster in all cases. + + The elements of the matrix should support exact division with ``/``. For + example elements of any domain that is a field (e.g. ``QQ``) should be + fine. No attempt is made to handle inexact arithmetic. + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref + The higher-level function that would normally be used to call this + routine. + sympy.polys.matrices.dense.ddm_irref + The dense equivalent of this routine. + sdm_rref_den + Fraction-free version of this routine. + """ + # + # Any zeros in the matrix are not stored at all so an element is zero if + # its row dict has no index at that key. A row is entirely zero if its + # row index is not in the outer dict. Since rref reorders the rows and + # removes zero rows we can completely discard the row indices. The first + # step then copies the row dicts into a list sorted by the index of the + # first nonzero column in each row. + # + # The algorithm then processes each row Ai one at a time. Previously seen + # rows are used to cancel their pivot columns from Ai. Then a pivot from + # Ai is chosen and is cancelled from all previously seen rows. At this + # point Ai joins the previously seen rows. Once all rows are seen all + # elimination has occurred and the rows are sorted by pivot column index. + # + # The previously seen rows are stored in two separate groups. The reduced + # group consists of all rows that have been reduced to a single nonzero + # element (the pivot). There is no need to attempt any further reduction + # with these. Rows that still have other nonzeros need to be considered + # when Ai is cancelled from the previously seen rows. + # + # A dict nonzerocolumns is used to map from a column index to a set of + # previously seen rows that still have a nonzero element in that column. + # This means that we can cancel the pivot from Ai into the previously seen + # rows without needing to loop over each row that might have a zero in + # that column. + # + + # Row dicts sorted by index of first nonzero column + # (Maybe sorting is not needed/useful.) + Arows = sorted((Ai.copy() for Ai in A.values()), key=min) + + # Each processed row has an associated pivot column. + # pivot_row_map maps from the pivot column index to the row dict. + # This means that we can represent a set of rows purely as a set of their + # pivot indices. + pivot_row_map = {} + + # Set of pivot indices for rows that are fully reduced to a single nonzero. + reduced_pivots = set() + + # Set of pivot indices for rows not fully reduced + nonreduced_pivots = set() + + # Map from column index to a set of pivot indices representing the rows + # that have a nonzero at that column. + nonzero_columns = defaultdict(set) + + while Arows: + # Select pivot element and row + Ai = Arows.pop() + + # Nonzero columns from fully reduced pivot rows can be removed + Ai = {j: Aij for j, Aij in Ai.items() if j not in reduced_pivots} + + # Others require full row cancellation + for j in nonreduced_pivots & set(Ai): + Aj = pivot_row_map[j] + Aij = Ai[j] + Ainz = set(Ai) + Ajnz = set(Aj) + for k in Ajnz - Ainz: + Ai[k] = - Aij * Aj[k] + Ai.pop(j) + Ainz.remove(j) + for k in Ajnz & Ainz: + Aik = Ai[k] - Aij * Aj[k] + if Aik: + Ai[k] = Aik + else: + Ai.pop(k) + + # We have now cancelled previously seen pivots from Ai. + # If it is zero then discard it. + if not Ai: + continue + + # Choose a pivot from Ai: + j = min(Ai) + Aij = Ai[j] + pivot_row_map[j] = Ai + Ainz = set(Ai) + + # Normalise the pivot row to make the pivot 1. + # + # This approach is slow for some domains. Cross cancellation might be + # better for e.g. QQ(x) with division delayed to the final steps. + Aijinv = Aij**-1 + for l in Ai: + Ai[l] *= Aijinv + + # Use Aij to cancel column j from all previously seen rows + for k in nonzero_columns.pop(j, ()): + Ak = pivot_row_map[k] + Akj = Ak[j] + Aknz = set(Ak) + for l in Ainz - Aknz: + Ak[l] = - Akj * Ai[l] + nonzero_columns[l].add(k) + Ak.pop(j) + Aknz.remove(j) + for l in Ainz & Aknz: + Akl = Ak[l] - Akj * Ai[l] + if Akl: + Ak[l] = Akl + else: + # Drop nonzero elements + Ak.pop(l) + if l != j: + nonzero_columns[l].remove(k) + if len(Ak) == 1: + reduced_pivots.add(k) + nonreduced_pivots.remove(k) + + if len(Ai) == 1: + reduced_pivots.add(j) + else: + nonreduced_pivots.add(j) + for l in Ai: + if l != j: + nonzero_columns[l].add(j) + + # All done! + pivots = sorted(reduced_pivots | nonreduced_pivots) + pivot2row = {p: n for n, p in enumerate(pivots)} + nonzero_columns = {c: {pivot2row[p] for p in s} for c, s in nonzero_columns.items()} + rows = [pivot_row_map[i] for i in pivots] + rref = dict(enumerate(rows)) + return rref, pivots, nonzero_columns + + +def sdm_rref_den(A, K): + """ + Return the reduced row echelon form (RREF) of A with denominator. + + The RREF is computed using fraction-free Gauss-Jordan elimination. + + Explanation + =========== + + The algorithm used is the fraction-free version of Gauss-Jordan elimination + described as FFGJ in [1]_. Here it is modified to handle zero or missing + pivots and to avoid redundant arithmetic. This implementation is also + optimized for sparse matrices. + + The domain $K$ must support exact division (``K.exquo``) but does not need + to be a field. This method is suitable for most exact rings and fields like + :ref:`ZZ`, :ref:`QQ` and :ref:`QQ(a)`. In the case of :ref:`QQ` or + :ref:`K(x)` it might be more efficient to clear denominators and use + :ref:`ZZ` or :ref:`K[x]` instead. + + For inexact domains like :ref:`RR` and :ref:`CC` use ``ddm_irref`` instead. + + Examples + ======== + + >>> from sympy.polys.matrices.sdm import sdm_rref_den + >>> from sympy.polys.domains import ZZ + >>> A = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}} + >>> A_rref, den, pivots = sdm_rref_den(A, ZZ) + >>> A_rref + {0: {0: -2}, 1: {1: -2}} + >>> den + -2 + >>> pivots + [0, 1] + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.rref_den + Higher-level interface to ``sdm_rref_den`` that would usually be used + instead of calling this function directly. + sympy.polys.matrices.sdm.sdm_rref_den + The ``SDM`` method that uses this function. + sdm_irref + Computes RREF using field division. + ddm_irref_den + The dense version of this algorithm. + + References + ========== + + .. [1] Fraction-free algorithms for linear and polynomial equations. + George C. Nakos , Peter R. Turner , Robert M. Williams. + https://dl.acm.org/doi/10.1145/271130.271133 + """ + # + # We represent each row of the matrix as a dict mapping column indices to + # nonzero elements. We will build the RREF matrix starting from an empty + # matrix and appending one row at a time. At each step we will have the + # RREF of the rows we have processed so far. + # + # Our representation of the RREF divides it into three parts: + # + # 1. Fully reduced rows having only a single nonzero element (the pivot). + # 2. Partially reduced rows having nonzeros after the pivot. + # 3. The current denominator and divisor. + # + # For example if the incremental RREF might be: + # + # [2, 0, 0, 0, 0, 0, 0, 0, 0, 0] + # [0, 0, 2, 0, 0, 0, 7, 0, 0, 0] + # [0, 0, 0, 0, 0, 2, 0, 0, 0, 0] + # [0, 0, 0, 0, 0, 0, 0, 2, 0, 0] + # [0, 0, 0, 0, 0, 0, 0, 0, 2, 0] + # + # Here the second row is partially reduced and the other rows are fully + # reduced. The denominator would be 2 in this case. We distinguish the + # fully reduced rows because we can handle them more efficiently when + # adding a new row. + # + # When adding a new row we need to multiply it by the current denominator. + # Then we reduce the new row by cross cancellation with the previous rows. + # Then if it is not reduced to zero we take its leading entry as the new + # pivot, cross cancel the new row from the previous rows and update the + # denominator. In the fraction-free version this last step requires + # multiplying and dividing the whole matrix by the new pivot and the + # current divisor. The advantage of building the RREF one row at a time is + # that in the sparse case we only need to work with the relatively sparse + # upper rows of the matrix. The simple version of FFGJ in [1] would + # multiply and divide all the dense lower rows at each step. + + # Handle the trivial cases. + if not A: + return ({}, K.one, []) + elif len(A) == 1: + Ai, = A.values() + j = min(Ai) + Aij = Ai[j] + return ({0: Ai.copy()}, Aij, [j]) + + # For inexact domains like RR[x] we use quo and discard the remainder. + # Maybe it would be better for K.exquo to do this automatically. + if K.is_Exact: + exquo = K.exquo + else: + exquo = K.quo + + # Make sure we have the rows in order to make this deterministic from the + # outset. + _, rows_in_order = zip(*sorted(A.items())) + + col_to_row_reduced = {} + col_to_row_unreduced = {} + reduced = col_to_row_reduced.keys() + unreduced = col_to_row_unreduced.keys() + + # Our representation of the RREF so far. + A_rref_rows = [] + denom = None + divisor = None + + # The rows that remain to be added to the RREF. These are sorted by the + # column index of their leading entry. Note that sorted() is stable so the + # previous sort by unique row index is still needed to make this + # deterministic (there may be multiple rows with the same leading column). + A_rows = sorted(rows_in_order, key=min) + + for Ai in A_rows: + + # All fully reduced columns can be immediately discarded. + Ai = {j: Aij for j, Aij in Ai.items() if j not in reduced} + + # We need to multiply the new row by the current denominator to bring + # it into the same scale as the previous rows and then cross-cancel to + # reduce it wrt the previous unreduced rows. All pivots in the previous + # rows are equal to denom so the coefficients we need to make a linear + # combination of the previous rows to cancel into the new row are just + # the ones that are already in the new row *before* we multiply by + # denom. We compute that linear combination first and then multiply the + # new row by denom before subtraction. + Ai_cancel = {} + + for j in unreduced & Ai.keys(): + # Remove the pivot column from the new row since it would become + # zero anyway. + Aij = Ai.pop(j) + + Aj = A_rref_rows[col_to_row_unreduced[j]] + + for k, Ajk in Aj.items(): + Aik_cancel = Ai_cancel.get(k) + if Aik_cancel is None: + Ai_cancel[k] = Aij * Ajk + else: + Aik_cancel = Aik_cancel + Aij * Ajk + if Aik_cancel: + Ai_cancel[k] = Aik_cancel + else: + Ai_cancel.pop(k) + + # Multiply the new row by the current denominator and subtract. + Ai_nz = set(Ai) + Ai_cancel_nz = set(Ai_cancel) + + d = denom or K.one + + for k in Ai_cancel_nz - Ai_nz: + Ai[k] = -Ai_cancel[k] + + for k in Ai_nz - Ai_cancel_nz: + Ai[k] = Ai[k] * d + + for k in Ai_cancel_nz & Ai_nz: + Aik = Ai[k] * d - Ai_cancel[k] + if Aik: + Ai[k] = Aik + else: + Ai.pop(k) + + # Now Ai has the same scale as the other rows and is reduced wrt the + # unreduced rows. + + # If the row is reduced to zero then discard it. + if not Ai: + continue + + # Choose a pivot for this row. + j = min(Ai) + Aij = Ai.pop(j) + + # Cross cancel the unreduced rows by the new row. + # a[k][l] = (a[i][j]*a[k][l] - a[k][j]*a[i][l]) / divisor + for pk, k in list(col_to_row_unreduced.items()): + + Ak = A_rref_rows[k] + + if j not in Ak: + # This row is already reduced wrt the new row but we need to + # bring it to the same scale as the new denominator. This step + # is not needed in sdm_irref. + for l, Akl in Ak.items(): + Akl = Akl * Aij + if divisor is not None: + Akl = exquo(Akl, divisor) + Ak[l] = Akl + continue + + Akj = Ak.pop(j) + Ai_nz = set(Ai) + Ak_nz = set(Ak) + + for l in Ai_nz - Ak_nz: + Ak[l] = - Akj * Ai[l] + if divisor is not None: + Ak[l] = exquo(Ak[l], divisor) + + # This loop also not needed in sdm_irref. + for l in Ak_nz - Ai_nz: + Ak[l] = Aij * Ak[l] + if divisor is not None: + Ak[l] = exquo(Ak[l], divisor) + + for l in Ai_nz & Ak_nz: + Akl = Aij * Ak[l] - Akj * Ai[l] + if Akl: + if divisor is not None: + Akl = exquo(Akl, divisor) + Ak[l] = Akl + else: + Ak.pop(l) + + if not Ak: + col_to_row_unreduced.pop(pk) + col_to_row_reduced[pk] = k + + i = len(A_rref_rows) + A_rref_rows.append(Ai) + if Ai: + col_to_row_unreduced[j] = i + else: + col_to_row_reduced[j] = i + + # Update the denominator. + if not K.is_one(Aij): + if denom is None: + denom = Aij + else: + denom *= Aij + + if divisor is not None: + denom = exquo(denom, divisor) + + # Update the divisor. + divisor = denom + + if denom is None: + denom = K.one + + # Sort the rows by their leading column index. + col_to_row = {**col_to_row_reduced, **col_to_row_unreduced} + row_to_col = {i: j for j, i in col_to_row.items()} + A_rref_rows_col = [(row_to_col[i], Ai) for i, Ai in enumerate(A_rref_rows)] + pivots, A_rref = zip(*sorted(A_rref_rows_col)) + pivots = list(pivots) + + # Insert the pivot values + for i, Ai in enumerate(A_rref): + Ai[pivots[i]] = denom + + A_rref_sdm = dict(enumerate(A_rref)) + + return A_rref_sdm, denom, pivots + + +def sdm_nullspace_from_rref(A, one, ncols, pivots, nonzero_cols): + """Get nullspace from A which is in RREF""" + nonpivots = sorted(set(range(ncols)) - set(pivots)) + + K = [] + for j in nonpivots: + Kj = {j:one} + for i in nonzero_cols.get(j, ()): + Kj[pivots[i]] = -A[i][j] + K.append(Kj) + + return K, nonpivots + + +def sdm_particular_from_rref(A, ncols, pivots): + """Get a particular solution from A which is in RREF""" + P = {} + for i, j in enumerate(pivots): + Ain = A[i].get(ncols-1, None) + if Ain is not None: + P[j] = Ain / A[i][j] + return P + + +def sdm_berk(M, n, K): + """ + Berkowitz algorithm for computing the characteristic polynomial. + + Explanation + =========== + + The Berkowitz algorithm is a division-free algorithm for computing the + characteristic polynomial of a matrix over any commutative ring using only + arithmetic in the coefficient ring. This implementation is for sparse + matrices represented in a dict-of-dicts format (like :class:`SDM`). + + Examples + ======== + + >>> from sympy import Matrix + >>> from sympy.polys.matrices.sdm import sdm_berk + >>> from sympy.polys.domains import ZZ + >>> M = {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + >>> sdm_berk(M, 2, ZZ) + {0: 1, 1: -5, 2: -2} + >>> Matrix([[1, 2], [3, 4]]).charpoly() + PurePoly(lambda**2 - 5*lambda - 2, lambda, domain='ZZ') + + See Also + ======== + + sympy.polys.matrices.domainmatrix.DomainMatrix.charpoly + The high-level interface to this function. + sympy.polys.matrices.dense.ddm_berk + The dense version of this function. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Samuelson%E2%80%93Berkowitz_algorithm + """ + zero = K.zero + one = K.one + + if n == 0: + return {0: one} + elif n == 1: + pdict = {0: one} + if M00 := M.get(0, {}).get(0, zero): + pdict[1] = -M00 + + # M = [[a, R], + # [C, A]] + a, R, C, A = K.zero, {}, {}, defaultdict(dict) + for i, Mi in M.items(): + for j, Mij in Mi.items(): + if i and j: + A[i-1][j-1] = Mij + elif i: + C[i-1] = Mij + elif j: + R[j-1] = Mij + else: + a = Mij + + # T = [ 1, 0, 0, 0, 0, ... ] + # [ -a, 1, 0, 0, 0, ... ] + # [ -R*C, -a, 1, 0, 0, ... ] + # [ -R*A*C, -R*C, -a, 1, 0, ... ] + # [-R*A^2*C, -R*A*C, -R*C, -a, 1, ... ] + # [ ... ] + # T is (n+1) x n + # + # In the sparse case we might have A^m*C = 0 for some m making T banded + # rather than triangular so we just compute the nonzero entries of the + # first column rather than constructing the matrix explicitly. + + AnC = C + RC = sdm_dotvec(R, C, K) + + Tvals = [one, -a, -RC] + for i in range(3, n+1): + AnC = sdm_matvecmul(A, AnC, K) + if not AnC: + break + RAnC = sdm_dotvec(R, AnC, K) + Tvals.append(-RAnC) + + # Strip trailing zeros + while Tvals and not Tvals[-1]: + Tvals.pop() + + q = sdm_berk(A, n-1, K) + + # This would be the explicit multiplication T*q but we can do better: + # + # T = {} + # for i in range(n+1): + # Ti = {} + # for j in range(max(0, i-len(Tvals)+1), min(i+1, n)): + # Ti[j] = Tvals[i-j] + # T[i] = Ti + # Tq = sdm_matvecmul(T, q, K) + # + # In the sparse case q might be mostly zero. We know that T[i,j] is nonzero + # for i <= j < i + len(Tvals) so if q does not have a nonzero entry in that + # range then Tq[j] must be zero. We exploit this potential banded + # structure and the potential sparsity of q to compute Tq more efficiently. + + Tvals = Tvals[::-1] + + Tq = {} + + for i in range(min(q), min(max(q)+len(Tvals), n+1)): + Ti = dict(enumerate(Tvals, i-len(Tvals)+1)) + if Tqi := sdm_dotvec(Ti, q, K): + Tq[i] = Tqi + + return Tq diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__init__.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4cc0a6e2bae550ce3052fd71b33838444d8fcfc7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_ddm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_ddm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b33a08f81694f8b020ef0ede1b39c54481798936 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_ddm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_dense.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_dense.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..25e6d272bef1c8f0b6410017c1761f5884171544 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_dense.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainmatrix.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainmatrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36dd426ab9aa6c9f5a8100f54d505e5aaa9ce453 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainmatrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainscalar.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainscalar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..740fa97b1421a48340feeabe7f6f994c42338429 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_domainscalar.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_eigen.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_eigen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27deae9bea6226afc78ffe6b8d2725b1ee257a2f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_eigen.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_inverse.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_inverse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f912488c85414f6c89407a078feaa765541854f9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_inverse.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_linsolve.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_linsolve.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da3f6bd2d9bd9a442b8fd3e96ccc11d2fb52302f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_linsolve.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_lll.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_lll.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c728fd24addd20618192a19e2144b9456d3d244b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_lll.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_normalforms.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_normalforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dd6e8d309c8b68cee8ccdccf9b8b4f812d229894 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_normalforms.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_nullspace.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_nullspace.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..689ecabf03ca5025d45ea8a0278be8c8d24976d0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_nullspace.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_rref.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_rref.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7a55d8101d9f7497b0369a74b5e6053ee110b14b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_rref.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_sdm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_sdm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b94be1a017660a023f10eab43354331b0c43577b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_sdm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_xxm.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_xxm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7f49901001fe18816cc28708ade5eb9ef6032adf Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/matrices/tests/__pycache__/test_xxm.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_ddm.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_ddm.py new file mode 100644 index 0000000000000000000000000000000000000000..44c862461e85d503696e621874c10d67d8ee1f1d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_ddm.py @@ -0,0 +1,558 @@ +from sympy.testing.pytest import raises +from sympy.external.gmpy import GROUND_TYPES + +from sympy.polys import ZZ, QQ + +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.exceptions import ( + DMShapeError, DMNonInvertibleMatrixError, DMDomainError, + DMBadInputError) + + +def test_DDM_init(): + items = [[ZZ(0), ZZ(1), ZZ(2)], [ZZ(3), ZZ(4), ZZ(5)]] + shape = (2, 3) + ddm = DDM(items, shape, ZZ) + assert ddm.shape == shape + assert ddm.rows == 2 + assert ddm.cols == 3 + assert ddm.domain == ZZ + + raises(DMBadInputError, lambda: DDM([[ZZ(2), ZZ(3)]], (2, 2), ZZ)) + raises(DMBadInputError, lambda: DDM([[ZZ(1)], [ZZ(2), ZZ(3)]], (2, 2), ZZ)) + + +def test_DDM_getsetitem(): + ddm = DDM([[ZZ(2), ZZ(3)], [ZZ(4), ZZ(5)]], (2, 2), ZZ) + + assert ddm[0][0] == ZZ(2) + assert ddm[0][1] == ZZ(3) + assert ddm[1][0] == ZZ(4) + assert ddm[1][1] == ZZ(5) + + raises(IndexError, lambda: ddm[2][0]) + raises(IndexError, lambda: ddm[0][2]) + + ddm[0][0] = ZZ(-1) + assert ddm[0][0] == ZZ(-1) + + +def test_DDM_str(): + ddm = DDM([[ZZ(0), ZZ(1)], [ZZ(2), ZZ(3)]], (2, 2), ZZ) + if GROUND_TYPES == 'gmpy': # pragma: no cover + assert str(ddm) == '[[0, 1], [2, 3]]' + assert repr(ddm) == 'DDM([[mpz(0), mpz(1)], [mpz(2), mpz(3)]], (2, 2), ZZ)' + else: # pragma: no cover + assert repr(ddm) == 'DDM([[0, 1], [2, 3]], (2, 2), ZZ)' + assert str(ddm) == '[[0, 1], [2, 3]]' + + +def test_DDM_eq(): + items = [[ZZ(0), ZZ(1)], [ZZ(2), ZZ(3)]] + ddm1 = DDM(items, (2, 2), ZZ) + ddm2 = DDM(items, (2, 2), ZZ) + + assert (ddm1 == ddm1) is True + assert (ddm1 == items) is False + assert (items == ddm1) is False + assert (ddm1 == ddm2) is True + assert (ddm2 == ddm1) is True + + assert (ddm1 != ddm1) is False + assert (ddm1 != items) is True + assert (items != ddm1) is True + assert (ddm1 != ddm2) is False + assert (ddm2 != ddm1) is False + + ddm3 = DDM([[ZZ(0), ZZ(1)], [ZZ(3), ZZ(3)]], (2, 2), ZZ) + ddm3 = DDM(items, (2, 2), QQ) + + assert (ddm1 == ddm3) is False + assert (ddm3 == ddm1) is False + assert (ddm1 != ddm3) is True + assert (ddm3 != ddm1) is True + + +def test_DDM_convert_to(): + ddm = DDM([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + assert ddm.convert_to(ZZ) == ddm + ddmq = ddm.convert_to(QQ) + assert ddmq.domain == QQ + + +def test_DDM_zeros(): + ddmz = DDM.zeros((3, 4), QQ) + assert list(ddmz) == [[QQ(0)] * 4] * 3 + assert ddmz.shape == (3, 4) + assert ddmz.domain == QQ + +def test_DDM_ones(): + ddmone = DDM.ones((2, 3), QQ) + assert list(ddmone) == [[QQ(1)] * 3] * 2 + assert ddmone.shape == (2, 3) + assert ddmone.domain == QQ + +def test_DDM_eye(): + ddmz = DDM.eye(3, QQ) + f = lambda i, j: QQ(1) if i == j else QQ(0) + assert list(ddmz) == [[f(i, j) for i in range(3)] for j in range(3)] + assert ddmz.shape == (3, 3) + assert ddmz.domain == QQ + + +def test_DDM_copy(): + ddm1 = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + ddm2 = ddm1.copy() + assert (ddm1 == ddm2) is True + ddm1[0][0] = QQ(-1) + assert (ddm1 == ddm2) is False + ddm2[0][0] = QQ(-1) + assert (ddm1 == ddm2) is True + + +def test_DDM_transpose(): + ddm = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + ddmT = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + assert ddm.transpose() == ddmT + ddm02 = DDM([], (0, 2), QQ) + ddm02T = DDM([[], []], (2, 0), QQ) + assert ddm02.transpose() == ddm02T + assert ddm02T.transpose() == ddm02 + ddm0 = DDM([], (0, 0), QQ) + assert ddm0.transpose() == ddm0 + + +def test_DDM_add(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + B = DDM([[ZZ(3)], [ZZ(4)]], (2, 1), ZZ) + C = DDM([[ZZ(4)], [ZZ(6)]], (2, 1), ZZ) + AQ = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + assert A + B == A.add(B) == C + + raises(DMShapeError, lambda: A + DDM([[ZZ(5)]], (1, 1), ZZ)) + raises(TypeError, lambda: A + ZZ(1)) + raises(TypeError, lambda: ZZ(1) + A) + raises(DMDomainError, lambda: A + AQ) + raises(DMDomainError, lambda: AQ + A) + + +def test_DDM_sub(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + B = DDM([[ZZ(3)], [ZZ(4)]], (2, 1), ZZ) + C = DDM([[ZZ(-2)], [ZZ(-2)]], (2, 1), ZZ) + AQ = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + D = DDM([[ZZ(5)]], (1, 1), ZZ) + assert A - B == A.sub(B) == C + + raises(TypeError, lambda: A - ZZ(1)) + raises(TypeError, lambda: ZZ(1) - A) + raises(DMShapeError, lambda: A - D) + raises(DMShapeError, lambda: D - A) + raises(DMShapeError, lambda: A.sub(D)) + raises(DMShapeError, lambda: D.sub(A)) + raises(DMDomainError, lambda: A - AQ) + raises(DMDomainError, lambda: AQ - A) + raises(DMDomainError, lambda: A.sub(AQ)) + raises(DMDomainError, lambda: AQ.sub(A)) + + +def test_DDM_neg(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + An = DDM([[ZZ(-1)], [ZZ(-2)]], (2, 1), ZZ) + assert -A == A.neg() == An + assert -An == An.neg() == A + + +def test_DDM_mul(): + A = DDM([[ZZ(1)]], (1, 1), ZZ) + A2 = DDM([[ZZ(2)]], (1, 1), ZZ) + assert A * ZZ(2) == A2 + assert ZZ(2) * A == A2 + raises(TypeError, lambda: [[1]] * A) + raises(TypeError, lambda: A * [[1]]) + + +def test_DDM_matmul(): + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + B = DDM([[ZZ(3), ZZ(4)]], (1, 2), ZZ) + AB = DDM([[ZZ(3), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + BA = DDM([[ZZ(11)]], (1, 1), ZZ) + + assert A @ B == A.matmul(B) == AB + assert B @ A == B.matmul(A) == BA + + raises(TypeError, lambda: A @ 1) + raises(TypeError, lambda: A @ [[3, 4]]) + + Bq = DDM([[QQ(3), QQ(4)]], (1, 2), QQ) + + raises(DMDomainError, lambda: A @ Bq) + raises(DMDomainError, lambda: Bq @ A) + + C = DDM([[ZZ(1)]], (1, 1), ZZ) + + assert A @ C == A.matmul(C) == A + + raises(DMShapeError, lambda: C @ A) + raises(DMShapeError, lambda: C.matmul(A)) + + Z04 = DDM([], (0, 4), ZZ) + Z40 = DDM([[]]*4, (4, 0), ZZ) + Z50 = DDM([[]]*5, (5, 0), ZZ) + Z05 = DDM([], (0, 5), ZZ) + Z45 = DDM([[0] * 5] * 4, (4, 5), ZZ) + Z54 = DDM([[0] * 4] * 5, (5, 4), ZZ) + Z00 = DDM([], (0, 0), ZZ) + + assert Z04 @ Z45 == Z04.matmul(Z45) == Z05 + assert Z45 @ Z50 == Z45.matmul(Z50) == Z40 + assert Z00 @ Z04 == Z00.matmul(Z04) == Z04 + assert Z50 @ Z00 == Z50.matmul(Z00) == Z50 + assert Z00 @ Z00 == Z00.matmul(Z00) == Z00 + assert Z50 @ Z04 == Z50.matmul(Z04) == Z54 + + raises(DMShapeError, lambda: Z05 @ Z40) + raises(DMShapeError, lambda: Z05.matmul(Z40)) + + +def test_DDM_hstack(): + A = DDM([[ZZ(1), ZZ(2), ZZ(3)]], (1, 3), ZZ) + B = DDM([[ZZ(4), ZZ(5)]], (1, 2), ZZ) + C = DDM([[ZZ(6)]], (1, 1), ZZ) + + Ah = A.hstack(B) + assert Ah.shape == (1, 5) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1), ZZ(2), ZZ(3), ZZ(4), ZZ(5)]], (1, 5), ZZ) + + Ah = A.hstack(B, C) + assert Ah.shape == (1, 6) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1), ZZ(2), ZZ(3), ZZ(4), ZZ(5), ZZ(6)]], (1, 6), ZZ) + + +def test_DDM_vstack(): + A = DDM([[ZZ(1)], [ZZ(2)], [ZZ(3)]], (3, 1), ZZ) + B = DDM([[ZZ(4)], [ZZ(5)]], (2, 1), ZZ) + C = DDM([[ZZ(6)]], (1, 1), ZZ) + + Ah = A.vstack(B) + assert Ah.shape == (5, 1) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1)], [ZZ(2)], [ZZ(3)], [ZZ(4)], [ZZ(5)]], (5, 1), ZZ) + + Ah = A.vstack(B, C) + assert Ah.shape == (6, 1) + assert Ah.domain == ZZ + assert Ah == DDM([[ZZ(1)], [ZZ(2)], [ZZ(3)], [ZZ(4)], [ZZ(5)], [ZZ(6)]], (6, 1), ZZ) + + +def test_DDM_applyfunc(): + A = DDM([[ZZ(1), ZZ(2), ZZ(3)]], (1, 3), ZZ) + B = DDM([[ZZ(2), ZZ(4), ZZ(6)]], (1, 3), ZZ) + assert A.applyfunc(lambda x: 2*x, ZZ) == B + +def test_DDM_rref(): + + A = DDM([], (0, 4), QQ) + assert A.rref() == (A, []) + + A = DDM([[QQ(0), QQ(1)], [QQ(1), QQ(1)]], (2, 2), QQ) + Ar = DDM([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(1), QQ(2), QQ(1)], [QQ(3), QQ(4), QQ(1)]], (2, 3), QQ) + Ar = DDM([[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]], (2, 3), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(3), QQ(4), QQ(1)], [QQ(1), QQ(2), QQ(1)]], (2, 3), QQ) + Ar = DDM([[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]], (2, 3), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(1), QQ(0)], [QQ(1), QQ(3)], [QQ(0), QQ(1)]], (3, 2), QQ) + Ar = DDM([[QQ(1), QQ(0)], [QQ(0), QQ(1)], [QQ(0), QQ(0)]], (3, 2), QQ) + pivots = [0, 1] + assert A.rref() == (Ar, pivots) + + A = DDM([[QQ(1), QQ(0), QQ(1)], [QQ(3), QQ(0), QQ(1)]], (2, 3), QQ) + Ar = DDM([[QQ(1), QQ(0), QQ(0)], [QQ(0), QQ(0), QQ(1)]], (2, 3), QQ) + pivots = [0, 2] + assert A.rref() == (Ar, pivots) + + +def test_DDM_nullspace(): + # more tests are in test_nullspace.py + A = DDM([[QQ(1), QQ(1)], [QQ(1), QQ(1)]], (2, 2), QQ) + Anull = DDM([[QQ(-1), QQ(1)]], (1, 2), QQ) + nonpivots = [1] + assert A.nullspace() == (Anull, nonpivots) + + +def test_DDM_particular(): + A = DDM([[QQ(1), QQ(0)]], (1, 2), QQ) + assert A.particular() == DDM.zeros((1, 1), QQ) + + +def test_DDM_det(): + # 0x0 case + A = DDM([], (0, 0), ZZ) + assert A.det() == ZZ(1) + + # 1x1 case + A = DDM([[ZZ(2)]], (1, 1), ZZ) + assert A.det() == ZZ(2) + + # 2x2 case + A = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.det() == ZZ(-2) + + # 3x3 with swap + A = DDM([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(2), ZZ(5)]], (3, 3), ZZ) + assert A.det() == ZZ(0) + + # 2x2 QQ case + A = DDM([[QQ(1, 2), QQ(1, 2)], [QQ(1, 3), QQ(1, 4)]], (2, 2), QQ) + assert A.det() == QQ(-1, 24) + + # Nonsquare error + A = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMShapeError, lambda: A.det()) + + # Nonsquare error with empty matrix + A = DDM([], (0, 1), ZZ) + raises(DMShapeError, lambda: A.det()) + + +def test_DDM_inv(): + A = DDM([[QQ(1, 1), QQ(2, 1)], [QQ(3, 1), QQ(4, 1)]], (2, 2), QQ) + Ainv = DDM([[QQ(-2, 1), QQ(1, 1)], [QQ(3, 2), QQ(-1, 2)]], (2, 2), QQ) + assert A.inv() == Ainv + + A = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMShapeError, lambda: A.inv()) + + A = DDM([[ZZ(2)]], (1, 1), ZZ) + raises(DMDomainError, lambda: A.inv()) + + A = DDM([], (0, 0), QQ) + assert A.inv() == A + + A = DDM([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +def test_DDM_lu(): + A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + L, U, swaps = A.lu() + assert L == DDM([[QQ(1), QQ(0)], [QQ(3), QQ(1)]], (2, 2), QQ) + assert U == DDM([[QQ(1), QQ(2)], [QQ(0), QQ(-2)]], (2, 2), QQ) + assert swaps == [] + + A = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]] + Lexp = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]] + Uexp = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]] + to_dom = lambda rows, dom: [[dom(e) for e in row] for row in rows] + A = DDM(to_dom(A, QQ), (4, 4), QQ) + Lexp = DDM(to_dom(Lexp, QQ), (4, 4), QQ) + Uexp = DDM(to_dom(Uexp, QQ), (4, 4), QQ) + L, U, swaps = A.lu() + assert L == Lexp + assert U == Uexp + assert swaps == [] + + +def test_DDM_lu_solve(): + # Basic example + A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Example with swaps + A = DDM([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + assert A.lu_solve(b) == x + + # Overdetermined, consistent + A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + b = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + assert A.lu_solve(b) == x + + # Overdetermined, inconsistent + b = DDM([[QQ(1)], [QQ(2)], [QQ(4)]], (3, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Square, noninvertible + A = DDM([[QQ(1), QQ(2)], [QQ(1), QQ(2)]], (2, 2), QQ) + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Underdetermined + A = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + b = DDM([[QQ(3)]], (1, 1), QQ) + raises(NotImplementedError, lambda: A.lu_solve(b)) + + # Domain mismatch + bz = DDM([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMDomainError, lambda: A.lu_solve(bz)) + + # Shape mismatch + b3 = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + raises(DMShapeError, lambda: A.lu_solve(b3)) + + +def test_DDM_charpoly(): + A = DDM([], (0, 0), ZZ) + assert A.charpoly() == [ZZ(1)] + + A = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + Avec = [ZZ(1), ZZ(-15), ZZ(-18), ZZ(0)] + assert A.charpoly() == Avec + + A = DDM([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.charpoly()) + + +def test_DDM_getitem(): + dm = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + + assert dm.getitem(1, 1) == ZZ(5) + assert dm.getitem(1, -2) == ZZ(5) + assert dm.getitem(-1, -3) == ZZ(7) + + raises(IndexError, lambda: dm.getitem(3, 3)) + + +def test_DDM_setitem(): + dm = DDM.zeros((3, 3), ZZ) + dm.setitem(0, 0, 1) + dm.setitem(1, -2, 1) + dm.setitem(-1, -1, 1) + assert dm == DDM.eye(3, ZZ) + + raises(IndexError, lambda: dm.setitem(3, 3, 0)) + + +def test_DDM_extract_slice(): + dm = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + + assert dm.extract_slice(slice(0, 3), slice(0, 3)) == dm + assert dm.extract_slice(slice(1, 3), slice(-2)) == DDM([[4], [7]], (2, 1), ZZ) + assert dm.extract_slice(slice(1, 3), slice(-2)) == DDM([[4], [7]], (2, 1), ZZ) + assert dm.extract_slice(slice(2, 3), slice(-2)) == DDM([[ZZ(7)]], (1, 1), ZZ) + assert dm.extract_slice(slice(0, 2), slice(-2)) == DDM([[1], [4]], (2, 1), ZZ) + assert dm.extract_slice(slice(-1), slice(-1)) == DDM([[1, 2], [4, 5]], (2, 2), ZZ) + + assert dm.extract_slice(slice(2), slice(3, 4)) == DDM([[], []], (2, 0), ZZ) + assert dm.extract_slice(slice(3, 4), slice(2)) == DDM([], (0, 2), ZZ) + assert dm.extract_slice(slice(3, 4), slice(3, 4)) == DDM([], (0, 0), ZZ) + + +def test_DDM_extract(): + dm1 = DDM([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + dm2 = DDM([ + [ZZ(6), ZZ(4)], + [ZZ(3), ZZ(1)]], (2, 2), ZZ) + assert dm1.extract([1, 0], [2, 0]) == dm2 + assert dm1.extract([-2, 0], [-1, 0]) == dm2 + + assert dm1.extract([], []) == DDM.zeros((0, 0), ZZ) + assert dm1.extract([1], []) == DDM.zeros((1, 0), ZZ) + assert dm1.extract([], [1]) == DDM.zeros((0, 1), ZZ) + + raises(IndexError, lambda: dm2.extract([2], [0])) + raises(IndexError, lambda: dm2.extract([0], [2])) + raises(IndexError, lambda: dm2.extract([-3], [0])) + raises(IndexError, lambda: dm2.extract([0], [-3])) + + +def test_DDM_flat(): + dm = DDM([ + [ZZ(6), ZZ(4)], + [ZZ(3), ZZ(1)]], (2, 2), ZZ) + assert dm.flat() == [ZZ(6), ZZ(4), ZZ(3), ZZ(1)] + + +def test_DDM_is_zero_matrix(): + A = DDM([[QQ(1), QQ(0)], [QQ(0), QQ(0)]], (2, 2), QQ) + Azero = DDM.zeros((1, 2), QQ) + assert A.is_zero_matrix() is False + assert Azero.is_zero_matrix() is True + + +def test_DDM_is_upper(): + # Wide matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(0), QQ(8), QQ(9)] + ], (3, 4), QQ) + B = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(7), QQ(8), QQ(9)] + ], (3, 4), QQ) + assert A.is_upper() is True + assert B.is_upper() is False + + # Tall matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(0)] + ], (4, 3), QQ) + B = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(10)] + ], (4, 3), QQ) + assert A.is_upper() is True + assert B.is_upper() is False + + +def test_DDM_is_lower(): + # Tall matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(0), QQ(8), QQ(9)] + ], (3, 4), QQ).transpose() + B = DDM([ + [QQ(1), QQ(2), QQ(3), QQ(4)], + [QQ(0), QQ(5), QQ(6), QQ(7)], + [QQ(0), QQ(7), QQ(8), QQ(9)] + ], (3, 4), QQ).transpose() + assert A.is_lower() is True + assert B.is_lower() is False + + # Wide matrices: + A = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(0)] + ], (4, 3), QQ).transpose() + B = DDM([ + [QQ(1), QQ(2), QQ(3)], + [QQ(0), QQ(5), QQ(6)], + [QQ(0), QQ(0), QQ(8)], + [QQ(0), QQ(0), QQ(10)] + ], (4, 3), QQ).transpose() + assert A.is_lower() is True + assert B.is_lower() is False diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_dense.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_dense.py new file mode 100644 index 0000000000000000000000000000000000000000..75315ebf6b2ae7d53b4a5737578d3ac5ed4ea36a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_dense.py @@ -0,0 +1,350 @@ +from sympy.testing.pytest import raises + +from sympy.polys import ZZ, QQ + +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.dense import ( + ddm_transpose, + ddm_iadd, ddm_isub, ddm_ineg, ddm_imatmul, ddm_imul, ddm_irref, + ddm_idet, ddm_iinv, ddm_ilu, ddm_ilu_split, ddm_ilu_solve, ddm_berk) + +from sympy.polys.matrices.exceptions import ( + DMDomainError, + DMNonInvertibleMatrixError, + DMNonSquareMatrixError, + DMShapeError, +) + + +def test_ddm_transpose(): + a = [[1, 2], [3, 4]] + assert ddm_transpose(a) == [[1, 3], [2, 4]] + + +def test_ddm_iadd(): + a = [[1, 2], [3, 4]] + b = [[5, 6], [7, 8]] + ddm_iadd(a, b) + assert a == [[6, 8], [10, 12]] + + +def test_ddm_isub(): + a = [[1, 2], [3, 4]] + b = [[5, 6], [7, 8]] + ddm_isub(a, b) + assert a == [[-4, -4], [-4, -4]] + + +def test_ddm_ineg(): + a = [[1, 2], [3, 4]] + ddm_ineg(a) + assert a == [[-1, -2], [-3, -4]] + + +def test_ddm_matmul(): + a = [[1, 2], [3, 4]] + ddm_imul(a, 2) + assert a == [[2, 4], [6, 8]] + + a = [[1, 2], [3, 4]] + ddm_imul(a, 0) + assert a == [[0, 0], [0, 0]] + + +def test_ddm_imatmul(): + a = [[1, 2, 3], [4, 5, 6]] + b = [[1, 2], [3, 4], [5, 6]] + + c1 = [[0, 0], [0, 0]] + ddm_imatmul(c1, a, b) + assert c1 == [[22, 28], [49, 64]] + + c2 = [[0, 0, 0], [0, 0, 0], [0, 0, 0]] + ddm_imatmul(c2, b, a) + assert c2 == [[9, 12, 15], [19, 26, 33], [29, 40, 51]] + + b3 = [[1], [2], [3]] + c3 = [[0], [0]] + ddm_imatmul(c3, a, b3) + assert c3 == [[14], [32]] + + +def test_ddm_irref(): + # Empty matrix + A = [] + Ar = [] + pivots = [] + assert ddm_irref(A) == pivots + assert A == Ar + + # Standard square case + A = [[QQ(0), QQ(1)], [QQ(1), QQ(1)]] + Ar = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # m < n case + A = [[QQ(1), QQ(2), QQ(1)], [QQ(3), QQ(4), QQ(1)]] + Ar = [[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # same m < n but reversed + A = [[QQ(3), QQ(4), QQ(1)], [QQ(1), QQ(2), QQ(1)]] + Ar = [[QQ(1), QQ(0), QQ(-1)], [QQ(0), QQ(1), QQ(1)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # m > n case + A = [[QQ(1), QQ(0)], [QQ(1), QQ(3)], [QQ(0), QQ(1)]] + Ar = [[QQ(1), QQ(0)], [QQ(0), QQ(1)], [QQ(0), QQ(0)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + # Example with missing pivot + A = [[QQ(1), QQ(0), QQ(1)], [QQ(3), QQ(0), QQ(1)]] + Ar = [[QQ(1), QQ(0), QQ(0)], [QQ(0), QQ(0), QQ(1)]] + pivots = [0, 2] + assert ddm_irref(A) == pivots + assert A == Ar + + # Example with missing pivot and no replacement + A = [[QQ(0), QQ(1)], [QQ(0), QQ(2)], [QQ(1), QQ(0)]] + Ar = [[QQ(1), QQ(0)], [QQ(0), QQ(1)], [QQ(0), QQ(0)]] + pivots = [0, 1] + assert ddm_irref(A) == pivots + assert A == Ar + + +def test_ddm_idet(): + A = [] + assert ddm_idet(A, ZZ) == ZZ(1) + + A = [[ZZ(2)]] + assert ddm_idet(A, ZZ) == ZZ(2) + + A = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + assert ddm_idet(A, ZZ) == ZZ(-2) + + A = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(3), ZZ(5)]] + assert ddm_idet(A, ZZ) == ZZ(-1) + + A = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(2), ZZ(5)]] + assert ddm_idet(A, ZZ) == ZZ(0) + + A = [[QQ(1, 2), QQ(1, 2)], [QQ(1, 3), QQ(1, 4)]] + assert ddm_idet(A, QQ) == QQ(-1, 24) + + +def test_ddm_inv(): + A = [] + Ainv = [] + ddm_iinv(Ainv, A, QQ) + assert Ainv == A + + A = [] + Ainv = [] + raises(DMDomainError, lambda: ddm_iinv(Ainv, A, ZZ)) + + A = [[QQ(1), QQ(2)]] + Ainv = [[QQ(0), QQ(0)]] + raises(DMNonSquareMatrixError, lambda: ddm_iinv(Ainv, A, QQ)) + + A = [[QQ(1, 1), QQ(2, 1)], [QQ(3, 1), QQ(4, 1)]] + Ainv = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + Ainv_expected = [[QQ(-2, 1), QQ(1, 1)], [QQ(3, 2), QQ(-1, 2)]] + ddm_iinv(Ainv, A, QQ) + assert Ainv == Ainv_expected + + A = [[QQ(1, 1), QQ(2, 1)], [QQ(2, 1), QQ(4, 1)]] + Ainv = [[QQ(0), QQ(0)], [QQ(0), QQ(0)]] + raises(DMNonInvertibleMatrixError, lambda: ddm_iinv(Ainv, A, QQ)) + + +def test_ddm_ilu(): + A = [] + Alu = [] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[]] + Alu = [[]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + Alu = [[QQ(1), QQ(2)], [QQ(3), QQ(-2)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(0), QQ(2)], [QQ(3), QQ(4)]] + Alu = [[QQ(3), QQ(4)], [QQ(0), QQ(2)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [(0, 1)] + + A = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)], [QQ(7), QQ(8), QQ(9)]] + Alu = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(-3), QQ(-6)], [QQ(7), QQ(2), QQ(0)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(0), QQ(1), QQ(2)], [QQ(0), QQ(1), QQ(3)], [QQ(1), QQ(1), QQ(2)]] + Alu = [[QQ(1), QQ(1), QQ(2)], [QQ(0), QQ(1), QQ(3)], [QQ(0), QQ(1), QQ(-1)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [(0, 2)] + + A = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]] + Alu = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(-3), QQ(-6)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]] + Alu = [[QQ(1), QQ(2)], [QQ(3), QQ(-2)], [QQ(5), QQ(2)]] + swaps = ddm_ilu(A) + assert A == Alu + assert swaps == [] + + +def test_ddm_ilu_split(): + U = [] + L = [] + Uexp = [] + Lexp = [] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[]] + L = [[QQ(1)]] + Uexp = [[]] + Lexp = [[QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + L = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + Uexp = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)]] + Lexp = [[QQ(1), QQ(0)], [QQ(3), QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]] + L = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + Uexp = [[QQ(1), QQ(2), QQ(3)], [QQ(0), QQ(-3), QQ(-6)]] + Lexp = [[QQ(1), QQ(0)], [QQ(4), QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + U = [[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]] + L = [[QQ(1), QQ(0), QQ(0)], [QQ(0), QQ(1), QQ(0)], [QQ(0), QQ(0), QQ(1)]] + Uexp = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)], [QQ(0), QQ(0)]] + Lexp = [[QQ(1), QQ(0), QQ(0)], [QQ(3), QQ(1), QQ(0)], [QQ(5), QQ(2), QQ(1)]] + swaps = ddm_ilu_split(L, U, QQ) + assert U == Uexp + assert L == Lexp + assert swaps == [] + + +def test_ddm_ilu_solve(): + # Basic example + # A = [[QQ(1), QQ(2)], [QQ(3), QQ(4)]] + U = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)]] + L = [[QQ(1), QQ(0)], [QQ(3), QQ(1)]] + swaps = [] + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DDM([[QQ(0)], [QQ(0)]], (2, 1), QQ) + xexp = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + ddm_ilu_solve(x, L, U, swaps, b) + assert x == xexp + + # Example with swaps + # A = [[QQ(0), QQ(2)], [QQ(3), QQ(4)]] + U = [[QQ(3), QQ(4)], [QQ(0), QQ(2)]] + L = [[QQ(1), QQ(0)], [QQ(0), QQ(1)]] + swaps = [(0, 1)] + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DDM([[QQ(0)], [QQ(0)]], (2, 1), QQ) + xexp = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + ddm_ilu_solve(x, L, U, swaps, b) + assert x == xexp + + # Overdetermined, consistent + # A = DDM([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + U = [[QQ(1), QQ(2)], [QQ(0), QQ(-2)], [QQ(0), QQ(0)]] + L = [[QQ(1), QQ(0), QQ(0)], [QQ(3), QQ(1), QQ(0)], [QQ(5), QQ(2), QQ(1)]] + swaps = [] + b = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + x = DDM([[QQ(0)], [QQ(0)]], (2, 1), QQ) + xexp = DDM([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + ddm_ilu_solve(x, L, U, swaps, b) + assert x == xexp + + # Overdetermined, inconsistent + b = DDM([[QQ(1)], [QQ(2)], [QQ(4)]], (3, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Square, noninvertible + # A = DDM([[QQ(1), QQ(2)], [QQ(1), QQ(2)]], (2, 2), QQ) + U = [[QQ(1), QQ(2)], [QQ(0), QQ(0)]] + L = [[QQ(1), QQ(0)], [QQ(1), QQ(1)]] + swaps = [] + b = DDM([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Underdetermined + # A = DDM([[QQ(1), QQ(2)]], (1, 2), QQ) + U = [[QQ(1), QQ(2)]] + L = [[QQ(1)]] + swaps = [] + b = DDM([[QQ(3)]], (1, 1), QQ) + raises(NotImplementedError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Shape mismatch + b3 = DDM([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + raises(DMShapeError, lambda: ddm_ilu_solve(x, L, U, swaps, b3)) + + # Empty shape mismatch + U = [[QQ(1)]] + L = [[QQ(1)]] + swaps = [] + x = [[QQ(1)]] + b = [] + raises(DMShapeError, lambda: ddm_ilu_solve(x, L, U, swaps, b)) + + # Empty system + U = [] + L = [] + swaps = [] + b = [] + x = [] + ddm_ilu_solve(x, L, U, swaps, b) + assert x == [] + + +def test_ddm_charpoly(): + A = [] + assert ddm_berk(A, ZZ) == [[ZZ(1)]] + + A = [[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)], [ZZ(7), ZZ(8), ZZ(9)]] + Avec = [[ZZ(1)], [ZZ(-15)], [ZZ(-18)], [ZZ(0)]] + assert ddm_berk(A, ZZ) == Avec + + A = DDM([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: ddm_berk(A, ZZ)) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_domainmatrix.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_domainmatrix.py new file mode 100644 index 0000000000000000000000000000000000000000..f75124d3822f026505e5f182658e70ba197a3c16 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_domainmatrix.py @@ -0,0 +1,1345 @@ +from sympy.external.gmpy import GROUND_TYPES + +from sympy import Integer, Rational, S, sqrt, Matrix, symbols +from sympy import FF, ZZ, QQ, QQ_I, EXRAW + +from sympy.polys.matrices.domainmatrix import DomainMatrix, DomainScalar, DM +from sympy.polys.matrices.exceptions import ( + DMBadInputError, DMDomainError, DMShapeError, DMFormatError, DMNotAField, + DMNonSquareMatrixError, DMNonInvertibleMatrixError, +) +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.sdm import SDM + +from sympy.testing.pytest import raises + + +def test_DM(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DM([[1, 2], [3, 4]], ZZ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + +def test_DomainMatrix_init(): + lol = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + dod = {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + ddm = DDM(lol, (2, 2), ZZ) + sdm = SDM(dod, (2, 2), ZZ) + + A = DomainMatrix(lol, (2, 2), ZZ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + A = DomainMatrix(dod, (2, 2), ZZ) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == ZZ + + raises(TypeError, lambda: DomainMatrix(ddm, (2, 2), ZZ)) + raises(TypeError, lambda: DomainMatrix(sdm, (2, 2), ZZ)) + raises(TypeError, lambda: DomainMatrix(Matrix([[1]]), (1, 1), ZZ)) + + for fmt, rep in [('sparse', sdm), ('dense', ddm)]: + if fmt == 'dense' and GROUND_TYPES == 'flint': + rep = rep.to_dfm() + A = DomainMatrix(lol, (2, 2), ZZ, fmt=fmt) + assert A.rep == rep + A = DomainMatrix(dod, (2, 2), ZZ, fmt=fmt) + assert A.rep == rep + + raises(ValueError, lambda: DomainMatrix(lol, (2, 2), ZZ, fmt='invalid')) + + raises(DMBadInputError, lambda: DomainMatrix([[ZZ(1), ZZ(2)]], (2, 2), ZZ)) + + +def test_DomainMatrix_from_rep(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DomainMatrix.from_rep(ddm) + # XXX: Should from_rep convert to DFM? + assert A.rep == ddm + assert A.shape == (2, 2) + assert A.domain == ZZ + + sdm = SDM({0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + A = DomainMatrix.from_rep(sdm) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == ZZ + + A = DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + raises(TypeError, lambda: DomainMatrix.from_rep(A)) + + +def test_DomainMatrix_from_list(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DomainMatrix.from_list([[1, 2], [3, 4]], ZZ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + dom = FF(7) + ddm = DDM([[dom(1), dom(2)], [dom(3), dom(4)]], (2, 2), dom) + A = DomainMatrix.from_list([[1, 2], [3, 4]], dom) + # Not a DFM because FF(7) is not supported by DFM + assert A.rep == ddm + assert A.shape == (2, 2) + assert A.domain == dom + + ddm = DDM([[QQ(1, 2), QQ(3, 1)], [QQ(1, 4), QQ(5, 1)]], (2, 2), QQ) + A = DomainMatrix.from_list([[(1, 2), (3, 1)], [(1, 4), (5, 1)]], QQ) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == QQ + + +def test_DomainMatrix_from_list_sympy(): + ddm = DDM([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A = DomainMatrix.from_list_sympy(2, 2, [[1, 2], [3, 4]]) + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == ZZ + + K = QQ.algebraic_field(sqrt(2)) + ddm = DDM( + [[K.convert(1 + sqrt(2)), K.convert(2 + sqrt(2))], + [K.convert(3 + sqrt(2)), K.convert(4 + sqrt(2))]], + (2, 2), + K + ) + A = DomainMatrix.from_list_sympy( + 2, 2, [[1 + sqrt(2), 2 + sqrt(2)], [3 + sqrt(2), 4 + sqrt(2)]], + extension=True) + assert A.rep == ddm + assert A.shape == (2, 2) + assert A.domain == K + + +def test_DomainMatrix_from_dict_sympy(): + sdm = SDM({0: {0: QQ(1, 2)}, 1: {1: QQ(2, 3)}}, (2, 2), QQ) + sympy_dict = {0: {0: Rational(1, 2)}, 1: {1: Rational(2, 3)}} + A = DomainMatrix.from_dict_sympy(2, 2, sympy_dict) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == QQ + + fds = DomainMatrix.from_dict_sympy + raises(DMBadInputError, lambda: fds(2, 2, {3: {0: Rational(1, 2)}})) + raises(DMBadInputError, lambda: fds(2, 2, {0: {3: Rational(1, 2)}})) + + +def test_DomainMatrix_from_Matrix(): + sdm = SDM({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}}, (2, 2), ZZ) + A = DomainMatrix.from_Matrix(Matrix([[1, 2], [3, 4]])) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == ZZ + + K = QQ.algebraic_field(sqrt(2)) + sdm = SDM( + {0: {0: K.convert(1 + sqrt(2)), 1: K.convert(2 + sqrt(2))}, + 1: {0: K.convert(3 + sqrt(2)), 1: K.convert(4 + sqrt(2))}}, + (2, 2), + K + ) + A = DomainMatrix.from_Matrix( + Matrix([[1 + sqrt(2), 2 + sqrt(2)], [3 + sqrt(2), 4 + sqrt(2)]]), + extension=True) + assert A.rep == sdm + assert A.shape == (2, 2) + assert A.domain == K + + A = DomainMatrix.from_Matrix(Matrix([[QQ(1, 2), QQ(3, 4)], [QQ(0, 1), QQ(0, 1)]]), fmt='dense') + ddm = DDM([[QQ(1, 2), QQ(3, 4)], [QQ(0, 1), QQ(0, 1)]], (2, 2), QQ) + + if GROUND_TYPES != 'flint': + assert A.rep == ddm + else: + assert A.rep == ddm.to_dfm() + assert A.shape == (2, 2) + assert A.domain == QQ + + +def test_DomainMatrix_eq(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A == A + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(1)]], (2, 2), ZZ) + assert A != B + C = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + assert A != C + + +def test_DomainMatrix_unify_eq(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B1 = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + B2 = DomainMatrix([[QQ(1), QQ(3)], [QQ(3), QQ(4)]], (2, 2), QQ) + B3 = DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + assert A.unify_eq(B1) is True + assert A.unify_eq(B2) is False + assert A.unify_eq(B3) is False + + +def test_DomainMatrix_get_domain(): + K, items = DomainMatrix.get_domain([1, 2, 3, 4]) + assert items == [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + assert K == ZZ + + K, items = DomainMatrix.get_domain([1, 2, 3, Rational(1, 2)]) + assert items == [QQ(1), QQ(2), QQ(3), QQ(1, 2)] + assert K == QQ + + +def test_DomainMatrix_convert_to(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = A.convert_to(QQ) + assert Aq == DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + + +def test_DomainMatrix_choose_domain(): + A = [[1, 2], [3, 0]] + assert DM(A, QQ).choose_domain() == DM(A, ZZ) + assert DM(A, QQ).choose_domain(field=True) == DM(A, QQ) + assert DM(A, ZZ).choose_domain(field=True) == DM(A, QQ) + + x = symbols('x') + B = [[1, x], [x**2, x**3]] + assert DM(B, QQ[x]).choose_domain(field=True) == DM(B, ZZ.frac_field(x)) + + +def test_DomainMatrix_to_flat_nz(): + Adm = DM([[1, 2], [3, 0]], ZZ) + Addm = Adm.rep.to_ddm() + Asdm = Adm.rep.to_sdm() + for A in [Adm, Addm, Asdm]: + elems, data = A.to_flat_nz() + assert A.from_flat_nz(elems, data, A.domain) == A + elemsq = [QQ(e) for e in elems] + assert A.from_flat_nz(elemsq, data, QQ) == A.convert_to(QQ) + elems2 = [2*e for e in elems] + assert A.from_flat_nz(elems2, data, A.domain) == 2*A + + +def test_DomainMatrix_to_sympy(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_sympy() == A.convert_to(EXRAW) + + +def test_DomainMatrix_to_field(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = A.to_field() + assert Aq == DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + + +def test_DomainMatrix_to_sparse(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A_sparse = A.to_sparse() + assert A_sparse.rep == {0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}} + + +def test_DomainMatrix_to_dense(): + A = DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + A_dense = A.to_dense() + ddm = DDM([[1, 2], [3, 4]], (2, 2), ZZ) + if GROUND_TYPES != 'flint': + assert A_dense.rep == ddm + else: + assert A_dense.rep == ddm.to_dfm() + + +def test_DomainMatrix_unify(): + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + assert Az.unify(Az) == (Az, Az) + assert Az.unify(Aq) == (Aq, Aq) + assert Aq.unify(Az) == (Aq, Aq) + assert Aq.unify(Aq) == (Aq, Aq) + + As = DomainMatrix({0: {1: ZZ(1)}, 1:{0:ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + assert As.unify(As) == (As, As) + assert Ad.unify(Ad) == (Ad, Ad) + + Bs, Bd = As.unify(Ad, fmt='dense') + assert Bs.rep == DDM([[0, 1], [2, 0]], (2, 2), ZZ).to_dfm_or_ddm() + assert Bd.rep == DDM([[1, 2],[3, 4]], (2, 2), ZZ).to_dfm_or_ddm() + + Bs, Bd = As.unify(Ad, fmt='sparse') + assert Bs.rep == SDM({0: {1: 1}, 1: {0: 2}}, (2, 2), ZZ) + assert Bd.rep == SDM({0: {0: 1, 1: 2}, 1: {0: 3, 1: 4}}, (2, 2), ZZ) + + raises(ValueError, lambda: As.unify(Ad, fmt='invalid')) + + +def test_DomainMatrix_to_Matrix(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A_Matrix = Matrix([[1, 2], [3, 4]]) + assert A.to_Matrix() == A_Matrix + assert A.to_sparse().to_Matrix() == A_Matrix + assert A.convert_to(QQ).to_Matrix() == A_Matrix + assert A.convert_to(QQ.algebraic_field(sqrt(2))).to_Matrix() == A_Matrix + + +def test_DomainMatrix_to_list(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_list() == [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + + +def test_DomainMatrix_to_list_flat(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_list_flat() == [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + + +def test_DomainMatrix_flat(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.flat() == [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + + +def test_DomainMatrix_from_list_flat(): + nums = [ZZ(1), ZZ(2), ZZ(3), ZZ(4)] + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + assert DomainMatrix.from_list_flat(nums, (2, 2), ZZ) == A + assert DDM.from_list_flat(nums, (2, 2), ZZ) == A.rep.to_ddm() + assert SDM.from_list_flat(nums, (2, 2), ZZ) == A.rep.to_sdm() + + assert A == A.from_list_flat(A.to_list_flat(), A.shape, A.domain) + + raises(DMBadInputError, DomainMatrix.from_list_flat, nums, (2, 3), ZZ) + raises(DMBadInputError, DDM.from_list_flat, nums, (2, 3), ZZ) + raises(DMBadInputError, SDM.from_list_flat, nums, (2, 3), ZZ) + + +def test_DomainMatrix_to_dod(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_dod() == {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(4)]], (2, 2), ZZ) + assert A.to_dod() == {0: {0: ZZ(1)}, 1: {1: ZZ(4)}} + + +def test_DomainMatrix_from_dod(): + items = {0: {0: ZZ(1), 1:ZZ(2)}, 1: {0:ZZ(3), 1:ZZ(4)}} + A = DM([[1, 2], [3, 4]], ZZ) + assert DomainMatrix.from_dod(items, (2, 2), ZZ) == A.to_sparse() + assert A.from_dod_like(items) == A + assert A.from_dod_like(items, QQ) == A.convert_to(QQ) + + +def test_DomainMatrix_to_dok(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_dok() == {(0, 0):ZZ(1), (0, 1):ZZ(2), (1, 0):ZZ(3), (1, 1):ZZ(4)} + A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(0), ZZ(4)]], (2, 2), ZZ) + dok = {(0, 0):ZZ(1), (1, 1):ZZ(4)} + assert A.to_dok() == dok + assert A.to_dense().to_dok() == dok + assert A.to_sparse().to_dok() == dok + assert A.rep.to_ddm().to_dok() == dok + assert A.rep.to_sdm().to_dok() == dok + + +def test_DomainMatrix_from_dok(): + items = {(0, 0): ZZ(1), (1, 1): ZZ(2)} + A = DM([[1, 0], [0, 2]], ZZ) + assert DomainMatrix.from_dok(items, (2, 2), ZZ) == A.to_sparse() + assert DDM.from_dok(items, (2, 2), ZZ) == A.rep.to_ddm() + assert SDM.from_dok(items, (2, 2), ZZ) == A.rep.to_sdm() + + +def test_DomainMatrix_repr(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert repr(A) == 'DomainMatrix([[1, 2], [3, 4]], (2, 2), ZZ)' + + +def test_DomainMatrix_transpose(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + AT = DomainMatrix([[ZZ(1), ZZ(3)], [ZZ(2), ZZ(4)]], (2, 2), ZZ) + assert A.transpose() == AT + + +def test_DomainMatrix_is_zero_matrix(): + A = DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + B = DomainMatrix([[ZZ(0)]], (1, 1), ZZ) + assert A.is_zero_matrix is False + assert B.is_zero_matrix is True + + +def test_DomainMatrix_is_upper(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(0), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.is_upper is True + assert B.is_upper is False + + +def test_DomainMatrix_is_lower(): + A = DomainMatrix([[ZZ(1), ZZ(0)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.is_lower is True + assert B.is_lower is False + + +def test_DomainMatrix_is_diagonal(): + A = DM([[1, 0], [0, 4]], ZZ) + B = DM([[1, 2], [3, 4]], ZZ) + assert A.is_diagonal is A.to_sparse().is_diagonal is True + assert B.is_diagonal is B.to_sparse().is_diagonal is False + + +def test_DomainMatrix_is_square(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)], [ZZ(5), ZZ(6)]], (3, 2), ZZ) + assert A.is_square is True + assert B.is_square is False + + +def test_DomainMatrix_diagonal(): + A = DM([[1, 2], [3, 4]], ZZ) + assert A.diagonal() == A.to_sparse().diagonal() == [ZZ(1), ZZ(4)] + A = DM([[1, 2], [3, 4], [5, 6]], ZZ) + assert A.diagonal() == A.to_sparse().diagonal() == [ZZ(1), ZZ(4)] + A = DM([[1, 2, 3], [4, 5, 6]], ZZ) + assert A.diagonal() == A.to_sparse().diagonal() == [ZZ(1), ZZ(5)] + + +def test_DomainMatrix_rank(): + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(6), QQ(8)]], (3, 2), QQ) + assert A.rank() == 2 + + +def test_DomainMatrix_add(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + assert A + A == A.add(A) == B + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + L = [[2, 3], [3, 4]] + raises(TypeError, lambda: A + L) + raises(TypeError, lambda: L + A) + + A1 = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A1 + A2) + raises(DMShapeError, lambda: A2 + A1) + raises(DMShapeError, lambda: A1.add(A2)) + raises(DMShapeError, lambda: A2.add(A1)) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Asum = DomainMatrix([[QQ(2), QQ(4)], [QQ(6), QQ(8)]], (2, 2), QQ) + assert Az + Aq == Asum + assert Aq + Az == Asum + raises(DMDomainError, lambda: Az.add(Aq)) + raises(DMDomainError, lambda: Aq.add(Az)) + + As = DomainMatrix({0: {1: ZZ(1)}, 1: {0: ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + Asd = As + Ad + Ads = Ad + As + assert Asd == DomainMatrix([[1, 3], [5, 4]], (2, 2), ZZ) + assert Asd.rep == DDM([[1, 3], [5, 4]], (2, 2), ZZ).to_dfm_or_ddm() + assert Ads == DomainMatrix([[1, 3], [5, 4]], (2, 2), ZZ) + assert Ads.rep == DDM([[1, 3], [5, 4]], (2, 2), ZZ).to_dfm_or_ddm() + raises(DMFormatError, lambda: As.add(Ad)) + + +def test_DomainMatrix_sub(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(0), ZZ(0)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + assert A - A == A.sub(A) == B + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + L = [[2, 3], [3, 4]] + raises(TypeError, lambda: A - L) + raises(TypeError, lambda: L - A) + + A1 = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A1 - A2) + raises(DMShapeError, lambda: A2 - A1) + raises(DMShapeError, lambda: A1.sub(A2)) + raises(DMShapeError, lambda: A2.sub(A1)) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Adiff = DomainMatrix([[QQ(0), QQ(0)], [QQ(0), QQ(0)]], (2, 2), QQ) + assert Az - Aq == Adiff + assert Aq - Az == Adiff + raises(DMDomainError, lambda: Az.sub(Aq)) + raises(DMDomainError, lambda: Aq.sub(Az)) + + As = DomainMatrix({0: {1: ZZ(1)}, 1: {0: ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + Asd = As - Ad + Ads = Ad - As + assert Asd == DomainMatrix([[-1, -1], [-1, -4]], (2, 2), ZZ) + assert Asd.rep == DDM([[-1, -1], [-1, -4]], (2, 2), ZZ).to_dfm_or_ddm() + assert Asd == -Ads + assert Asd.rep == -Ads.rep + + +def test_DomainMatrix_neg(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aneg = DomainMatrix([[ZZ(-1), ZZ(-2)], [ZZ(-3), ZZ(-4)]], (2, 2), ZZ) + assert -A == A.neg() == Aneg + + +def test_DomainMatrix_mul(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(7), ZZ(10)], [ZZ(15), ZZ(22)]], (2, 2), ZZ) + assert A*A == A.matmul(A) == A2 + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + L = [[1, 2], [3, 4]] + raises(TypeError, lambda: A * L) + raises(TypeError, lambda: L * A) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Aq = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Aprod = DomainMatrix([[QQ(7), QQ(10)], [QQ(15), QQ(22)]], (2, 2), QQ) + assert Az * Aq == Aprod + assert Aq * Az == Aprod + raises(DMDomainError, lambda: Az.matmul(Aq)) + raises(DMDomainError, lambda: Aq.matmul(Az)) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + AA = DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + x = ZZ(2) + assert A * x == x * A == A.mul(x) == AA + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + AA = DomainMatrix.zeros((2, 2), ZZ) + x = ZZ(0) + assert A * x == x * A == A.mul(x).to_sparse() == AA + + As = DomainMatrix({0: {1: ZZ(1)}, 1: {0: ZZ(2)}}, (2, 2), ZZ) + Ad = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + + Asd = As * Ad + Ads = Ad * As + assert Asd == DomainMatrix([[3, 4], [2, 4]], (2, 2), ZZ) + assert Asd.rep == DDM([[3, 4], [2, 4]], (2, 2), ZZ).to_dfm_or_ddm() + assert Ads == DomainMatrix([[4, 1], [8, 3]], (2, 2), ZZ) + assert Ads.rep == DDM([[4, 1], [8, 3]], (2, 2), ZZ).to_dfm_or_ddm() + + +def test_DomainMatrix_mul_elementwise(): + A = DomainMatrix([[ZZ(2), ZZ(2)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(4), ZZ(0)], [ZZ(3), ZZ(0)]], (2, 2), ZZ) + C = DomainMatrix([[ZZ(8), ZZ(0)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + assert A.mul_elementwise(B) == C + assert B.mul_elementwise(A) == C + + +def test_DomainMatrix_pow(): + eye = DomainMatrix.eye(2, ZZ) + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + A2 = DomainMatrix([[ZZ(7), ZZ(10)], [ZZ(15), ZZ(22)]], (2, 2), ZZ) + A3 = DomainMatrix([[ZZ(37), ZZ(54)], [ZZ(81), ZZ(118)]], (2, 2), ZZ) + assert A**0 == A.pow(0) == eye + assert A**1 == A.pow(1) == A + assert A**2 == A.pow(2) == A2 + assert A**3 == A.pow(3) == A3 + + raises(TypeError, lambda: A ** Rational(1, 2)) + raises(NotImplementedError, lambda: A ** -1) + raises(NotImplementedError, lambda: A.pow(-1)) + + A = DomainMatrix.zeros((2, 1), ZZ) + raises(DMNonSquareMatrixError, lambda: A ** 1) + + +def test_DomainMatrix_clear_denoms(): + A = DM([[(1,2),(1,3)],[(1,4),(1,5)]], QQ) + + den_Z = DomainScalar(ZZ(60), ZZ) + Anum_Z = DM([[30, 20], [15, 12]], ZZ) + Anum_Q = Anum_Z.convert_to(QQ) + + assert A.clear_denoms() == (den_Z, Anum_Q) + assert A.clear_denoms(convert=True) == (den_Z, Anum_Z) + assert A * den_Z == Anum_Q + assert A == Anum_Q / den_Z + + +def test_DomainMatrix_clear_denoms_rowwise(): + A = DM([[(1,2),(1,3)],[(1,4),(1,5)]], QQ) + + den_Z = DM([[6, 0], [0, 20]], ZZ).to_sparse() + Anum_Z = DM([[3, 2], [5, 4]], ZZ) + Anum_Q = DM([[3, 2], [5, 4]], QQ) + + assert A.clear_denoms_rowwise() == (den_Z, Anum_Q) + assert A.clear_denoms_rowwise(convert=True) == (den_Z, Anum_Z) + assert den_Z * A == Anum_Q + assert A == den_Z.to_field().inv() * Anum_Q + + A = DM([[(1,2),(1,3),0,0],[0,0,0,0], [(1,4),(1,5),(1,6),(1,7)]], QQ) + den_Z = DM([[6, 0, 0], [0, 1, 0], [0, 0, 420]], ZZ).to_sparse() + Anum_Z = DM([[3, 2, 0, 0], [0, 0, 0, 0], [105, 84, 70, 60]], ZZ) + Anum_Q = Anum_Z.convert_to(QQ) + + assert A.clear_denoms_rowwise() == (den_Z, Anum_Q) + assert A.clear_denoms_rowwise(convert=True) == (den_Z, Anum_Z) + assert den_Z * A == Anum_Q + assert A == den_Z.to_field().inv() * Anum_Q + + +def test_DomainMatrix_cancel_denom(): + A = DM([[2, 4], [6, 8]], ZZ) + assert A.cancel_denom(ZZ(1)) == (DM([[2, 4], [6, 8]], ZZ), ZZ(1)) + assert A.cancel_denom(ZZ(3)) == (DM([[2, 4], [6, 8]], ZZ), ZZ(3)) + assert A.cancel_denom(ZZ(4)) == (DM([[1, 2], [3, 4]], ZZ), ZZ(2)) + + A = DM([[1, 2], [3, 4]], ZZ) + assert A.cancel_denom(ZZ(2)) == (A, ZZ(2)) + assert A.cancel_denom(ZZ(-2)) == (-A, ZZ(2)) + + # Test canonicalization of denominator over Gaussian rationals. + A = DM([[1, 2], [3, 4]], QQ_I) + assert A.cancel_denom(QQ_I(0,2)) == (QQ_I(0,-1)*A, QQ_I(2)) + + raises(ZeroDivisionError, lambda: A.cancel_denom(ZZ(0))) + + +def test_DomainMatrix_cancel_denom_elementwise(): + A = DM([[2, 4], [6, 8]], ZZ) + numers, denoms = A.cancel_denom_elementwise(ZZ(1)) + assert numers == DM([[2, 4], [6, 8]], ZZ) + assert denoms == DM([[1, 1], [1, 1]], ZZ) + numers, denoms = A.cancel_denom_elementwise(ZZ(4)) + assert numers == DM([[1, 1], [3, 2]], ZZ) + assert denoms == DM([[2, 1], [2, 1]], ZZ) + + raises(ZeroDivisionError, lambda: A.cancel_denom_elementwise(ZZ(0))) + + +def test_DomainMatrix_content_primitive(): + A = DM([[2, 4], [6, 8]], ZZ) + A_primitive = DM([[1, 2], [3, 4]], ZZ) + A_content = ZZ(2) + assert A.content() == A_content + assert A.primitive() == (A_content, A_primitive) + + +def test_DomainMatrix_scc(): + Ad = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], + [ZZ(0), ZZ(1), ZZ(0)], + [ZZ(2), ZZ(0), ZZ(4)]], (3, 3), ZZ) + As = Ad.to_sparse() + Addm = Ad.rep + Asdm = As.rep + for A in [Ad, As, Addm, Asdm]: + assert Ad.scc() == [[1], [0, 2]] + + A = DM([[ZZ(1), ZZ(2), ZZ(3)]], ZZ) + raises(DMNonSquareMatrixError, lambda: A.scc()) + + +def test_DomainMatrix_rref(): + # More tests in test_rref.py + A = DomainMatrix([], (0, 1), QQ) + assert A.rref() == (A, ()) + + A = DomainMatrix([[QQ(1)]], (1, 1), QQ) + assert A.rref() == (A, (0,)) + + A = DomainMatrix([[QQ(0)]], (1, 1), QQ) + assert A.rref() == (A, ()) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Ar, pivots = A.rref() + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Ar, pivots = A.rref() + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(0), QQ(4)]], (2, 2), QQ) + Ar, pivots = A.rref() + assert Ar == DomainMatrix([[QQ(0), QQ(1)], [QQ(0), QQ(0)]], (2, 2), QQ) + assert pivots == (1,) + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + Ar, pivots = Az.rref() + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + methods = ('auto', 'GJ', 'FF', 'CD', 'GJ_dense', 'FF_dense', 'CD_dense') + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + for method in methods: + Ar, pivots = Az.rref(method=method) + assert Ar == DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + assert pivots == (0, 1) + + raises(ValueError, lambda: Az.rref(method='foo')) + raises(ValueError, lambda: Az.rref_den(method='foo')) + + +def test_DomainMatrix_columnspace(): + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ) + Acol = DomainMatrix([[QQ(1), QQ(1)], [QQ(2), QQ(3)]], (2, 2), QQ) + assert A.columnspace() == Acol + + Az = DomainMatrix([[ZZ(1), ZZ(-1), ZZ(1)], [ZZ(2), ZZ(-2), ZZ(3)]], (2, 3), ZZ) + raises(DMNotAField, lambda: Az.columnspace()) + + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ, fmt='sparse') + Acol = DomainMatrix({0: {0: QQ(1), 1: QQ(1)}, 1: {0: QQ(2), 1: QQ(3)}}, (2, 2), QQ) + assert A.columnspace() == Acol + + +def test_DomainMatrix_rowspace(): + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ) + assert A.rowspace() == A + + Az = DomainMatrix([[ZZ(1), ZZ(-1), ZZ(1)], [ZZ(2), ZZ(-2), ZZ(3)]], (2, 3), ZZ) + raises(DMNotAField, lambda: Az.rowspace()) + + A = DomainMatrix([[QQ(1), QQ(-1), QQ(1)], [QQ(2), QQ(-2), QQ(3)]], (2, 3), QQ, fmt='sparse') + assert A.rowspace() == A + + +def test_DomainMatrix_nullspace(): + A = DomainMatrix([[QQ(1), QQ(1)], [QQ(1), QQ(1)]], (2, 2), QQ) + Anull = DomainMatrix([[QQ(-1), QQ(1)]], (1, 2), QQ) + assert A.nullspace() == Anull + + A = DomainMatrix([[ZZ(1), ZZ(1)], [ZZ(1), ZZ(1)]], (2, 2), ZZ) + Anull = DomainMatrix([[ZZ(-1), ZZ(1)]], (1, 2), ZZ) + assert A.nullspace() == Anull + + raises(DMNotAField, lambda: A.nullspace(divide_last=True)) + + A = DomainMatrix([[ZZ(2), ZZ(2)], [ZZ(2), ZZ(2)]], (2, 2), ZZ) + Anull = DomainMatrix([[ZZ(-2), ZZ(2)]], (1, 2), ZZ) + + Arref, den, pivots = A.rref_den() + assert den == ZZ(2) + assert Arref.nullspace_from_rref() == Anull + assert Arref.nullspace_from_rref(pivots) == Anull + assert Arref.to_sparse().nullspace_from_rref() == Anull.to_sparse() + assert Arref.to_sparse().nullspace_from_rref(pivots) == Anull.to_sparse() + + +def test_DomainMatrix_solve(): + # XXX: Maybe the _solve method should be changed... + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + particular = DomainMatrix([[1, 0]], (1, 2), QQ) + nullspace = DomainMatrix([[-2, 1]], (1, 2), QQ) + assert A._solve(b) == (particular, nullspace) + + b3 = DomainMatrix([[QQ(1)], [QQ(1)], [QQ(1)]], (3, 1), QQ) + raises(DMShapeError, lambda: A._solve(b3)) + + bz = DomainMatrix([[ZZ(1)], [ZZ(1)]], (2, 1), ZZ) + raises(DMNotAField, lambda: A._solve(bz)) + + +def test_DomainMatrix_inv(): + A = DomainMatrix([], (0, 0), QQ) + assert A.inv() == A + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Ainv = DomainMatrix([[QQ(-2), QQ(1)], [QQ(3, 2), QQ(-1, 2)]], (2, 2), QQ) + assert A.inv() == Ainv + + Az = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + raises(DMNotAField, lambda: Az.inv()) + + Ans = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMNonSquareMatrixError, lambda: Ans.inv()) + + Aninv = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(6)]], (2, 2), QQ) + raises(DMNonInvertibleMatrixError, lambda: Aninv.inv()) + + +def test_DomainMatrix_det(): + A = DomainMatrix([], (0, 0), ZZ) + assert A.det() == 1 + + A = DomainMatrix([[1]], (1, 1), ZZ) + assert A.det() == 1 + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.det() == ZZ(-2) + + A = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(3), ZZ(5)]], (3, 3), ZZ) + assert A.det() == ZZ(-1) + + A = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(1), ZZ(2), ZZ(4)], [ZZ(1), ZZ(2), ZZ(5)]], (3, 3), ZZ) + assert A.det() == ZZ(0) + + Ans = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMNonSquareMatrixError, lambda: Ans.det()) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + assert A.det() == QQ(-2) + + +def test_DomainMatrix_eval_poly(): + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + p = [ZZ(1), ZZ(2), ZZ(3)] + result = DomainMatrix([[ZZ(12), ZZ(14)], [ZZ(21), ZZ(33)]], (2, 2), ZZ) + assert dM.eval_poly(p) == result == p[0]*dM**2 + p[1]*dM + p[2]*dM**0 + assert dM.eval_poly([]) == dM.zeros(dM.shape, dM.domain) + assert dM.eval_poly([ZZ(2)]) == 2*dM.eye(2, dM.domain) + + dM2 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMNonSquareMatrixError, lambda: dM2.eval_poly([ZZ(1)])) + + +def test_DomainMatrix_eval_poly_mul(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + p = [ZZ(1), ZZ(2), ZZ(3)] + result = DomainMatrix([[ZZ(40)], [ZZ(87)]], (2, 1), ZZ) + assert A.eval_poly_mul(p, b) == result == p[0]*A**2*b + p[1]*A*b + p[2]*b + + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + dM1 = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMNonSquareMatrixError, lambda: dM1.eval_poly_mul([ZZ(1)], b)) + b1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: dM.eval_poly_mul([ZZ(1)], b1)) + bq = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMDomainError, lambda: dM.eval_poly_mul([ZZ(1)], bq)) + + +def _check_solve_den(A, b, xnum, xden): + # Examples for solve_den, solve_den_charpoly, solve_den_rref should use + # this so that all methods and types are tested. + + case1 = (A, xnum, b) + case2 = (A.to_sparse(), xnum.to_sparse(), b.to_sparse()) + + for Ai, xnum_i, b_i in [case1, case2]: + # The key invariant for solve_den: + assert Ai*xnum_i == xden*b_i + + # solve_den_rref can differ at least by a minus sign + answers = [(xnum_i, xden), (-xnum_i, -xden)] + assert Ai.solve_den(b) in answers + assert Ai.solve_den(b, method='rref') in answers + assert Ai.solve_den_rref(b) in answers + + # charpoly can only be used if A is square and guarantees to return the + # actual determinant as a denominator. + m, n = Ai.shape + if m == n: + assert Ai.solve_den(b_i, method='charpoly') == (xnum_i, xden) + assert Ai.solve_den_charpoly(b_i) == (xnum_i, xden) + else: + raises(DMNonSquareMatrixError, lambda: Ai.solve_den_charpoly(b)) + raises(DMNonSquareMatrixError, lambda: Ai.solve_den(b, method='charpoly')) + + +def test_DomainMatrix_solve_den(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + result = DomainMatrix([[ZZ(0)], [ZZ(-1)]], (2, 1), ZZ) + den = ZZ(-2) + _check_solve_den(A, b, result, den) + + A = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(1), ZZ(2), ZZ(4)], + [ZZ(1), ZZ(3), ZZ(5)]], (3, 3), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)], [ZZ(3)]], (3, 1), ZZ) + result = DomainMatrix([[ZZ(2)], [ZZ(0)], [ZZ(-1)]], (3, 1), ZZ) + den = ZZ(-1) + _check_solve_den(A, b, result, den) + + A = DomainMatrix([[ZZ(2)], [ZZ(2)]], (2, 1), ZZ) + b = DomainMatrix([[ZZ(3)], [ZZ(3)]], (2, 1), ZZ) + result = DomainMatrix([[ZZ(3)]], (1, 1), ZZ) + den = ZZ(2) + _check_solve_den(A, b, result, den) + + +def test_DomainMatrix_solve_den_charpoly(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + A1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMNonSquareMatrixError, lambda: A1.solve_den_charpoly(b)) + b1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.solve_den_charpoly(b1)) + bq = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMDomainError, lambda: A.solve_den_charpoly(bq)) + + +def test_DomainMatrix_solve_den_charpoly_check(): + # Test check + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(2), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(3)]], (2, 1), ZZ) + raises(DMNonInvertibleMatrixError, lambda: A.solve_den_charpoly(b)) + adjAb = DomainMatrix([[ZZ(-2)], [ZZ(1)]], (2, 1), ZZ) + assert A.adjugate() * b == adjAb + assert A.solve_den_charpoly(b, check=False) == (adjAb, ZZ(0)) + + +def test_DomainMatrix_solve_den_errors(): + A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMShapeError, lambda: A.solve_den(b)) + raises(DMShapeError, lambda: A.solve_den_rref(b)) + + A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + b = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.solve_den(b)) + raises(DMShapeError, lambda: A.solve_den_rref(b)) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b1 = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + raises(DMShapeError, lambda: A.solve_den(b1)) + + A = DomainMatrix([[ZZ(2)]], (1, 1), ZZ) + b = DomainMatrix([[ZZ(2)]], (1, 1), ZZ) + raises(DMBadInputError, lambda: A.solve_den(b1, method='invalid')) + + A = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMNonSquareMatrixError, lambda: A.solve_den_charpoly(b)) + + +def test_DomainMatrix_solve_den_rref_underdetermined(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(1), ZZ(2)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(1)]], (2, 1), ZZ) + raises(DMNonInvertibleMatrixError, lambda: A.solve_den(b)) + raises(DMNonInvertibleMatrixError, lambda: A.solve_den_rref(b)) + + +def test_DomainMatrix_adj_poly_det(): + A = DM([[ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], ZZ) + p, detA = A.adj_poly_det() + assert p == [ZZ(1), ZZ(-15), ZZ(-18)] + assert A.adjugate() == p[0]*A**2 + p[1]*A**1 + p[2]*A**0 == A.eval_poly(p) + assert A.det() == detA + + A = DM([[ZZ(1), ZZ(2), ZZ(3)], + [ZZ(7), ZZ(8), ZZ(9)]], ZZ) + raises(DMNonSquareMatrixError, lambda: A.adj_poly_det()) + + +def test_DomainMatrix_inv_den(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + den = ZZ(-2) + result = DomainMatrix([[ZZ(4), ZZ(-2)], [ZZ(-3), ZZ(1)]], (2, 2), ZZ) + assert A.inv_den() == (result, den) + + +def test_DomainMatrix_adjugate(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + result = DomainMatrix([[ZZ(4), ZZ(-2)], [ZZ(-3), ZZ(1)]], (2, 2), ZZ) + assert A.adjugate() == result + + +def test_DomainMatrix_adj_det(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + adjA = DomainMatrix([[ZZ(4), ZZ(-2)], [ZZ(-3), ZZ(1)]], (2, 2), ZZ) + assert A.adj_det() == (adjA, ZZ(-2)) + + +def test_DomainMatrix_lu(): + A = DomainMatrix([], (0, 0), QQ) + assert A.lu() == (A, A, []) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(3), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(1), QQ(2)], [QQ(0), QQ(-2)]], (2, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(3), QQ(4)], [QQ(0), QQ(2)]], (2, 2), QQ) + swaps = [(0, 1)] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(2), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(1), QQ(2)], [QQ(0), QQ(0)]], (2, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(0), QQ(4)]], (2, 2), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(0), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(0), QQ(2)], [QQ(0), QQ(4)]], (2, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(1), QQ(2), QQ(3)], [QQ(4), QQ(5), QQ(6)]], (2, 3), QQ) + L = DomainMatrix([[QQ(1), QQ(0)], [QQ(4), QQ(1)]], (2, 2), QQ) + U = DomainMatrix([[QQ(1), QQ(2), QQ(3)], [QQ(0), QQ(-3), QQ(-6)]], (2, 3), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + L = DomainMatrix([ + [QQ(1), QQ(0), QQ(0)], + [QQ(3), QQ(1), QQ(0)], + [QQ(5), QQ(2), QQ(1)]], (3, 3), QQ) + U = DomainMatrix([[QQ(1), QQ(2)], [QQ(0), QQ(-2)], [QQ(0), QQ(0)]], (3, 2), QQ) + swaps = [] + assert A.lu() == (L, U, swaps) + + A = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 1, 2]] + L = [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 1, 1]] + U = [[1, 0, 0, 0], [0, 0, 0, 0], [0, 0, 1, 1], [0, 0, 0, 1]] + to_dom = lambda rows, dom: [[dom(e) for e in row] for row in rows] + A = DomainMatrix(to_dom(A, QQ), (4, 4), QQ) + L = DomainMatrix(to_dom(L, QQ), (4, 4), QQ) + U = DomainMatrix(to_dom(U, QQ), (4, 4), QQ) + assert A.lu() == (L, U, []) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + raises(DMNotAField, lambda: A.lu()) + + +def test_DomainMatrix_lu_solve(): + # Base case + A = b = x = DomainMatrix([], (0, 0), QQ) + assert A.lu_solve(b) == x + + # Basic example + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DomainMatrix([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Example with swaps + A = DomainMatrix([[QQ(0), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + x = DomainMatrix([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Non-invertible + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(2), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)]], (2, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Overdetermined, consistent + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)], [QQ(3)]], (3, 1), QQ) + x = DomainMatrix([[QQ(0)], [QQ(1, 2)]], (2, 1), QQ) + assert A.lu_solve(b) == x + + # Overdetermined, inconsistent + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)], [QQ(5), QQ(6)]], (3, 2), QQ) + b = DomainMatrix([[QQ(1)], [QQ(2)], [QQ(4)]], (3, 1), QQ) + raises(DMNonInvertibleMatrixError, lambda: A.lu_solve(b)) + + # Underdetermined + A = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + b = DomainMatrix([[QQ(1)]], (1, 1), QQ) + raises(NotImplementedError, lambda: A.lu_solve(b)) + + # Non-field + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + b = DomainMatrix([[ZZ(1)], [ZZ(2)]], (2, 1), ZZ) + raises(DMNotAField, lambda: A.lu_solve(b)) + + # Shape mismatch + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + b = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMShapeError, lambda: A.lu_solve(b)) + + +def test_DomainMatrix_charpoly(): + A = DomainMatrix([], (0, 0), ZZ) + p = [ZZ(1)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[1]], (1, 1), ZZ) + p = [ZZ(1), ZZ(-1)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + p = [ZZ(1), ZZ(-5), ZZ(-2)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[ZZ(1), ZZ(2), ZZ(3)], [ZZ(4), ZZ(5), ZZ(6)], [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + p = [ZZ(1), ZZ(-15), ZZ(-18), ZZ(0)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DomainMatrix([[ZZ(0), ZZ(1), ZZ(0)], + [ZZ(1), ZZ(0), ZZ(1)], + [ZZ(0), ZZ(1), ZZ(0)]], (3, 3), ZZ) + p = [ZZ(1), ZZ(0), ZZ(-2), ZZ(0)] + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + A = DM([[17, 0, 30, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [69, 0, 0, 0, 0, 86, 0, 0, 0, 0], + [23, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 13, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 32, 0, 0], + [ 0, 0, 0, 0, 37, 67, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], ZZ) + p = ZZ.map([1, -17, -2070, 0, -771420, 0, 0, 0, 0, 0, 0]) + assert A.charpoly() == p + assert A.to_sparse().charpoly() == p + + Ans = DomainMatrix([[QQ(1), QQ(2)]], (1, 2), QQ) + raises(DMNonSquareMatrixError, lambda: Ans.charpoly()) + + +def test_DomainMatrix_charpoly_factor_list(): + A = DomainMatrix([], (0, 0), ZZ) + assert A.charpoly_factor_list() == [] + + A = DM([[1]], ZZ) + assert A.charpoly_factor_list() == [ + ([ZZ(1), ZZ(-1)], 1) + ] + + A = DM([[1, 2], [3, 4]], ZZ) + assert A.charpoly_factor_list() == [ + ([ZZ(1), ZZ(-5), ZZ(-2)], 1) + ] + + A = DM([[1, 2, 0], [3, 4, 0], [0, 0, 1]], ZZ) + assert A.charpoly_factor_list() == [ + ([ZZ(1), ZZ(-1)], 1), + ([ZZ(1), ZZ(-5), ZZ(-2)], 1) + ] + + +def test_DomainMatrix_eye(): + A = DomainMatrix.eye(3, QQ) + assert A.rep == SDM.eye((3, 3), QQ) + assert A.shape == (3, 3) + assert A.domain == QQ + + +def test_DomainMatrix_zeros(): + A = DomainMatrix.zeros((1, 2), QQ) + assert A.rep == SDM.zeros((1, 2), QQ) + assert A.shape == (1, 2) + assert A.domain == QQ + + +def test_DomainMatrix_ones(): + A = DomainMatrix.ones((2, 3), QQ) + if GROUND_TYPES != 'flint': + assert A.rep == DDM.ones((2, 3), QQ) + else: + assert A.rep == SDM.ones((2, 3), QQ).to_dfm() + assert A.shape == (2, 3) + assert A.domain == QQ + + +def test_DomainMatrix_diag(): + A = DomainMatrix({0:{0:ZZ(2)}, 1:{1:ZZ(3)}}, (2, 2), ZZ) + assert DomainMatrix.diag([ZZ(2), ZZ(3)], ZZ) == A + + A = DomainMatrix({0:{0:ZZ(2)}, 1:{1:ZZ(3)}}, (3, 4), ZZ) + assert DomainMatrix.diag([ZZ(2), ZZ(3)], ZZ, (3, 4)) == A + + +def test_DomainMatrix_hstack(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + + AB = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(5), ZZ(6)], + [ZZ(3), ZZ(4), ZZ(7), ZZ(8)]], (2, 4), ZZ) + ABC = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(5), ZZ(6), ZZ(9), ZZ(10)], + [ZZ(3), ZZ(4), ZZ(7), ZZ(8), ZZ(11), ZZ(12)]], (2, 6), ZZ) + assert A.hstack(B) == AB + assert A.hstack(B, C) == ABC + + +def test_DomainMatrix_vstack(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + B = DomainMatrix([[ZZ(5), ZZ(6)], [ZZ(7), ZZ(8)]], (2, 2), ZZ) + C = DomainMatrix([[ZZ(9), ZZ(10)], [ZZ(11), ZZ(12)]], (2, 2), ZZ) + + AB = DomainMatrix([ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8)]], (4, 2), ZZ) + ABC = DomainMatrix([ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8)], + [ZZ(9), ZZ(10)], + [ZZ(11), ZZ(12)]], (6, 2), ZZ) + assert A.vstack(B) == AB + assert A.vstack(B, C) == ABC + + +def test_DomainMatrix_applyfunc(): + A = DomainMatrix([[ZZ(1), ZZ(2)]], (1, 2), ZZ) + B = DomainMatrix([[ZZ(2), ZZ(4)]], (1, 2), ZZ) + assert A.applyfunc(lambda x: 2*x) == B + + +def test_DomainMatrix_scalarmul(): + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + lamda = DomainScalar(QQ(3)/QQ(2), QQ) + assert A * lamda == DomainMatrix([[QQ(3, 2), QQ(3)], [QQ(9, 2), QQ(6)]], (2, 2), QQ) + assert A * 2 == DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + assert 2 * A == DomainMatrix([[ZZ(2), ZZ(4)], [ZZ(6), ZZ(8)]], (2, 2), ZZ) + assert A * DomainScalar(ZZ(0), ZZ) == DomainMatrix({}, (2, 2), ZZ) + assert A * DomainScalar(ZZ(1), ZZ) == A + + raises(TypeError, lambda: A * 1.5) + + +def test_DomainMatrix_truediv(): + A = DomainMatrix.from_Matrix(Matrix([[1, 2], [3, 4]])) + lamda = DomainScalar(QQ(3)/QQ(2), QQ) + assert A / lamda == DomainMatrix({0: {0: QQ(2, 3), 1: QQ(4, 3)}, 1: {0: QQ(2), 1: QQ(8, 3)}}, (2, 2), QQ) + b = DomainScalar(ZZ(1), ZZ) + assert A / b == DomainMatrix({0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}}, (2, 2), QQ) + + assert A / 1 == DomainMatrix({0: {0: QQ(1), 1: QQ(2)}, 1: {0: QQ(3), 1: QQ(4)}}, (2, 2), QQ) + assert A / 2 == DomainMatrix({0: {0: QQ(1, 2), 1: QQ(1)}, 1: {0: QQ(3, 2), 1: QQ(2)}}, (2, 2), QQ) + + raises(ZeroDivisionError, lambda: A / 0) + raises(TypeError, lambda: A / 1.5) + raises(ZeroDivisionError, lambda: A / DomainScalar(ZZ(0), ZZ)) + + A = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A.to_field() / 2 == DomainMatrix([[QQ(1, 2), QQ(1)], [QQ(3, 2), QQ(2)]], (2, 2), QQ) + assert A / 2 == DomainMatrix([[QQ(1, 2), QQ(1)], [QQ(3, 2), QQ(2)]], (2, 2), QQ) + assert A.to_field() / QQ(2,3) == DomainMatrix([[QQ(3, 2), QQ(3)], [QQ(9, 2), QQ(6)]], (2, 2), QQ) + + +def test_DomainMatrix_getitem(): + dM = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + + assert dM[1:,:-2] == DomainMatrix([[ZZ(4)], [ZZ(7)]], (2, 1), ZZ) + assert dM[2,:-2] == DomainMatrix([[ZZ(7)]], (1, 1), ZZ) + assert dM[:-2,:-2] == DomainMatrix([[ZZ(1)]], (1, 1), ZZ) + assert dM[:-1,0:2] == DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(4), ZZ(5)]], (2, 2), ZZ) + assert dM[:, -1] == DomainMatrix([[ZZ(3)], [ZZ(6)], [ZZ(9)]], (3, 1), ZZ) + assert dM[-1, :] == DomainMatrix([[ZZ(7), ZZ(8), ZZ(9)]], (1, 3), ZZ) + assert dM[::-1, :] == DomainMatrix([ + [ZZ(7), ZZ(8), ZZ(9)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(1), ZZ(2), ZZ(3)]], (3, 3), ZZ) + + raises(IndexError, lambda: dM[4, :-2]) + raises(IndexError, lambda: dM[:-2, 4]) + + assert dM[1, 2] == DomainScalar(ZZ(6), ZZ) + assert dM[-2, 2] == DomainScalar(ZZ(6), ZZ) + assert dM[1, -2] == DomainScalar(ZZ(5), ZZ) + assert dM[-1, -3] == DomainScalar(ZZ(7), ZZ) + + raises(IndexError, lambda: dM[3, 3]) + raises(IndexError, lambda: dM[1, 4]) + raises(IndexError, lambda: dM[-1, -4]) + + dM = DomainMatrix({0: {0: ZZ(1)}}, (10, 10), ZZ) + assert dM[5, 5] == DomainScalar(ZZ(0), ZZ) + assert dM[0, 0] == DomainScalar(ZZ(1), ZZ) + + dM = DomainMatrix({1: {0: 1}}, (2,1), ZZ) + assert dM[0:, 0] == DomainMatrix({1: {0: 1}}, (2, 1), ZZ) + raises(IndexError, lambda: dM[3, 0]) + + dM = DomainMatrix({2: {2: ZZ(1)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + assert dM[:2,:2] == DomainMatrix({}, (2, 2), ZZ) + assert dM[2:,2:] == DomainMatrix({0: {0: 1}, 2: {2: 1}}, (3, 3), ZZ) + assert dM[3:,3:] == DomainMatrix({1: {1: 1}}, (2, 2), ZZ) + assert dM[2:, 6:] == DomainMatrix({}, (3, 0), ZZ) + + +def test_DomainMatrix_getitem_sympy(): + dM = DomainMatrix({2: {2: ZZ(2)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + val1 = dM.getitem_sympy(0, 0) + assert val1 is S.Zero + val2 = dM.getitem_sympy(2, 2) + assert val2 == 2 and isinstance(val2, Integer) + + +def test_DomainMatrix_extract(): + dM1 = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(3)], + [ZZ(4), ZZ(5), ZZ(6)], + [ZZ(7), ZZ(8), ZZ(9)]], (3, 3), ZZ) + dM2 = DomainMatrix([ + [ZZ(1), ZZ(3)], + [ZZ(7), ZZ(9)]], (2, 2), ZZ) + assert dM1.extract([0, 2], [0, 2]) == dM2 + assert dM1.to_sparse().extract([0, 2], [0, 2]) == dM2.to_sparse() + assert dM1.extract([0, -1], [0, -1]) == dM2 + assert dM1.to_sparse().extract([0, -1], [0, -1]) == dM2.to_sparse() + + dM3 = DomainMatrix([ + [ZZ(1), ZZ(2), ZZ(2)], + [ZZ(4), ZZ(5), ZZ(5)], + [ZZ(4), ZZ(5), ZZ(5)]], (3, 3), ZZ) + assert dM1.extract([0, 1, 1], [0, 1, 1]) == dM3 + assert dM1.to_sparse().extract([0, 1, 1], [0, 1, 1]) == dM3.to_sparse() + + empty = [ + ([], [], (0, 0)), + ([1], [], (1, 0)), + ([], [1], (0, 1)), + ] + for rows, cols, size in empty: + assert dM1.extract(rows, cols) == DomainMatrix.zeros(size, ZZ).to_dense() + assert dM1.to_sparse().extract(rows, cols) == DomainMatrix.zeros(size, ZZ) + + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + bad_indices = [([2], [0]), ([0], [2]), ([-3], [0]), ([0], [-3])] + for rows, cols in bad_indices: + raises(IndexError, lambda: dM.extract(rows, cols)) + raises(IndexError, lambda: dM.to_sparse().extract(rows, cols)) + + +def test_DomainMatrix_setitem(): + dM = DomainMatrix({2: {2: ZZ(1)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + dM[2, 2] = ZZ(2) + assert dM == DomainMatrix({2: {2: ZZ(2)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + def setitem(i, j, val): + dM[i, j] = val + raises(TypeError, lambda: setitem(2, 2, QQ(1, 2))) + raises(NotImplementedError, lambda: setitem(slice(1, 2), 2, ZZ(1))) + + +def test_DomainMatrix_pickling(): + import pickle + dM = DomainMatrix({2: {2: ZZ(1)}, 4: {4: ZZ(1)}}, (5, 5), ZZ) + assert pickle.loads(pickle.dumps(dM)) == dM + dM = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert pickle.loads(pickle.dumps(dM)) == dM diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_domainscalar.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_domainscalar.py new file mode 100644 index 0000000000000000000000000000000000000000..8c507caf079cc62ba23ba171a50d0d27c98eb6d9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_domainscalar.py @@ -0,0 +1,153 @@ +from sympy.testing.pytest import raises + +from sympy.core.symbol import S +from sympy.polys import ZZ, QQ +from sympy.polys.matrices.domainscalar import DomainScalar +from sympy.polys.matrices.domainmatrix import DomainMatrix + + +def test_DomainScalar___new__(): + raises(TypeError, lambda: DomainScalar(ZZ(1), QQ)) + raises(TypeError, lambda: DomainScalar(ZZ(1), 1)) + + +def test_DomainScalar_new(): + A = DomainScalar(ZZ(1), ZZ) + B = A.new(ZZ(4), ZZ) + assert B == DomainScalar(ZZ(4), ZZ) + + +def test_DomainScalar_repr(): + A = DomainScalar(ZZ(1), ZZ) + assert repr(A) in {'1', 'mpz(1)'} + + +def test_DomainScalar_from_sympy(): + expr = S(1) + B = DomainScalar.from_sympy(expr) + assert B == DomainScalar(ZZ(1), ZZ) + + +def test_DomainScalar_to_sympy(): + B = DomainScalar(ZZ(1), ZZ) + expr = B.to_sympy() + assert expr.is_Integer and expr == 1 + + +def test_DomainScalar_to_domain(): + A = DomainScalar(ZZ(1), ZZ) + B = A.to_domain(QQ) + assert B == DomainScalar(QQ(1), QQ) + + +def test_DomainScalar_convert_to(): + A = DomainScalar(ZZ(1), ZZ) + B = A.convert_to(QQ) + assert B == DomainScalar(QQ(1), QQ) + + +def test_DomainScalar_unify(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + A, B = A.unify(B) + assert A.domain == B.domain == QQ + + +def test_DomainScalar_add(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A + B == DomainScalar(QQ(3), QQ) + + raises(TypeError, lambda: A + 1.5) + +def test_DomainScalar_sub(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A - B == DomainScalar(QQ(-1), QQ) + + raises(TypeError, lambda: A - 1.5) + +def test_DomainScalar_mul(): + A = DomainScalar(ZZ(1), ZZ) + B = DomainScalar(QQ(2), QQ) + dm = DomainMatrix([[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]], (2, 2), ZZ) + assert A * B == DomainScalar(QQ(2), QQ) + assert A * dm == dm + assert B * 2 == DomainScalar(QQ(4), QQ) + + raises(TypeError, lambda: A * 1.5) + + +def test_DomainScalar_floordiv(): + A = DomainScalar(ZZ(-5), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A // B == DomainScalar(QQ(-5, 2), QQ) + C = DomainScalar(ZZ(2), ZZ) + assert A // C == DomainScalar(ZZ(-3), ZZ) + + raises(TypeError, lambda: A // 1.5) + + +def test_DomainScalar_mod(): + A = DomainScalar(ZZ(5), ZZ) + B = DomainScalar(QQ(2), QQ) + assert A % B == DomainScalar(QQ(0), QQ) + C = DomainScalar(ZZ(2), ZZ) + assert A % C == DomainScalar(ZZ(1), ZZ) + + raises(TypeError, lambda: A % 1.5) + + +def test_DomainScalar_divmod(): + A = DomainScalar(ZZ(5), ZZ) + B = DomainScalar(QQ(2), QQ) + assert divmod(A, B) == (DomainScalar(QQ(5, 2), QQ), DomainScalar(QQ(0), QQ)) + C = DomainScalar(ZZ(2), ZZ) + assert divmod(A, C) == (DomainScalar(ZZ(2), ZZ), DomainScalar(ZZ(1), ZZ)) + + raises(TypeError, lambda: divmod(A, 1.5)) + + +def test_DomainScalar_pow(): + A = DomainScalar(ZZ(-5), ZZ) + B = A**(2) + assert B == DomainScalar(ZZ(25), ZZ) + + raises(TypeError, lambda: A**(1.5)) + + +def test_DomainScalar_pos(): + A = DomainScalar(QQ(2), QQ) + B = DomainScalar(QQ(2), QQ) + assert +A == B + + +def test_DomainScalar_neg(): + A = DomainScalar(QQ(2), QQ) + B = DomainScalar(QQ(-2), QQ) + assert -A == B + + +def test_DomainScalar_eq(): + A = DomainScalar(QQ(2), QQ) + assert A == A + B = DomainScalar(ZZ(-5), ZZ) + assert A != B + C = DomainScalar(ZZ(2), ZZ) + assert A != C + D = [1] + assert A != D + + +def test_DomainScalar_isZero(): + A = DomainScalar(ZZ(0), ZZ) + assert A.is_zero() == True + B = DomainScalar(ZZ(1), ZZ) + assert B.is_zero() == False + + +def test_DomainScalar_isOne(): + A = DomainScalar(ZZ(1), ZZ) + assert A.is_one() == True + B = DomainScalar(ZZ(0), ZZ) + assert B.is_one() == False diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_eigen.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_eigen.py new file mode 100644 index 0000000000000000000000000000000000000000..70482eab686d5b4e1c45d552f5eccb5bdaa9e1ed --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_eigen.py @@ -0,0 +1,90 @@ +""" +Tests for the sympy.polys.matrices.eigen module +""" + +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix + +from sympy.polys.agca.extensions import FiniteExtension +from sympy.polys.domains import QQ +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.polys.matrices.domainmatrix import DomainMatrix + +from sympy.polys.matrices.eigen import dom_eigenvects, dom_eigenvects_to_sympy + + +def test_dom_eigenvects_rational(): + # Rational eigenvalues + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(1), QQ(2)]], (2, 2), QQ) + rational_eigenvects = [ + (QQ, QQ(3), 1, DomainMatrix([[QQ(1), QQ(1)]], (1, 2), QQ)), + (QQ, QQ(0), 1, DomainMatrix([[QQ(-2), QQ(1)]], (1, 2), QQ)), + ] + assert dom_eigenvects(A) == (rational_eigenvects, []) + + # Test converting to Expr: + sympy_eigenvects = [ + (S(3), 1, [Matrix([1, 1])]), + (S(0), 1, [Matrix([-2, 1])]), + ] + assert dom_eigenvects_to_sympy(rational_eigenvects, [], Matrix) == sympy_eigenvects + + +def test_dom_eigenvects_algebraic(): + # Algebraic eigenvalues + A = DomainMatrix([[QQ(1), QQ(2)], [QQ(3), QQ(4)]], (2, 2), QQ) + Avects = dom_eigenvects(A) + + # Extract the dummy to build the expected result: + lamda = Avects[1][0][1].gens[0] + irreducible = Poly(lamda**2 - 5*lamda - 2, lamda, domain=QQ) + K = FiniteExtension(irreducible) + KK = K.from_sympy + algebraic_eigenvects = [ + (K, irreducible, 1, DomainMatrix([[KK((lamda-4)/3), KK(1)]], (1, 2), K)), + ] + assert Avects == ([], algebraic_eigenvects) + + # Test converting to Expr: + sympy_eigenvects = [ + (S(5)/2 - sqrt(33)/2, 1, [Matrix([[-sqrt(33)/6 - S(1)/2], [1]])]), + (S(5)/2 + sqrt(33)/2, 1, [Matrix([[-S(1)/2 + sqrt(33)/6], [1]])]), + ] + assert dom_eigenvects_to_sympy([], algebraic_eigenvects, Matrix) == sympy_eigenvects + + +def test_dom_eigenvects_rootof(): + # Algebraic eigenvalues + A = DomainMatrix([ + [0, 0, 0, 0, -1], + [1, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0]], (5, 5), QQ) + Avects = dom_eigenvects(A) + + # Extract the dummy to build the expected result: + lamda = Avects[1][0][1].gens[0] + irreducible = Poly(lamda**5 - lamda + 1, lamda, domain=QQ) + K = FiniteExtension(irreducible) + KK = K.from_sympy + algebraic_eigenvects = [ + (K, irreducible, 1, + DomainMatrix([ + [KK(lamda**4-1), KK(lamda**3), KK(lamda**2), KK(lamda), KK(1)] + ], (1, 5), K)), + ] + assert Avects == ([], algebraic_eigenvects) + + # Test converting to Expr (slow): + l0, l1, l2, l3, l4 = [CRootOf(lamda**5 - lamda + 1, i) for i in range(5)] + sympy_eigenvects = [ + (l0, 1, [Matrix([-1 + l0**4, l0**3, l0**2, l0, 1])]), + (l1, 1, [Matrix([-1 + l1**4, l1**3, l1**2, l1, 1])]), + (l2, 1, [Matrix([-1 + l2**4, l2**3, l2**2, l2, 1])]), + (l3, 1, [Matrix([-1 + l3**4, l3**3, l3**2, l3, 1])]), + (l4, 1, [Matrix([-1 + l4**4, l4**3, l4**2, l4, 1])]), + ] + assert dom_eigenvects_to_sympy([], algebraic_eigenvects, Matrix) == sympy_eigenvects diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_inverse.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_inverse.py new file mode 100644 index 0000000000000000000000000000000000000000..47c82799324518bd7d1cc2405ade0aa0a5a4f6e9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_inverse.py @@ -0,0 +1,193 @@ +from sympy import ZZ, Matrix +from sympy.polys.matrices import DM, DomainMatrix +from sympy.polys.matrices.dense import ddm_iinv +from sympy.polys.matrices.exceptions import DMNonInvertibleMatrixError +from sympy.matrices.exceptions import NonInvertibleMatrixError + +import pytest +from sympy.testing.pytest import raises +from sympy.core.numbers import all_close + +from sympy.abc import x + + +# Examples are given as adjugate matrix and determinant adj_det should match +# these exactly but inv_den only matches after cancel_denom. + + +INVERSE_EXAMPLES = [ + + ( + 'zz_1', + DomainMatrix([], (0, 0), ZZ), + DomainMatrix([], (0, 0), ZZ), + ZZ(1), + ), + + ( + 'zz_2', + DM([[2]], ZZ), + DM([[1]], ZZ), + ZZ(2), + ), + + ( + 'zz_3', + DM([[2, 0], + [0, 2]], ZZ), + DM([[2, 0], + [0, 2]], ZZ), + ZZ(4), + ), + + ( + 'zz_4', + DM([[1, 2], + [3, 4]], ZZ), + DM([[ 4, -2], + [-3, 1]], ZZ), + ZZ(-2), + ), + + ( + 'zz_5', + DM([[2, 2, 0], + [0, 2, 2], + [0, 0, 2]], ZZ), + DM([[4, -4, 4], + [0, 4, -4], + [0, 0, 4]], ZZ), + ZZ(8), + ), + + ( + 'zz_6', + DM([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], ZZ), + DM([[-3, 6, -3], + [ 6, -12, 6], + [-3, 6, -3]], ZZ), + ZZ(0), + ), +] + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_Matrix_inv(name, A, A_inv, den): + + def _check(**kwargs): + if den != 0: + assert A.inv(**kwargs) == A_inv + else: + raises(NonInvertibleMatrixError, lambda: A.inv(**kwargs)) + + K = A.domain + A = A.to_Matrix() + A_inv = A_inv.to_Matrix() / K.to_sympy(den) + _check() + for method in ['GE', 'LU', 'ADJ', 'CH', 'LDL', 'QR']: + _check(method=method) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dm_inv_den(name, A, A_inv, den): + if den != 0: + A_inv_f, den_f = A.inv_den() + assert A_inv_f.cancel_denom(den_f) == A_inv.cancel_denom(den) + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv_den()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dm_inv(name, A, A_inv, den): + A = A.to_field() + if den != 0: + A_inv = A_inv.to_field() / den + assert A.inv() == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_ddm_inv(name, A, A_inv, den): + A = A.to_field().to_ddm() + if den != 0: + A_inv = (A_inv.to_field() / den).to_ddm() + assert A.inv() == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_sdm_inv(name, A, A_inv, den): + A = A.to_field().to_sdm() + if den != 0: + A_inv = (A_inv.to_field() / den).to_sdm() + assert A.inv() == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: A.inv()) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dense_ddm_iinv(name, A, A_inv, den): + A = A.to_field().to_ddm().copy() + K = A.domain + A_result = A.copy() + if den != 0: + A_inv = (A_inv.to_field() / den).to_ddm() + ddm_iinv(A_result, A, K) + assert A_result == A_inv + else: + raises(DMNonInvertibleMatrixError, lambda: ddm_iinv(A_result, A, K)) + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_Matrix_adjugate(name, A, A_inv, den): + A = A.to_Matrix() + A_inv = A_inv.to_Matrix() + assert A.adjugate() == A_inv + for method in ["bareiss", "berkowitz", "bird", "laplace", "lu"]: + assert A.adjugate(method=method) == A_inv + + +@pytest.mark.parametrize('name, A, A_inv, den', INVERSE_EXAMPLES) +def test_dm_adj_det(name, A, A_inv, den): + assert A.adj_det() == (A_inv, den) + + +def test_inverse_inexact(): + + M = Matrix([[x-0.3, -0.06, -0.22], + [-0.46, x-0.48, -0.41], + [-0.14, -0.39, x-0.64]]) + + Mn = Matrix([[1.0*x**2 - 1.12*x + 0.1473, 0.06*x + 0.0474, 0.22*x - 0.081], + [0.46*x - 0.237, 1.0*x**2 - 0.94*x + 0.1612, 0.41*x - 0.0218], + [0.14*x + 0.1122, 0.39*x - 0.1086, 1.0*x**2 - 0.78*x + 0.1164]]) + + d = 1.0*x**3 - 1.42*x**2 + 0.4249*x - 0.0546540000000002 + + Mi = Mn / d + + M_dm = M.to_DM() + M_dmd = M_dm.to_dense() + M_dm_num, M_dm_den = M_dm.inv_den() + M_dmd_num, M_dmd_den = M_dmd.inv_den() + + # XXX: We don't check M_dm().to_field().inv() which currently uses division + # and produces a more complicate result from gcd cancellation failing. + # DomainMatrix.inv() over RR(x) should be changed to clear denominators and + # use DomainMatrix.inv_den(). + + Minvs = [ + M.inv(), + (M_dm_num.to_field() / M_dm_den).to_Matrix(), + (M_dmd_num.to_field() / M_dmd_den).to_Matrix(), + M_dm_num.to_Matrix() / M_dm_den.as_expr(), + M_dmd_num.to_Matrix() / M_dmd_den.as_expr(), + ] + + for Minv in Minvs: + for Mi1, Mi2 in zip(Minv.flat(), Mi.flat()): + assert all_close(Mi2, Mi1) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_linsolve.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_linsolve.py new file mode 100644 index 0000000000000000000000000000000000000000..9d8cd7eb9feb27c59d6a32ceb3f04118eae971e2 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_linsolve.py @@ -0,0 +1,111 @@ +# +# test_linsolve.py +# +# Test the internal implementation of linsolve. +# + +from sympy.testing.pytest import raises + +from sympy.core.numbers import I +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.abc import x, y, z + +from sympy.polys.matrices.linsolve import _linsolve +from sympy.polys.solvers import PolyNonlinearError + + +def test__linsolve(): + assert _linsolve([], [x]) == {x:x} + assert _linsolve([S.Zero], [x]) == {x:x} + assert _linsolve([x-1,x-2], [x]) is None + assert _linsolve([x-1], [x]) == {x:1} + assert _linsolve([x-1, y], [x, y]) == {x:1, y:S.Zero} + assert _linsolve([2*I], [x]) is None + raises(PolyNonlinearError, lambda: _linsolve([x*(1 + x)], [x])) + + +def test__linsolve_float(): + + # This should give the exact answer: + eqs = [ + y - x, + y - 0.0216 * x + ] + sol = {x:0.0, y:0.0} + assert _linsolve(eqs, (x, y)) == sol + + # Other cases should be close to eps + + def all_close(sol1, sol2, eps=1e-15): + close = lambda a, b: abs(a - b) < eps + assert sol1.keys() == sol2.keys() + return all(close(sol1[s], sol2[s]) for s in sol1) + + eqs = [ + 0.8*x + 0.8*z + 0.2, + 0.9*x + 0.7*y + 0.2*z + 0.9, + 0.7*x + 0.2*y + 0.2*z + 0.5 + ] + sol_exact = {x:-29/42, y:-11/21, z:37/84} + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + eqs = [ + 0.9*x + 0.3*y + 0.4*z + 0.6, + 0.6*x + 0.9*y + 0.1*z + 0.7, + 0.4*x + 0.6*y + 0.9*z + 0.5 + ] + sol_exact = {x:-88/175, y:-46/105, z:-1/25} + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + eqs = [ + 0.4*x + 0.3*y + 0.6*z + 0.7, + 0.4*x + 0.3*y + 0.9*z + 0.9, + 0.7*x + 0.9*y, + ] + sol_exact = {x:-9/5, y:7/5, z:-2/3} + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + eqs = [ + x*(0.7 + 0.6*I) + y*(0.4 + 0.7*I) + z*(0.9 + 0.1*I) + 0.5, + 0.2*I*x + 0.2*I*y + z*(0.9 + 0.2*I) + 0.1, + x*(0.9 + 0.7*I) + y*(0.9 + 0.7*I) + z*(0.9 + 0.4*I) + 0.4, + ] + sol_exact = { + x:-6157/7995 - 411/5330*I, + y:8519/15990 + 1784/7995*I, + z:-34/533 + 107/1599*I, + } + sol_linsolve = _linsolve(eqs, [x,y,z]) + assert all_close(sol_exact, sol_linsolve) + + # XXX: This system for x and y over RR(z) is problematic. + # + # eqs = [ + # x*(0.2*z + 0.9) + y*(0.5*z + 0.8) + 0.6, + # 0.1*x*z + y*(0.1*z + 0.6) + 0.9, + # ] + # + # linsolve(eqs, [x, y]) + # The solution for x comes out as + # + # -3.9e-5*z**2 - 3.6e-5*z - 8.67361737988404e-20 + # x = ---------------------------------------------- + # 3.0e-6*z**3 - 1.3e-5*z**2 - 5.4e-5*z + # + # The 8e-20 in the numerator should be zero which would allow z to cancel + # from top and bottom. It should be possible to avoid this somehow because + # the inverse of the matrix only has a quadratic factor (the determinant) + # in the denominator. + + +def test__linsolve_deprecated(): + raises(PolyNonlinearError, lambda: + _linsolve([Eq(x**2, x**2 + y)], [x, y])) + raises(PolyNonlinearError, lambda: + _linsolve([(x + y)**2 - x**2], [x])) + raises(PolyNonlinearError, lambda: + _linsolve([Eq((x + y)**2, x**2)], [x])) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_lll.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_lll.py new file mode 100644 index 0000000000000000000000000000000000000000..2cf91a00703532f02d763656d6117018fbc496cf --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_lll.py @@ -0,0 +1,145 @@ +from sympy.polys.domains import ZZ, QQ +from sympy.polys.matrices import DM +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.matrices.exceptions import DMRankError, DMValueError, DMShapeError, DMDomainError +from sympy.polys.matrices.lll import _ddm_lll, ddm_lll, ddm_lll_transform +from sympy.testing.pytest import raises + + +def test_lll(): + normal_test_data = [ + ( + DM([[1, 0, 0, 0, -20160], + [0, 1, 0, 0, 33768], + [0, 0, 1, 0, 39578], + [0, 0, 0, 1, 47757]], ZZ), + DM([[10, -3, -2, 8, -4], + [3, -9, 8, 1, -11], + [-3, 13, -9, -3, -9], + [-12, -7, -11, 9, -1]], ZZ) + ), + ( + DM([[20, 52, 3456], + [14, 31, -1], + [34, -442, 0]], ZZ), + DM([[14, 31, -1], + [188, -101, -11], + [236, 13, 3443]], ZZ) + ), + ( + DM([[34, -1, -86, 12], + [-54, 34, 55, 678], + [23, 3498, 234, 6783], + [87, 49, 665, 11]], ZZ), + DM([[34, -1, -86, 12], + [291, 43, 149, 83], + [-54, 34, 55, 678], + [-189, 3077, -184, -223]], ZZ) + ) + ] + delta = QQ(5, 6) + for basis_dm, reduced_dm in normal_test_data: + reduced = _ddm_lll(basis_dm.rep.to_ddm(), delta=delta)[0] + assert reduced == reduced_dm.rep.to_ddm() + + reduced = ddm_lll(basis_dm.rep.to_ddm(), delta=delta) + assert reduced == reduced_dm.rep.to_ddm() + + reduced, transform = _ddm_lll(basis_dm.rep.to_ddm(), delta=delta, return_transform=True) + assert reduced == reduced_dm.rep.to_ddm() + assert transform.matmul(basis_dm.rep.to_ddm()) == reduced_dm.rep.to_ddm() + + reduced, transform = ddm_lll_transform(basis_dm.rep.to_ddm(), delta=delta) + assert reduced == reduced_dm.rep.to_ddm() + assert transform.matmul(basis_dm.rep.to_ddm()) == reduced_dm.rep.to_ddm() + + reduced = basis_dm.rep.lll(delta=delta) + assert reduced == reduced_dm.rep + + reduced, transform = basis_dm.rep.lll_transform(delta=delta) + assert reduced == reduced_dm.rep + assert transform.matmul(basis_dm.rep) == reduced_dm.rep + + reduced = basis_dm.rep.to_sdm().lll(delta=delta) + assert reduced == reduced_dm.rep.to_sdm() + + reduced, transform = basis_dm.rep.to_sdm().lll_transform(delta=delta) + assert reduced == reduced_dm.rep.to_sdm() + assert transform.matmul(basis_dm.rep.to_sdm()) == reduced_dm.rep.to_sdm() + + reduced = basis_dm.lll(delta=delta) + assert reduced == reduced_dm + + reduced, transform = basis_dm.lll_transform(delta=delta) + assert reduced == reduced_dm + assert transform.matmul(basis_dm) == reduced_dm + + +def test_lll_linear_dependent(): + linear_dependent_test_data = [ + DM([[0, -1, -2, -3], + [1, 0, -1, -2], + [2, 1, 0, -1], + [3, 2, 1, 0]], ZZ), + DM([[1, 0, 0, 1], + [0, 1, 0, 1], + [0, 0, 1, 1], + [1, 2, 3, 6]], ZZ), + DM([[3, -5, 1], + [4, 6, 0], + [10, -4, 2]], ZZ) + ] + for not_basis in linear_dependent_test_data: + raises(DMRankError, lambda: _ddm_lll(not_basis.rep.to_ddm())) + raises(DMRankError, lambda: ddm_lll(not_basis.rep.to_ddm())) + raises(DMRankError, lambda: not_basis.rep.lll()) + raises(DMRankError, lambda: not_basis.rep.to_sdm().lll()) + raises(DMRankError, lambda: not_basis.lll()) + raises(DMRankError, lambda: _ddm_lll(not_basis.rep.to_ddm(), return_transform=True)) + raises(DMRankError, lambda: ddm_lll_transform(not_basis.rep.to_ddm())) + raises(DMRankError, lambda: not_basis.rep.lll_transform()) + raises(DMRankError, lambda: not_basis.rep.to_sdm().lll_transform()) + raises(DMRankError, lambda: not_basis.lll_transform()) + + +def test_lll_wrong_delta(): + dummy_matrix = DomainMatrix.ones((3, 3), ZZ) + for wrong_delta in [QQ(-1, 4), QQ(0, 1), QQ(1, 4), QQ(1, 1), QQ(100, 1)]: + raises(DMValueError, lambda: _ddm_lll(dummy_matrix.rep, delta=wrong_delta)) + raises(DMValueError, lambda: ddm_lll(dummy_matrix.rep, delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.lll(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.to_sdm().lll(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.lll(delta=wrong_delta)) + raises(DMValueError, lambda: _ddm_lll(dummy_matrix.rep, delta=wrong_delta, return_transform=True)) + raises(DMValueError, lambda: ddm_lll_transform(dummy_matrix.rep, delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.lll_transform(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.rep.to_sdm().lll_transform(delta=wrong_delta)) + raises(DMValueError, lambda: dummy_matrix.lll_transform(delta=wrong_delta)) + + +def test_lll_wrong_shape(): + wrong_shape_matrix = DomainMatrix.ones((4, 3), ZZ) + raises(DMShapeError, lambda: _ddm_lll(wrong_shape_matrix.rep)) + raises(DMShapeError, lambda: ddm_lll(wrong_shape_matrix.rep)) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.lll()) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.to_sdm().lll()) + raises(DMShapeError, lambda: wrong_shape_matrix.lll()) + raises(DMShapeError, lambda: _ddm_lll(wrong_shape_matrix.rep, return_transform=True)) + raises(DMShapeError, lambda: ddm_lll_transform(wrong_shape_matrix.rep)) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.lll_transform()) + raises(DMShapeError, lambda: wrong_shape_matrix.rep.to_sdm().lll_transform()) + raises(DMShapeError, lambda: wrong_shape_matrix.lll_transform()) + + +def test_lll_wrong_domain(): + wrong_domain_matrix = DomainMatrix.ones((3, 3), QQ) + raises(DMDomainError, lambda: _ddm_lll(wrong_domain_matrix.rep)) + raises(DMDomainError, lambda: ddm_lll(wrong_domain_matrix.rep)) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.lll()) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.to_sdm().lll()) + raises(DMDomainError, lambda: wrong_domain_matrix.lll()) + raises(DMDomainError, lambda: _ddm_lll(wrong_domain_matrix.rep, return_transform=True)) + raises(DMDomainError, lambda: ddm_lll_transform(wrong_domain_matrix.rep)) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.lll_transform()) + raises(DMDomainError, lambda: wrong_domain_matrix.rep.to_sdm().lll_transform()) + raises(DMDomainError, lambda: wrong_domain_matrix.lll_transform()) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_normalforms.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_normalforms.py new file mode 100644 index 0000000000000000000000000000000000000000..a3471400c877608003a14e55b4ffe49df6f6bd09 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_normalforms.py @@ -0,0 +1,75 @@ +from sympy.testing.pytest import raises + +from sympy.core.symbol import Symbol +from sympy.polys.matrices.normalforms import ( + invariant_factors, smith_normal_form, + hermite_normal_form, _hermite_normal_form, _hermite_normal_form_modulo_D) +from sympy.polys.domains import ZZ, QQ +from sympy.polys.matrices import DomainMatrix, DM +from sympy.polys.matrices.exceptions import DMDomainError, DMShapeError + + +def test_smith_normal(): + + m = DM([[12, 6, 4, 8], [3, 9, 6, 12], [2, 16, 14, 28], [20, 10, 10, 20]], ZZ) + smf = DM([[1, 0, 0, 0], [0, 10, 0, 0], [0, 0, -30, 0], [0, 0, 0, 0]], ZZ) + assert smith_normal_form(m).to_dense() == smf + + x = Symbol('x') + m = DM([[x-1, 1, -1], + [ 0, x, -1], + [ 0, -1, x]], QQ[x]) + dx = m.domain.gens[0] + assert invariant_factors(m) == (1, dx-1, dx**2-1) + + zr = DomainMatrix([], (0, 2), ZZ) + zc = DomainMatrix([[], []], (2, 0), ZZ) + assert smith_normal_form(zr).to_dense() == zr + assert smith_normal_form(zc).to_dense() == zc + + assert smith_normal_form(DM([[2, 4]], ZZ)).to_dense() == DM([[2, 0]], ZZ) + assert smith_normal_form(DM([[0, -2]], ZZ)).to_dense() == DM([[-2, 0]], ZZ) + assert smith_normal_form(DM([[0], [-2]], ZZ)).to_dense() == DM([[-2], [0]], ZZ) + + m = DM([[3, 0, 0, 0], [0, 0, 0, 0], [0, 0, 2, 0]], ZZ) + snf = DM([[1, 0, 0, 0], [0, 6, 0, 0], [0, 0, 0, 0]], ZZ) + assert smith_normal_form(m).to_dense() == snf + + raises(ValueError, lambda: smith_normal_form(DM([[1]], ZZ[x]))) + + +def test_hermite_normal(): + m = DM([[2, 7, 17, 29, 41], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]], ZZ) + hnf = DM([[1, 0, 0], [0, 2, 1], [0, 0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + assert hermite_normal_form(m, D=ZZ(2)) == hnf + assert hermite_normal_form(m, D=ZZ(2), check_rank=True) == hnf + + m = m.transpose() + hnf = DM([[37, 0, 19], [222, -6, 113], [48, 0, 25], [0, 2, 1], [0, 0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + raises(DMShapeError, lambda: _hermite_normal_form_modulo_D(m, ZZ(96))) + raises(DMDomainError, lambda: _hermite_normal_form_modulo_D(m, QQ(96))) + + m = DM([[8, 28, 68, 116, 164], [3, 11, 19, 31, 43], [5, 13, 23, 37, 47]], ZZ) + hnf = DM([[4, 0, 0], [0, 2, 1], [0, 0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + assert hermite_normal_form(m, D=ZZ(8)) == hnf + assert hermite_normal_form(m, D=ZZ(8), check_rank=True) == hnf + + m = DM([[10, 8, 6, 30, 2], [45, 36, 27, 18, 9], [5, 4, 3, 2, 1]], ZZ) + hnf = DM([[26, 2], [0, 9], [0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + + m = DM([[2, 7], [0, 0], [0, 0]], ZZ) + hnf = DM([[1], [0], [0]], ZZ) + assert hermite_normal_form(m) == hnf + + m = DM([[-2, 1], [0, 1]], ZZ) + hnf = DM([[2, 1], [0, 1]], ZZ) + assert hermite_normal_form(m) == hnf + + m = DomainMatrix([[QQ(1)]], (1, 1), QQ) + raises(DMDomainError, lambda: hermite_normal_form(m)) + raises(DMDomainError, lambda: _hermite_normal_form(m)) + raises(DMDomainError, lambda: _hermite_normal_form_modulo_D(m, ZZ(1))) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_nullspace.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_nullspace.py new file mode 100644 index 0000000000000000000000000000000000000000..dbb025b7dc9dff31bc97d86e175147ffede5a7e3 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_nullspace.py @@ -0,0 +1,209 @@ +from sympy import ZZ, Matrix +from sympy.polys.matrices import DM, DomainMatrix +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.sdm import SDM + +import pytest + +zeros = lambda shape, K: DomainMatrix.zeros(shape, K).to_dense() +eye = lambda n, K: DomainMatrix.eye(n, K).to_dense() + + +# +# DomainMatrix.nullspace can have a divided answer or can return an undivided +# uncanonical answer. The uncanonical answer is not unique but we can make it +# unique by making it primitive (remove gcd). The tests here all show the +# primitive form. We test two things: +# +# A.nullspace().primitive()[1] == answer. +# A.nullspace(divide_last=True) == _divide_last(answer). +# +# The nullspace as returned by DomainMatrix and related classes is the +# transpose of the nullspace as returned by Matrix. Matrix returns a list of +# of column vectors whereas DomainMatrix returns a matrix whose rows are the +# nullspace vectors. +# + + +NULLSPACE_EXAMPLES = [ + + ( + 'zz_1', + DM([[ 1, 2, 3]], ZZ), + DM([[-2, 1, 0], + [-3, 0, 1]], ZZ), + ), + + ( + 'zz_2', + zeros((0, 0), ZZ), + zeros((0, 0), ZZ), + ), + + ( + 'zz_3', + zeros((2, 0), ZZ), + zeros((0, 0), ZZ), + ), + + ( + 'zz_4', + zeros((0, 2), ZZ), + eye(2, ZZ), + ), + + ( + 'zz_5', + zeros((2, 2), ZZ), + eye(2, ZZ), + ), + + ( + 'zz_6', + DM([[1, 2], + [3, 4]], ZZ), + zeros((0, 2), ZZ), + ), + + ( + 'zz_7', + DM([[1, 1], + [1, 1]], ZZ), + DM([[-1, 1]], ZZ), + ), + + ( + 'zz_8', + DM([[1], + [1]], ZZ), + zeros((0, 1), ZZ), + ), + + ( + 'zz_9', + DM([[1, 1]], ZZ), + DM([[-1, 1]], ZZ), + ), + + ( + 'zz_10', + DM([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1]], ZZ), + DM([[ 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [-1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [ 0, -1, 0, 0, 0, 0, 0, 1, 0, 0], + [ 0, 0, 0, -1, 0, 0, 0, 0, 1, 0], + [ 0, 0, 0, 0, -1, 0, 0, 0, 0, 1]], ZZ), + ), + +] + + +def _to_DM(A, ans): + """Convert the answer to DomainMatrix.""" + if isinstance(A, DomainMatrix): + return A.to_dense() + elif isinstance(A, DDM): + return DomainMatrix(list(A), A.shape, A.domain).to_dense() + elif isinstance(A, SDM): + return DomainMatrix(dict(A), A.shape, A.domain).to_dense() + else: + assert False # pragma: no cover + + +def _divide_last(null): + """Normalize the nullspace by the rightmost non-zero entry.""" + null = null.to_field() + + if null.is_zero_matrix: + return null + + rows = [] + for i in range(null.shape[0]): + for j in reversed(range(null.shape[1])): + if null[i, j]: + rows.append(null[i, :] / null[i, j]) + break + else: + assert False # pragma: no cover + + return DomainMatrix.vstack(*rows) + + +def _check_primitive(null, null_ans): + """Check that the primitive of the answer matches.""" + null = _to_DM(null, null_ans) + cont, null_prim = null.primitive() + assert null_prim == null_ans + + +def _check_divided(null, null_ans): + """Check the divided answer.""" + null = _to_DM(null, null_ans) + null_ans_norm = _divide_last(null_ans) + assert null == null_ans_norm + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_Matrix_nullspace(name, A, A_null): + A = A.to_Matrix() + + A_null_cols = A.nullspace() + + # We have to patch up the case where the nullspace is empty + if A_null_cols: + A_null_found = Matrix.hstack(*A_null_cols) + else: + A_null_found = Matrix.zeros(A.cols, 0) + + A_null_found = A_null_found.to_DM().to_field().to_dense() + + # The Matrix result is the transpose of DomainMatrix result. + A_null_found = A_null_found.transpose() + + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_dense_nullspace(name, A, A_null): + A = A.to_field().to_dense() + A_null_found = A.nullspace(divide_last=True) + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_sparse_nullspace(name, A, A_null): + A = A.to_field().to_sparse() + A_null_found = A.nullspace(divide_last=True) + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_ddm_nullspace(name, A, A_null): + A = A.to_field().to_ddm() + A_null_found, _ = A.nullspace() + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_sdm_nullspace(name, A, A_null): + A = A.to_field().to_sdm() + A_null_found, _ = A.nullspace() + _check_divided(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_dense_nullspace_fracfree(name, A, A_null): + A = A.to_dense() + A_null_found = A.nullspace() + _check_primitive(A_null_found, A_null) + + +@pytest.mark.parametrize('name, A, A_null', NULLSPACE_EXAMPLES) +def test_dm_sparse_nullspace_fracfree(name, A, A_null): + A = A.to_sparse() + A_null_found = A.nullspace() + _check_primitive(A_null_found, A_null) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_rref.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_rref.py new file mode 100644 index 0000000000000000000000000000000000000000..49def18c8132c0537540163a96bf6cf323c5a85c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_rref.py @@ -0,0 +1,737 @@ +from sympy import ZZ, QQ, ZZ_I, EX, Matrix, eye, zeros, symbols +from sympy.polys.matrices import DM, DomainMatrix +from sympy.polys.matrices.dense import ddm_irref_den, ddm_irref +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.sdm import SDM, sdm_irref, sdm_rref_den + +import pytest + + +# +# The dense and sparse implementations of rref_den are ddm_irref_den and +# sdm_irref_den. These can give results that differ by some factor and also +# give different results if the order of the rows is changed. The tests below +# show all results on lowest terms as should be returned by cancel_denom. +# +# The EX domain is also a case where the dense and sparse implementations +# can give results in different forms: the results should be equivalent but +# are not canonical because EX does not have a canonical form. +# + + +a, b, c, d = symbols('a, b, c, d') + + +qq_large_1 = DM([ +[ (1,2), (1,3), (1,5), (1,7), (1,11), (1,13), (1,17), (1,19), (1,23), (1,29), (1,31)], +[ (1,37), (1,41), (1,43), (1,47), (1,53), (1,59), (1,61), (1,67), (1,71), (1,73), (1,79)], +[ (1,83), (1,89), (1,97),(1,101),(1,103),(1,107),(1,109),(1,113),(1,127),(1,131),(1,137)], +[(1,139),(1,149),(1,151),(1,157),(1,163),(1,167),(1,173),(1,179),(1,181),(1,191),(1,193)], +[(1,197),(1,199),(1,211),(1,223),(1,227),(1,229),(1,233),(1,239),(1,241),(1,251),(1,257)], +[(1,263),(1,269),(1,271),(1,277),(1,281),(1,283),(1,293),(1,307),(1,311),(1,313),(1,317)], +[(1,331),(1,337),(1,347),(1,349),(1,353),(1,359),(1,367),(1,373),(1,379),(1,383),(1,389)], +[(1,397),(1,401),(1,409),(1,419),(1,421),(1,431),(1,433),(1,439),(1,443),(1,449),(1,457)], +[(1,461),(1,463),(1,467),(1,479),(1,487),(1,491),(1,499),(1,503),(1,509),(1,521),(1,523)], +[(1,541),(1,547),(1,557),(1,563),(1,569),(1,571),(1,577),(1,587),(1,593),(1,599),(1,601)], +[(1,607),(1,613),(1,617),(1,619),(1,631),(1,641),(1,643),(1,647),(1,653),(1,659),(1,661)]], + QQ) + +qq_large_2 = qq_large_1 + 10**100 * DomainMatrix.eye(11, QQ) + + +RREF_EXAMPLES = [ + ( + 'zz_1', + DM([[1, 2, 3]], ZZ), + DM([[1, 2, 3]], ZZ), + ZZ(1), + ), + + ( + 'zz_2', + DomainMatrix([], (0, 0), ZZ), + DomainMatrix([], (0, 0), ZZ), + ZZ(1), + ), + + ( + 'zz_3', + DM([[1, 2], + [3, 4]], ZZ), + DM([[1, 0], + [0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_4', + DM([[1, 0], + [3, 4]], ZZ), + DM([[1, 0], + [0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_5', + DM([[0, 2], + [3, 4]], ZZ), + DM([[1, 0], + [0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_6', + DM([[1, 2, 3], + [4, 5, 6], + [7, 8, 9]], ZZ), + DM([[1, 0, -1], + [0, 1, 2], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_7', + DM([[0, 0, 0], + [0, 0, 0], + [1, 0, 0]], ZZ), + DM([[1, 0, 0], + [0, 0, 0], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_8', + DM([[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], ZZ), + DM([[0, 0, 0], + [0, 0, 0], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_9', + DM([[1, 1, 0], + [0, 0, 2], + [0, 0, 0]], ZZ), + DM([[1, 1, 0], + [0, 0, 1], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_10', + DM([[2, 2, 0], + [0, 0, 2], + [0, 0, 0]], ZZ), + DM([[1, 1, 0], + [0, 0, 1], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_11', + DM([[2, 2, 0], + [0, 2, 2], + [0, 0, 2]], ZZ), + DM([[1, 0, 0], + [0, 1, 0], + [0, 0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_12', + DM([[ 1, 2, 3], + [ 4, 5, 6], + [ 7, 8, 9], + [10, 11, 12]], ZZ), + DM([[1, 0, -1], + [0, 1, 2], + [0, 0, 0], + [0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_13', + DM([[ 1, 2, 3], + [ 4, 5, 6], + [ 7, 8, 9], + [10, 11, 13]], ZZ), + DM([[ 1, 0, 0], + [ 0, 1, 0], + [ 0, 0, 1], + [ 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_14', + DM([[1, 2, 4, 3], + [4, 5, 10, 6], + [7, 8, 16, 9]], ZZ), + DM([[1, 0, 0, -1], + [0, 1, 2, 2], + [0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_15', + DM([[1, 2, 4, 3], + [4, 5, 10, 6], + [7, 8, 17, 9]], ZZ), + DM([[1, 0, 0, -1], + [0, 1, 0, 2], + [0, 0, 1, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_16', + DM([[1, 2, 0, 1], + [1, 1, 9, 0]], ZZ), + DM([[1, 0, 18, -1], + [0, 1, -9, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_17', + DM([[1, 1, 1], + [1, 2, 2]], ZZ), + DM([[1, 0, 0], + [0, 1, 1]], ZZ), + ZZ(1), + ), + + ( + # Here the sparse implementation and dense implementation give very + # different denominators: 4061232 and -1765176. + 'zz_18', + DM([[94, 24, 0, 27, 0], + [79, 0, 0, 0, 0], + [85, 16, 71, 81, 0], + [ 0, 0, 72, 77, 0], + [21, 0, 34, 0, 0]], ZZ), + DM([[ 1, 0, 0, 0, 0], + [ 0, 1, 0, 0, 0], + [ 0, 0, 1, 0, 0], + [ 0, 0, 0, 1, 0], + [ 0, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + # Let's have a denominator that cannot be cancelled. + 'zz_19', + DM([[1, 2, 4], + [4, 5, 6]], ZZ), + DM([[3, 0, -8], + [0, 3, 10]], ZZ), + ZZ(3), + ), + + ( + 'zz_20', + DM([[0, 0, 0, 0, 0], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 4]], ZZ), + DM([[0, 0, 0, 0, 1], + [0, 0, 0, 0, 0], + [0, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_21', + DM([[0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1]], ZZ), + DM([[1, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_22', + DM([[1, 1, 1, 0, 1], + [1, 1, 0, 1, 0], + [1, 0, 1, 0, 1], + [1, 1, 0, 1, 0], + [1, 0, 0, 0, 0]], ZZ), + DM([[1, 0, 0, 0, 0], + [0, 1, 0, 0, 0], + [0, 0, 1, 0, 1], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 0]], ZZ), + ZZ(1), + ), + + ( + 'zz_large_1', + DM([ +[ 0, 0, 0, 81, 0, 0, 75, 0, 0, 0, 0, 0, 0, 27, 0, 0, 0, 0, 0, 0], +[ 0, 0, 0, 0, 0, 86, 0, 92, 79, 54, 0, 7, 0, 0, 0, 0, 79, 0, 0, 0], +[89, 54, 81, 0, 0, 20, 0, 0, 0, 0, 0, 0, 51, 0, 94, 0, 0, 77, 0, 0], +[ 0, 0, 0, 96, 0, 0, 0, 0, 0, 0, 0, 0, 48, 29, 0, 0, 5, 0, 32, 0], +[ 0, 70, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 60, 0, 0, 0, 11], +[ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 37, 0, 43, 0, 0], +[ 0, 0, 0, 0, 0, 38, 91, 0, 0, 0, 0, 38, 0, 0, 0, 0, 0, 26, 0, 0], +[69, 0, 0, 0, 0, 0, 94, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 55], +[ 0, 13, 18, 49, 49, 88, 0, 0, 35, 54, 0, 0, 51, 0, 0, 0, 0, 0, 0, 87], +[ 0, 0, 0, 0, 31, 0, 40, 0, 0, 0, 0, 0, 0, 50, 0, 0, 0, 0, 88, 0], +[ 0, 0, 0, 0, 0, 0, 0, 0, 98, 0, 0, 0, 15, 53, 0, 92, 0, 0, 0, 0], +[ 0, 0, 0, 95, 0, 0, 0, 36, 0, 0, 0, 0, 0, 72, 0, 0, 0, 0, 73, 19], +[ 0, 65, 14, 96, 0, 0, 0, 0, 0, 0, 0, 0, 0, 90, 0, 0, 0, 34, 0, 0], +[ 0, 0, 0, 16, 39, 44, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 51, 0, 0], +[ 0, 17, 0, 0, 0, 99, 84, 13, 50, 84, 0, 0, 0, 0, 95, 0, 43, 33, 20, 0], +[79, 0, 17, 52, 99, 12, 69, 0, 98, 0, 68, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[ 0, 0, 0, 82, 0, 44, 0, 0, 0, 97, 0, 0, 0, 0, 0, 10, 0, 0, 31, 0], +[ 0, 0, 21, 0, 67, 0, 0, 0, 0, 0, 4, 0, 50, 0, 0, 0, 33, 0, 0, 0], +[ 0, 0, 0, 0, 9, 42, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8], +[ 0, 77, 0, 0, 0, 0, 0, 0, 0, 0, 34, 93, 0, 0, 0, 0, 47, 0, 0, 0]], + ZZ), + DM([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], ZZ), + ZZ(1), + ), + + ( + 'zz_large_2', + DM([ +[ 0, 0, 0, 0, 50, 0, 6, 81, 0, 1, 86, 0, 0, 98, 82, 94, 4, 0, 0, 29], +[ 0, 44, 43, 0, 62, 0, 0, 0, 60, 0, 0, 0, 0, 71, 9, 0, 57, 41, 0, 93], +[ 0, 0, 28, 0, 74, 89, 42, 0, 28, 0, 6, 0, 0, 0, 44, 0, 0, 0, 77, 19], +[ 0, 21, 82, 0, 30, 88, 0, 89, 68, 0, 0, 0, 79, 41, 0, 0, 99, 0, 0, 0], +[31, 0, 0, 0, 19, 64, 0, 0, 79, 0, 5, 0, 72, 10, 60, 32, 64, 59, 0, 24], +[ 0, 0, 0, 0, 0, 57, 0, 94, 0, 83, 20, 0, 0, 9, 31, 0, 49, 26, 58, 0], +[ 0, 65, 56, 31, 64, 0, 0, 0, 0, 0, 0, 52, 85, 0, 0, 0, 0, 51, 0, 0], +[ 0, 35, 0, 0, 0, 69, 0, 0, 64, 0, 0, 0, 0, 70, 0, 0, 90, 0, 75, 76], +[69, 7, 0, 90, 0, 0, 84, 0, 47, 69, 19, 20, 42, 0, 0, 32, 71, 35, 0, 0], +[39, 0, 90, 0, 0, 4, 85, 0, 0, 55, 0, 0, 0, 35, 67, 40, 0, 40, 0, 77], +[98, 63, 0, 71, 0, 50, 0, 2, 61, 0, 38, 0, 0, 0, 0, 75, 0, 40, 33, 56], +[ 0, 73, 0, 64, 0, 38, 0, 35, 61, 0, 0, 52, 0, 7, 0, 51, 0, 0, 0, 34], +[ 0, 0, 28, 0, 34, 5, 63, 45, 14, 42, 60, 16, 76, 54, 99, 0, 28, 30, 0, 0], +[58, 37, 14, 0, 0, 0, 94, 0, 0, 90, 0, 0, 0, 0, 0, 0, 0, 8, 90, 53], +[86, 74, 94, 0, 49, 10, 60, 0, 40, 18, 0, 0, 0, 31, 60, 24, 0, 1, 0, 29], +[53, 0, 0, 97, 0, 0, 58, 0, 0, 39, 44, 47, 0, 0, 0, 12, 50, 0, 0, 11], +[ 4, 0, 92, 10, 28, 0, 0, 89, 0, 0, 18, 54, 23, 39, 0, 2, 0, 48, 0, 92], +[ 0, 0, 90, 77, 95, 33, 0, 0, 49, 22, 39, 0, 0, 0, 0, 0, 0, 40, 0, 0], +[96, 0, 0, 0, 0, 38, 86, 0, 22, 76, 0, 0, 0, 0, 83, 88, 95, 65, 72, 0], +[81, 65, 0, 4, 60, 0, 19, 0, 0, 68, 0, 0, 89, 0, 67, 22, 0, 0, 55, 33]], + ZZ), + DM([ +[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0], +[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]], + ZZ), + ZZ(1), + ), + + ( + 'zz_large_3', + DM([ +[62,35,89,58,22,47,30,28,52,72,17,56,80,26,64,21,10,35,24,42,96,32,23,50,92,37,76,94,63,66], +[20,47,96,34,10,98,19,6,29,2,19,92,61,94,38,41,32,9,5,94,31,58,27,41,72,85,61,62,40,46], +[69,26,35,68,25,52,94,13,38,65,81,10,29,15,5,4,13,99,85,0,80,51,60,60,26,77,85,2,87,25], +[99,58,69,15,52,12,18,7,27,56,12,54,21,92,38,95,33,83,28,1,44,8,29,84,92,12,2,25,46,46], +[93,13,55,48,35,87,24,40,23,35,25,32,0,19,0,85,4,79,26,11,46,75,7,96,76,11,7,57,99,75], +[128,85,26,51,161,173,77,78,85,103,123,58,91,147,38,91,161,36,123,81,102,25,75,59,17,150,112,65,77,143], +[15,59,61,82,12,83,34,8,94,71,66,7,91,21,48,69,26,12,64,38,97,87,38,15,51,33,93,43,66,89], +[74,74,53,39,69,90,41,80,32,66,40,83,87,87,61,38,12,80,24,49,37,90,19,33,56,0,46,57,56,60], +[82,11,0,25,56,58,39,49,92,93,80,38,19,62,33,85,19,61,14,30,45,91,97,34,97,53,92,28,33,43], +[83,79,41,16,95,35,53,45,26,4,71,76,61,69,69,72,87,92,59,72,54,11,22,83,8,57,77,55,19,22], +[49,34,13,31,72,77,52,70,46,41,37,6,42,66,35,6,75,33,62,57,30,14,26,31,9,95,89,13,12,90], +[29,3,49,30,51,32,77,41,38,50,16,1,87,81,93,88,58,91,83,0,38,67,29,64,60,84,5,60,23,28], +[79,51,13,20,89,96,25,8,39,62,86,52,49,81,3,85,86,3,61,24,72,11,49,28,8,55,23,52,65,53], +[96,86,73,20,41,20,37,18,10,61,85,24,40,83,69,41,4,92,23,99,64,33,18,36,32,56,60,98,39,24], +[32,62,47,80,51,66,17,1,9,30,65,75,75,88,99,92,64,53,53,86,38,51,41,14,35,18,39,25,26,32], +[39,21,8,16,33,6,35,85,75,62,43,34,18,68,71,28,32,18,12,0,81,53,1,99,3,5,45,99,35,33], +[19,95,89,45,75,94,92,5,84,93,34,17,50,56,79,98,68,82,65,81,51,90,5,95,33,71,46,61,14,7], +[53,92,8,49,67,84,21,79,49,95,66,48,36,14,62,97,26,45,58,31,83,48,11,89,67,72,91,34,56,89], +[56,76,99,92,40,8,0,16,15,48,35,72,91,46,81,14,86,60,51,7,33,12,53,78,48,21,3,89,15,79], +[81,43,33,49,6,49,36,32,57,74,87,91,17,37,31,17,67,1,40,38,69,8,3,48,59,37,64,97,11,3], +[98,48,77,16,2,48,57,38,63,59,79,35,16,71,60,86,71,41,14,76,80,97,77,69,4,58,22,55,26,73], +[80,47,78,44,31,48,47,29,29,62,19,21,17,24,19,3,53,93,97,57,13,54,12,10,77,66,60,75,32,21], +[86,63,2,13,71,38,86,23,18,15,91,65,77,65,9,92,50,0,17,42,99,80,99,27,10,99,92,9,87,84], +[66,27,72,13,13,15,72,75,39,3,14,71,15,68,10,19,49,54,11,29,47,20,63,13,97,47,24,62,16,96], +[42,63,83,60,49,68,9,53,75,87,40,25,12,63,0,12,0,95,46,46,55,25,89,1,51,1,1,96,80,52], +[35,9,97,13,86,39,66,48,41,57,23,38,11,9,35,72,88,13,41,60,10,64,71,23,1,5,23,57,6,19], +[70,61,5,50,72,60,77,13,41,94,1,45,52,22,99,47,27,18,99,42,16,48,26,9,88,77,10,94,11,92], +[55,68,58,2,72,56,81,52,79,37,1,40,21,46,27,60,37,13,97,42,85,98,69,60,76,44,42,46,29,73], +[73,0,43,17,89,97,45,2,68,14,55,60,95,2,74,85,88,68,93,76,38,76,2,51,45,76,50,79,56,18], +[72,58,41,39,24,80,23,79,44,7,98,75,30,6,85,60,20,58,77,71,90,51,38,80,30,15,33,10,82,8]], + ZZ), + Matrix([ + [eye(29) * 2028539767964472550625641331179545072876560857886207583101, + Matrix([ 4260575808093245475167216057435155595594339172099000182569, + 169148395880755256182802335904188369274227936894862744452, + 4915975976683942569102447281579134986891620721539038348914, + 6113916866367364958834844982578214901958429746875633283248, + 5585689617819894460378537031623265659753379011388162534838, + 359776822829880747716695359574308645968094838905181892423, + -2800926112141776386671436511182421432449325232461665113305, + 941642292388230001722444876624818265766384442910688463158, + 3648811843256146649321864698600908938933015862008642023935, + -4104526163246702252932955226754097174212129127510547462419, + -704814955438106792441896903238080197619233342348191408078, + 1640882266829725529929398131287244562048075707575030019335, + -4068330845192910563212155694231438198040299927120544468520, + 136589038308366497790495711534532612862715724187671166593, + 2544937011460702462290799932536905731142196510605191645593, + 755591839174293940486133926192300657264122907519174116472, + -3683838489869297144348089243628436188645897133242795965021, + -522207137101161299969706310062775465103537953077871128403, + -2260451796032703984456606059649402832441331339246756656334, + -6476809325293587953616004856993300606040336446656916663680, + 3521944238996782387785653800944972787867472610035040989081, + 2270762115788407950241944504104975551914297395787473242379, + -3259947194628712441902262570532921252128444706733549251156, + -5624569821491886970999097239695637132075823246850431083557, + -3262698255682055804320585332902837076064075936601504555698, + 5786719943788937667411185880136324396357603606944869545501, + -955257841973865996077323863289453200904051299086000660036, + -1294235552446355326174641248209752679127075717918392702116, + -3718353510747301598130831152458342785269166356215331448279, + ]),], + [zeros(1, 29), zeros(1, 1)], + ]).to_DM().to_dense(), + ZZ(2028539767964472550625641331179545072876560857886207583101), + ), + + + ( + 'qq_1', + DM([[(1,2), 0], [0, 2]], QQ), + DM([[1, 0], [0, 1]], QQ), + QQ(1), + ), + + ( + # Standard square case + 'qq_2', + DM([[0, 1], + [1, 1]], QQ), + DM([[1, 0], + [0, 1]], QQ), + QQ(1), + ), + + ( + # m < n case + 'qq_3', + DM([[1, 2, 1], + [3, 4, 1]], QQ), + DM([[1, 0, -1], + [0, 1, 1]], QQ), + QQ(1), + ), + + ( + # same m < n but reversed + 'qq_4', + DM([[3, 4, 1], + [1, 2, 1]], QQ), + DM([[1, 0, -1], + [0, 1, 1]], QQ), + QQ(1), + ), + + ( + # m > n case + 'qq_5', + DM([[1, 0], + [1, 3], + [0, 1]], QQ), + DM([[1, 0], + [0, 1], + [0, 0]], QQ), + QQ(1), + ), + + ( + # Example with missing pivot + 'qq_6', + DM([[1, 0, 1], + [3, 0, 1]], QQ), + DM([[1, 0, 0], + [0, 0, 1]], QQ), + QQ(1), + ), + + ( + # This is intended to trigger the threshold where we give up on + # clearing denominators. + 'qq_large_1', + qq_large_1, + DomainMatrix.eye(11, QQ).to_dense(), + QQ(1), + ), + + ( + # This is intended to trigger the threshold where we use rref_den over + # QQ. + 'qq_large_2', + qq_large_2, + DomainMatrix.eye(11, QQ).to_dense(), + QQ(1), + ), + + ( + # Example with missing pivot and no replacement + + # This example is just enough to show a different result from the dense + # and sparse versions of the algorithm: + # + # >>> A = Matrix([[0, 1], [0, 2], [1, 0]]) + # >>> A.to_DM().to_sparse().rref_den()[0].to_Matrix() + # Matrix([ + # [1, 0], + # [0, 1], + # [0, 0]]) + # >>> A.to_DM().to_dense().rref_den()[0].to_Matrix() + # Matrix([ + # [2, 0], + # [0, 2], + # [0, 0]]) + # + 'qq_7', + DM([[0, 1], + [0, 2], + [1, 0]], QQ), + DM([[1, 0], + [0, 1], + [0, 0]], QQ), + QQ(1), + ), + + ( + # Gaussian integers + 'zz_i_1', + DM([[(0,1), 1, 1], + [ 1, 1, 1]], ZZ_I), + DM([[1, 0, 0], + [0, 1, 1]], ZZ_I), + ZZ_I(1), + ), + + ( + # EX: test_issue_23718 + 'EX_1', + DM([ + [a, b, 1], + [c, d, 1]], EX), + DM([[a*d - b*c, 0, -b + d], + [ 0, a*d - b*c, a - c]], EX), + EX(a*d - b*c), + ), + +] + + +def _to_DM(A, ans): + """Convert the answer to DomainMatrix.""" + if isinstance(A, DomainMatrix): + return A.to_dense() + elif isinstance(A, Matrix): + return A.to_DM(ans.domain).to_dense() + + if not (hasattr(A, 'shape') and hasattr(A, 'domain')): + shape, domain = ans.shape, ans.domain + else: + shape, domain = A.shape, A.domain + + if isinstance(A, (DDM, list)): + return DomainMatrix(list(A), shape, domain).to_dense() + elif isinstance(A, (SDM, dict)): + return DomainMatrix(dict(A), shape, domain).to_dense() + else: + assert False # pragma: no cover + + +def _pivots(A_rref): + """Return the pivots from the rref of A.""" + return tuple(sorted(map(min, A_rref.to_sdm().values()))) + + +def _check_cancel(result, rref_ans, den_ans): + """Check the cancelled result.""" + rref, den, pivots = result + if isinstance(rref, (DDM, SDM, list, dict)): + assert type(pivots) is list + pivots = tuple(pivots) + rref = _to_DM(rref, rref_ans) + rref2, den2 = rref.cancel_denom(den) + assert rref2 == rref_ans + assert den2 == den_ans + assert pivots == _pivots(rref) + + +def _check_divide(result, rref_ans, den_ans): + """Check the divided result.""" + rref, pivots = result + if isinstance(rref, (DDM, SDM, list, dict)): + assert type(pivots) is list + pivots = tuple(pivots) + rref_ans = rref_ans.to_field() / den_ans + rref = _to_DM(rref, rref_ans) + assert rref == rref_ans + assert _pivots(rref) == pivots + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_Matrix_rref(name, A, A_rref, den): + K = A.domain + A = A.to_Matrix() + A_rref_found, pivots = A.rref() + if K.is_EX: + A_rref_found = A_rref_found.expand() + _check_divide((A_rref_found, pivots), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_dense_rref(name, A, A_rref, den): + A = A.to_field() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_dense_rref_den(name, A, A_rref, den): + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref(name, A, A_rref, den): + A = A.to_field().to_sparse() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den(name, A, A_rref, den): + A = A.to_sparse() + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den_keep_domain(name, A, A_rref, den): + A = A.to_sparse() + A_rref_f, den_f, pivots_f = A.rref_den(keep_domain=False) + A_rref_f = A_rref_f.to_field() / den_f + _check_divide((A_rref_f, pivots_f), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den_keep_domain_CD(name, A, A_rref, den): + A = A.to_sparse() + A_rref_f, den_f, pivots_f = A.rref_den(keep_domain=False, method='CD') + A_rref_f = A_rref_f.to_field() / den_f + _check_divide((A_rref_f, pivots_f), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_dm_sparse_rref_den_keep_domain_GJ(name, A, A_rref, den): + A = A.to_sparse() + A_rref_f, den_f, pivots_f = A.rref_den(keep_domain=False, method='GJ') + A_rref_f = A_rref_f.to_field() / den_f + _check_divide((A_rref_f, pivots_f), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_rref_den(name, A, A_rref, den): + A = A.to_ddm() + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sdm_rref_den(name, A, A_rref, den): + A = A.to_sdm() + _check_cancel(A.rref_den(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_rref(name, A, A_rref, den): + A = A.to_field().to_ddm() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sdm_rref(name, A, A_rref, den): + A = A.to_field().to_sdm() + _check_divide(A.rref(), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_irref(name, A, A_rref, den): + A = A.to_field().to_ddm().copy() + pivots_found = ddm_irref(A) + _check_divide((A, pivots_found), A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_ddm_irref_den(name, A, A_rref, den): + A = A.to_ddm().copy() + (den_found, pivots_found) = ddm_irref_den(A, A.domain) + result = (A, den_found, pivots_found) + _check_cancel(result, A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sparse_sdm_rref(name, A, A_rref, den): + A = A.to_field().to_sdm() + _check_divide(sdm_irref(A)[:2], A_rref, den) + + +@pytest.mark.parametrize('name, A, A_rref, den', RREF_EXAMPLES) +def test_sparse_sdm_rref_den(name, A, A_rref, den): + A = A.to_sdm().copy() + K = A.domain + _check_cancel(sdm_rref_den(A, K), A_rref, den) diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_sdm.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_sdm.py new file mode 100644 index 0000000000000000000000000000000000000000..cd7e5d460a1b2d44279a2a1772cc901f80ca733e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_sdm.py @@ -0,0 +1,428 @@ +""" +Tests for the basic functionality of the SDM class. +""" + +from itertools import product + +from sympy.core.singleton import S +from sympy.external.gmpy import GROUND_TYPES +from sympy.testing.pytest import raises + +from sympy.polys.domains import QQ, ZZ, EXRAW +from sympy.polys.matrices.sdm import SDM +from sympy.polys.matrices.ddm import DDM +from sympy.polys.matrices.exceptions import (DMBadInputError, DMDomainError, + DMShapeError) + + +def test_SDM(): + A = SDM({0:{0:ZZ(1)}}, (2, 2), ZZ) + assert A.domain == ZZ + assert A.shape == (2, 2) + assert dict(A) == {0:{0:ZZ(1)}} + + raises(DMBadInputError, lambda: SDM({5:{1:ZZ(0)}}, (2, 2), ZZ)) + raises(DMBadInputError, lambda: SDM({0:{5:ZZ(0)}}, (2, 2), ZZ)) + + +def test_DDM_str(): + sdm = SDM({0:{0:ZZ(1)}, 1:{1:ZZ(1)}}, (2, 2), ZZ) + assert str(sdm) == '{0: {0: 1}, 1: {1: 1}}' + if GROUND_TYPES == 'gmpy': # pragma: no cover + assert repr(sdm) == 'SDM({0: {0: mpz(1)}, 1: {1: mpz(1)}}, (2, 2), ZZ)' + else: # pragma: no cover + assert repr(sdm) == 'SDM({0: {0: 1}, 1: {1: 1}}, (2, 2), ZZ)' + + +def test_SDM_new(): + A = SDM({0:{0:ZZ(1)}}, (2, 2), ZZ) + B = A.new({}, (2, 2), ZZ) + assert B == SDM({}, (2, 2), ZZ) + + +def test_SDM_copy(): + A = SDM({0:{0:ZZ(1)}}, (2, 2), ZZ) + B = A.copy() + assert A == B + A[0][0] = ZZ(2) + assert A != B + + +def test_SDM_from_list(): + A = SDM.from_list([[ZZ(0), ZZ(1)], [ZZ(1), ZZ(0)]], (2, 2), ZZ) + assert A == SDM({0:{1:ZZ(1)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + + raises(DMBadInputError, lambda: SDM.from_list([[ZZ(0)], [ZZ(0), ZZ(1)]], (2, 2), ZZ)) + raises(DMBadInputError, lambda: SDM.from_list([[ZZ(0), ZZ(1)]], (2, 2), ZZ)) + + +def test_SDM_to_list(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_list() == [[ZZ(0), ZZ(1)], [ZZ(0), ZZ(0)]] + + A = SDM({}, (0, 2), ZZ) + assert A.to_list() == [] + + A = SDM({}, (2, 0), ZZ) + assert A.to_list() == [[], []] + + +def test_SDM_to_list_flat(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_list_flat() == [ZZ(0), ZZ(1), ZZ(0), ZZ(0)] + + +def test_SDM_to_dok(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_dok() == {(0, 1): ZZ(1)} + + +def test_SDM_from_ddm(): + A = DDM([[ZZ(1), ZZ(0)], [ZZ(1), ZZ(0)]], (2, 2), ZZ) + B = SDM.from_ddm(A) + assert B.domain == ZZ + assert B.shape == (2, 2) + assert dict(B) == {0:{0:ZZ(1)}, 1:{0:ZZ(1)}} + + +def test_SDM_to_ddm(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + B = DDM([[ZZ(0), ZZ(1)], [ZZ(0), ZZ(0)]], (2, 2), ZZ) + assert A.to_ddm() == B + + +def test_SDM_to_sdm(): + A = SDM({0:{1: ZZ(1)}}, (2, 2), ZZ) + assert A.to_sdm() == A + + +def test_SDM_getitem(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + assert A.getitem(0, 0) == ZZ.zero + assert A.getitem(0, 1) == ZZ.one + assert A.getitem(1, 0) == ZZ.zero + assert A.getitem(-2, -2) == ZZ.zero + assert A.getitem(-2, -1) == ZZ.one + assert A.getitem(-1, -2) == ZZ.zero + raises(IndexError, lambda: A.getitem(2, 0)) + raises(IndexError, lambda: A.getitem(0, 2)) + + +def test_SDM_setitem(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + A.setitem(0, 0, ZZ(1)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}}, (2, 2), ZZ) + A.setitem(1, 0, ZZ(1)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}, 1:{0:ZZ(1)}}, (2, 2), ZZ) + A.setitem(1, 0, ZZ(0)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}}, (2, 2), ZZ) + # Repeat the above test so that this time the row is empty + A.setitem(1, 0, ZZ(0)) + assert A == SDM({0:{0:ZZ(1), 1:ZZ(1)}}, (2, 2), ZZ) + A.setitem(0, 0, ZZ(0)) + assert A == SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + # This time the row is there but column is empty + A.setitem(0, 0, ZZ(0)) + assert A == SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + raises(IndexError, lambda: A.setitem(2, 0, ZZ(1))) + raises(IndexError, lambda: A.setitem(0, 2, ZZ(1))) + + +def test_SDM_extract_slice(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = A.extract_slice(slice(1, 2), slice(1, 2)) + assert B == SDM({0:{0:ZZ(4)}}, (1, 1), ZZ) + + +def test_SDM_extract(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = A.extract([1], [1]) + assert B == SDM({0:{0:ZZ(4)}}, (1, 1), ZZ) + B = A.extract([1, 0], [1, 0]) + assert B == SDM({0:{0:ZZ(4), 1:ZZ(3)}, 1:{0:ZZ(2), 1:ZZ(1)}}, (2, 2), ZZ) + B = A.extract([1, 1], [1, 1]) + assert B == SDM({0:{0:ZZ(4), 1:ZZ(4)}, 1:{0:ZZ(4), 1:ZZ(4)}}, (2, 2), ZZ) + B = A.extract([-1], [-1]) + assert B == SDM({0:{0:ZZ(4)}}, (1, 1), ZZ) + + A = SDM({}, (2, 2), ZZ) + B = A.extract([0, 1, 0], [0, 0]) + assert B == SDM({}, (3, 2), ZZ) + + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + assert A.extract([], []) == SDM.zeros((0, 0), ZZ) + assert A.extract([1], []) == SDM.zeros((1, 0), ZZ) + assert A.extract([], [1]) == SDM.zeros((0, 1), ZZ) + + raises(IndexError, lambda: A.extract([2], [0])) + raises(IndexError, lambda: A.extract([0], [2])) + raises(IndexError, lambda: A.extract([-3], [0])) + raises(IndexError, lambda: A.extract([0], [-3])) + + +def test_SDM_zeros(): + A = SDM.zeros((2, 2), ZZ) + assert A.domain == ZZ + assert A.shape == (2, 2) + assert dict(A) == {} + +def test_SDM_ones(): + A = SDM.ones((1, 2), QQ) + assert A.domain == QQ + assert A.shape == (1, 2) + assert dict(A) == {0:{0:QQ(1), 1:QQ(1)}} + +def test_SDM_eye(): + A = SDM.eye((2, 2), ZZ) + assert A.domain == ZZ + assert A.shape == (2, 2) + assert dict(A) == {0:{0:ZZ(1)}, 1:{1:ZZ(1)}} + + +def test_SDM_diag(): + A = SDM.diag([ZZ(1), ZZ(2)], ZZ, (2, 3)) + assert A == SDM({0:{0:ZZ(1)}, 1:{1:ZZ(2)}}, (2, 3), ZZ) + + +def test_SDM_transpose(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1), 1:ZZ(3)}, 1:{0:ZZ(2), 1:ZZ(4)}}, (2, 2), ZZ) + assert A.transpose() == B + + A = SDM({0:{1:ZZ(2)}}, (2, 2), ZZ) + B = SDM({1:{0:ZZ(2)}}, (2, 2), ZZ) + assert A.transpose() == B + + A = SDM({0:{1:ZZ(2)}}, (1, 2), ZZ) + B = SDM({1:{0:ZZ(2)}}, (2, 1), ZZ) + assert A.transpose() == B + + +def test_SDM_mul(): + A = SDM({0:{0:ZZ(2)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(4)}}, (2, 2), ZZ) + assert A*ZZ(2) == B + assert ZZ(2)*A == B + + raises(TypeError, lambda: A*QQ(1, 2)) + raises(TypeError, lambda: QQ(1, 2)*A) + + +def test_SDM_mul_elementwise(): + A = SDM({0:{0:ZZ(2), 1:ZZ(2)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(4)}, 1:{0:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(8)}}, (2, 2), ZZ) + assert A.mul_elementwise(B) == C + assert B.mul_elementwise(A) == C + + Aq = A.convert_to(QQ) + A1 = SDM({0:{0:ZZ(1)}}, (1, 1), ZZ) + + raises(DMDomainError, lambda: Aq.mul_elementwise(B)) + raises(DMShapeError, lambda: A1.mul_elementwise(B)) + + +def test_SDM_matmul(): + A = SDM({0:{0:ZZ(2)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(4)}}, (2, 2), ZZ) + assert A.matmul(A) == A*A == B + + C = SDM({0:{0:ZZ(2)}}, (2, 2), QQ) + raises(DMDomainError, lambda: A.matmul(C)) + + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(7), 1:ZZ(10)}, 1:{0:ZZ(15), 1:ZZ(22)}}, (2, 2), ZZ) + assert A.matmul(A) == A*A == B + + A22 = SDM({0:{0:ZZ(4)}}, (2, 2), ZZ) + A32 = SDM({0:{0:ZZ(2)}}, (3, 2), ZZ) + A23 = SDM({0:{0:ZZ(4)}}, (2, 3), ZZ) + A33 = SDM({0:{0:ZZ(8)}}, (3, 3), ZZ) + A22 = SDM({0:{0:ZZ(8)}}, (2, 2), ZZ) + assert A32.matmul(A23) == A33 + assert A23.matmul(A32) == A22 + # XXX: @ not supported by SDM... + #assert A32.matmul(A23) == A32 @ A23 == A33 + #assert A23.matmul(A32) == A23 @ A32 == A22 + #raises(DMShapeError, lambda: A23 @ A22) + raises(DMShapeError, lambda: A23.matmul(A22)) + + A = SDM({0: {0: ZZ(-1), 1: ZZ(1)}}, (1, 2), ZZ) + B = SDM({0: {0: ZZ(-1)}, 1: {0: ZZ(-1)}}, (2, 1), ZZ) + assert A.matmul(B) == A*B == SDM({}, (1, 1), ZZ) + + +def test_matmul_exraw(): + + def dm(d): + result = {} + for i, row in d.items(): + row = {j:val for j, val in row.items() if val} + if row: + result[i] = row + return SDM(result, (2, 2), EXRAW) + + values = [S.NegativeInfinity, S.NegativeOne, S.Zero, S.One, S.Infinity] + for a, b, c, d in product(*[values]*4): + Ad = dm({0: {0:a, 1:b}, 1: {0:c, 1:d}}) + Ad2 = dm({0: {0:a*a + b*c, 1:a*b + b*d}, 1:{0:c*a + d*c, 1: c*b + d*d}}) + assert Ad * Ad == Ad2 + + +def test_SDM_add(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(1), 1:ZZ(1)}, 1:{1:ZZ(6)}}, (2, 2), ZZ) + assert A.add(B) == B.add(A) == A + B == B + A == C + + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(1), 1:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + assert A.add(B) == B.add(A) == A + B == B + A == C + + raises(TypeError, lambda: A + []) + + +def test_SDM_sub(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{0:ZZ(1)}, 1:{0:ZZ(-2), 1:ZZ(3)}}, (2, 2), ZZ) + C = SDM({0:{0:ZZ(-1), 1:ZZ(1)}, 1:{0:ZZ(4)}}, (2, 2), ZZ) + assert A.sub(B) == A - B == C + + raises(TypeError, lambda: A - []) + + +def test_SDM_neg(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{1:ZZ(-1)}, 1:{0:ZZ(-2), 1:ZZ(-3)}}, (2, 2), ZZ) + assert A.neg() == -A == B + + +def test_SDM_convert_to(): + A = SDM({0:{1:ZZ(1)}, 1:{0:ZZ(2), 1:ZZ(3)}}, (2, 2), ZZ) + B = SDM({0:{1:QQ(1)}, 1:{0:QQ(2), 1:QQ(3)}}, (2, 2), QQ) + C = A.convert_to(QQ) + assert C == B + assert C.domain == QQ + + D = A.convert_to(ZZ) + assert D == A + assert D.domain == ZZ + + +def test_SDM_hstack(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({1:{1:ZZ(1)}}, (2, 2), ZZ) + AA = SDM({0:{1:ZZ(1), 3:ZZ(1)}}, (2, 4), ZZ) + AB = SDM({0:{1:ZZ(1)}, 1:{3:ZZ(1)}}, (2, 4), ZZ) + assert SDM.hstack(A) == A + assert SDM.hstack(A, A) == AA + assert SDM.hstack(A, B) == AB + + +def test_SDM_vstack(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({1:{1:ZZ(1)}}, (2, 2), ZZ) + AA = SDM({0:{1:ZZ(1)}, 2:{1:ZZ(1)}}, (4, 2), ZZ) + AB = SDM({0:{1:ZZ(1)}, 3:{1:ZZ(1)}}, (4, 2), ZZ) + assert SDM.vstack(A) == A + assert SDM.vstack(A, A) == AA + assert SDM.vstack(A, B) == AB + + +def test_SDM_applyfunc(): + A = SDM({0:{1:ZZ(1)}}, (2, 2), ZZ) + B = SDM({0:{1:ZZ(2)}}, (2, 2), ZZ) + assert A.applyfunc(lambda x: 2*x, ZZ) == B + + +def test_SDM_inv(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + B = SDM({0:{0:QQ(-2), 1:QQ(1)}, 1:{0:QQ(3, 2), 1:QQ(-1, 2)}}, (2, 2), QQ) + assert A.inv() == B + + +def test_SDM_det(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + assert A.det() == QQ(-2) + + +def test_SDM_lu(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + L = SDM({0:{0:QQ(1)}, 1:{0:QQ(3), 1:QQ(1)}}, (2, 2), QQ) + #U = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(-2)}}, (2, 2), QQ) + #swaps = [] + # This doesn't quite work. U has some nonzero elements in the lower part. + #assert A.lu() == (L, U, swaps) + assert A.lu()[0] == L + + +def test_SDM_lu_solve(): + A = SDM({0:{0:QQ(1), 1:QQ(2)}, 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + b = SDM({0:{0:QQ(1)}, 1:{0:QQ(2)}}, (2, 1), QQ) + x = SDM({1:{0:QQ(1, 2)}}, (2, 1), QQ) + assert A.matmul(x) == b + assert A.lu_solve(b) == x + + +def test_SDM_charpoly(): + A = SDM({0:{0:ZZ(1), 1:ZZ(2)}, 1:{0:ZZ(3), 1:ZZ(4)}}, (2, 2), ZZ) + assert A.charpoly() == [ZZ(1), ZZ(-5), ZZ(-2)] + + +def test_SDM_nullspace(): + # More tests are in test_nullspace.py + A = SDM({0:{0:QQ(1), 1:QQ(1)}}, (2, 2), QQ) + assert A.nullspace()[0] == SDM({0:{0:QQ(-1), 1:QQ(1)}}, (1, 2), QQ) + + +def test_SDM_rref(): + # More tests are in test_rref.py + + A = SDM({0:{0:QQ(1), 1:QQ(2)}, + 1:{0:QQ(3), 1:QQ(4)}}, (2, 2), QQ) + A_rref = SDM({0:{0:QQ(1)}, 1:{1:QQ(1)}}, (2, 2), QQ) + assert A.rref() == (A_rref, [0, 1]) + + A = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(2)}, + 1: {0: QQ(3), 2: QQ(4)}}, (2, 3), ZZ) + A_rref = SDM({0: {0: QQ(1,1), 2: QQ(4,3)}, + 1: {1: QQ(1,1), 2: QQ(1,3)}}, (2, 3), QQ) + assert A.rref() == (A_rref, [0, 1]) + + +def test_SDM_particular(): + A = SDM({0:{0:QQ(1)}}, (2, 2), QQ) + Apart = SDM.zeros((1, 2), QQ) + assert A.particular() == Apart + + +def test_SDM_is_zero_matrix(): + A = SDM({0: {0: QQ(1)}}, (2, 2), QQ) + Azero = SDM.zeros((1, 2), QQ) + assert A.is_zero_matrix() is False + assert Azero.is_zero_matrix() is True + + +def test_SDM_is_upper(): + A = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {2: QQ(8), 3: QQ(9)}}, (3, 4), QQ) + B = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {1: QQ(7), 2: QQ(8), 3: QQ(9)}}, (3, 4), QQ) + assert A.is_upper() is True + assert B.is_upper() is False + + +def test_SDM_is_lower(): + A = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {2: QQ(8), 3: QQ(9)}}, (3, 4), QQ + ).transpose() + B = SDM({0: {0: QQ(1), 1: QQ(2), 2: QQ(3), 3: QQ(4)}, + 1: {1: QQ(5), 2: QQ(6), 3: QQ(7)}, + 2: {1: QQ(7), 2: QQ(8), 3: QQ(9)}}, (3, 4), QQ + ).transpose() + assert A.is_lower() is True + assert B.is_lower() is False diff --git a/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_xxm.py b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_xxm.py new file mode 100644 index 0000000000000000000000000000000000000000..96a660123aa232107258a6f0e0a3e24a0bba07b4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/matrices/tests/test_xxm.py @@ -0,0 +1,864 @@ +# +# Test basic features of DDM, SDM and DFM. +# +# These three types are supposed to be interchangeable, so we should use the +# same tests for all of them for the most part. +# +# The tests here cover the basic part of the inerface that the three types +# should expose and that DomainMatrix should mostly rely on. +# +# More in-depth tests of the heavier algorithms like rref etc should go in +# their own test files. +# +# Any new methods added to the DDM, SDM or DFM classes should be tested here +# and added to all classes. +# + +from sympy.external.gmpy import GROUND_TYPES + +from sympy import ZZ, QQ, GF, ZZ_I, symbols + +from sympy.polys.matrices.exceptions import ( + DMBadInputError, + DMDomainError, + DMNonSquareMatrixError, + DMNonInvertibleMatrixError, + DMShapeError, +) + +from sympy.polys.matrices.domainmatrix import DM, DomainMatrix, DDM, SDM, DFM + +from sympy.testing.pytest import raises, skip +import pytest + + +def test_XXM_constructors(): + """Test the DDM, etc constructors.""" + + lol = [ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6)], + ] + dod = { + 0: {0: ZZ(1), 1: ZZ(2)}, + 1: {0: ZZ(3), 1: ZZ(4)}, + 2: {0: ZZ(5), 1: ZZ(6)}, + } + + lol_0x0 = [] + lol_0x2 = [] + lol_2x0 = [[], []] + dod_0x0 = {} + dod_0x2 = {} + dod_2x0 = {} + + lol_bad = [ + [ZZ(1), ZZ(2)], + [ZZ(3), ZZ(4)], + [ZZ(5), ZZ(6), ZZ(7)], + ] + dod_bad = { + 0: {0: ZZ(1), 1: ZZ(2)}, + 1: {0: ZZ(3), 1: ZZ(4)}, + 2: {0: ZZ(5), 1: ZZ(6), 2: ZZ(7)}, + } + + XDM_dense = [DDM] + XDM_sparse = [SDM] + + if GROUND_TYPES == 'flint': + XDM_dense.append(DFM) + + for XDM in XDM_dense: + + A = XDM(lol, (3, 2), ZZ) + assert A.rows == 3 + assert A.cols == 2 + assert A.domain == ZZ + assert A.shape == (3, 2) + if XDM is not DFM: + assert ZZ.of_type(A[0][0]) is True + else: + assert ZZ.of_type(A.rep[0, 0]) is True + + Adm = DomainMatrix(lol, (3, 2), ZZ) + if XDM is DFM: + assert Adm.rep == A + assert Adm.rep.to_ddm() != A + elif GROUND_TYPES == 'flint': + assert Adm.rep.to_ddm() == A + assert Adm.rep != A + else: + assert Adm.rep == A + assert Adm.rep.to_ddm() == A + + assert XDM(lol_0x0, (0, 0), ZZ).shape == (0, 0) + assert XDM(lol_0x2, (0, 2), ZZ).shape == (0, 2) + assert XDM(lol_2x0, (2, 0), ZZ).shape == (2, 0) + raises(DMBadInputError, lambda: XDM(lol, (2, 3), ZZ)) + raises(DMBadInputError, lambda: XDM(lol_bad, (3, 2), ZZ)) + raises(DMBadInputError, lambda: XDM(dod, (3, 2), ZZ)) + + for XDM in XDM_sparse: + + A = XDM(dod, (3, 2), ZZ) + assert A.rows == 3 + assert A.cols == 2 + assert A.domain == ZZ + assert A.shape == (3, 2) + assert ZZ.of_type(A[0][0]) is True + + assert DomainMatrix(dod, (3, 2), ZZ).rep == A + + assert XDM(dod_0x0, (0, 0), ZZ).shape == (0, 0) + assert XDM(dod_0x2, (0, 2), ZZ).shape == (0, 2) + assert XDM(dod_2x0, (2, 0), ZZ).shape == (2, 0) + raises(DMBadInputError, lambda: XDM(dod, (2, 3), ZZ)) + raises(DMBadInputError, lambda: XDM(lol, (3, 2), ZZ)) + raises(DMBadInputError, lambda: XDM(dod_bad, (3, 2), ZZ)) + + raises(DMBadInputError, lambda: DomainMatrix(lol, (2, 3), ZZ)) + raises(DMBadInputError, lambda: DomainMatrix(lol_bad, (3, 2), ZZ)) + raises(DMBadInputError, lambda: DomainMatrix(dod_bad, (3, 2), ZZ)) + + +def test_XXM_eq(): + """Test equality for DDM, SDM, DFM and DomainMatrix.""" + + lol1 = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + dod1 = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}} + + lol2 = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(5)]] + dod2 = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(5)}} + + A1_ddm = DDM(lol1, (2, 2), ZZ) + A1_sdm = SDM(dod1, (2, 2), ZZ) + A1_dm_d = DomainMatrix(lol1, (2, 2), ZZ) + A1_dm_s = DomainMatrix(dod1, (2, 2), ZZ) + + A2_ddm = DDM(lol2, (2, 2), ZZ) + A2_sdm = SDM(dod2, (2, 2), ZZ) + A2_dm_d = DomainMatrix(lol2, (2, 2), ZZ) + A2_dm_s = DomainMatrix(dod2, (2, 2), ZZ) + + A1_all = [A1_ddm, A1_sdm, A1_dm_d, A1_dm_s] + A2_all = [A2_ddm, A2_sdm, A2_dm_d, A2_dm_s] + + if GROUND_TYPES == 'flint': + + A1_dfm = DFM([[1, 2], [3, 4]], (2, 2), ZZ) + A2_dfm = DFM([[1, 2], [3, 5]], (2, 2), ZZ) + + A1_all.append(A1_dfm) + A2_all.append(A2_dfm) + + for n, An in enumerate(A1_all): + for m, Am in enumerate(A1_all): + if n == m: + assert (An == Am) is True + assert (An != Am) is False + else: + assert (An == Am) is False + assert (An != Am) is True + + for n, An in enumerate(A2_all): + for m, Am in enumerate(A2_all): + if n == m: + assert (An == Am) is True + assert (An != Am) is False + else: + assert (An == Am) is False + assert (An != Am) is True + + for n, A1 in enumerate(A1_all): + for m, A2 in enumerate(A2_all): + assert (A1 == A2) is False + assert (A1 != A2) is True + + +def test_to_XXM(): + """Test to_ddm etc. for DDM, SDM, DFM and DomainMatrix.""" + + lol = [[ZZ(1), ZZ(2)], [ZZ(3), ZZ(4)]] + dod = {0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3), 1: ZZ(4)}} + + A_ddm = DDM(lol, (2, 2), ZZ) + A_sdm = SDM(dod, (2, 2), ZZ) + A_dm_d = DomainMatrix(lol, (2, 2), ZZ) + A_dm_s = DomainMatrix(dod, (2, 2), ZZ) + + A_all = [A_ddm, A_sdm, A_dm_d, A_dm_s] + + if GROUND_TYPES == 'flint': + A_dfm = DFM(lol, (2, 2), ZZ) + A_all.append(A_dfm) + + for A in A_all: + assert A.to_ddm() == A_ddm + assert A.to_sdm() == A_sdm + if GROUND_TYPES != 'flint': + raises(NotImplementedError, lambda: A.to_dfm()) + assert A.to_dfm_or_ddm() == A_ddm + + # Add e.g. DDM.to_DM()? + # assert A.to_DM() == A_dm + + if GROUND_TYPES == 'flint': + for A in A_all: + assert A.to_dfm() == A_dfm + for K in [ZZ, QQ, GF(5), ZZ_I]: + if isinstance(A, DFM) and not DFM._supports_domain(K): + raises(NotImplementedError, lambda: A.convert_to(K)) + else: + A_K = A.convert_to(K) + if DFM._supports_domain(K): + A_dfm_K = A_dfm.convert_to(K) + assert A_K.to_dfm() == A_dfm_K + assert A_K.to_dfm_or_ddm() == A_dfm_K + else: + raises(NotImplementedError, lambda: A_K.to_dfm()) + assert A_K.to_dfm_or_ddm() == A_ddm.convert_to(K) + + +def test_DFM_domains(): + """Test which domains are supported by DFM.""" + + x, y = symbols('x, y') + + if GROUND_TYPES in ('python', 'gmpy'): + + supported = [] + flint_funcs = {} + not_supported = [ZZ, QQ, GF(5), QQ[x], QQ[x,y]] + + elif GROUND_TYPES == 'flint': + + import flint + supported = [ZZ, QQ] + flint_funcs = { + ZZ: flint.fmpz_mat, + QQ: flint.fmpq_mat, + } + not_supported = [ + # This could be supported but not yet implemented in SymPy: + GF(5), + # Other domains could be supported but not implemented as matrices + # in python-flint: + QQ[x], + QQ[x,y], + QQ.frac_field(x,y), + # Others would potentially never be supported by python-flint: + ZZ_I, + ] + + else: + assert False, "Unknown GROUND_TYPES: %s" % GROUND_TYPES + + for domain in supported: + assert DFM._supports_domain(domain) is True + assert DFM._get_flint_func(domain) == flint_funcs[domain] + for domain in not_supported: + assert DFM._supports_domain(domain) is False + raises(NotImplementedError, lambda: DFM._get_flint_func(domain)) + + +def _DM(lol, typ, K): + """Make a DM of type typ over K from lol.""" + A = DM(lol, K) + + if typ == 'DDM': + return A.to_ddm() + elif typ == 'SDM': + return A.to_sdm() + elif typ == 'DFM': + if GROUND_TYPES != 'flint': + skip("DFM not supported in this ground type") + return A.to_dfm() + else: + assert False, "Unknown type %s" % typ + + +def _DMZ(lol, typ): + """Make a DM of type typ over ZZ from lol.""" + return _DM(lol, typ, ZZ) + + +def _DMQ(lol, typ): + """Make a DM of type typ over QQ from lol.""" + return _DM(lol, typ, QQ) + + +def DM_ddm(lol, K): + """Make a DDM over K from lol.""" + return _DM(lol, 'DDM', K) + + +def DM_sdm(lol, K): + """Make a SDM over K from lol.""" + return _DM(lol, 'SDM', K) + + +def DM_dfm(lol, K): + """Make a DFM over K from lol.""" + return _DM(lol, 'DFM', K) + + +def DMZ_ddm(lol): + """Make a DDM from lol.""" + return _DMZ(lol, 'DDM') + + +def DMZ_sdm(lol): + """Make a SDM from lol.""" + return _DMZ(lol, 'SDM') + + +def DMZ_dfm(lol): + """Make a DFM from lol.""" + return _DMZ(lol, 'DFM') + + +def DMQ_ddm(lol): + """Make a DDM from lol.""" + return _DMQ(lol, 'DDM') + + +def DMQ_sdm(lol): + """Make a SDM from lol.""" + return _DMQ(lol, 'SDM') + + +def DMQ_dfm(lol): + """Make a DFM from lol.""" + return _DMQ(lol, 'DFM') + + +DM_all = [DM_ddm, DM_sdm, DM_dfm] +DMZ_all = [DMZ_ddm, DMZ_sdm, DMZ_dfm] +DMQ_all = [DMQ_ddm, DMQ_sdm, DMQ_dfm] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XDM_getitem(DM): + """Test getitem for DDM, etc.""" + + lol = [[0, 1], [2, 0]] + A = DM(lol) + m, n = A.shape + + indices = [-3, -2, -1, 0, 1, 2] + + for i in indices: + for j in indices: + if -2 <= i < m and -2 <= j < n: + assert A.getitem(i, j) == ZZ(lol[i][j]) + else: + raises(IndexError, lambda: A.getitem(i, j)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XDM_setitem(DM): + """Test setitem for DDM, etc.""" + + A = DM([[0, 1, 2], [3, 4, 5]]) + + A.setitem(0, 0, ZZ(6)) + assert A == DM([[6, 1, 2], [3, 4, 5]]) + + A.setitem(0, 1, ZZ(7)) + assert A == DM([[6, 7, 2], [3, 4, 5]]) + + A.setitem(0, 2, ZZ(8)) + assert A == DM([[6, 7, 8], [3, 4, 5]]) + + A.setitem(0, -1, ZZ(9)) + assert A == DM([[6, 7, 9], [3, 4, 5]]) + + A.setitem(0, -2, ZZ(10)) + assert A == DM([[6, 10, 9], [3, 4, 5]]) + + A.setitem(0, -3, ZZ(11)) + assert A == DM([[11, 10, 9], [3, 4, 5]]) + + raises(IndexError, lambda: A.setitem(0, 3, ZZ(12))) + raises(IndexError, lambda: A.setitem(0, -4, ZZ(13))) + + A.setitem(1, 0, ZZ(14)) + assert A == DM([[11, 10, 9], [14, 4, 5]]) + + A.setitem(1, 1, ZZ(15)) + assert A == DM([[11, 10, 9], [14, 15, 5]]) + + A.setitem(-1, 1, ZZ(16)) + assert A == DM([[11, 10, 9], [14, 16, 5]]) + + A.setitem(-2, 1, ZZ(17)) + assert A == DM([[11, 17, 9], [14, 16, 5]]) + + raises(IndexError, lambda: A.setitem(2, 0, ZZ(18))) + raises(IndexError, lambda: A.setitem(-3, 0, ZZ(19))) + + A.setitem(1, 2, ZZ(0)) + assert A == DM([[11, 17, 9], [14, 16, 0]]) + + A.setitem(1, -2, ZZ(0)) + assert A == DM([[11, 17, 9], [14, 0, 0]]) + + A.setitem(1, -3, ZZ(0)) + assert A == DM([[11, 17, 9], [0, 0, 0]]) + + A.setitem(0, 0, ZZ(0)) + assert A == DM([[0, 17, 9], [0, 0, 0]]) + + A.setitem(0, -1, ZZ(0)) + assert A == DM([[0, 17, 0], [0, 0, 0]]) + + A.setitem(0, 0, ZZ(0)) + assert A == DM([[0, 17, 0], [0, 0, 0]]) + + A.setitem(0, -2, ZZ(0)) + assert A == DM([[0, 0, 0], [0, 0, 0]]) + + A.setitem(0, -3, ZZ(1)) + assert A == DM([[1, 0, 0], [0, 0, 0]]) + + +class _Sliced: + def __getitem__(self, item): + return item + + +_slice = _Sliced() + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_extract_slice(DM): + A = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.extract_slice(*_slice[:,:]) == A + assert A.extract_slice(*_slice[1:,:]) == DM([[4, 5, 6], [7, 8, 9]]) + assert A.extract_slice(*_slice[1:,1:]) == DM([[5, 6], [8, 9]]) + assert A.extract_slice(*_slice[1:,:-1]) == DM([[4, 5], [7, 8]]) + assert A.extract_slice(*_slice[1:,:-1:2]) == DM([[4], [7]]) + assert A.extract_slice(*_slice[:,::2]) == DM([[1, 3], [4, 6], [7, 9]]) + assert A.extract_slice(*_slice[::2,:]) == DM([[1, 2, 3], [7, 8, 9]]) + assert A.extract_slice(*_slice[::2,::2]) == DM([[1, 3], [7, 9]]) + assert A.extract_slice(*_slice[::2,::-2]) == DM([[3, 1], [9, 7]]) + assert A.extract_slice(*_slice[::-2,::2]) == DM([[7, 9], [1, 3]]) + assert A.extract_slice(*_slice[::-2,::-2]) == DM([[9, 7], [3, 1]]) + assert A.extract_slice(*_slice[:,::-1]) == DM([[3, 2, 1], [6, 5, 4], [9, 8, 7]]) + assert A.extract_slice(*_slice[::-1,:]) == DM([[7, 8, 9], [4, 5, 6], [1, 2, 3]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_extract(DM): + + A = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + + assert A.extract([0, 1, 2], [0, 1, 2]) == A + assert A.extract([1, 2], [1, 2]) == DM([[5, 6], [8, 9]]) + assert A.extract([1, 2], [0, 1]) == DM([[4, 5], [7, 8]]) + assert A.extract([1, 2], [0, 2]) == DM([[4, 6], [7, 9]]) + assert A.extract([1, 2], [0]) == DM([[4], [7]]) + assert A.extract([1, 2], []) == DM([[1]]).zeros((2, 0), ZZ) + assert A.extract([], [0, 1, 2]) == DM([[1]]).zeros((0, 3), ZZ) + + raises(IndexError, lambda: A.extract([1, 2], [0, 3])) + raises(IndexError, lambda: A.extract([1, 2], [0, -4])) + raises(IndexError, lambda: A.extract([3, 1], [0, 1])) + raises(IndexError, lambda: A.extract([-4, 2], [3, 1])) + + B = DM([[0, 0, 0], [0, 0, 0], [0, 0, 0]]) + assert B.extract([1, 2], [1, 2]) == DM([[0, 0], [0, 0]]) + + +def test_XXM_str(): + + A = DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ) + + assert str(A) == \ + 'DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert str(A.to_ddm()) == \ + '[[1, 2, 3], [4, 5, 6], [7, 8, 9]]' + assert str(A.to_sdm()) == \ + '{0: {0: 1, 1: 2, 2: 3}, 1: {0: 4, 1: 5, 2: 6}, 2: {0: 7, 1: 8, 2: 9}}' + + assert repr(A) == \ + 'DomainMatrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert repr(A.to_ddm()) == \ + 'DDM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert repr(A.to_sdm()) == \ + 'SDM({0: {0: 1, 1: 2, 2: 3}, 1: {0: 4, 1: 5, 2: 6}, 2: {0: 7, 1: 8, 2: 9}}, (3, 3), ZZ)' + + B = DomainMatrix({0: {0: ZZ(1), 1: ZZ(2)}, 1: {0: ZZ(3)}}, (2, 2), ZZ) + + assert str(B) == \ + 'DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)' + assert str(B.to_ddm()) == \ + '[[1, 2], [3, 0]]' + assert str(B.to_sdm()) == \ + '{0: {0: 1, 1: 2}, 1: {0: 3}}' + + assert repr(B) == \ + 'DomainMatrix({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)' + + if GROUND_TYPES != 'gmpy': + assert repr(B.to_ddm()) == \ + 'DDM([[1, 2], [3, 0]], (2, 2), ZZ)' + assert repr(B.to_sdm()) == \ + 'SDM({0: {0: 1, 1: 2}, 1: {0: 3}}, (2, 2), ZZ)' + else: + assert repr(B.to_ddm()) == \ + 'DDM([[mpz(1), mpz(2)], [mpz(3), mpz(0)]], (2, 2), ZZ)' + assert repr(B.to_sdm()) == \ + 'SDM({0: {0: mpz(1), 1: mpz(2)}, 1: {0: mpz(3)}}, (2, 2), ZZ)' + + if GROUND_TYPES == 'flint': + + assert str(A.to_dfm()) == \ + '[[1, 2, 3], [4, 5, 6], [7, 8, 9]]' + assert str(B.to_dfm()) == \ + '[[1, 2], [3, 0]]' + + assert repr(A.to_dfm()) == \ + 'DFM([[1, 2, 3], [4, 5, 6], [7, 8, 9]], (3, 3), ZZ)' + assert repr(B.to_dfm()) == \ + 'DFM([[1, 2], [3, 0]], (2, 2), ZZ)' + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_list(DM): + T = type(DM([[0]])) + + lol = [[1, 2, 4], [4, 5, 6]] + lol_ZZ = [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6)]] + lol_ZZ_bad = [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6), ZZ(7)]] + + assert T.from_list(lol_ZZ, (2, 3), ZZ) == DM(lol) + raises(DMBadInputError, lambda: T.from_list(lol_ZZ_bad, (3, 2), ZZ)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_list(DM): + lol = [[1, 2, 4], [4, 5, 6]] + assert DM(lol).to_list() == [[ZZ(1), ZZ(2), ZZ(4)], [ZZ(4), ZZ(5), ZZ(6)]] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_list_flat(DM): + lol = [[1, 2, 4], [4, 5, 6]] + assert DM(lol).to_list_flat() == [ZZ(1), ZZ(2), ZZ(4), ZZ(4), ZZ(5), ZZ(6)] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_list_flat(DM): + T = type(DM([[0]])) + flat = [ZZ(1), ZZ(2), ZZ(4), ZZ(4), ZZ(5), ZZ(6)] + assert T.from_list_flat(flat, (2, 3), ZZ) == DM([[1, 2, 4], [4, 5, 6]]) + raises(DMBadInputError, lambda: T.from_list_flat(flat, (3, 3), ZZ)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_flat_nz(DM): + M = DM([[1, 2, 0], [0, 0, 0], [0, 0, 3]]) + elements = [ZZ(1), ZZ(2), ZZ(3)] + indices = ((0, 0), (0, 1), (2, 2)) + assert M.to_flat_nz() == (elements, (indices, M.shape)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_flat_nz(DM): + T = type(DM([[0]])) + elements = [ZZ(1), ZZ(2), ZZ(3)] + indices = ((0, 0), (0, 1), (2, 2)) + data = (indices, (3, 3)) + result = DM([[1, 2, 0], [0, 0, 0], [0, 0, 3]]) + assert T.from_flat_nz(elements, data, ZZ) == result + raises(DMBadInputError, lambda: T.from_flat_nz(elements, (indices, (2, 3)), ZZ)) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_dod(DM): + dod = {0: {0: ZZ(1), 2: ZZ(4)}, 1: {0: ZZ(4), 1: ZZ(5), 2: ZZ(6)}} + assert DM([[1, 0, 4], [4, 5, 6]]).to_dod() == dod + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_dod(DM): + T = type(DM([[0]])) + dod = {0: {0: ZZ(1), 2: ZZ(4)}, 1: {0: ZZ(4), 1: ZZ(5), 2: ZZ(6)}} + assert T.from_dod(dod, (2, 3), ZZ) == DM([[1, 0, 4], [4, 5, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_to_dok(DM): + dod = {(0, 0): ZZ(1), (0, 2): ZZ(4), + (1, 0): ZZ(4), (1, 1): ZZ(5), (1, 2): ZZ(6)} + assert DM([[1, 0, 4], [4, 5, 6]]).to_dok() == dod + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_dok(DM): + T = type(DM([[0]])) + dod = {(0, 0): ZZ(1), (0, 2): ZZ(4), + (1, 0): ZZ(4), (1, 1): ZZ(5), (1, 2): ZZ(6)} + assert T.from_dok(dod, (2, 3), ZZ) == DM([[1, 0, 4], [4, 5, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_iter_values(DM): + values = [ZZ(1), ZZ(4), ZZ(4), ZZ(5), ZZ(6)] + assert sorted(DM([[1, 0, 4], [4, 5, 6]]).iter_values()) == values + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_iter_items(DM): + items = [((0, 0), ZZ(1)), ((0, 2), ZZ(4)), + ((1, 0), ZZ(4)), ((1, 1), ZZ(5)), ((1, 2), ZZ(6))] + assert sorted(DM([[1, 0, 4], [4, 5, 6]]).iter_items()) == items + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_from_ddm(DM): + T = type(DM([[0]])) + ddm = DDM([[1, 2, 4], [4, 5, 6]], (2, 3), ZZ) + assert T.from_ddm(ddm) == DM([[1, 2, 4], [4, 5, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_zeros(DM): + T = type(DM([[0]])) + assert T.zeros((2, 3), ZZ) == DM([[0, 0, 0], [0, 0, 0]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_ones(DM): + T = type(DM([[0]])) + assert T.ones((2, 3), ZZ) == DM([[1, 1, 1], [1, 1, 1]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_eye(DM): + T = type(DM([[0]])) + assert T.eye(3, ZZ) == DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + assert T.eye((3, 2), ZZ) == DM([[1, 0], [0, 1], [0, 0]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_diag(DM): + T = type(DM([[0]])) + assert T.diag([1, 2, 3], ZZ) == DM([[1, 0, 0], [0, 2, 0], [0, 0, 3]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_transpose(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + assert A.transpose() == DM([[1, 4], [2, 5], [3, 6]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_add(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[2, 4, 6], [8, 10, 12]]) + assert A.add(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_sub(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[0, 0, 0], [0, 0, 0]]) + assert A.sub(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_mul(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + b = ZZ(2) + assert A.mul(b) == DM([[2, 4, 6], [8, 10, 12]]) + assert A.rmul(b) == DM([[2, 4, 6], [8, 10, 12]]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_matmul(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2], [3, 4], [5, 6]]) + C = DM([[22, 28], [49, 64]]) + assert A.matmul(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_mul_elementwise(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[1, 4, 9], [16, 25, 36]]) + assert A.mul_elementwise(B) == C + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_neg(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + C = DM([[-1, -2, -3], [-4, -5, -6]]) + assert A.neg() == C + + +@pytest.mark.parametrize('DM', DM_all) +def test_XXM_convert_to(DM): + A = DM([[1, 2, 3], [4, 5, 6]], ZZ) + B = DM([[1, 2, 3], [4, 5, 6]], QQ) + assert A.convert_to(QQ) == B + assert B.convert_to(ZZ) == A + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_scc(DM): + A = DM([ + [0, 1, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0], + [0, 0, 0, 1, 0, 1], + [0, 0, 0, 0, 1, 0], + [0, 0, 0, 1, 0, 1]]) + assert A.scc() == [[0, 1], [2], [3, 5], [4]] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_hstack(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[7, 8], [9, 10]]) + C = DM([[1, 2, 3, 7, 8], [4, 5, 6, 9, 10]]) + ABC = DM([[1, 2, 3, 7, 8, 1, 2, 3, 7, 8], + [4, 5, 6, 9, 10, 4, 5, 6, 9, 10]]) + assert A.hstack(B) == C + assert A.hstack(B, C) == ABC + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_vstack(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[7, 8, 9]]) + C = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + ABC = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9], [1, 2, 3], [4, 5, 6], [7, 8, 9]]) + assert A.vstack(B) == C + assert A.vstack(B, C) == ABC + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_applyfunc(DM): + A = DM([[1, 2, 3], [4, 5, 6]]) + B = DM([[2, 4, 6], [8, 10, 12]]) + assert A.applyfunc(lambda x: 2*x, ZZ) == B + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_upper(DM): + assert DM([[1, 2, 3], [0, 5, 6]]).is_upper() is True + assert DM([[1, 2, 3], [4, 5, 6]]).is_upper() is False + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_lower(DM): + assert DM([[1, 0, 0], [4, 5, 0]]).is_lower() is True + assert DM([[1, 2, 3], [4, 5, 6]]).is_lower() is False + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_diagonal(DM): + assert DM([[1, 0, 0], [0, 5, 0]]).is_diagonal() is True + assert DM([[1, 2, 3], [4, 5, 6]]).is_diagonal() is False + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_diagonal(DM): + assert DM([[1, 0, 0], [0, 5, 0]]).diagonal() == [1, 5] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_is_zero_matrix(DM): + assert DM([[0, 0, 0], [0, 0, 0]]).is_zero_matrix() is True + assert DM([[1, 0, 0], [0, 0, 0]]).is_zero_matrix() is False + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_det_ZZ(DM): + assert DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).det() == 0 + assert DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]).det() == -3 + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_det_QQ(DM): + dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]]) + assert dM1.det() == QQ(-1,10) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_inv_QQ(DM): + dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]]) + dM2 = DM([[(-8,1), (20,3)], [(15,2), (-5,1)]]) + assert dM1.inv() == dM2 + assert dM1.matmul(dM2) == DM([[1, 0], [0, 1]]) + + dM3 = DM([[(1,2), (2,3)], [(1,4), (1,3)]]) + raises(DMNonInvertibleMatrixError, lambda: dM3.inv()) + + dM4 = DM([[(1,2), (2,3), (3,4)], [(1,4), (1,3), (1,2)]]) + raises(DMNonSquareMatrixError, lambda: dM4.inv()) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_inv_ZZ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + # XXX: Maybe this should return a DM over QQ instead? + # XXX: Handle unimodular matrices? + raises(DMDomainError, lambda: dM1.inv()) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_charpoly_ZZ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + assert dM1.charpoly() == [1, -16, -12, 3] + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_charpoly_QQ(DM): + dM1 = DM([[(1,2), (2,3)], [(3,4), (4,5)]]) + assert dM1.charpoly() == [QQ(1,1), QQ(-13,10), QQ(-1,10)] + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_lu_solve_ZZ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + dM2 = DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + raises(DMDomainError, lambda: dM1.lu_solve(dM2)) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_lu_solve_QQ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 10]]) + dM2 = DM([[1, 0, 0], [0, 1, 0], [0, 0, 1]]) + dM3 = DM([[(-2,3),(-4,3),(1,1)],[(-2,3),(11,3),(-2,1)],[(1,1),(-2,1),(1,1)]]) + assert dM1.lu_solve(dM2) == dM3 == dM1.inv() + + dM4 = DM([[1, 2, 3], [4, 5, 6]]) + dM5 = DM([[1, 0], [0, 1], [0, 0]]) + raises(DMShapeError, lambda: dM4.lu_solve(dM5)) + + +@pytest.mark.parametrize('DM', DMQ_all) +def test_XXM_nullspace_QQ(DM): + dM1 = DM([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + # XXX: Change the signature to just return the nullspace. Possibly + # returning the rank or nullity makes sense but the list of nonpivots is + # not useful. + assert dM1.nullspace() == (DM([[1, -2, 1]]), [2]) + + +@pytest.mark.parametrize('DM', DMZ_all) +def test_XXM_lll(DM): + M = DM([[1, 2, 3], [4, 5, 20]]) + M_lll = DM([[1, 2, 3], [-1, -5, 5]]) + T = DM([[1, 0], [-5, 1]]) + assert M.lll() == M_lll + assert M.lll_transform() == (M_lll, T) + assert T.matmul(M) == M_lll diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/__init__.py b/lib/python3.10/site-packages/sympy/polys/numberfields/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..38403fdf80be22d47589a346d1b1878b982c3c93 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/__init__.py @@ -0,0 +1,27 @@ +"""Computational algebraic field theory. """ + +__all__ = [ + 'minpoly', 'minimal_polynomial', + + 'field_isomorphism', 'primitive_element', 'to_number_field', + + 'isolate', + + 'round_two', + + 'prime_decomp', 'prime_valuation', + + 'galois_group', +] + +from .minpoly import minpoly, minimal_polynomial + +from .subfield import field_isomorphism, primitive_element, to_number_field + +from .utilities import isolate + +from .basis import round_two + +from .primes import prime_decomp, prime_valuation + +from .galoisgroups import galois_group diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/basis.py b/lib/python3.10/site-packages/sympy/polys/numberfields/basis.py new file mode 100644 index 0000000000000000000000000000000000000000..7c9cb41925973b3a10a80cc6ba1442cf44330971 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/basis.py @@ -0,0 +1,246 @@ +"""Computing integral bases for number fields. """ + +from sympy.polys.polytools import Poly +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.utilities.decorator import public +from .modules import ModuleEndomorphism, ModuleHomomorphism, PowerBasis +from .utilities import extract_fundamental_discriminant + + +def _apply_Dedekind_criterion(T, p): + r""" + Apply the "Dedekind criterion" to test whether the order needs to be + enlarged relative to a given prime *p*. + """ + x = T.gen + T_bar = Poly(T, modulus=p) + lc, fl = T_bar.factor_list() + assert lc == 1 + g_bar = Poly(1, x, modulus=p) + for ti_bar, _ in fl: + g_bar *= ti_bar + h_bar = T_bar // g_bar + g = Poly(g_bar, domain=ZZ) + h = Poly(h_bar, domain=ZZ) + f = (g * h - T) // p + f_bar = Poly(f, modulus=p) + Z_bar = f_bar + for b in [g_bar, h_bar]: + Z_bar = Z_bar.gcd(b) + U_bar = T_bar // Z_bar + m = Z_bar.degree() + return U_bar, m + + +def nilradical_mod_p(H, p, q=None): + r""" + Compute the nilradical mod *p* for a given order *H*, and prime *p*. + + Explanation + =========== + + This is the ideal $I$ in $H/pH$ consisting of all elements some positive + power of which is zero in this quotient ring, i.e. is a multiple of *p*. + + Parameters + ========== + + H : :py:class:`~.Submodule` + The given order. + p : int + The rational prime. + q : int, optional + If known, the smallest power of *p* that is $>=$ the dimension of *H*. + If not provided, we compute it here. + + Returns + ======= + + :py:class:`~.Module` representing the nilradical mod *p* in *H*. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory*. + (See Lemma 6.1.6.) + + """ + n = H.n + if q is None: + q = p + while q < n: + q *= p + phi = ModuleEndomorphism(H, lambda x: x**q) + return phi.kernel(modulus=p) + + +def _second_enlargement(H, p, q): + r""" + Perform the second enlargement in the Round Two algorithm. + """ + Ip = nilradical_mod_p(H, p, q=q) + B = H.parent.submodule_from_matrix(H.matrix * Ip.matrix, denom=H.denom) + C = B + p*H + E = C.endomorphism_ring() + phi = ModuleHomomorphism(H, E, lambda x: E.inner_endomorphism(x)) + gamma = phi.kernel(modulus=p) + G = H.parent.submodule_from_matrix(H.matrix * gamma.matrix, denom=H.denom * p) + H1 = G + H + return H1, Ip + + +@public +def round_two(T, radicals=None): + r""" + Zassenhaus's "Round 2" algorithm. + + Explanation + =========== + + Carry out Zassenhaus's "Round 2" algorithm on an irreducible polynomial + *T* over :ref:`ZZ` or :ref:`QQ`. This computes an integral basis and the + discriminant for the field $K = \mathbb{Q}[x]/(T(x))$. + + Alternatively, you may pass an :py:class:`~.AlgebraicField` instance, in + place of the polynomial *T*, in which case the algorithm is applied to the + minimal polynomial for the field's primitive element. + + Ordinarily this function need not be called directly, as one can instead + access the :py:meth:`~.AlgebraicField.maximal_order`, + :py:meth:`~.AlgebraicField.integral_basis`, and + :py:meth:`~.AlgebraicField.discriminant` methods of an + :py:class:`~.AlgebraicField`. + + Examples + ======== + + Working through an AlgebraicField: + + >>> from sympy import Poly, QQ + >>> from sympy.abc import x + >>> T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + >>> K = QQ.alg_field_from_poly(T, "theta") + >>> print(K.maximal_order()) + Submodule[[2, 0, 0], [0, 2, 0], [0, 1, 1]]/2 + >>> print(K.discriminant()) + -503 + >>> print(K.integral_basis(fmt='sympy')) + [1, theta, theta/2 + theta**2/2] + + Calling directly: + + >>> from sympy import Poly + >>> from sympy.abc import x + >>> from sympy.polys.numberfields.basis import round_two + >>> T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + >>> print(round_two(T)) + (Submodule[[2, 0, 0], [0, 2, 0], [0, 1, 1]]/2, -503) + + The nilradicals mod $p$ that are sometimes computed during the Round Two + algorithm may be useful in further calculations. Pass a dictionary under + `radicals` to receive these: + + >>> T = Poly(x**3 + 3*x**2 + 5) + >>> rad = {} + >>> ZK, dK = round_two(T, radicals=rad) + >>> print(rad) + {3: Submodule[[-1, 1, 0], [-1, 0, 1]]} + + Parameters + ========== + + T : :py:class:`~.Poly`, :py:class:`~.AlgebraicField` + Either (1) the irreducible polynomial over :ref:`ZZ` or :ref:`QQ` + defining the number field, or (2) an :py:class:`~.AlgebraicField` + representing the number field itself. + + radicals : dict, optional + This is a way for any $p$-radicals (if computed) to be returned by + reference. If desired, pass an empty dictionary. If the algorithm + reaches the point where it computes the nilradical mod $p$ of the ring + of integers $Z_K$, then an $\mathbb{F}_p$-basis for this ideal will be + stored in this dictionary under the key ``p``. This can be useful for + other algorithms, such as prime decomposition. + + Returns + ======= + + Pair ``(ZK, dK)``, where: + + ``ZK`` is a :py:class:`~sympy.polys.numberfields.modules.Submodule` + representing the maximal order. + + ``dK`` is the discriminant of the field $K = \mathbb{Q}[x]/(T(x))$. + + See Also + ======== + + .AlgebraicField.maximal_order + .AlgebraicField.integral_basis + .AlgebraicField.discriminant + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + + """ + K = None + if isinstance(T, AlgebraicField): + K, T = T, T.ext.minpoly_of_element() + if ( not T.is_univariate + or not T.is_irreducible + or T.domain not in [ZZ, QQ]): + raise ValueError('Round 2 requires an irreducible univariate polynomial over ZZ or QQ.') + T, _ = T.make_monic_over_integers_by_scaling_roots() + n = T.degree() + D = T.discriminant() + D_modulus = ZZ.from_sympy(abs(D)) + # D must be 0 or 1 mod 4 (see Cohen Sec 4.4), which ensures we can write + # it in the form D = D_0 * F**2, where D_0 is 1 or a fundamental discriminant. + _, F = extract_fundamental_discriminant(D) + Ztheta = PowerBasis(K or T) + H = Ztheta.whole_submodule() + nilrad = None + while F: + # Next prime: + p, e = F.popitem() + U_bar, m = _apply_Dedekind_criterion(T, p) + if m == 0: + continue + # For a given prime p, the first enlargement of the order spanned by + # the current basis can be done in a simple way: + U = Ztheta.element_from_poly(Poly(U_bar, domain=ZZ)) + # TODO: + # Theory says only first m columns of the U//p*H term below are needed. + # Could be slightly more efficient to use only those. Maybe `Submodule` + # class should support a slice operator? + H = H.add(U // p * H, hnf_modulus=D_modulus) + if e <= m: + continue + # A second, and possibly more, enlargements for p will be needed. + # These enlargements require a more involved procedure. + q = p + while q < n: + q *= p + H1, nilrad = _second_enlargement(H, p, q) + while H1 != H: + H = H1 + H1, nilrad = _second_enlargement(H, p, q) + # Note: We do not store all nilradicals mod p, only the very last. This is + # because, unless computed against the entire integral basis, it might not + # be accurate. (In other words, if H was not already equal to ZK when we + # passed it to `_second_enlargement`, then we can't trust the nilradical + # so computed.) Example: if T(x) = x ** 3 + 15 * x ** 2 - 9 * x + 13, then + # F is divisible by 2, 3, and 7, and the nilradical mod 2 as computed above + # will not be accurate for the full, maximal order ZK. + if nilrad is not None and isinstance(radicals, dict): + radicals[p] = nilrad + ZK = H + # Pre-set expensive boolean properties which we already know to be true: + ZK._starts_with_unity = True + ZK._is_sq_maxrank_HNF = True + dK = (D * ZK.matrix.det() ** 2) // ZK.denom ** (2 * n) + return ZK, dK diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/exceptions.py b/lib/python3.10/site-packages/sympy/polys/numberfields/exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..6e0d1ddc23c39295626fa036cf34974f50e4f53a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/exceptions.py @@ -0,0 +1,54 @@ +"""Special exception classes for numberfields. """ + + +class ClosureFailure(Exception): + r""" + Signals that a :py:class:`ModuleElement` which we tried to represent in a + certain :py:class:`Module` cannot in fact be represented there. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, ZZ + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.numberfields.modules import PowerBasis, to_col + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) + + Because we are in a cyclotomic field, the power basis ``A`` is an integral + basis, and the submodule ``B`` is just the ideal $(2)$. Therefore ``B`` can + represent an element having all even coefficients over the power basis: + + >>> a1 = A(to_col([2, 4, 6, 8])) + >>> print(B.represent(a1)) + DomainMatrix([[1], [2], [3], [4]], (4, 1), ZZ) + + but ``B`` cannot represent an element with an odd coefficient: + + >>> a2 = A(to_col([1, 2, 2, 2])) + >>> B.represent(a2) + Traceback (most recent call last): + ... + ClosureFailure: Element in QQ-span but not ZZ-span of this basis. + + """ + pass + + +class StructureError(Exception): + r""" + Represents cases in which an algebraic structure was expected to have a + certain property, or be of a certain type, but was not. + """ + pass + + +class MissingUnityError(StructureError): + r"""Structure should contain a unity element but does not.""" + pass + + +__all__ = [ + 'ClosureFailure', 'StructureError', 'MissingUnityError', +] diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/galois_resolvents.py b/lib/python3.10/site-packages/sympy/polys/numberfields/galois_resolvents.py new file mode 100644 index 0000000000000000000000000000000000000000..5d73b56870a498f09102787da3517e7520edb3db --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/galois_resolvents.py @@ -0,0 +1,676 @@ +r""" +Galois resolvents + +Each of the functions in ``sympy.polys.numberfields.galoisgroups`` that +computes Galois groups for a particular degree $n$ uses resolvents. Given the +polynomial $T$ whose Galois group is to be computed, a resolvent is a +polynomial $R$ whose roots are defined as functions of the roots of $T$. + +One way to compute the coefficients of $R$ is by approximating the roots of $T$ +to sufficient precision. This module defines a :py:class:`~.Resolvent` class +that handles this job, determining the necessary precision, and computing $R$. + +In some cases, the coefficients of $R$ are symmetric in the roots of $T$, +meaning they are equal to fixed functions of the coefficients of $T$. Therefore +another approach is to compute these functions once and for all, and record +them in a lookup table. This module defines code that can compute such tables. +The tables for polynomials $T$ of degrees 4 through 6, produced by this code, +are recorded in the resolvent_lookup.py module. + +""" + +from sympy.core.evalf import ( + evalf, fastlog, _evalf_with_bounded_error, quad_to_mpmath, +) +from sympy.core.symbol import symbols, Dummy +from sympy.polys.densetools import dup_eval +from sympy.polys.domains import ZZ +from sympy.polys.orderings import lex +from sympy.polys.polyroots import preprocess_roots +from sympy.polys.polytools import Poly +from sympy.polys.rings import xring +from sympy.polys.specialpolys import symmetric_poly +from sympy.utilities.lambdify import lambdify + +from mpmath import MPContext +from mpmath.libmp.libmpf import prec_to_dps + + +class GaloisGroupException(Exception): + ... + + +class ResolventException(GaloisGroupException): + ... + + +class Resolvent: + r""" + If $G$ is a subgroup of the symmetric group $S_n$, + $F$ a multivariate polynomial in $\mathbb{Z}[X_1, \ldots, X_n]$, + $H$ the stabilizer of $F$ in $G$ (i.e. the permutations $\sigma$ such that + $F(X_{\sigma(1)}, \ldots, X_{\sigma(n)}) = F(X_1, \ldots, X_n)$), and $s$ + a set of left coset representatives of $H$ in $G$, then the resolvent + polynomial $R(Y)$ is the product over $\sigma \in s$ of + $Y - F(X_{\sigma(1)}, \ldots, X_{\sigma(n)})$. + + For example, consider the resolvent for the form + $$F = X_0 X_2 + X_1 X_3$$ + and the group $G = S_4$. In this case, the stabilizer $H$ is the dihedral + group $D4 = < (0123), (02) >$, and a set of representatives of $G/H$ is + $\{I, (01), (03)\}$. The resolvent can be constructed as follows: + + >>> from sympy.combinatorics.permutations import Permutation + >>> from sympy.core.symbol import symbols + >>> from sympy.polys.numberfields.galoisgroups import Resolvent + >>> X = symbols('X0 X1 X2 X3') + >>> F = X[0]*X[2] + X[1]*X[3] + >>> s = [Permutation([0, 1, 2, 3]), Permutation([1, 0, 2, 3]), + ... Permutation([3, 1, 2, 0])] + >>> R = Resolvent(F, X, s) + + This resolvent has three roots, which are the conjugates of ``F`` under the + three permutations in ``s``: + + >>> R.root_lambdas[0](*X) + X0*X2 + X1*X3 + >>> R.root_lambdas[1](*X) + X0*X3 + X1*X2 + >>> R.root_lambdas[2](*X) + X0*X1 + X2*X3 + + Resolvents are useful for computing Galois groups. Given a polynomial $T$ + of degree $n$, we will use a resolvent $R$ where $Gal(T) \leq G \leq S_n$. + We will then want to substitute the roots of $T$ for the variables $X_i$ + in $R$, and study things like the discriminant of $R$, and the way $R$ + factors over $\mathbb{Q}$. + + From the symmetry in $R$'s construction, and since $Gal(T) \leq G$, we know + from Galois theory that the coefficients of $R$ must lie in $\mathbb{Z}$. + This allows us to compute the coefficients of $R$ by approximating the + roots of $T$ to sufficient precision, plugging these values in for the + variables $X_i$ in the coefficient expressions of $R$, and then simply + rounding to the nearest integer. + + In order to determine a sufficient precision for the roots of $T$, this + ``Resolvent`` class imposes certain requirements on the form ``F``. It + could be possible to design a different ``Resolvent`` class, that made + different precision estimates, and different assumptions about ``F``. + + ``F`` must be homogeneous, and all terms must have unit coefficient. + Furthermore, if $r$ is the number of terms in ``F``, and $t$ the total + degree, and if $m$ is the number of conjugates of ``F``, i.e. the number + of permutations in ``s``, then we require that $m < r 2^t$. Again, it is + not impossible to work with forms ``F`` that violate these assumptions, but + this ``Resolvent`` class requires them. + + Since determining the integer coefficients of the resolvent for a given + polynomial $T$ is one of the main problems this class solves, we take some + time to explain the precision bounds it uses. + + The general problem is: + Given a multivariate polynomial $P \in \mathbb{Z}[X_1, \ldots, X_n]$, and a + bound $M \in \mathbb{R}_+$, compute an $\varepsilon > 0$ such that for any + complex numbers $a_1, \ldots, a_n$ with $|a_i| < M$, if the $a_i$ are + approximated to within an accuracy of $\varepsilon$ by $b_i$, that is, + $|a_i - b_i| < \varepsilon$ for $i = 1, \ldots, n$, then + $|P(a_1, \ldots, a_n) - P(b_1, \ldots, b_n)| < 1/2$. In other words, if it + is known that $P(a_1, \ldots, a_n) = c$ for some $c \in \mathbb{Z}$, then + $P(b_1, \ldots, b_n)$ can be rounded to the nearest integer in order to + determine $c$. + + To derive our error bound, consider the monomial $xyz$. Defining + $d_i = b_i - a_i$, our error is + $|(a_1 + d_1)(a_2 + d_2)(a_3 + d_3) - a_1 a_2 a_3|$, which is bounded + above by $|(M + \varepsilon)^3 - M^3|$. Passing to a general monomial of + total degree $t$, this expression is bounded by + $M^{t-1}\varepsilon(t + 2^t\varepsilon/M)$ provided $\varepsilon < M$, + and by $(t+1)M^{t-1}\varepsilon$ provided $\varepsilon < M/2^t$. + But since our goal is to make the error less than $1/2$, we will choose + $\varepsilon < 1/(2(t+1)M^{t-1})$, which implies the condition that + $\varepsilon < M/2^t$, as long as $M \geq 2$. + + Passing from the general monomial to the general polynomial is easy, by + scaling and summing error bounds. + + In our specific case, we are given a homogeneous polynomial $F$ of + $r$ terms and total degree $t$, all of whose coefficients are $\pm 1$. We + are given the $m$ permutations that make the conjugates of $F$, and + we want to bound the error in the coefficients of the monic polynomial + $R(Y)$ having $F$ and its conjugates as roots (i.e. the resolvent). + + For $j$ from $1$ to $m$, the coefficient of $Y^{m-j}$ in $R(Y)$ is the + $j$th elementary symmetric polynomial in the conjugates of $F$. This sums + the products of these conjugates, taken $j$ at a time, in all possible + combinations. There are $\binom{m}{j}$ such combinations, and each product + of $j$ conjugates of $F$ expands to a sum of $r^j$ terms, each of unit + coefficient, and total degree $jt$. An error bound for the $j$th coeff of + $R$ is therefore + $$\binom{m}{j} r^j (jt + 1) M^{jt - 1} \varepsilon$$ + When our goal is to evaluate all the coefficients of $R$, we will want to + use the maximum of these error bounds. It is clear that this bound is + strictly increasing for $j$ up to the ceiling of $m/2$. After that point, + the first factor $\binom{m}{j}$ begins to decrease, while the others + continue to increase. However, the binomial coefficient never falls by more + than a factor of $1/m$ at a time, so our assumptions that $M \geq 2$ and + $m < r 2^t$ are enough to tell us that the constant coefficient of $R$, + i.e. that where $j = m$, has the largest error bound. Therefore we can use + $$r^m (mt + 1) M^{mt - 1} \varepsilon$$ + as our error bound for all the coefficients. + + Note that this bound is also (more than) adequate to determine whether any + of the roots of $R$ is an integer. Each of these roots is a single + conjugate of $F$, which contains less error than the trace, i.e. the + coefficient of $Y^{m - 1}$. By rounding the roots of $R$ to the nearest + integers, we therefore get all the candidates for integer roots of $R$. By + plugging these candidates into $R$, we can check whether any of them + actually is a root. + + Note: We take the definition of resolvent from Cohen, but the error bound + is ours. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory*. + (Def 6.3.2) + + """ + + def __init__(self, F, X, s): + r""" + Parameters + ========== + + F : :py:class:`~.Expr` + polynomial in the symbols in *X* + X : list of :py:class:`~.Symbol` + s : list of :py:class:`~.Permutation` + representing the cosets of the stabilizer of *F* in + some subgroup $G$ of $S_n$, where $n$ is the length of *X*. + """ + self.F = F + self.X = X + self.s = s + + # Number of conjugates: + self.m = len(s) + # Total degree of F (computed below): + self.t = None + # Number of terms in F (computed below): + self.r = 0 + + for monom, coeff in Poly(F).terms(): + if abs(coeff) != 1: + raise ResolventException('Resolvent class expects forms with unit coeffs') + t = sum(monom) + if t != self.t and self.t is not None: + raise ResolventException('Resolvent class expects homogeneous forms') + self.t = t + self.r += 1 + + m, t, r = self.m, self.t, self.r + if not m < r * 2**t: + raise ResolventException('Resolvent class expects m < r*2^t') + M = symbols('M') + # Precision sufficient for computing the coeffs of the resolvent: + self.coeff_prec_func = Poly(r**m*(m*t + 1)*M**(m*t - 1)) + # Precision sufficient for checking whether any of the roots of the + # resolvent are integers: + self.root_prec_func = Poly(r*(t + 1)*M**(t - 1)) + + # The conjugates of F are the roots of the resolvent. + # For evaluating these to required numerical precisions, we need + # lambdified versions. + # Note: for a given permutation sigma, the conjugate (sigma F) is + # equivalent to lambda [sigma^(-1) X]: F. + self.root_lambdas = [ + lambdify((~s[j])(X), F) + for j in range(self.m) + ] + + # For evaluating the coeffs, we'll also need lambdified versions of + # the elementary symmetric functions for degree m. + Y = symbols('Y') + R = symbols(' '.join(f'R{i}' for i in range(m))) + f = 1 + for r in R: + f *= (Y - r) + C = Poly(f, Y).coeffs() + self.esf_lambdas = [lambdify(R, c) for c in C] + + def get_prec(self, M, target='coeffs'): + r""" + For a given upper bound *M* on the magnitude of the complex numbers to + be plugged in for this resolvent's symbols, compute a sufficient + precision for evaluating those complex numbers, such that the + coefficients, or the integer roots, of the resolvent can be determined. + + Parameters + ========== + + M : real number + Upper bound on magnitude of the complex numbers to be plugged in. + + target : str, 'coeffs' or 'roots', default='coeffs' + Name the task for which a sufficient precision is desired. + This is either determining the coefficients of the resolvent + ('coeffs') or determining its possible integer roots ('roots'). + The latter may require significantly lower precision. + + Returns + ======= + + int $m$ + such that $2^{-m}$ is a sufficient upper bound on the + error in approximating the complex numbers to be plugged in. + + """ + # As explained in the docstring for this class, our precision estimates + # require that M be at least 2. + M = max(M, 2) + f = self.coeff_prec_func if target == 'coeffs' else self.root_prec_func + r, _, _, _ = evalf(2*f(M), 1, {}) + return fastlog(r) + 1 + + def approximate_roots_of_poly(self, T, target='coeffs'): + """ + Approximate the roots of a given polynomial *T* to sufficient precision + in order to evaluate this resolvent's coefficients, or determine + whether the resolvent has an integer root. + + Parameters + ========== + + T : :py:class:`~.Poly` + + target : str, 'coeffs' or 'roots', default='coeffs' + Set the approximation precision to be sufficient for the desired + task, which is either determining the coefficients of the resolvent + ('coeffs') or determining its possible integer roots ('roots'). + The latter may require significantly lower precision. + + Returns + ======= + + list of elements of :ref:`CC` + + """ + ctx = MPContext() + # Because sympy.polys.polyroots._integer_basis() is called when a CRootOf + # is formed, we proactively extract the integer basis now. This means that + # when we call T.all_roots(), every root will be a CRootOf, not a Mul + # of Integer*CRootOf. + coeff, T = preprocess_roots(T) + coeff = ctx.mpf(str(coeff)) + + scaled_roots = T.all_roots(radicals=False) + + # Since we're going to be approximating the roots of T anyway, we can + # get a good upper bound on the magnitude of the roots by starting with + # a very low precision approx. + approx0 = [coeff * quad_to_mpmath(_evalf_with_bounded_error(r, m=0)) for r in scaled_roots] + # Here we add 1 to account for the possible error in our initial approximation. + M = max(abs(b) for b in approx0) + 1 + m = self.get_prec(M, target=target) + n = fastlog(M._mpf_) + 1 + p = m + n + 1 + ctx.prec = p + d = prec_to_dps(p) + + approx1 = [r.eval_approx(d, return_mpmath=True) for r in scaled_roots] + approx1 = [coeff*ctx.mpc(r) for r in approx1] + + return approx1 + + @staticmethod + def round_mpf(a): + if isinstance(a, int): + return a + # If we use python's built-in `round()`, we lose precision. + # If we use `ZZ` directly, we may add or subtract 1. + # + # XXX: We have to convert to int before converting to ZZ because + # flint.fmpz cannot convert a mpmath mpf. + return ZZ(int(a.context.nint(a))) + + def round_roots_to_integers_for_poly(self, T): + """ + For a given polynomial *T*, round the roots of this resolvent to the + nearest integers. + + Explanation + =========== + + None of the integers returned by this method is guaranteed to be a + root of the resolvent; however, if the resolvent has any integer roots + (for the given polynomial *T*), then they must be among these. + + If the coefficients of the resolvent are also desired, then this method + should not be used. Instead, use the ``eval_for_poly`` method. This + method may be significantly faster than ``eval_for_poly``. + + Parameters + ========== + + T : :py:class:`~.Poly` + + Returns + ======= + + dict + Keys are the indices of those permutations in ``self.s`` such that + the corresponding root did round to a rational integer. + + Values are :ref:`ZZ`. + + + """ + approx_roots_of_T = self.approximate_roots_of_poly(T, target='roots') + approx_roots_of_self = [r(*approx_roots_of_T) for r in self.root_lambdas] + return { + i: self.round_mpf(r.real) + for i, r in enumerate(approx_roots_of_self) + if self.round_mpf(r.imag) == 0 + } + + def eval_for_poly(self, T, find_integer_root=False): + r""" + Compute the integer values of the coefficients of this resolvent, when + plugging in the roots of a given polynomial. + + Parameters + ========== + + T : :py:class:`~.Poly` + + find_integer_root : ``bool``, default ``False`` + If ``True``, then also determine whether the resolvent has an + integer root, and return the first one found, along with its + index, i.e. the index of the permutation ``self.s[i]`` it + corresponds to. + + Returns + ======= + + Tuple ``(R, a, i)`` + + ``R`` is this resolvent as a dense univariate polynomial over + :ref:`ZZ`, i.e. a list of :ref:`ZZ`. + + If *find_integer_root* was ``True``, then ``a`` and ``i`` are the + first integer root found, and its index, if one exists. + Otherwise ``a`` and ``i`` are both ``None``. + + """ + approx_roots_of_T = self.approximate_roots_of_poly(T, target='coeffs') + approx_roots_of_self = [r(*approx_roots_of_T) for r in self.root_lambdas] + approx_coeffs_of_self = [c(*approx_roots_of_self) for c in self.esf_lambdas] + + R = [] + for c in approx_coeffs_of_self: + if self.round_mpf(c.imag) != 0: + # If precision was enough, this should never happen. + raise ResolventException(f"Got non-integer coeff for resolvent: {c}") + R.append(self.round_mpf(c.real)) + + a0, i0 = None, None + + if find_integer_root: + for i, r in enumerate(approx_roots_of_self): + if self.round_mpf(r.imag) != 0: + continue + if not dup_eval(R, (a := self.round_mpf(r.real)), ZZ): + a0, i0 = a, i + break + + return R, a0, i0 + + +def wrap(text, width=80): + """Line wrap a polynomial expression. """ + out = '' + col = 0 + for c in text: + if c == ' ' and col > width: + c, col = '\n', 0 + else: + col += 1 + out += c + return out + + +def s_vars(n): + """Form the symbols s1, s2, ..., sn to stand for elem. symm. polys. """ + return symbols([f's{i + 1}' for i in range(n)]) + + +def sparse_symmetrize_resolvent_coeffs(F, X, s, verbose=False): + """ + Compute the coefficients of a resolvent as functions of the coefficients of + the associated polynomial. + + F must be a sparse polynomial. + """ + import time, sys + # Roots of resolvent as multivariate forms over vars X: + root_forms = [ + F.compose(list(zip(X, sigma(X)))) + for sigma in s + ] + + # Coeffs of resolvent (besides lead coeff of 1) as symmetric forms over vars X: + Y = [Dummy(f'Y{i}') for i in range(len(s))] + coeff_forms = [] + for i in range(1, len(s) + 1): + if verbose: + print('----') + print(f'Computing symmetric poly of degree {i}...') + sys.stdout.flush() + t0 = time.time() + G = symmetric_poly(i, *Y) + t1 = time.time() + if verbose: + print(f'took {t1 - t0} seconds') + print('lambdifying...') + sys.stdout.flush() + t0 = time.time() + C = lambdify(Y, (-1)**i*G) + t1 = time.time() + if verbose: + print(f'took {t1 - t0} seconds') + sys.stdout.flush() + coeff_forms.append(C) + + coeffs = [] + for i, f in enumerate(coeff_forms): + if verbose: + print('----') + print(f'Plugging root forms into elem symm poly {i+1}...') + sys.stdout.flush() + t0 = time.time() + g = f(*root_forms) + t1 = time.time() + coeffs.append(g) + if verbose: + print(f'took {t1 - t0} seconds') + sys.stdout.flush() + + # Now symmetrize these coeffs. This means recasting them as polynomials in + # the elementary symmetric polys over X. + symmetrized = [] + symmetrization_times = [] + ss = s_vars(len(X)) + for i, A in list(enumerate(coeffs)): + if verbose: + print('-----') + print(f'Coeff {i+1}...') + sys.stdout.flush() + t0 = time.time() + B, rem, _ = A.symmetrize() + t1 = time.time() + if rem != 0: + msg = f"Got nonzero remainder {rem} for resolvent (F, X, s) = ({F}, {X}, {s})" + raise ResolventException(msg) + B_str = str(B.as_expr(*ss)) + symmetrized.append(B_str) + symmetrization_times.append(t1 - t0) + if verbose: + print(wrap(B_str)) + print(f'took {t1 - t0} seconds') + sys.stdout.flush() + + return symmetrized, symmetrization_times + + +def define_resolvents(): + """Define all the resolvents for polys T of degree 4 through 6. """ + from sympy.combinatorics.galois import PGL2F5 + from sympy.combinatorics.permutations import Permutation + + R4, X4 = xring("X0,X1,X2,X3", ZZ, lex) + X = X4 + + # The one resolvent used in `_galois_group_degree_4_lookup()`: + F40 = X[0]*X[1]**2 + X[1]*X[2]**2 + X[2]*X[3]**2 + X[3]*X[0]**2 + s40 = [ + Permutation(3), + Permutation(3)(0, 1), + Permutation(3)(0, 2), + Permutation(3)(0, 3), + Permutation(3)(1, 2), + Permutation(3)(2, 3), + ] + + # First resolvent used in `_galois_group_degree_4_root_approx()`: + F41 = X[0]*X[2] + X[1]*X[3] + s41 = [ + Permutation(3), + Permutation(3)(0, 1), + Permutation(3)(0, 3) + ] + + R5, X5 = xring("X0,X1,X2,X3,X4", ZZ, lex) + X = X5 + + # First resolvent used in `_galois_group_degree_5_hybrid()`, + # and only one used in `_galois_group_degree_5_lookup_ext_factor()`: + F51 = ( X[0]**2*(X[1]*X[4] + X[2]*X[3]) + + X[1]**2*(X[2]*X[0] + X[3]*X[4]) + + X[2]**2*(X[3]*X[1] + X[4]*X[0]) + + X[3]**2*(X[4]*X[2] + X[0]*X[1]) + + X[4]**2*(X[0]*X[3] + X[1]*X[2])) + s51 = [ + Permutation(4), + Permutation(4)(0, 1), + Permutation(4)(0, 2), + Permutation(4)(0, 3), + Permutation(4)(0, 4), + Permutation(4)(1, 4) + ] + + R6, X6 = xring("X0,X1,X2,X3,X4,X5", ZZ, lex) + X = X6 + + # First resolvent used in `_galois_group_degree_6_lookup()`: + H = PGL2F5() + term0 = X[0]**2*X[5]**2*(X[1]*X[4] + X[2]*X[3]) + terms = {term0.compose(list(zip(X, s(X)))) for s in H.elements} + F61 = sum(terms) + s61 = [Permutation(5)] + [Permutation(5)(0, n) for n in range(1, 6)] + + # Second resolvent used in `_galois_group_degree_6_lookup()`: + F62 = X[0]*X[1]*X[2] + X[3]*X[4]*X[5] + s62 = [Permutation(5)] + [ + Permutation(5)(i, j + 3) for i in range(3) for j in range(3) + ] + + return { + (4, 0): (F40, X4, s40), + (4, 1): (F41, X4, s41), + (5, 1): (F51, X5, s51), + (6, 1): (F61, X6, s61), + (6, 2): (F62, X6, s62), + } + + +def generate_lambda_lookup(verbose=False, trial_run=False): + """ + Generate the whole lookup table of coeff lambdas, for all resolvents. + """ + jobs = define_resolvents() + lambda_lists = {} + total_time = 0 + time_for_61 = 0 + time_for_61_last = 0 + for k, (F, X, s) in jobs.items(): + symmetrized, times = sparse_symmetrize_resolvent_coeffs(F, X, s, verbose=verbose) + + total_time += sum(times) + if k == (6, 1): + time_for_61 = sum(times) + time_for_61_last = times[-1] + + sv = s_vars(len(X)) + head = f'lambda {", ".join(str(v) for v in sv)}:' + lambda_lists[k] = ',\n '.join([ + f'{head} ({wrap(f)})' + for f in symmetrized + ]) + + if trial_run: + break + + table = ( + "# This table was generated by a call to\n" + "# `sympy.polys.numberfields.galois_resolvents.generate_lambda_lookup()`.\n" + f"# The entire job took {total_time:.2f}s.\n" + f"# Of this, Case (6, 1) took {time_for_61:.2f}s.\n" + f"# The final polynomial of Case (6, 1) alone took {time_for_61_last:.2f}s.\n" + "resolvent_coeff_lambdas = {\n") + + for k, L in lambda_lists.items(): + table += f" {k}: [\n" + table += " " + L + '\n' + table += " ],\n" + table += "}\n" + return table + + +def get_resolvent_by_lookup(T, number): + """ + Use the lookup table, to return a resolvent (as dup) for a given + polynomial *T*. + + Parameters + ========== + + T : Poly + The polynomial whose resolvent is needed + + number : int + For some degrees, there are multiple resolvents. + Use this to indicate which one you want. + + Returns + ======= + + dup + + """ + from sympy.polys.numberfields.resolvent_lookup import resolvent_coeff_lambdas + degree = T.degree() + L = resolvent_coeff_lambdas[(degree, number)] + T_coeffs = T.rep.to_list()[1:] + return [ZZ(1)] + [c(*T_coeffs) for c in L] + + +# Use +# (.venv) $ python -m sympy.polys.numberfields.galois_resolvents +# to reproduce the table found in resolvent_lookup.py +if __name__ == "__main__": + import sys + verbose = '-v' in sys.argv[1:] + trial_run = '-t' in sys.argv[1:] + table = generate_lambda_lookup(verbose=verbose, trial_run=trial_run) + print(table) diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/galoisgroups.py b/lib/python3.10/site-packages/sympy/polys/numberfields/galoisgroups.py new file mode 100644 index 0000000000000000000000000000000000000000..a0e424bf7554c0cedd926902e7322b9640735a8b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/galoisgroups.py @@ -0,0 +1,623 @@ +""" +Compute Galois groups of polynomials. + +We use algorithms from [1], with some modifications to use lookup tables for +resolvents. + +References +========== + +.. [1] Cohen, H. *A Course in Computational Algebraic Number Theory*. + +""" + +from collections import defaultdict +import random + +from sympy.core.symbol import Dummy, symbols +from sympy.ntheory.primetest import is_square +from sympy.polys.domains import ZZ +from sympy.polys.densebasic import dup_random +from sympy.polys.densetools import dup_eval +from sympy.polys.euclidtools import dup_discriminant +from sympy.polys.factortools import dup_factor_list, dup_irreducible_p +from sympy.polys.numberfields.galois_resolvents import ( + GaloisGroupException, get_resolvent_by_lookup, define_resolvents, + Resolvent, +) +from sympy.polys.numberfields.utilities import coeff_search +from sympy.polys.polytools import (Poly, poly_from_expr, + PolificationFailed, ComputationFailed) +from sympy.polys.sqfreetools import dup_sqf_p +from sympy.utilities import public + + +class MaxTriesException(GaloisGroupException): + ... + + +def tschirnhausen_transformation(T, max_coeff=10, max_tries=30, history=None, + fixed_order=True): + r""" + Given a univariate, monic, irreducible polynomial over the integers, find + another such polynomial defining the same number field. + + Explanation + =========== + + See Alg 6.3.4 of [1]. + + Parameters + ========== + + T : Poly + The given polynomial + max_coeff : int + When choosing a transformation as part of the process, + keep the coeffs between plus and minus this. + max_tries : int + Consider at most this many transformations. + history : set, None, optional (default=None) + Pass a set of ``Poly.rep``'s in order to prevent any of these + polynomials from being returned as the polynomial ``U`` i.e. the + transformation of the given polynomial *T*. The given poly *T* will + automatically be added to this set, before we try to find a new one. + fixed_order : bool, default True + If ``True``, work through candidate transformations A(x) in a fixed + order, from small coeffs to large, resulting in deterministic behavior. + If ``False``, the A(x) are chosen randomly, while still working our way + up from small coefficients to larger ones. + + Returns + ======= + + Pair ``(A, U)`` + + ``A`` and ``U`` are ``Poly``, ``A`` is the + transformation, and ``U`` is the transformed polynomial that defines + the same number field as *T*. The polynomial ``A`` maps the roots of + *T* to the roots of ``U``. + + Raises + ====== + + MaxTriesException + if could not find a polynomial before exceeding *max_tries*. + + """ + X = Dummy('X') + n = T.degree() + if history is None: + history = set() + history.add(T.rep) + + if fixed_order: + coeff_generators = {} + deg_coeff_sum = 3 + current_degree = 2 + + def get_coeff_generator(degree): + gen = coeff_generators.get(degree, coeff_search(degree, 1)) + coeff_generators[degree] = gen + return gen + + for i in range(max_tries): + + # We never use linear A(x), since applying a fixed linear transformation + # to all roots will only multiply the discriminant of T by a square + # integer. This will change nothing important. In particular, if disc(T) + # was zero before, it will still be zero now, and typically we apply + # the transformation in hopes of replacing T by a squarefree poly. + + if fixed_order: + # If d is degree and c max coeff, we move through the dc-space + # along lines of constant sum. First d + c = 3 with (d, c) = (2, 1). + # Then d + c = 4 with (d, c) = (3, 1), (2, 2). Then d + c = 5 with + # (d, c) = (4, 1), (3, 2), (2, 3), and so forth. For a given (d, c) + # we go though all sets of coeffs where max = c, before moving on. + gen = get_coeff_generator(current_degree) + coeffs = next(gen) + m = max(abs(c) for c in coeffs) + if current_degree + m > deg_coeff_sum: + if current_degree == 2: + deg_coeff_sum += 1 + current_degree = deg_coeff_sum - 1 + else: + current_degree -= 1 + gen = get_coeff_generator(current_degree) + coeffs = next(gen) + a = [ZZ(1)] + [ZZ(c) for c in coeffs] + + else: + # We use a progressive coeff bound, up to the max specified, since it + # is preferable to succeed with smaller coeffs. + # Give each coeff bound five tries, before incrementing. + C = min(i//5 + 1, max_coeff) + d = random.randint(2, n - 1) + a = dup_random(d, -C, C, ZZ) + + A = Poly(a, T.gen) + U = Poly(T.resultant(X - A), X) + if U.rep not in history and dup_sqf_p(U.rep.to_list(), ZZ): + return A, U + raise MaxTriesException + + +def has_square_disc(T): + """Convenience to check if a Poly or dup has square discriminant. """ + d = T.discriminant() if isinstance(T, Poly) else dup_discriminant(T, ZZ) + return is_square(d) + + +def _galois_group_degree_3(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 3. + + Explanation + =========== + + Uses Prop 6.3.5 of [1]. + + """ + from sympy.combinatorics.galois import S3TransitiveSubgroups + return ((S3TransitiveSubgroups.A3, True) if has_square_disc(T) + else (S3TransitiveSubgroups.S3, False)) + + +def _galois_group_degree_4_root_approx(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 4. + + Explanation + =========== + + Follows Alg 6.3.7 of [1], using a pure root approximation approach. + + """ + from sympy.combinatorics.permutations import Permutation + from sympy.combinatorics.galois import S4TransitiveSubgroups + + X = symbols('X0 X1 X2 X3') + # We start by considering the resolvent for the form + # F = X0*X2 + X1*X3 + # and the group G = S4. In this case, the stabilizer H is D4 = < (0123), (02) >, + # and a set of representatives of G/H is {I, (01), (03)} + F1 = X[0]*X[2] + X[1]*X[3] + s1 = [ + Permutation(3), + Permutation(3)(0, 1), + Permutation(3)(0, 3) + ] + R1 = Resolvent(F1, X, s1) + + # In the second half of the algorithm (if we reach it), we use another + # form and set of coset representatives. However, we may need to permute + # them first, so cannot form their resolvent now. + F2_pre = X[0]*X[1]**2 + X[1]*X[2]**2 + X[2]*X[3]**2 + X[3]*X[0]**2 + s2_pre = [ + Permutation(3), + Permutation(3)(0, 2) + ] + + history = set() + for i in range(max_tries): + if i > 0: + # If we're retrying, need a new polynomial T. + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + + R_dup, _, i0 = R1.eval_for_poly(T, find_integer_root=True) + # If R is not squarefree, must retry. + if not dup_sqf_p(R_dup, ZZ): + continue + + # By Prop 6.3.1 of [1], Gal(T) is contained in A4 iff disc(T) is square. + sq_disc = has_square_disc(T) + + if i0 is None: + # By Thm 6.3.3 of [1], Gal(T) is not conjugate to any subgroup of the + # stabilizer H = D4 that we chose. This means Gal(T) is either A4 or S4. + return ((S4TransitiveSubgroups.A4, True) if sq_disc + else (S4TransitiveSubgroups.S4, False)) + + # Gal(T) is conjugate to a subgroup of H = D4, so it is either V, C4 + # or D4 itself. + + if sq_disc: + # Neither C4 nor D4 is contained in A4, so Gal(T) must be V. + return (S4TransitiveSubgroups.V, True) + + # Gal(T) can only be D4 or C4. + # We will now use our second resolvent, with G being that conjugate of D4 that + # Gal(T) is contained in. To determine the right conjugate, we will need + # the permutation corresponding to the integer root we found. + sigma = s1[i0] + # Applying sigma means permuting the args of F, and + # conjugating the set of coset representatives. + F2 = F2_pre.subs(zip(X, sigma(X)), simultaneous=True) + s2 = [sigma*tau*sigma for tau in s2_pre] + R2 = Resolvent(F2, X, s2) + R_dup, _, _ = R2.eval_for_poly(T) + d = dup_discriminant(R_dup, ZZ) + # If d is zero (R has a repeated root), must retry. + if d == 0: + continue + if is_square(d): + return (S4TransitiveSubgroups.C4, False) + else: + return (S4TransitiveSubgroups.D4, False) + + raise MaxTriesException + + +def _galois_group_degree_4_lookup(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 4. + + Explanation + =========== + + Based on Alg 6.3.6 of [1], but uses resolvent coeff lookup. + + """ + from sympy.combinatorics.galois import S4TransitiveSubgroups + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 0) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + # Compute list L of degrees of irreducible factors of R, in increasing order: + fl = dup_factor_list(R_dup, ZZ) + L = sorted(sum([ + [len(r) - 1] * e for r, e in fl[1] + ], [])) + + if L == [6]: + return ((S4TransitiveSubgroups.A4, True) if has_square_disc(T) + else (S4TransitiveSubgroups.S4, False)) + + if L == [1, 1, 4]: + return (S4TransitiveSubgroups.C4, False) + + if L == [2, 2, 2]: + return (S4TransitiveSubgroups.V, True) + + assert L == [2, 4] + return (S4TransitiveSubgroups.D4, False) + + +def _galois_group_degree_5_hybrid(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 5. + + Explanation + =========== + + Based on Alg 6.3.9 of [1], but uses a hybrid approach, combining resolvent + coeff lookup, with root approximation. + + """ + from sympy.combinatorics.galois import S5TransitiveSubgroups + from sympy.combinatorics.permutations import Permutation + + X5 = symbols("X0,X1,X2,X3,X4") + res = define_resolvents() + F51, _, s51 = res[(5, 1)] + F51 = F51.as_expr(*X5) + R51 = Resolvent(F51, X5, s51) + + history = set() + reached_second_stage = False + for i in range(max_tries): + if i > 0: + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + R51_dup = get_resolvent_by_lookup(T, 1) + if not dup_sqf_p(R51_dup, ZZ): + continue + + # First stage + # If we have not yet reached the second stage, then the group still + # might be S5, A5, or M20, so must test for that. + if not reached_second_stage: + sq_disc = has_square_disc(T) + + if dup_irreducible_p(R51_dup, ZZ): + return ((S5TransitiveSubgroups.A5, True) if sq_disc + else (S5TransitiveSubgroups.S5, False)) + + if not sq_disc: + return (S5TransitiveSubgroups.M20, False) + + # Second stage + reached_second_stage = True + # R51 must have an integer root for T. + # To choose our second resolvent, we need to know which conjugate of + # F51 is a root. + rounded_roots = R51.round_roots_to_integers_for_poly(T) + # These are integers, and candidates to be roots of R51. + # We find the first one that actually is a root. + for permutation_index, candidate_root in rounded_roots.items(): + if not dup_eval(R51_dup, candidate_root, ZZ): + break + + X = X5 + F2_pre = X[0]*X[1]**2 + X[1]*X[2]**2 + X[2]*X[3]**2 + X[3]*X[4]**2 + X[4]*X[0]**2 + s2_pre = [ + Permutation(4), + Permutation(4)(0, 1)(2, 4) + ] + + i0 = permutation_index + sigma = s51[i0] + F2 = F2_pre.subs(zip(X, sigma(X)), simultaneous=True) + s2 = [sigma*tau*sigma for tau in s2_pre] + R2 = Resolvent(F2, X, s2) + R_dup, _, _ = R2.eval_for_poly(T) + d = dup_discriminant(R_dup, ZZ) + + if d == 0: + continue + if is_square(d): + return (S5TransitiveSubgroups.C5, True) + else: + return (S5TransitiveSubgroups.D5, True) + + raise MaxTriesException + + +def _galois_group_degree_5_lookup_ext_factor(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 5. + + Explanation + =========== + + Based on Alg 6.3.9 of [1], but uses resolvent coeff lookup, plus + factorization over an algebraic extension. + + """ + from sympy.combinatorics.galois import S5TransitiveSubgroups + + _T = T + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 1) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + sq_disc = has_square_disc(T) + + if dup_irreducible_p(R_dup, ZZ): + return ((S5TransitiveSubgroups.A5, True) if sq_disc + else (S5TransitiveSubgroups.S5, False)) + + if not sq_disc: + return (S5TransitiveSubgroups.M20, False) + + # If we get this far, Gal(T) can only be D5 or C5. + # But for Gal(T) to have order 5, T must already split completely in + # the extension field obtained by adjoining a single one of its roots. + fl = Poly(_T, domain=ZZ.alg_field_from_poly(_T)).factor_list()[1] + if len(fl) == 5: + return (S5TransitiveSubgroups.C5, True) + else: + return (S5TransitiveSubgroups.D5, True) + + +def _galois_group_degree_6_lookup(T, max_tries=30, randomize=False): + r""" + Compute the Galois group of a polynomial of degree 6. + + Explanation + =========== + + Based on Alg 6.3.10 of [1], but uses resolvent coeff lookup. + + """ + from sympy.combinatorics.galois import S6TransitiveSubgroups + + # First resolvent: + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 1) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + fl = dup_factor_list(R_dup, ZZ) + + # Group the factors by degree. + factors_by_deg = defaultdict(list) + for r, _ in fl[1]: + factors_by_deg[len(r) - 1].append(r) + + L = sorted(sum([ + [d] * len(ff) for d, ff in factors_by_deg.items() + ], [])) + + T_has_sq_disc = has_square_disc(T) + + if L == [1, 2, 3]: + f1 = factors_by_deg[3][0] + return ((S6TransitiveSubgroups.C6, False) if has_square_disc(f1) + else (S6TransitiveSubgroups.D6, False)) + + elif L == [3, 3]: + f1, f2 = factors_by_deg[3] + any_square = has_square_disc(f1) or has_square_disc(f2) + return ((S6TransitiveSubgroups.G18, False) if any_square + else (S6TransitiveSubgroups.G36m, False)) + + elif L == [2, 4]: + if T_has_sq_disc: + return (S6TransitiveSubgroups.S4p, True) + else: + f1 = factors_by_deg[4][0] + return ((S6TransitiveSubgroups.A4xC2, False) if has_square_disc(f1) + else (S6TransitiveSubgroups.S4xC2, False)) + + elif L == [1, 1, 4]: + return ((S6TransitiveSubgroups.A4, True) if T_has_sq_disc + else (S6TransitiveSubgroups.S4m, False)) + + elif L == [1, 5]: + return ((S6TransitiveSubgroups.PSL2F5, True) if T_has_sq_disc + else (S6TransitiveSubgroups.PGL2F5, False)) + + elif L == [1, 1, 1, 3]: + return (S6TransitiveSubgroups.S3, False) + + assert L == [6] + + # Second resolvent: + + history = set() + for i in range(max_tries): + R_dup = get_resolvent_by_lookup(T, 2) + if dup_sqf_p(R_dup, ZZ): + break + _, T = tschirnhausen_transformation(T, max_tries=max_tries, + history=history, + fixed_order=not randomize) + else: + raise MaxTriesException + + T_has_sq_disc = has_square_disc(T) + + if dup_irreducible_p(R_dup, ZZ): + return ((S6TransitiveSubgroups.A6, True) if T_has_sq_disc + else (S6TransitiveSubgroups.S6, False)) + else: + return ((S6TransitiveSubgroups.G36p, True) if T_has_sq_disc + else (S6TransitiveSubgroups.G72, False)) + + +@public +def galois_group(f, *gens, by_name=False, max_tries=30, randomize=False, **args): + r""" + Compute the Galois group for polynomials *f* up to degree 6. + + Examples + ======== + + >>> from sympy import galois_group + >>> from sympy.abc import x + >>> f = x**4 + 1 + >>> G, alt = galois_group(f) + >>> print(G) + PermutationGroup([ + (0 1)(2 3), + (0 2)(1 3)]) + + The group is returned along with a boolean, indicating whether it is + contained in the alternating group $A_n$, where $n$ is the degree of *T*. + Along with other group properties, this can help determine which group it + is: + + >>> alt + True + >>> G.order() + 4 + + Alternatively, the group can be returned by name: + + >>> G_name, _ = galois_group(f, by_name=True) + >>> print(G_name) + S4TransitiveSubgroups.V + + The group itself can then be obtained by calling the name's + ``get_perm_group()`` method: + + >>> G_name.get_perm_group() + PermutationGroup([ + (0 1)(2 3), + (0 2)(1 3)]) + + Group names are values of the enum classes + :py:class:`sympy.combinatorics.galois.S1TransitiveSubgroups`, + :py:class:`sympy.combinatorics.galois.S2TransitiveSubgroups`, + etc. + + Parameters + ========== + + f : Expr + Irreducible polynomial over :ref:`ZZ` or :ref:`QQ`, whose Galois group + is to be determined. + gens : optional list of symbols + For converting *f* to Poly, and will be passed on to the + :py:func:`~.poly_from_expr` function. + by_name : bool, default False + If ``True``, the Galois group will be returned by name. + Otherwise it will be returned as a :py:class:`~.PermutationGroup`. + max_tries : int, default 30 + Make at most this many attempts in those steps that involve + generating Tschirnhausen transformations. + randomize : bool, default False + If ``True``, then use random coefficients when generating Tschirnhausen + transformations. Otherwise try transformations in a fixed order. Both + approaches start with small coefficients and degrees and work upward. + args : optional + For converting *f* to Poly, and will be passed on to the + :py:func:`~.poly_from_expr` function. + + Returns + ======= + + Pair ``(G, alt)`` + The first element ``G`` indicates the Galois group. It is an instance + of one of the :py:class:`sympy.combinatorics.galois.S1TransitiveSubgroups` + :py:class:`sympy.combinatorics.galois.S2TransitiveSubgroups`, etc. enum + classes if *by_name* was ``True``, and a :py:class:`~.PermutationGroup` + if ``False``. + + The second element is a boolean, saying whether the group is contained + in the alternating group $A_n$ ($n$ the degree of *T*). + + Raises + ====== + + ValueError + if *f* is of an unsupported degree. + + MaxTriesException + if could not complete before exceeding *max_tries* in those steps + that involve generating Tschirnhausen transformations. + + See Also + ======== + + .Poly.galois_group + + """ + gens = gens or [] + args = args or {} + + try: + F, opt = poly_from_expr(f, *gens, **args) + except PolificationFailed as exc: + raise ComputationFailed('galois_group', 1, exc) + + return F.galois_group(by_name=by_name, max_tries=max_tries, + randomize=randomize) diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/minpoly.py b/lib/python3.10/site-packages/sympy/polys/numberfields/minpoly.py new file mode 100644 index 0000000000000000000000000000000000000000..a3543339bfbaeb0ec5b3bea1ea66c4f354a0c8af --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/minpoly.py @@ -0,0 +1,883 @@ +"""Minimal polynomials for algebraic numbers.""" + +from functools import reduce + +from sympy.core.add import Add +from sympy.core.exprtools import Factors +from sympy.core.function import expand_mul, expand_multinomial, _mexpand +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi, _illegal) +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify +from sympy.core.traversal import preorder_traversal +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt, cbrt +from sympy.functions.elementary.trigonometric import cos, sin, tan +from sympy.ntheory.factor_ import divisors +from sympy.utilities.iterables import subsets + +from sympy.polys.domains import ZZ, QQ, FractionField +from sympy.polys.orthopolys import dup_chebyshevt +from sympy.polys.polyerrors import ( + NotAlgebraic, + GeneratorsError, +) +from sympy.polys.polytools import ( + Poly, PurePoly, invert, factor_list, groebner, resultant, + degree, poly_from_expr, parallel_poly_from_expr, lcm +) +from sympy.polys.polyutils import dict_from_expr, expr_from_dict +from sympy.polys.ring_series import rs_compose_add +from sympy.polys.rings import ring +from sympy.polys.rootoftools import CRootOf +from sympy.polys.specialpolys import cyclotomic_poly +from sympy.utilities import ( + numbered_symbols, public, sift +) + + +def _choose_factor(factors, x, v, dom=QQ, prec=200, bound=5): + """ + Return a factor having root ``v`` + It is assumed that one of the factors has root ``v``. + """ + + if isinstance(factors[0], tuple): + factors = [f[0] for f in factors] + if len(factors) == 1: + return factors[0] + + prec1 = 10 + points = {} + symbols = dom.symbols if hasattr(dom, 'symbols') else [] + while prec1 <= prec: + # when dealing with non-Rational numbers we usually evaluate + # with `subs` argument but we only need a ballpark evaluation + fe = [f.as_expr().xreplace({x:v}) for f in factors] + if v.is_number: + fe = [f.n(prec) for f in fe] + + # assign integers [0, n) to symbols (if any) + for n in subsets(range(bound), k=len(symbols), repetition=True): + for s, i in zip(symbols, n): + points[s] = i + + # evaluate the expression at these points + candidates = [(abs(f.subs(points).n(prec1)), i) + for i,f in enumerate(fe)] + + # if we get invalid numbers (e.g. from division by zero) + # we try again + if any(i in _illegal for i, _ in candidates): + continue + + # find the smallest two -- if they differ significantly + # then we assume we have found the factor that becomes + # 0 when v is substituted into it + can = sorted(candidates) + (a, ix), (b, _) = can[:2] + if b > a * 10**6: # XXX what to use? + return factors[ix] + + prec1 *= 2 + + raise NotImplementedError("multiple candidates for the minimal polynomial of %s" % v) + + +def _is_sum_surds(p): + args = p.args if p.is_Add else [p] + for y in args: + if not ((y**2).is_Rational and y.is_extended_real): + return False + return True + + +def _separate_sq(p): + """ + helper function for ``_minimal_polynomial_sq`` + + It selects a rational ``g`` such that the polynomial ``p`` + consists of a sum of terms whose surds squared have gcd equal to ``g`` + and a sum of terms with surds squared prime with ``g``; + then it takes the field norm to eliminate ``sqrt(g)`` + + See simplify.simplify.split_surds and polytools.sqf_norm. + + Examples + ======== + + >>> from sympy import sqrt + >>> from sympy.abc import x + >>> from sympy.polys.numberfields.minpoly import _separate_sq + >>> p= -x + sqrt(2) + sqrt(3) + sqrt(7) + >>> p = _separate_sq(p); p + -x**2 + 2*sqrt(3)*x + 2*sqrt(7)*x - 2*sqrt(21) - 8 + >>> p = _separate_sq(p); p + -x**4 + 4*sqrt(7)*x**3 - 32*x**2 + 8*sqrt(7)*x + 20 + >>> p = _separate_sq(p); p + -x**8 + 48*x**6 - 536*x**4 + 1728*x**2 - 400 + + """ + def is_sqrt(expr): + return expr.is_Pow and expr.exp is S.Half + # p = c1*sqrt(q1) + ... + cn*sqrt(qn) -> a = [(c1, q1), .., (cn, qn)] + a = [] + for y in p.args: + if not y.is_Mul: + if is_sqrt(y): + a.append((S.One, y**2)) + elif y.is_Atom: + a.append((y, S.One)) + elif y.is_Pow and y.exp.is_integer: + a.append((y, S.One)) + else: + raise NotImplementedError + else: + T, F = sift(y.args, is_sqrt, binary=True) + a.append((Mul(*F), Mul(*T)**2)) + a.sort(key=lambda z: z[1]) + if a[-1][1] is S.One: + # there are no surds + return p + surds = [z for y, z in a] + for i in range(len(surds)): + if surds[i] != 1: + break + from sympy.simplify.radsimp import _split_gcd + g, b1, b2 = _split_gcd(*surds[i:]) + a1 = [] + a2 = [] + for y, z in a: + if z in b1: + a1.append(y*z**S.Half) + else: + a2.append(y*z**S.Half) + p1 = Add(*a1) + p2 = Add(*a2) + p = _mexpand(p1**2) - _mexpand(p2**2) + return p + +def _minimal_polynomial_sq(p, n, x): + """ + Returns the minimal polynomial for the ``nth-root`` of a sum of surds + or ``None`` if it fails. + + Parameters + ========== + + p : sum of surds + n : positive integer + x : variable of the returned polynomial + + Examples + ======== + + >>> from sympy.polys.numberfields.minpoly import _minimal_polynomial_sq + >>> from sympy import sqrt + >>> from sympy.abc import x + >>> q = 1 + sqrt(2) + sqrt(3) + >>> _minimal_polynomial_sq(q, 3, x) + x**12 - 4*x**9 - 4*x**6 + 16*x**3 - 8 + + """ + p = sympify(p) + n = sympify(n) + if not n.is_Integer or not n > 0 or not _is_sum_surds(p): + return None + pn = p**Rational(1, n) + # eliminate the square roots + p -= x + while 1: + p1 = _separate_sq(p) + if p1 is p: + p = p1.subs({x:x**n}) + break + else: + p = p1 + + # _separate_sq eliminates field extensions in a minimal way, so that + # if n = 1 then `p = constant*(minimal_polynomial(p))` + # if n > 1 it contains the minimal polynomial as a factor. + if n == 1: + p1 = Poly(p) + if p.coeff(x**p1.degree(x)) < 0: + p = -p + p = p.primitive()[1] + return p + # by construction `p` has root `pn` + # the minimal polynomial is the factor vanishing in x = pn + factors = factor_list(p)[1] + + result = _choose_factor(factors, x, pn) + return result + +def _minpoly_op_algebraic_element(op, ex1, ex2, x, dom, mp1=None, mp2=None): + """ + return the minimal polynomial for ``op(ex1, ex2)`` + + Parameters + ========== + + op : operation ``Add`` or ``Mul`` + ex1, ex2 : expressions for the algebraic elements + x : indeterminate of the polynomials + dom: ground domain + mp1, mp2 : minimal polynomials for ``ex1`` and ``ex2`` or None + + Examples + ======== + + >>> from sympy import sqrt, Add, Mul, QQ + >>> from sympy.polys.numberfields.minpoly import _minpoly_op_algebraic_element + >>> from sympy.abc import x, y + >>> p1 = sqrt(sqrt(2) + 1) + >>> p2 = sqrt(sqrt(2) - 1) + >>> _minpoly_op_algebraic_element(Mul, p1, p2, x, QQ) + x - 1 + >>> q1 = sqrt(y) + >>> q2 = 1 / y + >>> _minpoly_op_algebraic_element(Add, q1, q2, x, QQ.frac_field(y)) + x**2*y**2 - 2*x*y - y**3 + 1 + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Resultant + .. [2] I.M. Isaacs, Proc. Amer. Math. Soc. 25 (1970), 638 + "Degrees of sums in a separable field extension". + + """ + y = Dummy(str(x)) + if mp1 is None: + mp1 = _minpoly_compose(ex1, x, dom) + if mp2 is None: + mp2 = _minpoly_compose(ex2, y, dom) + else: + mp2 = mp2.subs({x: y}) + + if op is Add: + # mp1a = mp1.subs({x: x - y}) + if dom == QQ: + R, X = ring('X', QQ) + p1 = R(dict_from_expr(mp1)[0]) + p2 = R(dict_from_expr(mp2)[0]) + else: + (p1, p2), _ = parallel_poly_from_expr((mp1, x - y), x, y) + r = p1.compose(p2) + mp1a = r.as_expr() + + elif op is Mul: + mp1a = _muly(mp1, x, y) + else: + raise NotImplementedError('option not available') + + if op is Mul or dom != QQ: + r = resultant(mp1a, mp2, gens=[y, x]) + else: + r = rs_compose_add(p1, p2) + r = expr_from_dict(r.as_expr_dict(), x) + + deg1 = degree(mp1, x) + deg2 = degree(mp2, y) + if op is Mul and deg1 == 1 or deg2 == 1: + # if deg1 = 1, then mp1 = x - a; mp1a = x - y - a; + # r = mp2(x - a), so that `r` is irreducible + return r + + r = Poly(r, x, domain=dom) + _, factors = r.factor_list() + res = _choose_factor(factors, x, op(ex1, ex2), dom) + return res.as_expr() + + +def _invertx(p, x): + """ + Returns ``expand_mul(x**degree(p, x)*p.subs(x, 1/x))`` + """ + p1 = poly_from_expr(p, x)[0] + + n = degree(p1) + a = [c * x**(n - i) for (i,), c in p1.terms()] + return Add(*a) + + +def _muly(p, x, y): + """ + Returns ``_mexpand(y**deg*p.subs({x:x / y}))`` + """ + p1 = poly_from_expr(p, x)[0] + + n = degree(p1) + a = [c * x**i * y**(n - i) for (i,), c in p1.terms()] + return Add(*a) + + +def _minpoly_pow(ex, pw, x, dom, mp=None): + """ + Returns ``minpoly(ex**pw, x)`` + + Parameters + ========== + + ex : algebraic element + pw : rational number + x : indeterminate of the polynomial + dom: ground domain + mp : minimal polynomial of ``p`` + + Examples + ======== + + >>> from sympy import sqrt, QQ, Rational + >>> from sympy.polys.numberfields.minpoly import _minpoly_pow, minpoly + >>> from sympy.abc import x, y + >>> p = sqrt(1 + sqrt(2)) + >>> _minpoly_pow(p, 2, x, QQ) + x**2 - 2*x - 1 + >>> minpoly(p**2, x) + x**2 - 2*x - 1 + >>> _minpoly_pow(y, Rational(1, 3), x, QQ.frac_field(y)) + x**3 - y + >>> minpoly(y**Rational(1, 3), x) + x**3 - y + + """ + pw = sympify(pw) + if not mp: + mp = _minpoly_compose(ex, x, dom) + if not pw.is_rational: + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + if pw < 0: + if mp == x: + raise ZeroDivisionError('%s is zero' % ex) + mp = _invertx(mp, x) + if pw == -1: + return mp + pw = -pw + ex = 1/ex + + y = Dummy(str(x)) + mp = mp.subs({x: y}) + n, d = pw.as_numer_denom() + res = Poly(resultant(mp, x**d - y**n, gens=[y]), x, domain=dom) + _, factors = res.factor_list() + res = _choose_factor(factors, x, ex**pw, dom) + return res.as_expr() + + +def _minpoly_add(x, dom, *a): + """ + returns ``minpoly(Add(*a), dom, x)`` + """ + mp = _minpoly_op_algebraic_element(Add, a[0], a[1], x, dom) + p = a[0] + a[1] + for px in a[2:]: + mp = _minpoly_op_algebraic_element(Add, p, px, x, dom, mp1=mp) + p = p + px + return mp + + +def _minpoly_mul(x, dom, *a): + """ + returns ``minpoly(Mul(*a), dom, x)`` + """ + mp = _minpoly_op_algebraic_element(Mul, a[0], a[1], x, dom) + p = a[0] * a[1] + for px in a[2:]: + mp = _minpoly_op_algebraic_element(Mul, p, px, x, dom, mp1=mp) + p = p * px + return mp + + +def _minpoly_sin(ex, x): + """ + Returns the minimal polynomial of ``sin(ex)`` + see https://mathworld.wolfram.com/TrigonometryAngles.html + """ + c, a = ex.args[0].as_coeff_Mul() + if a is pi: + if c.is_rational: + n = c.q + q = sympify(n) + if q.is_prime: + # for a = pi*p/q with q odd prime, using chebyshevt + # write sin(q*a) = mp(sin(a))*sin(a); + # the roots of mp(x) are sin(pi*p/q) for p = 1,..., q - 1 + a = dup_chebyshevt(n, ZZ) + return Add(*[x**(n - i - 1)*a[i] for i in range(n)]) + if c.p == 1: + if q == 9: + return 64*x**6 - 96*x**4 + 36*x**2 - 3 + + if n % 2 == 1: + # for a = pi*p/q with q odd, use + # sin(q*a) = 0 to see that the minimal polynomial must be + # a factor of dup_chebyshevt(n, ZZ) + a = dup_chebyshevt(n, ZZ) + a = [x**(n - i)*a[i] for i in range(n + 1)] + r = Add(*a) + _, factors = factor_list(r) + res = _choose_factor(factors, x, ex) + return res + + expr = ((1 - cos(2*c*pi))/2)**S.Half + res = _minpoly_compose(expr, x, QQ) + return res + + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_cos(ex, x): + """ + Returns the minimal polynomial of ``cos(ex)`` + see https://mathworld.wolfram.com/TrigonometryAngles.html + """ + c, a = ex.args[0].as_coeff_Mul() + if a is pi: + if c.is_rational: + if c.p == 1: + if c.q == 7: + return 8*x**3 - 4*x**2 - 4*x + 1 + if c.q == 9: + return 8*x**3 - 6*x - 1 + elif c.p == 2: + q = sympify(c.q) + if q.is_prime: + s = _minpoly_sin(ex, x) + return _mexpand(s.subs({x:sqrt((1 - x)/2)})) + + # for a = pi*p/q, cos(q*a) =T_q(cos(a)) = (-1)**p + n = int(c.q) + a = dup_chebyshevt(n, ZZ) + a = [x**(n - i)*a[i] for i in range(n + 1)] + r = Add(*a) - (-1)**c.p + _, factors = factor_list(r) + res = _choose_factor(factors, x, ex) + return res + + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_tan(ex, x): + """ + Returns the minimal polynomial of ``tan(ex)`` + see https://github.com/sympy/sympy/issues/21430 + """ + c, a = ex.args[0].as_coeff_Mul() + if a is pi: + if c.is_rational: + c = c * 2 + n = int(c.q) + a = n if c.p % 2 == 0 else 1 + terms = [] + for k in range((c.p+1)%2, n+1, 2): + terms.append(a*x**k) + a = -(a*(n-k-1)*(n-k)) // ((k+1)*(k+2)) + + r = Add(*terms) + _, factors = factor_list(r) + res = _choose_factor(factors, x, ex) + return res + + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_exp(ex, x): + """ + Returns the minimal polynomial of ``exp(ex)`` + """ + c, a = ex.args[0].as_coeff_Mul() + if a == I*pi: + if c.is_rational: + q = sympify(c.q) + if c.p == 1 or c.p == -1: + if q == 3: + return x**2 - x + 1 + if q == 4: + return x**4 + 1 + if q == 6: + return x**4 - x**2 + 1 + if q == 8: + return x**8 + 1 + if q == 9: + return x**6 - x**3 + 1 + if q == 10: + return x**8 - x**6 + x**4 - x**2 + 1 + if q.is_prime: + s = 0 + for i in range(q): + s += (-x)**i + return s + + # x**(2*q) = product(factors) + factors = [cyclotomic_poly(i, x) for i in divisors(2*q)] + mp = _choose_factor(factors, x, ex) + return mp + else: + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + + +def _minpoly_rootof(ex, x): + """ + Returns the minimal polynomial of a ``CRootOf`` object. + """ + p = ex.expr + p = p.subs({ex.poly.gens[0]:x}) + _, factors = factor_list(p, x) + result = _choose_factor(factors, x, ex) + return result + + +def _minpoly_compose(ex, x, dom): + """ + Computes the minimal polynomial of an algebraic element + using operations on minimal polynomials + + Examples + ======== + + >>> from sympy import minimal_polynomial, sqrt, Rational + >>> from sympy.abc import x, y + >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=True) + x**2 - 2*x - 1 + >>> minimal_polynomial(sqrt(y) + 1/y, x, compose=True) + x**2*y**2 - 2*x*y - y**3 + 1 + + """ + if ex.is_Rational: + return ex.q*x - ex.p + if ex is I: + _, factors = factor_list(x**2 + 1, x, domain=dom) + return x**2 + 1 if len(factors) == 1 else x - I + + if ex is S.GoldenRatio: + _, factors = factor_list(x**2 - x - 1, x, domain=dom) + if len(factors) == 1: + return x**2 - x - 1 + else: + return _choose_factor(factors, x, (1 + sqrt(5))/2, dom=dom) + + if ex is S.TribonacciConstant: + _, factors = factor_list(x**3 - x**2 - x - 1, x, domain=dom) + if len(factors) == 1: + return x**3 - x**2 - x - 1 + else: + fac = (1 + cbrt(19 - 3*sqrt(33)) + cbrt(19 + 3*sqrt(33))) / 3 + return _choose_factor(factors, x, fac, dom=dom) + + if hasattr(dom, 'symbols') and ex in dom.symbols: + return x - ex + + if dom.is_QQ and _is_sum_surds(ex): + # eliminate the square roots + ex -= x + while 1: + ex1 = _separate_sq(ex) + if ex1 is ex: + return ex + else: + ex = ex1 + + if ex.is_Add: + res = _minpoly_add(x, dom, *ex.args) + elif ex.is_Mul: + f = Factors(ex).factors + r = sift(f.items(), lambda itx: itx[0].is_Rational and itx[1].is_Rational) + if r[True] and dom == QQ: + ex1 = Mul(*[bx**ex for bx, ex in r[False] + r[None]]) + r1 = dict(r[True]) + dens = [y.q for y in r1.values()] + lcmdens = reduce(lcm, dens, 1) + neg1 = S.NegativeOne + expn1 = r1.pop(neg1, S.Zero) + nums = [base**(y.p*lcmdens // y.q) for base, y in r1.items()] + ex2 = Mul(*nums) + mp1 = minimal_polynomial(ex1, x) + # use the fact that in SymPy canonicalization products of integers + # raised to rational powers are organized in relatively prime + # bases, and that in ``base**(n/d)`` a perfect power is + # simplified with the root + # Powers of -1 have to be treated separately to preserve sign. + mp2 = ex2.q*x**lcmdens - ex2.p*neg1**(expn1*lcmdens) + ex2 = neg1**expn1 * ex2**Rational(1, lcmdens) + res = _minpoly_op_algebraic_element(Mul, ex1, ex2, x, dom, mp1=mp1, mp2=mp2) + else: + res = _minpoly_mul(x, dom, *ex.args) + elif ex.is_Pow: + res = _minpoly_pow(ex.base, ex.exp, x, dom) + elif ex.__class__ is sin: + res = _minpoly_sin(ex, x) + elif ex.__class__ is cos: + res = _minpoly_cos(ex, x) + elif ex.__class__ is tan: + res = _minpoly_tan(ex, x) + elif ex.__class__ is exp: + res = _minpoly_exp(ex, x) + elif ex.__class__ is CRootOf: + res = _minpoly_rootof(ex, x) + else: + raise NotAlgebraic("%s does not seem to be an algebraic element" % ex) + return res + + +@public +def minimal_polynomial(ex, x=None, compose=True, polys=False, domain=None): + """ + Computes the minimal polynomial of an algebraic element. + + Parameters + ========== + + ex : Expr + Element or expression whose minimal polynomial is to be calculated. + + x : Symbol, optional + Independent variable of the minimal polynomial + + compose : boolean, optional (default=True) + Method to use for computing minimal polynomial. If ``compose=True`` + (default) then ``_minpoly_compose`` is used, if ``compose=False`` then + groebner bases are used. + + polys : boolean, optional (default=False) + If ``True`` returns a ``Poly`` object else an ``Expr`` object. + + domain : Domain, optional + Ground domain + + Notes + ===== + + By default ``compose=True``, the minimal polynomial of the subexpressions of ``ex`` + are computed, then the arithmetic operations on them are performed using the resultant + and factorization. + If ``compose=False``, a bottom-up algorithm is used with ``groebner``. + The default algorithm stalls less frequently. + + If no ground domain is given, it will be generated automatically from the expression. + + Examples + ======== + + >>> from sympy import minimal_polynomial, sqrt, solve, QQ + >>> from sympy.abc import x, y + + >>> minimal_polynomial(sqrt(2), x) + x**2 - 2 + >>> minimal_polynomial(sqrt(2), x, domain=QQ.algebraic_field(sqrt(2))) + x - sqrt(2) + >>> minimal_polynomial(sqrt(2) + sqrt(3), x) + x**4 - 10*x**2 + 1 + >>> minimal_polynomial(solve(x**3 + x + 3)[0], x) + x**3 + x + 3 + >>> minimal_polynomial(sqrt(y), x) + x**2 - y + + """ + + ex = sympify(ex) + if ex.is_number: + # not sure if it's always needed but try it for numbers (issue 8354) + ex = _mexpand(ex, recursive=True) + for expr in preorder_traversal(ex): + if expr.is_AlgebraicNumber: + compose = False + break + + if x is not None: + x, cls = sympify(x), Poly + else: + x, cls = Dummy('x'), PurePoly + + if not domain: + if ex.free_symbols: + domain = FractionField(QQ, list(ex.free_symbols)) + else: + domain = QQ + if hasattr(domain, 'symbols') and x in domain.symbols: + raise GeneratorsError("the variable %s is an element of the ground " + "domain %s" % (x, domain)) + + if compose: + result = _minpoly_compose(ex, x, domain) + result = result.primitive()[1] + c = result.coeff(x**degree(result, x)) + if c.is_negative: + result = expand_mul(-result) + return cls(result, x, field=True) if polys else result.collect(x) + + if not domain.is_QQ: + raise NotImplementedError("groebner method only works for QQ") + + result = _minpoly_groebner(ex, x, cls) + return cls(result, x, field=True) if polys else result.collect(x) + + +def _minpoly_groebner(ex, x, cls): + """ + Computes the minimal polynomial of an algebraic number + using Groebner bases + + Examples + ======== + + >>> from sympy import minimal_polynomial, sqrt, Rational + >>> from sympy.abc import x + >>> minimal_polynomial(sqrt(2) + 3*Rational(1, 3), x, compose=False) + x**2 - 2*x - 1 + + """ + + generator = numbered_symbols('a', cls=Dummy) + mapping, symbols = {}, {} + + def update_mapping(ex, exp, base=None): + a = next(generator) + symbols[ex] = a + + if base is not None: + mapping[ex] = a**exp + base + else: + mapping[ex] = exp.as_expr(a) + + return a + + def bottom_up_scan(ex): + """ + Transform a given algebraic expression *ex* into a multivariate + polynomial, by introducing fresh variables with defining equations. + + Explanation + =========== + + The critical elements of the algebraic expression *ex* are root + extractions, instances of :py:class:`~.AlgebraicNumber`, and negative + powers. + + When we encounter a root extraction or an :py:class:`~.AlgebraicNumber` + we replace this expression with a fresh variable ``a_i``, and record + the defining polynomial for ``a_i``. For example, if ``a_0**(1/3)`` + occurs, we will replace it with ``a_1``, and record the new defining + polynomial ``a_1**3 - a_0``. + + When we encounter a negative power we transform it into a positive + power by algebraically inverting the base. This means computing the + minimal polynomial in ``x`` for the base, inverting ``x`` modulo this + poly (which generates a new polynomial) and then substituting the + original base expression for ``x`` in this last polynomial. + + We return the transformed expression, and we record the defining + equations for new symbols using the ``update_mapping()`` function. + + """ + if ex.is_Atom: + if ex is S.ImaginaryUnit: + if ex not in mapping: + return update_mapping(ex, 2, 1) + else: + return symbols[ex] + elif ex.is_Rational: + return ex + elif ex.is_Add: + return Add(*[ bottom_up_scan(g) for g in ex.args ]) + elif ex.is_Mul: + return Mul(*[ bottom_up_scan(g) for g in ex.args ]) + elif ex.is_Pow: + if ex.exp.is_Rational: + if ex.exp < 0: + minpoly_base = _minpoly_groebner(ex.base, x, cls) + inverse = invert(x, minpoly_base).as_expr() + base_inv = inverse.subs(x, ex.base).expand() + + if ex.exp == -1: + return bottom_up_scan(base_inv) + else: + ex = base_inv**(-ex.exp) + if not ex.exp.is_Integer: + base, exp = ( + ex.base**ex.exp.p).expand(), Rational(1, ex.exp.q) + else: + base, exp = ex.base, ex.exp + base = bottom_up_scan(base) + expr = base**exp + + if expr not in mapping: + if exp.is_Integer: + return expr.expand() + else: + return update_mapping(expr, 1 / exp, -base) + else: + return symbols[expr] + elif ex.is_AlgebraicNumber: + if ex not in mapping: + return update_mapping(ex, ex.minpoly_of_element()) + else: + return symbols[ex] + + raise NotAlgebraic("%s does not seem to be an algebraic number" % ex) + + def simpler_inverse(ex): + """ + Returns True if it is more likely that the minimal polynomial + algorithm works better with the inverse + """ + if ex.is_Pow: + if (1/ex.exp).is_integer and ex.exp < 0: + if ex.base.is_Add: + return True + if ex.is_Mul: + hit = True + for p in ex.args: + if p.is_Add: + return False + if p.is_Pow: + if p.base.is_Add and p.exp > 0: + return False + + if hit: + return True + return False + + inverted = False + ex = expand_multinomial(ex) + if ex.is_AlgebraicNumber: + return ex.minpoly_of_element().as_expr(x) + elif ex.is_Rational: + result = ex.q*x - ex.p + else: + inverted = simpler_inverse(ex) + if inverted: + ex = ex**-1 + res = None + if ex.is_Pow and (1/ex.exp).is_Integer: + n = 1/ex.exp + res = _minimal_polynomial_sq(ex.base, n, x) + + elif _is_sum_surds(ex): + res = _minimal_polynomial_sq(ex, S.One, x) + + if res is not None: + result = res + + if res is None: + bus = bottom_up_scan(ex) + F = [x - bus] + list(mapping.values()) + G = groebner(F, list(symbols.values()) + [x], order='lex') + + _, factors = factor_list(G[-1]) + # by construction G[-1] has root `ex` + result = _choose_factor(factors, x, ex) + if inverted: + result = _invertx(result, x) + if result.coeff(x**degree(result, x)) < 0: + result = expand_mul(-result) + + return result + + +@public +def minpoly(ex, x=None, compose=True, polys=False, domain=None): + """This is a synonym for :py:func:`~.minimal_polynomial`.""" + return minimal_polynomial(ex, x=x, compose=compose, polys=polys, domain=domain) diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/modules.py b/lib/python3.10/site-packages/sympy/polys/numberfields/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..af2e29bcc9cf73d97def0701712f90db58601b86 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/modules.py @@ -0,0 +1,2114 @@ +r"""Modules in number fields. + +The classes defined here allow us to work with finitely generated, free +modules, whose generators are algebraic numbers. + +There is an abstract base class called :py:class:`~.Module`, which has two +concrete subclasses, :py:class:`~.PowerBasis` and :py:class:`~.Submodule`. + +Every module is defined by its basis, or set of generators: + +* For a :py:class:`~.PowerBasis`, the generators are the first $n$ powers + (starting with the zeroth) of an algebraic integer $\theta$ of degree $n$. + The :py:class:`~.PowerBasis` is constructed by passing either the minimal + polynomial of $\theta$, or an :py:class:`~.AlgebraicField` having $\theta$ + as its primitive element. + +* For a :py:class:`~.Submodule`, the generators are a set of + $\mathbb{Q}$-linear combinations of the generators of another module. That + other module is then the "parent" of the :py:class:`~.Submodule`. The + coefficients of the $\mathbb{Q}$-linear combinations may be given by an + integer matrix, and a positive integer denominator. Each column of the matrix + defines a generator. + +>>> from sympy.polys import Poly, cyclotomic_poly, ZZ +>>> from sympy.abc import x +>>> from sympy.polys.matrices import DomainMatrix, DM +>>> from sympy.polys.numberfields.modules import PowerBasis +>>> T = Poly(cyclotomic_poly(5, x)) +>>> A = PowerBasis(T) +>>> print(A) +PowerBasis(x**4 + x**3 + x**2 + x + 1) +>>> B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ), denom=3) +>>> print(B) +Submodule[[2, 0, 0, 0], [0, 2, 0, 0], [0, 0, 2, 0], [0, 0, 0, 2]]/3 +>>> print(B.parent) +PowerBasis(x**4 + x**3 + x**2 + x + 1) + +Thus, every module is either a :py:class:`~.PowerBasis`, +or a :py:class:`~.Submodule`, some ancestor of which is a +:py:class:`~.PowerBasis`. (If ``S`` is a :py:class:`~.Submodule`, then its +ancestors are ``S.parent``, ``S.parent.parent``, and so on). + +The :py:class:`~.ModuleElement` class represents a linear combination of the +generators of any module. Critically, the coefficients of this linear +combination are not restricted to be integers, but may be any rational +numbers. This is necessary so that any and all algebraic integers be +representable, starting from the power basis in a primitive element $\theta$ +for the number field in question. For example, in a quadratic field +$\mathbb{Q}(\sqrt{d})$ where $d \equiv 1 \mod{4}$, a denominator of $2$ is +needed. + +A :py:class:`~.ModuleElement` can be constructed from an integer column vector +and a denominator: + +>>> U = Poly(x**2 - 5) +>>> M = PowerBasis(U) +>>> e = M(DM([[1], [1]], ZZ), denom=2) +>>> print(e) +[1, 1]/2 +>>> print(e.module) +PowerBasis(x**2 - 5) + +The :py:class:`~.PowerBasisElement` class is a subclass of +:py:class:`~.ModuleElement` that represents elements of a +:py:class:`~.PowerBasis`, and adds functionality pertinent to elements +represented directly over powers of the primitive element $\theta$. + + +Arithmetic with module elements +=============================== + +While a :py:class:`~.ModuleElement` represents a linear combination over the +generators of a particular module, recall that every module is either a +:py:class:`~.PowerBasis` or a descendant (along a chain of +:py:class:`~.Submodule` objects) thereof, so that in fact every +:py:class:`~.ModuleElement` represents an algebraic number in some field +$\mathbb{Q}(\theta)$, where $\theta$ is the defining element of some +:py:class:`~.PowerBasis`. It thus makes sense to talk about the number field +to which a given :py:class:`~.ModuleElement` belongs. + +This means that any two :py:class:`~.ModuleElement` instances can be added, +subtracted, multiplied, or divided, provided they belong to the same number +field. Similarly, since $\mathbb{Q}$ is a subfield of every number field, +any :py:class:`~.ModuleElement` may be added, multiplied, etc. by any +rational number. + +>>> from sympy import QQ +>>> from sympy.polys.numberfields.modules import to_col +>>> T = Poly(cyclotomic_poly(5)) +>>> A = PowerBasis(T) +>>> C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) +>>> e = A(to_col([0, 2, 0, 0]), denom=3) +>>> f = A(to_col([0, 0, 0, 7]), denom=5) +>>> g = C(to_col([1, 1, 1, 1])) +>>> e + f +[0, 10, 0, 21]/15 +>>> e - f +[0, 10, 0, -21]/15 +>>> e - g +[-9, -7, -9, -9]/3 +>>> e + QQ(7, 10) +[21, 20, 0, 0]/30 +>>> e * f +[-14, -14, -14, -14]/15 +>>> e ** 2 +[0, 0, 4, 0]/9 +>>> f // g +[7, 7, 7, 7]/15 +>>> f * QQ(2, 3) +[0, 0, 0, 14]/15 + +However, care must be taken with arithmetic operations on +:py:class:`~.ModuleElement`, because the module $C$ to which the result will +belong will be the nearest common ancestor (NCA) of the modules $A$, $B$ to +which the two operands belong, and $C$ may be different from either or both +of $A$ and $B$. + +>>> A = PowerBasis(T) +>>> B = A.submodule_from_matrix(2 * DomainMatrix.eye(4, ZZ)) +>>> C = A.submodule_from_matrix(3 * DomainMatrix.eye(4, ZZ)) +>>> print((B(0) * C(0)).module == A) +True + +Before the arithmetic operation is performed, copies of the two operands are +automatically converted into elements of the NCA (the operands themselves are +not modified). This upward conversion along an ancestor chain is easy: it just +requires the successive multiplication by the defining matrix of each +:py:class:`~.Submodule`. + +Conversely, downward conversion, i.e. representing a given +:py:class:`~.ModuleElement` in a submodule, is also supported -- namely by +the :py:meth:`~sympy.polys.numberfields.modules.Submodule.represent` method +-- but is not guaranteed to succeed in general, since the given element may +not belong to the submodule. The main circumstance in which this issue tends +to arise is with multiplication, since modules, while closed under addition, +need not be closed under multiplication. + + +Multiplication +-------------- + +Generally speaking, a module need not be closed under multiplication, i.e. need +not form a ring. However, many of the modules we work with in the context of +number fields are in fact rings, and our classes do support multiplication. + +Specifically, any :py:class:`~.Module` can attempt to compute its own +multiplication table, but this does not happen unless an attempt is made to +multiply two :py:class:`~.ModuleElement` instances belonging to it. + +>>> A = PowerBasis(T) +>>> print(A._mult_tab is None) +True +>>> a = A(0)*A(1) +>>> print(A._mult_tab is None) +False + +Every :py:class:`~.PowerBasis` is, by its nature, closed under multiplication, +so instances of :py:class:`~.PowerBasis` can always successfully compute their +multiplication table. + +When a :py:class:`~.Submodule` attempts to compute its multiplication table, +it converts each of its own generators into elements of its parent module, +multiplies them there, in every possible pairing, and then tries to +represent the results in itself, i.e. as $\mathbb{Z}$-linear combinations +over its own generators. This will succeed if and only if the submodule is +in fact closed under multiplication. + + +Module Homomorphisms +==================== + +Many important number theoretic algorithms require the calculation of the +kernel of one or more module homomorphisms. Accordingly we have several +lightweight classes, :py:class:`~.ModuleHomomorphism`, +:py:class:`~.ModuleEndomorphism`, :py:class:`~.InnerEndomorphism`, and +:py:class:`~.EndomorphismRing`, which provide the minimal necessary machinery +to support this. + +""" + +from sympy.core.intfunc import igcd, ilcm +from sympy.core.symbol import Dummy +from sympy.polys.polyclasses import ANP +from sympy.polys.polytools import Poly +from sympy.polys.densetools import dup_clear_denoms +from sympy.polys.domains.algebraicfield import AlgebraicField +from sympy.polys.domains.finitefield import FF +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.integerring import ZZ +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.matrices.exceptions import DMBadInputError +from sympy.polys.matrices.normalforms import hermite_normal_form +from sympy.polys.polyerrors import CoercionFailed, UnificationFailed +from sympy.polys.polyutils import IntegerPowerable +from .exceptions import ClosureFailure, MissingUnityError, StructureError +from .utilities import AlgIntPowers, is_rat, get_num_denom + + +def to_col(coeffs): + r"""Transform a list of integer coefficients into a column vector.""" + return DomainMatrix([[ZZ(c) for c in coeffs]], (1, len(coeffs)), ZZ).transpose() + + +class Module: + """ + Generic finitely-generated module. + + This is an abstract base class, and should not be instantiated directly. + The two concrete subclasses are :py:class:`~.PowerBasis` and + :py:class:`~.Submodule`. + + Every :py:class:`~.Submodule` is derived from another module, referenced + by its ``parent`` attribute. If ``S`` is a submodule, then we refer to + ``S.parent``, ``S.parent.parent``, and so on, as the "ancestors" of + ``S``. Thus, every :py:class:`~.Module` is either a + :py:class:`~.PowerBasis` or a :py:class:`~.Submodule`, some ancestor of + which is a :py:class:`~.PowerBasis`. + """ + + @property + def n(self): + """The number of generators of this module.""" + raise NotImplementedError + + def mult_tab(self): + """ + Get the multiplication table for this module (if closed under mult). + + Explanation + =========== + + Computes a dictionary ``M`` of dictionaries of lists, representing the + upper triangular half of the multiplication table. + + In other words, if ``0 <= i <= j < self.n``, then ``M[i][j]`` is the + list ``c`` of coefficients such that + ``g[i] * g[j] == sum(c[k]*g[k], k in range(self.n))``, + where ``g`` is the list of generators of this module. + + If ``j < i`` then ``M[i][j]`` is undefined. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> print(A.mult_tab()) # doctest: +SKIP + {0: {0: [1, 0, 0, 0], 1: [0, 1, 0, 0], 2: [0, 0, 1, 0], 3: [0, 0, 0, 1]}, + 1: {1: [0, 0, 1, 0], 2: [0, 0, 0, 1], 3: [-1, -1, -1, -1]}, + 2: {2: [-1, -1, -1, -1], 3: [1, 0, 0, 0]}, + 3: {3: [0, 1, 0, 0]}} + + Returns + ======= + + dict of dict of lists + + Raises + ====== + + ClosureFailure + If the module is not closed under multiplication. + + """ + raise NotImplementedError + + @property + def parent(self): + """ + The parent module, if any, for this module. + + Explanation + =========== + + For a :py:class:`~.Submodule` this is its ``parent`` attribute; for a + :py:class:`~.PowerBasis` this is ``None``. + + Returns + ======= + + :py:class:`~.Module`, ``None`` + + See Also + ======== + + Module + + """ + return None + + def represent(self, elt): + r""" + Represent a module element as an integer-linear combination over the + generators of this module. + + Explanation + =========== + + In our system, to "represent" always means to write a + :py:class:`~.ModuleElement` as a :ref:`ZZ`-linear combination over the + generators of the present :py:class:`~.Module`. Furthermore, the + incoming :py:class:`~.ModuleElement` must belong to an ancestor of + the present :py:class:`~.Module` (or to the present + :py:class:`~.Module` itself). + + The most common application is to represent a + :py:class:`~.ModuleElement` in a :py:class:`~.Submodule`. For example, + this is involved in computing multiplication tables. + + On the other hand, representing in a :py:class:`~.PowerBasis` is an + odd case, and one which tends not to arise in practice, except for + example when using a :py:class:`~.ModuleEndomorphism` on a + :py:class:`~.PowerBasis`. + + In such a case, (1) the incoming :py:class:`~.ModuleElement` must + belong to the :py:class:`~.PowerBasis` itself (since the latter has no + proper ancestors) and (2) it is "representable" iff it belongs to + $\mathbb{Z}[\theta]$ (although generally a + :py:class:`~.PowerBasisElement` may represent any element of + $\mathbb{Q}(\theta)$, i.e. any algebraic number). + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis, to_col + >>> from sympy.abc import zeta + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> a = A(to_col([2, 4, 6, 8])) + + The :py:class:`~.ModuleElement` ``a`` has all even coefficients. + If we represent ``a`` in the submodule ``B = 2*A``, the coefficients in + the column vector will be halved: + + >>> B = A.submodule_from_gens([2*A(i) for i in range(4)]) + >>> b = B.represent(a) + >>> print(b.transpose()) # doctest: +SKIP + DomainMatrix([[1, 2, 3, 4]], (1, 4), ZZ) + + However, the element of ``B`` so defined still represents the same + algebraic number: + + >>> print(a.poly(zeta).as_expr()) + 8*zeta**3 + 6*zeta**2 + 4*zeta + 2 + >>> print(B(b).over_power_basis().poly(zeta).as_expr()) + 8*zeta**3 + 6*zeta**2 + 4*zeta + 2 + + Parameters + ========== + + elt : :py:class:`~.ModuleElement` + The module element to be represented. Must belong to some ancestor + module of this module (including this module itself). + + Returns + ======= + + :py:class:`~.DomainMatrix` over :ref:`ZZ` + This will be a column vector, representing the coefficients of a + linear combination of this module's generators, which equals the + given element. + + Raises + ====== + + ClosureFailure + If the given element cannot be represented as a :ref:`ZZ`-linear + combination over this module. + + See Also + ======== + + .Submodule.represent + .PowerBasis.represent + + """ + raise NotImplementedError + + def ancestors(self, include_self=False): + """ + Return the list of ancestor modules of this module, from the + foundational :py:class:`~.PowerBasis` downward, optionally including + ``self``. + + See Also + ======== + + Module + + """ + c = self.parent + a = [] if c is None else c.ancestors(include_self=True) + if include_self: + a.append(self) + return a + + def power_basis_ancestor(self): + """ + Return the :py:class:`~.PowerBasis` that is an ancestor of this module. + + See Also + ======== + + Module + + """ + if isinstance(self, PowerBasis): + return self + c = self.parent + if c is not None: + return c.power_basis_ancestor() + return None + + def nearest_common_ancestor(self, other): + """ + Locate the nearest common ancestor of this module and another. + + Returns + ======= + + :py:class:`~.Module`, ``None`` + + See Also + ======== + + Module + + """ + sA = self.ancestors(include_self=True) + oA = other.ancestors(include_self=True) + nca = None + for sa, oa in zip(sA, oA): + if sa == oa: + nca = sa + else: + break + return nca + + @property + def number_field(self): + r""" + Return the associated :py:class:`~.AlgebraicField`, if any. + + Explanation + =========== + + A :py:class:`~.PowerBasis` can be constructed on a :py:class:`~.Poly` + $f$ or on an :py:class:`~.AlgebraicField` $K$. In the latter case, the + :py:class:`~.PowerBasis` and all its descendant modules will return $K$ + as their ``.number_field`` property, while in the former case they will + all return ``None``. + + Returns + ======= + + :py:class:`~.AlgebraicField`, ``None`` + + """ + return self.power_basis_ancestor().number_field + + def is_compat_col(self, col): + """Say whether *col* is a suitable column vector for this module.""" + return isinstance(col, DomainMatrix) and col.shape == (self.n, 1) and col.domain.is_ZZ + + def __call__(self, spec, denom=1): + r""" + Generate a :py:class:`~.ModuleElement` belonging to this module. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis, to_col + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> e = A(to_col([1, 2, 3, 4]), denom=3) + >>> print(e) # doctest: +SKIP + [1, 2, 3, 4]/3 + >>> f = A(2) + >>> print(f) # doctest: +SKIP + [0, 0, 1, 0] + + Parameters + ========== + + spec : :py:class:`~.DomainMatrix`, int + Specifies the numerators of the coefficients of the + :py:class:`~.ModuleElement`. Can be either a column vector over + :ref:`ZZ`, whose length must equal the number $n$ of generators of + this module, or else an integer ``j``, $0 \leq j < n$, which is a + shorthand for column $j$ of $I_n$, the $n \times n$ identity + matrix. + denom : int, optional (default=1) + Denominator for the coefficients of the + :py:class:`~.ModuleElement`. + + Returns + ======= + + :py:class:`~.ModuleElement` + The coefficients are the entries of the *spec* vector, divided by + *denom*. + + """ + if isinstance(spec, int) and 0 <= spec < self.n: + spec = DomainMatrix.eye(self.n, ZZ)[:, spec].to_dense() + if not self.is_compat_col(spec): + raise ValueError('Compatible column vector required.') + return make_mod_elt(self, spec, denom=denom) + + def starts_with_unity(self): + """Say whether the module's first generator equals unity.""" + raise NotImplementedError + + def basis_elements(self): + """ + Get list of :py:class:`~.ModuleElement` being the generators of this + module. + """ + return [self(j) for j in range(self.n)] + + def zero(self): + """Return a :py:class:`~.ModuleElement` representing zero.""" + return self(0) * 0 + + def one(self): + """ + Return a :py:class:`~.ModuleElement` representing unity, + and belonging to the first ancestor of this module (including + itself) that starts with unity. + """ + return self.element_from_rational(1) + + def element_from_rational(self, a): + """ + Return a :py:class:`~.ModuleElement` representing a rational number. + + Explanation + =========== + + The returned :py:class:`~.ModuleElement` will belong to the first + module on this module's ancestor chain (including this module + itself) that starts with unity. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, QQ + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> a = A.element_from_rational(QQ(2, 3)) + >>> print(a) # doctest: +SKIP + [2, 0, 0, 0]/3 + + Parameters + ========== + + a : int, :ref:`ZZ`, :ref:`QQ` + + Returns + ======= + + :py:class:`~.ModuleElement` + + """ + raise NotImplementedError + + def submodule_from_gens(self, gens, hnf=True, hnf_modulus=None): + """ + Form the submodule generated by a list of :py:class:`~.ModuleElement` + belonging to this module. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> gens = [A(0), 2*A(1), 3*A(2), 4*A(3)//5] + >>> B = A.submodule_from_gens(gens) + >>> print(B) # doctest: +SKIP + Submodule[[5, 0, 0, 0], [0, 10, 0, 0], [0, 0, 15, 0], [0, 0, 0, 4]]/5 + + Parameters + ========== + + gens : list of :py:class:`~.ModuleElement` belonging to this module. + hnf : boolean, optional (default=True) + If True, we will reduce the matrix into Hermite Normal Form before + forming the :py:class:`~.Submodule`. + hnf_modulus : int, None, optional (default=None) + Modulus for use in the HNF reduction algorithm. See + :py:func:`~sympy.polys.matrices.normalforms.hermite_normal_form`. + + Returns + ======= + + :py:class:`~.Submodule` + + See Also + ======== + + submodule_from_matrix + + """ + if not all(g.module == self for g in gens): + raise ValueError('Generators must belong to this module.') + n = len(gens) + if n == 0: + raise ValueError('Need at least one generator.') + m = gens[0].n + d = gens[0].denom if n == 1 else ilcm(*[g.denom for g in gens]) + B = DomainMatrix.zeros((m, 0), ZZ).hstack(*[(d // g.denom) * g.col for g in gens]) + if hnf: + B = hermite_normal_form(B, D=hnf_modulus) + return self.submodule_from_matrix(B, denom=d) + + def submodule_from_matrix(self, B, denom=1): + """ + Form the submodule generated by the elements of this module indicated + by the columns of a matrix, with an optional denominator. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, ZZ + >>> from sympy.polys.matrices import DM + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_matrix(DM([ + ... [0, 10, 0, 0], + ... [0, 0, 7, 0], + ... ], ZZ).transpose(), denom=15) + >>> print(B) # doctest: +SKIP + Submodule[[0, 10, 0, 0], [0, 0, 7, 0]]/15 + + Parameters + ========== + + B : :py:class:`~.DomainMatrix` over :ref:`ZZ` + Each column gives the numerators of the coefficients of one + generator of the submodule. Thus, the number of rows of *B* must + equal the number of generators of the present module. + denom : int, optional (default=1) + Common denominator for all generators of the submodule. + + Returns + ======= + + :py:class:`~.Submodule` + + Raises + ====== + + ValueError + If the given matrix *B* is not over :ref:`ZZ` or its number of rows + does not equal the number of generators of the present module. + + See Also + ======== + + submodule_from_gens + + """ + m, n = B.shape + if not B.domain.is_ZZ: + raise ValueError('Matrix must be over ZZ.') + if not m == self.n: + raise ValueError('Matrix row count must match base module.') + return Submodule(self, B, denom=denom) + + def whole_submodule(self): + """ + Return a submodule equal to this entire module. + + Explanation + =========== + + This is useful when you have a :py:class:`~.PowerBasis` and want to + turn it into a :py:class:`~.Submodule` (in order to use methods + belonging to the latter). + + """ + B = DomainMatrix.eye(self.n, ZZ) + return self.submodule_from_matrix(B) + + def endomorphism_ring(self): + """Form the :py:class:`~.EndomorphismRing` for this module.""" + return EndomorphismRing(self) + + +class PowerBasis(Module): + """The module generated by the powers of an algebraic integer.""" + + def __init__(self, T): + """ + Parameters + ========== + + T : :py:class:`~.Poly`, :py:class:`~.AlgebraicField` + Either (1) the monic, irreducible, univariate polynomial over + :ref:`ZZ`, a root of which is the generator of the power basis, + or (2) an :py:class:`~.AlgebraicField` whose primitive element + is the generator of the power basis. + + """ + K = None + if isinstance(T, AlgebraicField): + K, T = T, T.ext.minpoly_of_element() + # Sometimes incoming Polys are formally over QQ, although all their + # coeffs are integral. We want them to be formally over ZZ. + T = T.set_domain(ZZ) + self.K = K + self.T = T + self._n = T.degree() + self._mult_tab = None + + @property + def number_field(self): + return self.K + + def __repr__(self): + return f'PowerBasis({self.T.as_expr()})' + + def __eq__(self, other): + if isinstance(other, PowerBasis): + return self.T == other.T + return NotImplemented + + @property + def n(self): + return self._n + + def mult_tab(self): + if self._mult_tab is None: + self.compute_mult_tab() + return self._mult_tab + + def compute_mult_tab(self): + theta_pow = AlgIntPowers(self.T) + M = {} + n = self.n + for u in range(n): + M[u] = {} + for v in range(u, n): + M[u][v] = theta_pow[u + v] + self._mult_tab = M + + def represent(self, elt): + r""" + Represent a module element as an integer-linear combination over the + generators of this module. + + See Also + ======== + + .Module.represent + .Submodule.represent + + """ + if elt.module == self and elt.denom == 1: + return elt.column() + else: + raise ClosureFailure('Element not representable in ZZ[theta].') + + def starts_with_unity(self): + return True + + def element_from_rational(self, a): + return self(0) * a + + def element_from_poly(self, f): + """ + Produce an element of this module, representing *f* after reduction mod + our defining minimal polynomial. + + Parameters + ========== + + f : :py:class:`~.Poly` over :ref:`ZZ` in same var as our defining poly. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + + """ + n, k = self.n, f.degree() + if k >= n: + f = f % self.T + if f == 0: + return self.zero() + d, c = dup_clear_denoms(f.rep.to_list(), QQ, convert=True) + c = list(reversed(c)) + ell = len(c) + z = [ZZ(0)] * (n - ell) + col = to_col(c + z) + return self(col, denom=d) + + def _element_from_rep_and_mod(self, rep, mod): + """ + Produce a PowerBasisElement representing a given algebraic number. + + Parameters + ========== + + rep : list of coeffs + Represents the number as polynomial in the primitive element of the + field. + + mod : list of coeffs + Represents the minimal polynomial of the primitive element of the + field. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + + """ + if mod != self.T.rep.to_list(): + raise UnificationFailed('Element does not appear to be in the same field.') + return self.element_from_poly(Poly(rep, self.T.gen)) + + def element_from_ANP(self, a): + """Convert an ANP into a PowerBasisElement. """ + return self._element_from_rep_and_mod(a.to_list(), a.mod_to_list()) + + def element_from_alg_num(self, a): + """Convert an AlgebraicNumber into a PowerBasisElement. """ + return self._element_from_rep_and_mod(a.rep.to_list(), a.minpoly.rep.to_list()) + + +class Submodule(Module, IntegerPowerable): + """A submodule of another module.""" + + def __init__(self, parent, matrix, denom=1, mult_tab=None): + """ + Parameters + ========== + + parent : :py:class:`~.Module` + The module from which this one is derived. + matrix : :py:class:`~.DomainMatrix` over :ref:`ZZ` + The matrix whose columns define this submodule's generators as + linear combinations over the parent's generators. + denom : int, optional (default=1) + Denominator for the coefficients given by the matrix. + mult_tab : dict, ``None``, optional + If already known, the multiplication table for this module may be + supplied. + + """ + self._parent = parent + self._matrix = matrix + self._denom = denom + self._mult_tab = mult_tab + self._n = matrix.shape[1] + self._QQ_matrix = None + self._starts_with_unity = None + self._is_sq_maxrank_HNF = None + + def __repr__(self): + r = 'Submodule' + repr(self.matrix.transpose().to_Matrix().tolist()) + if self.denom > 1: + r += f'/{self.denom}' + return r + + def reduced(self): + """ + Produce a reduced version of this submodule. + + Explanation + =========== + + In the reduced version, it is guaranteed that 1 is the only positive + integer dividing both the submodule's denominator, and every entry in + the submodule's matrix. + + Returns + ======= + + :py:class:`~.Submodule` + + """ + if self.denom == 1: + return self + g = igcd(self.denom, *self.coeffs) + if g == 1: + return self + return type(self)(self.parent, (self.matrix / g).convert_to(ZZ), denom=self.denom // g, mult_tab=self._mult_tab) + + def discard_before(self, r): + """ + Produce a new module by discarding all generators before a given + index *r*. + """ + W = self.matrix[:, r:] + s = self.n - r + M = None + mt = self._mult_tab + if mt is not None: + M = {} + for u in range(s): + M[u] = {} + for v in range(u, s): + M[u][v] = mt[r + u][r + v][r:] + return Submodule(self.parent, W, denom=self.denom, mult_tab=M) + + @property + def n(self): + return self._n + + def mult_tab(self): + if self._mult_tab is None: + self.compute_mult_tab() + return self._mult_tab + + def compute_mult_tab(self): + gens = self.basis_element_pullbacks() + M = {} + n = self.n + for u in range(n): + M[u] = {} + for v in range(u, n): + M[u][v] = self.represent(gens[u] * gens[v]).flat() + self._mult_tab = M + + @property + def parent(self): + return self._parent + + @property + def matrix(self): + return self._matrix + + @property + def coeffs(self): + return self.matrix.flat() + + @property + def denom(self): + return self._denom + + @property + def QQ_matrix(self): + """ + :py:class:`~.DomainMatrix` over :ref:`QQ`, equal to + ``self.matrix / self.denom``, and guaranteed to be dense. + + Explanation + =========== + + Depending on how it is formed, a :py:class:`~.DomainMatrix` may have + an internal representation that is sparse or dense. We guarantee a + dense representation here, so that tests for equivalence of submodules + always come out as expected. + + Examples + ======== + + >>> from sympy.polys import Poly, cyclotomic_poly, ZZ + >>> from sympy.abc import x + >>> from sympy.polys.matrices import DomainMatrix + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> T = Poly(cyclotomic_poly(5, x)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_matrix(3*DomainMatrix.eye(4, ZZ), denom=6) + >>> C = A.submodule_from_matrix(DomainMatrix.eye(4, ZZ), denom=2) + >>> print(B.QQ_matrix == C.QQ_matrix) + True + + Returns + ======= + + :py:class:`~.DomainMatrix` over :ref:`QQ` + + """ + if self._QQ_matrix is None: + self._QQ_matrix = (self.matrix / self.denom).to_dense() + return self._QQ_matrix + + def starts_with_unity(self): + if self._starts_with_unity is None: + self._starts_with_unity = self(0).equiv(1) + return self._starts_with_unity + + def is_sq_maxrank_HNF(self): + if self._is_sq_maxrank_HNF is None: + self._is_sq_maxrank_HNF = is_sq_maxrank_HNF(self._matrix) + return self._is_sq_maxrank_HNF + + def is_power_basis_submodule(self): + return isinstance(self.parent, PowerBasis) + + def element_from_rational(self, a): + if self.starts_with_unity(): + return self(0) * a + else: + return self.parent.element_from_rational(a) + + def basis_element_pullbacks(self): + """ + Return list of this submodule's basis elements as elements of the + submodule's parent module. + """ + return [e.to_parent() for e in self.basis_elements()] + + def represent(self, elt): + """ + Represent a module element as an integer-linear combination over the + generators of this module. + + See Also + ======== + + .Module.represent + .PowerBasis.represent + + """ + if elt.module == self: + return elt.column() + elif elt.module == self.parent: + try: + # The given element should be a ZZ-linear combination over our + # basis vectors; however, due to the presence of denominators, + # we need to solve over QQ. + A = self.QQ_matrix + b = elt.QQ_col + x = A._solve(b)[0].transpose() + x = x.convert_to(ZZ) + except DMBadInputError: + raise ClosureFailure('Element outside QQ-span of this basis.') + except CoercionFailed: + raise ClosureFailure('Element in QQ-span but not ZZ-span of this basis.') + return x + elif isinstance(self.parent, Submodule): + coeffs_in_parent = self.parent.represent(elt) + parent_element = self.parent(coeffs_in_parent) + return self.represent(parent_element) + else: + raise ClosureFailure('Element outside ancestor chain of this module.') + + def is_compat_submodule(self, other): + return isinstance(other, Submodule) and other.parent == self.parent + + def __eq__(self, other): + if self.is_compat_submodule(other): + return other.QQ_matrix == self.QQ_matrix + return NotImplemented + + def add(self, other, hnf=True, hnf_modulus=None): + """ + Add this :py:class:`~.Submodule` to another. + + Explanation + =========== + + This represents the module generated by the union of the two modules' + sets of generators. + + Parameters + ========== + + other : :py:class:`~.Submodule` + hnf : boolean, optional (default=True) + If ``True``, reduce the matrix of the combined module to its + Hermite Normal Form. + hnf_modulus : :ref:`ZZ`, None, optional + If a positive integer is provided, use this as modulus in the + HNF reduction. See + :py:func:`~sympy.polys.matrices.normalforms.hermite_normal_form`. + + Returns + ======= + + :py:class:`~.Submodule` + + """ + d, e = self.denom, other.denom + m = ilcm(d, e) + a, b = m // d, m // e + B = (a * self.matrix).hstack(b * other.matrix) + if hnf: + B = hermite_normal_form(B, D=hnf_modulus) + return self.parent.submodule_from_matrix(B, denom=m) + + def __add__(self, other): + if self.is_compat_submodule(other): + return self.add(other) + return NotImplemented + + __radd__ = __add__ + + def mul(self, other, hnf=True, hnf_modulus=None): + """ + Multiply this :py:class:`~.Submodule` by a rational number, a + :py:class:`~.ModuleElement`, or another :py:class:`~.Submodule`. + + Explanation + =========== + + To multiply by a rational number or :py:class:`~.ModuleElement` means + to form the submodule whose generators are the products of this + quantity with all the generators of the present submodule. + + To multiply by another :py:class:`~.Submodule` means to form the + submodule whose generators are all the products of one generator from + the one submodule, and one generator from the other. + + Parameters + ========== + + other : int, :ref:`ZZ`, :ref:`QQ`, :py:class:`~.ModuleElement`, :py:class:`~.Submodule` + hnf : boolean, optional (default=True) + If ``True``, reduce the matrix of the product module to its + Hermite Normal Form. + hnf_modulus : :ref:`ZZ`, None, optional + If a positive integer is provided, use this as modulus in the + HNF reduction. See + :py:func:`~sympy.polys.matrices.normalforms.hermite_normal_form`. + + Returns + ======= + + :py:class:`~.Submodule` + + """ + if is_rat(other): + a, b = get_num_denom(other) + if a == b == 1: + return self + else: + return Submodule(self.parent, + self.matrix * a, denom=self.denom * b, + mult_tab=None).reduced() + elif isinstance(other, ModuleElement) and other.module == self.parent: + # The submodule is multiplied by an element of the parent module. + # We presume this means we want a new submodule of the parent module. + gens = [other * e for e in self.basis_element_pullbacks()] + return self.parent.submodule_from_gens(gens, hnf=hnf, hnf_modulus=hnf_modulus) + elif self.is_compat_submodule(other): + # This case usually means you're multiplying ideals, and want another + # ideal, i.e. another submodule of the same parent module. + alphas, betas = self.basis_element_pullbacks(), other.basis_element_pullbacks() + gens = [a * b for a in alphas for b in betas] + return self.parent.submodule_from_gens(gens, hnf=hnf, hnf_modulus=hnf_modulus) + return NotImplemented + + def __mul__(self, other): + return self.mul(other) + + __rmul__ = __mul__ + + def _first_power(self): + return self + + def reduce_element(self, elt): + r""" + If this submodule $B$ has defining matrix $W$ in square, maximal-rank + Hermite normal form, then, given an element $x$ of the parent module + $A$, we produce an element $y \in A$ such that $x - y \in B$, and the + $i$th coordinate of $y$ satisfies $0 \leq y_i < w_{i,i}$. This + representative $y$ is unique, in the sense that every element of + the coset $x + B$ reduces to it under this procedure. + + Explanation + =========== + + In the special case where $A$ is a power basis for a number field $K$, + and $B$ is a submodule representing an ideal $I$, this operation + represents one of a few important ways of reducing an element of $K$ + modulo $I$ to obtain a "small" representative. See [Cohen00]_ Section + 1.4.3. + + Examples + ======== + + >>> from sympy import QQ, Poly, symbols + >>> t = symbols('t') + >>> k = QQ.alg_field_from_poly(Poly(t**3 + t**2 - 2*t + 8)) + >>> Zk = k.maximal_order() + >>> A = Zk.parent + >>> B = (A(2) - 3*A(0))*Zk + >>> B.reduce_element(A(2)) + [3, 0, 0] + + Parameters + ========== + + elt : :py:class:`~.ModuleElement` + An element of this submodule's parent module. + + Returns + ======= + + elt : :py:class:`~.ModuleElement` + An element of this submodule's parent module. + + Raises + ====== + + NotImplementedError + If the given :py:class:`~.ModuleElement` does not belong to this + submodule's parent module. + StructureError + If this submodule's defining matrix is not in square, maximal-rank + Hermite normal form. + + References + ========== + + .. [Cohen00] Cohen, H. *Advanced Topics in Computational Number + Theory.* + + """ + if not elt.module == self.parent: + raise NotImplementedError + if not self.is_sq_maxrank_HNF(): + msg = "Reduction not implemented unless matrix square max-rank HNF" + raise StructureError(msg) + B = self.basis_element_pullbacks() + a = elt + for i in range(self.n - 1, -1, -1): + b = B[i] + q = a.coeffs[i]*b.denom // (b.coeffs[i]*a.denom) + a -= q*b + return a + + +def is_sq_maxrank_HNF(dm): + r""" + Say whether a :py:class:`~.DomainMatrix` is in that special case of Hermite + Normal Form, in which the matrix is also square and of maximal rank. + + Explanation + =========== + + We commonly work with :py:class:`~.Submodule` instances whose matrix is in + this form, and it can be useful to be able to check that this condition is + satisfied. + + For example this is the case with the :py:class:`~.Submodule` ``ZK`` + returned by :py:func:`~sympy.polys.numberfields.basis.round_two`, which + represents the maximal order in a number field, and with ideals formed + therefrom, such as ``2 * ZK``. + + """ + if dm.domain.is_ZZ and dm.is_square and dm.is_upper: + n = dm.shape[0] + for i in range(n): + d = dm[i, i].element + if d <= 0: + return False + for j in range(i + 1, n): + if not (0 <= dm[i, j].element < d): + return False + return True + return False + + +def make_mod_elt(module, col, denom=1): + r""" + Factory function which builds a :py:class:`~.ModuleElement`, but ensures + that it is a :py:class:`~.PowerBasisElement` if the module is a + :py:class:`~.PowerBasis`. + """ + if isinstance(module, PowerBasis): + return PowerBasisElement(module, col, denom=denom) + else: + return ModuleElement(module, col, denom=denom) + + +class ModuleElement(IntegerPowerable): + r""" + Represents an element of a :py:class:`~.Module`. + + NOTE: Should not be constructed directly. Use the + :py:meth:`~.Module.__call__` method or the :py:func:`make_mod_elt()` + factory function instead. + """ + + def __init__(self, module, col, denom=1): + """ + Parameters + ========== + + module : :py:class:`~.Module` + The module to which this element belongs. + col : :py:class:`~.DomainMatrix` over :ref:`ZZ` + Column vector giving the numerators of the coefficients of this + element. + denom : int, optional (default=1) + Denominator for the coefficients of this element. + + """ + self.module = module + self.col = col + self.denom = denom + self._QQ_col = None + + def __repr__(self): + r = str([int(c) for c in self.col.flat()]) + if self.denom > 1: + r += f'/{self.denom}' + return r + + def reduced(self): + """ + Produce a reduced version of this ModuleElement, i.e. one in which the + gcd of the denominator together with all numerator coefficients is 1. + """ + if self.denom == 1: + return self + g = igcd(self.denom, *self.coeffs) + if g == 1: + return self + return type(self)(self.module, + (self.col / g).convert_to(ZZ), + denom=self.denom // g) + + def reduced_mod_p(self, p): + """ + Produce a version of this :py:class:`~.ModuleElement` in which all + numerator coefficients have been reduced mod *p*. + """ + return make_mod_elt(self.module, + self.col.convert_to(FF(p)).convert_to(ZZ), + denom=self.denom) + + @classmethod + def from_int_list(cls, module, coeffs, denom=1): + """ + Make a :py:class:`~.ModuleElement` from a list of ints (instead of a + column vector). + """ + col = to_col(coeffs) + return cls(module, col, denom=denom) + + @property + def n(self): + """The length of this element's column.""" + return self.module.n + + def __len__(self): + return self.n + + def column(self, domain=None): + """ + Get a copy of this element's column, optionally converting to a domain. + """ + if domain is None: + return self.col.copy() + else: + return self.col.convert_to(domain) + + @property + def coeffs(self): + return self.col.flat() + + @property + def QQ_col(self): + """ + :py:class:`~.DomainMatrix` over :ref:`QQ`, equal to + ``self.col / self.denom``, and guaranteed to be dense. + + See Also + ======== + + .Submodule.QQ_matrix + + """ + if self._QQ_col is None: + self._QQ_col = (self.col / self.denom).to_dense() + return self._QQ_col + + def to_parent(self): + """ + Transform into a :py:class:`~.ModuleElement` belonging to the parent of + this element's module. + """ + if not isinstance(self.module, Submodule): + raise ValueError('Not an element of a Submodule.') + return make_mod_elt( + self.module.parent, self.module.matrix * self.col, + denom=self.module.denom * self.denom) + + def to_ancestor(self, anc): + """ + Transform into a :py:class:`~.ModuleElement` belonging to a given + ancestor of this element's module. + + Parameters + ========== + + anc : :py:class:`~.Module` + + """ + if anc == self.module: + return self + else: + return self.to_parent().to_ancestor(anc) + + def over_power_basis(self): + """ + Transform into a :py:class:`~.PowerBasisElement` over our + :py:class:`~.PowerBasis` ancestor. + """ + e = self + while not isinstance(e.module, PowerBasis): + e = e.to_parent() + return e + + def is_compat(self, other): + """ + Test whether other is another :py:class:`~.ModuleElement` with same + module. + """ + return isinstance(other, ModuleElement) and other.module == self.module + + def unify(self, other): + """ + Try to make a compatible pair of :py:class:`~.ModuleElement`, one + equivalent to this one, and one equivalent to the other. + + Explanation + =========== + + We search for the nearest common ancestor module for the pair of + elements, and represent each one there. + + Returns + ======= + + Pair ``(e1, e2)`` + Each ``ei`` is a :py:class:`~.ModuleElement`, they belong to the + same :py:class:`~.Module`, ``e1`` is equivalent to ``self``, and + ``e2`` is equivalent to ``other``. + + Raises + ====== + + UnificationFailed + If ``self`` and ``other`` have no common ancestor module. + + """ + if self.module == other.module: + return self, other + nca = self.module.nearest_common_ancestor(other.module) + if nca is not None: + return self.to_ancestor(nca), other.to_ancestor(nca) + raise UnificationFailed(f"Cannot unify {self} with {other}") + + def __eq__(self, other): + if self.is_compat(other): + return self.QQ_col == other.QQ_col + return NotImplemented + + def equiv(self, other): + """ + A :py:class:`~.ModuleElement` may test as equivalent to a rational + number or another :py:class:`~.ModuleElement`, if they represent the + same algebraic number. + + Explanation + =========== + + This method is intended to check equivalence only in those cases in + which it is easy to test; namely, when *other* is either a + :py:class:`~.ModuleElement` that can be unified with this one (i.e. one + which shares a common :py:class:`~.PowerBasis` ancestor), or else a + rational number (which is easy because every :py:class:`~.PowerBasis` + represents every rational number). + + Parameters + ========== + + other : int, :ref:`ZZ`, :ref:`QQ`, :py:class:`~.ModuleElement` + + Returns + ======= + + bool + + Raises + ====== + + UnificationFailed + If ``self`` and ``other`` do not share a common + :py:class:`~.PowerBasis` ancestor. + + """ + if self == other: + return True + elif isinstance(other, ModuleElement): + a, b = self.unify(other) + return a == b + elif is_rat(other): + if isinstance(self, PowerBasisElement): + return self == self.module(0) * other + else: + return self.over_power_basis().equiv(other) + return False + + def __add__(self, other): + """ + A :py:class:`~.ModuleElement` can be added to a rational number, or to + another :py:class:`~.ModuleElement`. + + Explanation + =========== + + When the other summand is a rational number, it will be converted into + a :py:class:`~.ModuleElement` (belonging to the first ancestor of this + module that starts with unity). + + In all cases, the sum belongs to the nearest common ancestor (NCA) of + the modules of the two summands. If the NCA does not exist, we return + ``NotImplemented``. + """ + if self.is_compat(other): + d, e = self.denom, other.denom + m = ilcm(d, e) + u, v = m // d, m // e + col = to_col([u * a + v * b for a, b in zip(self.coeffs, other.coeffs)]) + return type(self)(self.module, col, denom=m).reduced() + elif isinstance(other, ModuleElement): + try: + a, b = self.unify(other) + except UnificationFailed: + return NotImplemented + return a + b + elif is_rat(other): + return self + self.module.element_from_rational(other) + return NotImplemented + + __radd__ = __add__ + + def __neg__(self): + return self * -1 + + def __sub__(self, other): + return self + (-other) + + def __rsub__(self, other): + return -self + other + + def __mul__(self, other): + """ + A :py:class:`~.ModuleElement` can be multiplied by a rational number, + or by another :py:class:`~.ModuleElement`. + + Explanation + =========== + + When the multiplier is a rational number, the product is computed by + operating directly on the coefficients of this + :py:class:`~.ModuleElement`. + + When the multiplier is another :py:class:`~.ModuleElement`, the product + will belong to the nearest common ancestor (NCA) of the modules of the + two operands, and that NCA must have a multiplication table. If the NCA + does not exist, we return ``NotImplemented``. If the NCA does not have + a mult. table, ``ClosureFailure`` will be raised. + """ + if self.is_compat(other): + M = self.module.mult_tab() + A, B = self.col.flat(), other.col.flat() + n = self.n + C = [0] * n + for u in range(n): + for v in range(u, n): + c = A[u] * B[v] + if v > u: + c += A[v] * B[u] + if c != 0: + R = M[u][v] + for k in range(n): + C[k] += c * R[k] + d = self.denom * other.denom + return self.from_int_list(self.module, C, denom=d) + elif isinstance(other, ModuleElement): + try: + a, b = self.unify(other) + except UnificationFailed: + return NotImplemented + return a * b + elif is_rat(other): + a, b = get_num_denom(other) + if a == b == 1: + return self + else: + return make_mod_elt(self.module, + self.col * a, denom=self.denom * b).reduced() + return NotImplemented + + __rmul__ = __mul__ + + def _zeroth_power(self): + return self.module.one() + + def _first_power(self): + return self + + def __floordiv__(self, a): + if is_rat(a): + a = QQ(a) + return self * (1/a) + elif isinstance(a, ModuleElement): + return self * (1//a) + return NotImplemented + + def __rfloordiv__(self, a): + return a // self.over_power_basis() + + def __mod__(self, m): + r""" + Reduce this :py:class:`~.ModuleElement` mod a :py:class:`~.Submodule`. + + Parameters + ========== + + m : int, :ref:`ZZ`, :ref:`QQ`, :py:class:`~.Submodule` + If a :py:class:`~.Submodule`, reduce ``self`` relative to this. + If an integer or rational, reduce relative to the + :py:class:`~.Submodule` that is our own module times this constant. + + See Also + ======== + + .Submodule.reduce_element + + """ + if is_rat(m): + m = m * self.module.whole_submodule() + if isinstance(m, Submodule) and m.parent == self.module: + return m.reduce_element(self) + return NotImplemented + + +class PowerBasisElement(ModuleElement): + r""" + Subclass for :py:class:`~.ModuleElement` instances whose module is a + :py:class:`~.PowerBasis`. + """ + + @property + def T(self): + """Access the defining polynomial of the :py:class:`~.PowerBasis`.""" + return self.module.T + + def numerator(self, x=None): + """Obtain the numerator as a polynomial over :ref:`ZZ`.""" + x = x or self.T.gen + return Poly(reversed(self.coeffs), x, domain=ZZ) + + def poly(self, x=None): + """Obtain the number as a polynomial over :ref:`QQ`.""" + return self.numerator(x=x) // self.denom + + @property + def is_rational(self): + """Say whether this element represents a rational number.""" + return self.col[1:, :].is_zero_matrix + + @property + def generator(self): + """ + Return a :py:class:`~.Symbol` to be used when expressing this element + as a polynomial. + + If we have an associated :py:class:`~.AlgebraicField` whose primitive + element has an alias symbol, we use that. Otherwise we use the variable + of the minimal polynomial defining the power basis to which we belong. + """ + K = self.module.number_field + return K.ext.alias if K and K.ext.is_aliased else self.T.gen + + def as_expr(self, x=None): + """Create a Basic expression from ``self``. """ + return self.poly(x or self.generator).as_expr() + + def norm(self, T=None): + """Compute the norm of this number.""" + T = T or self.T + x = T.gen + A = self.numerator(x=x) + return T.resultant(A) // self.denom ** self.n + + def inverse(self): + f = self.poly() + f_inv = f.invert(self.T) + return self.module.element_from_poly(f_inv) + + def __rfloordiv__(self, a): + return self.inverse() * a + + def _negative_power(self, e, modulo=None): + return self.inverse() ** abs(e) + + def to_ANP(self): + """Convert to an equivalent :py:class:`~.ANP`. """ + return ANP(list(reversed(self.QQ_col.flat())), QQ.map(self.T.rep.to_list()), QQ) + + def to_alg_num(self): + """ + Try to convert to an equivalent :py:class:`~.AlgebraicNumber`. + + Explanation + =========== + + In general, the conversion from an :py:class:`~.AlgebraicNumber` to a + :py:class:`~.PowerBasisElement` throws away information, because an + :py:class:`~.AlgebraicNumber` specifies a complex embedding, while a + :py:class:`~.PowerBasisElement` does not. However, in some cases it is + possible to convert a :py:class:`~.PowerBasisElement` back into an + :py:class:`~.AlgebraicNumber`, namely when the associated + :py:class:`~.PowerBasis` has a reference to an + :py:class:`~.AlgebraicField`. + + Returns + ======= + + :py:class:`~.AlgebraicNumber` + + Raises + ====== + + StructureError + If the :py:class:`~.PowerBasis` to which this element belongs does + not have an associated :py:class:`~.AlgebraicField`. + + """ + K = self.module.number_field + if K: + return K.to_alg_num(self.to_ANP()) + raise StructureError("No associated AlgebraicField") + + +class ModuleHomomorphism: + r"""A homomorphism from one module to another.""" + + def __init__(self, domain, codomain, mapping): + r""" + Parameters + ========== + + domain : :py:class:`~.Module` + The domain of the mapping. + + codomain : :py:class:`~.Module` + The codomain of the mapping. + + mapping : callable + An arbitrary callable is accepted, but should be chosen so as + to represent an actual module homomorphism. In particular, should + accept elements of *domain* and return elements of *codomain*. + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis, ModuleHomomorphism + >>> T = Poly(cyclotomic_poly(5)) + >>> A = PowerBasis(T) + >>> B = A.submodule_from_gens([2*A(j) for j in range(4)]) + >>> phi = ModuleHomomorphism(A, B, lambda x: 6*x) + >>> print(phi.matrix()) # doctest: +SKIP + DomainMatrix([[3, 0, 0, 0], [0, 3, 0, 0], [0, 0, 3, 0], [0, 0, 0, 3]], (4, 4), ZZ) + + """ + self.domain = domain + self.codomain = codomain + self.mapping = mapping + + def matrix(self, modulus=None): + r""" + Compute the matrix of this homomorphism. + + Parameters + ========== + + modulus : int, optional + A positive prime number $p$ if the matrix should be reduced mod + $p$. + + Returns + ======= + + :py:class:`~.DomainMatrix` + The matrix is over :ref:`ZZ`, or else over :ref:`GF(p)` if a + modulus was given. + + """ + basis = self.domain.basis_elements() + cols = [self.codomain.represent(self.mapping(elt)) for elt in basis] + if not cols: + return DomainMatrix.zeros((self.codomain.n, 0), ZZ).to_dense() + M = cols[0].hstack(*cols[1:]) + if modulus: + M = M.convert_to(FF(modulus)) + return M + + def kernel(self, modulus=None): + r""" + Compute a Submodule representing the kernel of this homomorphism. + + Parameters + ========== + + modulus : int, optional + A positive prime number $p$ if the kernel should be computed mod + $p$. + + Returns + ======= + + :py:class:`~.Submodule` + This submodule's generators span the kernel of this + homomorphism over :ref:`ZZ`, or else over :ref:`GF(p)` if a + modulus was given. + + """ + M = self.matrix(modulus=modulus) + if modulus is None: + M = M.convert_to(QQ) + # Note: Even when working over a finite field, what we want here is + # the pullback into the integers, so in this case the conversion to ZZ + # below is appropriate. When working over ZZ, the kernel should be a + # ZZ-submodule, so, while the conversion to QQ above was required in + # order for the nullspace calculation to work, conversion back to ZZ + # afterward should always work. + # TODO: + # Watch , which calls + # for fraction-free algorithms. If this is implemented, we can skip + # the conversion to `QQ` above. + K = M.nullspace().convert_to(ZZ).transpose() + return self.domain.submodule_from_matrix(K) + + +class ModuleEndomorphism(ModuleHomomorphism): + r"""A homomorphism from one module to itself.""" + + def __init__(self, domain, mapping): + r""" + Parameters + ========== + + domain : :py:class:`~.Module` + The common domain and codomain of the mapping. + + mapping : callable + An arbitrary callable is accepted, but should be chosen so as + to represent an actual module endomorphism. In particular, should + accept and return elements of *domain*. + + """ + super().__init__(domain, domain, mapping) + + +class InnerEndomorphism(ModuleEndomorphism): + r""" + An inner endomorphism on a module, i.e. the endomorphism corresponding to + multiplication by a fixed element. + """ + + def __init__(self, domain, multiplier): + r""" + Parameters + ========== + + domain : :py:class:`~.Module` + The domain and codomain of the endomorphism. + + multiplier : :py:class:`~.ModuleElement` + The element $a$ defining the mapping as $x \mapsto a x$. + + """ + super().__init__(domain, lambda x: multiplier * x) + self.multiplier = multiplier + + +class EndomorphismRing: + r"""The ring of endomorphisms on a module.""" + + def __init__(self, domain): + """ + Parameters + ========== + + domain : :py:class:`~.Module` + The domain and codomain of the endomorphisms. + + """ + self.domain = domain + + def inner_endomorphism(self, multiplier): + r""" + Form an inner endomorphism belonging to this endomorphism ring. + + Parameters + ========== + + multiplier : :py:class:`~.ModuleElement` + Element $a$ defining the inner endomorphism $x \mapsto a x$. + + Returns + ======= + + :py:class:`~.InnerEndomorphism` + + """ + return InnerEndomorphism(self.domain, multiplier) + + def represent(self, element): + r""" + Represent an element of this endomorphism ring, as a single column + vector. + + Explanation + =========== + + Let $M$ be a module, and $E$ its ring of endomorphisms. Let $N$ be + another module, and consider a homomorphism $\varphi: N \rightarrow E$. + In the event that $\varphi$ is to be represented by a matrix $A$, each + column of $A$ must represent an element of $E$. This is possible when + the elements of $E$ are themselves representable as matrices, by + stacking the columns of such a matrix into a single column. + + This method supports calculating such matrices $A$, by representing + an element of this endomorphism ring first as a matrix, and then + stacking that matrix's columns into a single column. + + Examples + ======== + + Note that in these examples we print matrix transposes, to make their + columns easier to inspect. + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.modules import PowerBasis + >>> from sympy.polys.numberfields.modules import ModuleHomomorphism + >>> T = Poly(cyclotomic_poly(5)) + >>> M = PowerBasis(T) + >>> E = M.endomorphism_ring() + + Let $\zeta$ be a primitive 5th root of unity, a generator of our field, + and consider the inner endomorphism $\tau$ on the ring of integers, + induced by $\zeta$: + + >>> zeta = M(1) + >>> tau = E.inner_endomorphism(zeta) + >>> tau.matrix().transpose() # doctest: +SKIP + DomainMatrix( + [[0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1], [-1, -1, -1, -1]], + (4, 4), ZZ) + + The matrix representation of $\tau$ is as expected. The first column + shows that multiplying by $\zeta$ carries $1$ to $\zeta$, the second + column that it carries $\zeta$ to $\zeta^2$, and so forth. + + The ``represent`` method of the endomorphism ring ``E`` stacks these + into a single column: + + >>> E.represent(tau).transpose() # doctest: +SKIP + DomainMatrix( + [[0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, -1, -1, -1, -1]], + (1, 16), ZZ) + + This is useful when we want to consider a homomorphism $\varphi$ having + ``E`` as codomain: + + >>> phi = ModuleHomomorphism(M, E, lambda x: E.inner_endomorphism(x)) + + and we want to compute the matrix of such a homomorphism: + + >>> phi.matrix().transpose() # doctest: +SKIP + DomainMatrix( + [[1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1], + [0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, -1, -1, -1, -1], + [0, 0, 1, 0, 0, 0, 0, 1, -1, -1, -1, -1, 1, 0, 0, 0], + [0, 0, 0, 1, -1, -1, -1, -1, 1, 0, 0, 0, 0, 1, 0, 0]], + (4, 16), ZZ) + + Note that the stacked matrix of $\tau$ occurs as the second column in + this example. This is because $\zeta$ is the second basis element of + ``M``, and $\varphi(\zeta) = \tau$. + + Parameters + ========== + + element : :py:class:`~.ModuleEndomorphism` belonging to this ring. + + Returns + ======= + + :py:class:`~.DomainMatrix` + Column vector equalling the vertical stacking of all the columns + of the matrix that represents the given *element* as a mapping. + + """ + if isinstance(element, ModuleEndomorphism) and element.domain == self.domain: + M = element.matrix() + # Transform the matrix into a single column, which should reproduce + # the original columns, one after another. + m, n = M.shape + if n == 0: + return M + return M[:, 0].vstack(*[M[:, j] for j in range(1, n)]) + raise NotImplementedError + + +def find_min_poly(alpha, domain, x=None, powers=None): + r""" + Find a polynomial of least degree (not necessarily irreducible) satisfied + by an element of a finitely-generated ring with unity. + + Examples + ======== + + For the $n$th cyclotomic field, $n$ an odd prime, consider the quadratic + equation whose roots are the two periods of length $(n-1)/2$. Article 356 + of Gauss tells us that we should get $x^2 + x - (n-1)/4$ or + $x^2 + x + (n+1)/4$ according to whether $n$ is 1 or 3 mod 4, respectively. + + >>> from sympy import Poly, cyclotomic_poly, primitive_root, QQ + >>> from sympy.abc import x + >>> from sympy.polys.numberfields.modules import PowerBasis, find_min_poly + >>> n = 13 + >>> g = primitive_root(n) + >>> C = PowerBasis(Poly(cyclotomic_poly(n, x))) + >>> ee = [g**(2*k+1) % n for k in range((n-1)//2)] + >>> eta = sum(C(e) for e in ee) + >>> print(find_min_poly(eta, QQ, x=x).as_expr()) + x**2 + x - 3 + >>> n = 19 + >>> g = primitive_root(n) + >>> C = PowerBasis(Poly(cyclotomic_poly(n, x))) + >>> ee = [g**(2*k+2) % n for k in range((n-1)//2)] + >>> eta = sum(C(e) for e in ee) + >>> print(find_min_poly(eta, QQ, x=x).as_expr()) + x**2 + x + 5 + + Parameters + ========== + + alpha : :py:class:`~.ModuleElement` + The element whose min poly is to be found, and whose module has + multiplication and starts with unity. + + domain : :py:class:`~.Domain` + The desired domain of the polynomial. + + x : :py:class:`~.Symbol`, optional + The desired variable for the polynomial. + + powers : list, optional + If desired, pass an empty list. The powers of *alpha* (as + :py:class:`~.ModuleElement` instances) from the zeroth up to the degree + of the min poly will be recorded here, as we compute them. + + Returns + ======= + + :py:class:`~.Poly`, ``None`` + The minimal polynomial for alpha, or ``None`` if no polynomial could be + found over the desired domain. + + Raises + ====== + + MissingUnityError + If the module to which alpha belongs does not start with unity. + ClosureFailure + If the module to which alpha belongs is not closed under + multiplication. + + """ + R = alpha.module + if not R.starts_with_unity(): + raise MissingUnityError("alpha must belong to finitely generated ring with unity.") + if powers is None: + powers = [] + one = R(0) + powers.append(one) + powers_matrix = one.column(domain=domain) + ak = alpha + m = None + for k in range(1, R.n + 1): + powers.append(ak) + ak_col = ak.column(domain=domain) + try: + X = powers_matrix._solve(ak_col)[0] + except DMBadInputError: + # This means alpha^k still isn't in the domain-span of the lower powers. + powers_matrix = powers_matrix.hstack(ak_col) + ak *= alpha + else: + # alpha^k is in the domain-span of the lower powers, so we have found a + # minimal-degree poly for alpha. + coeffs = [1] + [-c for c in reversed(X.to_list_flat())] + x = x or Dummy('x') + if domain.is_FF: + m = Poly(coeffs, x, modulus=domain.mod) + else: + m = Poly(coeffs, x, domain=domain) + break + return m diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/primes.py b/lib/python3.10/site-packages/sympy/polys/numberfields/primes.py new file mode 100644 index 0000000000000000000000000000000000000000..8f28f13d94f33ed59cded8eabd05e9cf7d0f103f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/primes.py @@ -0,0 +1,784 @@ +"""Prime ideals in number fields. """ + +from sympy.polys.polytools import Poly +from sympy.polys.domains.finitefield import FF +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.integerring import ZZ +from sympy.polys.matrices.domainmatrix import DomainMatrix +from sympy.polys.polyerrors import CoercionFailed +from sympy.polys.polyutils import IntegerPowerable +from sympy.utilities.decorator import public +from .basis import round_two, nilradical_mod_p +from .exceptions import StructureError +from .modules import ModuleEndomorphism, find_min_poly +from .utilities import coeff_search, supplement_a_subspace + + +def _check_formal_conditions_for_maximal_order(submodule): + r""" + Several functions in this module accept an argument which is to be a + :py:class:`~.Submodule` representing the maximal order in a number field, + such as returned by the :py:func:`~sympy.polys.numberfields.basis.round_two` + algorithm. + + We do not attempt to check that the given ``Submodule`` actually represents + a maximal order, but we do check a basic set of formal conditions that the + ``Submodule`` must satisfy, at a minimum. The purpose is to catch an + obviously ill-formed argument. + """ + prefix = 'The submodule representing the maximal order should ' + cond = None + if not submodule.is_power_basis_submodule(): + cond = 'be a direct submodule of a power basis.' + elif not submodule.starts_with_unity(): + cond = 'have 1 as its first generator.' + elif not submodule.is_sq_maxrank_HNF(): + cond = 'have square matrix, of maximal rank, in Hermite Normal Form.' + if cond is not None: + raise StructureError(prefix + cond) + + +class PrimeIdeal(IntegerPowerable): + r""" + A prime ideal in a ring of algebraic integers. + """ + + def __init__(self, ZK, p, alpha, f, e=None): + """ + Parameters + ========== + + ZK : :py:class:`~.Submodule` + The maximal order where this ideal lives. + p : int + The rational prime this ideal divides. + alpha : :py:class:`~.PowerBasisElement` + Such that the ideal is equal to ``p*ZK + alpha*ZK``. + f : int + The inertia degree. + e : int, ``None``, optional + The ramification index, if already known. If ``None``, we will + compute it here. + + """ + _check_formal_conditions_for_maximal_order(ZK) + self.ZK = ZK + self.p = p + self.alpha = alpha + self.f = f + self._test_factor = None + self.e = e if e is not None else self.valuation(p * ZK) + + def __str__(self): + if self.is_inert: + return f'({self.p})' + return f'({self.p}, {self.alpha.as_expr()})' + + @property + def is_inert(self): + """ + Say whether the rational prime we divide is inert, i.e. stays prime in + our ring of integers. + """ + return self.f == self.ZK.n + + def repr(self, field_gen=None, just_gens=False): + """ + Print a representation of this prime ideal. + + Examples + ======== + + >>> from sympy import cyclotomic_poly, QQ + >>> from sympy.abc import x, zeta + >>> T = cyclotomic_poly(7, x) + >>> K = QQ.algebraic_field((T, zeta)) + >>> P = K.primes_above(11) + >>> print(P[0].repr()) + [ (11, x**3 + 5*x**2 + 4*x - 1) e=1, f=3 ] + >>> print(P[0].repr(field_gen=zeta)) + [ (11, zeta**3 + 5*zeta**2 + 4*zeta - 1) e=1, f=3 ] + >>> print(P[0].repr(field_gen=zeta, just_gens=True)) + (11, zeta**3 + 5*zeta**2 + 4*zeta - 1) + + Parameters + ========== + + field_gen : :py:class:`~.Symbol`, ``None``, optional (default=None) + The symbol to use for the generator of the field. This will appear + in our representation of ``self.alpha``. If ``None``, we use the + variable of the defining polynomial of ``self.ZK``. + just_gens : bool, optional (default=False) + If ``True``, just print the "(p, alpha)" part, showing "just the + generators" of the prime ideal. Otherwise, print a string of the + form "[ (p, alpha) e=..., f=... ]", giving the ramification index + and inertia degree, along with the generators. + + """ + field_gen = field_gen or self.ZK.parent.T.gen + p, alpha, e, f = self.p, self.alpha, self.e, self.f + alpha_rep = str(alpha.numerator(x=field_gen).as_expr()) + if alpha.denom > 1: + alpha_rep = f'({alpha_rep})/{alpha.denom}' + gens = f'({p}, {alpha_rep})' + if just_gens: + return gens + return f'[ {gens} e={e}, f={f} ]' + + def __repr__(self): + return self.repr() + + def as_submodule(self): + r""" + Represent this prime ideal as a :py:class:`~.Submodule`. + + Explanation + =========== + + The :py:class:`~.PrimeIdeal` class serves to bundle information about + a prime ideal, such as its inertia degree, ramification index, and + two-generator representation, as well as to offer helpful methods like + :py:meth:`~.PrimeIdeal.valuation` and + :py:meth:`~.PrimeIdeal.test_factor`. + + However, in order to be added and multiplied by other ideals or + rational numbers, it must first be converted into a + :py:class:`~.Submodule`, which is a class that supports these + operations. + + In many cases, the user need not perform this conversion deliberately, + since it is automatically performed by the arithmetic operator methods + :py:meth:`~.PrimeIdeal.__add__` and :py:meth:`~.PrimeIdeal.__mul__`. + + Raising a :py:class:`~.PrimeIdeal` to a non-negative integer power is + also supported. + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly, prime_decomp + >>> T = Poly(cyclotomic_poly(7)) + >>> P0 = prime_decomp(7, T)[0] + >>> print(P0**6 == 7*P0.ZK) + True + + Note that, on both sides of the equation above, we had a + :py:class:`~.Submodule`. In the next equation we recall that adding + ideals yields their GCD. This time, we need a deliberate conversion + to :py:class:`~.Submodule` on the right: + + >>> print(P0 + 7*P0.ZK == P0.as_submodule()) + True + + Returns + ======= + + :py:class:`~.Submodule` + Will be equal to ``self.p * self.ZK + self.alpha * self.ZK``. + + See Also + ======== + + __add__ + __mul__ + + """ + M = self.p * self.ZK + self.alpha * self.ZK + # Pre-set expensive boolean properties whose value we already know: + M._starts_with_unity = False + M._is_sq_maxrank_HNF = True + return M + + def __eq__(self, other): + if isinstance(other, PrimeIdeal): + return self.as_submodule() == other.as_submodule() + return NotImplemented + + def __add__(self, other): + """ + Convert to a :py:class:`~.Submodule` and add to another + :py:class:`~.Submodule`. + + See Also + ======== + + as_submodule + + """ + return self.as_submodule() + other + + __radd__ = __add__ + + def __mul__(self, other): + """ + Convert to a :py:class:`~.Submodule` and multiply by another + :py:class:`~.Submodule` or a rational number. + + See Also + ======== + + as_submodule + + """ + return self.as_submodule() * other + + __rmul__ = __mul__ + + def _zeroth_power(self): + return self.ZK + + def _first_power(self): + return self + + def test_factor(self): + r""" + Compute a test factor for this prime ideal. + + Explanation + =========== + + Write $\mathfrak{p}$ for this prime ideal, $p$ for the rational prime + it divides. Then, for computing $\mathfrak{p}$-adic valuations it is + useful to have a number $\beta \in \mathbb{Z}_K$ such that + $p/\mathfrak{p} = p \mathbb{Z}_K + \beta \mathbb{Z}_K$. + + Essentially, this is the same as the number $\Psi$ (or the "reagent") + from Kummer's 1847 paper (*Ueber die Zerlegung...*, Crelle vol. 35) in + which ideal divisors were invented. + """ + if self._test_factor is None: + self._test_factor = _compute_test_factor(self.p, [self.alpha], self.ZK) + return self._test_factor + + def valuation(self, I): + r""" + Compute the $\mathfrak{p}$-adic valuation of integral ideal I at this + prime ideal. + + Parameters + ========== + + I : :py:class:`~.Submodule` + + See Also + ======== + + prime_valuation + + """ + return prime_valuation(I, self) + + def reduce_element(self, elt): + """ + Reduce a :py:class:`~.PowerBasisElement` to a "small representative" + modulo this prime ideal. + + Parameters + ========== + + elt : :py:class:`~.PowerBasisElement` + The element to be reduced. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + The reduced element. + + See Also + ======== + + reduce_ANP + reduce_alg_num + .Submodule.reduce_element + + """ + return self.as_submodule().reduce_element(elt) + + def reduce_ANP(self, a): + """ + Reduce an :py:class:`~.ANP` to a "small representative" modulo this + prime ideal. + + Parameters + ========== + + elt : :py:class:`~.ANP` + The element to be reduced. + + Returns + ======= + + :py:class:`~.ANP` + The reduced element. + + See Also + ======== + + reduce_element + reduce_alg_num + .Submodule.reduce_element + + """ + elt = self.ZK.parent.element_from_ANP(a) + red = self.reduce_element(elt) + return red.to_ANP() + + def reduce_alg_num(self, a): + """ + Reduce an :py:class:`~.AlgebraicNumber` to a "small representative" + modulo this prime ideal. + + Parameters + ========== + + elt : :py:class:`~.AlgebraicNumber` + The element to be reduced. + + Returns + ======= + + :py:class:`~.AlgebraicNumber` + The reduced element. + + See Also + ======== + + reduce_element + reduce_ANP + .Submodule.reduce_element + + """ + elt = self.ZK.parent.element_from_alg_num(a) + red = self.reduce_element(elt) + return a.field_element(list(reversed(red.QQ_col.flat()))) + + +def _compute_test_factor(p, gens, ZK): + r""" + Compute the test factor for a :py:class:`~.PrimeIdeal` $\mathfrak{p}$. + + Parameters + ========== + + p : int + The rational prime $\mathfrak{p}$ divides + + gens : list of :py:class:`PowerBasisElement` + A complete set of generators for $\mathfrak{p}$ over *ZK*, EXCEPT that + an element equivalent to rational *p* can and should be omitted (since + it has no effect except to waste time). + + ZK : :py:class:`~.Submodule` + The maximal order where the prime ideal $\mathfrak{p}$ lives. + + Returns + ======= + + :py:class:`~.PowerBasisElement` + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Proposition 4.8.15.) + + """ + _check_formal_conditions_for_maximal_order(ZK) + E = ZK.endomorphism_ring() + matrices = [E.inner_endomorphism(g).matrix(modulus=p) for g in gens] + B = DomainMatrix.zeros((0, ZK.n), FF(p)).vstack(*matrices) + # A nonzero element of the nullspace of B will represent a + # lin comb over the omegas which (i) is not a multiple of p + # (since it is nonzero over FF(p)), while (ii) is such that + # its product with each g in gens _is_ a multiple of p (since + # B represents multiplication by these generators). Theory + # predicts that such an element must exist, so nullspace should + # be non-trivial. + x = B.nullspace()[0, :].transpose() + beta = ZK.parent(ZK.matrix * x.convert_to(ZZ), denom=ZK.denom) + return beta + + +@public +def prime_valuation(I, P): + r""" + Compute the *P*-adic valuation for an integral ideal *I*. + + Examples + ======== + + >>> from sympy import QQ + >>> from sympy.polys.numberfields import prime_valuation + >>> K = QQ.cyclotomic_field(5) + >>> P = K.primes_above(5) + >>> ZK = K.maximal_order() + >>> print(prime_valuation(25*ZK, P[0])) + 8 + + Parameters + ========== + + I : :py:class:`~.Submodule` + An integral ideal whose valuation is desired. + + P : :py:class:`~.PrimeIdeal` + The prime at which to compute the valuation. + + Returns + ======= + + int + + See Also + ======== + + .PrimeIdeal.valuation + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 4.8.17.) + + """ + p, ZK = P.p, P.ZK + n, W, d = ZK.n, ZK.matrix, ZK.denom + + A = W.convert_to(QQ).inv() * I.matrix * d / I.denom + # Although A must have integer entries, given that I is an integral ideal, + # as a DomainMatrix it will still be over QQ, so we convert back: + A = A.convert_to(ZZ) + D = A.det() + if D % p != 0: + return 0 + + beta = P.test_factor() + + f = d ** n // W.det() + need_complete_test = (f % p == 0) + v = 0 + while True: + # Entering the loop, the cols of A represent lin combs of omegas. + # Turn them into lin combs of thetas: + A = W * A + # And then one column at a time... + for j in range(n): + c = ZK.parent(A[:, j], denom=d) + c *= beta + # ...turn back into lin combs of omegas, after multiplying by beta: + c = ZK.represent(c).flat() + for i in range(n): + A[i, j] = c[i] + if A[n - 1, n - 1].element % p != 0: + break + A = A / p + # As noted above, domain converts to QQ even when division goes evenly. + # So must convert back, even when we don't "need_complete_test". + if need_complete_test: + # In this case, having a non-integer entry is actually just our + # halting condition. + try: + A = A.convert_to(ZZ) + except CoercionFailed: + break + else: + # In this case theory says we should not have any non-integer entries. + A = A.convert_to(ZZ) + v += 1 + return v + + +def _two_elt_rep(gens, ZK, p, f=None, Np=None): + r""" + Given a set of *ZK*-generators of a prime ideal, compute a set of just two + *ZK*-generators for the same ideal, one of which is *p* itself. + + Parameters + ========== + + gens : list of :py:class:`PowerBasisElement` + Generators for the prime ideal over *ZK*, the ring of integers of the + field $K$. + + ZK : :py:class:`~.Submodule` + The maximal order in $K$. + + p : int + The rational prime divided by the prime ideal. + + f : int, optional + The inertia degree of the prime ideal, if known. + + Np : int, optional + The norm $p^f$ of the prime ideal, if known. + NOTE: There is no reason to supply both *f* and *Np*. Either one will + save us from having to compute the norm *Np* ourselves. If both are known, + *Np* is preferred since it saves one exponentiation. + + Returns + ======= + + :py:class:`~.PowerBasisElement` representing a single algebraic integer + alpha such that the prime ideal is equal to ``p*ZK + alpha*ZK``. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 4.7.10.) + + """ + _check_formal_conditions_for_maximal_order(ZK) + pb = ZK.parent + T = pb.T + # Detect the special cases in which either (a) all generators are multiples + # of p, or (b) there are no generators (so `all` is vacuously true): + if all((g % p).equiv(0) for g in gens): + return pb.zero() + + if Np is None: + if f is not None: + Np = p**f + else: + Np = abs(pb.submodule_from_gens(gens).matrix.det()) + + omega = ZK.basis_element_pullbacks() + beta = [p*om for om in omega[1:]] # note: we omit omega[0] == 1 + beta += gens + search = coeff_search(len(beta), 1) + for c in search: + alpha = sum(ci*betai for ci, betai in zip(c, beta)) + # Note: It may be tempting to reduce alpha mod p here, to try to work + # with smaller numbers, but must not do that, as it can result in an + # infinite loop! E.g. try factoring 2 in Q(sqrt(-7)). + n = alpha.norm(T) // Np + if n % p != 0: + # Now can reduce alpha mod p. + return alpha % p + + +def _prime_decomp_easy_case(p, ZK): + r""" + Compute the decomposition of rational prime *p* in the ring of integers + *ZK* (given as a :py:class:`~.Submodule`), in the "easy case", i.e. the + case where *p* does not divide the index of $\theta$ in *ZK*, where + $\theta$ is the generator of the ``PowerBasis`` of which *ZK* is a + ``Submodule``. + """ + T = ZK.parent.T + T_bar = Poly(T, modulus=p) + lc, fl = T_bar.factor_list() + if len(fl) == 1 and fl[0][1] == 1: + return [PrimeIdeal(ZK, p, ZK.parent.zero(), ZK.n, 1)] + return [PrimeIdeal(ZK, p, + ZK.parent.element_from_poly(Poly(t, domain=ZZ)), + t.degree(), e) + for t, e in fl] + + +def _prime_decomp_compute_kernel(I, p, ZK): + r""" + Parameters + ========== + + I : :py:class:`~.Module` + An ideal of ``ZK/pZK``. + p : int + The rational prime being factored. + ZK : :py:class:`~.Submodule` + The maximal order. + + Returns + ======= + + Pair ``(N, G)``, where: + + ``N`` is a :py:class:`~.Module` representing the kernel of the map + ``a |--> a**p - a`` on ``(O/pO)/I``, guaranteed to be a module with + unity. + + ``G`` is a :py:class:`~.Module` representing a basis for the separable + algebra ``A = O/I`` (see Cohen). + + """ + W = I.matrix + n, r = W.shape + # Want to take the Fp-basis given by the columns of I, adjoin (1, 0, ..., 0) + # (which we know is not already in there since I is a basis for a prime ideal) + # and then supplement this with additional columns to make an invertible n x n + # matrix. This will then represent a full basis for ZK, whose first r columns + # are pullbacks of the basis for I. + if r == 0: + B = W.eye(n, ZZ) + else: + B = W.hstack(W.eye(n, ZZ)[:, 0]) + if B.shape[1] < n: + B = supplement_a_subspace(B.convert_to(FF(p))).convert_to(ZZ) + + G = ZK.submodule_from_matrix(B) + # Must compute G's multiplication table _before_ discarding the first r + # columns. (See Step 9 in Alg 6.2.9 in Cohen, where the betas are actually + # needed in order to represent each product of gammas. However, once we've + # found the representations, then we can ignore the betas.) + G.compute_mult_tab() + G = G.discard_before(r) + + phi = ModuleEndomorphism(G, lambda x: x**p - x) + N = phi.kernel(modulus=p) + assert N.starts_with_unity() + return N, G + + +def _prime_decomp_maximal_ideal(I, p, ZK): + r""" + We have reached the case where we have a maximal (hence prime) ideal *I*, + which we know because the quotient ``O/I`` is a field. + + Parameters + ========== + + I : :py:class:`~.Module` + An ideal of ``O/pO``. + p : int + The rational prime being factored. + ZK : :py:class:`~.Submodule` + The maximal order. + + Returns + ======= + + :py:class:`~.PrimeIdeal` instance representing this prime + + """ + m, n = I.matrix.shape + f = m - n + G = ZK.matrix * I.matrix + gens = [ZK.parent(G[:, j], denom=ZK.denom) for j in range(G.shape[1])] + alpha = _two_elt_rep(gens, ZK, p, f=f) + return PrimeIdeal(ZK, p, alpha, f) + + +def _prime_decomp_split_ideal(I, p, N, G, ZK): + r""" + Perform the step in the prime decomposition algorithm where we have determined + the quotient ``ZK/I`` is _not_ a field, and we want to perform a non-trivial + factorization of *I* by locating an idempotent element of ``ZK/I``. + """ + assert I.parent == ZK and G.parent is ZK and N.parent is G + # Since ZK/I is not a field, the kernel computed in the previous step contains + # more than just the prime field Fp, and our basis N for the nullspace therefore + # contains at least a second column (which represents an element outside Fp). + # Let alpha be such an element: + alpha = N(1).to_parent() + assert alpha.module is G + + alpha_powers = [] + m = find_min_poly(alpha, FF(p), powers=alpha_powers) + # TODO (future work): + # We don't actually need full factorization, so might use a faster method + # to just break off a single non-constant factor m1? + lc, fl = m.factor_list() + m1 = fl[0][0] + m2 = m.quo(m1) + U, V, g = m1.gcdex(m2) + # Sanity check: theory says m is squarefree, so m1, m2 should be coprime: + assert g == 1 + E = list(reversed(Poly(U * m1, domain=ZZ).rep.to_list())) + eps1 = sum(E[i]*alpha_powers[i] for i in range(len(E))) + eps2 = 1 - eps1 + idemps = [eps1, eps2] + factors = [] + for eps in idemps: + e = eps.to_parent() + assert e.module is ZK + D = I.matrix.convert_to(FF(p)).hstack(*[ + (e * om).column(domain=FF(p)) for om in ZK.basis_elements() + ]) + W = D.columnspace().convert_to(ZZ) + H = ZK.submodule_from_matrix(W) + factors.append(H) + return factors + + +@public +def prime_decomp(p, T=None, ZK=None, dK=None, radical=None): + r""" + Compute the decomposition of rational prime *p* in a number field. + + Explanation + =========== + + Ordinarily this should be accessed through the + :py:meth:`~.AlgebraicField.primes_above` method of an + :py:class:`~.AlgebraicField`. + + Examples + ======== + + >>> from sympy import Poly, QQ + >>> from sympy.abc import x, theta + >>> T = Poly(x ** 3 + x ** 2 - 2 * x + 8) + >>> K = QQ.algebraic_field((T, theta)) + >>> print(K.primes_above(2)) + [[ (2, x**2 + 1) e=1, f=1 ], [ (2, (x**2 + 3*x + 2)/2) e=1, f=1 ], + [ (2, (3*x**2 + 3*x)/2) e=1, f=1 ]] + + Parameters + ========== + + p : int + The rational prime whose decomposition is desired. + + T : :py:class:`~.Poly`, optional + Monic irreducible polynomial defining the number field $K$ in which to + factor. NOTE: at least one of *T* or *ZK* must be provided. + + ZK : :py:class:`~.Submodule`, optional + The maximal order for $K$, if already known. + NOTE: at least one of *T* or *ZK* must be provided. + + dK : int, optional + The discriminant of the field $K$, if already known. + + radical : :py:class:`~.Submodule`, optional + The nilradical mod *p* in the integers of $K$, if already known. + + Returns + ======= + + List of :py:class:`~.PrimeIdeal` instances. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Algorithm 6.2.9.) + + """ + if T is None and ZK is None: + raise ValueError('At least one of T or ZK must be provided.') + if ZK is not None: + _check_formal_conditions_for_maximal_order(ZK) + if T is None: + T = ZK.parent.T + radicals = {} + if dK is None or ZK is None: + ZK, dK = round_two(T, radicals=radicals) + dT = T.discriminant() + f_squared = dT // dK + if f_squared % p != 0: + return _prime_decomp_easy_case(p, ZK) + radical = radical or radicals.get(p) or nilradical_mod_p(ZK, p) + stack = [radical] + primes = [] + while stack: + I = stack.pop() + N, G = _prime_decomp_compute_kernel(I, p, ZK) + if N.n == 1: + P = _prime_decomp_maximal_ideal(I, p, ZK) + primes.append(P) + else: + I1, I2 = _prime_decomp_split_ideal(I, p, N, G, ZK) + stack.extend([I1, I2]) + return primes diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/resolvent_lookup.py b/lib/python3.10/site-packages/sympy/polys/numberfields/resolvent_lookup.py new file mode 100644 index 0000000000000000000000000000000000000000..71812c0d7aec6501039eefe4f3602b1916628071 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/resolvent_lookup.py @@ -0,0 +1,456 @@ +"""Lookup table for Galois resolvents for polys of degree 4 through 6. """ +# This table was generated by a call to +# `sympy.polys.numberfields.galois_resolvents.generate_lambda_lookup()`. +# The entire job took 543.23s. +# Of this, Case (6, 1) took 539.03s. +# The final polynomial of Case (6, 1) alone took 455.09s. +resolvent_coeff_lambdas = { + (4, 0): [ + lambda s1, s2, s3, s4: (-2*s1*s2 + 6*s3), + lambda s1, s2, s3, s4: (2*s1**3*s3 + s1**2*s2**2 + s1**2*s4 - 17*s1*s2*s3 + 2*s2**3 - 8*s2*s4 + 24*s3**2), + lambda s1, s2, s3, s4: (-2*s1**5*s4 - 2*s1**4*s2*s3 + 10*s1**3*s2*s4 + 8*s1**3*s3**2 + 10*s1**2*s2**2*s3 - +12*s1**2*s3*s4 - 2*s1*s2**4 - 54*s1*s2*s3**2 + 32*s1*s4**2 + 8*s2**3*s3 - 32*s2*s3*s4 ++ 56*s3**3), + lambda s1, s2, s3, s4: (2*s1**6*s2*s4 + s1**6*s3**2 - 5*s1**5*s3*s4 - 11*s1**4*s2**2*s4 - 13*s1**4*s2*s3**2 ++ 7*s1**4*s4**2 + 3*s1**3*s2**3*s3 + 30*s1**3*s2*s3*s4 + 22*s1**3*s3**3 + 10*s1**2*s2**3*s4 ++ 33*s1**2*s2**2*s3**2 - 72*s1**2*s2*s4**2 - 36*s1**2*s3**2*s4 - 13*s1*s2**4*s3 + +48*s1*s2**2*s3*s4 - 116*s1*s2*s3**3 + 144*s1*s3*s4**2 + s2**6 - 12*s2**4*s4 + 22*s2**3*s3**2 ++ 48*s2**2*s4**2 - 120*s2*s3**2*s4 + 96*s3**4 - 64*s4**3), + lambda s1, s2, s3, s4: (-2*s1**8*s3*s4 - s1**7*s4**2 + 22*s1**6*s2*s3*s4 + 2*s1**6*s3**3 - 2*s1**5*s2**3*s4 +- s1**5*s2**2*s3**2 - 29*s1**5*s3**2*s4 - 60*s1**4*s2**2*s3*s4 - 19*s1**4*s2*s3**3 ++ 38*s1**4*s3*s4**2 + 9*s1**3*s2**4*s4 + 10*s1**3*s2**3*s3**2 + 24*s1**3*s2**2*s4**2 ++ 134*s1**3*s2*s3**2*s4 + 28*s1**3*s3**4 + 16*s1**3*s4**3 - s1**2*s2**5*s3 - 4*s1**2*s2**3*s3*s4 ++ 34*s1**2*s2**2*s3**3 - 288*s1**2*s2*s3*s4**2 - 104*s1**2*s3**3*s4 - 19*s1*s2**4*s3**2 ++ 120*s1*s2**2*s3**2*s4 - 128*s1*s2*s3**4 + 336*s1*s3**2*s4**2 + 2*s2**6*s3 - 24*s2**4*s3*s4 ++ 28*s2**3*s3**3 + 96*s2**2*s3*s4**2 - 176*s2*s3**3*s4 + 96*s3**5 - 128*s3*s4**3), + lambda s1, s2, s3, s4: (s1**10*s4**2 - 11*s1**8*s2*s4**2 - 2*s1**8*s3**2*s4 + s1**7*s2**2*s3*s4 + 15*s1**7*s3*s4**2 ++ 45*s1**6*s2**2*s4**2 + 17*s1**6*s2*s3**2*s4 + s1**6*s3**4 - 5*s1**6*s4**3 - 12*s1**5*s2**3*s3*s4 +- 133*s1**5*s2*s3*s4**2 - 22*s1**5*s3**3*s4 + s1**4*s2**5*s4 - 76*s1**4*s2**3*s4**2 +- 6*s1**4*s2**2*s3**2*s4 - 12*s1**4*s2*s3**4 + 32*s1**4*s2*s4**3 + 128*s1**4*s3**2*s4**2 ++ 29*s1**3*s2**4*s3*s4 + 2*s1**3*s2**3*s3**3 + 344*s1**3*s2**2*s3*s4**2 + 48*s1**3*s2*s3**3*s4 ++ 16*s1**3*s3**5 - 48*s1**3*s3*s4**3 - 4*s1**2*s2**6*s4 + 32*s1**2*s2**4*s4**2 - 134*s1**2*s2**3*s3**2*s4 ++ 36*s1**2*s2**2*s3**4 - 64*s1**2*s2**2*s4**3 - 648*s1**2*s2*s3**2*s4**2 - 48*s1**2*s3**4*s4 ++ 16*s1*s2**5*s3*s4 - 12*s1*s2**4*s3**3 - 128*s1*s2**3*s3*s4**2 + 296*s1*s2**2*s3**3*s4 +- 96*s1*s2*s3**5 + 256*s1*s2*s3*s4**3 + 416*s1*s3**3*s4**2 + s2**6*s3**2 - 28*s2**4*s3**2*s4 ++ 16*s2**3*s3**4 + 176*s2**2*s3**2*s4**2 - 224*s2*s3**4*s4 + 64*s3**6 - 320*s3**2*s4**3) + ], + (4, 1): [ + lambda s1, s2, s3, s4: (-s2), + lambda s1, s2, s3, s4: (s1*s3 - 4*s4), + lambda s1, s2, s3, s4: (-s1**2*s4 + 4*s2*s4 - s3**2) + ], + (5, 1): [ + lambda s1, s2, s3, s4, s5: (-2*s1*s3 + 8*s4), + lambda s1, s2, s3, s4, s5: (-8*s1**3*s5 + 2*s1**2*s2*s4 + s1**2*s3**2 + 30*s1*s2*s5 - 14*s1*s3*s4 - 6*s2**2*s4 ++ 2*s2*s3**2 - 50*s3*s5 + 40*s4**2), + lambda s1, s2, s3, s4, s5: (16*s1**4*s3*s5 - 2*s1**4*s4**2 - 2*s1**3*s2**2*s5 - 2*s1**3*s2*s3*s4 - 44*s1**3*s4*s5 +- 66*s1**2*s2*s3*s5 + 21*s1**2*s2*s4**2 + 6*s1**2*s3**2*s4 - 50*s1**2*s5**2 + 9*s1*s2**3*s5 ++ 5*s1*s2**2*s3*s4 - 2*s1*s2*s3**3 + 190*s1*s2*s4*s5 + 120*s1*s3**2*s5 - 80*s1*s3*s4**2 +- 15*s2**2*s3*s5 - 40*s2**2*s4**2 + 21*s2*s3**2*s4 + 125*s2*s5**2 - 2*s3**4 - 400*s3*s4*s5 ++ 160*s4**3), + lambda s1, s2, s3, s4, s5: (16*s1**6*s5**2 - 8*s1**5*s2*s4*s5 - 8*s1**5*s3**2*s5 + 2*s1**5*s3*s4**2 + 2*s1**4*s2**2*s3*s5 ++ s1**4*s2**2*s4**2 - 120*s1**4*s2*s5**2 + 68*s1**4*s3*s4*s5 - 8*s1**4*s4**3 + 46*s1**3*s2**2*s4*s5 ++ 28*s1**3*s2*s3**2*s5 - 19*s1**3*s2*s3*s4**2 + 250*s1**3*s3*s5**2 - 144*s1**3*s4**2*s5 +- 9*s1**2*s2**3*s3*s5 - 6*s1**2*s2**3*s4**2 + 3*s1**2*s2**2*s3**2*s4 + 225*s1**2*s2**2*s5**2 +- 354*s1**2*s2*s3*s4*s5 + 76*s1**2*s2*s4**3 - 70*s1**2*s3**3*s5 + 41*s1**2*s3**2*s4**2 +- 200*s1**2*s4*s5**2 - 54*s1*s2**3*s4*s5 + 45*s1*s2**2*s3**2*s5 + 30*s1*s2**2*s3*s4**2 +- 19*s1*s2*s3**3*s4 - 875*s1*s2*s3*s5**2 + 640*s1*s2*s4**2*s5 + 2*s1*s3**5 + 630*s1*s3**2*s4*s5 +- 264*s1*s3*s4**3 + 9*s2**4*s4**2 - 6*s2**3*s3**2*s4 + s2**2*s3**4 + 90*s2**2*s3*s4*s5 +- 136*s2**2*s4**3 - 50*s2*s3**3*s5 + 76*s2*s3**2*s4**2 + 500*s2*s4*s5**2 - 8*s3**4*s4 ++ 625*s3**2*s5**2 - 1400*s3*s4**2*s5 + 400*s4**4), + lambda s1, s2, s3, s4, s5: (-32*s1**7*s3*s5**2 + 8*s1**7*s4**2*s5 + 8*s1**6*s2**2*s5**2 + 8*s1**6*s2*s3*s4*s5 +- 2*s1**6*s2*s4**3 + 48*s1**6*s4*s5**2 - 2*s1**5*s2**3*s4*s5 + 264*s1**5*s2*s3*s5**2 +- 94*s1**5*s2*s4**2*s5 - 24*s1**5*s3**2*s4*s5 + 6*s1**5*s3*s4**3 - 56*s1**5*s5**3 +- 66*s1**4*s2**3*s5**2 - 50*s1**4*s2**2*s3*s4*s5 + 19*s1**4*s2**2*s4**3 + 8*s1**4*s2*s3**3*s5 +- 2*s1**4*s2*s3**2*s4**2 - 318*s1**4*s2*s4*s5**2 - 352*s1**4*s3**2*s5**2 + 166*s1**4*s3*s4**2*s5 ++ 3*s1**4*s4**4 + 15*s1**3*s2**4*s4*s5 - 2*s1**3*s2**3*s3**2*s5 - s1**3*s2**3*s3*s4**2 +- 574*s1**3*s2**2*s3*s5**2 + 347*s1**3*s2**2*s4**2*s5 + 194*s1**3*s2*s3**2*s4*s5 - +89*s1**3*s2*s3*s4**3 + 350*s1**3*s2*s5**3 - 8*s1**3*s3**4*s5 + 4*s1**3*s3**3*s4**2 ++ 1090*s1**3*s3*s4*s5**2 - 364*s1**3*s4**3*s5 + 162*s1**2*s2**4*s5**2 + 33*s1**2*s2**3*s3*s4*s5 +- 51*s1**2*s2**3*s4**3 - 32*s1**2*s2**2*s3**3*s5 + 28*s1**2*s2**2*s3**2*s4**2 + 305*s1**2*s2**2*s4*s5**2 +- 2*s1**2*s2*s3**4*s4 + 1340*s1**2*s2*s3**2*s5**2 - 901*s1**2*s2*s3*s4**2*s5 + 76*s1**2*s2*s4**4 +- 234*s1**2*s3**3*s4*s5 + 102*s1**2*s3**2*s4**3 - 750*s1**2*s3*s5**3 - 550*s1**2*s4**2*s5**2 +- 27*s1*s2**5*s4*s5 + 9*s1*s2**4*s3**2*s5 + 3*s1*s2**4*s3*s4**2 - s1*s2**3*s3**3*s4 ++ 180*s1*s2**3*s3*s5**2 - 366*s1*s2**3*s4**2*s5 - 231*s1*s2**2*s3**2*s4*s5 + 212*s1*s2**2*s3*s4**3 +- 375*s1*s2**2*s5**3 + 112*s1*s2*s3**4*s5 - 89*s1*s2*s3**3*s4**2 - 3075*s1*s2*s3*s4*s5**2 ++ 1640*s1*s2*s4**3*s5 + 6*s1*s3**5*s4 - 850*s1*s3**3*s5**2 + 1220*s1*s3**2*s4**2*s5 +- 384*s1*s3*s4**4 + 2500*s1*s4*s5**3 - 108*s2**5*s5**2 + 117*s2**4*s3*s4*s5 + 32*s2**4*s4**3 +- 31*s2**3*s3**3*s5 - 51*s2**3*s3**2*s4**2 + 525*s2**3*s4*s5**2 + 19*s2**2*s3**4*s4 +- 325*s2**2*s3**2*s5**2 + 260*s2**2*s3*s4**2*s5 - 256*s2**2*s4**4 - 2*s2*s3**6 + 105*s2*s3**3*s4*s5 ++ 76*s2*s3**2*s4**3 + 625*s2*s3*s5**3 - 500*s2*s4**2*s5**2 - 58*s3**5*s5 + 3*s3**4*s4**2 ++ 2750*s3**2*s4*s5**2 - 2400*s3*s4**3*s5 + 512*s4**5 - 3125*s5**4), + lambda s1, s2, s3, s4, s5: (16*s1**8*s3**2*s5**2 - 8*s1**8*s3*s4**2*s5 + s1**8*s4**4 - 8*s1**7*s2**2*s3*s5**2 ++ 2*s1**7*s2**2*s4**2*s5 - 48*s1**7*s3*s4*s5**2 + 12*s1**7*s4**3*s5 + s1**6*s2**4*s5**2 ++ 12*s1**6*s2**2*s4*s5**2 - 144*s1**6*s2*s3**2*s5**2 + 88*s1**6*s2*s3*s4**2*s5 - 13*s1**6*s2*s4**4 ++ 56*s1**6*s3*s5**3 + 86*s1**6*s4**2*s5**2 + 72*s1**5*s2**3*s3*s5**2 - 22*s1**5*s2**3*s4**2*s5 +- 4*s1**5*s2**2*s3**2*s4*s5 + s1**5*s2**2*s3*s4**3 - 14*s1**5*s2**2*s5**3 + 304*s1**5*s2*s3*s4*s5**2 +- 148*s1**5*s2*s4**3*s5 + 152*s1**5*s3**3*s5**2 - 54*s1**5*s3**2*s4**2*s5 + 5*s1**5*s3*s4**4 +- 468*s1**5*s4*s5**3 - 9*s1**4*s2**5*s5**2 + s1**4*s2**4*s3*s4*s5 - 76*s1**4*s2**3*s4*s5**2 ++ 370*s1**4*s2**2*s3**2*s5**2 - 287*s1**4*s2**2*s3*s4**2*s5 + 65*s1**4*s2**2*s4**4 +- 28*s1**4*s2*s3**3*s4*s5 + 5*s1**4*s2*s3**2*s4**3 - 200*s1**4*s2*s3*s5**3 - 294*s1**4*s2*s4**2*s5**2 ++ 8*s1**4*s3**5*s5 - 2*s1**4*s3**4*s4**2 - 676*s1**4*s3**2*s4*s5**2 + 180*s1**4*s3*s4**3*s5 ++ 17*s1**4*s4**5 + 625*s1**4*s5**4 - 210*s1**3*s2**4*s3*s5**2 + 76*s1**3*s2**4*s4**2*s5 ++ 43*s1**3*s2**3*s3**2*s4*s5 - 15*s1**3*s2**3*s3*s4**3 + 50*s1**3*s2**3*s5**3 - 6*s1**3*s2**2*s3**4*s5 ++ 2*s1**3*s2**2*s3**3*s4**2 - 397*s1**3*s2**2*s3*s4*s5**2 + 514*s1**3*s2**2*s4**3*s5 +- 700*s1**3*s2*s3**3*s5**2 + 447*s1**3*s2*s3**2*s4**2*s5 - 118*s1**3*s2*s3*s4**4 + +2300*s1**3*s2*s4*s5**3 - 12*s1**3*s3**4*s4*s5 + 6*s1**3*s3**3*s4**3 + 250*s1**3*s3**2*s5**3 ++ 1470*s1**3*s3*s4**2*s5**2 - 276*s1**3*s4**4*s5 + 27*s1**2*s2**6*s5**2 - 9*s1**2*s2**5*s3*s4*s5 ++ s1**2*s2**5*s4**3 + s1**2*s2**4*s3**3*s5 + 141*s1**2*s2**4*s4*s5**2 - 185*s1**2*s2**3*s3**2*s5**2 ++ 168*s1**2*s2**3*s3*s4**2*s5 - 128*s1**2*s2**3*s4**4 + 93*s1**2*s2**2*s3**3*s4*s5 ++ 19*s1**2*s2**2*s3**2*s4**3 - 125*s1**2*s2**2*s3*s5**3 - 610*s1**2*s2**2*s4**2*s5**2 +- 36*s1**2*s2*s3**5*s5 + 5*s1**2*s2*s3**4*s4**2 + 1995*s1**2*s2*s3**2*s4*s5**2 - 1174*s1**2*s2*s3*s4**3*s5 +- 16*s1**2*s2*s4**5 - 3125*s1**2*s2*s5**4 + 375*s1**2*s3**4*s5**2 - 172*s1**2*s3**3*s4**2*s5 ++ 82*s1**2*s3**2*s4**4 - 3500*s1**2*s3*s4*s5**3 - 1450*s1**2*s4**3*s5**2 + 198*s1*s2**5*s3*s5**2 +- 78*s1*s2**5*s4**2*s5 - 95*s1*s2**4*s3**2*s4*s5 + 44*s1*s2**4*s3*s4**3 + 25*s1*s2**3*s3**4*s5 +- 15*s1*s2**3*s3**3*s4**2 + 15*s1*s2**3*s3*s4*s5**2 - 384*s1*s2**3*s4**3*s5 + s1*s2**2*s3**5*s4 ++ 525*s1*s2**2*s3**3*s5**2 - 528*s1*s2**2*s3**2*s4**2*s5 + 384*s1*s2**2*s3*s4**4 - +1750*s1*s2**2*s4*s5**3 - 29*s1*s2*s3**4*s4*s5 - 118*s1*s2*s3**3*s4**3 + 625*s1*s2*s3**2*s5**3 +- 850*s1*s2*s3*s4**2*s5**2 + 1760*s1*s2*s4**4*s5 + 38*s1*s3**6*s5 + 5*s1*s3**5*s4**2 +- 2050*s1*s3**3*s4*s5**2 + 780*s1*s3**2*s4**3*s5 - 192*s1*s3*s4**5 + 3125*s1*s3*s5**4 ++ 7500*s1*s4**2*s5**3 - 27*s2**7*s5**2 + 18*s2**6*s3*s4*s5 - 4*s2**6*s4**3 - 4*s2**5*s3**3*s5 ++ s2**5*s3**2*s4**2 - 99*s2**5*s4*s5**2 - 150*s2**4*s3**2*s5**2 + 196*s2**4*s3*s4**2*s5 ++ 48*s2**4*s4**4 + 12*s2**3*s3**3*s4*s5 - 128*s2**3*s3**2*s4**3 + 1200*s2**3*s4**2*s5**2 +- 12*s2**2*s3**5*s5 + 65*s2**2*s3**4*s4**2 - 725*s2**2*s3**2*s4*s5**2 - 160*s2**2*s3*s4**3*s5 +- 192*s2**2*s4**5 + 3125*s2**2*s5**4 - 13*s2*s3**6*s4 - 125*s2*s3**4*s5**2 + 590*s2*s3**3*s4**2*s5 +- 16*s2*s3**2*s4**4 - 1250*s2*s3*s4*s5**3 - 2000*s2*s4**3*s5**2 + s3**8 - 124*s3**5*s4*s5 ++ 17*s3**4*s4**3 + 3250*s3**2*s4**2*s5**2 - 1600*s3*s4**4*s5 + 256*s4**6 - 9375*s4*s5**4) + ], + (6, 1): [ + lambda s1, s2, s3, s4, s5, s6: (8*s1*s5 - 2*s2*s4 - 18*s6), + lambda s1, s2, s3, s4, s5, s6: (-50*s1**2*s4*s6 + 40*s1**2*s5**2 + 30*s1*s2*s3*s6 - 14*s1*s2*s4*s5 - 6*s1*s3**2*s5 ++ 2*s1*s3*s4**2 - 30*s1*s5*s6 - 8*s2**3*s6 + 2*s2**2*s3*s5 + s2**2*s4**2 + 114*s2*s4*s6 +- 50*s2*s5**2 - 54*s3**2*s6 + 30*s3*s4*s5 - 8*s4**3 - 135*s6**2), + lambda s1, s2, s3, s4, s5, s6: (125*s1**3*s3*s6**2 - 400*s1**3*s4*s5*s6 + 160*s1**3*s5**3 - 50*s1**2*s2**2*s6**2 + +190*s1**2*s2*s3*s5*s6 + 120*s1**2*s2*s4**2*s6 - 80*s1**2*s2*s4*s5**2 - 15*s1**2*s3**2*s4*s6 +- 40*s1**2*s3**2*s5**2 + 21*s1**2*s3*s4**2*s5 - 2*s1**2*s4**4 + 900*s1**2*s4*s6**2 +- 80*s1**2*s5**2*s6 - 44*s1*s2**3*s5*s6 - 66*s1*s2**2*s3*s4*s6 + 21*s1*s2**2*s3*s5**2 ++ 6*s1*s2**2*s4**2*s5 + 9*s1*s2*s3**3*s6 + 5*s1*s2*s3**2*s4*s5 - 2*s1*s2*s3*s4**3 +- 990*s1*s2*s3*s6**2 + 920*s1*s2*s4*s5*s6 - 400*s1*s2*s5**3 - 135*s1*s3**2*s5*s6 - +126*s1*s3*s4**2*s6 + 190*s1*s3*s4*s5**2 - 44*s1*s4**3*s5 - 2070*s1*s5*s6**2 + 16*s2**4*s4*s6 +- 2*s2**4*s5**2 - 2*s2**3*s3**2*s6 - 2*s2**3*s3*s4*s5 + 304*s2**3*s6**2 - 126*s2**2*s3*s5*s6 +- 232*s2**2*s4**2*s6 + 120*s2**2*s4*s5**2 + 198*s2*s3**2*s4*s6 - 15*s2*s3**2*s5**2 +- 66*s2*s3*s4**2*s5 + 16*s2*s4**4 - 1440*s2*s4*s6**2 + 900*s2*s5**2*s6 - 27*s3**4*s6 ++ 9*s3**3*s4*s5 - 2*s3**2*s4**3 + 1350*s3**2*s6**2 - 990*s3*s4*s5*s6 + 125*s3*s5**3 ++ 304*s4**3*s6 - 50*s4**2*s5**2 + 3240*s6**3), + lambda s1, s2, s3, s4, s5, s6: (500*s1**4*s3*s5*s6**2 + 625*s1**4*s4**2*s6**2 - 1400*s1**4*s4*s5**2*s6 + 400*s1**4*s5**4 +- 200*s1**3*s2**2*s5*s6**2 - 875*s1**3*s2*s3*s4*s6**2 + 640*s1**3*s2*s3*s5**2*s6 + +630*s1**3*s2*s4**2*s5*s6 - 264*s1**3*s2*s4*s5**3 + 90*s1**3*s3**2*s4*s5*s6 - 136*s1**3*s3**2*s5**3 +- 50*s1**3*s3*s4**3*s6 + 76*s1**3*s3*s4**2*s5**2 - 1125*s1**3*s3*s6**3 - 8*s1**3*s4**4*s5 ++ 2550*s1**3*s4*s5*s6**2 - 200*s1**3*s5**3*s6 + 250*s1**2*s2**3*s4*s6**2 - 144*s1**2*s2**3*s5**2*s6 ++ 225*s1**2*s2**2*s3**2*s6**2 - 354*s1**2*s2**2*s3*s4*s5*s6 + 76*s1**2*s2**2*s3*s5**3 +- 70*s1**2*s2**2*s4**3*s6 + 41*s1**2*s2**2*s4**2*s5**2 + 450*s1**2*s2**2*s6**3 - 54*s1**2*s2*s3**3*s5*s6 ++ 45*s1**2*s2*s3**2*s4**2*s6 + 30*s1**2*s2*s3**2*s4*s5**2 - 19*s1**2*s2*s3*s4**3*s5 +- 2880*s1**2*s2*s3*s5*s6**2 + 2*s1**2*s2*s4**5 - 3480*s1**2*s2*s4**2*s6**2 + 4692*s1**2*s2*s4*s5**2*s6 +- 1400*s1**2*s2*s5**4 + 9*s1**2*s3**4*s5**2 - 6*s1**2*s3**3*s4**2*s5 + s1**2*s3**2*s4**4 ++ 1485*s1**2*s3**2*s4*s6**2 - 522*s1**2*s3**2*s5**2*s6 - 1257*s1**2*s3*s4**2*s5*s6 ++ 640*s1**2*s3*s4*s5**3 + 218*s1**2*s4**4*s6 - 144*s1**2*s4**3*s5**2 + 1350*s1**2*s4*s6**3 +- 5175*s1**2*s5**2*s6**2 - 120*s1*s2**4*s3*s6**2 + 68*s1*s2**4*s4*s5*s6 - 8*s1*s2**4*s5**3 ++ 46*s1*s2**3*s3**2*s5*s6 + 28*s1*s2**3*s3*s4**2*s6 - 19*s1*s2**3*s3*s4*s5**2 + 868*s1*s2**3*s5*s6**2 +- 9*s1*s2**2*s3**3*s4*s6 - 6*s1*s2**2*s3**3*s5**2 + 3*s1*s2**2*s3**2*s4**2*s5 + 2484*s1*s2**2*s3*s4*s6**2 +- 1257*s1*s2**2*s3*s5**2*s6 - 1356*s1*s2**2*s4**2*s5*s6 + 630*s1*s2**2*s4*s5**3 - +891*s1*s2*s3**3*s6**2 + 882*s1*s2*s3**2*s4*s5*s6 + 90*s1*s2*s3**2*s5**3 + 84*s1*s2*s3*s4**3*s6 +- 354*s1*s2*s3*s4**2*s5**2 + 3240*s1*s2*s3*s6**3 + 68*s1*s2*s4**4*s5 - 4392*s1*s2*s4*s5*s6**2 ++ 2550*s1*s2*s5**3*s6 + 54*s1*s3**4*s5*s6 - 54*s1*s3**3*s4**2*s6 - 54*s1*s3**3*s4*s5**2 ++ 46*s1*s3**2*s4**3*s5 + 2727*s1*s3**2*s5*s6**2 - 8*s1*s3*s4**5 + 756*s1*s3*s4**2*s6**2 +- 2880*s1*s3*s4*s5**2*s6 + 500*s1*s3*s5**4 + 868*s1*s4**3*s5*s6 - 200*s1*s4**2*s5**3 ++ 8100*s1*s5*s6**3 + 16*s2**6*s6**2 - 8*s2**5*s3*s5*s6 - 8*s2**5*s4**2*s6 + 2*s2**5*s4*s5**2 ++ 2*s2**4*s3**2*s4*s6 + s2**4*s3**2*s5**2 - 688*s2**4*s4*s6**2 + 218*s2**4*s5**2*s6 ++ 234*s2**3*s3**2*s6**2 + 84*s2**3*s3*s4*s5*s6 - 50*s2**3*s3*s5**3 + 168*s2**3*s4**3*s6 +- 70*s2**3*s4**2*s5**2 - 1224*s2**3*s6**3 - 54*s2**2*s3**3*s5*s6 - 144*s2**2*s3**2*s4**2*s6 ++ 45*s2**2*s3**2*s4*s5**2 + 28*s2**2*s3*s4**3*s5 + 756*s2**2*s3*s5*s6**2 - 8*s2**2*s4**5 ++ 4320*s2**2*s4**2*s6**2 - 3480*s2**2*s4*s5**2*s6 + 625*s2**2*s5**4 + 27*s2*s3**4*s4*s6 +- 9*s2*s3**3*s4**2*s5 + 2*s2*s3**2*s4**4 - 4752*s2*s3**2*s4*s6**2 + 1485*s2*s3**2*s5**2*s6 ++ 2484*s2*s3*s4**2*s5*s6 - 875*s2*s3*s4*s5**3 - 688*s2*s4**4*s6 + 250*s2*s4**3*s5**2 +- 4536*s2*s4*s6**3 + 1350*s2*s5**2*s6**2 + 972*s3**4*s6**2 - 891*s3**3*s4*s5*s6 + +234*s3**2*s4**3*s6 + 225*s3**2*s4**2*s5**2 - 1944*s3**2*s6**3 - 120*s3*s4**4*s5 + +3240*s3*s4*s5*s6**2 - 1125*s3*s5**3*s6 + 16*s4**6 - 1224*s4**3*s6**2 + 450*s4**2*s5**2*s6), + lambda s1, s2, s3, s4, s5, s6: (-3125*s1**6*s6**4 + 2500*s1**5*s2*s5*s6**3 + 625*s1**5*s3*s4*s6**3 - 500*s1**5*s3*s5**2*s6**2 ++ 2750*s1**5*s4**2*s5*s6**2 - 2400*s1**5*s4*s5**3*s6 + 512*s1**5*s5**5 - 750*s1**4*s2**2*s4*s6**3 +- 550*s1**4*s2**2*s5**2*s6**2 - 375*s1**4*s2*s3**2*s6**3 - 3075*s1**4*s2*s3*s4*s5*s6**2 ++ 1640*s1**4*s2*s3*s5**3*s6 - 850*s1**4*s2*s4**3*s6**2 + 1220*s1**4*s2*s4**2*s5**2*s6 +- 384*s1**4*s2*s4*s5**4 + 22500*s1**4*s2*s6**4 + 525*s1**4*s3**3*s5*s6**2 - 325*s1**4*s3**2*s4**2*s6**2 ++ 260*s1**4*s3**2*s4*s5**2*s6 - 256*s1**4*s3**2*s5**4 + 105*s1**4*s3*s4**3*s5*s6 + +76*s1**4*s3*s4**2*s5**3 + 375*s1**4*s3*s5*s6**3 - 58*s1**4*s4**5*s6 + 3*s1**4*s4**4*s5**2 +- 12750*s1**4*s4**2*s6**3 + 3700*s1**4*s4*s5**2*s6**2 + 640*s1**4*s5**4*s6 + 350*s1**3*s2**3*s3*s6**3 ++ 1090*s1**3*s2**3*s4*s5*s6**2 - 364*s1**3*s2**3*s5**3*s6 + 305*s1**3*s2**2*s3**2*s5*s6**2 ++ 1340*s1**3*s2**2*s3*s4**2*s6**2 - 901*s1**3*s2**2*s3*s4*s5**2*s6 + 76*s1**3*s2**2*s3*s5**4 +- 234*s1**3*s2**2*s4**3*s5*s6 + 102*s1**3*s2**2*s4**2*s5**3 - 16650*s1**3*s2**2*s5*s6**3 ++ 180*s1**3*s2*s3**3*s4*s6**2 - 366*s1**3*s2*s3**3*s5**2*s6 - 231*s1**3*s2*s3**2*s4**2*s5*s6 ++ 212*s1**3*s2*s3**2*s4*s5**3 + 112*s1**3*s2*s3*s4**4*s6 - 89*s1**3*s2*s3*s4**3*s5**2 ++ 10950*s1**3*s2*s3*s4*s6**3 + 1555*s1**3*s2*s3*s5**2*s6**2 + 6*s1**3*s2*s4**5*s5 +- 9540*s1**3*s2*s4**2*s5*s6**2 + 9016*s1**3*s2*s4*s5**3*s6 - 2400*s1**3*s2*s5**5 - +108*s1**3*s3**5*s6**2 + 117*s1**3*s3**4*s4*s5*s6 + 32*s1**3*s3**4*s5**3 - 31*s1**3*s3**3*s4**3*s6 +- 51*s1**3*s3**3*s4**2*s5**2 - 2025*s1**3*s3**3*s6**3 + 19*s1**3*s3**2*s4**4*s5 + +2955*s1**3*s3**2*s4*s5*s6**2 - 1436*s1**3*s3**2*s5**3*s6 - 2*s1**3*s3*s4**6 + 2770*s1**3*s3*s4**3*s6**2 +- 5123*s1**3*s3*s4**2*s5**2*s6 + 1640*s1**3*s3*s4*s5**4 - 40500*s1**3*s3*s6**4 + 914*s1**3*s4**4*s5*s6 +- 364*s1**3*s4**3*s5**3 + 53550*s1**3*s4*s5*s6**3 - 17930*s1**3*s5**3*s6**2 - 56*s1**2*s2**5*s6**3 +- 318*s1**2*s2**4*s3*s5*s6**2 - 352*s1**2*s2**4*s4**2*s6**2 + 166*s1**2*s2**4*s4*s5**2*s6 ++ 3*s1**2*s2**4*s5**4 - 574*s1**2*s2**3*s3**2*s4*s6**2 + 347*s1**2*s2**3*s3**2*s5**2*s6 ++ 194*s1**2*s2**3*s3*s4**2*s5*s6 - 89*s1**2*s2**3*s3*s4*s5**3 - 8*s1**2*s2**3*s4**4*s6 ++ 4*s1**2*s2**3*s4**3*s5**2 + 560*s1**2*s2**3*s4*s6**3 + 3662*s1**2*s2**3*s5**2*s6**2 ++ 162*s1**2*s2**2*s3**4*s6**2 + 33*s1**2*s2**2*s3**3*s4*s5*s6 - 51*s1**2*s2**2*s3**3*s5**3 +- 32*s1**2*s2**2*s3**2*s4**3*s6 + 28*s1**2*s2**2*s3**2*s4**2*s5**2 + 270*s1**2*s2**2*s3**2*s6**3 +- 2*s1**2*s2**2*s3*s4**4*s5 + 4872*s1**2*s2**2*s3*s4*s5*s6**2 - 5123*s1**2*s2**2*s3*s5**3*s6 ++ 2144*s1**2*s2**2*s4**3*s6**2 - 2812*s1**2*s2**2*s4**2*s5**2*s6 + 1220*s1**2*s2**2*s4*s5**4 +- 37800*s1**2*s2**2*s6**4 - 27*s1**2*s2*s3**5*s5*s6 + 9*s1**2*s2*s3**4*s4**2*s6 + +3*s1**2*s2*s3**4*s4*s5**2 - s1**2*s2*s3**3*s4**3*s5 - 3078*s1**2*s2*s3**3*s5*s6**2 +- 4014*s1**2*s2*s3**2*s4**2*s6**2 + 5412*s1**2*s2*s3**2*s4*s5**2*s6 + 260*s1**2*s2*s3**2*s5**4 +- 310*s1**2*s2*s3*s4**3*s5*s6 - 901*s1**2*s2*s3*s4**2*s5**3 - 3780*s1**2*s2*s3*s5*s6**3 ++ 166*s1**2*s2*s4**4*s5**2 + 40320*s1**2*s2*s4**2*s6**3 - 25344*s1**2*s2*s4*s5**2*s6**2 ++ 3700*s1**2*s2*s5**4*s6 + 918*s1**2*s3**4*s4*s6**2 + 27*s1**2*s3**4*s5**2*s6 - 342*s1**2*s3**3*s4**2*s5*s6 +- 366*s1**2*s3**3*s4*s5**3 + 32*s1**2*s3**2*s4**4*s6 + 347*s1**2*s3**2*s4**3*s5**2 +- 4590*s1**2*s3**2*s4*s6**3 + 594*s1**2*s3**2*s5**2*s6**2 - 94*s1**2*s3*s4**5*s5 + +3618*s1**2*s3*s4**2*s5*s6**2 + 1555*s1**2*s3*s4*s5**3*s6 - 500*s1**2*s3*s5**5 + 8*s1**2*s4**7 +- 7192*s1**2*s4**4*s6**2 + 3662*s1**2*s4**3*s5**2*s6 - 550*s1**2*s4**2*s5**4 - 48600*s1**2*s4*s6**4 ++ 1080*s1**2*s5**2*s6**3 + 48*s1*s2**6*s5*s6**2 + 264*s1*s2**5*s3*s4*s6**2 - 94*s1*s2**5*s3*s5**2*s6 +- 24*s1*s2**5*s4**2*s5*s6 + 6*s1*s2**5*s4*s5**3 - 66*s1*s2**4*s3**3*s6**2 - 50*s1*s2**4*s3**2*s4*s5*s6 ++ 19*s1*s2**4*s3**2*s5**3 + 8*s1*s2**4*s3*s4**3*s6 - 2*s1*s2**4*s3*s4**2*s5**2 - 552*s1*s2**4*s3*s6**3 +- 2560*s1*s2**4*s4*s5*s6**2 + 914*s1*s2**4*s5**3*s6 + 15*s1*s2**3*s3**4*s5*s6 - 2*s1*s2**3*s3**3*s4**2*s6 +- s1*s2**3*s3**3*s4*s5**2 + 1602*s1*s2**3*s3**2*s5*s6**2 - 608*s1*s2**3*s3*s4**2*s6**2 +- 310*s1*s2**3*s3*s4*s5**2*s6 + 105*s1*s2**3*s3*s5**4 + 600*s1*s2**3*s4**3*s5*s6 - +234*s1*s2**3*s4**2*s5**3 + 31368*s1*s2**3*s5*s6**3 + 756*s1*s2**2*s3**3*s4*s6**2 - +342*s1*s2**2*s3**3*s5**2*s6 + 216*s1*s2**2*s3**2*s4**2*s5*s6 - 231*s1*s2**2*s3**2*s4*s5**3 +- 192*s1*s2**2*s3*s4**4*s6 + 194*s1*s2**2*s3*s4**3*s5**2 - 39096*s1*s2**2*s3*s4*s6**3 ++ 3618*s1*s2**2*s3*s5**2*s6**2 - 24*s1*s2**2*s4**5*s5 + 9408*s1*s2**2*s4**2*s5*s6**2 +- 9540*s1*s2**2*s4*s5**3*s6 + 2750*s1*s2**2*s5**5 - 162*s1*s2*s3**5*s6**2 - 378*s1*s2*s3**4*s4*s5*s6 ++ 117*s1*s2*s3**4*s5**3 + 150*s1*s2*s3**3*s4**3*s6 + 33*s1*s2*s3**3*s4**2*s5**2 + +10044*s1*s2*s3**3*s6**3 - 50*s1*s2*s3**2*s4**4*s5 - 8640*s1*s2*s3**2*s4*s5*s6**2 + +2955*s1*s2*s3**2*s5**3*s6 + 8*s1*s2*s3*s4**6 + 6144*s1*s2*s3*s4**3*s6**2 + 4872*s1*s2*s3*s4**2*s5**2*s6 +- 3075*s1*s2*s3*s4*s5**4 + 174960*s1*s2*s3*s6**4 - 2560*s1*s2*s4**4*s5*s6 + 1090*s1*s2*s4**3*s5**3 +- 148824*s1*s2*s4*s5*s6**3 + 53550*s1*s2*s5**3*s6**2 + 81*s1*s3**6*s5*s6 - 27*s1*s3**5*s4**2*s6 +- 27*s1*s3**5*s4*s5**2 + 15*s1*s3**4*s4**3*s5 + 2430*s1*s3**4*s5*s6**2 - 2*s1*s3**3*s4**5 +- 2052*s1*s3**3*s4**2*s6**2 - 3078*s1*s3**3*s4*s5**2*s6 + 525*s1*s3**3*s5**4 + 1602*s1*s3**2*s4**3*s5*s6 ++ 305*s1*s3**2*s4**2*s5**3 + 18144*s1*s3**2*s5*s6**3 - 104*s1*s3*s4**5*s6 - 318*s1*s3*s4**4*s5**2 +- 33696*s1*s3*s4**2*s6**3 - 3780*s1*s3*s4*s5**2*s6**2 + 375*s1*s3*s5**4*s6 + 48*s1*s4**6*s5 ++ 31368*s1*s4**3*s5*s6**2 - 16650*s1*s4**2*s5**3*s6 + 2500*s1*s4*s5**5 + 77760*s1*s5*s6**4 +- 32*s2**7*s4*s6**2 + 8*s2**7*s5**2*s6 + 8*s2**6*s3**2*s6**2 + 8*s2**6*s3*s4*s5*s6 +- 2*s2**6*s3*s5**3 + 96*s2**6*s6**3 - 2*s2**5*s3**3*s5*s6 - 104*s2**5*s3*s5*s6**2 ++ 416*s2**5*s4**2*s6**2 - 58*s2**5*s5**4 - 312*s2**4*s3**2*s4*s6**2 + 32*s2**4*s3**2*s5**2*s6 +- 192*s2**4*s3*s4**2*s5*s6 + 112*s2**4*s3*s4*s5**3 - 8*s2**4*s4**3*s5**2 + 4224*s2**4*s4*s6**3 +- 7192*s2**4*s5**2*s6**2 + 54*s2**3*s3**4*s6**2 + 150*s2**3*s3**3*s4*s5*s6 - 31*s2**3*s3**3*s5**3 +- 32*s2**3*s3**2*s4**2*s5**2 - 864*s2**3*s3**2*s6**3 + 8*s2**3*s3*s4**4*s5 + 6144*s2**3*s3*s4*s5*s6**2 ++ 2770*s2**3*s3*s5**3*s6 - 4032*s2**3*s4**3*s6**2 + 2144*s2**3*s4**2*s5**2*s6 - 850*s2**3*s4*s5**4 +- 16416*s2**3*s6**4 - 27*s2**2*s3**5*s5*s6 + 9*s2**2*s3**4*s4*s5**2 - 2*s2**2*s3**3*s4**3*s5 +- 2052*s2**2*s3**3*s5*s6**2 + 2376*s2**2*s3**2*s4**2*s6**2 - 4014*s2**2*s3**2*s4*s5**2*s6 +- 325*s2**2*s3**2*s5**4 - 608*s2**2*s3*s4**3*s5*s6 + 1340*s2**2*s3*s4**2*s5**3 - 33696*s2**2*s3*s5*s6**3 ++ 416*s2**2*s4**5*s6 - 352*s2**2*s4**4*s5**2 - 6048*s2**2*s4**2*s6**3 + 40320*s2**2*s4*s5**2*s6**2 +- 12750*s2**2*s5**4*s6 - 324*s2*s3**4*s4*s6**2 + 918*s2*s3**4*s5**2*s6 + 756*s2*s3**3*s4**2*s5*s6 ++ 180*s2*s3**3*s4*s5**3 - 312*s2*s3**2*s4**4*s6 - 574*s2*s3**2*s4**3*s5**2 + 43416*s2*s3**2*s4*s6**3 +- 4590*s2*s3**2*s5**2*s6**2 + 264*s2*s3*s4**5*s5 - 39096*s2*s3*s4**2*s5*s6**2 + 10950*s2*s3*s4*s5**3*s6 ++ 625*s2*s3*s5**5 - 32*s2*s4**7 + 4224*s2*s4**4*s6**2 + 560*s2*s4**3*s5**2*s6 - 750*s2*s4**2*s5**4 ++ 85536*s2*s4*s6**4 - 48600*s2*s5**2*s6**3 - 162*s3**5*s4*s5*s6 - 108*s3**5*s5**3 ++ 54*s3**4*s4**3*s6 + 162*s3**4*s4**2*s5**2 - 11664*s3**4*s6**3 - 66*s3**3*s4**4*s5 ++ 10044*s3**3*s4*s5*s6**2 - 2025*s3**3*s5**3*s6 + 8*s3**2*s4**6 - 864*s3**2*s4**3*s6**2 ++ 270*s3**2*s4**2*s5**2*s6 - 375*s3**2*s4*s5**4 - 163296*s3**2*s6**4 - 552*s3*s4**4*s5*s6 ++ 350*s3*s4**3*s5**3 + 174960*s3*s4*s5*s6**3 - 40500*s3*s5**3*s6**2 + 96*s4**6*s6 +- 56*s4**5*s5**2 - 16416*s4**3*s6**3 - 37800*s4**2*s5**2*s6**2 + 22500*s4*s5**4*s6 +- 3125*s5**6 - 93312*s6**5), + lambda s1, s2, s3, s4, s5, s6: (-9375*s1**7*s5*s6**4 + 3125*s1**6*s2*s4*s6**4 + 7500*s1**6*s2*s5**2*s6**3 + 3125*s1**6*s3**2*s6**4 +- 1250*s1**6*s3*s4*s5*s6**3 - 2000*s1**6*s3*s5**3*s6**2 + 3250*s1**6*s4**2*s5**2*s6**2 +- 1600*s1**6*s4*s5**4*s6 + 256*s1**6*s5**6 + 40625*s1**6*s6**5 - 3125*s1**5*s2**2*s3*s6**4 +- 3500*s1**5*s2**2*s4*s5*s6**3 - 1450*s1**5*s2**2*s5**3*s6**2 - 1750*s1**5*s2*s3**2*s5*s6**3 ++ 625*s1**5*s2*s3*s4**2*s6**3 - 850*s1**5*s2*s3*s4*s5**2*s6**2 + 1760*s1**5*s2*s3*s5**4*s6 +- 2050*s1**5*s2*s4**3*s5*s6**2 + 780*s1**5*s2*s4**2*s5**3*s6 - 192*s1**5*s2*s4*s5**5 ++ 35000*s1**5*s2*s5*s6**4 + 1200*s1**5*s3**3*s5**2*s6**2 - 725*s1**5*s3**2*s4**2*s5*s6**2 +- 160*s1**5*s3**2*s4*s5**3*s6 - 192*s1**5*s3**2*s5**5 - 125*s1**5*s3*s4**4*s6**2 + +590*s1**5*s3*s4**3*s5**2*s6 - 16*s1**5*s3*s4**2*s5**4 - 20625*s1**5*s3*s4*s6**4 + +17250*s1**5*s3*s5**2*s6**3 - 124*s1**5*s4**5*s5*s6 + 17*s1**5*s4**4*s5**3 - 20250*s1**5*s4**2*s5*s6**3 ++ 1900*s1**5*s4*s5**3*s6**2 + 1344*s1**5*s5**5*s6 + 625*s1**4*s2**4*s6**4 + 2300*s1**4*s2**3*s3*s5*s6**3 ++ 250*s1**4*s2**3*s4**2*s6**3 + 1470*s1**4*s2**3*s4*s5**2*s6**2 - 276*s1**4*s2**3*s5**4*s6 +- 125*s1**4*s2**2*s3**2*s4*s6**3 - 610*s1**4*s2**2*s3**2*s5**2*s6**2 + 1995*s1**4*s2**2*s3*s4**2*s5*s6**2 +- 1174*s1**4*s2**2*s3*s4*s5**3*s6 - 16*s1**4*s2**2*s3*s5**5 + 375*s1**4*s2**2*s4**4*s6**2 +- 172*s1**4*s2**2*s4**3*s5**2*s6 + 82*s1**4*s2**2*s4**2*s5**4 - 7750*s1**4*s2**2*s4*s6**4 +- 46650*s1**4*s2**2*s5**2*s6**3 + 15*s1**4*s2*s3**3*s4*s5*s6**2 - 384*s1**4*s2*s3**3*s5**3*s6 ++ 525*s1**4*s2*s3**2*s4**3*s6**2 - 528*s1**4*s2*s3**2*s4**2*s5**2*s6 + 384*s1**4*s2*s3**2*s4*s5**4 +- 10125*s1**4*s2*s3**2*s6**4 - 29*s1**4*s2*s3*s4**4*s5*s6 - 118*s1**4*s2*s3*s4**3*s5**3 ++ 36700*s1**4*s2*s3*s4*s5*s6**3 + 2410*s1**4*s2*s3*s5**3*s6**2 + 38*s1**4*s2*s4**6*s6 ++ 5*s1**4*s2*s4**5*s5**2 + 5550*s1**4*s2*s4**3*s6**3 - 10040*s1**4*s2*s4**2*s5**2*s6**2 ++ 5800*s1**4*s2*s4*s5**4*s6 - 1600*s1**4*s2*s5**6 - 292500*s1**4*s2*s6**5 - 99*s1**4*s3**5*s5*s6**2 +- 150*s1**4*s3**4*s4**2*s6**2 + 196*s1**4*s3**4*s4*s5**2*s6 + 48*s1**4*s3**4*s5**4 ++ 12*s1**4*s3**3*s4**3*s5*s6 - 128*s1**4*s3**3*s4**2*s5**3 - 6525*s1**4*s3**3*s5*s6**3 +- 12*s1**4*s3**2*s4**5*s6 + 65*s1**4*s3**2*s4**4*s5**2 + 225*s1**4*s3**2*s4**2*s6**3 ++ 80*s1**4*s3**2*s4*s5**2*s6**2 - 13*s1**4*s3*s4**6*s5 + 5145*s1**4*s3*s4**3*s5*s6**2 +- 6746*s1**4*s3*s4**2*s5**3*s6 + 1760*s1**4*s3*s4*s5**5 - 103500*s1**4*s3*s5*s6**4 ++ s1**4*s4**8 + 954*s1**4*s4**5*s6**2 + 449*s1**4*s4**4*s5**2*s6 - 276*s1**4*s4**3*s5**4 ++ 70125*s1**4*s4**2*s6**4 + 58900*s1**4*s4*s5**2*s6**3 - 23310*s1**4*s5**4*s6**2 - +468*s1**3*s2**5*s5*s6**3 - 200*s1**3*s2**4*s3*s4*s6**3 - 294*s1**3*s2**4*s3*s5**2*s6**2 +- 676*s1**3*s2**4*s4**2*s5*s6**2 + 180*s1**3*s2**4*s4*s5**3*s6 + 17*s1**3*s2**4*s5**5 ++ 50*s1**3*s2**3*s3**3*s6**3 - 397*s1**3*s2**3*s3**2*s4*s5*s6**2 + 514*s1**3*s2**3*s3**2*s5**3*s6 +- 700*s1**3*s2**3*s3*s4**3*s6**2 + 447*s1**3*s2**3*s3*s4**2*s5**2*s6 - 118*s1**3*s2**3*s3*s4*s5**4 ++ 11700*s1**3*s2**3*s3*s6**4 - 12*s1**3*s2**3*s4**4*s5*s6 + 6*s1**3*s2**3*s4**3*s5**3 ++ 10360*s1**3*s2**3*s4*s5*s6**3 + 11404*s1**3*s2**3*s5**3*s6**2 + 141*s1**3*s2**2*s3**4*s5*s6**2 +- 185*s1**3*s2**2*s3**3*s4**2*s6**2 + 168*s1**3*s2**2*s3**3*s4*s5**2*s6 - 128*s1**3*s2**2*s3**3*s5**4 ++ 93*s1**3*s2**2*s3**2*s4**3*s5*s6 + 19*s1**3*s2**2*s3**2*s4**2*s5**3 + 5895*s1**3*s2**2*s3**2*s5*s6**3 +- 36*s1**3*s2**2*s3*s4**5*s6 + 5*s1**3*s2**2*s3*s4**4*s5**2 - 12020*s1**3*s2**2*s3*s4**2*s6**3 +- 5698*s1**3*s2**2*s3*s4*s5**2*s6**2 - 6746*s1**3*s2**2*s3*s5**4*s6 + 5064*s1**3*s2**2*s4**3*s5*s6**2 +- 762*s1**3*s2**2*s4**2*s5**3*s6 + 780*s1**3*s2**2*s4*s5**5 + 93900*s1**3*s2**2*s5*s6**4 ++ 198*s1**3*s2*s3**5*s4*s6**2 - 78*s1**3*s2*s3**5*s5**2*s6 - 95*s1**3*s2*s3**4*s4**2*s5*s6 ++ 44*s1**3*s2*s3**4*s4*s5**3 + 25*s1**3*s2*s3**3*s4**4*s6 - 15*s1**3*s2*s3**3*s4**3*s5**2 ++ 1935*s1**3*s2*s3**3*s4*s6**3 - 2808*s1**3*s2*s3**3*s5**2*s6**2 + s1**3*s2*s3**2*s4**5*s5 +- 4844*s1**3*s2*s3**2*s4**2*s5*s6**2 + 8996*s1**3*s2*s3**2*s4*s5**3*s6 - 160*s1**3*s2*s3**2*s5**5 +- 3616*s1**3*s2*s3*s4**4*s6**2 + 500*s1**3*s2*s3*s4**3*s5**2*s6 - 1174*s1**3*s2*s3*s4**2*s5**4 ++ 72900*s1**3*s2*s3*s4*s6**4 - 55665*s1**3*s2*s3*s5**2*s6**3 + 128*s1**3*s2*s4**5*s5*s6 ++ 180*s1**3*s2*s4**4*s5**3 + 16240*s1**3*s2*s4**2*s5*s6**3 - 9330*s1**3*s2*s4*s5**3*s6**2 ++ 1900*s1**3*s2*s5**5*s6 - 27*s1**3*s3**7*s6**2 + 18*s1**3*s3**6*s4*s5*s6 - 4*s1**3*s3**6*s5**3 +- 4*s1**3*s3**5*s4**3*s6 + s1**3*s3**5*s4**2*s5**2 + 54*s1**3*s3**5*s6**3 + 1143*s1**3*s3**4*s4*s5*s6**2 +- 820*s1**3*s3**4*s5**3*s6 + 923*s1**3*s3**3*s4**3*s6**2 + 57*s1**3*s3**3*s4**2*s5**2*s6 +- 384*s1**3*s3**3*s4*s5**4 + 29700*s1**3*s3**3*s6**4 - 547*s1**3*s3**2*s4**4*s5*s6 ++ 514*s1**3*s3**2*s4**3*s5**3 - 10305*s1**3*s3**2*s4*s5*s6**3 - 7405*s1**3*s3**2*s5**3*s6**2 ++ 108*s1**3*s3*s4**6*s6 - 148*s1**3*s3*s4**5*s5**2 - 11360*s1**3*s3*s4**3*s6**3 + +22209*s1**3*s3*s4**2*s5**2*s6**2 + 2410*s1**3*s3*s4*s5**4*s6 - 2000*s1**3*s3*s5**6 ++ 432000*s1**3*s3*s6**5 + 12*s1**3*s4**7*s5 - 22624*s1**3*s4**4*s5*s6**2 + 11404*s1**3*s4**3*s5**3*s6 +- 1450*s1**3*s4**2*s5**5 - 242100*s1**3*s4*s5*s6**4 + 58430*s1**3*s5**3*s6**3 + 56*s1**2*s2**6*s4*s6**3 ++ 86*s1**2*s2**6*s5**2*s6**2 - 14*s1**2*s2**5*s3**2*s6**3 + 304*s1**2*s2**5*s3*s4*s5*s6**2 +- 148*s1**2*s2**5*s3*s5**3*s6 + 152*s1**2*s2**5*s4**3*s6**2 - 54*s1**2*s2**5*s4**2*s5**2*s6 ++ 5*s1**2*s2**5*s4*s5**4 - 2472*s1**2*s2**5*s6**4 - 76*s1**2*s2**4*s3**3*s5*s6**2 ++ 370*s1**2*s2**4*s3**2*s4**2*s6**2 - 287*s1**2*s2**4*s3**2*s4*s5**2*s6 + 65*s1**2*s2**4*s3**2*s5**4 +- 28*s1**2*s2**4*s3*s4**3*s5*s6 + 5*s1**2*s2**4*s3*s4**2*s5**3 - 8092*s1**2*s2**4*s3*s5*s6**3 ++ 8*s1**2*s2**4*s4**5*s6 - 2*s1**2*s2**4*s4**4*s5**2 + 1096*s1**2*s2**4*s4**2*s6**3 +- 5144*s1**2*s2**4*s4*s5**2*s6**2 + 449*s1**2*s2**4*s5**4*s6 - 210*s1**2*s2**3*s3**4*s4*s6**2 ++ 76*s1**2*s2**3*s3**4*s5**2*s6 + 43*s1**2*s2**3*s3**3*s4**2*s5*s6 - 15*s1**2*s2**3*s3**3*s4*s5**3 +- 6*s1**2*s2**3*s3**2*s4**4*s6 + 2*s1**2*s2**3*s3**2*s4**3*s5**2 + 1962*s1**2*s2**3*s3**2*s4*s6**3 ++ 3181*s1**2*s2**3*s3**2*s5**2*s6**2 + 1684*s1**2*s2**3*s3*s4**2*s5*s6**2 + 500*s1**2*s2**3*s3*s4*s5**3*s6 ++ 590*s1**2*s2**3*s3*s5**5 - 168*s1**2*s2**3*s4**4*s6**2 - 494*s1**2*s2**3*s4**3*s5**2*s6 +- 172*s1**2*s2**3*s4**2*s5**4 - 22080*s1**2*s2**3*s4*s6**4 + 58894*s1**2*s2**3*s5**2*s6**3 ++ 27*s1**2*s2**2*s3**6*s6**2 - 9*s1**2*s2**2*s3**5*s4*s5*s6 + s1**2*s2**2*s3**5*s5**3 ++ s1**2*s2**2*s3**4*s4**3*s6 - 486*s1**2*s2**2*s3**4*s6**3 + 1071*s1**2*s2**2*s3**3*s4*s5*s6**2 ++ 57*s1**2*s2**2*s3**3*s5**3*s6 + 2262*s1**2*s2**2*s3**2*s4**3*s6**2 - 2742*s1**2*s2**2*s3**2*s4**2*s5**2*s6 +- 528*s1**2*s2**2*s3**2*s4*s5**4 - 29160*s1**2*s2**2*s3**2*s6**4 + 772*s1**2*s2**2*s3*s4**4*s5*s6 ++ 447*s1**2*s2**2*s3*s4**3*s5**3 - 96732*s1**2*s2**2*s3*s4*s5*s6**3 + 22209*s1**2*s2**2*s3*s5**3*s6**2 +- 160*s1**2*s2**2*s4**6*s6 - 54*s1**2*s2**2*s4**5*s5**2 - 7992*s1**2*s2**2*s4**3*s6**3 ++ 8634*s1**2*s2**2*s4**2*s5**2*s6**2 - 10040*s1**2*s2**2*s4*s5**4*s6 + 3250*s1**2*s2**2*s5**6 ++ 529200*s1**2*s2**2*s6**5 - 351*s1**2*s2*s3**5*s5*s6**2 - 1215*s1**2*s2*s3**4*s4**2*s6**2 +- 360*s1**2*s2*s3**4*s4*s5**2*s6 + 196*s1**2*s2*s3**4*s5**4 + 741*s1**2*s2*s3**3*s4**3*s5*s6 ++ 168*s1**2*s2*s3**3*s4**2*s5**3 + 11718*s1**2*s2*s3**3*s5*s6**3 - 106*s1**2*s2*s3**2*s4**5*s6 +- 287*s1**2*s2*s3**2*s4**4*s5**2 + 22572*s1**2*s2*s3**2*s4**2*s6**3 - 8892*s1**2*s2*s3**2*s4*s5**2*s6**2 ++ 80*s1**2*s2*s3**2*s5**4*s6 + 88*s1**2*s2*s3*s4**6*s5 + 22144*s1**2*s2*s3*s4**3*s5*s6**2 +- 5698*s1**2*s2*s3*s4**2*s5**3*s6 - 850*s1**2*s2*s3*s4*s5**5 + 169560*s1**2*s2*s3*s5*s6**4 +- 8*s1**2*s2*s4**8 + 3032*s1**2*s2*s4**5*s6**2 - 5144*s1**2*s2*s4**4*s5**2*s6 + 1470*s1**2*s2*s4**3*s5**4 +- 249480*s1**2*s2*s4**2*s6**4 - 105390*s1**2*s2*s4*s5**2*s6**3 + 58900*s1**2*s2*s5**4*s6**2 ++ 162*s1**2*s3**6*s4*s6**2 + 216*s1**2*s3**6*s5**2*s6 - 216*s1**2*s3**5*s4**2*s5*s6 +- 78*s1**2*s3**5*s4*s5**3 + 36*s1**2*s3**4*s4**4*s6 + 76*s1**2*s3**4*s4**3*s5**2 - +3564*s1**2*s3**4*s4*s6**3 + 8802*s1**2*s3**4*s5**2*s6**2 - 22*s1**2*s3**3*s4**5*s5 +- 11475*s1**2*s3**3*s4**2*s5*s6**2 - 2808*s1**2*s3**3*s4*s5**3*s6 + 1200*s1**2*s3**3*s5**5 ++ 2*s1**2*s3**2*s4**7 + 222*s1**2*s3**2*s4**4*s6**2 + 3181*s1**2*s3**2*s4**3*s5**2*s6 +- 610*s1**2*s3**2*s4**2*s5**4 - 165240*s1**2*s3**2*s4*s6**4 + 118260*s1**2*s3**2*s5**2*s6**3 ++ 572*s1**2*s3*s4**5*s5*s6 - 294*s1**2*s3*s4**4*s5**3 - 32616*s1**2*s3*s4**2*s5*s6**3 +- 55665*s1**2*s3*s4*s5**3*s6**2 + 17250*s1**2*s3*s5**5*s6 - 232*s1**2*s4**7*s6 + 86*s1**2*s4**6*s5**2 ++ 48408*s1**2*s4**4*s6**3 + 58894*s1**2*s4**3*s5**2*s6**2 - 46650*s1**2*s4**2*s5**4*s6 ++ 7500*s1**2*s4*s5**6 - 129600*s1**2*s4*s6**5 + 41040*s1**2*s5**2*s6**4 - 48*s1*s2**7*s4*s5*s6**2 ++ 12*s1*s2**7*s5**3*s6 + 12*s1*s2**6*s3**2*s5*s6**2 - 144*s1*s2**6*s3*s4**2*s6**2 ++ 88*s1*s2**6*s3*s4*s5**2*s6 - 13*s1*s2**6*s3*s5**4 + 1680*s1*s2**6*s5*s6**3 + 72*s1*s2**5*s3**3*s4*s6**2 +- 22*s1*s2**5*s3**3*s5**2*s6 - 4*s1*s2**5*s3**2*s4**2*s5*s6 + s1*s2**5*s3**2*s4*s5**3 +- 144*s1*s2**5*s3*s4*s6**3 + 572*s1*s2**5*s3*s5**2*s6**2 + 736*s1*s2**5*s4**2*s5*s6**2 ++ 128*s1*s2**5*s4*s5**3*s6 - 124*s1*s2**5*s5**5 - 9*s1*s2**4*s3**5*s6**2 + s1*s2**4*s3**4*s4*s5*s6 ++ 36*s1*s2**4*s3**3*s6**3 - 2028*s1*s2**4*s3**2*s4*s5*s6**2 - 547*s1*s2**4*s3**2*s5**3*s6 +- 480*s1*s2**4*s3*s4**3*s6**2 + 772*s1*s2**4*s3*s4**2*s5**2*s6 - 29*s1*s2**4*s3*s4*s5**4 ++ 6336*s1*s2**4*s3*s6**4 - 12*s1*s2**4*s4**3*s5**3 + 4368*s1*s2**4*s4*s5*s6**3 - 22624*s1*s2**4*s5**3*s6**2 ++ 441*s1*s2**3*s3**4*s5*s6**2 + 336*s1*s2**3*s3**3*s4**2*s6**2 + 741*s1*s2**3*s3**3*s4*s5**2*s6 ++ 12*s1*s2**3*s3**3*s5**4 - 868*s1*s2**3*s3**2*s4**3*s5*s6 + 93*s1*s2**3*s3**2*s4**2*s5**3 ++ 11016*s1*s2**3*s3**2*s5*s6**3 + 176*s1*s2**3*s3*s4**5*s6 - 28*s1*s2**3*s3*s4**4*s5**2 ++ 14784*s1*s2**3*s3*s4**2*s6**3 + 22144*s1*s2**3*s3*s4*s5**2*s6**2 + 5145*s1*s2**3*s3*s5**4*s6 +- 11344*s1*s2**3*s4**3*s5*s6**2 + 5064*s1*s2**3*s4**2*s5**3*s6 - 2050*s1*s2**3*s4*s5**5 +- 346896*s1*s2**3*s5*s6**4 - 54*s1*s2**2*s3**5*s4*s6**2 - 216*s1*s2**2*s3**5*s5**2*s6 ++ 324*s1*s2**2*s3**4*s4**2*s5*s6 - 95*s1*s2**2*s3**4*s4*s5**3 - 80*s1*s2**2*s3**3*s4**4*s6 ++ 43*s1*s2**2*s3**3*s4**3*s5**2 - 12204*s1*s2**2*s3**3*s4*s6**3 - 11475*s1*s2**2*s3**3*s5**2*s6**2 +- 4*s1*s2**2*s3**2*s4**5*s5 - 3888*s1*s2**2*s3**2*s4**2*s5*s6**2 - 4844*s1*s2**2*s3**2*s4*s5**3*s6 +- 725*s1*s2**2*s3**2*s5**5 - 1312*s1*s2**2*s3*s4**4*s6**2 + 1684*s1*s2**2*s3*s4**3*s5**2*s6 ++ 1995*s1*s2**2*s3*s4**2*s5**4 + 139104*s1*s2**2*s3*s4*s6**4 - 32616*s1*s2**2*s3*s5**2*s6**3 ++ 736*s1*s2**2*s4**5*s5*s6 - 676*s1*s2**2*s4**4*s5**3 + 131040*s1*s2**2*s4**2*s5*s6**3 ++ 16240*s1*s2**2*s4*s5**3*s6**2 - 20250*s1*s2**2*s5**5*s6 - 27*s1*s2*s3**6*s4*s5*s6 ++ 18*s1*s2*s3**6*s5**3 + 9*s1*s2*s3**5*s4**3*s6 - 9*s1*s2*s3**5*s4**2*s5**2 + 1944*s1*s2*s3**5*s6**3 ++ s1*s2*s3**4*s4**4*s5 + 6156*s1*s2*s3**4*s4*s5*s6**2 + 1143*s1*s2*s3**4*s5**3*s6 ++ 324*s1*s2*s3**3*s4**3*s6**2 + 1071*s1*s2*s3**3*s4**2*s5**2*s6 + 15*s1*s2*s3**3*s4*s5**4 +- 7776*s1*s2*s3**3*s6**4 - 2028*s1*s2*s3**2*s4**4*s5*s6 - 397*s1*s2*s3**2*s4**3*s5**3 ++ 112860*s1*s2*s3**2*s4*s5*s6**3 - 10305*s1*s2*s3**2*s5**3*s6**2 + 336*s1*s2*s3*s4**6*s6 ++ 304*s1*s2*s3*s4**5*s5**2 - 68976*s1*s2*s3*s4**3*s6**3 - 96732*s1*s2*s3*s4**2*s5**2*s6**2 ++ 36700*s1*s2*s3*s4*s5**4*s6 - 1250*s1*s2*s3*s5**6 - 1477440*s1*s2*s3*s6**5 - 48*s1*s2*s4**7*s5 ++ 4368*s1*s2*s4**4*s5*s6**2 + 10360*s1*s2*s4**3*s5**3*s6 - 3500*s1*s2*s4**2*s5**5 ++ 935280*s1*s2*s4*s5*s6**4 - 242100*s1*s2*s5**3*s6**3 - 972*s1*s3**6*s5*s6**2 - 351*s1*s3**5*s4*s5**2*s6 +- 99*s1*s3**5*s5**4 + 441*s1*s3**4*s4**3*s5*s6 + 141*s1*s3**4*s4**2*s5**3 - 36936*s1*s3**4*s5*s6**3 +- 84*s1*s3**3*s4**5*s6 - 76*s1*s3**3*s4**4*s5**2 + 17496*s1*s3**3*s4**2*s6**3 + 11718*s1*s3**3*s4*s5**2*s6**2 +- 6525*s1*s3**3*s5**4*s6 + 12*s1*s3**2*s4**6*s5 + 11016*s1*s3**2*s4**3*s5*s6**2 + +5895*s1*s3**2*s4**2*s5**3*s6 - 1750*s1*s3**2*s4*s5**5 - 252720*s1*s3**2*s5*s6**4 - +2544*s1*s3*s4**5*s6**2 - 8092*s1*s3*s4**4*s5**2*s6 + 2300*s1*s3*s4**3*s5**4 + 536544*s1*s3*s4**2*s6**4 ++ 169560*s1*s3*s4*s5**2*s6**3 - 103500*s1*s3*s5**4*s6**2 + 1680*s1*s4**6*s5*s6 - 468*s1*s4**5*s5**3 +- 346896*s1*s4**3*s5*s6**3 + 93900*s1*s4**2*s5**3*s6**2 + 35000*s1*s4*s5**5*s6 - 9375*s1*s5**7 ++ 108864*s1*s5*s6**5 + 16*s2**8*s4**2*s6**2 - 8*s2**8*s4*s5**2*s6 + s2**8*s5**4 - +8*s2**7*s3**2*s4*s6**2 + 2*s2**7*s3**2*s5**2*s6 - 96*s2**7*s4*s6**3 - 232*s2**7*s5**2*s6**2 ++ s2**6*s3**4*s6**2 + 24*s2**6*s3**2*s6**3 + 336*s2**6*s3*s4*s5*s6**2 + 108*s2**6*s3*s5**3*s6 +- 32*s2**6*s4**3*s6**2 - 160*s2**6*s4**2*s5**2*s6 + 38*s2**6*s4*s5**4 + 144*s2**6*s6**4 +- 84*s2**5*s3**3*s5*s6**2 + 8*s2**5*s3**2*s4**2*s6**2 - 106*s2**5*s3**2*s4*s5**2*s6 +- 12*s2**5*s3**2*s5**4 + 176*s2**5*s3*s4**3*s5*s6 - 36*s2**5*s3*s4**2*s5**3 - 2544*s2**5*s3*s5*s6**3 +- 32*s2**5*s4**5*s6 + 8*s2**5*s4**4*s5**2 - 3072*s2**5*s4**2*s6**3 + 3032*s2**5*s4*s5**2*s6**2 ++ 954*s2**5*s5**4*s6 + 36*s2**4*s3**4*s5**2*s6 - 80*s2**4*s3**3*s4**2*s5*s6 + 25*s2**4*s3**3*s4*s5**3 ++ 16*s2**4*s3**2*s4**4*s6 - 6*s2**4*s3**2*s4**3*s5**2 + 2520*s2**4*s3**2*s4*s6**3 ++ 222*s2**4*s3**2*s5**2*s6**2 - 1312*s2**4*s3*s4**2*s5*s6**2 - 3616*s2**4*s3*s4*s5**3*s6 +- 125*s2**4*s3*s5**5 + 1296*s2**4*s4**4*s6**2 - 168*s2**4*s4**3*s5**2*s6 + 375*s2**4*s4**2*s5**4 ++ 19296*s2**4*s4*s6**4 + 48408*s2**4*s5**2*s6**3 + 9*s2**3*s3**5*s4*s5*s6 - 4*s2**3*s3**5*s5**3 +- 2*s2**3*s3**4*s4**3*s6 + s2**3*s3**4*s4**2*s5**2 - 432*s2**3*s3**4*s6**3 + 324*s2**3*s3**3*s4*s5*s6**2 ++ 923*s2**3*s3**3*s5**3*s6 - 752*s2**3*s3**2*s4**3*s6**2 + 2262*s2**3*s3**2*s4**2*s5**2*s6 ++ 525*s2**3*s3**2*s4*s5**4 - 9936*s2**3*s3**2*s6**4 - 480*s2**3*s3*s4**4*s5*s6 - 700*s2**3*s3*s4**3*s5**3 +- 68976*s2**3*s3*s4*s5*s6**3 - 11360*s2**3*s3*s5**3*s6**2 - 32*s2**3*s4**6*s6 + 152*s2**3*s4**5*s5**2 ++ 6912*s2**3*s4**3*s6**3 - 7992*s2**3*s4**2*s5**2*s6**2 + 5550*s2**3*s4*s5**4*s6 - +29376*s2**3*s6**5 + 108*s2**2*s3**4*s4**2*s6**2 - 1215*s2**2*s3**4*s4*s5**2*s6 - 150*s2**2*s3**4*s5**4 ++ 336*s2**2*s3**3*s4**3*s5*s6 - 185*s2**2*s3**3*s4**2*s5**3 + 17496*s2**2*s3**3*s5*s6**3 ++ 8*s2**2*s3**2*s4**5*s6 + 370*s2**2*s3**2*s4**4*s5**2 - 864*s2**2*s3**2*s4**2*s6**3 ++ 22572*s2**2*s3**2*s4*s5**2*s6**2 + 225*s2**2*s3**2*s5**4*s6 - 144*s2**2*s3*s4**6*s5 ++ 14784*s2**2*s3*s4**3*s5*s6**2 - 12020*s2**2*s3*s4**2*s5**3*s6 + 625*s2**2*s3*s4*s5**5 ++ 536544*s2**2*s3*s5*s6**4 + 16*s2**2*s4**8 - 3072*s2**2*s4**5*s6**2 + 1096*s2**2*s4**4*s5**2*s6 ++ 250*s2**2*s4**3*s5**4 - 93744*s2**2*s4**2*s6**4 - 249480*s2**2*s4*s5**2*s6**3 + +70125*s2**2*s5**4*s6**2 + 162*s2*s3**6*s5**2*s6 - 54*s2*s3**5*s4**2*s5*s6 + 198*s2*s3**5*s4*s5**3 +- 210*s2*s3**4*s4**3*s5**2 - 3564*s2*s3**4*s5**2*s6**2 + 72*s2*s3**3*s4**5*s5 - 12204*s2*s3**3*s4**2*s5*s6**2 ++ 1935*s2*s3**3*s4*s5**3*s6 - 8*s2*s3**2*s4**7 + 2520*s2*s3**2*s4**4*s6**2 + 1962*s2*s3**2*s4**3*s5**2*s6 +- 125*s2*s3**2*s4**2*s5**4 - 178848*s2*s3**2*s4*s6**4 - 165240*s2*s3**2*s5**2*s6**3 +- 144*s2*s3*s4**5*s5*s6 - 200*s2*s3*s4**4*s5**3 + 139104*s2*s3*s4**2*s5*s6**3 + 72900*s2*s3*s4*s5**3*s6**2 +- 20625*s2*s3*s5**5*s6 - 96*s2*s4**7*s6 + 56*s2*s4**6*s5**2 + 19296*s2*s4**4*s6**3 +- 22080*s2*s4**3*s5**2*s6**2 - 7750*s2*s4**2*s5**4*s6 + 3125*s2*s4*s5**6 + 248832*s2*s4*s6**5 +- 129600*s2*s5**2*s6**4 - 27*s3**7*s5**3 + 27*s3**6*s4**2*s5**2 - 9*s3**5*s4**4*s5 ++ 1944*s3**5*s4*s5*s6**2 + 54*s3**5*s5**3*s6 + s3**4*s4**6 - 432*s3**4*s4**3*s6**2 +- 486*s3**4*s4**2*s5**2*s6 + 46656*s3**4*s6**4 + 36*s3**3*s4**4*s5*s6 + 50*s3**3*s4**3*s5**3 +- 7776*s3**3*s4*s5*s6**3 + 29700*s3**3*s5**3*s6**2 + 24*s3**2*s4**6*s6 - 14*s3**2*s4**5*s5**2 +- 9936*s3**2*s4**3*s6**3 - 29160*s3**2*s4**2*s5**2*s6**2 - 10125*s3**2*s4*s5**4*s6 ++ 3125*s3**2*s5**6 + 1026432*s3**2*s6**5 + 6336*s3*s4**4*s5*s6**2 + 11700*s3*s4**3*s5**3*s6 +- 3125*s3*s4**2*s5**5 - 1477440*s3*s4*s5*s6**4 + 432000*s3*s5**3*s6**3 + 144*s4**6*s6**2 +- 2472*s4**5*s5**2*s6 + 625*s4**4*s5**4 - 29376*s4**3*s6**4 + 529200*s4**2*s5**2*s6**3 +- 292500*s4*s5**4*s6**2 + 40625*s5**6*s6 - 186624*s6**6) + ], + (6, 2): [ + lambda s1, s2, s3, s4, s5, s6: (-s3), + lambda s1, s2, s3, s4, s5, s6: (-s1*s5 + s2*s4 - 9*s6), + lambda s1, s2, s3, s4, s5, s6: (s1*s2*s6 + 2*s1*s3*s5 - s1*s4**2 - s2**2*s5 + 6*s3*s6 + s4*s5), + lambda s1, s2, s3, s4, s5, s6: (s1**2*s4*s6 - s1**2*s5**2 - 3*s1*s2*s3*s6 + s1*s2*s4*s5 + 9*s1*s5*s6 + s2**3*s6 - +9*s2*s4*s6 + s2*s5**2 + 3*s3**2*s6 - 3*s3*s4*s5 + s4**3 + 27*s6**2), + lambda s1, s2, s3, s4, s5, s6: (-2*s1**3*s6**2 + 2*s1**2*s2*s5*s6 + 2*s1**2*s3*s4*s6 - s1**2*s3*s5**2 - s1*s2**2*s4*s6 +- 3*s1*s2*s6**2 - 16*s1*s3*s5*s6 + 4*s1*s4**2*s6 + 2*s1*s4*s5**2 + 4*s2**2*s5*s6 + +s2*s3*s4*s6 + 2*s2*s3*s5**2 - s2*s4**2*s5 - 9*s3*s6**2 - 3*s4*s5*s6 - 2*s5**3), + lambda s1, s2, s3, s4, s5, s6: (s1**3*s3*s6**2 - 3*s1**3*s4*s5*s6 + s1**3*s5**3 - s1**2*s2**2*s6**2 + s1**2*s2*s3*s5*s6 +- 2*s1**2*s4*s6**2 + 6*s1**2*s5**2*s6 + 16*s1*s2*s3*s6**2 - 3*s1*s2*s5**3 - s1*s3**2*s5*s6 +- 2*s1*s3*s4**2*s6 + s1*s3*s4*s5**2 - 30*s1*s5*s6**2 - 4*s2**3*s6**2 - 2*s2**2*s3*s5*s6 ++ s2**2*s4**2*s6 + 18*s2*s4*s6**2 - 2*s2*s5**2*s6 - 15*s3**2*s6**2 + 16*s3*s4*s5*s6 ++ s3*s5**3 - 4*s4**3*s6 - s4**2*s5**2 - 27*s6**3), + lambda s1, s2, s3, s4, s5, s6: (s1**4*s5*s6**2 + 2*s1**3*s2*s4*s6**2 - s1**3*s2*s5**2*s6 - s1**3*s3**2*s6**2 + 9*s1**3*s6**3 +- 14*s1**2*s2*s5*s6**2 - 11*s1**2*s3*s4*s6**2 + 6*s1**2*s3*s5**2*s6 + 3*s1**2*s4**2*s5*s6 +- s1**2*s4*s5**3 + 3*s1*s2**2*s5**2*s6 + 3*s1*s2*s3**2*s6**2 - s1*s2*s3*s4*s5*s6 + +39*s1*s3*s5*s6**2 - 14*s1*s4*s5**2*s6 + s1*s5**4 - 11*s2*s3*s5**2*s6 + 2*s2*s4*s5**3 +- 3*s3**3*s6**2 + 3*s3**2*s4*s5*s6 - s3**2*s5**3 + 9*s5**3*s6), + lambda s1, s2, s3, s4, s5, s6: (-s1**4*s2*s6**3 + s1**4*s3*s5*s6**2 - 4*s1**3*s3*s6**3 + 10*s1**3*s4*s5*s6**2 - 4*s1**3*s5**3*s6 ++ 8*s1**2*s2**2*s6**3 - 8*s1**2*s2*s3*s5*s6**2 - 2*s1**2*s2*s4**2*s6**2 + s1**2*s2*s4*s5**2*s6 ++ s1**2*s3**2*s4*s6**2 - 6*s1**2*s4*s6**3 - 7*s1**2*s5**2*s6**2 - 24*s1*s2*s3*s6**3 +- 4*s1*s2*s4*s5*s6**2 + 10*s1*s2*s5**3*s6 + 8*s1*s3**2*s5*s6**2 + 8*s1*s3*s4**2*s6**2 +- 8*s1*s3*s4*s5**2*s6 + s1*s3*s5**4 + 36*s1*s5*s6**3 + 8*s2**2*s3*s5*s6**2 - 2*s2**2*s4*s5**2*s6 +- 2*s2*s3**2*s4*s6**2 + s2*s3**2*s5**2*s6 - 6*s2*s5**2*s6**2 + 18*s3**2*s6**3 - 24*s3*s4*s5*s6**2 +- 4*s3*s5**3*s6 + 8*s4**2*s5**2*s6 - s4*s5**4), + lambda s1, s2, s3, s4, s5, s6: (-s1**5*s4*s6**3 - 2*s1**4*s5*s6**3 + 3*s1**3*s2*s5**2*s6**2 + 3*s1**3*s3**2*s6**3 +- s1**3*s3*s4*s5*s6**2 - 8*s1**3*s6**4 + 16*s1**2*s2*s5*s6**3 + 8*s1**2*s3*s4*s6**3 +- 6*s1**2*s3*s5**2*s6**2 - 8*s1**2*s4**2*s5*s6**2 + 3*s1**2*s4*s5**3*s6 - 8*s1*s2**2*s5**2*s6**2 +- 8*s1*s2*s3**2*s6**3 + 8*s1*s2*s3*s4*s5*s6**2 - s1*s2*s3*s5**3*s6 - s1*s3**3*s5*s6**2 +- 24*s1*s3*s5*s6**3 + 16*s1*s4*s5**2*s6**2 - 2*s1*s5**4*s6 + 8*s2*s3*s5**2*s6**2 - +s2*s5**5 + 8*s3**3*s6**3 - 8*s3**2*s4*s5*s6**2 + 3*s3**2*s5**3*s6 - 8*s5**3*s6**2), + lambda s1, s2, s3, s4, s5, s6: (s1**6*s6**4 - 4*s1**4*s2*s6**4 - 2*s1**4*s3*s5*s6**3 + s1**4*s4**2*s6**3 + 8*s1**3*s3*s6**4 +- 4*s1**3*s4*s5*s6**3 + 2*s1**3*s5**3*s6**2 + 8*s1**2*s2*s3*s5*s6**3 - 2*s1**2*s2*s4*s5**2*s6**2 +- 2*s1**2*s3**2*s4*s6**3 + s1**2*s3**2*s5**2*s6**2 - 4*s1*s2*s5**3*s6**2 - 12*s1*s3**2*s5*s6**3 ++ 8*s1*s3*s4*s5**2*s6**2 - 2*s1*s3*s5**4*s6 + s2**2*s5**4*s6 - 2*s2*s3**2*s5**2*s6**2 ++ s3**4*s6**3 + 8*s3*s5**3*s6**2 - 4*s4*s5**4*s6 + s5**6) + ], +} diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/subfield.py b/lib/python3.10/site-packages/sympy/polys/numberfields/subfield.py new file mode 100644 index 0000000000000000000000000000000000000000..b959ddeb27a6bc5719dc4fc567bd20a3fd936798 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/subfield.py @@ -0,0 +1,507 @@ +r""" +Functions in ``polys.numberfields.subfield`` solve the "Subfield Problem" and +allied problems, for algebraic number fields. + +Following Cohen (see [Cohen93]_ Section 4.5), we can define the main problem as +follows: + +* **Subfield Problem:** + + Given two number fields $\mathbb{Q}(\alpha)$, $\mathbb{Q}(\beta)$ + via the minimal polynomials for their generators $\alpha$ and $\beta$, decide + whether one field is isomorphic to a subfield of the other. + +From a solution to this problem flow solutions to the following problems as +well: + +* **Primitive Element Problem:** + + Given several algebraic numbers + $\alpha_1, \ldots, \alpha_m$, compute a single algebraic number $\theta$ + such that $\mathbb{Q}(\alpha_1, \ldots, \alpha_m) = \mathbb{Q}(\theta)$. + +* **Field Isomorphism Problem:** + + Decide whether two number fields + $\mathbb{Q}(\alpha)$, $\mathbb{Q}(\beta)$ are isomorphic. + +* **Field Membership Problem:** + + Given two algebraic numbers $\alpha$, + $\beta$, decide whether $\alpha \in \mathbb{Q}(\beta)$, and if so write + $\alpha = f(\beta)$ for some $f(x) \in \mathbb{Q}[x]$. +""" + +from sympy.core.add import Add +from sympy.core.numbers import AlgebraicNumber +from sympy.core.singleton import S +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify, _sympify +from sympy.ntheory import sieve +from sympy.polys.densetools import dup_eval +from sympy.polys.domains import QQ +from sympy.polys.numberfields.minpoly import _choose_factor, minimal_polynomial +from sympy.polys.polyerrors import IsomorphismFailed +from sympy.polys.polytools import Poly, PurePoly, factor_list +from sympy.utilities import public + +from mpmath import MPContext + + +def is_isomorphism_possible(a, b): + """Necessary but not sufficient test for isomorphism. """ + n = a.minpoly.degree() + m = b.minpoly.degree() + + if m % n != 0: + return False + + if n == m: + return True + + da = a.minpoly.discriminant() + db = b.minpoly.discriminant() + + i, k, half = 1, m//n, db//2 + + while True: + p = sieve[i] + P = p**k + + if P > half: + break + + if ((da % p) % 2) and not (db % P): + return False + + i += 1 + + return True + + +def field_isomorphism_pslq(a, b): + """Construct field isomorphism using PSLQ algorithm. """ + if not a.root.is_real or not b.root.is_real: + raise NotImplementedError("PSLQ doesn't support complex coefficients") + + f = a.minpoly + g = b.minpoly.replace(f.gen) + + n, m, prev = 100, b.minpoly.degree(), None + ctx = MPContext() + + for i in range(1, 5): + A = a.root.evalf(n) + B = b.root.evalf(n) + + basis = [1, B] + [ B**i for i in range(2, m) ] + [-A] + + ctx.dps = n + coeffs = ctx.pslq(basis, maxcoeff=10**10, maxsteps=1000) + + if coeffs is None: + # PSLQ can't find an integer linear combination. Give up. + break + + if coeffs != prev: + prev = coeffs + else: + # Increasing precision didn't produce anything new. Give up. + break + + # We have + # c0 + c1*B + c2*B^2 + ... + cm-1*B^(m-1) - cm*A ~ 0. + # So bring cm*A to the other side, and divide through by cm, + # for an approximate representation of A as a polynomial in B. + # (We know cm != 0 since `b.minpoly` is irreducible.) + coeffs = [S(c)/coeffs[-1] for c in coeffs[:-1]] + + # Throw away leading zeros. + while not coeffs[-1]: + coeffs.pop() + + coeffs = list(reversed(coeffs)) + h = Poly(coeffs, f.gen, domain='QQ') + + # We only have A ~ h(B). We must check whether the relation is exact. + if f.compose(h).rem(g).is_zero: + # Now we know that h(b) is in fact equal to _some conjugate of_ a. + # But from the very precise approximation A ~ h(B) we can assume + # the conjugate is a itself. + return coeffs + else: + n *= 2 + + return None + + +def field_isomorphism_factor(a, b): + """Construct field isomorphism via factorization. """ + _, factors = factor_list(a.minpoly, extension=b) + for f, _ in factors: + if f.degree() == 1: + # Any linear factor f(x) represents some conjugate of a in QQ(b). + # We want to know whether this linear factor represents a itself. + # Let f = x - c + c = -f.rep.TC() + # Write c as polynomial in b + coeffs = c.to_sympy_list() + d, terms = len(coeffs) - 1, [] + for i, coeff in enumerate(coeffs): + terms.append(coeff*b.root**(d - i)) + r = Add(*terms) + # Check whether we got the number a + if a.minpoly.same_root(r, a): + return coeffs + + # If none of the linear factors represented a in QQ(b), then in fact a is + # not an element of QQ(b). + return None + + +@public +def field_isomorphism(a, b, *, fast=True): + r""" + Find an embedding of one number field into another. + + Explanation + =========== + + This function looks for an isomorphism from $\mathbb{Q}(a)$ onto some + subfield of $\mathbb{Q}(b)$. Thus, it solves the Subfield Problem. + + Examples + ======== + + >>> from sympy import sqrt, field_isomorphism, I + >>> print(field_isomorphism(3, sqrt(2))) # doctest: +SKIP + [3] + >>> print(field_isomorphism( I*sqrt(3), I*sqrt(3)/2)) # doctest: +SKIP + [2, 0] + + Parameters + ========== + + a : :py:class:`~.Expr` + Any expression representing an algebraic number. + b : :py:class:`~.Expr` + Any expression representing an algebraic number. + fast : boolean, optional (default=True) + If ``True``, we first attempt a potentially faster way of computing the + isomorphism, falling back on a slower method if this fails. If + ``False``, we go directly to the slower method, which is guaranteed to + return a result. + + Returns + ======= + + List of rational numbers, or None + If $\mathbb{Q}(a)$ is not isomorphic to some subfield of + $\mathbb{Q}(b)$, then return ``None``. Otherwise, return a list of + rational numbers representing an element of $\mathbb{Q}(b)$ to which + $a$ may be mapped, in order to define a monomorphism, i.e. an + isomorphism from $\mathbb{Q}(a)$ to some subfield of $\mathbb{Q}(b)$. + The elements of the list are the coefficients of falling powers of $b$. + + """ + a, b = sympify(a), sympify(b) + + if not a.is_AlgebraicNumber: + a = AlgebraicNumber(a) + + if not b.is_AlgebraicNumber: + b = AlgebraicNumber(b) + + a = a.to_primitive_element() + b = b.to_primitive_element() + + if a == b: + return a.coeffs() + + n = a.minpoly.degree() + m = b.minpoly.degree() + + if n == 1: + return [a.root] + + if m % n != 0: + return None + + if fast: + try: + result = field_isomorphism_pslq(a, b) + + if result is not None: + return result + except NotImplementedError: + pass + + return field_isomorphism_factor(a, b) + + +def _switch_domain(g, K): + # An algebraic relation f(a, b) = 0 over Q can also be written + # g(b) = 0 where g is in Q(a)[x] and h(a) = 0 where h is in Q(b)[x]. + # This function transforms g into h where Q(b) = K. + frep = g.rep.inject() + hrep = frep.eject(K, front=True) + + return g.new(hrep, g.gens[0]) + + +def _linsolve(p): + # Compute root of linear polynomial. + c, d = p.rep.to_list() + return -d/c + + +@public +def primitive_element(extension, x=None, *, ex=False, polys=False): + r""" + Find a single generator for a number field given by several generators. + + Explanation + =========== + + The basic problem is this: Given several algebraic numbers + $\alpha_1, \alpha_2, \ldots, \alpha_n$, find a single algebraic number + $\theta$ such that + $\mathbb{Q}(\alpha_1, \alpha_2, \ldots, \alpha_n) = \mathbb{Q}(\theta)$. + + This function actually guarantees that $\theta$ will be a linear + combination of the $\alpha_i$, with non-negative integer coefficients. + + Furthermore, if desired, this function will tell you how to express each + $\alpha_i$ as a $\mathbb{Q}$-linear combination of the powers of $\theta$. + + Examples + ======== + + >>> from sympy import primitive_element, sqrt, S, minpoly, simplify + >>> from sympy.abc import x + >>> f, lincomb, reps = primitive_element([sqrt(2), sqrt(3)], x, ex=True) + + Then ``lincomb`` tells us the primitive element as a linear combination of + the given generators ``sqrt(2)`` and ``sqrt(3)``. + + >>> print(lincomb) + [1, 1] + + This means the primtiive element is $\sqrt{2} + \sqrt{3}$. + Meanwhile ``f`` is the minimal polynomial for this primitive element. + + >>> print(f) + x**4 - 10*x**2 + 1 + >>> print(minpoly(sqrt(2) + sqrt(3), x)) + x**4 - 10*x**2 + 1 + + Finally, ``reps`` (which was returned only because we set keyword arg + ``ex=True``) tells us how to recover each of the generators $\sqrt{2}$ and + $\sqrt{3}$ as $\mathbb{Q}$-linear combinations of the powers of the + primitive element $\sqrt{2} + \sqrt{3}$. + + >>> print([S(r) for r in reps[0]]) + [1/2, 0, -9/2, 0] + >>> theta = sqrt(2) + sqrt(3) + >>> print(simplify(theta**3/2 - 9*theta/2)) + sqrt(2) + >>> print([S(r) for r in reps[1]]) + [-1/2, 0, 11/2, 0] + >>> print(simplify(-theta**3/2 + 11*theta/2)) + sqrt(3) + + Parameters + ========== + + extension : list of :py:class:`~.Expr` + Each expression must represent an algebraic number $\alpha_i$. + x : :py:class:`~.Symbol`, optional (default=None) + The desired symbol to appear in the computed minimal polynomial for the + primitive element $\theta$. If ``None``, we use a dummy symbol. + ex : boolean, optional (default=False) + If and only if ``True``, compute the representation of each $\alpha_i$ + as a $\mathbb{Q}$-linear combination over the powers of $\theta$. + polys : boolean, optional (default=False) + If ``True``, return the minimal polynomial as a :py:class:`~.Poly`. + Otherwise return it as an :py:class:`~.Expr`. + + Returns + ======= + + Pair (f, coeffs) or triple (f, coeffs, reps), where: + ``f`` is the minimal polynomial for the primitive element. + ``coeffs`` gives the primitive element as a linear combination of the + given generators. + ``reps`` is present if and only if argument ``ex=True`` was passed, + and is a list of lists of rational numbers. Each list gives the + coefficients of falling powers of the primitive element, to recover + one of the original, given generators. + + """ + if not extension: + raise ValueError("Cannot compute primitive element for empty extension") + extension = [_sympify(ext) for ext in extension] + + if x is not None: + x, cls = sympify(x), Poly + else: + x, cls = Dummy('x'), PurePoly + + if not ex: + gen, coeffs = extension[0], [1] + g = minimal_polynomial(gen, x, polys=True) + for ext in extension[1:]: + if ext.is_Rational: + coeffs.append(0) + continue + _, factors = factor_list(g, extension=ext) + g = _choose_factor(factors, x, gen) + [s], _, g = g.sqf_norm() + gen += s*ext + coeffs.append(s) + + if not polys: + return g.as_expr(), coeffs + else: + return cls(g), coeffs + + gen, coeffs = extension[0], [1] + f = minimal_polynomial(gen, x, polys=True) + K = QQ.algebraic_field((f, gen)) # incrementally constructed field + reps = [K.unit] # representations of extension elements in K + for ext in extension[1:]: + if ext.is_Rational: + coeffs.append(0) # rational ext is not included in the expression of a primitive element + reps.append(K.convert(ext)) # but it is included in reps + continue + p = minimal_polynomial(ext, x, polys=True) + L = QQ.algebraic_field((p, ext)) + _, factors = factor_list(f, domain=L) + f = _choose_factor(factors, x, gen) + [s], g, f = f.sqf_norm() + gen += s*ext + coeffs.append(s) + K = QQ.algebraic_field((f, gen)) + h = _switch_domain(g, K) + erep = _linsolve(h.gcd(p)) # ext as element of K + ogen = K.unit - s*erep # old gen as element of K + reps = [dup_eval(_.to_list(), ogen, K) for _ in reps] + [erep] + + if K.ext.root.is_Rational: # all extensions are rational + H = [K.convert(_).rep for _ in extension] + coeffs = [0]*len(extension) + f = cls(x, domain=QQ) + else: + H = [_.to_list() for _ in reps] + if not polys: + return f.as_expr(), coeffs, H + else: + return f, coeffs, H + + +@public +def to_number_field(extension, theta=None, *, gen=None, alias=None): + r""" + Express one algebraic number in the field generated by another. + + Explanation + =========== + + Given two algebraic numbers $\eta, \theta$, this function either expresses + $\eta$ as an element of $\mathbb{Q}(\theta)$, or else raises an exception + if $\eta \not\in \mathbb{Q}(\theta)$. + + This function is essentially just a convenience, utilizing + :py:func:`~.field_isomorphism` (our solution of the Subfield Problem) to + solve this, the Field Membership Problem. + + As an additional convenience, this function allows you to pass a list of + algebraic numbers $\alpha_1, \alpha_2, \ldots, \alpha_n$ instead of $\eta$. + It then computes $\eta$ for you, as a solution of the Primitive Element + Problem, using :py:func:`~.primitive_element` on the list of $\alpha_i$. + + Examples + ======== + + >>> from sympy import sqrt, to_number_field + >>> eta = sqrt(2) + >>> theta = sqrt(2) + sqrt(3) + >>> a = to_number_field(eta, theta) + >>> print(type(a)) + + >>> a.root + sqrt(2) + sqrt(3) + >>> print(a) + sqrt(2) + >>> a.coeffs() + [1/2, 0, -9/2, 0] + + We get an :py:class:`~.AlgebraicNumber`, whose ``.root`` is $\theta$, whose + value is $\eta$, and whose ``.coeffs()`` show how to write $\eta$ as a + $\mathbb{Q}$-linear combination in falling powers of $\theta$. + + Parameters + ========== + + extension : :py:class:`~.Expr` or list of :py:class:`~.Expr` + Either the algebraic number that is to be expressed in the other field, + or else a list of algebraic numbers, a primitive element for which is + to be expressed in the other field. + theta : :py:class:`~.Expr`, None, optional (default=None) + If an :py:class:`~.Expr` representing an algebraic number, behavior is + as described under **Explanation**. If ``None``, then this function + reduces to a shorthand for calling :py:func:`~.primitive_element` on + ``extension`` and turning the computed primitive element into an + :py:class:`~.AlgebraicNumber`. + gen : :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the generator symbol for the minimal + polynomial in the returned :py:class:`~.AlgebraicNumber`. + alias : str, :py:class:`~.Symbol`, None, optional (default=None) + If provided, this will be used as the alias symbol for the returned + :py:class:`~.AlgebraicNumber`. + + Returns + ======= + + AlgebraicNumber + Belonging to $\mathbb{Q}(\theta)$ and equaling $\eta$. + + Raises + ====== + + IsomorphismFailed + If $\eta \not\in \mathbb{Q}(\theta)$. + + See Also + ======== + + field_isomorphism + primitive_element + + """ + if hasattr(extension, '__iter__'): + extension = list(extension) + else: + extension = [extension] + + if len(extension) == 1 and isinstance(extension[0], tuple): + return AlgebraicNumber(extension[0], alias=alias) + + minpoly, coeffs = primitive_element(extension, gen, polys=True) + root = sum(coeff*ext for coeff, ext in zip(coeffs, extension)) + + if theta is None: + return AlgebraicNumber((minpoly, root), alias=alias) + else: + theta = sympify(theta) + + if not theta.is_AlgebraicNumber: + theta = AlgebraicNumber(theta, gen=gen, alias=alias) + + coeffs = field_isomorphism(root, theta) + + if coeffs is not None: + return AlgebraicNumber(theta, coeffs, alias=alias) + else: + raise IsomorphismFailed( + "%s is not in a subfield of %s" % (root, theta.root)) diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/polys/numberfields/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6daf01882059624ac74b04cec5487ecd1bfe1c65 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/polys/numberfields/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/polys/numberfields/utilities.py b/lib/python3.10/site-packages/sympy/polys/numberfields/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..fe583efb440f02f1b16c38fb7d03621c1f97e83d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/numberfields/utilities.py @@ -0,0 +1,474 @@ +"""Utilities for algebraic number theory. """ + +from sympy.core.sympify import sympify +from sympy.ntheory.factor_ import factorint +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.integerring import ZZ +from sympy.polys.matrices.exceptions import DMRankError +from sympy.polys.numberfields.minpoly import minpoly +from sympy.printing.lambdarepr import IntervalPrinter +from sympy.utilities.decorator import public +from sympy.utilities.lambdify import lambdify + +from mpmath import mp + + +def is_rat(c): + r""" + Test whether an argument is of an acceptable type to be used as a rational + number. + + Explanation + =========== + + Returns ``True`` on any argument of type ``int``, :ref:`ZZ`, or :ref:`QQ`. + + See Also + ======== + + is_int + + """ + # ``c in QQ`` is too accepting (e.g. ``3.14 in QQ`` is ``True``), + # ``QQ.of_type(c)`` is too demanding (e.g. ``QQ.of_type(3)`` is ``False``). + # + # Meanwhile, if gmpy2 is installed then ``ZZ.of_type()`` accepts only + # ``mpz``, not ``int``, so we need another clause to ensure ``int`` is + # accepted. + return isinstance(c, int) or ZZ.of_type(c) or QQ.of_type(c) + + +def is_int(c): + r""" + Test whether an argument is of an acceptable type to be used as an integer. + + Explanation + =========== + + Returns ``True`` on any argument of type ``int`` or :ref:`ZZ`. + + See Also + ======== + + is_rat + + """ + # If gmpy2 is installed then ``ZZ.of_type()`` accepts only + # ``mpz``, not ``int``, so we need another clause to ensure ``int`` is + # accepted. + return isinstance(c, int) or ZZ.of_type(c) + + +def get_num_denom(c): + r""" + Given any argument on which :py:func:`~.is_rat` is ``True``, return the + numerator and denominator of this number. + + See Also + ======== + + is_rat + + """ + r = QQ(c) + return r.numerator, r.denominator + + +@public +def extract_fundamental_discriminant(a): + r""" + Extract a fundamental discriminant from an integer *a*. + + Explanation + =========== + + Given any rational integer *a* that is 0 or 1 mod 4, write $a = d f^2$, + where $d$ is either 1 or a fundamental discriminant, and return a pair + of dictionaries ``(D, F)`` giving the prime factorizations of $d$ and $f$ + respectively, in the same format returned by :py:func:`~.factorint`. + + A fundamental discriminant $d$ is different from unity, and is either + 1 mod 4 and squarefree, or is 0 mod 4 and such that $d/4$ is squarefree + and 2 or 3 mod 4. This is the same as being the discriminant of some + quadratic field. + + Examples + ======== + + >>> from sympy.polys.numberfields.utilities import extract_fundamental_discriminant + >>> print(extract_fundamental_discriminant(-432)) + ({3: 1, -1: 1}, {2: 2, 3: 1}) + + For comparison: + + >>> from sympy import factorint + >>> print(factorint(-432)) + {2: 4, 3: 3, -1: 1} + + Parameters + ========== + + a: int, must be 0 or 1 mod 4 + + Returns + ======= + + Pair ``(D, F)`` of dictionaries. + + Raises + ====== + + ValueError + If *a* is not 0 or 1 mod 4. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + (See Prop. 5.1.3) + + """ + if a % 4 not in [0, 1]: + raise ValueError('To extract fundamental discriminant, number must be 0 or 1 mod 4.') + if a == 0: + return {}, {0: 1} + if a == 1: + return {}, {} + a_factors = factorint(a) + D = {} + F = {} + # First pass: just make d squarefree, and a/d a perfect square. + # We'll count primes (and units! i.e. -1) that are 3 mod 4 and present in d. + num_3_mod_4 = 0 + for p, e in a_factors.items(): + if e % 2 == 1: + D[p] = 1 + if p % 4 == 3: + num_3_mod_4 += 1 + if e >= 3: + F[p] = (e - 1) // 2 + else: + F[p] = e // 2 + # Second pass: if d is cong. to 2 or 3 mod 4, then we must steal away + # another factor of 4 from f**2 and give it to d. + even = 2 in D + if even or num_3_mod_4 % 2 == 1: + e2 = F[2] + assert e2 > 0 + if e2 == 1: + del F[2] + else: + F[2] = e2 - 1 + D[2] = 3 if even else 2 + return D, F + + +@public +class AlgIntPowers: + r""" + Compute the powers of an algebraic integer. + + Explanation + =========== + + Given an algebraic integer $\theta$ by its monic irreducible polynomial + ``T`` over :ref:`ZZ`, this class computes representations of arbitrarily + high powers of $\theta$, as :ref:`ZZ`-linear combinations over + $\{1, \theta, \ldots, \theta^{n-1}\}$, where $n = \deg(T)$. + + The representations are computed using the linear recurrence relations for + powers of $\theta$, derived from the polynomial ``T``. See [1], Sec. 4.2.2. + + Optionally, the representations may be reduced with respect to a modulus. + + Examples + ======== + + >>> from sympy import Poly, cyclotomic_poly + >>> from sympy.polys.numberfields.utilities import AlgIntPowers + >>> T = Poly(cyclotomic_poly(5)) + >>> zeta_pow = AlgIntPowers(T) + >>> print(zeta_pow[0]) + [1, 0, 0, 0] + >>> print(zeta_pow[1]) + [0, 1, 0, 0] + >>> print(zeta_pow[4]) # doctest: +SKIP + [-1, -1, -1, -1] + >>> print(zeta_pow[24]) # doctest: +SKIP + [-1, -1, -1, -1] + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory.* + + """ + + def __init__(self, T, modulus=None): + """ + Parameters + ========== + + T : :py:class:`~.Poly` + The monic irreducible polynomial over :ref:`ZZ` defining the + algebraic integer. + + modulus : int, None, optional + If not ``None``, all representations will be reduced w.r.t. this. + + """ + self.T = T + self.modulus = modulus + self.n = T.degree() + self.powers_n_and_up = [[-c % self for c in reversed(T.rep.to_list())][:-1]] + self.max_so_far = self.n + + def red(self, exp): + return exp if self.modulus is None else exp % self.modulus + + def __rmod__(self, other): + return self.red(other) + + def compute_up_through(self, e): + m = self.max_so_far + if e <= m: return + n = self.n + r = self.powers_n_and_up + c = r[0] + for k in range(m+1, e+1): + b = r[k-1-n][n-1] + r.append( + [c[0]*b % self] + [ + (r[k-1-n][i-1] + c[i]*b) % self for i in range(1, n) + ] + ) + self.max_so_far = e + + def get(self, e): + n = self.n + if e < 0: + raise ValueError('Exponent must be non-negative.') + elif e < n: + return [1 if i == e else 0 for i in range(n)] + else: + self.compute_up_through(e) + return self.powers_n_and_up[e - n] + + def __getitem__(self, item): + return self.get(item) + + +@public +def coeff_search(m, R): + r""" + Generate coefficients for searching through polynomials. + + Explanation + =========== + + Lead coeff is always non-negative. Explore all combinations with coeffs + bounded in absolute value before increasing the bound. Skip the all-zero + list, and skip any repeats. See examples. + + Examples + ======== + + >>> from sympy.polys.numberfields.utilities import coeff_search + >>> cs = coeff_search(2, 1) + >>> C = [next(cs) for i in range(13)] + >>> print(C) + [[1, 1], [1, 0], [1, -1], [0, 1], [2, 2], [2, 1], [2, 0], [2, -1], [2, -2], + [1, 2], [1, -2], [0, 2], [3, 3]] + + Parameters + ========== + + m : int + Length of coeff list. + R : int + Initial max abs val for coeffs (will increase as search proceeds). + + Returns + ======= + + generator + Infinite generator of lists of coefficients. + + """ + R0 = R + c = [R] * m + while True: + if R == R0 or R in c or -R in c: + yield c[:] + j = m - 1 + while c[j] == -R: + j -= 1 + c[j] -= 1 + for i in range(j + 1, m): + c[i] = R + for j in range(m): + if c[j] != 0: + break + else: + R += 1 + c = [R] * m + + +def supplement_a_subspace(M): + r""" + Extend a basis for a subspace to a basis for the whole space. + + Explanation + =========== + + Given an $n \times r$ matrix *M* of rank $r$ (so $r \leq n$), this function + computes an invertible $n \times n$ matrix $B$ such that the first $r$ + columns of $B$ equal *M*. + + This operation can be interpreted as a way of extending a basis for a + subspace, to give a basis for the whole space. + + To be precise, suppose you have an $n$-dimensional vector space $V$, with + basis $\{v_1, v_2, \ldots, v_n\}$, and an $r$-dimensional subspace $W$ of + $V$, spanned by a basis $\{w_1, w_2, \ldots, w_r\}$, where the $w_j$ are + given as linear combinations of the $v_i$. If the columns of *M* represent + the $w_j$ as such linear combinations, then the columns of the matrix $B$ + computed by this function give a new basis $\{u_1, u_2, \ldots, u_n\}$ for + $V$, again relative to the $\{v_i\}$ basis, and such that $u_j = w_j$ + for $1 \leq j \leq r$. + + Examples + ======== + + Note: The function works in terms of columns, so in these examples we + print matrix transposes in order to make the columns easier to inspect. + + >>> from sympy.polys.matrices import DM + >>> from sympy import QQ, FF + >>> from sympy.polys.numberfields.utilities import supplement_a_subspace + >>> M = DM([[1, 7, 0], [2, 3, 4]], QQ).transpose() + >>> print(supplement_a_subspace(M).to_Matrix().transpose()) + Matrix([[1, 7, 0], [2, 3, 4], [1, 0, 0]]) + + >>> M2 = M.convert_to(FF(7)) + >>> print(M2.to_Matrix().transpose()) + Matrix([[1, 0, 0], [2, 3, -3]]) + >>> print(supplement_a_subspace(M2).to_Matrix().transpose()) + Matrix([[1, 0, 0], [2, 3, -3], [0, 1, 0]]) + + Parameters + ========== + + M : :py:class:`~.DomainMatrix` + The columns give the basis for the subspace. + + Returns + ======= + + :py:class:`~.DomainMatrix` + This matrix is invertible and its first $r$ columns equal *M*. + + Raises + ====== + + DMRankError + If *M* was not of maximal rank. + + References + ========== + + .. [1] Cohen, H. *A Course in Computational Algebraic Number Theory* + (See Sec. 2.3.2.) + + """ + n, r = M.shape + # Let In be the n x n identity matrix. + # Form the augmented matrix [M | In] and compute RREF. + Maug = M.hstack(M.eye(n, M.domain)) + R, pivots = Maug.rref() + if pivots[:r] != tuple(range(r)): + raise DMRankError('M was not of maximal rank') + # Let J be the n x r matrix equal to the first r columns of In. + # Since M is of rank r, RREF reduces [M | In] to [J | A], where A is the product of + # elementary matrices Ei corresp. to the row ops performed by RREF. Since the Ei are + # invertible, so is A. Let B = A^(-1). + A = R[:, r:] + B = A.inv() + # Then B is the desired matrix. It is invertible, since B^(-1) == A. + # And A * [M | In] == [J | A] + # => A * M == J + # => M == B * J == the first r columns of B. + return B + + +@public +def isolate(alg, eps=None, fast=False): + """ + Find a rational isolating interval for a real algebraic number. + + Examples + ======== + + >>> from sympy import isolate, sqrt, Rational + >>> print(isolate(sqrt(2))) # doctest: +SKIP + (1, 2) + >>> print(isolate(sqrt(2), eps=Rational(1, 100))) + (24/17, 17/12) + + Parameters + ========== + + alg : str, int, :py:class:`~.Expr` + The algebraic number to be isolated. Must be a real number, to use this + particular function. However, see also :py:meth:`.Poly.intervals`, + which isolates complex roots when you pass ``all=True``. + eps : positive element of :ref:`QQ`, None, optional (default=None) + Precision to be passed to :py:meth:`.Poly.refine_root` + fast : boolean, optional (default=False) + Say whether fast refinement procedure should be used. + (Will be passed to :py:meth:`.Poly.refine_root`.) + + Returns + ======= + + Pair of rational numbers defining an isolating interval for the given + algebraic number. + + See Also + ======== + + .Poly.intervals + + """ + alg = sympify(alg) + + if alg.is_Rational: + return (alg, alg) + elif not alg.is_real: + raise NotImplementedError( + "complex algebraic numbers are not supported") + + func = lambdify((), alg, modules="mpmath", printer=IntervalPrinter()) + + poly = minpoly(alg, polys=True) + intervals = poly.intervals(sqf=True) + + dps, done = mp.dps, False + + try: + while not done: + alg = func() + + for a, b in intervals: + if a <= alg.a and alg.b <= b: + done = True + break + else: + mp.dps *= 2 + finally: + mp.dps = dps + + if eps is not None: + a, b = poly.refine_root(a, b, eps=eps, fast=fast) + + return (a, b) diff --git a/lib/python3.10/site-packages/sympy/polys/tests/test_modulargcd.py b/lib/python3.10/site-packages/sympy/polys/tests/test_modulargcd.py new file mode 100644 index 0000000000000000000000000000000000000000..235fb8df0a5c582a82626326aedc0d6727b1c21a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/tests/test_modulargcd.py @@ -0,0 +1,325 @@ +from sympy.polys.rings import ring +from sympy.polys.domains import ZZ, QQ, AlgebraicField +from sympy.polys.modulargcd import ( + modgcd_univariate, + modgcd_bivariate, + _chinese_remainder_reconstruction_multivariate, + modgcd_multivariate, + _to_ZZ_poly, + _to_ANP_poly, + func_field_modgcd, + _func_field_modgcd_m) +from sympy.functions.elementary.miscellaneous import sqrt + + +def test_modgcd_univariate_integers(): + R, x = ring("x", ZZ) + + f, g = R.zero, R.zero + assert modgcd_univariate(f, g) == (0, 0, 0) + + f, g = R.zero, x + assert modgcd_univariate(f, g) == (x, 0, 1) + assert modgcd_univariate(g, f) == (x, 1, 0) + + f, g = R.zero, -x + assert modgcd_univariate(f, g) == (x, 0, -1) + assert modgcd_univariate(g, f) == (x, -1, 0) + + f, g = 2*x, R(2) + assert modgcd_univariate(f, g) == (2, x, 1) + + f, g = 2*x + 2, 6*x**2 - 6 + assert modgcd_univariate(f, g) == (2*x + 2, 1, 3*x - 3) + + f = x**4 + 8*x**3 + 21*x**2 + 22*x + 8 + g = x**3 + 6*x**2 + 11*x + 6 + + h = x**2 + 3*x + 2 + + cff = x**2 + 5*x + 4 + cfg = x + 3 + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + f = x**4 - 4 + g = x**4 + 4*x**2 + 4 + + h = x**2 + 2 + + cff = x**2 - 2 + cfg = x**2 + 2 + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + f = x**8 + x**6 - 3*x**4 - 3*x**3 + 8*x**2 + 2*x - 5 + g = 3*x**6 + 5*x**4 - 4*x**2 - 9*x + 21 + + h = 1 + + cff = f + cfg = g + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + f = - 352518131239247345597970242177235495263669787845475025293906825864749649589178600387510272*x**49 \ + + 46818041807522713962450042363465092040687472354933295397472942006618953623327997952*x**42 \ + + 378182690892293941192071663536490788434899030680411695933646320291525827756032*x**35 \ + + 112806468807371824947796775491032386836656074179286744191026149539708928*x**28 \ + - 12278371209708240950316872681744825481125965781519138077173235712*x**21 \ + + 289127344604779611146960547954288113529690984687482920704*x**14 \ + + 19007977035740498977629742919480623972236450681*x**7 \ + + 311973482284542371301330321821976049 + + g = 365431878023781158602430064717380211405897160759702125019136*x**21 \ + + 197599133478719444145775798221171663643171734081650688*x**14 \ + - 9504116979659010018253915765478924103928886144*x**7 \ + - 311973482284542371301330321821976049 + + assert modgcd_univariate(f, f.diff(x))[0] == g + + f = 1317378933230047068160*x + 2945748836994210856960 + g = 120352542776360960*x + 269116466014453760 + + h = 120352542776360960*x + 269116466014453760 + cff = 10946 + cfg = 1 + + assert modgcd_univariate(f, g) == (h, cff, cfg) + + +def test_modgcd_bivariate_integers(): + R, x, y = ring("x,y", ZZ) + + f, g = R.zero, R.zero + assert modgcd_bivariate(f, g) == (0, 0, 0) + + f, g = 2*x, R(2) + assert modgcd_bivariate(f, g) == (2, x, 1) + + f, g = x + 2*y, x + y + assert modgcd_bivariate(f, g) == (1, f, g) + + f, g = x**2 + 2*x*y + y**2, x**3 + y**3 + assert modgcd_bivariate(f, g) == (x + y, x + y, x**2 - x*y + y**2) + + f, g = x*y**2 + 2*x*y + x, x*y**3 + x + assert modgcd_bivariate(f, g) == (x*y + x, y + 1, y**2 - y + 1) + + f, g = x**2*y**2 + x**2*y + 1, x*y**2 + x*y + 1 + assert modgcd_bivariate(f, g) == (1, f, g) + + f = 2*x*y**2 + 4*x*y + 2*x + y**2 + 2*y + 1 + g = 2*x*y**3 + 2*x + y**3 + 1 + assert modgcd_bivariate(f, g) == (2*x*y + 2*x + y + 1, y + 1, y**2 - y + 1) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert modgcd_bivariate(f, g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert modgcd_bivariate(f, g) == (x + 1, 1, 2*x + 2) + + f = 2*x**2 + 4*x*y - 2*x - 4*y + g = x**2 + x - 2 + assert modgcd_bivariate(f, g) == (x - 1, 2*x + 4*y, x + 2) + + f = 2*x**2 + 2*x*y - 3*x - 3*y + g = 4*x*y - 2*x + 4*y**2 - 2*y + assert modgcd_bivariate(f, g) == (x + y, 2*x - 3, 4*y - 2) + + +def test_chinese_remainder(): + R, x, y = ring("x, y", ZZ) + p, q = 3, 5 + + hp = x**3*y - x**2 - 1 + hq = -x**3*y - 2*x*y**2 + 2 + + hpq = _chinese_remainder_reconstruction_multivariate(hp, hq, p, q) + + assert hpq.trunc_ground(p) == hp + assert hpq.trunc_ground(q) == hq + + T, z = ring("z", R) + p, q = 3, 7 + + hp = (x*y + 1)*z**2 + x + hq = (x**2 - 3*y)*z + 2 + + hpq = _chinese_remainder_reconstruction_multivariate(hp, hq, p, q) + + assert hpq.trunc_ground(p) == hp + assert hpq.trunc_ground(q) == hq + + +def test_modgcd_multivariate_integers(): + R, x, y = ring("x,y", ZZ) + + f, g = R.zero, R.zero + assert modgcd_multivariate(f, g) == (0, 0, 0) + + f, g = 2*x**2 + 4*x + 2, x + 1 + assert modgcd_multivariate(f, g) == (x + 1, 2*x + 2, 1) + + f, g = x + 1, 2*x**2 + 4*x + 2 + assert modgcd_multivariate(f, g) == (x + 1, 1, 2*x + 2) + + f = 2*x**2 + 2*x*y - 3*x - 3*y + g = 4*x*y - 2*x + 4*y**2 - 2*y + assert modgcd_multivariate(f, g) == (x + y, 2*x - 3, 4*y - 2) + + f, g = x*y**2 + 2*x*y + x, x*y**3 + x + assert modgcd_multivariate(f, g) == (x*y + x, y + 1, y**2 - y + 1) + + f, g = x**2*y**2 + x**2*y + 1, x*y**2 + x*y + 1 + assert modgcd_multivariate(f, g) == (1, f, g) + + f = x**4 + 8*x**3 + 21*x**2 + 22*x + 8 + g = x**3 + 6*x**2 + 11*x + 6 + + h = x**2 + 3*x + 2 + + cff = x**2 + 5*x + 4 + cfg = x + 3 + + assert modgcd_multivariate(f, g) == (h, cff, cfg) + + R, x, y, z, u = ring("x,y,z,u", ZZ) + + f, g = x + y + z, -x - y - z - u + assert modgcd_multivariate(f, g) == (1, f, g) + + f, g = u**2 + 2*u + 1, 2*u + 2 + assert modgcd_multivariate(f, g) == (u + 1, u + 1, 2) + + f, g = z**2*u**2 + 2*z**2*u + z**2 + z*u + z, u**2 + 2*u + 1 + h, cff, cfg = u + 1, z**2*u + z**2 + z, u + 1 + + assert modgcd_multivariate(f, g) == (h, cff, cfg) + assert modgcd_multivariate(g, f) == (h, cfg, cff) + + R, x, y, z = ring("x,y,z", ZZ) + + f, g = x - y*z, x - y*z + assert modgcd_multivariate(f, g) == (x - y*z, 1, 1) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v = ring("x,y,z,u,v", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v, a, b = ring("x,y,z,u,v,a,b", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, u, v, a, b, c, d = ring("x,y,z,u,v,a,b,c,d", ZZ) + + f, g, h = R.fateman_poly_F_1() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z = ring("x,y,z", ZZ) + + f, g, h = R.fateman_poly_F_2() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + f, g, h = R.fateman_poly_F_3() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + R, x, y, z, t = ring("x,y,z,t", ZZ) + + f, g, h = R.fateman_poly_F_3() + H, cff, cfg = modgcd_multivariate(f, g) + + assert H == h and H*cff == f and H*cfg == g + + +def test_to_ZZ_ANP_poly(): + A = AlgebraicField(QQ, sqrt(2)) + R, x = ring("x", A) + f = x*(sqrt(2) + 1) + + T, x_, z_ = ring("x_, z_", ZZ) + f_ = x_*z_ + x_ + + assert _to_ZZ_poly(f, T) == f_ + assert _to_ANP_poly(f_, R) == f + + R, x, t, s = ring("x, t, s", A) + f = x*t**2 + x*s + sqrt(2) + + D, t_, s_ = ring("t_, s_", ZZ) + T, x_, z_ = ring("x_, z_", D) + f_ = (t_**2 + s_)*x_ + z_ + + assert _to_ZZ_poly(f, T) == f_ + assert _to_ANP_poly(f_, R) == f + + +def test_modgcd_algebraic_field(): + A = AlgebraicField(QQ, sqrt(2)) + R, x = ring("x", A) + one = A.one + + f, g = 2*x, R(2) + assert func_field_modgcd(f, g) == (one, f, g) + + f, g = 2*x, R(sqrt(2)) + assert func_field_modgcd(f, g) == (one, f, g) + + f, g = 2*x + 2, 6*x**2 - 6 + assert func_field_modgcd(f, g) == (x + 1, R(2), 6*x - 6) + + R, x, y = ring("x, y", A) + + f, g = x + sqrt(2)*y, x + y + assert func_field_modgcd(f, g) == (one, f, g) + + f, g = x*y + sqrt(2)*y**2, R(sqrt(2))*y + assert func_field_modgcd(f, g) == (y, x + sqrt(2)*y, R(sqrt(2))) + + f, g = x**2 + 2*sqrt(2)*x*y + 2*y**2, x + sqrt(2)*y + assert func_field_modgcd(f, g) == (g, g, one) + + A = AlgebraicField(QQ, sqrt(2), sqrt(3)) + R, x, y, z = ring("x, y, z", A) + + h = x**2*y**7 + sqrt(6)/21*z + f, g = h*(27*y**3 + 1), h*(y + x) + assert func_field_modgcd(f, g) == (h, 27*y**3+1, y+x) + + h = x**13*y**3 + 1/2*x**10 + 1/sqrt(2) + f, g = h*(x + 1), h*sqrt(2)/sqrt(3) + assert func_field_modgcd(f, g) == (h, x + 1, R(sqrt(2)/sqrt(3))) + + A = AlgebraicField(QQ, sqrt(2)**(-1)*sqrt(3)) + R, x = ring("x", A) + + f, g = x + 1, x - 1 + assert func_field_modgcd(f, g) == (A.one, f, g) + + +# when func_field_modgcd suppors function fields, this test can be changed +def test_modgcd_func_field(): + D, t = ring("t", ZZ) + R, x, z = ring("x, z", D) + + minpoly = (z**2*t**2 + z**2*t - 1).drop(0) + f, g = x + 1, x - 1 + + assert _func_field_modgcd_m(f, g, minpoly) == R.one diff --git a/lib/python3.10/site-packages/sympy/polys/tests/test_monomials.py b/lib/python3.10/site-packages/sympy/polys/tests/test_monomials.py new file mode 100644 index 0000000000000000000000000000000000000000..c5ed28ba0e8e3f8e9f85c543a4fffcaef855fff8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/tests/test_monomials.py @@ -0,0 +1,269 @@ +"""Tests for tools and arithmetics for monomials of distributed polynomials. """ + +from sympy.polys.monomials import ( + itermonomials, monomial_count, + monomial_mul, monomial_div, + monomial_gcd, monomial_lcm, + monomial_max, monomial_min, + monomial_divides, monomial_pow, + Monomial, +) + +from sympy.polys.polyerrors import ExactQuotientFailed + +from sympy.abc import a, b, c, x, y, z +from sympy.core import S, symbols +from sympy.testing.pytest import raises + +def test_monomials(): + + # total_degree tests + assert set(itermonomials([], 0)) == {S.One} + assert set(itermonomials([], 1)) == {S.One} + assert set(itermonomials([], 2)) == {S.One} + + assert set(itermonomials([], 0, 0)) == {S.One} + assert set(itermonomials([], 1, 0)) == {S.One} + assert set(itermonomials([], 2, 0)) == {S.One} + + raises(StopIteration, lambda: next(itermonomials([], 0, 1))) + raises(StopIteration, lambda: next(itermonomials([], 0, 2))) + raises(StopIteration, lambda: next(itermonomials([], 0, 3))) + + assert set(itermonomials([], 0, 1)) == set() + assert set(itermonomials([], 0, 2)) == set() + assert set(itermonomials([], 0, 3)) == set() + + raises(ValueError, lambda: set(itermonomials([], -1))) + raises(ValueError, lambda: set(itermonomials([x], -1))) + raises(ValueError, lambda: set(itermonomials([x, y], -1))) + + assert set(itermonomials([x], 0)) == {S.One} + assert set(itermonomials([x], 1)) == {S.One, x} + assert set(itermonomials([x], 2)) == {S.One, x, x**2} + assert set(itermonomials([x], 3)) == {S.One, x, x**2, x**3} + + assert set(itermonomials([x, y], 0)) == {S.One} + assert set(itermonomials([x, y], 1)) == {S.One, x, y} + assert set(itermonomials([x, y], 2)) == {S.One, x, y, x**2, y**2, x*y} + assert set(itermonomials([x, y], 3)) == \ + {S.One, x, y, x**2, x**3, y**2, y**3, x*y, x*y**2, y*x**2} + + i, j, k = symbols('i j k', commutative=False) + assert set(itermonomials([i, j, k], 0)) == {S.One} + assert set(itermonomials([i, j, k], 1)) == {S.One, i, j, k} + assert set(itermonomials([i, j, k], 2)) == \ + {S.One, i, j, k, i**2, j**2, k**2, i*j, i*k, j*i, j*k, k*i, k*j} + + assert set(itermonomials([i, j, k], 3)) == \ + {S.One, i, j, k, i**2, j**2, k**2, i*j, i*k, j*i, j*k, k*i, k*j, + i**3, j**3, k**3, + i**2 * j, i**2 * k, j * i**2, k * i**2, + j**2 * i, j**2 * k, i * j**2, k * j**2, + k**2 * i, k**2 * j, i * k**2, j * k**2, + i*j*i, i*k*i, j*i*j, j*k*j, k*i*k, k*j*k, + i*j*k, i*k*j, j*i*k, j*k*i, k*i*j, k*j*i, + } + + assert set(itermonomials([x, i, j], 0)) == {S.One} + assert set(itermonomials([x, i, j], 1)) == {S.One, x, i, j} + assert set(itermonomials([x, i, j], 2)) == {S.One, x, i, j, x*i, x*j, i*j, j*i, x**2, i**2, j**2} + assert set(itermonomials([x, i, j], 3)) == \ + {S.One, x, i, j, x*i, x*j, i*j, j*i, x**2, i**2, j**2, + x**3, i**3, j**3, + x**2 * i, x**2 * j, + x * i**2, j * i**2, i**2 * j, i*j*i, + x * j**2, i * j**2, j**2 * i, j*i*j, + x * i * j, x * j * i + } + + # degree_list tests + assert set(itermonomials([], [])) == {S.One} + + raises(ValueError, lambda: set(itermonomials([], [0]))) + raises(ValueError, lambda: set(itermonomials([], [1]))) + raises(ValueError, lambda: set(itermonomials([], [2]))) + + raises(ValueError, lambda: set(itermonomials([x], [1], []))) + raises(ValueError, lambda: set(itermonomials([x], [1, 2], []))) + raises(ValueError, lambda: set(itermonomials([x], [1, 2, 3], []))) + + raises(ValueError, lambda: set(itermonomials([x], [], [1]))) + raises(ValueError, lambda: set(itermonomials([x], [], [1, 2]))) + raises(ValueError, lambda: set(itermonomials([x], [], [1, 2, 3]))) + + raises(ValueError, lambda: set(itermonomials([x, y], [1, 2], [1, 2, 3]))) + raises(ValueError, lambda: set(itermonomials([x, y, z], [1, 2, 3], [0, 1]))) + + raises(ValueError, lambda: set(itermonomials([x], [1], [-1]))) + raises(ValueError, lambda: set(itermonomials([x, y], [1, 2], [1, -1]))) + + raises(ValueError, lambda: set(itermonomials([], [], 1))) + raises(ValueError, lambda: set(itermonomials([], [], 2))) + raises(ValueError, lambda: set(itermonomials([], [], 3))) + + raises(ValueError, lambda: set(itermonomials([x, y], [0, 1], [1, 2]))) + raises(ValueError, lambda: set(itermonomials([x, y, z], [0, 0, 3], [0, 1, 2]))) + + assert set(itermonomials([x], [0])) == {S.One} + assert set(itermonomials([x], [1])) == {S.One, x} + assert set(itermonomials([x], [2])) == {S.One, x, x**2} + assert set(itermonomials([x], [3])) == {S.One, x, x**2, x**3} + + assert set(itermonomials([x], [3], [1])) == {x, x**3, x**2} + assert set(itermonomials([x], [3], [2])) == {x**3, x**2} + + assert set(itermonomials([x, y], 3, 3)) == {x**3, x**2*y, x*y**2, y**3} + assert set(itermonomials([x, y], 3, 2)) == {x**2, x*y, y**2, x**3, x**2*y, x*y**2, y**3} + + assert set(itermonomials([x, y], [0, 0])) == {S.One} + assert set(itermonomials([x, y], [0, 1])) == {S.One, y} + assert set(itermonomials([x, y], [0, 2])) == {S.One, y, y**2} + assert set(itermonomials([x, y], [0, 2], [0, 1])) == {y, y**2} + assert set(itermonomials([x, y], [0, 2], [0, 2])) == {y**2} + + assert set(itermonomials([x, y], [1, 0])) == {S.One, x} + assert set(itermonomials([x, y], [1, 1])) == {S.One, x, y, x*y} + assert set(itermonomials([x, y], [1, 2])) == {S.One, x, y, x*y, y**2, x*y**2} + assert set(itermonomials([x, y], [1, 2], [1, 1])) == {x*y, x*y**2} + assert set(itermonomials([x, y], [1, 2], [1, 2])) == {x*y**2} + + assert set(itermonomials([x, y], [2, 0])) == {S.One, x, x**2} + assert set(itermonomials([x, y], [2, 1])) == {S.One, x, y, x*y, x**2, x**2*y} + assert set(itermonomials([x, y], [2, 2])) == \ + {S.One, y**2, x*y**2, x, x*y, x**2, x**2*y**2, y, x**2*y} + + i, j, k = symbols('i j k', commutative=False) + assert set(itermonomials([i, j, k], 2, 2)) == \ + {k*i, i**2, i*j, j*k, j*i, k**2, j**2, k*j, i*k} + assert set(itermonomials([i, j, k], 3, 2)) == \ + {j*k**2, i*k**2, k*i*j, k*i**2, k**2, j*k*j, k*j**2, i*k*i, i*j, + j**2*k, i**2*j, j*i*k, j**3, i**3, k*j*i, j*k*i, j*i, + k**2*j, j*i**2, k*j, k*j*k, i*j*i, j*i*j, i*j**2, j**2, + k*i*k, i**2, j*k, i*k, i*k*j, k**3, i**2*k, j**2*i, k**2*i, + i*j*k, k*i + } + assert set(itermonomials([i, j, k], [0, 0, 0])) == {S.One} + assert set(itermonomials([i, j, k], [0, 0, 1])) == {1, k} + assert set(itermonomials([i, j, k], [0, 1, 0])) == {1, j} + assert set(itermonomials([i, j, k], [1, 0, 0])) == {i, 1} + assert set(itermonomials([i, j, k], [0, 0, 2])) == {k**2, 1, k} + assert set(itermonomials([i, j, k], [0, 2, 0])) == {1, j, j**2} + assert set(itermonomials([i, j, k], [2, 0, 0])) == {i, 1, i**2} + assert set(itermonomials([i, j, k], [1, 1, 1])) == {1, k, j, j*k, i*k, i, i*j, i*j*k} + assert set(itermonomials([i, j, k], [2, 2, 2])) == \ + {1, k, i**2*k**2, j*k, j**2, i, i*k, j*k**2, i*j**2*k**2, + i**2*j, i**2*j**2, k**2, j**2*k, i*j**2*k, + j**2*k**2, i*j, i**2*k, i**2*j**2*k, j, i**2*j*k, + i*j**2, i*k**2, i*j*k, i**2*j**2*k**2, i*j*k**2, i**2, i**2*j*k**2 + } + + assert set(itermonomials([x, j, k], [0, 0, 0])) == {S.One} + assert set(itermonomials([x, j, k], [0, 0, 1])) == {1, k} + assert set(itermonomials([x, j, k], [0, 1, 0])) == {1, j} + assert set(itermonomials([x, j, k], [1, 0, 0])) == {x, 1} + assert set(itermonomials([x, j, k], [0, 0, 2])) == {k**2, 1, k} + assert set(itermonomials([x, j, k], [0, 2, 0])) == {1, j, j**2} + assert set(itermonomials([x, j, k], [2, 0, 0])) == {x, 1, x**2} + assert set(itermonomials([x, j, k], [1, 1, 1])) == {1, k, j, j*k, x*k, x, x*j, x*j*k} + assert set(itermonomials([x, j, k], [2, 2, 2])) == \ + {1, k, x**2*k**2, j*k, j**2, x, x*k, j*k**2, x*j**2*k**2, + x**2*j, x**2*j**2, k**2, j**2*k, x*j**2*k, + j**2*k**2, x*j, x**2*k, x**2*j**2*k, j, x**2*j*k, + x*j**2, x*k**2, x*j*k, x**2*j**2*k**2, x*j*k**2, x**2, x**2*j*k**2 + } + +def test_monomial_count(): + assert monomial_count(2, 2) == 6 + assert monomial_count(2, 3) == 10 + +def test_monomial_mul(): + assert monomial_mul((3, 4, 1), (1, 2, 0)) == (4, 6, 1) + +def test_monomial_div(): + assert monomial_div((3, 4, 1), (1, 2, 0)) == (2, 2, 1) + +def test_monomial_gcd(): + assert monomial_gcd((3, 4, 1), (1, 2, 0)) == (1, 2, 0) + +def test_monomial_lcm(): + assert monomial_lcm((3, 4, 1), (1, 2, 0)) == (3, 4, 1) + +def test_monomial_max(): + assert monomial_max((3, 4, 5), (0, 5, 1), (6, 3, 9)) == (6, 5, 9) + +def test_monomial_pow(): + assert monomial_pow((1, 2, 3), 3) == (3, 6, 9) + +def test_monomial_min(): + assert monomial_min((3, 4, 5), (0, 5, 1), (6, 3, 9)) == (0, 3, 1) + +def test_monomial_divides(): + assert monomial_divides((1, 2, 3), (4, 5, 6)) is True + assert monomial_divides((1, 2, 3), (0, 5, 6)) is False + +def test_Monomial(): + m = Monomial((3, 4, 1), (x, y, z)) + n = Monomial((1, 2, 0), (x, y, z)) + + assert m.as_expr() == x**3*y**4*z + assert n.as_expr() == x**1*y**2 + + assert m.as_expr(a, b, c) == a**3*b**4*c + assert n.as_expr(a, b, c) == a**1*b**2 + + assert m.exponents == (3, 4, 1) + assert m.gens == (x, y, z) + + assert n.exponents == (1, 2, 0) + assert n.gens == (x, y, z) + + assert m == (3, 4, 1) + assert n != (3, 4, 1) + assert m != (1, 2, 0) + assert n == (1, 2, 0) + assert (m == 1) is False + + assert m[0] == m[-3] == 3 + assert m[1] == m[-2] == 4 + assert m[2] == m[-1] == 1 + + assert n[0] == n[-3] == 1 + assert n[1] == n[-2] == 2 + assert n[2] == n[-1] == 0 + + assert m[:2] == (3, 4) + assert n[:2] == (1, 2) + + assert m*n == Monomial((4, 6, 1)) + assert m/n == Monomial((2, 2, 1)) + + assert m*(1, 2, 0) == Monomial((4, 6, 1)) + assert m/(1, 2, 0) == Monomial((2, 2, 1)) + + assert m.gcd(n) == Monomial((1, 2, 0)) + assert m.lcm(n) == Monomial((3, 4, 1)) + + assert m.gcd((1, 2, 0)) == Monomial((1, 2, 0)) + assert m.lcm((1, 2, 0)) == Monomial((3, 4, 1)) + + assert m**0 == Monomial((0, 0, 0)) + assert m**1 == m + assert m**2 == Monomial((6, 8, 2)) + assert m**3 == Monomial((9, 12, 3)) + _a = Monomial((0, 0, 0)) + for n in range(10): + assert _a == m**n + _a *= m + + raises(ExactQuotientFailed, lambda: m/Monomial((5, 2, 0))) + + mm = Monomial((1, 2, 3)) + raises(ValueError, lambda: mm.as_expr()) + assert str(mm) == 'Monomial((1, 2, 3))' + assert str(m) == 'x**3*y**4*z**1' + raises(NotImplementedError, lambda: m*1) + raises(NotImplementedError, lambda: m/1) + raises(ValueError, lambda: m**-1) + raises(TypeError, lambda: m.gcd(3)) + raises(TypeError, lambda: m.lcm(3)) diff --git a/lib/python3.10/site-packages/sympy/polys/tests/test_multivariate_resultants.py b/lib/python3.10/site-packages/sympy/polys/tests/test_multivariate_resultants.py new file mode 100644 index 0000000000000000000000000000000000000000..0799feb41fc875cf038723916a3efd62ff31b1b4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/tests/test_multivariate_resultants.py @@ -0,0 +1,294 @@ +"""Tests for Dixon's and Macaulay's classes. """ + +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import factor +from sympy.core import symbols +from sympy.tensor.indexed import IndexedBase + +from sympy.polys.multivariate_resultants import (DixonResultant, + MacaulayResultant) + +c, d = symbols("a, b") +x, y = symbols("x, y") + +p = c * x + y +q = x + d * y + +dixon = DixonResultant(polynomials=[p, q], variables=[x, y]) +macaulay = MacaulayResultant(polynomials=[p, q], variables=[x, y]) + +def test_dixon_resultant_init(): + """Test init method of DixonResultant.""" + a = IndexedBase("alpha") + + assert dixon.polynomials == [p, q] + assert dixon.variables == [x, y] + assert dixon.n == 2 + assert dixon.m == 2 + assert dixon.dummy_variables == [a[0], a[1]] + +def test_get_dixon_polynomial_numerical(): + """Test Dixon's polynomial for a numerical example.""" + a = IndexedBase("alpha") + + p = x + y + q = x ** 2 + y **3 + h = x ** 2 + y + + dixon = DixonResultant([p, q, h], [x, y]) + polynomial = -x * y ** 2 * a[0] - x * y ** 2 * a[1] - x * y * a[0] \ + * a[1] - x * y * a[1] ** 2 - x * a[0] * a[1] ** 2 + x * a[0] - \ + y ** 2 * a[0] * a[1] + y ** 2 * a[1] - y * a[0] * a[1] ** 2 + y * \ + a[1] ** 2 + + assert dixon.get_dixon_polynomial().as_expr().expand() == polynomial + +def test_get_max_degrees(): + """Tests max degrees function.""" + + p = x + y + q = x ** 2 + y **3 + h = x ** 2 + y + + dixon = DixonResultant(polynomials=[p, q, h], variables=[x, y]) + dixon_polynomial = dixon.get_dixon_polynomial() + + assert dixon.get_max_degrees(dixon_polynomial) == [1, 2] + +def test_get_dixon_matrix(): + """Test Dixon's resultant for a numerical example.""" + + x, y = symbols('x, y') + + p = x + y + q = x ** 2 + y ** 3 + h = x ** 2 + y + + dixon = DixonResultant([p, q, h], [x, y]) + polynomial = dixon.get_dixon_polynomial() + + assert dixon.get_dixon_matrix(polynomial).det() == 0 + +def test_get_dixon_matrix_example_two(): + """Test Dixon's matrix for example from [Palancz08]_.""" + x, y, z = symbols('x, y, z') + + f = x ** 2 + y ** 2 - 1 + z * 0 + g = x ** 2 + z ** 2 - 1 + y * 0 + h = y ** 2 + z ** 2 - 1 + + example_two = DixonResultant([f, g, h], [y, z]) + poly = example_two.get_dixon_polynomial() + matrix = example_two.get_dixon_matrix(poly) + + expr = 1 - 8 * x ** 2 + 24 * x ** 4 - 32 * x ** 6 + 16 * x ** 8 + assert (matrix.det() - expr).expand() == 0 + +def test_KSY_precondition(): + """Tests precondition for KSY Resultant.""" + A, B, C = symbols('A, B, C') + + m1 = Matrix([[1, 2, 3], + [4, 5, 12], + [6, 7, 18]]) + + m2 = Matrix([[0, C**2], + [-2 * C, -C ** 2]]) + + m3 = Matrix([[1, 0], + [0, 1]]) + + m4 = Matrix([[A**2, 0, 1], + [A, 1, 1 / A]]) + + m5 = Matrix([[5, 1], + [2, B], + [0, 1], + [0, 0]]) + + assert dixon.KSY_precondition(m1) == False + assert dixon.KSY_precondition(m2) == True + assert dixon.KSY_precondition(m3) == True + assert dixon.KSY_precondition(m4) == False + assert dixon.KSY_precondition(m5) == True + +def test_delete_zero_rows_and_columns(): + """Tests method for deleting rows and columns containing only zeros.""" + A, B, C = symbols('A, B, C') + + m1 = Matrix([[0, 0], + [0, 0], + [1, 2]]) + + m2 = Matrix([[0, 1, 2], + [0, 3, 4], + [0, 5, 6]]) + + m3 = Matrix([[0, 0, 0, 0], + [0, 1, 2, 0], + [0, 3, 4, 0], + [0, 0, 0, 0]]) + + m4 = Matrix([[1, 0, 2], + [0, 0, 0], + [3, 0, 4]]) + + m5 = Matrix([[0, 0, 0, 1], + [0, 0, 0, 2], + [0, 0, 0, 3], + [0, 0, 0, 4]]) + + m6 = Matrix([[0, 0, A], + [B, 0, 0], + [0, 0, C]]) + + assert dixon.delete_zero_rows_and_columns(m1) == Matrix([[1, 2]]) + + assert dixon.delete_zero_rows_and_columns(m2) == Matrix([[1, 2], + [3, 4], + [5, 6]]) + + assert dixon.delete_zero_rows_and_columns(m3) == Matrix([[1, 2], + [3, 4]]) + + assert dixon.delete_zero_rows_and_columns(m4) == Matrix([[1, 2], + [3, 4]]) + + assert dixon.delete_zero_rows_and_columns(m5) == Matrix([[1], + [2], + [3], + [4]]) + + assert dixon.delete_zero_rows_and_columns(m6) == Matrix([[0, A], + [B, 0], + [0, C]]) + +def test_product_leading_entries(): + """Tests product of leading entries method.""" + A, B = symbols('A, B') + + m1 = Matrix([[1, 2, 3], + [0, 4, 5], + [0, 0, 6]]) + + m2 = Matrix([[0, 0, 1], + [2, 0, 3]]) + + m3 = Matrix([[0, 0, 0], + [1, 2, 3], + [0, 0, 0]]) + + m4 = Matrix([[0, 0, A], + [1, 2, 3], + [B, 0, 0]]) + + assert dixon.product_leading_entries(m1) == 24 + assert dixon.product_leading_entries(m2) == 2 + assert dixon.product_leading_entries(m3) == 1 + assert dixon.product_leading_entries(m4) == A * B + +def test_get_KSY_Dixon_resultant_example_one(): + """Tests the KSY Dixon resultant for example one""" + x, y, z = symbols('x, y, z') + + p = x * y * z + q = x**2 - z**2 + h = x + y + z + dixon = DixonResultant([p, q, h], [x, y]) + dixon_poly = dixon.get_dixon_polynomial() + dixon_matrix = dixon.get_dixon_matrix(dixon_poly) + D = dixon.get_KSY_Dixon_resultant(dixon_matrix) + + assert D == -z**3 + +def test_get_KSY_Dixon_resultant_example_two(): + """Tests the KSY Dixon resultant for example two""" + x, y, A = symbols('x, y, A') + + p = x * y + x * A + x - A**2 - A + y**2 + y + q = x**2 + x * A - x + x * y + y * A - y + h = x**2 + x * y + 2 * x - x * A - y * A - 2 * A + + dixon = DixonResultant([p, q, h], [x, y]) + dixon_poly = dixon.get_dixon_polynomial() + dixon_matrix = dixon.get_dixon_matrix(dixon_poly) + D = factor(dixon.get_KSY_Dixon_resultant(dixon_matrix)) + + assert D == -8*A*(A - 1)*(A + 2)*(2*A - 1)**2 + +def test_macaulay_resultant_init(): + """Test init method of MacaulayResultant.""" + + assert macaulay.polynomials == [p, q] + assert macaulay.variables == [x, y] + assert macaulay.n == 2 + assert macaulay.degrees == [1, 1] + assert macaulay.degree_m == 1 + assert macaulay.monomials_size == 2 + +def test_get_degree_m(): + assert macaulay._get_degree_m() == 1 + +def test_get_size(): + assert macaulay.get_size() == 2 + +def test_macaulay_example_one(): + """Tests the Macaulay for example from [Bruce97]_""" + + x, y, z = symbols('x, y, z') + a_1_1, a_1_2, a_1_3 = symbols('a_1_1, a_1_2, a_1_3') + a_2_2, a_2_3, a_3_3 = symbols('a_2_2, a_2_3, a_3_3') + b_1_1, b_1_2, b_1_3 = symbols('b_1_1, b_1_2, b_1_3') + b_2_2, b_2_3, b_3_3 = symbols('b_2_2, b_2_3, b_3_3') + c_1, c_2, c_3 = symbols('c_1, c_2, c_3') + + f_1 = a_1_1 * x ** 2 + a_1_2 * x * y + a_1_3 * x * z + \ + a_2_2 * y ** 2 + a_2_3 * y * z + a_3_3 * z ** 2 + f_2 = b_1_1 * x ** 2 + b_1_2 * x * y + b_1_3 * x * z + \ + b_2_2 * y ** 2 + b_2_3 * y * z + b_3_3 * z ** 2 + f_3 = c_1 * x + c_2 * y + c_3 * z + + mac = MacaulayResultant([f_1, f_2, f_3], [x, y, z]) + + assert mac.degrees == [2, 2, 1] + assert mac.degree_m == 3 + + assert mac.monomial_set == [x ** 3, x ** 2 * y, x ** 2 * z, + x * y ** 2, + x * y * z, x * z ** 2, y ** 3, + y ** 2 *z, y * z ** 2, z ** 3] + assert mac.monomials_size == 10 + assert mac.get_row_coefficients() == [[x, y, z], [x, y, z], + [x * y, x * z, y * z, z ** 2]] + + matrix = mac.get_matrix() + assert matrix.shape == (mac.monomials_size, mac.monomials_size) + assert mac.get_submatrix(matrix) == Matrix([[a_1_1, a_2_2], + [b_1_1, b_2_2]]) + +def test_macaulay_example_two(): + """Tests the Macaulay formulation for example from [Stiller96]_.""" + + x, y, z = symbols('x, y, z') + a_0, a_1, a_2 = symbols('a_0, a_1, a_2') + b_0, b_1, b_2 = symbols('b_0, b_1, b_2') + c_0, c_1, c_2, c_3, c_4 = symbols('c_0, c_1, c_2, c_3, c_4') + + f = a_0 * y - a_1 * x + a_2 * z + g = b_1 * x ** 2 + b_0 * y ** 2 - b_2 * z ** 2 + h = c_0 * y - c_1 * x ** 3 + c_2 * x ** 2 * z - c_3 * x * z ** 2 + \ + c_4 * z ** 3 + + mac = MacaulayResultant([f, g, h], [x, y, z]) + + assert mac.degrees == [1, 2, 3] + assert mac.degree_m == 4 + assert mac.monomials_size == 15 + assert len(mac.get_row_coefficients()) == mac.n + + matrix = mac.get_matrix() + assert matrix.shape == (mac.monomials_size, mac.monomials_size) + assert mac.get_submatrix(matrix) == Matrix([[-a_1, a_0, a_2, 0], + [0, -a_1, 0, 0], + [0, 0, -a_1, 0], + [0, 0, 0, -a_1]]) diff --git a/lib/python3.10/site-packages/sympy/polys/tests/test_orderings.py b/lib/python3.10/site-packages/sympy/polys/tests/test_orderings.py new file mode 100644 index 0000000000000000000000000000000000000000..d61d4887754c9d9f49905c2e131d253a45cf2ffd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/tests/test_orderings.py @@ -0,0 +1,124 @@ +"""Tests of monomial orderings. """ + +from sympy.polys.orderings import ( + monomial_key, lex, grlex, grevlex, ilex, igrlex, + LexOrder, InverseOrder, ProductOrder, build_product_order, +) + +from sympy.abc import x, y, z, t +from sympy.core import S +from sympy.testing.pytest import raises + +def test_lex_order(): + assert lex((1, 2, 3)) == (1, 2, 3) + assert str(lex) == 'lex' + + assert lex((1, 2, 3)) == lex((1, 2, 3)) + + assert lex((2, 2, 3)) > lex((1, 2, 3)) + assert lex((1, 3, 3)) > lex((1, 2, 3)) + assert lex((1, 2, 4)) > lex((1, 2, 3)) + + assert lex((0, 2, 3)) < lex((1, 2, 3)) + assert lex((1, 1, 3)) < lex((1, 2, 3)) + assert lex((1, 2, 2)) < lex((1, 2, 3)) + + assert lex.is_global is True + assert lex == LexOrder() + assert lex != grlex + +def test_grlex_order(): + assert grlex((1, 2, 3)) == (6, (1, 2, 3)) + assert str(grlex) == 'grlex' + + assert grlex((1, 2, 3)) == grlex((1, 2, 3)) + + assert grlex((2, 2, 3)) > grlex((1, 2, 3)) + assert grlex((1, 3, 3)) > grlex((1, 2, 3)) + assert grlex((1, 2, 4)) > grlex((1, 2, 3)) + + assert grlex((0, 2, 3)) < grlex((1, 2, 3)) + assert grlex((1, 1, 3)) < grlex((1, 2, 3)) + assert grlex((1, 2, 2)) < grlex((1, 2, 3)) + + assert grlex((2, 2, 3)) > grlex((1, 2, 4)) + assert grlex((1, 3, 3)) > grlex((1, 2, 4)) + + assert grlex((0, 2, 3)) < grlex((1, 2, 2)) + assert grlex((1, 1, 3)) < grlex((1, 2, 2)) + + assert grlex((0, 1, 1)) > grlex((0, 0, 2)) + assert grlex((0, 3, 1)) < grlex((2, 2, 1)) + + assert grlex.is_global is True + +def test_grevlex_order(): + assert grevlex((1, 2, 3)) == (6, (-3, -2, -1)) + assert str(grevlex) == 'grevlex' + + assert grevlex((1, 2, 3)) == grevlex((1, 2, 3)) + + assert grevlex((2, 2, 3)) > grevlex((1, 2, 3)) + assert grevlex((1, 3, 3)) > grevlex((1, 2, 3)) + assert grevlex((1, 2, 4)) > grevlex((1, 2, 3)) + + assert grevlex((0, 2, 3)) < grevlex((1, 2, 3)) + assert grevlex((1, 1, 3)) < grevlex((1, 2, 3)) + assert grevlex((1, 2, 2)) < grevlex((1, 2, 3)) + + assert grevlex((2, 2, 3)) > grevlex((1, 2, 4)) + assert grevlex((1, 3, 3)) > grevlex((1, 2, 4)) + + assert grevlex((0, 2, 3)) < grevlex((1, 2, 2)) + assert grevlex((1, 1, 3)) < grevlex((1, 2, 2)) + + assert grevlex((0, 1, 1)) > grevlex((0, 0, 2)) + assert grevlex((0, 3, 1)) < grevlex((2, 2, 1)) + + assert grevlex.is_global is True + +def test_InverseOrder(): + ilex = InverseOrder(lex) + igrlex = InverseOrder(grlex) + + assert ilex((1, 2, 3)) > ilex((2, 0, 3)) + assert igrlex((1, 2, 3)) < igrlex((0, 2, 3)) + assert str(ilex) == "ilex" + assert str(igrlex) == "igrlex" + assert ilex.is_global is False + assert igrlex.is_global is False + assert ilex != igrlex + assert ilex == InverseOrder(LexOrder()) + +def test_ProductOrder(): + P = ProductOrder((grlex, lambda m: m[:2]), (grlex, lambda m: m[2:])) + assert P((1, 3, 3, 4, 5)) > P((2, 1, 5, 5, 5)) + assert str(P) == "ProductOrder(grlex, grlex)" + assert P.is_global is True + assert ProductOrder((grlex, None), (ilex, None)).is_global is None + assert ProductOrder((igrlex, None), (ilex, None)).is_global is False + +def test_monomial_key(): + assert monomial_key() == lex + + assert monomial_key('lex') == lex + assert monomial_key('grlex') == grlex + assert monomial_key('grevlex') == grevlex + + raises(ValueError, lambda: monomial_key('foo')) + raises(ValueError, lambda: monomial_key(1)) + + M = [x, x**2*z**2, x*y, x**2, S.One, y**2, x**3, y, z, x*y**2*z, x**2*y**2] + assert sorted(M, key=monomial_key('lex', [z, y, x])) == \ + [S.One, x, x**2, x**3, y, x*y, y**2, x**2*y**2, z, x*y**2*z, x**2*z**2] + assert sorted(M, key=monomial_key('grlex', [z, y, x])) == \ + [S.One, x, y, z, x**2, x*y, y**2, x**3, x**2*y**2, x*y**2*z, x**2*z**2] + assert sorted(M, key=monomial_key('grevlex', [z, y, x])) == \ + [S.One, x, y, z, x**2, x*y, y**2, x**3, x**2*y**2, x**2*z**2, x*y**2*z] + +def test_build_product_order(): + assert build_product_order((("grlex", x, y), ("grlex", z, t)), [x, y, z, t])((4, 5, 6, 7)) == \ + ((9, (4, 5)), (13, (6, 7))) + + assert build_product_order((("grlex", x, y), ("grlex", z, t)), [x, y, z, t]) == \ + build_product_order((("grlex", x, y), ("grlex", z, t)), [x, y, z, t]) diff --git a/lib/python3.10/site-packages/sympy/polys/tests/test_orthopolys.py b/lib/python3.10/site-packages/sympy/polys/tests/test_orthopolys.py new file mode 100644 index 0000000000000000000000000000000000000000..e81fbe75aa6285d229ba817026f44b23b76abd6a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/tests/test_orthopolys.py @@ -0,0 +1,175 @@ +"""Tests for efficient functions for generating orthogonal polynomials. """ + +from sympy.core.numbers import Rational as Q +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.polys.polytools import Poly +from sympy.testing.pytest import raises + +from sympy.polys.orthopolys import ( + jacobi_poly, + gegenbauer_poly, + chebyshevt_poly, + chebyshevu_poly, + hermite_poly, + hermite_prob_poly, + legendre_poly, + laguerre_poly, + spherical_bessel_fn, +) + +from sympy.abc import x, a, b + + +def test_jacobi_poly(): + raises(ValueError, lambda: jacobi_poly(-1, a, b, x)) + + assert jacobi_poly(1, a, b, x, polys=True) == Poly( + (a/2 + b/2 + 1)*x + a/2 - b/2, x, domain='ZZ(a,b)') + + assert jacobi_poly(0, a, b, x) == 1 + assert jacobi_poly(1, a, b, x) == a/2 - b/2 + x*(a/2 + b/2 + 1) + assert jacobi_poly(2, a, b, x) == (a**2/8 - a*b/4 - a/8 + b**2/8 - b/8 + + x**2*(a**2/8 + a*b/4 + a*Q(7, 8) + b**2/8 + + b*Q(7, 8) + Q(3, 2)) + x*(a**2/4 + + a*Q(3, 4) - b**2/4 - b*Q(3, 4)) - S.Half) + + assert jacobi_poly(1, a, b, polys=True) == Poly( + (a/2 + b/2 + 1)*x + a/2 - b/2, x, domain='ZZ(a,b)') + + +def test_gegenbauer_poly(): + raises(ValueError, lambda: gegenbauer_poly(-1, a, x)) + + assert gegenbauer_poly( + 1, a, x, polys=True) == Poly(2*a*x, x, domain='ZZ(a)') + + assert gegenbauer_poly(0, a, x) == 1 + assert gegenbauer_poly(1, a, x) == 2*a*x + assert gegenbauer_poly(2, a, x) == -a + x**2*(2*a**2 + 2*a) + assert gegenbauer_poly( + 3, a, x) == x**3*(4*a**3/3 + 4*a**2 + a*Q(8, 3)) + x*(-2*a**2 - 2*a) + + assert gegenbauer_poly(1, S.Half).dummy_eq(x) + assert gegenbauer_poly(1, a, polys=True) == Poly(2*a*x, x, domain='ZZ(a)') + + +def test_chebyshevt_poly(): + raises(ValueError, lambda: chebyshevt_poly(-1, x)) + + assert chebyshevt_poly(1, x, polys=True) == Poly(x) + + assert chebyshevt_poly(0, x) == 1 + assert chebyshevt_poly(1, x) == x + assert chebyshevt_poly(2, x) == 2*x**2 - 1 + assert chebyshevt_poly(3, x) == 4*x**3 - 3*x + assert chebyshevt_poly(4, x) == 8*x**4 - 8*x**2 + 1 + assert chebyshevt_poly(5, x) == 16*x**5 - 20*x**3 + 5*x + assert chebyshevt_poly(6, x) == 32*x**6 - 48*x**4 + 18*x**2 - 1 + assert chebyshevt_poly(75, x) == (2*chebyshevt_poly(37, x)*chebyshevt_poly(38, x) - x).expand() + assert chebyshevt_poly(100, x) == (2*chebyshevt_poly(50, x)**2 - 1).expand() + + assert chebyshevt_poly(1).dummy_eq(x) + assert chebyshevt_poly(1, polys=True) == Poly(x) + + +def test_chebyshevu_poly(): + raises(ValueError, lambda: chebyshevu_poly(-1, x)) + + assert chebyshevu_poly(1, x, polys=True) == Poly(2*x) + + assert chebyshevu_poly(0, x) == 1 + assert chebyshevu_poly(1, x) == 2*x + assert chebyshevu_poly(2, x) == 4*x**2 - 1 + assert chebyshevu_poly(3, x) == 8*x**3 - 4*x + assert chebyshevu_poly(4, x) == 16*x**4 - 12*x**2 + 1 + assert chebyshevu_poly(5, x) == 32*x**5 - 32*x**3 + 6*x + assert chebyshevu_poly(6, x) == 64*x**6 - 80*x**4 + 24*x**2 - 1 + + assert chebyshevu_poly(1).dummy_eq(2*x) + assert chebyshevu_poly(1, polys=True) == Poly(2*x) + + +def test_hermite_poly(): + raises(ValueError, lambda: hermite_poly(-1, x)) + + assert hermite_poly(1, x, polys=True) == Poly(2*x) + + assert hermite_poly(0, x) == 1 + assert hermite_poly(1, x) == 2*x + assert hermite_poly(2, x) == 4*x**2 - 2 + assert hermite_poly(3, x) == 8*x**3 - 12*x + assert hermite_poly(4, x) == 16*x**4 - 48*x**2 + 12 + assert hermite_poly(5, x) == 32*x**5 - 160*x**3 + 120*x + assert hermite_poly(6, x) == 64*x**6 - 480*x**4 + 720*x**2 - 120 + + assert hermite_poly(1).dummy_eq(2*x) + assert hermite_poly(1, polys=True) == Poly(2*x) + + +def test_hermite_prob_poly(): + raises(ValueError, lambda: hermite_prob_poly(-1, x)) + + assert hermite_prob_poly(1, x, polys=True) == Poly(x) + + assert hermite_prob_poly(0, x) == 1 + assert hermite_prob_poly(1, x) == x + assert hermite_prob_poly(2, x) == x**2 - 1 + assert hermite_prob_poly(3, x) == x**3 - 3*x + assert hermite_prob_poly(4, x) == x**4 - 6*x**2 + 3 + assert hermite_prob_poly(5, x) == x**5 - 10*x**3 + 15*x + assert hermite_prob_poly(6, x) == x**6 - 15*x**4 + 45*x**2 - 15 + + assert hermite_prob_poly(1).dummy_eq(x) + assert hermite_prob_poly(1, polys=True) == Poly(x) + + +def test_legendre_poly(): + raises(ValueError, lambda: legendre_poly(-1, x)) + + assert legendre_poly(1, x, polys=True) == Poly(x, domain='QQ') + + assert legendre_poly(0, x) == 1 + assert legendre_poly(1, x) == x + assert legendre_poly(2, x) == Q(3, 2)*x**2 - Q(1, 2) + assert legendre_poly(3, x) == Q(5, 2)*x**3 - Q(3, 2)*x + assert legendre_poly(4, x) == Q(35, 8)*x**4 - Q(30, 8)*x**2 + Q(3, 8) + assert legendre_poly(5, x) == Q(63, 8)*x**5 - Q(70, 8)*x**3 + Q(15, 8)*x + assert legendre_poly(6, x) == Q( + 231, 16)*x**6 - Q(315, 16)*x**4 + Q(105, 16)*x**2 - Q(5, 16) + + assert legendre_poly(1).dummy_eq(x) + assert legendre_poly(1, polys=True) == Poly(x) + + +def test_laguerre_poly(): + raises(ValueError, lambda: laguerre_poly(-1, x)) + + assert laguerre_poly(1, x, polys=True) == Poly(-x + 1, domain='QQ') + + assert laguerre_poly(0, x) == 1 + assert laguerre_poly(1, x) == -x + 1 + assert laguerre_poly(2, x) == Q(1, 2)*x**2 - Q(4, 2)*x + 1 + assert laguerre_poly(3, x) == -Q(1, 6)*x**3 + Q(9, 6)*x**2 - Q(18, 6)*x + 1 + assert laguerre_poly(4, x) == Q( + 1, 24)*x**4 - Q(16, 24)*x**3 + Q(72, 24)*x**2 - Q(96, 24)*x + 1 + assert laguerre_poly(5, x) == -Q(1, 120)*x**5 + Q(25, 120)*x**4 - Q( + 200, 120)*x**3 + Q(600, 120)*x**2 - Q(600, 120)*x + 1 + assert laguerre_poly(6, x) == Q(1, 720)*x**6 - Q(36, 720)*x**5 + Q(450, 720)*x**4 - Q(2400, 720)*x**3 + Q(5400, 720)*x**2 - Q(4320, 720)*x + 1 + + assert laguerre_poly(0, x, a) == 1 + assert laguerre_poly(1, x, a) == -x + a + 1 + assert laguerre_poly(2, x, a) == x**2/2 + (-a - 2)*x + a**2/2 + a*Q(3, 2) + 1 + assert laguerre_poly(3, x, a) == -x**3/6 + (a/2 + Q( + 3)/2)*x**2 + (-a**2/2 - a*Q(5, 2) - 3)*x + a**3/6 + a**2 + a*Q(11, 6) + 1 + + assert laguerre_poly(1).dummy_eq(-x + 1) + assert laguerre_poly(1, polys=True) == Poly(-x + 1) + + +def test_spherical_bessel_fn(): + x, z = symbols("x z") + assert spherical_bessel_fn(1, z) == 1/z**2 + assert spherical_bessel_fn(2, z) == -1/z + 3/z**3 + assert spherical_bessel_fn(3, z) == -6/z**2 + 15/z**4 + assert spherical_bessel_fn(4, z) == 1/z - 45/z**3 + 105/z**5 diff --git a/lib/python3.10/site-packages/sympy/polys/tests/test_partfrac.py b/lib/python3.10/site-packages/sympy/polys/tests/test_partfrac.py new file mode 100644 index 0000000000000000000000000000000000000000..83c5d48383d20e67dbb53c081093ad35e654c9a0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/polys/tests/test_partfrac.py @@ -0,0 +1,249 @@ +"""Tests for algorithms for partial fraction decomposition of rational +functions. """ + +from sympy.polys.partfrac import ( + apart_undetermined_coeffs, + apart, + apart_list, assemble_partfrac_list +) + +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.numbers import (E, I, Rational, pi, all_close) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.polys.polytools import (Poly, factor) +from sympy.polys.rationaltools import together +from sympy.polys.rootoftools import RootSum +from sympy.testing.pytest import raises, XFAIL +from sympy.abc import x, y, a, b, c + + +def test_apart(): + assert apart(1) == 1 + assert apart(1, x) == 1 + + f, g = (x**2 + 1)/(x + 1), 2/(x + 1) + x - 1 + + assert apart(f, full=False) == g + assert apart(f, full=True) == g + + f, g = 1/(x + 2)/(x + 1), 1/(1 + x) - 1/(2 + x) + + assert apart(f, full=False) == g + assert apart(f, full=True) == g + + f, g = 1/(x + 1)/(x + 5), -1/(5 + x)/4 + 1/(1 + x)/4 + + assert apart(f, full=False) == g + assert apart(f, full=True) == g + + assert apart((E*x + 2)/(x - pi)*(x - 1), x) == \ + 2 - E + E*pi + E*x + (E*pi + 2)*(pi - 1)/(x - pi) + + assert apart(Eq((x**2 + 1)/(x + 1), x), x) == Eq(x - 1 + 2/(x + 1), x) + + assert apart(x/2, y) == x/2 + + f, g = (x+y)/(2*x - y), Rational(3, 2)*y/(2*x - y) + S.Half + + assert apart(f, x, full=False) == g + assert apart(f, x, full=True) == g + + f, g = (x+y)/(2*x - y), 3*x/(2*x - y) - 1 + + assert apart(f, y, full=False) == g + assert apart(f, y, full=True) == g + + raises(NotImplementedError, lambda: apart(1/(x + 1)/(y + 2))) + + +def test_apart_matrix(): + M = Matrix(2, 2, lambda i, j: 1/(x + i + 1)/(x + j)) + + assert apart(M) == Matrix([ + [1/x - 1/(x + 1), (x + 1)**(-2)], + [1/(2*x) - (S.Half)/(x + 2), 1/(x + 1) - 1/(x + 2)], + ]) + + +def test_apart_symbolic(): + f = a*x**4 + (2*b + 2*a*c)*x**3 + (4*b*c - a**2 + a*c**2)*x**2 + \ + (-2*a*b + 2*b*c**2)*x - b**2 + g = a**2*x**4 + (2*a*b + 2*c*a**2)*x**3 + (4*a*b*c + b**2 + + a**2*c**2)*x**2 + (2*c*b**2 + 2*a*b*c**2)*x + b**2*c**2 + + assert apart(f/g, x) == 1/a - 1/(x + c)**2 - b**2/(a*(a*x + b)**2) + + assert apart(1/((x + a)*(x + b)*(x + c)), x) == \ + 1/((a - c)*(b - c)*(c + x)) - 1/((a - b)*(b - c)*(b + x)) + \ + 1/((a - b)*(a - c)*(a + x)) + + +def _make_extension_example(): + # https://github.com/sympy/sympy/issues/18531 + from sympy.core import Mul + def mul2(expr): + # 2-arg mul hack... + return Mul(2, expr, evaluate=False) + + f = ((x**2 + 1)**3/((x - 1)**2*(x + 1)**2*(-x**2 + 2*x + 1)*(x**2 + 2*x - 1))) + g = (1/mul2(x - sqrt(2) + 1) + - 1/mul2(x - sqrt(2) - 1) + + 1/mul2(x + 1 + sqrt(2)) + - 1/mul2(x - 1 + sqrt(2)) + + 1/mul2((x + 1)**2) + + 1/mul2((x - 1)**2)) + return f, g + + +def test_apart_extension(): + f = 2/(x**2 + 1) + g = I/(x + I) - I/(x - I) + + assert apart(f, extension=I) == g + assert apart(f, gaussian=True) == g + + f = x/((x - 2)*(x + I)) + + assert factor(together(apart(f)).expand()) == f + + f, g = _make_extension_example() + + # XXX: Only works with dotprodsimp. See test_apart_extension_xfail below + from sympy.matrices import dotprodsimp + with dotprodsimp(True): + assert apart(f, x, extension={sqrt(2)}) == g + + +def test_apart_extension_xfail(): + f, g = _make_extension_example() + assert apart(f, x, extension={sqrt(2)}) == g + + +def test_apart_full(): + f = 1/(x**2 + 1) + + assert apart(f, full=False) == f + assert apart(f, full=True).dummy_eq( + -RootSum(x**2 + 1, Lambda(a, a/(x - a)), auto=False)/2) + + f = 1/(x**3 + x + 1) + + assert apart(f, full=False) == f + assert apart(f, full=True).dummy_eq( + RootSum(x**3 + x + 1, + Lambda(a, (a**2*Rational(6, 31) - a*Rational(9, 31) + Rational(4, 31))/(x - a)), auto=False)) + + f = 1/(x**5 + 1) + + assert apart(f, full=False) == \ + (Rational(-1, 5))*((x**3 - 2*x**2 + 3*x - 4)/(x**4 - x**3 + x**2 - + x + 1)) + (Rational(1, 5))/(x + 1) + assert apart(f, full=True).dummy_eq( + -RootSum(x**4 - x**3 + x**2 - x + 1, + Lambda(a, a/(x - a)), auto=False)/5 + (Rational(1, 5))/(x + 1)) + + +def test_apart_full_floats(): + # https://github.com/sympy/sympy/issues/26648 + f = ( + 6.43369157032015e-9*x**3 + 1.35203404799555e-5*x**2 + + 0.00357538393743079*x + 0.085 + )/( + 4.74334912634438e-11*x**4 + 4.09576274286244e-6*x**3 + + 0.00334241812250921*x**2 + 0.15406018058983*x + 1.0 + ) + + expected = ( + 133.599202650992/(x + 85524.0054884464) + + 1.07757928431867/(x + 774.88576677949) + + 0.395006955518971/(x + 40.7977016133126) + + 0.564264854137341/(x + 7.79746609204661) + ) + + f_apart = apart(f, full=True).evalf() + + # There is a significant floating point error in this operation. + assert all_close(f_apart, expected, rtol=1e-3, atol=1e-5) + + +def test_apart_undetermined_coeffs(): + p = Poly(2*x - 3) + q = Poly(x**9 - x**8 - x**6 + x**5 - 2*x**2 + 3*x - 1) + r = (-x**7 - x**6 - x**5 + 4)/(x**8 - x**5 - 2*x + 1) + 1/(x - 1) + + assert apart_undetermined_coeffs(p, q) == r + + p = Poly(1, x, domain='ZZ[a,b]') + q = Poly((x + a)*(x + b), x, domain='ZZ[a,b]') + r = 1/((a - b)*(b + x)) - 1/((a - b)*(a + x)) + + assert apart_undetermined_coeffs(p, q) == r + + +def test_apart_list(): + from sympy.utilities.iterables import numbered_symbols + def dummy_eq(i, j): + if type(i) in (list, tuple): + return all(dummy_eq(i, j) for i, j in zip(i, j)) + return i == j or i.dummy_eq(j) + + w0, w1, w2 = Symbol("w0"), Symbol("w1"), Symbol("w2") + _a = Dummy("a") + + f = (-2*x - 2*x**2) / (3*x**2 - 6*x) + got = apart_list(f, x, dummies=numbered_symbols("w")) + ans = (-1, Poly(Rational(2, 3), x, domain='QQ'), + [(Poly(w0 - 2, w0, domain='ZZ'), Lambda(_a, 2), Lambda(_a, -_a + x), 1)]) + assert dummy_eq(got, ans) + + got = apart_list(2/(x**2-2), x, dummies=numbered_symbols("w")) + ans = (1, Poly(0, x, domain='ZZ'), [(Poly(w0**2 - 2, w0, domain='ZZ'), + Lambda(_a, _a/2), + Lambda(_a, -_a + x), 1)]) + assert dummy_eq(got, ans) + + f = 36 / (x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2) + got = apart_list(f, x, dummies=numbered_symbols("w")) + ans = (1, Poly(0, x, domain='ZZ'), + [(Poly(w0 - 2, w0, domain='ZZ'), Lambda(_a, 4), Lambda(_a, -_a + x), 1), + (Poly(w1**2 - 1, w1, domain='ZZ'), Lambda(_a, -3*_a - 6), Lambda(_a, -_a + x), 2), + (Poly(w2 + 1, w2, domain='ZZ'), Lambda(_a, -4), Lambda(_a, -_a + x), 1)]) + assert dummy_eq(got, ans) + + +def test_assemble_partfrac_list(): + f = 36 / (x**5 - 2*x**4 - 2*x**3 + 4*x**2 + x - 2) + pfd = apart_list(f) + assert assemble_partfrac_list(pfd) == -4/(x + 1) - 3/(x + 1)**2 - 9/(x - 1)**2 + 4/(x - 2) + + a = Dummy("a") + pfd = (1, Poly(0, x, domain='ZZ'), [([sqrt(2),-sqrt(2)], Lambda(a, a/2), Lambda(a, -a + x), 1)]) + assert assemble_partfrac_list(pfd) == -1/(sqrt(2)*(x + sqrt(2))) + 1/(sqrt(2)*(x - sqrt(2))) + + +@XFAIL +def test_noncommutative_pseudomultivariate(): + # apart doesn't go inside noncommutative expressions + class foo(Expr): + is_commutative=False + e = x/(x + x*y) + c = 1/(1 + y) + assert apart(e + foo(e)) == c + foo(c) + assert apart(e*foo(e)) == c*foo(c) + +def test_noncommutative(): + class foo(Expr): + is_commutative=False + e = x/(x + x*y) + c = 1/(1 + y) + assert apart(e + foo()) == c + foo() + +def test_issue_5798(): + assert apart( + 2*x/(x**2 + 1) - (x - 1)/(2*(x**2 + 1)) + 1/(2*(x + 1)) - 2/x) == \ + (3*x + 1)/(x**2 + 1)/2 + 1/(x + 1)/2 - 2/x diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c41c999ae8627dd3a059f7af27dbd6fc6e3a4591 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/aesaracode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/aesaracode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2a9855b8a025b5b16da23e6d3676623edd764ff3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/aesaracode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/c.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/c.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e21dce24978cd7b16f5e6e828f81fb31908f9f79 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/c.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/codeprinter.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/codeprinter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c353fdc1d62d992d828769c7391369e37d427cea Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/codeprinter.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/conventions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/conventions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dca164e0e43860ee5733e1405d4254a6ab79cfe0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/conventions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/cxx.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/cxx.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56146a147994f6b2115d7617fe63a2d737bf6378 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/cxx.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/defaults.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/defaults.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fbc3c80679cf77e1075e14177bc56717e0c297f3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/defaults.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/dot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/dot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..155cc831186205186c28d4913e788ad83c1c34c7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/dot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/fortran.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/fortran.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..332e8383a1579a2441c1ca504e2b0addf1dd18c0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/fortran.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/glsl.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/glsl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..381cc8dc31983cc5b7a67a607660f50a8b0cc6ed Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/glsl.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/gtk.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/gtk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c27380ee95f953fa39c0183acaca199f180ff09a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/gtk.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/jscode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/jscode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f6341b776749b71502778eee9d40baffd187b061 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/jscode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/julia.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/julia.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebd8df6f958ac99a75f9ed2dc972e8830da8085e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/julia.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97a7e9c0531e4c136370458648bbd05ee7fb9c8e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/lambdarepr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..73edb0e2ce3fc496ea5a161c53a7314c1c89ecfb Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/llvmjitcode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/maple.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/maple.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68fd4cb1404865bdc50e7d24493a0f49866a0ea5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/maple.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/mathematica.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/mathematica.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..81376a3862591f05109f6e289e932fea4952b70a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/mathematica.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/mathml.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/mathml.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7cf65fd1e564b25635ff08db4077074dee8f6012 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/mathml.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/numpy.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/numpy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..388096f93f906f8784b891d224bf0e339f837c78 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/numpy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/octave.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/octave.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74480c14113d01ae0f9a32e81ab24edc1803a547 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/octave.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/precedence.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/precedence.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47957a0bdd5c72ee280d8f77153cbbfe3d4159ed Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/precedence.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/preview.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/preview.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aee2a6c4c4c252f1400f1c7ce1e1d9a8a749e221 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/preview.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/printer.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/printer.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2df29ea1f825f6b3707ae8e2f735e71c33a1e528 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/printer.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/pycode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/pycode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8bfd97163ef567b0f998a2d03af608aab870bb67 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/pycode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/python.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/python.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79a55cacb2a56d155c1a8c02df11bdb97e4f7595 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/python.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/rcode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/rcode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e2b9d19127fa1946a09ddc054ff84efe2286cd5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/rcode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/repr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/repr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f90ca7efa9ac5b2049e95c287aaf0d9b5c47154 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/repr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/rust.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/rust.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ef35e015c2ec43a626d16a62a9728f786eb0ab2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/rust.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/smtlib.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/smtlib.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c94cf8b3a4a6e5281c460eb5375ce523995407ee Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/smtlib.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/str.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/str.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..40cf3b1eaee1abeb72da1234ac60adcfc0179581 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/str.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/tableform.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/tableform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..068bcfdca31ce50ad0298198a63371cb8c1c8f85 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/tableform.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/tensorflow.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/tensorflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2de3308cc34efb29c015e0eec8d342d684539030 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/tensorflow.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/theanocode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/theanocode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e6c8631cb4c9f20451898340cbb6f7c24c9b87e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/theanocode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/__pycache__/tree.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/__pycache__/tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8266badc2017ce8f29c9b25f2472adcbbf2fbf93 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/__pycache__/tree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/__init__.py b/lib/python3.10/site-packages/sympy/printing/pretty/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cbabc649152a3c353a37225d342064634fbb5805 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/pretty/__init__.py @@ -0,0 +1,12 @@ +"""ASCII-ART 2D pretty-printer""" + +from .pretty import (pretty, pretty_print, pprint, pprint_use_unicode, + pprint_try_use_unicode, pager_print) + +# if unicode output is available -- let's use it +pprint_try_use_unicode() + +__all__ = [ + 'pretty', 'pretty_print', 'pprint', 'pprint_use_unicode', + 'pprint_try_use_unicode', 'pager_print', +] diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52118a89f7a0314d306c42262c8a2315caee9480 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/pretty.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/pretty.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13748669e079f7a584090ebba902d7c99bf4dc3c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/pretty.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e852507bd70b33fd7a5fe6eae30c878fc4e7900a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/pretty_symbology.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..27091dceb2cd427da4c39d99acfa6494c3b2c16e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/pretty/__pycache__/stringpict.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/pretty.py b/lib/python3.10/site-packages/sympy/printing/pretty/pretty.py new file mode 100644 index 0000000000000000000000000000000000000000..b945f009119b24fc95e8452d91359957baba26a8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/pretty/pretty.py @@ -0,0 +1,2937 @@ +import itertools + +from sympy.core import S +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import Number, Rational +from sympy.core.power import Pow +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Symbol +from sympy.core.sympify import SympifyError +from sympy.printing.conventions import requires_partial +from sympy.printing.precedence import PRECEDENCE, precedence, precedence_traditional +from sympy.printing.printer import Printer, print_function +from sympy.printing.str import sstr +from sympy.utilities.iterables import has_variety +from sympy.utilities.exceptions import sympy_deprecation_warning + +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.pretty.pretty_symbology import hobj, vobj, xobj, \ + xsym, pretty_symbol, pretty_atom, pretty_use_unicode, greek_unicode, U, \ + pretty_try_use_unicode, annotated, is_subscriptable_in_unicode, center_pad, root as nth_root + +# rename for usage from outside +pprint_use_unicode = pretty_use_unicode +pprint_try_use_unicode = pretty_try_use_unicode + + +class PrettyPrinter(Printer): + """Printer, which converts an expression into 2D ASCII-art figure.""" + printmethod = "_pretty" + + _default_settings = { + "order": None, + "full_prec": "auto", + "use_unicode": None, + "wrap_line": True, + "num_columns": None, + "use_unicode_sqrt_char": True, + "root_notation": True, + "mat_symbol_style": "plain", + "imaginary_unit": "i", + "perm_cyclic": True + } + + def __init__(self, settings=None): + Printer.__init__(self, settings) + + if not isinstance(self._settings['imaginary_unit'], str): + raise TypeError("'imaginary_unit' must a string, not {}".format(self._settings['imaginary_unit'])) + elif self._settings['imaginary_unit'] not in ("i", "j"): + raise ValueError("'imaginary_unit' must be either 'i' or 'j', not '{}'".format(self._settings['imaginary_unit'])) + + def emptyPrinter(self, expr): + return prettyForm(str(expr)) + + @property + def _use_unicode(self): + if self._settings['use_unicode']: + return True + else: + return pretty_use_unicode() + + def doprint(self, expr): + return self._print(expr).render(**self._settings) + + # empty op so _print(stringPict) returns the same + def _print_stringPict(self, e): + return e + + def _print_basestring(self, e): + return prettyForm(e) + + def _print_atan2(self, e): + pform = prettyForm(*self._print_seq(e.args).parens()) + pform = prettyForm(*pform.left('atan2')) + return pform + + def _print_Symbol(self, e, bold_name=False): + symb = pretty_symbol(e.name, bold_name) + return prettyForm(symb) + _print_RandomSymbol = _print_Symbol + def _print_MatrixSymbol(self, e): + return self._print_Symbol(e, self._settings['mat_symbol_style'] == "bold") + + def _print_Float(self, e): + # we will use StrPrinter's Float printer, but we need to handle the + # full_prec ourselves, according to the self._print_level + full_prec = self._settings["full_prec"] + if full_prec == "auto": + full_prec = self._print_level == 1 + return prettyForm(sstr(e, full_prec=full_prec)) + + def _print_Cross(self, e): + vec1 = e._expr1 + vec2 = e._expr2 + pform = self._print(vec2) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('MULTIPLICATION SIGN')))) + pform = prettyForm(*pform.left(')')) + pform = prettyForm(*pform.left(self._print(vec1))) + pform = prettyForm(*pform.left('(')) + return pform + + def _print_Curl(self, e): + vec = e._expr + pform = self._print(vec) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('MULTIPLICATION SIGN')))) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Divergence(self, e): + vec = e._expr + pform = self._print(vec) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('DOT OPERATOR')))) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Dot(self, e): + vec1 = e._expr1 + vec2 = e._expr2 + pform = self._print(vec2) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('DOT OPERATOR')))) + pform = prettyForm(*pform.left(')')) + pform = prettyForm(*pform.left(self._print(vec1))) + pform = prettyForm(*pform.left('(')) + return pform + + def _print_Gradient(self, e): + func = e._expr + pform = self._print(func) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('NABLA')))) + return pform + + def _print_Laplacian(self, e): + func = e._expr + pform = self._print(func) + pform = prettyForm(*pform.left('(')) + pform = prettyForm(*pform.right(')')) + pform = prettyForm(*pform.left(self._print(U('INCREMENT')))) + return pform + + def _print_Atom(self, e): + try: + # print atoms like Exp1 or Pi + return prettyForm(pretty_atom(e.__class__.__name__, printer=self)) + except KeyError: + return self.emptyPrinter(e) + + # Infinity inherits from Number, so we have to override _print_XXX order + _print_Infinity = _print_Atom + _print_NegativeInfinity = _print_Atom + _print_EmptySet = _print_Atom + _print_Naturals = _print_Atom + _print_Naturals0 = _print_Atom + _print_Integers = _print_Atom + _print_Rationals = _print_Atom + _print_Complexes = _print_Atom + + _print_EmptySequence = _print_Atom + + def _print_Reals(self, e): + if self._use_unicode: + return self._print_Atom(e) + else: + inf_list = ['-oo', 'oo'] + return self._print_seq(inf_list, '(', ')') + + def _print_subfactorial(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('!')) + return pform + + def _print_factorial(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right('!')) + return pform + + def _print_factorial2(self, e): + x = e.args[0] + pform = self._print(x) + # Add parentheses if needed + if not ((x.is_Integer and x.is_nonnegative) or x.is_Symbol): + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right('!!')) + return pform + + def _print_binomial(self, e): + n, k = e.args + + n_pform = self._print(n) + k_pform = self._print(k) + + bar = ' '*max(n_pform.width(), k_pform.width()) + + pform = prettyForm(*k_pform.above(bar)) + pform = prettyForm(*pform.above(n_pform)) + pform = prettyForm(*pform.parens('(', ')')) + + pform.baseline = (pform.baseline + 1)//2 + + return pform + + def _print_Relational(self, e): + op = prettyForm(' ' + xsym(e.rel_op) + ' ') + + l = self._print(e.lhs) + r = self._print(e.rhs) + pform = prettyForm(*stringPict.next(l, op, r), binding=prettyForm.OPEN) + return pform + + def _print_Not(self, e): + from sympy.logic.boolalg import (Equivalent, Implies) + if self._use_unicode: + arg = e.args[0] + pform = self._print(arg) + if isinstance(arg, Equivalent): + return self._print_Equivalent(arg, altchar=pretty_atom('NotEquiv')) + if isinstance(arg, Implies): + return self._print_Implies(arg, altchar=pretty_atom('NotArrow')) + + if arg.is_Boolean and not arg.is_Not: + pform = prettyForm(*pform.parens()) + + return prettyForm(*pform.left(pretty_atom('Not'))) + else: + return self._print_Function(e) + + def __print_Boolean(self, e, char, sort=True): + args = e.args + if sort: + args = sorted(e.args, key=default_sort_key) + arg = args[0] + pform = self._print(arg) + + if arg.is_Boolean and not arg.is_Not: + pform = prettyForm(*pform.parens()) + + for arg in args[1:]: + pform_arg = self._print(arg) + + if arg.is_Boolean and not arg.is_Not: + pform_arg = prettyForm(*pform_arg.parens()) + + pform = prettyForm(*pform.right(' %s ' % char)) + pform = prettyForm(*pform.right(pform_arg)) + + return pform + + def _print_And(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('And')) + else: + return self._print_Function(e, sort=True) + + def _print_Or(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Or')) + else: + return self._print_Function(e, sort=True) + + def _print_Xor(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom("Xor")) + else: + return self._print_Function(e, sort=True) + + def _print_Nand(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Nand')) + else: + return self._print_Function(e, sort=True) + + def _print_Nor(self, e): + if self._use_unicode: + return self.__print_Boolean(e, pretty_atom('Nor')) + else: + return self._print_Function(e, sort=True) + + def _print_Implies(self, e, altchar=None): + if self._use_unicode: + return self.__print_Boolean(e, altchar or pretty_atom('Arrow'), sort=False) + else: + return self._print_Function(e) + + def _print_Equivalent(self, e, altchar=None): + if self._use_unicode: + return self.__print_Boolean(e, altchar or pretty_atom('Equiv')) + else: + return self._print_Function(e, sort=True) + + def _print_conjugate(self, e): + pform = self._print(e.args[0]) + return prettyForm( *pform.above( hobj('_', pform.width())) ) + + def _print_Abs(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('|', '|')) + return pform + + def _print_floor(self, e): + if self._use_unicode: + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('lfloor', 'rfloor')) + return pform + else: + return self._print_Function(e) + + def _print_ceiling(self, e): + if self._use_unicode: + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens('lceil', 'rceil')) + return pform + else: + return self._print_Function(e) + + def _print_Derivative(self, deriv): + if requires_partial(deriv.expr) and self._use_unicode: + deriv_symbol = U('PARTIAL DIFFERENTIAL') + else: + deriv_symbol = r'd' + x = None + count_total_deriv = 0 + + for sym, num in reversed(deriv.variable_count): + s = self._print(sym) + ds = prettyForm(*s.left(deriv_symbol)) + count_total_deriv += num + + if (not num.is_Integer) or (num > 1): + ds = ds**prettyForm(str(num)) + + if x is None: + x = ds + else: + x = prettyForm(*x.right(' ')) + x = prettyForm(*x.right(ds)) + + f = prettyForm( + binding=prettyForm.FUNC, *self._print(deriv.expr).parens()) + + pform = prettyForm(deriv_symbol) + + if (count_total_deriv > 1) != False: + pform = pform**prettyForm(str(count_total_deriv)) + + pform = prettyForm(*pform.below(stringPict.LINE, x)) + pform.baseline = pform.baseline + 1 + pform = prettyForm(*stringPict.next(pform, f)) + pform.binding = prettyForm.MUL + + return pform + + def _print_Cycle(self, dc): + from sympy.combinatorics.permutations import Permutation, Cycle + # for Empty Cycle + if dc == Cycle(): + cyc = stringPict('') + return prettyForm(*cyc.parens()) + + dc_list = Permutation(dc.list()).cyclic_form + # for Identity Cycle + if dc_list == []: + cyc = self._print(dc.size - 1) + return prettyForm(*cyc.parens()) + + cyc = stringPict('') + for i in dc_list: + l = self._print(str(tuple(i)).replace(',', '')) + cyc = prettyForm(*cyc.right(l)) + return cyc + + def _print_Permutation(self, expr): + from sympy.combinatorics.permutations import Permutation, Cycle + + perm_cyclic = Permutation.print_cyclic + if perm_cyclic is not None: + sympy_deprecation_warning( + f""" + Setting Permutation.print_cyclic is deprecated. Instead use + init_printing(perm_cyclic={perm_cyclic}). + """, + deprecated_since_version="1.6", + active_deprecations_target="deprecated-permutation-print_cyclic", + stacklevel=7, + ) + else: + perm_cyclic = self._settings.get("perm_cyclic", True) + + if perm_cyclic: + return self._print_Cycle(Cycle(expr)) + + lower = expr.array_form + upper = list(range(len(lower))) + + result = stringPict('') + first = True + for u, l in zip(upper, lower): + s1 = self._print(u) + s2 = self._print(l) + col = prettyForm(*s1.below(s2)) + if first: + first = False + else: + col = prettyForm(*col.left(" ")) + result = prettyForm(*result.right(col)) + return prettyForm(*result.parens()) + + + def _print_Integral(self, integral): + f = integral.function + + # Add parentheses if arg involves addition of terms and + # create a pretty form for the argument + prettyF = self._print(f) + # XXX generalize parens + if f.is_Add: + prettyF = prettyForm(*prettyF.parens()) + + # dx dy dz ... + arg = prettyF + for x in integral.limits: + prettyArg = self._print(x[0]) + # XXX qparens (parens if needs-parens) + if prettyArg.width() > 1: + prettyArg = prettyForm(*prettyArg.parens()) + + arg = prettyForm(*arg.right(' d', prettyArg)) + + # \int \int \int ... + firstterm = True + s = None + for lim in integral.limits: + # Create bar based on the height of the argument + h = arg.height() + H = h + 2 + + # XXX hack! + ascii_mode = not self._use_unicode + if ascii_mode: + H += 2 + + vint = vobj('int', H) + + # Construct the pretty form with the integral sign and the argument + pform = prettyForm(vint) + pform.baseline = arg.baseline + ( + H - h)//2 # covering the whole argument + + if len(lim) > 1: + # Create pretty forms for endpoints, if definite integral. + # Do not print empty endpoints. + if len(lim) == 2: + prettyA = prettyForm("") + prettyB = self._print(lim[1]) + if len(lim) == 3: + prettyA = self._print(lim[1]) + prettyB = self._print(lim[2]) + + if ascii_mode: # XXX hack + # Add spacing so that endpoint can more easily be + # identified with the correct integral sign + spc = max(1, 3 - prettyB.width()) + prettyB = prettyForm(*prettyB.left(' ' * spc)) + + spc = max(1, 4 - prettyA.width()) + prettyA = prettyForm(*prettyA.right(' ' * spc)) + + pform = prettyForm(*pform.above(prettyB)) + pform = prettyForm(*pform.below(prettyA)) + + if not ascii_mode: # XXX hack + pform = prettyForm(*pform.right(' ')) + + if firstterm: + s = pform # first term + firstterm = False + else: + s = prettyForm(*s.left(pform)) + + pform = prettyForm(*arg.left(s)) + pform.binding = prettyForm.MUL + return pform + + def _print_Product(self, expr): + func = expr.term + pretty_func = self._print(func) + + horizontal_chr = xobj('_', 1) + corner_chr = xobj('_', 1) + vertical_chr = xobj('|', 1) + + if self._use_unicode: + # use unicode corners + horizontal_chr = xobj('-', 1) + corner_chr = xobj('UpTack', 1) + + func_height = pretty_func.height() + + first = True + max_upper = 0 + sign_height = 0 + + for lim in expr.limits: + pretty_lower, pretty_upper = self.__print_SumProduct_Limits(lim) + + width = (func_height + 2) * 5 // 3 - 2 + sign_lines = [horizontal_chr + corner_chr + (horizontal_chr * (width-2)) + corner_chr + horizontal_chr] + for _ in range(func_height + 1): + sign_lines.append(' ' + vertical_chr + (' ' * (width-2)) + vertical_chr + ' ') + + pretty_sign = stringPict('') + pretty_sign = prettyForm(*pretty_sign.stack(*sign_lines)) + + + max_upper = max(max_upper, pretty_upper.height()) + + if first: + sign_height = pretty_sign.height() + + pretty_sign = prettyForm(*pretty_sign.above(pretty_upper)) + pretty_sign = prettyForm(*pretty_sign.below(pretty_lower)) + + if first: + pretty_func.baseline = 0 + first = False + + height = pretty_sign.height() + padding = stringPict('') + padding = prettyForm(*padding.stack(*[' ']*(height - 1))) + pretty_sign = prettyForm(*pretty_sign.right(padding)) + + pretty_func = prettyForm(*pretty_sign.right(pretty_func)) + + pretty_func.baseline = max_upper + sign_height//2 + pretty_func.binding = prettyForm.MUL + return pretty_func + + def __print_SumProduct_Limits(self, lim): + def print_start(lhs, rhs): + op = prettyForm(' ' + xsym("==") + ' ') + l = self._print(lhs) + r = self._print(rhs) + pform = prettyForm(*stringPict.next(l, op, r)) + return pform + + prettyUpper = self._print(lim[2]) + prettyLower = print_start(lim[0], lim[1]) + return prettyLower, prettyUpper + + def _print_Sum(self, expr): + ascii_mode = not self._use_unicode + + def asum(hrequired, lower, upper, use_ascii): + def adjust(s, wid=None, how='<^>'): + if not wid or len(s) > wid: + return s + need = wid - len(s) + if how in ('<^>', "<") or how not in list('<^>'): + return s + ' '*need + half = need//2 + lead = ' '*half + if how == ">": + return " "*need + s + return lead + s + ' '*(need - len(lead)) + + h = max(hrequired, 2) + d = h//2 + w = d + 1 + more = hrequired % 2 + + lines = [] + if use_ascii: + lines.append("_"*(w) + ' ') + lines.append(r"\%s`" % (' '*(w - 1))) + for i in range(1, d): + lines.append('%s\\%s' % (' '*i, ' '*(w - i))) + if more: + lines.append('%s)%s' % (' '*(d), ' '*(w - d))) + for i in reversed(range(1, d)): + lines.append('%s/%s' % (' '*i, ' '*(w - i))) + lines.append("/" + "_"*(w - 1) + ',') + return d, h + more, lines, more + else: + w = w + more + d = d + more + vsum = vobj('sum', 4) + lines.append("_"*(w)) + for i in range(0, d): + lines.append('%s%s%s' % (' '*i, vsum[2], ' '*(w - i - 1))) + for i in reversed(range(0, d)): + lines.append('%s%s%s' % (' '*i, vsum[4], ' '*(w - i - 1))) + lines.append(vsum[8]*(w)) + return d, h + 2*more, lines, more + + f = expr.function + + prettyF = self._print(f) + + if f.is_Add: # add parens + prettyF = prettyForm(*prettyF.parens()) + + H = prettyF.height() + 2 + + # \sum \sum \sum ... + first = True + max_upper = 0 + sign_height = 0 + + for lim in expr.limits: + prettyLower, prettyUpper = self.__print_SumProduct_Limits(lim) + + max_upper = max(max_upper, prettyUpper.height()) + + # Create sum sign based on the height of the argument + d, h, slines, adjustment = asum( + H, prettyLower.width(), prettyUpper.width(), ascii_mode) + prettySign = stringPict('') + prettySign = prettyForm(*prettySign.stack(*slines)) + + if first: + sign_height = prettySign.height() + + prettySign = prettyForm(*prettySign.above(prettyUpper)) + prettySign = prettyForm(*prettySign.below(prettyLower)) + + if first: + # change F baseline so it centers on the sign + prettyF.baseline -= d - (prettyF.height()//2 - + prettyF.baseline) + first = False + + # put padding to the right + pad = stringPict('') + pad = prettyForm(*pad.stack(*[' ']*h)) + prettySign = prettyForm(*prettySign.right(pad)) + # put the present prettyF to the right + prettyF = prettyForm(*prettySign.right(prettyF)) + + # adjust baseline of ascii mode sigma with an odd height so that it is + # exactly through the center + ascii_adjustment = ascii_mode if not adjustment else 0 + prettyF.baseline = max_upper + sign_height//2 + ascii_adjustment + + prettyF.binding = prettyForm.MUL + return prettyF + + def _print_Limit(self, l): + e, z, z0, dir = l.args + + E = self._print(e) + if precedence(e) <= PRECEDENCE["Mul"]: + E = prettyForm(*E.parens('(', ')')) + Lim = prettyForm('lim') + + LimArg = self._print(z) + if self._use_unicode: + LimArg = prettyForm(*LimArg.right(f"{xobj('-', 1)}{pretty_atom('Arrow')}")) + else: + LimArg = prettyForm(*LimArg.right('->')) + LimArg = prettyForm(*LimArg.right(self._print(z0))) + + if str(dir) == '+-' or z0 in (S.Infinity, S.NegativeInfinity): + dir = "" + else: + if self._use_unicode: + dir = pretty_atom('SuperscriptPlus') if str(dir) == "+" else pretty_atom('SuperscriptMinus') + + LimArg = prettyForm(*LimArg.right(self._print(dir))) + + Lim = prettyForm(*Lim.below(LimArg)) + Lim = prettyForm(*Lim.right(E), binding=prettyForm.MUL) + + return Lim + + def _print_matrix_contents(self, e): + """ + This method factors out what is essentially grid printing. + """ + M = e # matrix + Ms = {} # i,j -> pretty(M[i,j]) + for i in range(M.rows): + for j in range(M.cols): + Ms[i, j] = self._print(M[i, j]) + + # h- and v- spacers + hsep = 2 + vsep = 1 + + # max width for columns + maxw = [-1] * M.cols + + for j in range(M.cols): + maxw[j] = max([Ms[i, j].width() for i in range(M.rows)] or [0]) + + # drawing result + D = None + + for i in range(M.rows): + + D_row = None + for j in range(M.cols): + s = Ms[i, j] + + # reshape s to maxw + # XXX this should be generalized, and go to stringPict.reshape ? + assert s.width() <= maxw[j] + + # hcenter it, +0.5 to the right 2 + # ( it's better to align formula starts for say 0 and r ) + # XXX this is not good in all cases -- maybe introduce vbaseline? + left, right = center_pad(s.width(), maxw[j]) + + s = prettyForm(*s.right(right)) + s = prettyForm(*s.left(left)) + + # we don't need vcenter cells -- this is automatically done in + # a pretty way because when their baselines are taking into + # account in .right() + + if D_row is None: + D_row = s # first box in a row + continue + + D_row = prettyForm(*D_row.right(' '*hsep)) # h-spacer + D_row = prettyForm(*D_row.right(s)) + + if D is None: + D = D_row # first row in a picture + continue + + # v-spacer + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + + D = prettyForm(*D.below(D_row)) + + if D is None: + D = prettyForm('') # Empty Matrix + + return D + + def _print_MatrixBase(self, e, lparens='[', rparens=']'): + D = self._print_matrix_contents(e) + D.baseline = D.height()//2 + D = prettyForm(*D.parens(lparens, rparens)) + return D + + def _print_Determinant(self, e): + mat = e.arg + if mat.is_MatrixExpr: + from sympy.matrices.expressions.blockmatrix import BlockMatrix + if isinstance(mat, BlockMatrix): + return self._print_MatrixBase(mat.blocks, lparens='|', rparens='|') + D = self._print(mat) + D.baseline = D.height()//2 + return prettyForm(*D.parens('|', '|')) + else: + return self._print_MatrixBase(mat, lparens='|', rparens='|') + + def _print_TensorProduct(self, expr): + # This should somehow share the code with _print_WedgeProduct: + if self._use_unicode: + circled_times = "\u2297" + else: + circled_times = ".*" + return self._print_seq(expr.args, None, None, circled_times, + parenthesize=lambda x: precedence_traditional(x) <= PRECEDENCE["Mul"]) + + def _print_WedgeProduct(self, expr): + # This should somehow share the code with _print_TensorProduct: + if self._use_unicode: + wedge_symbol = "\u2227" + else: + wedge_symbol = '/\\' + return self._print_seq(expr.args, None, None, wedge_symbol, + parenthesize=lambda x: precedence_traditional(x) <= PRECEDENCE["Mul"]) + + def _print_Trace(self, e): + D = self._print(e.arg) + D = prettyForm(*D.parens('(',')')) + D.baseline = D.height()//2 + D = prettyForm(*D.left('\n'*(0) + 'tr')) + return D + + + def _print_MatrixElement(self, expr): + from sympy.matrices import MatrixSymbol + if (isinstance(expr.parent, MatrixSymbol) + and expr.i.is_number and expr.j.is_number): + return self._print( + Symbol(expr.parent.name + '_%d%d' % (expr.i, expr.j))) + else: + prettyFunc = self._print(expr.parent) + prettyFunc = prettyForm(*prettyFunc.parens()) + prettyIndices = self._print_seq((expr.i, expr.j), delimiter=', ' + ).parens(left='[', right=']')[0] + pform = prettyForm(binding=prettyForm.FUNC, + *stringPict.next(prettyFunc, prettyIndices)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyIndices + + return pform + + + def _print_MatrixSlice(self, m): + # XXX works only for applied functions + from sympy.matrices import MatrixSymbol + prettyFunc = self._print(m.parent) + if not isinstance(m.parent, MatrixSymbol): + prettyFunc = prettyForm(*prettyFunc.parens()) + def ppslice(x, dim): + x = list(x) + if x[2] == 1: + del x[2] + if x[0] == 0: + x[0] = '' + if x[1] == dim: + x[1] = '' + return prettyForm(*self._print_seq(x, delimiter=':')) + prettyArgs = self._print_seq((ppslice(m.rowslice, m.parent.rows), + ppslice(m.colslice, m.parent.cols)), delimiter=', ').parens(left='[', right=']')[0] + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_Transpose(self, expr): + mat = expr.arg + pform = self._print(mat) + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + pform = prettyForm(*pform.parens()) + pform = pform**(prettyForm('T')) + return pform + + def _print_Adjoint(self, expr): + mat = expr.arg + pform = self._print(mat) + if self._use_unicode: + dag = prettyForm(pretty_atom('Dagger')) + else: + dag = prettyForm('+') + from sympy.matrices import MatrixSymbol, BlockMatrix + if (not isinstance(mat, MatrixSymbol) and + not isinstance(mat, BlockMatrix) and mat.is_MatrixExpr): + pform = prettyForm(*pform.parens()) + pform = pform**dag + return pform + + def _print_BlockMatrix(self, B): + if B.blocks.shape == (1, 1): + return self._print(B.blocks[0, 0]) + return self._print(B.blocks) + + def _print_MatAdd(self, expr): + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + coeff = item.as_coeff_mmul()[0] + if S(coeff).could_extract_minus_sign(): + s = prettyForm(*stringPict.next(s, ' ')) + pform = self._print(item) + else: + s = prettyForm(*stringPict.next(s, ' + ')) + s = prettyForm(*stringPict.next(s, pform)) + + return s + + def _print_MatMul(self, expr): + args = list(expr.args) + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.kronecker import KroneckerProduct + from sympy.matrices.expressions.matadd import MatAdd + for i, a in enumerate(args): + if (isinstance(a, (Add, MatAdd, HadamardProduct, KroneckerProduct)) + and len(expr.args) > 1): + args[i] = prettyForm(*self._print(a).parens()) + else: + args[i] = self._print(a) + + return prettyForm.__mul__(*args) + + def _print_Identity(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('IdentityMatrix')) + else: + return prettyForm('I') + + def _print_ZeroMatrix(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('ZeroMatrix')) + else: + return prettyForm('0') + + def _print_OneMatrix(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom("OneMatrix")) + else: + return prettyForm('1') + + def _print_DotProduct(self, expr): + args = list(expr.args) + + for i, a in enumerate(args): + args[i] = self._print(a) + return prettyForm.__mul__(*args) + + def _print_MatPow(self, expr): + pform = self._print(expr.base) + from sympy.matrices import MatrixSymbol + if not isinstance(expr.base, MatrixSymbol) and expr.base.is_MatrixExpr: + pform = prettyForm(*pform.parens()) + pform = pform**(self._print(expr.exp)) + return pform + + def _print_HadamardProduct(self, expr): + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.matadd import MatAdd + from sympy.matrices.expressions.matmul import MatMul + if self._use_unicode: + delim = pretty_atom('Ring') + else: + delim = '.*' + return self._print_seq(expr.args, None, None, delim, + parenthesize=lambda x: isinstance(x, (MatAdd, MatMul, HadamardProduct))) + + def _print_HadamardPower(self, expr): + # from sympy import MatAdd, MatMul + if self._use_unicode: + circ = pretty_atom('Ring') + else: + circ = self._print('.') + pretty_base = self._print(expr.base) + pretty_exp = self._print(expr.exp) + if precedence(expr.exp) < PRECEDENCE["Mul"]: + pretty_exp = prettyForm(*pretty_exp.parens()) + pretty_circ_exp = prettyForm( + binding=prettyForm.LINE, + *stringPict.next(circ, pretty_exp) + ) + return pretty_base**pretty_circ_exp + + def _print_KroneckerProduct(self, expr): + from sympy.matrices.expressions.matadd import MatAdd + from sympy.matrices.expressions.matmul import MatMul + if self._use_unicode: + delim = f" {pretty_atom('TensorProduct')} " + else: + delim = ' x ' + return self._print_seq(expr.args, None, None, delim, + parenthesize=lambda x: isinstance(x, (MatAdd, MatMul))) + + def _print_FunctionMatrix(self, X): + D = self._print(X.lamda.expr) + D = prettyForm(*D.parens('[', ']')) + return D + + def _print_TransferFunction(self, expr): + if not expr.num == 1: + num, den = expr.num, expr.den + res = Mul(num, Pow(den, -1, evaluate=False), evaluate=False) + return self._print_Mul(res) + else: + return self._print(1)/self._print(expr.den) + + def _print_Series(self, expr): + args = list(expr.args) + for i, a in enumerate(expr.args): + args[i] = prettyForm(*self._print(a).parens()) + return prettyForm.__mul__(*args) + + def _print_MIMOSeries(self, expr): + from sympy.physics.control.lti import MIMOParallel + args = list(expr.args) + pretty_args = [] + for a in reversed(args): + if (isinstance(a, MIMOParallel) and len(expr.args) > 1): + expression = self._print(a) + expression.baseline = expression.height()//2 + pretty_args.append(prettyForm(*expression.parens())) + else: + expression = self._print(a) + expression.baseline = expression.height()//2 + pretty_args.append(expression) + return prettyForm.__mul__(*pretty_args) + + def _print_Parallel(self, expr): + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + s = prettyForm(*stringPict.next(s)) + s.baseline = s.height()//2 + s = prettyForm(*stringPict.next(s, ' + ')) + s = prettyForm(*stringPict.next(s, pform)) + return s + + def _print_MIMOParallel(self, expr): + from sympy.physics.control.lti import TransferFunctionMatrix + s = None + for item in expr.args: + pform = self._print(item) + if s is None: + s = pform # First element + else: + s = prettyForm(*stringPict.next(s)) + s.baseline = s.height()//2 + s = prettyForm(*stringPict.next(s, ' + ')) + if isinstance(item, TransferFunctionMatrix): + s.baseline = s.height() - 1 + s = prettyForm(*stringPict.next(s, pform)) + # s.baseline = s.height()//2 + return s + + def _print_Feedback(self, expr): + from sympy.physics.control import TransferFunction, Series + + num, tf = expr.sys1, TransferFunction(1, 1, expr.var) + num_arg_list = list(num.args) if isinstance(num, Series) else [num] + den_arg_list = list(expr.sys2.args) if \ + isinstance(expr.sys2, Series) else [expr.sys2] + + if isinstance(num, Series) and isinstance(expr.sys2, Series): + den = Series(*num_arg_list, *den_arg_list) + elif isinstance(num, Series) and isinstance(expr.sys2, TransferFunction): + if expr.sys2 == tf: + den = Series(*num_arg_list) + else: + den = Series(*num_arg_list, expr.sys2) + elif isinstance(num, TransferFunction) and isinstance(expr.sys2, Series): + if num == tf: + den = Series(*den_arg_list) + else: + den = Series(num, *den_arg_list) + else: + if num == tf: + den = Series(*den_arg_list) + elif expr.sys2 == tf: + den = Series(*num_arg_list) + else: + den = Series(*num_arg_list, *den_arg_list) + + denom = prettyForm(*stringPict.next(self._print(tf))) + denom.baseline = denom.height()//2 + denom = prettyForm(*stringPict.next(denom, ' + ')) if expr.sign == -1 \ + else prettyForm(*stringPict.next(denom, ' - ')) + denom = prettyForm(*stringPict.next(denom, self._print(den))) + + return self._print(num)/denom + + def _print_MIMOFeedback(self, expr): + from sympy.physics.control import MIMOSeries, TransferFunctionMatrix + + inv_mat = self._print(MIMOSeries(expr.sys2, expr.sys1)) + plant = self._print(expr.sys1) + _feedback = prettyForm(*stringPict.next(inv_mat)) + _feedback = prettyForm(*stringPict.right("I + ", _feedback)) if expr.sign == -1 \ + else prettyForm(*stringPict.right("I - ", _feedback)) + _feedback = prettyForm(*stringPict.parens(_feedback)) + _feedback.baseline = 0 + _feedback = prettyForm(*stringPict.right(_feedback, '-1 ')) + _feedback.baseline = _feedback.height()//2 + _feedback = prettyForm.__mul__(_feedback, prettyForm(" ")) + if isinstance(expr.sys1, TransferFunctionMatrix): + _feedback.baseline = _feedback.height() - 1 + _feedback = prettyForm(*stringPict.next(_feedback, plant)) + return _feedback + + def _print_TransferFunctionMatrix(self, expr): + mat = self._print(expr._expr_mat) + mat.baseline = mat.height() - 1 + subscript = greek_unicode['tau'] if self._use_unicode else r'{t}' + mat = prettyForm(*mat.right(subscript)) + return mat + + def _print_StateSpace(self, expr): + from sympy.matrices.expressions.blockmatrix import BlockMatrix + A = expr._A + B = expr._B + C = expr._C + D = expr._D + mat = BlockMatrix([[A, B], [C, D]]) + return self._print(mat.blocks) + + def _print_BasisDependent(self, expr): + from sympy.vector import Vector + + if not self._use_unicode: + raise NotImplementedError("ASCII pretty printing of BasisDependent is not implemented") + + if expr == expr.zero: + return prettyForm(expr.zero._pretty_form) + o1 = [] + vectstrs = [] + if isinstance(expr, Vector): + items = expr.separate().items() + else: + items = [(0, expr)] + for system, vect in items: + inneritems = list(vect.components.items()) + inneritems.sort(key = lambda x: x[0].__str__()) + for k, v in inneritems: + #if the coef of the basis vector is 1 + #we skip the 1 + if v == 1: + o1.append("" + + k._pretty_form) + #Same for -1 + elif v == -1: + o1.append("(-1) " + + k._pretty_form) + #For a general expr + else: + #We always wrap the measure numbers in + #parentheses + arg_str = self._print( + v).parens()[0] + + o1.append(arg_str + ' ' + k._pretty_form) + vectstrs.append(k._pretty_form) + + #outstr = u("").join(o1) + if o1[0].startswith(" + "): + o1[0] = o1[0][3:] + elif o1[0].startswith(" "): + o1[0] = o1[0][1:] + #Fixing the newlines + lengths = [] + strs = [''] + flag = [] + for i, partstr in enumerate(o1): + flag.append(0) + # XXX: What is this hack? + if '\n' in partstr: + tempstr = partstr + tempstr = tempstr.replace(vectstrs[i], '') + if xobj(')_ext', 1) in tempstr: # If scalar is a fraction + for paren in range(len(tempstr)): + flag[i] = 1 + if tempstr[paren] == xobj(')_ext', 1) and tempstr[paren + 1] == '\n': + # We want to place the vector string after all the right parentheses, because + # otherwise, the vector will be in the middle of the string + tempstr = tempstr[:paren] + xobj(')_ext', 1)\ + + ' ' + vectstrs[i] + tempstr[paren + 1:] + break + elif xobj(')_lower_hook', 1) in tempstr: + # We want to place the vector string after all the right parentheses, because + # otherwise, the vector will be in the middle of the string. For this reason, + # we insert the vector string at the rightmost index. + index = tempstr.rfind(xobj(')_lower_hook', 1)) + if index != -1: # then this character was found in this string + flag[i] = 1 + tempstr = tempstr[:index] + xobj(')_lower_hook', 1)\ + + ' ' + vectstrs[i] + tempstr[index + 1:] + o1[i] = tempstr + + o1 = [x.split('\n') for x in o1] + n_newlines = max(len(x) for x in o1) # Width of part in its pretty form + + if 1 in flag: # If there was a fractional scalar + for i, parts in enumerate(o1): + if len(parts) == 1: # If part has no newline + parts.insert(0, ' ' * (len(parts[0]))) + flag[i] = 1 + + for i, parts in enumerate(o1): + lengths.append(len(parts[flag[i]])) + for j in range(n_newlines): + if j+1 <= len(parts): + if j >= len(strs): + strs.append(' ' * (sum(lengths[:-1]) + + 3*(len(lengths)-1))) + if j == flag[i]: + strs[flag[i]] += parts[flag[i]] + ' + ' + else: + strs[j] += parts[j] + ' '*(lengths[-1] - + len(parts[j])+ + 3) + else: + if j >= len(strs): + strs.append(' ' * (sum(lengths[:-1]) + + 3*(len(lengths)-1))) + strs[j] += ' '*(lengths[-1]+3) + + return prettyForm('\n'.join([s[:-3] for s in strs])) + + def _print_NDimArray(self, expr): + from sympy.matrices.immutable import ImmutableMatrix + + if expr.rank() == 0: + return self._print(expr[()]) + + level_str = [[]] + [[] for i in range(expr.rank())] + shape_ranges = [list(range(i)) for i in expr.shape] + # leave eventual matrix elements unflattened + mat = lambda x: ImmutableMatrix(x, evaluate=False) + for outer_i in itertools.product(*shape_ranges): + level_str[-1].append(expr[outer_i]) + even = True + for back_outer_i in range(expr.rank()-1, -1, -1): + if len(level_str[back_outer_i+1]) < expr.shape[back_outer_i]: + break + if even: + level_str[back_outer_i].append(level_str[back_outer_i+1]) + else: + level_str[back_outer_i].append(mat( + level_str[back_outer_i+1])) + if len(level_str[back_outer_i + 1]) == 1: + level_str[back_outer_i][-1] = mat( + [[level_str[back_outer_i][-1]]]) + even = not even + level_str[back_outer_i+1] = [] + + out_expr = level_str[0][0] + if expr.rank() % 2 == 1: + out_expr = mat([out_expr]) + + return self._print(out_expr) + + def _printer_tensor_indices(self, name, indices, index_map={}): + center = stringPict(name) + top = stringPict(" "*center.width()) + bot = stringPict(" "*center.width()) + + last_valence = None + prev_map = None + + for index in indices: + indpic = self._print(index.args[0]) + if ((index in index_map) or prev_map) and last_valence == index.is_up: + if index.is_up: + top = prettyForm(*stringPict.next(top, ",")) + else: + bot = prettyForm(*stringPict.next(bot, ",")) + if index in index_map: + indpic = prettyForm(*stringPict.next(indpic, "=")) + indpic = prettyForm(*stringPict.next(indpic, self._print(index_map[index]))) + prev_map = True + else: + prev_map = False + if index.is_up: + top = stringPict(*top.right(indpic)) + center = stringPict(*center.right(" "*indpic.width())) + bot = stringPict(*bot.right(" "*indpic.width())) + else: + bot = stringPict(*bot.right(indpic)) + center = stringPict(*center.right(" "*indpic.width())) + top = stringPict(*top.right(" "*indpic.width())) + last_valence = index.is_up + + pict = prettyForm(*center.above(top)) + pict = prettyForm(*pict.below(bot)) + return pict + + def _print_Tensor(self, expr): + name = expr.args[0].name + indices = expr.get_indices() + return self._printer_tensor_indices(name, indices) + + def _print_TensorElement(self, expr): + name = expr.expr.args[0].name + indices = expr.expr.get_indices() + index_map = expr.index_map + return self._printer_tensor_indices(name, indices, index_map) + + def _print_TensMul(self, expr): + sign, args = expr._get_args_for_traditional_printer() + args = [ + prettyForm(*self._print(i).parens()) if + precedence_traditional(i) < PRECEDENCE["Mul"] else self._print(i) + for i in args + ] + pform = prettyForm.__mul__(*args) + if sign: + return prettyForm(*pform.left(sign)) + else: + return pform + + def _print_TensAdd(self, expr): + args = [ + prettyForm(*self._print(i).parens()) if + precedence_traditional(i) < PRECEDENCE["Mul"] else self._print(i) + for i in expr.args + ] + return prettyForm.__add__(*args) + + def _print_TensorIndex(self, expr): + sym = expr.args[0] + if not expr.is_up: + sym = -sym + return self._print(sym) + + def _print_PartialDerivative(self, deriv): + if self._use_unicode: + deriv_symbol = U('PARTIAL DIFFERENTIAL') + else: + deriv_symbol = r'd' + x = None + + for variable in reversed(deriv.variables): + s = self._print(variable) + ds = prettyForm(*s.left(deriv_symbol)) + + if x is None: + x = ds + else: + x = prettyForm(*x.right(' ')) + x = prettyForm(*x.right(ds)) + + f = prettyForm( + binding=prettyForm.FUNC, *self._print(deriv.expr).parens()) + + pform = prettyForm(deriv_symbol) + + if len(deriv.variables) > 1: + pform = pform**self._print(len(deriv.variables)) + + pform = prettyForm(*pform.below(stringPict.LINE, x)) + pform.baseline = pform.baseline + 1 + pform = prettyForm(*stringPict.next(pform, f)) + pform.binding = prettyForm.MUL + + return pform + + def _print_Piecewise(self, pexpr): + + P = {} + for n, ec in enumerate(pexpr.args): + P[n, 0] = self._print(ec.expr) + if ec.cond == True: + P[n, 1] = prettyForm('otherwise') + else: + P[n, 1] = prettyForm( + *prettyForm('for ').right(self._print(ec.cond))) + hsep = 2 + vsep = 1 + len_args = len(pexpr.args) + + # max widths + maxw = [max(P[i, j].width() for i in range(len_args)) + for j in range(2)] + + # FIXME: Refactor this code and matrix into some tabular environment. + # drawing result + D = None + + for i in range(len_args): + D_row = None + for j in range(2): + p = P[i, j] + assert p.width() <= maxw[j] + + wdelta = maxw[j] - p.width() + wleft = wdelta // 2 + wright = wdelta - wleft + + p = prettyForm(*p.right(' '*wright)) + p = prettyForm(*p.left(' '*wleft)) + + if D_row is None: + D_row = p + continue + + D_row = prettyForm(*D_row.right(' '*hsep)) # h-spacer + D_row = prettyForm(*D_row.right(p)) + if D is None: + D = D_row # first row in a picture + continue + + # v-spacer + for _ in range(vsep): + D = prettyForm(*D.below(' ')) + + D = prettyForm(*D.below(D_row)) + + D = prettyForm(*D.parens('{', '')) + D.baseline = D.height()//2 + D.binding = prettyForm.OPEN + return D + + def _print_ITE(self, ite): + from sympy.functions.elementary.piecewise import Piecewise + return self._print(ite.rewrite(Piecewise)) + + def _hprint_vec(self, v): + D = None + + for a in v: + p = a + if D is None: + D = p + else: + D = prettyForm(*D.right(', ')) + D = prettyForm(*D.right(p)) + if D is None: + D = stringPict(' ') + + return D + + def _hprint_vseparator(self, p1, p2, left=None, right=None, delimiter='', ifascii_nougly=False): + if ifascii_nougly and not self._use_unicode: + return self._print_seq((p1, '|', p2), left=left, right=right, + delimiter=delimiter, ifascii_nougly=True) + tmp = self._print_seq((p1, p2,), left=left, right=right, delimiter=delimiter) + sep = stringPict(vobj('|', tmp.height()), baseline=tmp.baseline) + return self._print_seq((p1, sep, p2), left=left, right=right, + delimiter=delimiter) + + def _print_hyper(self, e): + # FIXME refactor Matrix, Piecewise, and this into a tabular environment + ap = [self._print(a) for a in e.ap] + bq = [self._print(b) for b in e.bq] + + P = self._print(e.argument) + P.baseline = P.height()//2 + + # Drawing result - first create the ap, bq vectors + D = None + for v in [ap, bq]: + D_row = self._hprint_vec(v) + if D is None: + D = D_row # first row in a picture + else: + D = prettyForm(*D.below(' ')) + D = prettyForm(*D.below(D_row)) + + # make sure that the argument `z' is centred vertically + D.baseline = D.height()//2 + + # insert horizontal separator + P = prettyForm(*P.left(' ')) + D = prettyForm(*D.right(' ')) + + # insert separating `|` + D = self._hprint_vseparator(D, P) + + # add parens + D = prettyForm(*D.parens('(', ')')) + + # create the F symbol + above = D.height()//2 - 1 + below = D.height() - above - 1 + + sz, t, b, add, img = annotated('F') + F = prettyForm('\n' * (above - t) + img + '\n' * (below - b), + baseline=above + sz) + add = (sz + 1)//2 + + F = prettyForm(*F.left(self._print(len(e.ap)))) + F = prettyForm(*F.right(self._print(len(e.bq)))) + F.baseline = above + add + + D = prettyForm(*F.right(' ', D)) + + return D + + def _print_meijerg(self, e): + # FIXME refactor Matrix, Piecewise, and this into a tabular environment + + v = {} + v[(0, 0)] = [self._print(a) for a in e.an] + v[(0, 1)] = [self._print(a) for a in e.aother] + v[(1, 0)] = [self._print(b) for b in e.bm] + v[(1, 1)] = [self._print(b) for b in e.bother] + + P = self._print(e.argument) + P.baseline = P.height()//2 + + vp = {} + for idx in v: + vp[idx] = self._hprint_vec(v[idx]) + + for i in range(2): + maxw = max(vp[(0, i)].width(), vp[(1, i)].width()) + for j in range(2): + s = vp[(j, i)] + left = (maxw - s.width()) // 2 + right = maxw - left - s.width() + s = prettyForm(*s.left(' ' * left)) + s = prettyForm(*s.right(' ' * right)) + vp[(j, i)] = s + + D1 = prettyForm(*vp[(0, 0)].right(' ', vp[(0, 1)])) + D1 = prettyForm(*D1.below(' ')) + D2 = prettyForm(*vp[(1, 0)].right(' ', vp[(1, 1)])) + D = prettyForm(*D1.below(D2)) + + # make sure that the argument `z' is centred vertically + D.baseline = D.height()//2 + + # insert horizontal separator + P = prettyForm(*P.left(' ')) + D = prettyForm(*D.right(' ')) + + # insert separating `|` + D = self._hprint_vseparator(D, P) + + # add parens + D = prettyForm(*D.parens('(', ')')) + + # create the G symbol + above = D.height()//2 - 1 + below = D.height() - above - 1 + + sz, t, b, add, img = annotated('G') + F = prettyForm('\n' * (above - t) + img + '\n' * (below - b), + baseline=above + sz) + + pp = self._print(len(e.ap)) + pq = self._print(len(e.bq)) + pm = self._print(len(e.bm)) + pn = self._print(len(e.an)) + + def adjust(p1, p2): + diff = p1.width() - p2.width() + if diff == 0: + return p1, p2 + elif diff > 0: + return p1, prettyForm(*p2.left(' '*diff)) + else: + return prettyForm(*p1.left(' '*-diff)), p2 + pp, pm = adjust(pp, pm) + pq, pn = adjust(pq, pn) + pu = prettyForm(*pm.right(', ', pn)) + pl = prettyForm(*pp.right(', ', pq)) + + ht = F.baseline - above - 2 + if ht > 0: + pu = prettyForm(*pu.below('\n'*ht)) + p = prettyForm(*pu.below(pl)) + + F.baseline = above + F = prettyForm(*F.right(p)) + + F.baseline = above + add + + D = prettyForm(*F.right(' ', D)) + + return D + + def _print_ExpBase(self, e): + # TODO should exp_polar be printed differently? + # what about exp_polar(0), exp_polar(1)? + base = prettyForm(pretty_atom('Exp1', 'e')) + return base ** self._print(e.args[0]) + + def _print_Exp1(self, e): + return prettyForm(pretty_atom('Exp1', 'e')) + + def _print_Function(self, e, sort=False, func_name=None, left='(', + right=')'): + # optional argument func_name for supplying custom names + # XXX works only for applied functions + return self._helper_print_function(e.func, e.args, sort=sort, func_name=func_name, left=left, right=right) + + def _print_mathieuc(self, e): + return self._print_Function(e, func_name='C') + + def _print_mathieus(self, e): + return self._print_Function(e, func_name='S') + + def _print_mathieucprime(self, e): + return self._print_Function(e, func_name="C'") + + def _print_mathieusprime(self, e): + return self._print_Function(e, func_name="S'") + + def _helper_print_function(self, func, args, sort=False, func_name=None, + delimiter=', ', elementwise=False, left='(', + right=')'): + if sort: + args = sorted(args, key=default_sort_key) + + if not func_name and hasattr(func, "__name__"): + func_name = func.__name__ + + if func_name: + prettyFunc = self._print(Symbol(func_name)) + else: + prettyFunc = prettyForm(*self._print(func).parens()) + + if elementwise: + if self._use_unicode: + circ = pretty_atom('Modifier Letter Low Ring') + else: + circ = '.' + circ = self._print(circ) + prettyFunc = prettyForm( + binding=prettyForm.LINE, + *stringPict.next(prettyFunc, circ) + ) + + prettyArgs = prettyForm(*self._print_seq(args, delimiter=delimiter).parens( + left=left, right=right)) + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_ElementwiseApplyFunction(self, e): + func = e.function + arg = e.expr + args = [arg] + return self._helper_print_function(func, args, delimiter="", elementwise=True) + + @property + def _special_function_classes(self): + from sympy.functions.special.tensor_functions import KroneckerDelta + from sympy.functions.special.gamma_functions import gamma, lowergamma + from sympy.functions.special.zeta_functions import lerchphi + from sympy.functions.special.beta_functions import beta + from sympy.functions.special.delta_functions import DiracDelta + from sympy.functions.special.error_functions import Chi + return {KroneckerDelta: [greek_unicode['delta'], 'delta'], + gamma: [greek_unicode['Gamma'], 'Gamma'], + lerchphi: [greek_unicode['Phi'], 'lerchphi'], + lowergamma: [greek_unicode['gamma'], 'gamma'], + beta: [greek_unicode['Beta'], 'B'], + DiracDelta: [greek_unicode['delta'], 'delta'], + Chi: ['Chi', 'Chi']} + + def _print_FunctionClass(self, expr): + for cls in self._special_function_classes: + if issubclass(expr, cls) and expr.__name__ == cls.__name__: + if self._use_unicode: + return prettyForm(self._special_function_classes[cls][0]) + else: + return prettyForm(self._special_function_classes[cls][1]) + func_name = expr.__name__ + return prettyForm(pretty_symbol(func_name)) + + def _print_GeometryEntity(self, expr): + # GeometryEntity is based on Tuple but should not print like a Tuple + return self.emptyPrinter(expr) + + def _print_polylog(self, e): + subscript = self._print(e.args[0]) + if self._use_unicode and is_subscriptable_in_unicode(subscript): + return self._print_Function(Function('Li_%s' % subscript)(e.args[1])) + return self._print_Function(e) + + def _print_lerchphi(self, e): + func_name = greek_unicode['Phi'] if self._use_unicode else 'lerchphi' + return self._print_Function(e, func_name=func_name) + + def _print_dirichlet_eta(self, e): + func_name = greek_unicode['eta'] if self._use_unicode else 'dirichlet_eta' + return self._print_Function(e, func_name=func_name) + + def _print_Heaviside(self, e): + func_name = greek_unicode['theta'] if self._use_unicode else 'Heaviside' + if e.args[1] is S.Half: + pform = prettyForm(*self._print(e.args[0]).parens()) + pform = prettyForm(*pform.left(func_name)) + return pform + else: + return self._print_Function(e, func_name=func_name) + + def _print_fresnels(self, e): + return self._print_Function(e, func_name="S") + + def _print_fresnelc(self, e): + return self._print_Function(e, func_name="C") + + def _print_airyai(self, e): + return self._print_Function(e, func_name="Ai") + + def _print_airybi(self, e): + return self._print_Function(e, func_name="Bi") + + def _print_airyaiprime(self, e): + return self._print_Function(e, func_name="Ai'") + + def _print_airybiprime(self, e): + return self._print_Function(e, func_name="Bi'") + + def _print_LambertW(self, e): + return self._print_Function(e, func_name="W") + + def _print_Covariance(self, e): + return self._print_Function(e, func_name="Cov") + + def _print_Variance(self, e): + return self._print_Function(e, func_name="Var") + + def _print_Probability(self, e): + return self._print_Function(e, func_name="P") + + def _print_Expectation(self, e): + return self._print_Function(e, func_name="E", left='[', right=']') + + def _print_Lambda(self, e): + expr = e.expr + sig = e.signature + if self._use_unicode: + arrow = f" {pretty_atom('ArrowFromBar')} " + else: + arrow = " -> " + if len(sig) == 1 and sig[0].is_symbol: + sig = sig[0] + var_form = self._print(sig) + + return prettyForm(*stringPict.next(var_form, arrow, self._print(expr)), binding=8) + + def _print_Order(self, expr): + pform = self._print(expr.expr) + if (expr.point and any(p != S.Zero for p in expr.point)) or \ + len(expr.variables) > 1: + pform = prettyForm(*pform.right("; ")) + if len(expr.variables) > 1: + pform = prettyForm(*pform.right(self._print(expr.variables))) + elif len(expr.variables): + pform = prettyForm(*pform.right(self._print(expr.variables[0]))) + if self._use_unicode: + pform = prettyForm(*pform.right(f" {pretty_atom('Arrow')} ")) + else: + pform = prettyForm(*pform.right(" -> ")) + if len(expr.point) > 1: + pform = prettyForm(*pform.right(self._print(expr.point))) + else: + pform = prettyForm(*pform.right(self._print(expr.point[0]))) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left("O")) + return pform + + def _print_SingularityFunction(self, e): + if self._use_unicode: + shift = self._print(e.args[0]-e.args[1]) + n = self._print(e.args[2]) + base = prettyForm("<") + base = prettyForm(*base.right(shift)) + base = prettyForm(*base.right(">")) + pform = base**n + return pform + else: + n = self._print(e.args[2]) + shift = self._print(e.args[0]-e.args[1]) + base = self._print_seq(shift, "<", ">", ' ') + return base**n + + def _print_beta(self, e): + func_name = greek_unicode['Beta'] if self._use_unicode else 'B' + return self._print_Function(e, func_name=func_name) + + def _print_betainc(self, e): + func_name = "B'" + return self._print_Function(e, func_name=func_name) + + def _print_betainc_regularized(self, e): + func_name = 'I' + return self._print_Function(e, func_name=func_name) + + def _print_gamma(self, e): + func_name = greek_unicode['Gamma'] if self._use_unicode else 'Gamma' + return self._print_Function(e, func_name=func_name) + + def _print_uppergamma(self, e): + func_name = greek_unicode['Gamma'] if self._use_unicode else 'Gamma' + return self._print_Function(e, func_name=func_name) + + def _print_lowergamma(self, e): + func_name = greek_unicode['gamma'] if self._use_unicode else 'lowergamma' + return self._print_Function(e, func_name=func_name) + + def _print_DiracDelta(self, e): + if self._use_unicode: + if len(e.args) == 2: + a = prettyForm(greek_unicode['delta']) + b = self._print(e.args[1]) + b = prettyForm(*b.parens()) + c = self._print(e.args[0]) + c = prettyForm(*c.parens()) + pform = a**b + pform = prettyForm(*pform.right(' ')) + pform = prettyForm(*pform.right(c)) + return pform + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(greek_unicode['delta'])) + return pform + else: + return self._print_Function(e) + + def _print_expint(self, e): + subscript = self._print(e.args[0]) + if self._use_unicode and is_subscriptable_in_unicode(subscript): + return self._print_Function(Function('E_%s' % subscript)(e.args[1])) + return self._print_Function(e) + + def _print_Chi(self, e): + # This needs a special case since otherwise it comes out as greek + # letter chi... + prettyFunc = prettyForm("Chi") + prettyArgs = prettyForm(*self._print_seq(e.args).parens()) + + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + + # store pform parts so it can be reassembled e.g. when powered + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + + return pform + + def _print_elliptic_e(self, e): + pforma0 = self._print(e.args[0]) + if len(e.args) == 1: + pform = pforma0 + else: + pforma1 = self._print(e.args[1]) + pform = self._hprint_vseparator(pforma0, pforma1) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('E')) + return pform + + def _print_elliptic_k(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('K')) + return pform + + def _print_elliptic_f(self, e): + pforma0 = self._print(e.args[0]) + pforma1 = self._print(e.args[1]) + pform = self._hprint_vseparator(pforma0, pforma1) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left('F')) + return pform + + def _print_elliptic_pi(self, e): + name = greek_unicode['Pi'] if self._use_unicode else 'Pi' + pforma0 = self._print(e.args[0]) + pforma1 = self._print(e.args[1]) + if len(e.args) == 2: + pform = self._hprint_vseparator(pforma0, pforma1) + else: + pforma2 = self._print(e.args[2]) + pforma = self._hprint_vseparator(pforma1, pforma2, ifascii_nougly=False) + pforma = prettyForm(*pforma.left('; ')) + pform = prettyForm(*pforma.left(pforma0)) + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(name)) + return pform + + def _print_GoldenRatio(self, expr): + if self._use_unicode: + return prettyForm(pretty_symbol('phi')) + return self._print(Symbol("GoldenRatio")) + + def _print_EulerGamma(self, expr): + if self._use_unicode: + return prettyForm(pretty_symbol('gamma')) + return self._print(Symbol("EulerGamma")) + + def _print_Catalan(self, expr): + return self._print(Symbol("G")) + + def _print_Mod(self, expr): + pform = self._print(expr.args[0]) + if pform.binding > prettyForm.MUL: + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.right(' mod ')) + pform = prettyForm(*pform.right(self._print(expr.args[1]))) + pform.binding = prettyForm.OPEN + return pform + + def _print_Add(self, expr, order=None): + terms = self._as_ordered_terms(expr, order=order) + pforms, indices = [], [] + + def pretty_negative(pform, index): + """Prepend a minus sign to a pretty form. """ + #TODO: Move this code to prettyForm + if index == 0: + if pform.height() > 1: + pform_neg = '- ' + else: + pform_neg = '-' + else: + pform_neg = ' - ' + + if (pform.binding > prettyForm.NEG + or pform.binding == prettyForm.ADD): + p = stringPict(*pform.parens()) + else: + p = pform + p = stringPict.next(pform_neg, p) + # Lower the binding to NEG, even if it was higher. Otherwise, it + # will print as a + ( - (b)), instead of a - (b). + return prettyForm(binding=prettyForm.NEG, *p) + + for i, term in enumerate(terms): + if term.is_Mul and term.could_extract_minus_sign(): + coeff, other = term.as_coeff_mul(rational=False) + if coeff == -1: + negterm = Mul(*other, evaluate=False) + else: + negterm = Mul(-coeff, *other, evaluate=False) + pform = self._print(negterm) + pforms.append(pretty_negative(pform, i)) + elif term.is_Rational and term.q > 1: + pforms.append(None) + indices.append(i) + elif term.is_Number and term < 0: + pform = self._print(-term) + pforms.append(pretty_negative(pform, i)) + elif term.is_Relational: + pforms.append(prettyForm(*self._print(term).parens())) + else: + pforms.append(self._print(term)) + + if indices: + large = True + + for pform in pforms: + if pform is not None and pform.height() > 1: + break + else: + large = False + + for i in indices: + term, negative = terms[i], False + + if term < 0: + term, negative = -term, True + + if large: + pform = prettyForm(str(term.p))/prettyForm(str(term.q)) + else: + pform = self._print(term) + + if negative: + pform = pretty_negative(pform, i) + + pforms[i] = pform + + return prettyForm.__add__(*pforms) + + def _print_Mul(self, product): + from sympy.physics.units import Quantity + + # Check for unevaluated Mul. In this case we need to make sure the + # identities are visible, multiple Rational factors are not combined + # etc so we display in a straight-forward form that fully preserves all + # args and their order. + args = product.args + if args[0] is S.One or any(isinstance(arg, Number) for arg in args[1:]): + strargs = list(map(self._print, args)) + # XXX: This is a hack to work around the fact that + # prettyForm.__mul__ absorbs a leading -1 in the args. Probably it + # would be better to fix this in prettyForm.__mul__ instead. + negone = strargs[0] == '-1' + if negone: + strargs[0] = prettyForm('1', 0, 0) + obj = prettyForm.__mul__(*strargs) + if negone: + obj = prettyForm('-' + obj.s, obj.baseline, obj.binding) + return obj + + a = [] # items in the numerator + b = [] # items that are in the denominator (if any) + + if self.order not in ('old', 'none'): + args = product.as_ordered_factors() + else: + args = list(product.args) + + # If quantities are present append them at the back + args = sorted(args, key=lambda x: isinstance(x, Quantity) or + (isinstance(x, Pow) and isinstance(x.base, Quantity))) + + # Gather terms for numerator/denominator + for item in args: + if item.is_commutative and item.is_Pow and item.exp.is_Rational and item.exp.is_negative: + if item.exp != -1: + b.append(Pow(item.base, -item.exp, evaluate=False)) + else: + b.append(Pow(item.base, -item.exp)) + elif item.is_Rational and item is not S.Infinity: + if item.p != 1: + a.append( Rational(item.p) ) + if item.q != 1: + b.append( Rational(item.q) ) + else: + a.append(item) + + # Convert to pretty forms. Parentheses are added by `__mul__`. + a = [self._print(ai) for ai in a] + b = [self._print(bi) for bi in b] + + # Construct a pretty form + if len(b) == 0: + return prettyForm.__mul__(*a) + else: + if len(a) == 0: + a.append( self._print(S.One) ) + return prettyForm.__mul__(*a)/prettyForm.__mul__(*b) + + # A helper function for _print_Pow to print x**(1/n) + def _print_nth_root(self, base, root): + bpretty = self._print(base) + + # In very simple cases, use a single-char root sign + if (self._settings['use_unicode_sqrt_char'] and self._use_unicode + and root == 2 and bpretty.height() == 1 + and (bpretty.width() == 1 + or (base.is_Integer and base.is_nonnegative))): + return prettyForm(*bpretty.left(nth_root[2])) + + # Construct root sign, start with the \/ shape + _zZ = xobj('/', 1) + rootsign = xobj('\\', 1) + _zZ + # Constructing the number to put on root + rpretty = self._print(root) + # roots look bad if they are not a single line + if rpretty.height() != 1: + return self._print(base)**self._print(1/root) + # If power is half, no number should appear on top of root sign + exp = '' if root == 2 else str(rpretty).ljust(2) + if len(exp) > 2: + rootsign = ' '*(len(exp) - 2) + rootsign + # Stack the exponent + rootsign = stringPict(exp + '\n' + rootsign) + rootsign.baseline = 0 + # Diagonal: length is one less than height of base + linelength = bpretty.height() - 1 + diagonal = stringPict('\n'.join( + ' '*(linelength - i - 1) + _zZ + ' '*i + for i in range(linelength) + )) + # Put baseline just below lowest line: next to exp + diagonal.baseline = linelength - 1 + # Make the root symbol + rootsign = prettyForm(*rootsign.right(diagonal)) + # Det the baseline to match contents to fix the height + # but if the height of bpretty is one, the rootsign must be one higher + rootsign.baseline = max(1, bpretty.baseline) + #build result + s = prettyForm(hobj('_', 2 + bpretty.width())) + s = prettyForm(*bpretty.above(s)) + s = prettyForm(*s.left(rootsign)) + return s + + def _print_Pow(self, power): + from sympy.simplify.simplify import fraction + b, e = power.as_base_exp() + if power.is_commutative: + if e is S.NegativeOne: + return prettyForm("1")/self._print(b) + n, d = fraction(e) + if n is S.One and d.is_Atom and not e.is_Integer and (e.is_Rational or d.is_Symbol) \ + and self._settings['root_notation']: + return self._print_nth_root(b, d) + if e.is_Rational and e < 0: + return prettyForm("1")/self._print(Pow(b, -e, evaluate=False)) + + if b.is_Relational: + return prettyForm(*self._print(b).parens()).__pow__(self._print(e)) + + return self._print(b)**self._print(e) + + def _print_UnevaluatedExpr(self, expr): + return self._print(expr.args[0]) + + def __print_numer_denom(self, p, q): + if q == 1: + if p < 0: + return prettyForm(str(p), binding=prettyForm.NEG) + else: + return prettyForm(str(p)) + elif abs(p) >= 10 and abs(q) >= 10: + # If more than one digit in numer and denom, print larger fraction + if p < 0: + return prettyForm(str(p), binding=prettyForm.NEG)/prettyForm(str(q)) + # Old printing method: + #pform = prettyForm(str(-p))/prettyForm(str(q)) + #return prettyForm(binding=prettyForm.NEG, *pform.left('- ')) + else: + return prettyForm(str(p))/prettyForm(str(q)) + else: + return None + + def _print_Rational(self, expr): + result = self.__print_numer_denom(expr.p, expr.q) + + if result is not None: + return result + else: + return self.emptyPrinter(expr) + + def _print_Fraction(self, expr): + result = self.__print_numer_denom(expr.numerator, expr.denominator) + + if result is not None: + return result + else: + return self.emptyPrinter(expr) + + def _print_ProductSet(self, p): + if len(p.sets) >= 1 and not has_variety(p.sets): + return self._print(p.sets[0]) ** self._print(len(p.sets)) + else: + prod_char = pretty_atom('Multiplication') if self._use_unicode else 'x' + return self._print_seq(p.sets, None, None, ' %s ' % prod_char, + parenthesize=lambda set: set.is_Union or + set.is_Intersection or set.is_ProductSet) + + def _print_FiniteSet(self, s): + items = sorted(s.args, key=default_sort_key) + return self._print_seq(items, '{', '}', ', ' ) + + def _print_Range(self, s): + + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + + if s.start.is_infinite and s.stop.is_infinite: + if s.step.is_positive: + printset = dots, -1, 0, 1, dots + else: + printset = dots, 1, 0, -1, dots + elif s.start.is_infinite: + printset = dots, s[-1] - s.step, s[-1] + elif s.stop.is_infinite: + it = iter(s) + printset = next(it), next(it), dots + elif len(s) > 4: + it = iter(s) + printset = next(it), next(it), dots, s[-1] + else: + printset = tuple(s) + + return self._print_seq(printset, '{', '}', ', ' ) + + def _print_Interval(self, i): + if i.start == i.end: + return self._print_seq(i.args[:1], '{', '}') + + else: + if i.left_open: + left = '(' + else: + left = '[' + + if i.right_open: + right = ')' + else: + right = ']' + + return self._print_seq(i.args[:2], left, right) + + def _print_AccumulationBounds(self, i): + left = '<' + right = '>' + + return self._print_seq(i.args[:2], left, right) + + def _print_Intersection(self, u): + + delimiter = ' %s ' % pretty_atom('Intersection', 'n') + + return self._print_seq(u.args, None, None, delimiter, + parenthesize=lambda set: set.is_ProductSet or + set.is_Union or set.is_Complement) + + def _print_Union(self, u): + + union_delimiter = ' %s ' % pretty_atom('Union', 'U') + + return self._print_seq(u.args, None, None, union_delimiter, + parenthesize=lambda set: set.is_ProductSet or + set.is_Intersection or set.is_Complement) + + def _print_SymmetricDifference(self, u): + if not self._use_unicode: + raise NotImplementedError("ASCII pretty printing of SymmetricDifference is not implemented") + + sym_delimeter = ' %s ' % pretty_atom('SymmetricDifference') + + return self._print_seq(u.args, None, None, sym_delimeter) + + def _print_Complement(self, u): + + delimiter = r' \ ' + + return self._print_seq(u.args, None, None, delimiter, + parenthesize=lambda set: set.is_ProductSet or set.is_Intersection + or set.is_Union) + + def _print_ImageSet(self, ts): + if self._use_unicode: + inn = pretty_atom("SmallElementOf") + else: + inn = 'in' + fun = ts.lamda + sets = ts.base_sets + signature = fun.signature + expr = self._print(fun.expr) + + # TODO: the stuff to the left of the | and the stuff to the right of + # the | should have independent baselines, that way something like + # ImageSet(Lambda(x, 1/x**2), S.Naturals) prints the "x in N" part + # centered on the right instead of aligned with the fraction bar on + # the left. The same also applies to ConditionSet and ComplexRegion + if len(signature) == 1: + S = self._print_seq((signature[0], inn, sets[0]), + delimiter=' ') + return self._hprint_vseparator(expr, S, + left='{', right='}', + ifascii_nougly=True, delimiter=' ') + else: + pargs = tuple(j for var, setv in zip(signature, sets) for j in + (var, ' ', inn, ' ', setv, ", ")) + S = self._print_seq(pargs[:-1], delimiter='') + return self._hprint_vseparator(expr, S, + left='{', right='}', + ifascii_nougly=True, delimiter=' ') + + def _print_ConditionSet(self, ts): + if self._use_unicode: + inn = pretty_atom('SmallElementOf') + # using _and because and is a keyword and it is bad practice to + # overwrite them + _and = pretty_atom('And') + else: + inn = 'in' + _and = 'and' + + variables = self._print_seq(Tuple(ts.sym)) + as_expr = getattr(ts.condition, 'as_expr', None) + if as_expr is not None: + cond = self._print(ts.condition.as_expr()) + else: + cond = self._print(ts.condition) + if self._use_unicode: + cond = self._print(cond) + cond = prettyForm(*cond.parens()) + + if ts.base_set is S.UniversalSet: + return self._hprint_vseparator(variables, cond, left="{", + right="}", ifascii_nougly=True, + delimiter=' ') + + base = self._print(ts.base_set) + C = self._print_seq((variables, inn, base, _and, cond), + delimiter=' ') + return self._hprint_vseparator(variables, C, left="{", right="}", + ifascii_nougly=True, delimiter=' ') + + def _print_ComplexRegion(self, ts): + if self._use_unicode: + inn = pretty_atom('SmallElementOf') + else: + inn = 'in' + variables = self._print_seq(ts.variables) + expr = self._print(ts.expr) + prodsets = self._print(ts.sets) + + C = self._print_seq((variables, inn, prodsets), + delimiter=' ') + return self._hprint_vseparator(expr, C, left="{", right="}", + ifascii_nougly=True, delimiter=' ') + + def _print_Contains(self, e): + var, set = e.args + if self._use_unicode: + el = f" {pretty_atom('ElementOf')} " + return prettyForm(*stringPict.next(self._print(var), + el, self._print(set)), binding=8) + else: + return prettyForm(sstr(e)) + + def _print_FourierSeries(self, s): + if s.an.formula is S.Zero and s.bn.formula is S.Zero: + return self._print(s.a0) + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + return self._print_Add(s.truncate()) + self._print(dots) + + def _print_FormalPowerSeries(self, s): + return self._print_Add(s.infinite) + + def _print_SetExpr(self, se): + pretty_set = prettyForm(*self._print(se.set).parens()) + pretty_name = self._print(Symbol("SetExpr")) + return prettyForm(*pretty_name.right(pretty_set)) + + def _print_SeqFormula(self, s): + if self._use_unicode: + dots = pretty_atom('Dots') + else: + dots = '...' + + if len(s.start.free_symbols) > 0 or len(s.stop.free_symbols) > 0: + raise NotImplementedError("Pretty printing of sequences with symbolic bound not implemented") + + if s.start is S.NegativeInfinity: + stop = s.stop + printset = (dots, s.coeff(stop - 3), s.coeff(stop - 2), + s.coeff(stop - 1), s.coeff(stop)) + elif s.stop is S.Infinity or s.length > 4: + printset = s[:4] + printset.append(dots) + printset = tuple(printset) + else: + printset = tuple(s) + return self._print_list(printset) + + _print_SeqPer = _print_SeqFormula + _print_SeqAdd = _print_SeqFormula + _print_SeqMul = _print_SeqFormula + + def _print_seq(self, seq, left=None, right=None, delimiter=', ', + parenthesize=lambda x: False, ifascii_nougly=True): + + pforms = [] + for item in seq: + pform = self._print(item) + if parenthesize(item): + pform = prettyForm(*pform.parens()) + if pforms: + pforms.append(delimiter) + pforms.append(pform) + + if not pforms: + s = stringPict('') + else: + s = prettyForm(*stringPict.next(*pforms)) + + s = prettyForm(*s.parens(left, right, ifascii_nougly=ifascii_nougly)) + return s + + def join(self, delimiter, args): + pform = None + + for arg in args: + if pform is None: + pform = arg + else: + pform = prettyForm(*pform.right(delimiter)) + pform = prettyForm(*pform.right(arg)) + + if pform is None: + return prettyForm("") + else: + return pform + + def _print_list(self, l): + return self._print_seq(l, '[', ']') + + def _print_tuple(self, t): + if len(t) == 1: + ptuple = prettyForm(*stringPict.next(self._print(t[0]), ',')) + return prettyForm(*ptuple.parens('(', ')', ifascii_nougly=True)) + else: + return self._print_seq(t, '(', ')') + + def _print_Tuple(self, expr): + return self._print_tuple(expr) + + def _print_dict(self, d): + keys = sorted(d.keys(), key=default_sort_key) + items = [] + + for k in keys: + K = self._print(k) + V = self._print(d[k]) + s = prettyForm(*stringPict.next(K, ': ', V)) + + items.append(s) + + return self._print_seq(items, '{', '}') + + def _print_Dict(self, d): + return self._print_dict(d) + + def _print_set(self, s): + if not s: + return prettyForm('set()') + items = sorted(s, key=default_sort_key) + pretty = self._print_seq(items) + pretty = prettyForm(*pretty.parens('{', '}', ifascii_nougly=True)) + return pretty + + def _print_frozenset(self, s): + if not s: + return prettyForm('frozenset()') + items = sorted(s, key=default_sort_key) + pretty = self._print_seq(items) + pretty = prettyForm(*pretty.parens('{', '}', ifascii_nougly=True)) + pretty = prettyForm(*pretty.parens('(', ')', ifascii_nougly=True)) + pretty = prettyForm(*stringPict.next(type(s).__name__, pretty)) + return pretty + + def _print_UniversalSet(self, s): + if self._use_unicode: + return prettyForm(pretty_atom('Universe')) + else: + return prettyForm('UniversalSet') + + def _print_PolyRing(self, ring): + return prettyForm(sstr(ring)) + + def _print_FracField(self, field): + return prettyForm(sstr(field)) + + def _print_FreeGroupElement(self, elm): + return prettyForm(str(elm)) + + def _print_PolyElement(self, poly): + return prettyForm(sstr(poly)) + + def _print_FracElement(self, frac): + return prettyForm(sstr(frac)) + + def _print_AlgebraicNumber(self, expr): + if expr.is_aliased: + return self._print(expr.as_poly().as_expr()) + else: + return self._print(expr.as_expr()) + + def _print_ComplexRootOf(self, expr): + args = [self._print_Add(expr.expr, order='lex'), expr.index] + pform = prettyForm(*self._print_seq(args).parens()) + pform = prettyForm(*pform.left('CRootOf')) + return pform + + def _print_RootSum(self, expr): + args = [self._print_Add(expr.expr, order='lex')] + + if expr.fun is not S.IdentityFunction: + args.append(self._print(expr.fun)) + + pform = prettyForm(*self._print_seq(args).parens()) + pform = prettyForm(*pform.left('RootSum')) + + return pform + + def _print_FiniteField(self, expr): + if self._use_unicode: + form = f"{pretty_atom('Integers')}_%d" + else: + form = 'GF(%d)' + + return prettyForm(pretty_symbol(form % expr.mod)) + + def _print_IntegerRing(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('Integers')) + else: + return prettyForm('ZZ') + + def _print_RationalField(self, expr): + if self._use_unicode: + return prettyForm(pretty_atom('Rationals')) + else: + return prettyForm('QQ') + + def _print_RealField(self, domain): + if self._use_unicode: + prefix = pretty_atom("Reals") + else: + prefix = 'RR' + + if domain.has_default_precision: + return prettyForm(prefix) + else: + return self._print(pretty_symbol(prefix + "_" + str(domain.precision))) + + def _print_ComplexField(self, domain): + if self._use_unicode: + prefix = pretty_atom('Complexes') + else: + prefix = 'CC' + + if domain.has_default_precision: + return prettyForm(prefix) + else: + return self._print(pretty_symbol(prefix + "_" + str(domain.precision))) + + def _print_PolynomialRing(self, expr): + args = list(expr.symbols) + + if not expr.order.is_default: + order = prettyForm(*prettyForm("order=").right(self._print(expr.order))) + args.append(order) + + pform = self._print_seq(args, '[', ']') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_FractionField(self, expr): + args = list(expr.symbols) + + if not expr.order.is_default: + order = prettyForm(*prettyForm("order=").right(self._print(expr.order))) + args.append(order) + + pform = self._print_seq(args, '(', ')') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_PolynomialRingBase(self, expr): + g = expr.symbols + if str(expr.order) != str(expr.default_order): + g = g + ("order=" + str(expr.order),) + pform = self._print_seq(g, '[', ']') + pform = prettyForm(*pform.left(self._print(expr.domain))) + + return pform + + def _print_GroebnerBasis(self, basis): + exprs = [ self._print_Add(arg, order=basis.order) + for arg in basis.exprs ] + exprs = prettyForm(*self.join(", ", exprs).parens(left="[", right="]")) + + gens = [ self._print(gen) for gen in basis.gens ] + + domain = prettyForm( + *prettyForm("domain=").right(self._print(basis.domain))) + order = prettyForm( + *prettyForm("order=").right(self._print(basis.order))) + + pform = self.join(", ", [exprs] + gens + [domain, order]) + + pform = prettyForm(*pform.parens()) + pform = prettyForm(*pform.left(basis.__class__.__name__)) + + return pform + + def _print_Subs(self, e): + pform = self._print(e.expr) + pform = prettyForm(*pform.parens()) + + h = pform.height() if pform.height() > 1 else 2 + rvert = stringPict(vobj('|', h), baseline=pform.baseline) + pform = prettyForm(*pform.right(rvert)) + + b = pform.baseline + pform.baseline = pform.height() - 1 + pform = prettyForm(*pform.right(self._print_seq([ + self._print_seq((self._print(v[0]), xsym('=='), self._print(v[1])), + delimiter='') for v in zip(e.variables, e.point) ]))) + + pform.baseline = b + return pform + + def _print_number_function(self, e, name): + # Print name_arg[0] for one argument or name_arg[0](arg[1]) + # for more than one argument + pform = prettyForm(name) + arg = self._print(e.args[0]) + pform_arg = prettyForm(" "*arg.width()) + pform_arg = prettyForm(*pform_arg.below(arg)) + pform = prettyForm(*pform.right(pform_arg)) + if len(e.args) == 1: + return pform + m, x = e.args + # TODO: copy-pasted from _print_Function: can we do better? + prettyFunc = pform + prettyArgs = prettyForm(*self._print_seq([x]).parens()) + pform = prettyForm( + binding=prettyForm.FUNC, *stringPict.next(prettyFunc, prettyArgs)) + pform.prettyFunc = prettyFunc + pform.prettyArgs = prettyArgs + return pform + + def _print_euler(self, e): + return self._print_number_function(e, "E") + + def _print_catalan(self, e): + return self._print_number_function(e, "C") + + def _print_bernoulli(self, e): + return self._print_number_function(e, "B") + + _print_bell = _print_bernoulli + + def _print_lucas(self, e): + return self._print_number_function(e, "L") + + def _print_fibonacci(self, e): + return self._print_number_function(e, "F") + + def _print_tribonacci(self, e): + return self._print_number_function(e, "T") + + def _print_stieltjes(self, e): + if self._use_unicode: + return self._print_number_function(e, greek_unicode['gamma']) + else: + return self._print_number_function(e, "stieltjes") + + def _print_KroneckerDelta(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.right(prettyForm(','))) + pform = prettyForm(*pform.right(self._print(e.args[1]))) + if self._use_unicode: + a = stringPict(pretty_symbol('delta')) + else: + a = stringPict('d') + b = pform + top = stringPict(*b.left(' '*a.width())) + bot = stringPict(*a.right(' '*b.width())) + return prettyForm(binding=prettyForm.POW, *bot.below(top)) + + def _print_RandomDomain(self, d): + if hasattr(d, 'as_boolean'): + pform = self._print('Domain: ') + pform = prettyForm(*pform.right(self._print(d.as_boolean()))) + return pform + elif hasattr(d, 'set'): + pform = self._print('Domain: ') + pform = prettyForm(*pform.right(self._print(d.symbols))) + pform = prettyForm(*pform.right(self._print(' in '))) + pform = prettyForm(*pform.right(self._print(d.set))) + return pform + elif hasattr(d, 'symbols'): + pform = self._print('Domain on ') + pform = prettyForm(*pform.right(self._print(d.symbols))) + return pform + else: + return self._print(None) + + def _print_DMP(self, p): + try: + if p.ring is not None: + # TODO incorporate order + return self._print(p.ring.to_sympy(p)) + except SympifyError: + pass + return self._print(repr(p)) + + def _print_DMF(self, p): + return self._print_DMP(p) + + def _print_Object(self, object): + return self._print(pretty_symbol(object.name)) + + def _print_Morphism(self, morphism): + arrow = xsym("-->") + + domain = self._print(morphism.domain) + codomain = self._print(morphism.codomain) + tail = domain.right(arrow, codomain)[0] + + return prettyForm(tail) + + def _print_NamedMorphism(self, morphism): + pretty_name = self._print(pretty_symbol(morphism.name)) + pretty_morphism = self._print_Morphism(morphism) + return prettyForm(pretty_name.right(":", pretty_morphism)[0]) + + def _print_IdentityMorphism(self, morphism): + from sympy.categories import NamedMorphism + return self._print_NamedMorphism( + NamedMorphism(morphism.domain, morphism.codomain, "id")) + + def _print_CompositeMorphism(self, morphism): + + circle = xsym(".") + + # All components of the morphism have names and it is thus + # possible to build the name of the composite. + component_names_list = [pretty_symbol(component.name) for + component in morphism.components] + component_names_list.reverse() + component_names = circle.join(component_names_list) + ":" + + pretty_name = self._print(component_names) + pretty_morphism = self._print_Morphism(morphism) + return prettyForm(pretty_name.right(pretty_morphism)[0]) + + def _print_Category(self, category): + return self._print(pretty_symbol(category.name)) + + def _print_Diagram(self, diagram): + if not diagram.premises: + # This is an empty diagram. + return self._print(S.EmptySet) + + pretty_result = self._print(diagram.premises) + if diagram.conclusions: + results_arrow = " %s " % xsym("==>") + + pretty_conclusions = self._print(diagram.conclusions)[0] + pretty_result = pretty_result.right( + results_arrow, pretty_conclusions) + + return prettyForm(pretty_result[0]) + + def _print_DiagramGrid(self, grid): + from sympy.matrices import Matrix + matrix = Matrix([[grid[i, j] if grid[i, j] else Symbol(" ") + for j in range(grid.width)] + for i in range(grid.height)]) + return self._print_matrix_contents(matrix) + + def _print_FreeModuleElement(self, m): + # Print as row vector for convenience, for now. + return self._print_seq(m, '[', ']') + + def _print_SubModule(self, M): + gens = [[M.ring.to_sympy(g) for g in gen] for gen in M.gens] + return self._print_seq(gens, '<', '>') + + def _print_FreeModule(self, M): + return self._print(M.ring)**self._print(M.rank) + + def _print_ModuleImplementedIdeal(self, M): + sym = M.ring.to_sympy + return self._print_seq([sym(x) for [x] in M._module.gens], '<', '>') + + def _print_QuotientRing(self, R): + return self._print(R.ring) / self._print(R.base_ideal) + + def _print_QuotientRingElement(self, R): + return self._print(R.ring.to_sympy(R)) + self._print(R.ring.base_ideal) + + def _print_QuotientModuleElement(self, m): + return self._print(m.data) + self._print(m.module.killed_module) + + def _print_QuotientModule(self, M): + return self._print(M.base) / self._print(M.killed_module) + + def _print_MatrixHomomorphism(self, h): + matrix = self._print(h._sympy_matrix()) + matrix.baseline = matrix.height() // 2 + pform = prettyForm(*matrix.right(' : ', self._print(h.domain), + ' %s> ' % hobj('-', 2), self._print(h.codomain))) + return pform + + def _print_Manifold(self, manifold): + return self._print(manifold.name) + + def _print_Patch(self, patch): + return self._print(patch.name) + + def _print_CoordSystem(self, coords): + return self._print(coords.name) + + def _print_BaseScalarField(self, field): + string = field._coord_sys.symbols[field._index].name + return self._print(pretty_symbol(string)) + + def _print_BaseVectorField(self, field): + s = U('PARTIAL DIFFERENTIAL') + '_' + field._coord_sys.symbols[field._index].name + return self._print(pretty_symbol(s)) + + def _print_Differential(self, diff): + if self._use_unicode: + d = pretty_atom('Differential') + else: + d = 'd' + field = diff._form_field + if hasattr(field, '_coord_sys'): + string = field._coord_sys.symbols[field._index].name + return self._print(d + ' ' + pretty_symbol(string)) + else: + pform = self._print(field) + pform = prettyForm(*pform.parens()) + return prettyForm(*pform.left(d)) + + def _print_Tr(self, p): + #TODO: Handle indices + pform = self._print(p.args[0]) + pform = prettyForm(*pform.left('%s(' % (p.__class__.__name__))) + pform = prettyForm(*pform.right(')')) + return pform + + def _print_primenu(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + if self._use_unicode: + pform = prettyForm(*pform.left(greek_unicode['nu'])) + else: + pform = prettyForm(*pform.left('nu')) + return pform + + def _print_primeomega(self, e): + pform = self._print(e.args[0]) + pform = prettyForm(*pform.parens()) + if self._use_unicode: + pform = prettyForm(*pform.left(greek_unicode['Omega'])) + else: + pform = prettyForm(*pform.left('Omega')) + return pform + + def _print_Quantity(self, e): + if e.name.name == 'degree': + if self._use_unicode: + pform = self._print(pretty_atom('Degree')) + else: + pform = self._print(chr(176)) + return pform + else: + return self.emptyPrinter(e) + + def _print_AssignmentBase(self, e): + + op = prettyForm(' ' + xsym(e.op) + ' ') + + l = self._print(e.lhs) + r = self._print(e.rhs) + pform = prettyForm(*stringPict.next(l, op, r)) + return pform + + def _print_Str(self, s): + return self._print(s.name) + + +@print_function(PrettyPrinter) +def pretty(expr, **settings): + """Returns a string containing the prettified form of expr. + + For information on keyword arguments see pretty_print function. + + """ + pp = PrettyPrinter(settings) + + # XXX: this is an ugly hack, but at least it works + use_unicode = pp._settings['use_unicode'] + uflag = pretty_use_unicode(use_unicode) + + try: + return pp.doprint(expr) + finally: + pretty_use_unicode(uflag) + + +def pretty_print(expr, **kwargs): + """Prints expr in pretty form. + + pprint is just a shortcut for this function. + + Parameters + ========== + + expr : expression + The expression to print. + + wrap_line : bool, optional (default=True) + Line wrapping enabled/disabled. + + num_columns : int or None, optional (default=None) + Number of columns before line breaking (default to None which reads + the terminal width), useful when using SymPy without terminal. + + use_unicode : bool or None, optional (default=None) + Use unicode characters, such as the Greek letter pi instead of + the string pi. + + full_prec : bool or string, optional (default="auto") + Use full precision. + + order : bool or string, optional (default=None) + Set to 'none' for long expressions if slow; default is None. + + use_unicode_sqrt_char : bool, optional (default=True) + Use compact single-character square root symbol (when unambiguous). + + root_notation : bool, optional (default=True) + Set to 'False' for printing exponents of the form 1/n in fractional form. + By default exponent is printed in root form. + + mat_symbol_style : string, optional (default="plain") + Set to "bold" for printing MatrixSymbols using a bold mathematical symbol face. + By default the standard face is used. + + imaginary_unit : string, optional (default="i") + Letter to use for imaginary unit when use_unicode is True. + Can be "i" (default) or "j". + """ + print(pretty(expr, **kwargs)) + +pprint = pretty_print + + +def pager_print(expr, **settings): + """Prints expr using the pager, in pretty form. + + This invokes a pager command using pydoc. Lines are not wrapped + automatically. This routine is meant to be used with a pager that allows + sideways scrolling, like ``less -S``. + + Parameters are the same as for ``pretty_print``. If you wish to wrap lines, + pass ``num_columns=None`` to auto-detect the width of the terminal. + + """ + from pydoc import pager + from locale import getpreferredencoding + if 'num_columns' not in settings: + settings['num_columns'] = 500000 # disable line wrap + pager(pretty(expr, **settings).encode(getpreferredencoding())) diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/pretty_symbology.py b/lib/python3.10/site-packages/sympy/printing/pretty/pretty_symbology.py new file mode 100644 index 0000000000000000000000000000000000000000..d12fff726702101c167a5fef5cba387b4918749d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/pretty/pretty_symbology.py @@ -0,0 +1,732 @@ +"""Symbolic primitives + unicode/ASCII abstraction for pretty.py""" + +import sys +import warnings +from string import ascii_lowercase, ascii_uppercase +import unicodedata + +unicode_warnings = '' + +def U(name): + """ + Get a unicode character by name or, None if not found. + + This exists because older versions of Python use older unicode databases. + """ + try: + return unicodedata.lookup(name) + except KeyError: + global unicode_warnings + unicode_warnings += 'No \'%s\' in unicodedata\n' % name + return None + +from sympy.printing.conventions import split_super_sub +from sympy.core.alphabets import greeks +from sympy.utilities.exceptions import sympy_deprecation_warning + +# prefix conventions when constructing tables +# L - LATIN i +# G - GREEK beta +# D - DIGIT 0 +# S - SYMBOL + + + +__all__ = ['greek_unicode', 'sub', 'sup', 'xsym', 'vobj', 'hobj', 'pretty_symbol', + 'annotated', 'center_pad', 'center'] + + +_use_unicode = False + + +def pretty_use_unicode(flag=None): + """Set whether pretty-printer should use unicode by default""" + global _use_unicode + global unicode_warnings + if flag is None: + return _use_unicode + + if flag and unicode_warnings: + # print warnings (if any) on first unicode usage + warnings.warn(unicode_warnings) + unicode_warnings = '' + + use_unicode_prev = _use_unicode + _use_unicode = flag + return use_unicode_prev + + +def pretty_try_use_unicode(): + """See if unicode output is available and leverage it if possible""" + + encoding = getattr(sys.stdout, 'encoding', None) + + # this happens when e.g. stdout is redirected through a pipe, or is + # e.g. a cStringIO.StringO + if encoding is None: + return # sys.stdout has no encoding + + symbols = [] + + # see if we can represent greek alphabet + symbols += greek_unicode.values() + + # and atoms + symbols += atoms_table.values() + + for s in symbols: + if s is None: + return # common symbols not present! + + try: + s.encode(encoding) + except UnicodeEncodeError: + return + + # all the characters were present and encodable + pretty_use_unicode(True) + + +def xstr(*args): + sympy_deprecation_warning( + """ + The sympy.printing.pretty.pretty_symbology.xstr() function is + deprecated. Use str() instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions" + ) + return str(*args) + +# GREEK +g = lambda l: U('GREEK SMALL LETTER %s' % l.upper()) +G = lambda l: U('GREEK CAPITAL LETTER %s' % l.upper()) + +greek_letters = list(greeks) # make a copy +# deal with Unicode's funny spelling of lambda +greek_letters[greek_letters.index('lambda')] = 'lamda' + +# {} greek letter -> (g,G) +greek_unicode = {L: g(L) for L in greek_letters} +greek_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_letters) + +# aliases +greek_unicode['lambda'] = greek_unicode['lamda'] +greek_unicode['Lambda'] = greek_unicode['Lamda'] +greek_unicode['varsigma'] = '\N{GREEK SMALL LETTER FINAL SIGMA}' + +# BOLD +b = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper()) +B = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper()) + +bold_unicode = {l: b(l) for l in ascii_lowercase} +bold_unicode.update((L, B(L)) for L in ascii_uppercase) + +# GREEK BOLD +gb = lambda l: U('MATHEMATICAL BOLD SMALL %s' % l.upper()) +GB = lambda l: U('MATHEMATICAL BOLD CAPITAL %s' % l.upper()) + +greek_bold_letters = list(greeks) # make a copy, not strictly required here +# deal with Unicode's funny spelling of lambda +greek_bold_letters[greek_bold_letters.index('lambda')] = 'lamda' + +# {} greek letter -> (g,G) +greek_bold_unicode = {L: g(L) for L in greek_bold_letters} +greek_bold_unicode.update((L[0].upper() + L[1:], G(L)) for L in greek_bold_letters) +greek_bold_unicode['lambda'] = greek_unicode['lamda'] +greek_bold_unicode['Lambda'] = greek_unicode['Lamda'] +greek_bold_unicode['varsigma'] = '\N{MATHEMATICAL BOLD SMALL FINAL SIGMA}' + +digit_2txt = { + '0': 'ZERO', + '1': 'ONE', + '2': 'TWO', + '3': 'THREE', + '4': 'FOUR', + '5': 'FIVE', + '6': 'SIX', + '7': 'SEVEN', + '8': 'EIGHT', + '9': 'NINE', +} + +symb_2txt = { + '+': 'PLUS SIGN', + '-': 'MINUS', + '=': 'EQUALS SIGN', + '(': 'LEFT PARENTHESIS', + ')': 'RIGHT PARENTHESIS', + '[': 'LEFT SQUARE BRACKET', + ']': 'RIGHT SQUARE BRACKET', + '{': 'LEFT CURLY BRACKET', + '}': 'RIGHT CURLY BRACKET', + + # non-std + '{}': 'CURLY BRACKET', + 'sum': 'SUMMATION', + 'int': 'INTEGRAL', +} + +# SUBSCRIPT & SUPERSCRIPT +LSUB = lambda letter: U('LATIN SUBSCRIPT SMALL LETTER %s' % letter.upper()) +GSUB = lambda letter: U('GREEK SUBSCRIPT SMALL LETTER %s' % letter.upper()) +DSUB = lambda digit: U('SUBSCRIPT %s' % digit_2txt[digit]) +SSUB = lambda symb: U('SUBSCRIPT %s' % symb_2txt[symb]) + +LSUP = lambda letter: U('SUPERSCRIPT LATIN SMALL LETTER %s' % letter.upper()) +DSUP = lambda digit: U('SUPERSCRIPT %s' % digit_2txt[digit]) +SSUP = lambda symb: U('SUPERSCRIPT %s' % symb_2txt[symb]) + +sub = {} # symb -> subscript symbol +sup = {} # symb -> superscript symbol + +# latin subscripts +for l in 'aeioruvxhklmnpst': + sub[l] = LSUB(l) + +for l in 'in': + sup[l] = LSUP(l) + +for gl in ['beta', 'gamma', 'rho', 'phi', 'chi']: + sub[gl] = GSUB(gl) + +for d in [str(i) for i in range(10)]: + sub[d] = DSUB(d) + sup[d] = DSUP(d) + +for s in '+-=()': + sub[s] = SSUB(s) + sup[s] = SSUP(s) + +# Variable modifiers +# TODO: Make brackets adjust to height of contents +modifier_dict = { + # Accents + 'mathring': lambda s: center_accent(s, '\N{COMBINING RING ABOVE}'), + 'ddddot': lambda s: center_accent(s, '\N{COMBINING FOUR DOTS ABOVE}'), + 'dddot': lambda s: center_accent(s, '\N{COMBINING THREE DOTS ABOVE}'), + 'ddot': lambda s: center_accent(s, '\N{COMBINING DIAERESIS}'), + 'dot': lambda s: center_accent(s, '\N{COMBINING DOT ABOVE}'), + 'check': lambda s: center_accent(s, '\N{COMBINING CARON}'), + 'breve': lambda s: center_accent(s, '\N{COMBINING BREVE}'), + 'acute': lambda s: center_accent(s, '\N{COMBINING ACUTE ACCENT}'), + 'grave': lambda s: center_accent(s, '\N{COMBINING GRAVE ACCENT}'), + 'tilde': lambda s: center_accent(s, '\N{COMBINING TILDE}'), + 'hat': lambda s: center_accent(s, '\N{COMBINING CIRCUMFLEX ACCENT}'), + 'bar': lambda s: center_accent(s, '\N{COMBINING OVERLINE}'), + 'vec': lambda s: center_accent(s, '\N{COMBINING RIGHT ARROW ABOVE}'), + 'prime': lambda s: s+'\N{PRIME}', + 'prm': lambda s: s+'\N{PRIME}', + # # Faces -- these are here for some compatibility with latex printing + # 'bold': lambda s: s, + # 'bm': lambda s: s, + # 'cal': lambda s: s, + # 'scr': lambda s: s, + # 'frak': lambda s: s, + # Brackets + 'norm': lambda s: '\N{DOUBLE VERTICAL LINE}'+s+'\N{DOUBLE VERTICAL LINE}', + 'avg': lambda s: '\N{MATHEMATICAL LEFT ANGLE BRACKET}'+s+'\N{MATHEMATICAL RIGHT ANGLE BRACKET}', + 'abs': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}', + 'mag': lambda s: '\N{VERTICAL LINE}'+s+'\N{VERTICAL LINE}', +} + +# VERTICAL OBJECTS +HUP = lambda symb: U('%s UPPER HOOK' % symb_2txt[symb]) +CUP = lambda symb: U('%s UPPER CORNER' % symb_2txt[symb]) +MID = lambda symb: U('%s MIDDLE PIECE' % symb_2txt[symb]) +EXT = lambda symb: U('%s EXTENSION' % symb_2txt[symb]) +HLO = lambda symb: U('%s LOWER HOOK' % symb_2txt[symb]) +CLO = lambda symb: U('%s LOWER CORNER' % symb_2txt[symb]) +TOP = lambda symb: U('%s TOP' % symb_2txt[symb]) +BOT = lambda symb: U('%s BOTTOM' % symb_2txt[symb]) + +# {} '(' -> (extension, start, end, middle) 1-character +_xobj_unicode = { + + # vertical symbols + # (( ext, top, bot, mid ), c1) + '(': (( EXT('('), HUP('('), HLO('(') ), '('), + ')': (( EXT(')'), HUP(')'), HLO(')') ), ')'), + '[': (( EXT('['), CUP('['), CLO('[') ), '['), + ']': (( EXT(']'), CUP(']'), CLO(']') ), ']'), + '{': (( EXT('{}'), HUP('{'), HLO('{'), MID('{') ), '{'), + '}': (( EXT('{}'), HUP('}'), HLO('}'), MID('}') ), '}'), + '|': U('BOX DRAWINGS LIGHT VERTICAL'), + 'Tee': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'), + 'UpTack': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL'), + 'corner_up_centre' + '(_ext': U('LEFT PARENTHESIS EXTENSION'), + ')_ext': U('RIGHT PARENTHESIS EXTENSION'), + '(_lower_hook': U('LEFT PARENTHESIS LOWER HOOK'), + ')_lower_hook': U('RIGHT PARENTHESIS LOWER HOOK'), + '(_upper_hook': U('LEFT PARENTHESIS UPPER HOOK'), + ')_upper_hook': U('RIGHT PARENTHESIS UPPER HOOK'), + '<': ((U('BOX DRAWINGS LIGHT VERTICAL'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT')), '<'), + + '>': ((U('BOX DRAWINGS LIGHT VERTICAL'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), + U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), '>'), + + 'lfloor': (( EXT('['), EXT('['), CLO('[') ), U('LEFT FLOOR')), + 'rfloor': (( EXT(']'), EXT(']'), CLO(']') ), U('RIGHT FLOOR')), + 'lceil': (( EXT('['), CUP('['), EXT('[') ), U('LEFT CEILING')), + 'rceil': (( EXT(']'), CUP(']'), EXT(']') ), U('RIGHT CEILING')), + + 'int': (( EXT('int'), U('TOP HALF INTEGRAL'), U('BOTTOM HALF INTEGRAL') ), U('INTEGRAL')), + 'sum': (( U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), '_', U('OVERLINE'), U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT')), U('N-ARY SUMMATION')), + + # horizontal objects + #'-': '-', + '-': U('BOX DRAWINGS LIGHT HORIZONTAL'), + '_': U('LOW LINE'), + # We used to use this, but LOW LINE looks better for roots, as it's a + # little lower (i.e., it lines up with the / perfectly. But perhaps this + # one would still be wanted for some cases? + # '_': U('HORIZONTAL SCAN LINE-9'), + + # diagonal objects '\' & '/' ? + '/': U('BOX DRAWINGS LIGHT DIAGONAL UPPER RIGHT TO LOWER LEFT'), + '\\': U('BOX DRAWINGS LIGHT DIAGONAL UPPER LEFT TO LOWER RIGHT'), +} + +_xobj_ascii = { + # vertical symbols + # (( ext, top, bot, mid ), c1) + '(': (( '|', '/', '\\' ), '('), + ')': (( '|', '\\', '/' ), ')'), + +# XXX this looks ugly +# '[': (( '|', '-', '-' ), '['), +# ']': (( '|', '-', '-' ), ']'), +# XXX not so ugly :( + '[': (( '[', '[', '[' ), '['), + ']': (( ']', ']', ']' ), ']'), + + '{': (( '|', '/', '\\', '<' ), '{'), + '}': (( '|', '\\', '/', '>' ), '}'), + '|': '|', + + '<': (( '|', '/', '\\' ), '<'), + '>': (( '|', '\\', '/' ), '>'), + + 'int': ( ' | ', ' /', '/ ' ), + + # horizontal objects + '-': '-', + '_': '_', + + # diagonal objects '\' & '/' ? + '/': '/', + '\\': '\\', +} + + +def xobj(symb, length): + """Construct spatial object of given length. + + return: [] of equal-length strings + """ + + if length <= 0: + raise ValueError("Length should be greater than 0") + + # TODO robustify when no unicodedat available + if _use_unicode: + _xobj = _xobj_unicode + else: + _xobj = _xobj_ascii + + vinfo = _xobj[symb] + + c1 = top = bot = mid = None + + if not isinstance(vinfo, tuple): # 1 entry + ext = vinfo + else: + if isinstance(vinfo[0], tuple): # (vlong), c1 + vlong = vinfo[0] + c1 = vinfo[1] + else: # (vlong), c1 + vlong = vinfo + + ext = vlong[0] + + try: + top = vlong[1] + bot = vlong[2] + mid = vlong[3] + except IndexError: + pass + + if c1 is None: + c1 = ext + if top is None: + top = ext + if bot is None: + bot = ext + if mid is not None: + if (length % 2) == 0: + # even height, but we have to print it somehow anyway... + # XXX is it ok? + length += 1 + + else: + mid = ext + + if length == 1: + return c1 + + res = [] + next = (length - 2)//2 + nmid = (length - 2) - next*2 + + res += [top] + res += [ext]*next + res += [mid]*nmid + res += [ext]*next + res += [bot] + + return res + + +def vobj(symb, height): + """Construct vertical object of a given height + + see: xobj + """ + return '\n'.join( xobj(symb, height) ) + + +def hobj(symb, width): + """Construct horizontal object of a given width + + see: xobj + """ + return ''.join( xobj(symb, width) ) + +# RADICAL +# n -> symbol +root = { + 2: U('SQUARE ROOT'), # U('RADICAL SYMBOL BOTTOM') + 3: U('CUBE ROOT'), + 4: U('FOURTH ROOT'), +} + + +# RATIONAL +VF = lambda txt: U('VULGAR FRACTION %s' % txt) + +# (p,q) -> symbol +frac = { + (1, 2): VF('ONE HALF'), + (1, 3): VF('ONE THIRD'), + (2, 3): VF('TWO THIRDS'), + (1, 4): VF('ONE QUARTER'), + (3, 4): VF('THREE QUARTERS'), + (1, 5): VF('ONE FIFTH'), + (2, 5): VF('TWO FIFTHS'), + (3, 5): VF('THREE FIFTHS'), + (4, 5): VF('FOUR FIFTHS'), + (1, 6): VF('ONE SIXTH'), + (5, 6): VF('FIVE SIXTHS'), + (1, 8): VF('ONE EIGHTH'), + (3, 8): VF('THREE EIGHTHS'), + (5, 8): VF('FIVE EIGHTHS'), + (7, 8): VF('SEVEN EIGHTHS'), +} + + +# atom symbols +_xsym = { + '==': ('=', '='), + '<': ('<', '<'), + '>': ('>', '>'), + '<=': ('<=', U('LESS-THAN OR EQUAL TO')), + '>=': ('>=', U('GREATER-THAN OR EQUAL TO')), + '!=': ('!=', U('NOT EQUAL TO')), + ':=': (':=', ':='), + '+=': ('+=', '+='), + '-=': ('-=', '-='), + '*=': ('*=', '*='), + '/=': ('/=', '/='), + '%=': ('%=', '%='), + '*': ('*', U('DOT OPERATOR')), + '-->': ('-->', U('EM DASH') + U('EM DASH') + + U('BLACK RIGHT-POINTING TRIANGLE') if U('EM DASH') + and U('BLACK RIGHT-POINTING TRIANGLE') else None), + '==>': ('==>', U('BOX DRAWINGS DOUBLE HORIZONTAL') + + U('BOX DRAWINGS DOUBLE HORIZONTAL') + + U('BLACK RIGHT-POINTING TRIANGLE') if + U('BOX DRAWINGS DOUBLE HORIZONTAL') and + U('BOX DRAWINGS DOUBLE HORIZONTAL') and + U('BLACK RIGHT-POINTING TRIANGLE') else None), + '.': ('*', U('RING OPERATOR')), +} + + +def xsym(sym): + """get symbology for a 'character'""" + op = _xsym[sym] + + if _use_unicode: + return op[1] + else: + return op[0] + + +# SYMBOLS + +atoms_table = { + # class how-to-display + 'Exp1': U('SCRIPT SMALL E'), + 'Pi': U('GREEK SMALL LETTER PI'), + 'Infinity': U('INFINITY'), + 'NegativeInfinity': U('INFINITY') and ('-' + U('INFINITY')), # XXX what to do here + #'ImaginaryUnit': U('GREEK SMALL LETTER IOTA'), + #'ImaginaryUnit': U('MATHEMATICAL ITALIC SMALL I'), + 'ImaginaryUnit': U('DOUBLE-STRUCK ITALIC SMALL I'), + 'EmptySet': U('EMPTY SET'), + 'Naturals': U('DOUBLE-STRUCK CAPITAL N'), + 'Naturals0': (U('DOUBLE-STRUCK CAPITAL N') and + (U('DOUBLE-STRUCK CAPITAL N') + + U('SUBSCRIPT ZERO'))), + 'Integers': U('DOUBLE-STRUCK CAPITAL Z'), + 'Rationals': U('DOUBLE-STRUCK CAPITAL Q'), + 'Reals': U('DOUBLE-STRUCK CAPITAL R'), + 'Complexes': U('DOUBLE-STRUCK CAPITAL C'), + 'Universe': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL U'), + 'IdentityMatrix': U('MATHEMATICAL DOUBLE-STRUCK CAPITAL I'), + 'ZeroMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ZERO'), + 'OneMatrix': U('MATHEMATICAL DOUBLE-STRUCK DIGIT ONE'), + 'Differential': U('DOUBLE-STRUCK ITALIC SMALL D'), + 'Union': U('UNION'), + 'ElementOf': U('ELEMENT OF'), + 'SmallElementOf': U('SMALL ELEMENT OF'), + 'SymmetricDifference': U('INCREMENT'), + 'Intersection': U('INTERSECTION'), + 'Ring': U('RING OPERATOR'), + 'Multiplication': U('MULTIPLICATION SIGN'), + 'TensorProduct': U('N-ARY CIRCLED TIMES OPERATOR'), + 'Dots': U('HORIZONTAL ELLIPSIS'), + 'Modifier Letter Low Ring':U('Modifier Letter Low Ring'), + 'EmptySequence': 'EmptySequence', + 'SuperscriptPlus': U('SUPERSCRIPT PLUS SIGN'), + 'SuperscriptMinus': U('SUPERSCRIPT MINUS'), + 'Dagger': U('DAGGER'), + 'Degree': U('DEGREE SIGN'), + #Logic Symbols + 'And': U('LOGICAL AND'), + 'Or': U('LOGICAL OR'), + 'Not': U('NOT SIGN'), + 'Nor': U('NOR'), + 'Nand': U('NAND'), + 'Xor': U('XOR'), + 'Equiv': U('LEFT RIGHT DOUBLE ARROW'), + 'NotEquiv': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'), + 'Implies': U('LEFT RIGHT DOUBLE ARROW'), + 'NotImplies': U('LEFT RIGHT DOUBLE ARROW WITH STROKE'), + 'Arrow': U('RIGHTWARDS ARROW'), + 'ArrowFromBar': U('RIGHTWARDS ARROW FROM BAR'), + 'NotArrow': U('RIGHTWARDS ARROW WITH STROKE'), + 'Tautology': U('BOX DRAWINGS LIGHT UP AND HORIZONTAL'), + 'Contradiction': U('BOX DRAWINGS LIGHT DOWN AND HORIZONTAL') +} + + +def pretty_atom(atom_name, default=None, printer=None): + """return pretty representation of an atom""" + if _use_unicode: + if printer is not None and atom_name == 'ImaginaryUnit' and printer._settings['imaginary_unit'] == 'j': + return U('DOUBLE-STRUCK ITALIC SMALL J') + else: + return atoms_table[atom_name] + else: + if default is not None: + return default + + raise KeyError('only unicode') # send it default printer + + +def pretty_symbol(symb_name, bold_name=False): + """return pretty representation of a symbol""" + # let's split symb_name into symbol + index + # UC: beta1 + # UC: f_beta + + if not _use_unicode: + return symb_name + + name, sups, subs = split_super_sub(symb_name) + + def translate(s, bold_name) : + if bold_name: + gG = greek_bold_unicode.get(s) + else: + gG = greek_unicode.get(s) + if gG is not None: + return gG + for key in sorted(modifier_dict.keys(), key=lambda k:len(k), reverse=True) : + if s.lower().endswith(key) and len(s)>len(key): + return modifier_dict[key](translate(s[:-len(key)], bold_name)) + if bold_name: + return ''.join([bold_unicode[c] for c in s]) + return s + + name = translate(name, bold_name) + + # Let's prettify sups/subs. If it fails at one of them, pretty sups/subs are + # not used at all. + def pretty_list(l, mapping): + result = [] + for s in l: + pretty = mapping.get(s) + if pretty is None: + try: # match by separate characters + pretty = ''.join([mapping[c] for c in s]) + except (TypeError, KeyError): + return None + result.append(pretty) + return result + + pretty_sups = pretty_list(sups, sup) + if pretty_sups is not None: + pretty_subs = pretty_list(subs, sub) + else: + pretty_subs = None + + # glue the results into one string + if pretty_subs is None: # nice formatting of sups/subs did not work + if subs: + name += '_'+'_'.join([translate(s, bold_name) for s in subs]) + if sups: + name += '__'+'__'.join([translate(s, bold_name) for s in sups]) + return name + else: + sups_result = ' '.join(pretty_sups) + subs_result = ' '.join(pretty_subs) + + return ''.join([name, sups_result, subs_result]) + + +def annotated(letter): + """ + Return a stylised drawing of the letter ``letter``, together with + information on how to put annotations (super- and subscripts to the + left and to the right) on it. + + See pretty.py functions _print_meijerg, _print_hyper on how to use this + information. + """ + ucode_pics = { + 'F': (2, 0, 2, 0, '\N{BOX DRAWINGS LIGHT DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n' + '\N{BOX DRAWINGS LIGHT VERTICAL AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\n' + '\N{BOX DRAWINGS LIGHT UP}'), + 'G': (3, 0, 3, 1, '\N{BOX DRAWINGS LIGHT ARC DOWN AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC DOWN AND LEFT}\n' + '\N{BOX DRAWINGS LIGHT VERTICAL}\N{BOX DRAWINGS LIGHT RIGHT}\N{BOX DRAWINGS LIGHT DOWN AND LEFT}\n' + '\N{BOX DRAWINGS LIGHT ARC UP AND RIGHT}\N{BOX DRAWINGS LIGHT HORIZONTAL}\N{BOX DRAWINGS LIGHT ARC UP AND LEFT}') + } + ascii_pics = { + 'F': (3, 0, 3, 0, ' _\n|_\n|\n'), + 'G': (3, 0, 3, 1, ' __\n/__\n\\_|') + } + + if _use_unicode: + return ucode_pics[letter] + else: + return ascii_pics[letter] + +_remove_combining = dict.fromkeys(list(range(ord('\N{COMBINING GRAVE ACCENT}'), ord('\N{COMBINING LATIN SMALL LETTER X}'))) + + list(range(ord('\N{COMBINING LEFT HARPOON ABOVE}'), ord('\N{COMBINING ASTERISK ABOVE}')))) + +def is_combining(sym): + """Check whether symbol is a unicode modifier. """ + + return ord(sym) in _remove_combining + + +def center_accent(string, accent): + """ + Returns a string with accent inserted on the middle character. Useful to + put combining accents on symbol names, including multi-character names. + + Parameters + ========== + + string : string + The string to place the accent in. + accent : string + The combining accent to insert + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Combining_character + .. [2] https://en.wikipedia.org/wiki/Combining_Diacritical_Marks + + """ + + # Accent is placed on the previous character, although it may not always look + # like that depending on console + midpoint = len(string) // 2 + 1 + firstpart = string[:midpoint] + secondpart = string[midpoint:] + return firstpart + accent + secondpart + + +def line_width(line): + """Unicode combining symbols (modifiers) are not ever displayed as + separate symbols and thus should not be counted + """ + return len(line.translate(_remove_combining)) + + +def is_subscriptable_in_unicode(subscript): + """ + Checks whether a string is subscriptable in unicode or not. + + Parameters + ========== + + subscript: the string which needs to be checked + + Examples + ======== + + >>> from sympy.printing.pretty.pretty_symbology import is_subscriptable_in_unicode + >>> is_subscriptable_in_unicode('abc') + False + >>> is_subscriptable_in_unicode('123') + True + + """ + return all(character in sub for character in subscript) + + +def center_pad(wstring, wtarget, fillchar=' '): + """ + Return the padding strings necessary to center a string of + wstring characters wide in a wtarget wide space. + + The line_width wstring should always be less or equal to wtarget + or else a ValueError will be raised. + """ + if wstring > wtarget: + raise ValueError('not enough space for string') + wdelta = wtarget - wstring + + wleft = wdelta // 2 # favor left '1 ' + wright = wdelta - wleft + + left = fillchar * wleft + right = fillchar * wright + + return left, right + + +def center(string, width, fillchar=' '): + """Return a centered string of length determined by `line_width` + that uses `fillchar` for padding. + """ + left, right = center_pad(line_width(string), width, fillchar) + return ''.join([left, string, right]) diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/stringpict.py b/lib/python3.10/site-packages/sympy/printing/pretty/stringpict.py new file mode 100644 index 0000000000000000000000000000000000000000..b6055f09c83b2abbe0c492991aaee4dff5b34f49 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/pretty/stringpict.py @@ -0,0 +1,537 @@ +"""Prettyprinter by Jurjen Bos. +(I hate spammers: mail me at pietjepuk314 at the reverse of ku.oc.oohay). +All objects have a method that create a "stringPict", +that can be used in the str method for pretty printing. + +Updates by Jason Gedge (email at cs mun ca) + - terminal_string() method + - minor fixes and changes (mostly to prettyForm) + +TODO: + - Allow left/center/right alignment options for above/below and + top/center/bottom alignment options for left/right +""" + +import shutil + +from .pretty_symbology import hobj, vobj, xsym, xobj, pretty_use_unicode, line_width, center +from sympy.utilities.exceptions import sympy_deprecation_warning + +_GLOBAL_WRAP_LINE = None + +class stringPict: + """An ASCII picture. + The pictures are represented as a list of equal length strings. + """ + #special value for stringPict.below + LINE = 'line' + + def __init__(self, s, baseline=0): + """Initialize from string. + Multiline strings are centered. + """ + self.s = s + #picture is a string that just can be printed + self.picture = stringPict.equalLengths(s.splitlines()) + #baseline is the line number of the "base line" + self.baseline = baseline + self.binding = None + + @staticmethod + def equalLengths(lines): + # empty lines + if not lines: + return [''] + + width = max(line_width(line) for line in lines) + return [center(line, width) for line in lines] + + def height(self): + """The height of the picture in characters.""" + return len(self.picture) + + def width(self): + """The width of the picture in characters.""" + return line_width(self.picture[0]) + + @staticmethod + def next(*args): + """Put a string of stringPicts next to each other. + Returns string, baseline arguments for stringPict. + """ + #convert everything to stringPicts + objects = [] + for arg in args: + if isinstance(arg, str): + arg = stringPict(arg) + objects.append(arg) + + #make a list of pictures, with equal height and baseline + newBaseline = max(obj.baseline for obj in objects) + newHeightBelowBaseline = max( + obj.height() - obj.baseline + for obj in objects) + newHeight = newBaseline + newHeightBelowBaseline + + pictures = [] + for obj in objects: + oneEmptyLine = [' '*obj.width()] + basePadding = newBaseline - obj.baseline + totalPadding = newHeight - obj.height() + pictures.append( + oneEmptyLine * basePadding + + obj.picture + + oneEmptyLine * (totalPadding - basePadding)) + + result = [''.join(lines) for lines in zip(*pictures)] + return '\n'.join(result), newBaseline + + def right(self, *args): + r"""Put pictures next to this one. + Returns string, baseline arguments for stringPict. + (Multiline) strings are allowed, and are given a baseline of 0. + + Examples + ======== + + >>> from sympy.printing.pretty.stringpict import stringPict + >>> print(stringPict("10").right(" + ",stringPict("1\r-\r2",1))[0]) + 1 + 10 + - + 2 + + """ + return stringPict.next(self, *args) + + def left(self, *args): + """Put pictures (left to right) at left. + Returns string, baseline arguments for stringPict. + """ + return stringPict.next(*(args + (self,))) + + @staticmethod + def stack(*args): + """Put pictures on top of each other, + from top to bottom. + Returns string, baseline arguments for stringPict. + The baseline is the baseline of the second picture. + Everything is centered. + Baseline is the baseline of the second picture. + Strings are allowed. + The special value stringPict.LINE is a row of '-' extended to the width. + """ + #convert everything to stringPicts; keep LINE + objects = [] + for arg in args: + if arg is not stringPict.LINE and isinstance(arg, str): + arg = stringPict(arg) + objects.append(arg) + + #compute new width + newWidth = max( + obj.width() + for obj in objects + if obj is not stringPict.LINE) + + lineObj = stringPict(hobj('-', newWidth)) + + #replace LINE with proper lines + for i, obj in enumerate(objects): + if obj is stringPict.LINE: + objects[i] = lineObj + + #stack the pictures, and center the result + newPicture = [center(line, newWidth) for obj in objects for line in obj.picture] + newBaseline = objects[0].height() + objects[1].baseline + return '\n'.join(newPicture), newBaseline + + def below(self, *args): + """Put pictures under this picture. + Returns string, baseline arguments for stringPict. + Baseline is baseline of top picture + + Examples + ======== + + >>> from sympy.printing.pretty.stringpict import stringPict + >>> print(stringPict("x+3").below( + ... stringPict.LINE, '3')[0]) #doctest: +NORMALIZE_WHITESPACE + x+3 + --- + 3 + + """ + s, baseline = stringPict.stack(self, *args) + return s, self.baseline + + def above(self, *args): + """Put pictures above this picture. + Returns string, baseline arguments for stringPict. + Baseline is baseline of bottom picture. + """ + string, baseline = stringPict.stack(*(args + (self,))) + baseline = len(string.splitlines()) - self.height() + self.baseline + return string, baseline + + def parens(self, left='(', right=')', ifascii_nougly=False): + """Put parentheses around self. + Returns string, baseline arguments for stringPict. + + left or right can be None or empty string which means 'no paren from + that side' + """ + h = self.height() + b = self.baseline + + # XXX this is a hack -- ascii parens are ugly! + if ifascii_nougly and not pretty_use_unicode(): + h = 1 + b = 0 + + res = self + + if left: + lparen = stringPict(vobj(left, h), baseline=b) + res = stringPict(*lparen.right(self)) + if right: + rparen = stringPict(vobj(right, h), baseline=b) + res = stringPict(*res.right(rparen)) + + return ('\n'.join(res.picture), res.baseline) + + def leftslash(self): + """Precede object by a slash of the proper size. + """ + # XXX not used anywhere ? + height = max( + self.baseline, + self.height() - 1 - self.baseline)*2 + 1 + slash = '\n'.join( + ' '*(height - i - 1) + xobj('/', 1) + ' '*i + for i in range(height) + ) + return self.left(stringPict(slash, height//2)) + + def root(self, n=None): + """Produce a nice root symbol. + Produces ugly results for big n inserts. + """ + # XXX not used anywhere + # XXX duplicate of root drawing in pretty.py + #put line over expression + result = self.above('_'*self.width()) + #construct right half of root symbol + height = self.height() + slash = '\n'.join( + ' ' * (height - i - 1) + '/' + ' ' * i + for i in range(height) + ) + slash = stringPict(slash, height - 1) + #left half of root symbol + if height > 2: + downline = stringPict('\\ \n \\', 1) + else: + downline = stringPict('\\') + #put n on top, as low as possible + if n is not None and n.width() > downline.width(): + downline = downline.left(' '*(n.width() - downline.width())) + downline = downline.above(n) + #build root symbol + root = downline.right(slash) + #glue it on at the proper height + #normally, the root symbel is as high as self + #which is one less than result + #this moves the root symbol one down + #if the root became higher, the baseline has to grow too + root.baseline = result.baseline - result.height() + root.height() + return result.left(root) + + def render(self, * args, **kwargs): + """Return the string form of self. + + Unless the argument line_break is set to False, it will + break the expression in a form that can be printed + on the terminal without being broken up. + """ + if _GLOBAL_WRAP_LINE is not None: + kwargs["wrap_line"] = _GLOBAL_WRAP_LINE + + if kwargs["wrap_line"] is False: + return "\n".join(self.picture) + + if kwargs["num_columns"] is not None: + # Read the argument num_columns if it is not None + ncols = kwargs["num_columns"] + else: + # Attempt to get a terminal width + ncols = self.terminal_width() + + if ncols <= 0: + ncols = 80 + + # If smaller than the terminal width, no need to correct + if self.width() <= ncols: + return type(self.picture[0])(self) + + """ + Break long-lines in a visually pleasing format. + without overflow indicators | with overflow indicators + | 2 2 3 | | 2 2 3 ↪| + |6*x *y + 4*x*y + | |6*x *y + 4*x*y + ↪| + | | | | + | 3 4 4 | |↪ 3 4 4 | + |4*y*x + x + y | |↪ 4*y*x + x + y | + |a*c*e + a*c*f + a*d | |a*c*e + a*c*f + a*d ↪| + |*e + a*d*f + b*c*e | | | + |+ b*c*f + b*d*e + b | |↪ *e + a*d*f + b*c* ↪| + |*d*f | | | + | | |↪ e + b*c*f + b*d*e ↪| + | | | | + | | |↪ + b*d*f | + """ + + overflow_first = "" + if kwargs["use_unicode"] or pretty_use_unicode(): + overflow_start = "\N{RIGHTWARDS ARROW WITH HOOK} " + overflow_end = " \N{RIGHTWARDS ARROW WITH HOOK}" + else: + overflow_start = "> " + overflow_end = " >" + + def chunks(line): + """Yields consecutive chunks of line_width ncols""" + prefix = overflow_first + width, start = line_width(prefix + overflow_end), 0 + for i, x in enumerate(line): + wx = line_width(x) + # Only flush the screen when the current character overflows. + # This way, combining marks can be appended even when width == ncols. + if width + wx > ncols: + yield prefix + line[start:i] + overflow_end + prefix = overflow_start + width, start = line_width(prefix + overflow_end), i + width += wx + yield prefix + line[start:] + + # Concurrently assemble chunks of all lines into individual screens + pictures = zip(*map(chunks, self.picture)) + + # Join lines of each screen into sub-pictures + pictures = ["\n".join(picture) for picture in pictures] + + # Add spacers between sub-pictures + return "\n\n".join(pictures) + + def terminal_width(self): + """Return the terminal width if possible, otherwise return 0. + """ + size = shutil.get_terminal_size(fallback=(0, 0)) + return size.columns + + def __eq__(self, o): + if isinstance(o, str): + return '\n'.join(self.picture) == o + elif isinstance(o, stringPict): + return o.picture == self.picture + return False + + def __hash__(self): + return super().__hash__() + + def __str__(self): + return '\n'.join(self.picture) + + def __repr__(self): + return "stringPict(%r,%d)" % ('\n'.join(self.picture), self.baseline) + + def __getitem__(self, index): + return self.picture[index] + + def __len__(self): + return len(self.s) + + +class prettyForm(stringPict): + """ + Extension of the stringPict class that knows about basic math applications, + optimizing double minus signs. + + "Binding" is interpreted as follows:: + + ATOM this is an atom: never needs to be parenthesized + FUNC this is a function application: parenthesize if added (?) + DIV this is a division: make wider division if divided + POW this is a power: only parenthesize if exponent + MUL this is a multiplication: parenthesize if powered + ADD this is an addition: parenthesize if multiplied or powered + NEG this is a negative number: optimize if added, parenthesize if + multiplied or powered + OPEN this is an open object: parenthesize if added, multiplied, or + powered (example: Piecewise) + """ + ATOM, FUNC, DIV, POW, MUL, ADD, NEG, OPEN = range(8) + + def __init__(self, s, baseline=0, binding=0, unicode=None): + """Initialize from stringPict and binding power.""" + stringPict.__init__(self, s, baseline) + self.binding = binding + if unicode is not None: + sympy_deprecation_warning( + """ + The unicode argument to prettyForm is deprecated. Only the s + argument (the first positional argument) should be passed. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions") + self._unicode = unicode or s + + @property + def unicode(self): + sympy_deprecation_warning( + """ + The prettyForm.unicode attribute is deprecated. Use the + prettyForm.s attribute instead. + """, + deprecated_since_version="1.7", + active_deprecations_target="deprecated-pretty-printing-functions") + return self._unicode + + # Note: code to handle subtraction is in _print_Add + + def __add__(self, *others): + """Make a pretty addition. + Addition of negative numbers is simplified. + """ + arg = self + if arg.binding > prettyForm.NEG: + arg = stringPict(*arg.parens()) + result = [arg] + for arg in others: + #add parentheses for weak binders + if arg.binding > prettyForm.NEG: + arg = stringPict(*arg.parens()) + #use existing minus sign if available + if arg.binding != prettyForm.NEG: + result.append(' + ') + result.append(arg) + return prettyForm(binding=prettyForm.ADD, *stringPict.next(*result)) + + def __truediv__(self, den, slashed=False): + """Make a pretty division; stacked or slashed. + """ + if slashed: + raise NotImplementedError("Can't do slashed fraction yet") + num = self + if num.binding == prettyForm.DIV: + num = stringPict(*num.parens()) + if den.binding == prettyForm.DIV: + den = stringPict(*den.parens()) + + if num.binding==prettyForm.NEG: + num = num.right(" ")[0] + + return prettyForm(binding=prettyForm.DIV, *stringPict.stack( + num, + stringPict.LINE, + den)) + + def __mul__(self, *others): + """Make a pretty multiplication. + Parentheses are needed around +, - and neg. + """ + quantity = { + 'degree': "\N{DEGREE SIGN}" + } + + if len(others) == 0: + return self # We aren't actually multiplying... So nothing to do here. + + # add parens on args that need them + arg = self + if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG: + arg = stringPict(*arg.parens()) + result = [arg] + for arg in others: + if arg.picture[0] not in quantity.values(): + result.append(xsym('*')) + #add parentheses for weak binders + if arg.binding > prettyForm.MUL and arg.binding != prettyForm.NEG: + arg = stringPict(*arg.parens()) + result.append(arg) + + len_res = len(result) + for i in range(len_res): + if i < len_res - 1 and result[i] == '-1' and result[i + 1] == xsym('*'): + # substitute -1 by -, like in -1*x -> -x + result.pop(i) + result.pop(i) + result.insert(i, '-') + if result[0][0] == '-': + # if there is a - sign in front of all + # This test was failing to catch a prettyForm.__mul__(prettyForm("-1", 0, 6)) being negative + bin = prettyForm.NEG + if result[0] == '-': + right = result[1] + if right.picture[right.baseline][0] == '-': + result[0] = '- ' + else: + bin = prettyForm.MUL + return prettyForm(binding=bin, *stringPict.next(*result)) + + def __repr__(self): + return "prettyForm(%r,%d,%d)" % ( + '\n'.join(self.picture), + self.baseline, + self.binding) + + def __pow__(self, b): + """Make a pretty power. + """ + a = self + use_inline_func_form = False + if b.binding == prettyForm.POW: + b = stringPict(*b.parens()) + if a.binding > prettyForm.FUNC: + a = stringPict(*a.parens()) + elif a.binding == prettyForm.FUNC: + # heuristic for when to use inline power + if b.height() > 1: + a = stringPict(*a.parens()) + else: + use_inline_func_form = True + + if use_inline_func_form: + # 2 + # sin + + (x) + b.baseline = a.prettyFunc.baseline + b.height() + func = stringPict(*a.prettyFunc.right(b)) + return prettyForm(*func.right(a.prettyArgs)) + else: + # 2 <-- top + # (x+y) <-- bot + top = stringPict(*b.left(' '*a.width())) + bot = stringPict(*a.right(' '*b.width())) + + return prettyForm(binding=prettyForm.POW, *bot.above(top)) + + simpleFunctions = ["sin", "cos", "tan"] + + @staticmethod + def apply(function, *args): + """Functions of one or more variables. + """ + if function in prettyForm.simpleFunctions: + #simple function: use only space if possible + assert len( + args) == 1, "Simple function %s must have 1 argument" % function + arg = args[0].__pretty__() + if arg.binding <= prettyForm.DIV: + #optimization: no parentheses necessary + return prettyForm(binding=prettyForm.FUNC, *arg.left(function + ' ')) + argumentList = [] + for arg in args: + argumentList.append(',') + argumentList.append(arg.__pretty__()) + argumentList = stringPict(*stringPict.next(*argumentList[1:])) + argumentList = stringPict(*argumentList.parens()) + return prettyForm(binding=prettyForm.ATOM, *argumentList.left(function)) diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/tests/__init__.py b/lib/python3.10/site-packages/sympy/printing/pretty/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ec99aa99299f8665503e9ddca1de128ce6b59c11 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/pretty/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/pretty/tests/test_pretty.py b/lib/python3.10/site-packages/sympy/printing/pretty/tests/test_pretty.py new file mode 100644 index 0000000000000000000000000000000000000000..1cca79bd1dc5c3ba81483c8fe2e87c35926d1b94 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/pretty/tests/test_pretty.py @@ -0,0 +1,7972 @@ +# -*- coding: utf-8 -*- +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.containers import (Dict, Tuple) +from sympy.core.function import (Derivative, Function, Lambda, Subs) +from sympy.core.mul import Mul +from sympy.core import (EulerGamma, GoldenRatio, Catalan) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import conjugate +from sympy.functions.elementary.exponential import LambertW +from sympy.functions.special.bessel import (airyai, airyaiprime, airybi, airybiprime) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.error_functions import (fresnelc, fresnels) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.zeta_functions import dirichlet_eta +from sympy.geometry.line import (Ray, Segment) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Equivalent, ITE, Implies, Nand, Nor, Not, Or, Xor) +from sympy.matrices.dense import (Matrix, diag) +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices.expressions.trace import Trace +from sympy.polys.domains.finitefield import FF +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.domains.realfield import RR +from sympy.polys.orderings import (grlex, ilex) +from sympy.polys.polytools import groebner +from sympy.polys.rootoftools import (RootSum, rootof) +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.series.limits import Limit +from sympy.series.order import O +from sympy.series.sequences import (SeqAdd, SeqFormula, SeqMul, SeqPer) +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Complement, FiniteSet, Intersection, Interval, Union) +from sympy.codegen.ast import (Assignment, AddAugmentedAssignment, + SubAugmentedAssignment, MulAugmentedAssignment, DivAugmentedAssignment, ModAugmentedAssignment) +from sympy.core.expr import UnevaluatedExpr +from sympy.physics.quantum.trace import Tr + +from sympy.functions import (Abs, Chi, Ci, Ei, KroneckerDelta, + Piecewise, Shi, Si, atan2, beta, binomial, catalan, ceiling, cos, + euler, exp, expint, factorial, factorial2, floor, gamma, hyper, log, + meijerg, sin, sqrt, subfactorial, tan, uppergamma, lerchphi, polylog, + elliptic_k, elliptic_f, elliptic_e, elliptic_pi, DiracDelta, bell, + bernoulli, fibonacci, tribonacci, lucas, stieltjes, mathieuc, mathieus, + mathieusprime, mathieucprime) + +from sympy.matrices import (Adjoint, Inverse, MatrixSymbol, Transpose, + KroneckerProduct, BlockMatrix, OneMatrix, ZeroMatrix) +from sympy.matrices.expressions import hadamard_power + +from sympy.physics import mechanics +from sympy.physics.control.lti import (TransferFunction, Feedback, TransferFunctionMatrix, + Series, Parallel, MIMOSeries, MIMOParallel, MIMOFeedback, StateSpace) +from sympy.physics.units import joule, degree +from sympy.printing.pretty import pprint, pretty as xpretty +from sympy.printing.pretty.pretty_symbology import center_accent, is_combining, center +from sympy.sets.conditionset import ConditionSet + +from sympy.sets import ImageSet, ProductSet +from sympy.sets.setexpr import SetExpr +from sympy.stats.crv_types import Normal +from sympy.stats.symbolic_probability import (Covariance, Expectation, + Probability, Variance) +from sympy.tensor.array import (ImmutableDenseNDimArray, ImmutableSparseNDimArray, + MutableDenseNDimArray, MutableSparseNDimArray, tensorproduct) +from sympy.tensor.functions import TensorProduct +from sympy.tensor.tensor import (TensorIndexType, tensor_indices, TensorHead, + TensorElement, tensor_heads) + +from sympy.testing.pytest import raises, _both_exp_pow, warns_deprecated_sympy + +from sympy.vector import CoordSys3D, Gradient, Curl, Divergence, Dot, Cross, Laplacian + + + +import sympy as sym +class lowergamma(sym.lowergamma): + pass # testing notation inheritance by a subclass with same name + +a, b, c, d, x, y, z, k, n, s, p = symbols('a,b,c,d,x,y,z,k,n,s,p') +f = Function("f") +th = Symbol('theta') +ph = Symbol('phi') + +""" +Expressions whose pretty-printing is tested here: +(A '#' to the right of an expression indicates that its various acceptable +orderings are accounted for by the tests.) + + +BASIC EXPRESSIONS: + +oo +(x**2) +1/x +y*x**-2 +x**Rational(-5,2) +(-2)**x +Pow(3, 1, evaluate=False) +(x**2 + x + 1) # +1-x # +1-2*x # +x/y +-x/y +(x+2)/y # +(1+x)*y #3 +-5*x/(x+10) # correct placement of negative sign +1 - Rational(3,2)*(x+1) +-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5) # issue 5524 + + +ORDERING: + +x**2 + x + 1 +1 - x +1 - 2*x +2*x**4 + y**2 - x**2 + y**3 + + +RELATIONAL: + +Eq(x, y) +Lt(x, y) +Gt(x, y) +Le(x, y) +Ge(x, y) +Ne(x/(y+1), y**2) # + + +RATIONAL NUMBERS: + +y*x**-2 +y**Rational(3,2) * x**Rational(-5,2) +sin(x)**3/tan(x)**2 + + +FUNCTIONS (ABS, CONJ, EXP, FUNCTION BRACES, FACTORIAL, FLOOR, CEILING): + +(2*x + exp(x)) # +Abs(x) +Abs(x/(x**2+1)) # +Abs(1 / (y - Abs(x))) +factorial(n) +factorial(2*n) +subfactorial(n) +subfactorial(2*n) +factorial(factorial(factorial(n))) +factorial(n+1) # +conjugate(x) +conjugate(f(x+1)) # +f(x) +f(x, y) +f(x/(y+1), y) # +f(x**x**x**x**x**x) +sin(x)**2 +conjugate(a+b*I) +conjugate(exp(a+b*I)) +conjugate( f(1 + conjugate(f(x))) ) # +f(x/(y+1), y) # denom of first arg +floor(1 / (y - floor(x))) +ceiling(1 / (y - ceiling(x))) + + +SQRT: + +sqrt(2) +2**Rational(1,3) +2**Rational(1,1000) +sqrt(x**2 + 1) +(1 + sqrt(5))**Rational(1,3) +2**(1/x) +sqrt(2+pi) +(2+(1+x**2)/(2+x))**Rational(1,4)+(1+x**Rational(1,1000))/sqrt(3+x**2) + + +DERIVATIVES: + +Derivative(log(x), x, evaluate=False) +Derivative(log(x), x, evaluate=False) + x # +Derivative(log(x) + x**2, x, y, evaluate=False) +Derivative(2*x*y, y, x, evaluate=False) + x**2 # +beta(alpha).diff(alpha) + + +INTEGRALS: + +Integral(log(x), x) +Integral(x**2, x) +Integral((sin(x))**2 / (tan(x))**2) +Integral(x**(2**x), x) +Integral(x**2, (x,1,2)) +Integral(x**2, (x,Rational(1,2),10)) +Integral(x**2*y**2, x,y) +Integral(x**2, (x, None, 1)) +Integral(x**2, (x, 1, None)) +Integral(sin(th)/cos(ph), (th,0,pi), (ph, 0, 2*pi)) + + +MATRICES: + +Matrix([[x**2+1, 1], [y, x+y]]) # +Matrix([[x/y, y, th], [0, exp(I*k*ph), 1]]) + + +PIECEWISE: + +Piecewise((x,x<1),(x**2,True)) + +ITE: + +ITE(x, y, z) + +SEQUENCES (TUPLES, LISTS, DICTIONARIES): + +() +[] +{} +(1/x,) +[x**2, 1/x, x, y, sin(th)**2/cos(ph)**2] +(x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) +{x: sin(x)} +{1/x: 1/y, x: sin(x)**2} # +[x**2] +(x**2,) +{x**2: 1} + + +LIMITS: + +Limit(x, x, oo) +Limit(x**2, x, 0) +Limit(1/x, x, 0) +Limit(sin(x)/x, x, 0) + + +UNITS: + +joule => kg*m**2/s + + +SUBS: + +Subs(f(x), x, ph**2) +Subs(f(x).diff(x), x, 0) +Subs(f(x).diff(x)/y, (x, y), (0, Rational(1, 2))) + + +ORDER: + +O(1) +O(1/x) +O(x**2 + y**2) + +""" + + +def pretty(expr, order=None): + """ASCII pretty-printing""" + return xpretty(expr, order=order, use_unicode=False, wrap_line=False) + + +def upretty(expr, order=None): + """Unicode pretty-printing""" + return xpretty(expr, order=order, use_unicode=True, wrap_line=False) + + +def test_pretty_ascii_str(): + assert pretty( 'xxx' ) == 'xxx' + assert pretty( "xxx" ) == 'xxx' + assert pretty( 'xxx\'xxx' ) == 'xxx\'xxx' + assert pretty( 'xxx"xxx' ) == 'xxx\"xxx' + assert pretty( 'xxx\"xxx' ) == 'xxx\"xxx' + assert pretty( "xxx'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\"xxx" ) == 'xxx\"xxx' + assert pretty( "xxx\"xxx\'xxx" ) == 'xxx"xxx\'xxx' + assert pretty( "xxx\nxxx" ) == 'xxx\nxxx' + + +def test_pretty_unicode_str(): + assert pretty( 'xxx' ) == 'xxx' + assert pretty( 'xxx' ) == 'xxx' + assert pretty( 'xxx\'xxx' ) == 'xxx\'xxx' + assert pretty( 'xxx"xxx' ) == 'xxx\"xxx' + assert pretty( 'xxx\"xxx' ) == 'xxx\"xxx' + assert pretty( "xxx'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\'xxx" ) == 'xxx\'xxx' + assert pretty( "xxx\"xxx" ) == 'xxx\"xxx' + assert pretty( "xxx\"xxx\'xxx" ) == 'xxx"xxx\'xxx' + assert pretty( "xxx\nxxx" ) == 'xxx\nxxx' + + +def test_upretty_greek(): + assert upretty( oo ) == '∞' + assert upretty( Symbol('alpha^+_1') ) == 'α⁺₁' + assert upretty( Symbol('beta') ) == 'β' + assert upretty(Symbol('lambda')) == 'λ' + + +def test_upretty_multiindex(): + assert upretty( Symbol('beta12') ) == 'β₁₂' + assert upretty( Symbol('Y00') ) == 'Y₀₀' + assert upretty( Symbol('Y_00') ) == 'Y₀₀' + assert upretty( Symbol('F^+-') ) == 'F⁺⁻' + + +def test_upretty_sub_super(): + assert upretty( Symbol('beta_1_2') ) == 'β₁ ₂' + assert upretty( Symbol('beta^1^2') ) == 'β¹ ²' + assert upretty( Symbol('beta_1^2') ) == 'β²₁' + assert upretty( Symbol('beta_10_20') ) == 'β₁₀ ₂₀' + assert upretty( Symbol('beta_ax_gamma^i') ) == 'βⁱₐₓ ᵧ' + assert upretty( Symbol("F^1^2_3_4") ) == 'F¹ ²₃ ₄' + assert upretty( Symbol("F_1_2^3^4") ) == 'F³ ⁴₁ ₂' + assert upretty( Symbol("F_1_2_3_4") ) == 'F₁ ₂ ₃ ₄' + assert upretty( Symbol("F^1^2^3^4") ) == 'F¹ ² ³ ⁴' + + +def test_upretty_subs_missing_in_24(): + assert upretty( Symbol('F_beta') ) == 'Fᵦ' + assert upretty( Symbol('F_gamma') ) == 'Fᵧ' + assert upretty( Symbol('F_rho') ) == 'Fᵨ' + assert upretty( Symbol('F_phi') ) == 'Fᵩ' + assert upretty( Symbol('F_chi') ) == 'Fᵪ' + + assert upretty( Symbol('F_a') ) == 'Fₐ' + assert upretty( Symbol('F_e') ) == 'Fₑ' + assert upretty( Symbol('F_i') ) == 'Fᵢ' + assert upretty( Symbol('F_o') ) == 'Fₒ' + assert upretty( Symbol('F_u') ) == 'Fᵤ' + assert upretty( Symbol('F_r') ) == 'Fᵣ' + assert upretty( Symbol('F_v') ) == 'Fᵥ' + assert upretty( Symbol('F_x') ) == 'Fₓ' + + +def test_missing_in_2X_issue_9047(): + assert upretty( Symbol('F_h') ) == 'Fₕ' + assert upretty( Symbol('F_k') ) == 'Fₖ' + assert upretty( Symbol('F_l') ) == 'Fₗ' + assert upretty( Symbol('F_m') ) == 'Fₘ' + assert upretty( Symbol('F_n') ) == 'Fₙ' + assert upretty( Symbol('F_p') ) == 'Fₚ' + assert upretty( Symbol('F_s') ) == 'Fₛ' + assert upretty( Symbol('F_t') ) == 'Fₜ' + + +def test_upretty_modifiers(): + # Accents + assert upretty( Symbol('Fmathring') ) == 'F̊' + assert upretty( Symbol('Fddddot') ) == 'F⃜' + assert upretty( Symbol('Fdddot') ) == 'F⃛' + assert upretty( Symbol('Fddot') ) == 'F̈' + assert upretty( Symbol('Fdot') ) == 'Ḟ' + assert upretty( Symbol('Fcheck') ) == 'F̌' + assert upretty( Symbol('Fbreve') ) == 'F̆' + assert upretty( Symbol('Facute') ) == 'F́' + assert upretty( Symbol('Fgrave') ) == 'F̀' + assert upretty( Symbol('Ftilde') ) == 'F̃' + assert upretty( Symbol('Fhat') ) == 'F̂' + assert upretty( Symbol('Fbar') ) == 'F̅' + assert upretty( Symbol('Fvec') ) == 'F⃗' + assert upretty( Symbol('Fprime') ) == 'F′' + assert upretty( Symbol('Fprm') ) == 'F′' + # No faces are actually implemented, but test to make sure the modifiers are stripped + assert upretty( Symbol('Fbold') ) == 'Fbold' + assert upretty( Symbol('Fbm') ) == 'Fbm' + assert upretty( Symbol('Fcal') ) == 'Fcal' + assert upretty( Symbol('Fscr') ) == 'Fscr' + assert upretty( Symbol('Ffrak') ) == 'Ffrak' + # Brackets + assert upretty( Symbol('Fnorm') ) == '‖F‖' + assert upretty( Symbol('Favg') ) == '⟨F⟩' + assert upretty( Symbol('Fabs') ) == '|F|' + assert upretty( Symbol('Fmag') ) == '|F|' + # Combinations + assert upretty( Symbol('xvecdot') ) == 'x⃗̇' + assert upretty( Symbol('xDotVec') ) == 'ẋ⃗' + assert upretty( Symbol('xHATNorm') ) == '‖x̂‖' + assert upretty( Symbol('xMathring_yCheckPRM__zbreveAbs') ) == 'x̊_y̌′__|z̆|' + assert upretty( Symbol('alphadothat_nVECDOT__tTildePrime') ) == 'α̇̂_n⃗̇__t̃′' + assert upretty( Symbol('x_dot') ) == 'x_dot' + assert upretty( Symbol('x__dot') ) == 'x__dot' + + +def test_pretty_Cycle(): + from sympy.combinatorics.permutations import Cycle + assert pretty(Cycle(1, 2)) == '(1 2)' + assert pretty(Cycle(2)) == '(2)' + assert pretty(Cycle(1, 3)(4, 5)) == '(1 3)(4 5)' + assert pretty(Cycle()) == '()' + + +def test_pretty_Permutation(): + from sympy.combinatorics.permutations import Permutation + p1 = Permutation(1, 2)(3, 4) + assert xpretty(p1, perm_cyclic=True, use_unicode=True) == "(1 2)(3 4)" + assert xpretty(p1, perm_cyclic=True, use_unicode=False) == "(1 2)(3 4)" + assert xpretty(p1, perm_cyclic=False, use_unicode=True) == \ + '⎛0 1 2 3 4⎞\n'\ + '⎝0 2 1 4 3⎠' + assert xpretty(p1, perm_cyclic=False, use_unicode=False) == \ + "/0 1 2 3 4\\\n"\ + "\\0 2 1 4 3/" + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert xpretty(p1, use_unicode=True) == \ + '⎛0 1 2 3 4⎞\n'\ + '⎝0 2 1 4 3⎠' + assert xpretty(p1, use_unicode=False) == \ + "/0 1 2 3 4\\\n"\ + "\\0 2 1 4 3/" + Permutation.print_cyclic = old_print_cyclic + + +def test_pretty_basic(): + assert pretty( -Rational(1)/2 ) == '-1/2' + assert pretty( -Rational(13)/22 ) == \ +"""\ +-13 \n\ +----\n\ + 22 \ +""" + expr = oo + ascii_str = \ +"""\ +oo\ +""" + ucode_str = \ +"""\ +∞\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2) + ascii_str = \ +"""\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 1/x + ascii_str = \ +"""\ +1\n\ +-\n\ +x\ +""" + ucode_str = \ +"""\ +1\n\ +─\n\ +x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # not the same as 1/x + expr = x**-1.0 + ascii_str = \ +"""\ + -1.0\n\ +x \ +""" + ucode_str = \ +"""\ + -1.0\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # see issue #2860 + expr = Pow(S(2), -1.0, evaluate=False) + ascii_str = \ +"""\ + -1.0\n\ +2 \ +""" + ucode_str = \ +"""\ + -1.0\n\ +2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = y*x**-2 + ascii_str = \ +"""\ +y \n\ +--\n\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ +y \n\ +──\n\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + #see issue #14033 + expr = x**Rational(1, 3) + ascii_str = \ +"""\ + 1/3\n\ +x \ +""" + ucode_str = \ +"""\ + 1/3\n\ +x \ +""" + assert xpretty(expr, use_unicode=False, wrap_line=False,\ + root_notation = False) == ascii_str + assert xpretty(expr, use_unicode=True, wrap_line=False,\ + root_notation = False) == ucode_str + + expr = x**Rational(-5, 2) + ascii_str = \ +"""\ + 1 \n\ +----\n\ + 5/2\n\ +x \ +""" + ucode_str = \ +"""\ + 1 \n\ +────\n\ + 5/2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (-2)**x + ascii_str = \ +"""\ + x\n\ +(-2) \ +""" + ucode_str = \ +"""\ + x\n\ +(-2) \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # See issue 4923 + expr = Pow(3, 1, evaluate=False) + ascii_str = \ +"""\ + 1\n\ +3 \ +""" + ucode_str = \ +"""\ + 1\n\ +3 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2 + x + 1) + ascii_str_1 = \ +"""\ + 2\n\ +1 + x + x \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ +x + x + 1\ +""" + ascii_str_3 = \ +"""\ + 2 \n\ +x + 1 + x\ +""" + ucode_str_1 = \ +"""\ + 2\n\ +1 + x + x \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ +x + x + 1\ +""" + ucode_str_3 = \ +"""\ + 2 \n\ +x + 1 + x\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = 1 - x + ascii_str_1 = \ +"""\ +1 - x\ +""" + ascii_str_2 = \ +"""\ +-x + 1\ +""" + ucode_str_1 = \ +"""\ +1 - x\ +""" + ucode_str_2 = \ +"""\ +-x + 1\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = 1 - 2*x + ascii_str_1 = \ +"""\ +1 - 2*x\ +""" + ascii_str_2 = \ +"""\ +-2*x + 1\ +""" + ucode_str_1 = \ +"""\ +1 - 2⋅x\ +""" + ucode_str_2 = \ +"""\ +-2⋅x + 1\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = x/y + ascii_str = \ +"""\ +x\n\ +-\n\ +y\ +""" + ucode_str = \ +"""\ +x\n\ +─\n\ +y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -x/y + ascii_str = \ +"""\ +-x \n\ +---\n\ + y \ +""" + ucode_str = \ +"""\ +-x \n\ +───\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x + 2)/y + ascii_str_1 = \ +"""\ +2 + x\n\ +-----\n\ + y \ +""" + ascii_str_2 = \ +"""\ +x + 2\n\ +-----\n\ + y \ +""" + ucode_str_1 = \ +"""\ +2 + x\n\ +─────\n\ + y \ +""" + ucode_str_2 = \ +"""\ +x + 2\n\ +─────\n\ + y \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = (1 + x)*y + ascii_str_1 = \ +"""\ +y*(1 + x)\ +""" + ascii_str_2 = \ +"""\ +(1 + x)*y\ +""" + ascii_str_3 = \ +"""\ +y*(x + 1)\ +""" + ucode_str_1 = \ +"""\ +y⋅(1 + x)\ +""" + ucode_str_2 = \ +"""\ +(1 + x)⋅y\ +""" + ucode_str_3 = \ +"""\ +y⋅(x + 1)\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + # Test for correct placement of the negative sign + expr = -5*x/(x + 10) + ascii_str_1 = \ +"""\ +-5*x \n\ +------\n\ +10 + x\ +""" + ascii_str_2 = \ +"""\ +-5*x \n\ +------\n\ +x + 10\ +""" + ucode_str_1 = \ +"""\ +-5⋅x \n\ +──────\n\ +10 + x\ +""" + ucode_str_2 = \ +"""\ +-5⋅x \n\ +──────\n\ +x + 10\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = -S.Half - 3*x + ascii_str = \ +"""\ +-3*x - 1/2\ +""" + ucode_str = \ +"""\ +-3⋅x - 1/2\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = S.Half - 3*x + ascii_str = \ +"""\ +1/2 - 3*x\ +""" + ucode_str = \ +"""\ +1/2 - 3⋅x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -S.Half - 3*x/2 + ascii_str = \ +"""\ + 3*x 1\n\ +- --- - -\n\ + 2 2\ +""" + ucode_str = \ +"""\ + 3⋅x 1\n\ +- ─── - ─\n\ + 2 2\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = S.Half - 3*x/2 + ascii_str = \ +"""\ +1 3*x\n\ +- - ---\n\ +2 2 \ +""" + ucode_str = \ +"""\ +1 3⋅x\n\ +─ - ───\n\ +2 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_negative_fractions(): + expr = -x/y + ascii_str =\ +"""\ +-x \n\ +---\n\ + y \ +""" + ucode_str =\ +"""\ +-x \n\ +───\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x*z/y + ascii_str =\ +"""\ +-x*z \n\ +-----\n\ + y \ +""" + ucode_str =\ +"""\ +-x⋅z \n\ +─────\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = x**2/y + ascii_str =\ +"""\ + 2\n\ +x \n\ +--\n\ +y \ +""" + ucode_str =\ +"""\ + 2\n\ +x \n\ +──\n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x**2/y + ascii_str =\ +"""\ + 2 \n\ +-x \n\ +----\n\ + y \ +""" + ucode_str =\ +"""\ + 2 \n\ +-x \n\ +────\n\ + y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -x/(y*z) + ascii_str =\ +"""\ +-x \n\ +---\n\ +y*z\ +""" + ucode_str =\ +"""\ +-x \n\ +───\n\ +y⋅z\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -a/y**2 + ascii_str =\ +"""\ +-a \n\ +---\n\ + 2 \n\ +y \ +""" + ucode_str =\ +"""\ +-a \n\ +───\n\ + 2 \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = y**(-a/b) + ascii_str =\ +"""\ + -a \n\ + ---\n\ + b \n\ +y \ +""" + ucode_str =\ +"""\ + -a \n\ + ───\n\ + b \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -1/y**2 + ascii_str =\ +"""\ +-1 \n\ +---\n\ + 2 \n\ +y \ +""" + ucode_str =\ +"""\ +-1 \n\ +───\n\ + 2 \n\ +y \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = -10/b**2 + ascii_str =\ +"""\ +-10 \n\ +----\n\ + 2 \n\ + b \ +""" + ucode_str =\ +"""\ +-10 \n\ +────\n\ + 2 \n\ + b \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + expr = Rational(-200, 37) + ascii_str =\ +"""\ +-200 \n\ +-----\n\ + 37 \ +""" + ucode_str =\ +"""\ +-200 \n\ +─────\n\ + 37 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_Mul(): + expr = Mul(0, 1, evaluate=False) + assert pretty(expr) == "0*1" + assert upretty(expr) == "0⋅1" + expr = Mul(1, 0, evaluate=False) + assert pretty(expr) == "1*0" + assert upretty(expr) == "1⋅0" + expr = Mul(1, 1, evaluate=False) + assert pretty(expr) == "1*1" + assert upretty(expr) == "1⋅1" + expr = Mul(1, 1, 1, evaluate=False) + assert pretty(expr) == "1*1*1" + assert upretty(expr) == "1⋅1⋅1" + expr = Mul(1, 2, evaluate=False) + assert pretty(expr) == "1*2" + assert upretty(expr) == "1⋅2" + expr = Add(0, 1, evaluate=False) + assert pretty(expr) == "0 + 1" + assert upretty(expr) == "0 + 1" + expr = Mul(1, 1, 2, evaluate=False) + assert pretty(expr) == "1*1*2" + assert upretty(expr) == "1⋅1⋅2" + expr = Add(0, 0, 1, evaluate=False) + assert pretty(expr) == "0 + 0 + 1" + assert upretty(expr) == "0 + 0 + 1" + expr = Mul(1, -1, evaluate=False) + assert pretty(expr) == "1*-1" + assert upretty(expr) == "1⋅-1" + expr = Mul(1.0, x, evaluate=False) + assert pretty(expr) == "1.0*x" + assert upretty(expr) == "1.0⋅x" + expr = Mul(1, 1, 2, 3, x, evaluate=False) + assert pretty(expr) == "1*1*2*3*x" + assert upretty(expr) == "1⋅1⋅2⋅3⋅x" + expr = Mul(-1, 1, evaluate=False) + assert pretty(expr) == "-1*1" + assert upretty(expr) == "-1⋅1" + expr = Mul(4, 3, 2, 1, 0, y, x, evaluate=False) + assert pretty(expr) == "4*3*2*1*0*y*x" + assert upretty(expr) == "4⋅3⋅2⋅1⋅0⋅y⋅x" + expr = Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False) + assert pretty(expr) == "4*3*2*(z + 1)*0*y*x" + assert upretty(expr) == "4⋅3⋅2⋅(z + 1)⋅0⋅y⋅x" + expr = Mul(Rational(2, 3), Rational(5, 7), evaluate=False) + assert pretty(expr) == "2/3*5/7" + assert upretty(expr) == "2/3⋅5/7" + expr = Mul(x + y, Rational(1, 2), evaluate=False) + assert pretty(expr) == "(x + y)*1/2" + assert upretty(expr) == "(x + y)⋅1/2" + expr = Mul(Rational(1, 2), x + y, evaluate=False) + assert pretty(expr) == "x + y\n-----\n 2 " + assert upretty(expr) == "x + y\n─────\n 2 " + expr = Mul(S.One, x + y, evaluate=False) + assert pretty(expr) == "1*(x + y)" + assert upretty(expr) == "1⋅(x + y)" + expr = Mul(x - y, S.One, evaluate=False) + assert pretty(expr) == "(x - y)*1" + assert upretty(expr) == "(x - y)⋅1" + expr = Mul(Rational(1, 2), x - y, S.One, x + y, evaluate=False) + assert pretty(expr) == "1/2*(x - y)*1*(x + y)" + assert upretty(expr) == "1/2⋅(x - y)⋅1⋅(x + y)" + expr = Mul(x + y, Rational(3, 4), S.One, y - z, evaluate=False) + assert pretty(expr) == "(x + y)*3/4*1*(y - z)" + assert upretty(expr) == "(x + y)⋅3/4⋅1⋅(y - z)" + expr = Mul(x + y, Rational(1, 1), Rational(3, 4), Rational(5, 6),evaluate=False) + assert pretty(expr) == "(x + y)*1*3/4*5/6" + assert upretty(expr) == "(x + y)⋅1⋅3/4⋅5/6" + expr = Mul(Rational(3, 4), x + y, S.One, y - z, evaluate=False) + assert pretty(expr) == "3/4*(x + y)*1*(y - z)" + assert upretty(expr) == "3/4⋅(x + y)⋅1⋅(y - z)" + + +def test_issue_5524(): + assert pretty(-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5)) == \ +"""\ + 2 / ___ \\\n\ +- (5 - y) + (x - 5)*\\-x - 2*\\/ 2 + 5/\ +""" + + assert upretty(-(-x + 5)*(-x - 2*sqrt(2) + 5) - (-y + 5)*(-y + 5)) == \ +"""\ + 2 \n\ +- (5 - y) + (x - 5)⋅(-x - 2⋅√2 + 5)\ +""" + + +def test_pretty_ordering(): + assert pretty(x**2 + x + 1, order='lex') == \ +"""\ + 2 \n\ +x + x + 1\ +""" + assert pretty(x**2 + x + 1, order='rev-lex') == \ +"""\ + 2\n\ +1 + x + x \ +""" + assert pretty(1 - x, order='lex') == '-x + 1' + assert pretty(1 - x, order='rev-lex') == '1 - x' + + assert pretty(1 - 2*x, order='lex') == '-2*x + 1' + assert pretty(1 - 2*x, order='rev-lex') == '1 - 2*x' + + f = 2*x**4 + y**2 - x**2 + y**3 + assert pretty(f, order=None) == \ +"""\ + 4 2 3 2\n\ +2*x - x + y + y \ +""" + assert pretty(f, order='lex') == \ +"""\ + 4 2 3 2\n\ +2*x - x + y + y \ +""" + assert pretty(f, order='rev-lex') == \ +"""\ + 2 3 2 4\n\ +y + y - x + 2*x \ +""" + + expr = x - x**3/6 + x**5/120 + O(x**6) + ascii_str = \ +"""\ + 3 5 \n\ + x x / 6\\\n\ +x - -- + --- + O\\x /\n\ + 6 120 \ +""" + ucode_str = \ +"""\ + 3 5 \n\ + x x ⎛ 6⎞\n\ +x - ── + ─── + O⎝x ⎠\n\ + 6 120 \ +""" + assert pretty(expr, order=None) == ascii_str + assert upretty(expr, order=None) == ucode_str + + assert pretty(expr, order='lex') == ascii_str + assert upretty(expr, order='lex') == ucode_str + + assert pretty(expr, order='rev-lex') == ascii_str + assert upretty(expr, order='rev-lex') == ucode_str + + +def test_EulerGamma(): + assert pretty(EulerGamma) == str(EulerGamma) == "EulerGamma" + assert upretty(EulerGamma) == "γ" + + +def test_GoldenRatio(): + assert pretty(GoldenRatio) == str(GoldenRatio) == "GoldenRatio" + assert upretty(GoldenRatio) == "φ" + + +def test_Catalan(): + assert pretty(Catalan) == upretty(Catalan) == "G" + + +def test_pretty_relational(): + expr = Eq(x, y) + ascii_str = \ +"""\ +x = y\ +""" + ucode_str = \ +"""\ +x = y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lt(x, y) + ascii_str = \ +"""\ +x < y\ +""" + ucode_str = \ +"""\ +x < y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Gt(x, y) + ascii_str = \ +"""\ +x > y\ +""" + ucode_str = \ +"""\ +x > y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Le(x, y) + ascii_str = \ +"""\ +x <= y\ +""" + ucode_str = \ +"""\ +x ≤ y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Ge(x, y) + ascii_str = \ +"""\ +x >= y\ +""" + ucode_str = \ +"""\ +x ≥ y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Ne(x/(y + 1), y**2) + ascii_str_1 = \ +"""\ + x 2\n\ +----- != y \n\ +1 + y \ +""" + ascii_str_2 = \ +"""\ + x 2\n\ +----- != y \n\ +y + 1 \ +""" + ucode_str_1 = \ +"""\ + x 2\n\ +───── ≠ y \n\ +1 + y \ +""" + ucode_str_2 = \ +"""\ + x 2\n\ +───── ≠ y \n\ +y + 1 \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + +def test_Assignment(): + expr = Assignment(x, y) + ascii_str = \ +"""\ +x := y\ +""" + ucode_str = \ +"""\ +x := y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_AugmentedAssignment(): + expr = AddAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x += y\ +""" + ucode_str = \ +"""\ +x += y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = SubAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x -= y\ +""" + ucode_str = \ +"""\ +x -= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = MulAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x *= y\ +""" + ucode_str = \ +"""\ +x *= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = DivAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x /= y\ +""" + ucode_str = \ +"""\ +x /= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = ModAugmentedAssignment(x, y) + ascii_str = \ +"""\ +x %= y\ +""" + ucode_str = \ +"""\ +x %= y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_rational(): + expr = y*x**-2 + ascii_str = \ +"""\ +y \n\ +--\n\ + 2\n\ +x \ +""" + ucode_str = \ +"""\ +y \n\ +──\n\ + 2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = y**Rational(3, 2) * x**Rational(-5, 2) + ascii_str = \ +"""\ + 3/2\n\ +y \n\ +----\n\ + 5/2\n\ +x \ +""" + ucode_str = \ +"""\ + 3/2\n\ +y \n\ +────\n\ + 5/2\n\ +x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sin(x)**3/tan(x)**2 + ascii_str = \ +"""\ + 3 \n\ +sin (x)\n\ +-------\n\ + 2 \n\ +tan (x)\ +""" + ucode_str = \ +"""\ + 3 \n\ +sin (x)\n\ +───────\n\ + 2 \n\ +tan (x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +@_both_exp_pow +def test_pretty_functions(): + """Tests for Abs, conjugate, exp, function braces, and factorial.""" + expr = (2*x + exp(x)) + ascii_str_1 = \ +"""\ + x\n\ +2*x + e \ +""" + ascii_str_2 = \ +"""\ + x \n\ +e + 2*x\ +""" + ucode_str_1 = \ +"""\ + x\n\ +2⋅x + ℯ \ +""" + ucode_str_2 = \ +"""\ + x \n\ +ℯ + 2⋅x\ +""" + ucode_str_3 = \ +"""\ + x \n\ +ℯ + 2⋅x\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Abs(x) + ascii_str = \ +"""\ +|x|\ +""" + ucode_str = \ +"""\ +│x│\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Abs(x/(x**2 + 1)) + ascii_str_1 = \ +"""\ +| x |\n\ +|------|\n\ +| 2|\n\ +|1 + x |\ +""" + ascii_str_2 = \ +"""\ +| x |\n\ +|------|\n\ +| 2 |\n\ +|x + 1|\ +""" + ucode_str_1 = \ +"""\ +│ x │\n\ +│──────│\n\ +│ 2│\n\ +│1 + x │\ +""" + ucode_str_2 = \ +"""\ +│ x │\n\ +│──────│\n\ +│ 2 │\n\ +│x + 1│\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = Abs(1 / (y - Abs(x))) + ascii_str = \ +"""\ + 1 \n\ +---------\n\ +|y - |x||\ +""" + ucode_str = \ +"""\ + 1 \n\ +─────────\n\ +│y - │x││\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + n = Symbol('n', integer=True) + expr = factorial(n) + ascii_str = \ +"""\ +n!\ +""" + ucode_str = \ +"""\ +n!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(2*n) + ascii_str = \ +"""\ +(2*n)!\ +""" + ucode_str = \ +"""\ +(2⋅n)!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(factorial(factorial(n))) + ascii_str = \ +"""\ +((n!)!)!\ +""" + ucode_str = \ +"""\ +((n!)!)!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial(n + 1) + ascii_str_1 = \ +"""\ +(1 + n)!\ +""" + ascii_str_2 = \ +"""\ +(n + 1)!\ +""" + ucode_str_1 = \ +"""\ +(1 + n)!\ +""" + ucode_str_2 = \ +"""\ +(n + 1)!\ +""" + + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = subfactorial(n) + ascii_str = \ +"""\ +!n\ +""" + ucode_str = \ +"""\ +!n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = subfactorial(2*n) + ascii_str = \ +"""\ +!(2*n)\ +""" + ucode_str = \ +"""\ +!(2⋅n)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + n = Symbol('n', integer=True) + expr = factorial2(n) + ascii_str = \ +"""\ +n!!\ +""" + ucode_str = \ +"""\ +n!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(2*n) + ascii_str = \ +"""\ +(2*n)!!\ +""" + ucode_str = \ +"""\ +(2⋅n)!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(factorial2(factorial2(n))) + ascii_str = \ +"""\ +((n!!)!!)!!\ +""" + ucode_str = \ +"""\ +((n!!)!!)!!\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = factorial2(n + 1) + ascii_str_1 = \ +"""\ +(1 + n)!!\ +""" + ascii_str_2 = \ +"""\ +(n + 1)!!\ +""" + ucode_str_1 = \ +"""\ +(1 + n)!!\ +""" + ucode_str_2 = \ +"""\ +(n + 1)!!\ +""" + + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = 2*binomial(n, k) + ascii_str = \ +"""\ + /n\\\n\ +2*| |\n\ + \\k/\ +""" + ucode_str = \ +"""\ + ⎛n⎞\n\ +2⋅⎜ ⎟\n\ + ⎝k⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*binomial(2*n, k) + ascii_str = \ +"""\ + /2*n\\\n\ +2*| |\n\ + \\ k /\ +""" + ucode_str = \ +"""\ + ⎛2⋅n⎞\n\ +2⋅⎜ ⎟\n\ + ⎝ k ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*binomial(n**2, k) + ascii_str = \ +"""\ + / 2\\\n\ + |n |\n\ +2*| |\n\ + \\k /\ +""" + ucode_str = \ +"""\ + ⎛ 2⎞\n\ + ⎜n ⎟\n\ +2⋅⎜ ⎟\n\ + ⎝k ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = catalan(n) + ascii_str = \ +"""\ +C \n\ + n\ +""" + ucode_str = \ +"""\ +C \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = catalan(n) + ascii_str = \ +"""\ +C \n\ + n\ +""" + ucode_str = \ +"""\ +C \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bell(n) + ascii_str = \ +"""\ +B \n\ + n\ +""" + ucode_str = \ +"""\ +B \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bernoulli(n) + ascii_str = \ +"""\ +B \n\ + n\ +""" + ucode_str = \ +"""\ +B \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = bernoulli(n, x) + ascii_str = \ +"""\ +B (x)\n\ + n \ +""" + ucode_str = \ +"""\ +B (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = fibonacci(n) + ascii_str = \ +"""\ +F \n\ + n\ +""" + ucode_str = \ +"""\ +F \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = lucas(n) + ascii_str = \ +"""\ +L \n\ + n\ +""" + ucode_str = \ +"""\ +L \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = tribonacci(n) + ascii_str = \ +"""\ +T \n\ + n\ +""" + ucode_str = \ +"""\ +T \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = stieltjes(n) + ascii_str = \ +"""\ +stieltjes \n\ + n\ +""" + ucode_str = \ +"""\ +γ \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = stieltjes(n, x) + ascii_str = \ +"""\ +stieltjes (x)\n\ + n \ +""" + ucode_str = \ +"""\ +γ (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieuc(x, y, z) + ascii_str = 'C(x, y, z)' + ucode_str = 'C(x, y, z)' + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieus(x, y, z) + ascii_str = 'S(x, y, z)' + ucode_str = 'S(x, y, z)' + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieucprime(x, y, z) + ascii_str = "C'(x, y, z)" + ucode_str = "C'(x, y, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = mathieusprime(x, y, z) + ascii_str = "S'(x, y, z)" + ucode_str = "S'(x, y, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(x) + ascii_str = \ +"""\ +_\n\ +x\ +""" + ucode_str = \ +"""\ +_\n\ +x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + f = Function('f') + expr = conjugate(f(x + 1)) + ascii_str_1 = \ +"""\ +________\n\ +f(1 + x)\ +""" + ascii_str_2 = \ +"""\ +________\n\ +f(x + 1)\ +""" + ucode_str_1 = \ +"""\ +________\n\ +f(1 + x)\ +""" + ucode_str_2 = \ +"""\ +________\n\ +f(x + 1)\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x) + ascii_str = \ +"""\ +f(x)\ +""" + ucode_str = \ +"""\ +f(x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = f(x, y) + ascii_str = \ +"""\ +f(x, y)\ +""" + ucode_str = \ +"""\ +f(x, y)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = f(x/(y + 1), y) + ascii_str_1 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\1 + y /\ +""" + ascii_str_2 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\y + 1 /\ +""" + ucode_str_1 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝1 + y ⎠\ +""" + ucode_str_2 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝y + 1 ⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x**x**x**x**x**x) + ascii_str = \ +"""\ + / / / / / x\\\\\\\\\\ + | | | | \\x /|||| + | | | \\x /||| + | | \\x /|| + | \\x /| +f\\x /\ +""" + ucode_str = \ +"""\ + ⎛ ⎛ ⎛ ⎛ ⎛ x⎞⎞⎞⎞⎞ + ⎜ ⎜ ⎜ ⎜ ⎝x ⎠⎟⎟⎟⎟ + ⎜ ⎜ ⎜ ⎝x ⎠⎟⎟⎟ + ⎜ ⎜ ⎝x ⎠⎟⎟ + ⎜ ⎝x ⎠⎟ +f⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sin(x)**2 + ascii_str = \ +"""\ + 2 \n\ +sin (x)\ +""" + ucode_str = \ +"""\ + 2 \n\ +sin (x)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(a + b*I) + ascii_str = \ +"""\ +_ _\n\ +a - I*b\ +""" + ucode_str = \ +"""\ +_ _\n\ +a - ⅈ⋅b\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate(exp(a + b*I)) + ascii_str = \ +"""\ + _ _\n\ + a - I*b\n\ +e \ +""" + ucode_str = \ +"""\ + _ _\n\ + a - ⅈ⋅b\n\ +ℯ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = conjugate( f(1 + conjugate(f(x))) ) + ascii_str_1 = \ +"""\ +___________\n\ + / ____\\\n\ +f\\1 + f(x)/\ +""" + ascii_str_2 = \ +"""\ +___________\n\ + /____ \\\n\ +f\\f(x) + 1/\ +""" + ucode_str_1 = \ +"""\ +___________\n\ + ⎛ ____⎞\n\ +f⎝1 + f(x)⎠\ +""" + ucode_str_2 = \ +"""\ +___________\n\ + ⎛____ ⎞\n\ +f⎝f(x) + 1⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = f(x/(y + 1), y) + ascii_str_1 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\1 + y /\ +""" + ascii_str_2 = \ +"""\ + / x \\\n\ +f|-----, y|\n\ + \\y + 1 /\ +""" + ucode_str_1 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝1 + y ⎠\ +""" + ucode_str_2 = \ +"""\ + ⎛ x ⎞\n\ +f⎜─────, y⎟\n\ + ⎝y + 1 ⎠\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = floor(1 / (y - floor(x))) + ascii_str = \ +"""\ + / 1 \\\n\ +floor|------------|\n\ + \\y - floor(x)/\ +""" + ucode_str = \ +"""\ +⎢ 1 ⎥\n\ +⎢───────⎥\n\ +⎣y - ⌊x⌋⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = ceiling(1 / (y - ceiling(x))) + ascii_str = \ +"""\ + / 1 \\\n\ +ceiling|--------------|\n\ + \\y - ceiling(x)/\ +""" + ucode_str = \ +"""\ +⎡ 1 ⎤\n\ +⎢───────⎥\n\ +⎢y - ⌈x⌉⎥\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n) + ascii_str = \ +"""\ +E \n\ + n\ +""" + ucode_str = \ +"""\ +E \n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(1/(1 + 1/(1 + 1/n))) + ascii_str = \ +"""\ +E \n\ + 1 \n\ + ---------\n\ + 1 \n\ + 1 + -----\n\ + 1\n\ + 1 + -\n\ + n\ +""" + + ucode_str = \ +"""\ +E \n\ + 1 \n\ + ─────────\n\ + 1 \n\ + 1 + ─────\n\ + 1\n\ + 1 + ─\n\ + n\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n, x) + ascii_str = \ +"""\ +E (x)\n\ + n \ +""" + ucode_str = \ +"""\ +E (x)\n\ + n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = euler(n, x/2) + ascii_str = \ +"""\ + /x\\\n\ +E |-|\n\ + n\\2/\ +""" + ucode_str = \ +"""\ + ⎛x⎞\n\ +E ⎜─⎟\n\ + n⎝2⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_sqrt(): + expr = sqrt(2) + ascii_str = \ +"""\ + ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"√2" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**Rational(1, 3) + ascii_str = \ +"""\ +3 ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"""\ +3 ___\n\ +╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**Rational(1, 1000) + ascii_str = \ +"""\ +1000___\n\ + \\/ 2 \ +""" + ucode_str = \ +"""\ +1000___\n\ + ╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sqrt(x**2 + 1) + ascii_str = \ +"""\ + ________\n\ + / 2 \n\ +\\/ x + 1 \ +""" + ucode_str = \ +"""\ + ________\n\ + ╱ 2 \n\ +╲╱ x + 1 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (1 + sqrt(5))**Rational(1, 3) + ascii_str = \ +"""\ + ___________\n\ +3 / ___ \n\ +\\/ 1 + \\/ 5 \ +""" + ucode_str = \ +"""\ +3 ________\n\ +╲╱ 1 + √5 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2**(1/x) + ascii_str = \ +"""\ +x ___\n\ +\\/ 2 \ +""" + ucode_str = \ +"""\ +x ___\n\ +╲╱ 2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = sqrt(2 + pi) + ascii_str = \ +"""\ + ________\n\ +\\/ 2 + pi \ +""" + ucode_str = \ +"""\ + _______\n\ +╲╱ 2 + π \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (2 + ( + 1 + x**2)/(2 + x))**Rational(1, 4) + (1 + x**Rational(1, 1000))/sqrt(3 + x**2) + ascii_str = \ +"""\ + ____________ \n\ + / 2 1000___ \n\ + / x + 1 \\/ x + 1\n\ +4 / 2 + ------ + -----------\n\ +\\/ x + 2 ________\n\ + / 2 \n\ + \\/ x + 3 \ +""" + ucode_str = \ +"""\ + ____________ \n\ + ╱ 2 1000___ \n\ + ╱ x + 1 ╲╱ x + 1\n\ +4 ╱ 2 + ────── + ───────────\n\ +╲╱ x + 2 ________\n\ + ╱ 2 \n\ + ╲╱ x + 3 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_sqrt_char_knob(): + # See PR #9234. + expr = sqrt(2) + ucode_str1 = \ +"""\ + ___\n\ +╲╱ 2 \ +""" + ucode_str2 = \ +"√2" + assert xpretty(expr, use_unicode=True, + use_unicode_sqrt_char=False) == ucode_str1 + assert xpretty(expr, use_unicode=True, + use_unicode_sqrt_char=True) == ucode_str2 + + +def test_pretty_sqrt_longsymbol_no_sqrt_char(): + # Do not use unicode sqrt char for long symbols (see PR #9234). + expr = sqrt(Symbol('C1')) + ucode_str = \ +"""\ + ____\n\ +╲╱ C₁ \ +""" + assert upretty(expr) == ucode_str + + +def test_pretty_KroneckerDelta(): + x, y = symbols("x, y") + expr = KroneckerDelta(x, y) + ascii_str = \ +"""\ +d \n\ + x,y\ +""" + ucode_str = \ +"""\ +δ \n\ + x,y\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_product(): + n, m, k, l = symbols('n m k l') + f = symbols('f', cls=Function) + expr = Product(f((n/3)**2), (n, k**2, l)) + + unicode_str = \ +"""\ + l \n\ +─┬──────┬─ \n\ + │ │ ⎛ 2⎞\n\ + │ │ ⎜n ⎟\n\ + │ │ f⎜──⎟\n\ + │ │ ⎝9 ⎠\n\ + │ │ \n\ + 2 \n\ + n = k """ + ascii_str = \ +"""\ + l \n\ +__________ \n\ + | | / 2\\\n\ + | | |n |\n\ + | | f|--|\n\ + | | \\9 /\n\ + | | \n\ + 2 \n\ + n = k """ + + expr = Product(f((n/3)**2), (n, k**2, l), (l, 1, m)) + + unicode_str = \ +"""\ + m l \n\ +─┬──────┬─ ─┬──────┬─ \n\ + │ │ │ │ ⎛ 2⎞\n\ + │ │ │ │ ⎜n ⎟\n\ + │ │ │ │ f⎜──⎟\n\ + │ │ │ │ ⎝9 ⎠\n\ + │ │ │ │ \n\ + l = 1 2 \n\ + n = k """ + ascii_str = \ +"""\ + m l \n\ +__________ __________ \n\ + | | | | / 2\\\n\ + | | | | |n |\n\ + | | | | f|--|\n\ + | | | | \\9 /\n\ + | | | | \n\ + l = 1 2 \n\ + n = k """ + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + +def test_pretty_Lambda(): + # S.IdentityFunction is a special case + expr = Lambda(y, y) + assert pretty(expr) == "x -> x" + assert upretty(expr) == "x ↦ x" + + expr = Lambda(x, x+1) + assert pretty(expr) == "x -> x + 1" + assert upretty(expr) == "x ↦ x + 1" + + expr = Lambda(x, x**2) + ascii_str = \ +"""\ + 2\n\ +x -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +x ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda(x, x**2)**2 + ascii_str = \ +"""\ + 2 +/ 2\\ \n\ +\\x -> x / \ +""" + ucode_str = \ +"""\ + 2 +⎛ 2⎞ \n\ +⎝x ↦ x ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda((x, y), x) + ascii_str = "(x, y) -> x" + ucode_str = "(x, y) ↦ x" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda((x, y), x**2) + ascii_str = \ +"""\ + 2\n\ +(x, y) -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +(x, y) ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Lambda(((x, y),), x**2) + ascii_str = \ +"""\ + 2\n\ +((x, y),) -> x \ +""" + ucode_str = \ +"""\ + 2\n\ +((x, y),) ↦ x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_TransferFunction(): + tf1 = TransferFunction(s - 1, s + 1, s) + assert upretty(tf1) == "s - 1\n─────\ns + 1" + tf2 = TransferFunction(2*s + 1, 3 - p, s) + assert upretty(tf2) == "2⋅s + 1\n───────\n 3 - p " + tf3 = TransferFunction(p, p + 1, p) + assert upretty(tf3) == " p \n─────\np + 1" + + +def test_pretty_Series(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(x**2 + y, y - x, y) + tf4 = TransferFunction(2, 3, y) + + tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm2 = TransferFunctionMatrix([[tf3], [-tf4]]) + tfm3 = TransferFunctionMatrix([[tf1, -tf2, -tf3], [tf3, -tf4, tf2]]) + tfm4 = TransferFunctionMatrix([[tf1, tf2], [tf3, -tf4], [-tf2, -tf1]]) + tfm5 = TransferFunctionMatrix([[-tf2, -tf1], [tf4, -tf3], [tf1, tf2]]) + + expected1 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y ⎞ ⎜x + y⎟\n\ +⎜───────⎟⋅⎜──────⎟\n\ +⎝x - 2⋅y⎠ ⎝-x + y⎠\ +""" + expected2 = \ +"""\ +⎛-x + y⎞ ⎛-x - y ⎞\n\ +⎜──────⎟⋅⎜───────⎟\n\ +⎝x + y ⎠ ⎝x - 2⋅y⎠\ +""" + expected3 = \ +"""\ +⎛ 2 ⎞ \n\ +⎜x + y⎟ ⎛ x + y ⎞ ⎛-x - y x - y⎞\n\ +⎜──────⎟⋅⎜───────⎟⋅⎜─────── + ─────⎟\n\ +⎝-x + y⎠ ⎝x - 2⋅y⎠ ⎝x - 2⋅y x + y⎠\ +""" + expected4 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y x - y⎞ ⎜x - y x + y⎟\n\ +⎜─────── + ─────⎟⋅⎜───── + ──────⎟\n\ +⎝x - 2⋅y x + y⎠ ⎝x + y -x + y⎠\ +""" + expected5 = \ +"""\ +⎡ x + y x - y⎤ ⎡ 2 ⎤ \n\ +⎢─────── ─────⎥ ⎢x + y⎥ \n\ +⎢x - 2⋅y x + y⎥ ⎢──────⎥ \n\ +⎢ ⎥ ⎢-x + y⎥ \n\ +⎢ 2 ⎥ ⋅⎢ ⎥ \n\ +⎢x + y 2 ⎥ ⎢ -2 ⎥ \n\ +⎢────── ─ ⎥ ⎢ ─── ⎥ \n\ +⎣-x + y 3 ⎦τ ⎣ 3 ⎦τ\ +""" + expected6 = \ +"""\ + ⎛⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎞\n\ + ⎜⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎟\n\ +⎡ x + y x - y⎤ ⎡ 2 ⎤ ⎜⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎟\n\ +⎢─────── ─────⎥ ⎢ x + y -x + y - x - y⎥ ⎜⎢ ⎥ ⎢ ⎥ ⎟\n\ +⎢x - 2⋅y x + y⎥ ⎢─────── ────── ────────⎥ ⎜⎢ 2 ⎥ ⎢ 2 ⎥ ⎟\n\ +⎢ ⎥ ⎢x - 2⋅y x + y -x + y ⎥ ⎜⎢x + y -2 ⎥ ⎢ -2 x + y ⎥ ⎟\n\ +⎢ 2 ⎥ ⋅⎢ ⎥ ⋅⎜⎢────── ─── ⎥ + ⎢ ─── ────── ⎥ ⎟\n\ +⎢x + y 2 ⎥ ⎢ 2 ⎥ ⎜⎢-x + y 3 ⎥ ⎢ 3 -x + y ⎥ ⎟\n\ +⎢────── ─ ⎥ ⎢x + y -2 x - y ⎥ ⎜⎢ ⎥ ⎢ ⎥ ⎟\n\ +⎣-x + y 3 ⎦τ ⎢────── ─── ───── ⎥ ⎜⎢-x + y -x - y ⎥ ⎢-x - y -x + y ⎥ ⎟\n\ + ⎣-x + y 3 x + y ⎦τ ⎜⎢────── ───────⎥ ⎢─────── ────── ⎥ ⎟\n\ + ⎝⎣x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ⎠\ +""" + + assert upretty(Series(tf1, tf3)) == expected1 + assert upretty(Series(-tf2, -tf1)) == expected2 + assert upretty(Series(tf3, tf1, Parallel(-tf1, tf2))) == expected3 + assert upretty(Series(Parallel(tf1, tf2), Parallel(tf2, tf3))) == expected4 + assert upretty(MIMOSeries(tfm2, tfm1)) == expected5 + assert upretty(MIMOSeries(MIMOParallel(tfm4, -tfm5), tfm3, tfm1)) == expected6 + + +def test_pretty_Parallel(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(x**2 + y, y - x, y) + tf4 = TransferFunction(y**2 - x, x**3 + x, y) + + tfm1 = TransferFunctionMatrix([[tf1, tf2], [tf3, -tf4], [-tf2, -tf1]]) + tfm2 = TransferFunctionMatrix([[-tf2, -tf1], [tf4, -tf3], [tf1, tf2]]) + tfm3 = TransferFunctionMatrix([[-tf1, tf2], [-tf3, tf4], [tf2, tf1]]) + tfm4 = TransferFunctionMatrix([[-tf1, -tf2], [-tf3, -tf4]]) + + expected1 = \ +"""\ + x + y x - y\n\ +─────── + ─────\n\ +x - 2⋅y x + y\ +""" + expected2 = \ +"""\ +-x + y -x - y \n\ +────── + ─────── +x + y x - 2⋅y\ +""" + expected3 = \ +"""\ + 2 \n\ +x + y x + y ⎛-x - y ⎞ ⎛x - y⎞ +────── + ─────── + ⎜───────⎟⋅⎜─────⎟ +-x + y x - 2⋅y ⎝x - 2⋅y⎠ ⎝x + y⎠\ +""" + + expected4 = \ +"""\ + ⎛ 2 ⎞\n\ +⎛ x + y ⎞ ⎛x - y⎞ ⎛x - y⎞ ⎜x + y⎟\n\ +⎜───────⎟⋅⎜─────⎟ + ⎜─────⎟⋅⎜──────⎟\n\ +⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x + y⎠ ⎝-x + y⎠\ +""" + expected5 = \ +"""\ +⎡ x + y -x + y ⎤ ⎡ x - y x + y ⎤ ⎡ x + y x - y ⎤ \n\ +⎢─────── ────── ⎥ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ \n\ +⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎢x - 2⋅y x + y ⎥ \n\ +⎢ ⎥ ⎢ ⎥ ⎢ ⎥ \n\ +⎢ 2 2 ⎥ ⎢ 2 2 ⎥ ⎢ 2 2 ⎥ \n\ +⎢x + y x - y ⎥ ⎢x - y x + y ⎥ ⎢x + y x - y ⎥ \n\ +⎢────── ────── ⎥ + ⎢────── ────── ⎥ + ⎢────── ────── ⎥ \n\ +⎢-x + y 3 ⎥ ⎢ 3 -x + y ⎥ ⎢-x + y 3 ⎥ \n\ +⎢ x + x ⎥ ⎢x + x ⎥ ⎢ x + x ⎥ \n\ +⎢ ⎥ ⎢ ⎥ ⎢ ⎥ \n\ +⎢-x + y -x - y ⎥ ⎢-x - y -x + y ⎥ ⎢-x + y -x - y ⎥ \n\ +⎢────── ───────⎥ ⎢─────── ────── ⎥ ⎢────── ───────⎥ \n\ +⎣x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ ⎣x + y x - 2⋅y⎦τ\ +""" + expected6 = \ +"""\ +⎡ x - y x + y ⎤ ⎡-x + y -x - y ⎤ \n\ +⎢ ───── ───────⎥ ⎢────── ─────── ⎥ \n\ +⎢ x + y x - 2⋅y⎥ ⎡-x - y -x + y⎤ ⎢x + y x - 2⋅y ⎥ \n\ +⎢ ⎥ ⎢─────── ──────⎥ ⎢ ⎥ \n\ +⎢ 2 2 ⎥ ⎢x - 2⋅y x + y ⎥ ⎢ 2 2 ⎥ \n\ +⎢x - y x + y ⎥ ⎢ ⎥ ⎢-x + y - x - y⎥ \n\ +⎢────── ────── ⎥ ⋅⎢ 2 2⎥ + ⎢─────── ────────⎥ \n\ +⎢ 3 -x + y ⎥ ⎢- x - y x - y ⎥ ⎢ 3 -x + y ⎥ \n\ +⎢x + x ⎥ ⎢──────── ──────⎥ ⎢x + x ⎥ \n\ +⎢ ⎥ ⎢ -x + y 3 ⎥ ⎢ ⎥ \n\ +⎢-x - y -x + y ⎥ ⎣ x + x⎦τ ⎢ x + y x - y ⎥ \n\ +⎢─────── ────── ⎥ ⎢─────── ───── ⎥ \n\ +⎣x - 2⋅y x + y ⎦τ ⎣x - 2⋅y x + y ⎦τ\ +""" + assert upretty(Parallel(tf1, tf2)) == expected1 + assert upretty(Parallel(-tf2, -tf1)) == expected2 + assert upretty(Parallel(tf3, tf1, Series(-tf1, tf2))) == expected3 + assert upretty(Parallel(Series(tf1, tf2), Series(tf2, tf3))) == expected4 + assert upretty(MIMOParallel(-tfm3, -tfm2, tfm1)) == expected5 + assert upretty(MIMOParallel(MIMOSeries(tfm4, -tfm2), tfm2)) == expected6 + + +def test_pretty_Feedback(): + tf = TransferFunction(1, 1, y) + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(y**2 - 2*y + 1, y + 5, y) + tf4 = TransferFunction(x - 2*y**3, x + y, x) + tf5 = TransferFunction(1 - x, x - y, y) + tf6 = TransferFunction(2, 2, x) + expected1 = \ +"""\ + ⎛1⎞ \n\ + ⎜─⎟ \n\ + ⎝1⎠ \n\ +─────────────\n\ +1 ⎛ x + y ⎞\n\ +─ + ⎜───────⎟\n\ +1 ⎝x - 2⋅y⎠\ +""" + expected2 = \ +"""\ + ⎛1⎞ \n\ + ⎜─⎟ \n\ + ⎝1⎠ \n\ +────────────────────────────────────\n\ + ⎛ 2 ⎞\n\ +1 ⎛x - y⎞ ⎛ x + y ⎞ ⎜y - 2⋅y + 1⎟\n\ +─ + ⎜─────⎟⋅⎜───────⎟⋅⎜────────────⎟\n\ +1 ⎝x + y⎠ ⎝x - 2⋅y⎠ ⎝ y + 5 ⎠\ +""" + expected3 = \ +"""\ + ⎛ x + y ⎞ \n\ + ⎜───────⎟ \n\ + ⎝x - 2⋅y⎠ \n\ +────────────────────────────────────────────\n\ + ⎛ 2 ⎞ \n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟⋅⎜────────────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝ y + 5 ⎠ ⎝x - y⎠\ +""" + expected4 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠\ +""" + expected5 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎛1 - x⎞\n\ +─ + ⎜───────⎟⋅⎜─────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x - y⎠\ +""" + expected6 = \ +"""\ + ⎛ 2 ⎞ \n\ + ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞ \n\ + ⎜────────────⎟⋅⎜─────⎟ \n\ + ⎝ y + 5 ⎠ ⎝x - y⎠ \n\ +────────────────────────────────────────────\n\ + ⎛ 2 ⎞ \n\ +1 ⎜y - 2⋅y + 1⎟ ⎛1 - x⎞ ⎛x - y⎞ ⎛ x + y ⎞\n\ +─ + ⎜────────────⎟⋅⎜─────⎟⋅⎜─────⎟⋅⎜───────⎟\n\ +1 ⎝ y + 5 ⎠ ⎝x - y⎠ ⎝x + y⎠ ⎝x - 2⋅y⎠\ +""" + expected7 = \ +"""\ + ⎛ 3⎞ \n\ + ⎜x - 2⋅y ⎟ \n\ + ⎜────────⎟ \n\ + ⎝ x + y ⎠ \n\ +──────────────────\n\ + ⎛ 3⎞ \n\ +1 ⎜x - 2⋅y ⎟ ⎛2⎞\n\ +─ + ⎜────────⎟⋅⎜─⎟\n\ +1 ⎝ x + y ⎠ ⎝2⎠\ +""" + expected8 = \ +"""\ + ⎛1 - x⎞ \n\ + ⎜─────⎟ \n\ + ⎝x - y⎠ \n\ +───────────\n\ +1 ⎛1 - x⎞\n\ +─ + ⎜─────⎟\n\ +1 ⎝x - y⎠\ +""" + expected9 = \ +"""\ + ⎛ x + y ⎞ ⎛x - y⎞ \n\ + ⎜───────⎟⋅⎜─────⎟ \n\ + ⎝x - 2⋅y⎠ ⎝x + y⎠ \n\ +─────────────────────────────\n\ +1 ⎛ x + y ⎞ ⎛x - y⎞ ⎛1 - x⎞\n\ +─ - ⎜───────⎟⋅⎜─────⎟⋅⎜─────⎟\n\ +1 ⎝x - 2⋅y⎠ ⎝x + y⎠ ⎝x - y⎠\ +""" + expected10 = \ +"""\ + ⎛1 - x⎞ \n\ + ⎜─────⎟ \n\ + ⎝x - y⎠ \n\ +───────────\n\ +1 ⎛1 - x⎞\n\ +─ - ⎜─────⎟\n\ +1 ⎝x - y⎠\ +""" + assert upretty(Feedback(tf, tf1)) == expected1 + assert upretty(Feedback(tf, tf2*tf1*tf3)) == expected2 + assert upretty(Feedback(tf1, tf2*tf3*tf5)) == expected3 + assert upretty(Feedback(tf1*tf2, tf)) == expected4 + assert upretty(Feedback(tf1*tf2, tf5)) == expected5 + assert upretty(Feedback(tf3*tf5, tf2*tf1)) == expected6 + assert upretty(Feedback(tf4, tf6)) == expected7 + assert upretty(Feedback(tf5, tf)) == expected8 + + assert upretty(Feedback(tf1*tf2, tf5, 1)) == expected9 + assert upretty(Feedback(tf5, tf, 1)) == expected10 + + +def test_pretty_MIMOFeedback(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + tfm_3 = TransferFunctionMatrix([[tf1, tf1], [tf2, tf2]]) + + expected1 = \ +"""\ +⎛ ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎞-1 ⎡ x + y x - y ⎤ \n\ +⎜ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎟ ⎢─────── ───── ⎥ \n\ +⎜ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎟ ⎢x - 2⋅y x + y ⎥ \n\ +⎜I - ⎢ ⎥ ⋅⎢ ⎥ ⎟ ⋅ ⎢ ⎥ \n\ +⎜ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ ⎟ ⎢ x - y x + y ⎥ \n\ +⎜ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ ⎟ ⎢ ───── ───────⎥ \n\ +⎝ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ⎠ ⎣ x + y x - 2⋅y⎦τ\ +""" + expected2 = \ +"""\ +⎛ ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ ⎡ x + y x + y ⎤ ⎞-1 ⎡ x + y x - y ⎤ ⎡ x - y x + y ⎤ \n\ +⎜ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ ⎢─────── ───────⎥ ⎟ ⎢─────── ───── ⎥ ⎢ ───── ───────⎥ \n\ +⎜ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ ⎢x - 2⋅y x - 2⋅y⎥ ⎟ ⎢x - 2⋅y x + y ⎥ ⎢ x + y x - 2⋅y⎥ \n\ +⎜I + ⎢ ⎥ ⋅⎢ ⎥ ⋅⎢ ⎥ ⎟ ⋅ ⎢ ⎥ ⋅⎢ ⎥ \n\ +⎜ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ ⎢ x - y x - y ⎥ ⎟ ⎢ x - y x + y ⎥ ⎢ x + y x - y ⎥ \n\ +⎜ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ ⎢ ───── ───── ⎥ ⎟ ⎢ ───── ───────⎥ ⎢─────── ───── ⎥ \n\ +⎝ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ ⎣ x + y x + y ⎦τ⎠ ⎣ x + y x - 2⋅y⎦τ ⎣x - 2⋅y x + y ⎦τ\ +""" + + assert upretty(MIMOFeedback(tfm_1, tfm_2, 1)) == \ + expected1 # Positive MIMOFeedback + assert upretty(MIMOFeedback(tfm_1*tfm_2, tfm_3)) == \ + expected2 # Negative MIMOFeedback (Default) + + +def test_pretty_TransferFunctionMatrix(): + tf1 = TransferFunction(x + y, x - 2*y, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(y**2 - 2*y + 1, y + 5, y) + tf4 = TransferFunction(y, x**2 + x + 1, y) + tf5 = TransferFunction(1 - x, x - y, y) + tf6 = TransferFunction(2, 2, y) + expected1 = \ +"""\ +⎡ x + y ⎤ \n\ +⎢───────⎥ \n\ +⎢x - 2⋅y⎥ \n\ +⎢ ⎥ \n\ +⎢ x - y ⎥ \n\ +⎢ ───── ⎥ \n\ +⎣ x + y ⎦τ\ +""" + expected2 = \ +"""\ +⎡ x + y ⎤ \n\ +⎢ ─────── ⎥ \n\ +⎢ x - 2⋅y ⎥ \n\ +⎢ ⎥ \n\ +⎢ x - y ⎥ \n\ +⎢ ───── ⎥ \n\ +⎢ x + y ⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢- y + 2⋅y - 1⎥ \n\ +⎢──────────────⎥ \n\ +⎣ y + 5 ⎦τ\ +""" + expected3 = \ +"""\ +⎡ x + y x - y ⎤ \n\ +⎢ ─────── ───── ⎥ \n\ +⎢ x - 2⋅y x + y ⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢y - 2⋅y + 1 y ⎥ \n\ +⎢──────────── ──────────⎥ \n\ +⎢ y + 5 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 1 - x 2 ⎥ \n\ +⎢ ───── ─ ⎥ \n\ +⎣ x - y 2 ⎦τ\ +""" + expected4 = \ +"""\ +⎡ x - y x + y y ⎤ \n\ +⎢ ───── ─────── ──────────⎥ \n\ +⎢ x + y x - 2⋅y 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 2 ⎥ \n\ +⎢- y + 2⋅y - 1 x - 1 -2 ⎥ \n\ +⎢────────────── ───── ─── ⎥ \n\ +⎣ y + 5 x - y 2 ⎦τ\ +""" + expected5 = \ +"""\ +⎡ x + y x - y x + y y ⎤ \n\ +⎢───────⋅───── ─────── ──────────⎥ \n\ +⎢x - 2⋅y x + y x - 2⋅y 2 ⎥ \n\ +⎢ x + x + 1⎥ \n\ +⎢ ⎥ \n\ +⎢ 1 - x 2 x + y -2 ⎥ \n\ +⎢ ───── + ─ ─────── ─── ⎥ \n\ +⎣ x - y 2 x - 2⋅y 2 ⎦τ\ +""" + + assert upretty(TransferFunctionMatrix([[tf1], [tf2]])) == expected1 + assert upretty(TransferFunctionMatrix([[tf1], [tf2], [-tf3]])) == expected2 + assert upretty(TransferFunctionMatrix([[tf1, tf2], [tf3, tf4], [tf5, tf6]])) == expected3 + assert upretty(TransferFunctionMatrix([[tf2, tf1, tf4], [-tf3, -tf5, -tf6]])) == expected4 + assert upretty(TransferFunctionMatrix([[Series(tf2, tf1), tf1, tf4], [Parallel(tf6, tf5), tf1, -tf6]])) == \ + expected5 + + +def test_pretty_StateSpace(): + ss1 = StateSpace(Matrix([a]), Matrix([b]), Matrix([c]), Matrix([d])) + A = Matrix([[0, 1], [1, 0]]) + B = Matrix([1, 0]) + C = Matrix([[0, 1]]) + D = Matrix([0]) + ss2 = StateSpace(A, B, C, D) + ss3 = StateSpace(Matrix([[-1.5, -2], [1, 0]]), + Matrix([[0.5, 0], [0, 1]]), + Matrix([[0, 1], [0, 2]]), + Matrix([[2, 2], [1, 1]])) + + expected1 = \ +"""\ +⎡[a] [b]⎤\n\ +⎢ ⎥\n\ +⎣[c] [d]⎦\ +""" + expected2 = \ +"""\ +⎡⎡0 1⎤ ⎡1⎤⎤\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎣1 0⎦ ⎣0⎦⎥\n\ +⎢ ⎥\n\ +⎣[0 1] [0]⎦\ +""" + expected3 = \ +"""\ +⎡⎡-1.5 -2⎤ ⎡0.5 0⎤⎤\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎣ 1 0 ⎦ ⎣ 0 1⎦⎥\n\ +⎢ ⎥\n\ +⎢ ⎡0 1⎤ ⎡2 2⎤ ⎥\n\ +⎢ ⎢ ⎥ ⎢ ⎥ ⎥\n\ +⎣ ⎣0 2⎦ ⎣1 1⎦ ⎦\ +""" + + assert upretty(ss1) == expected1 + assert upretty(ss2) == expected2 + assert upretty(ss3) == expected3 + +def test_pretty_order(): + expr = O(1) + ascii_str = \ +"""\ +O(1)\ +""" + ucode_str = \ +"""\ +O(1)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1/x) + ascii_str = \ +"""\ + /1\\\n\ +O|-|\n\ + \\x/\ +""" + ucode_str = \ +"""\ + ⎛1⎞\n\ +O⎜─⎟\n\ + ⎝x⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(x**2 + y**2) + ascii_str = \ +"""\ + / 2 2 \\\n\ +O\\x + y ; (x, y) -> (0, 0)/\ +""" + ucode_str = \ +"""\ + ⎛ 2 2 ⎞\n\ +O⎝x + y ; (x, y) → (0, 0)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1, (x, oo)) + ascii_str = \ +"""\ +O(1; x -> oo)\ +""" + ucode_str = \ +"""\ +O(1; x → ∞)\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(1/x, (x, oo)) + ascii_str = \ +"""\ + /1 \\\n\ +O|-; x -> oo|\n\ + \\x /\ +""" + ucode_str = \ +"""\ + ⎛1 ⎞\n\ +O⎜─; x → ∞⎟\n\ + ⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = O(x**2 + y**2, (x, oo), (y, oo)) + ascii_str = \ +"""\ + / 2 2 \\\n\ +O\\x + y ; (x, y) -> (oo, oo)/\ +""" + ucode_str = \ +"""\ + ⎛ 2 2 ⎞\n\ +O⎝x + y ; (x, y) → (∞, ∞)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_derivatives(): + # Simple + expr = Derivative(log(x), x, evaluate=False) + ascii_str = \ +"""\ +d \n\ +--(log(x))\n\ +dx \ +""" + ucode_str = \ +"""\ +d \n\ +──(log(x))\n\ +dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(log(x), x, evaluate=False) + x + ascii_str_1 = \ +"""\ + d \n\ +x + --(log(x))\n\ + dx \ +""" + ascii_str_2 = \ +"""\ +d \n\ +--(log(x)) + x\n\ +dx \ +""" + ucode_str_1 = \ +"""\ + d \n\ +x + ──(log(x))\n\ + dx \ +""" + ucode_str_2 = \ +"""\ +d \n\ +──(log(x)) + x\n\ +dx \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + # basic partial derivatives + expr = Derivative(log(x + y) + x, x) + ascii_str_1 = \ +"""\ +d \n\ +--(log(x + y) + x)\n\ +dx \ +""" + ascii_str_2 = \ +"""\ +d \n\ +--(x + log(x + y))\n\ +dx \ +""" + ucode_str_1 = \ +"""\ +∂ \n\ +──(log(x + y) + x)\n\ +∂x \ +""" + ucode_str_2 = \ +"""\ +∂ \n\ +──(x + log(x + y))\n\ +∂x \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2], upretty(expr) + + # Multiple symbols + expr = Derivative(log(x) + x**2, x, y) + ascii_str_1 = \ +"""\ + 2 \n\ + d / 2\\\n\ +-----\\log(x) + x /\n\ +dy dx \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ + d / 2 \\\n\ +-----\\x + log(x)/\n\ +dy dx \ +""" + ascii_str_3 = \ +"""\ + 2 \n\ + d / 2 \\\n\ +-----\\x + log(x)/\n\ +dy dx \ +""" + ucode_str_1 = \ +"""\ + 2 \n\ + d ⎛ 2⎞\n\ +─────⎝log(x) + x ⎠\n\ +dy dx \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ + d ⎛ 2 ⎞\n\ +─────⎝x + log(x)⎠\n\ +dy dx \ +""" + ucode_str_3 = \ +"""\ + 2 \n\ + d ⎛ 2 ⎞\n\ +─────⎝x + log(x)⎠\n\ +dy dx \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Derivative(2*x*y, y, x) + x**2 + ascii_str_1 = \ +"""\ + 2 \n\ + d 2\n\ +-----(2*x*y) + x \n\ +dx dy \ +""" + ascii_str_2 = \ +"""\ + 2 \n\ + 2 d \n\ +x + -----(2*x*y)\n\ + dx dy \ +""" + ascii_str_3 = \ +"""\ + 2 \n\ + 2 d \n\ +x + -----(2*x*y)\n\ + dx dy \ +""" + ucode_str_1 = \ +"""\ + 2 \n\ + ∂ 2\n\ +─────(2⋅x⋅y) + x \n\ +∂x ∂y \ +""" + ucode_str_2 = \ +"""\ + 2 \n\ + 2 ∂ \n\ +x + ─────(2⋅x⋅y)\n\ + ∂x ∂y \ +""" + ucode_str_3 = \ +"""\ + 2 \n\ + 2 ∂ \n\ +x + ─────(2⋅x⋅y)\n\ + ∂x ∂y \ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2, ascii_str_3] + assert upretty(expr) in [ucode_str_1, ucode_str_2, ucode_str_3] + + expr = Derivative(2*x*y, x, x) + ascii_str = \ +"""\ + 2 \n\ +d \n\ +---(2*x*y)\n\ + 2 \n\ +dx \ +""" + ucode_str = \ +"""\ + 2 \n\ +∂ \n\ +───(2⋅x⋅y)\n\ + 2 \n\ +∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(2*x*y, x, 17) + ascii_str = \ +"""\ + 17 \n\ +d \n\ +----(2*x*y)\n\ + 17 \n\ +dx \ +""" + ucode_str = \ +"""\ + 17 \n\ +∂ \n\ +────(2⋅x⋅y)\n\ + 17 \n\ +∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(2*x*y, x, x, y) + ascii_str = \ +"""\ + 3 \n\ + d \n\ +------(2*x*y)\n\ + 2 \n\ +dy dx \ +""" + ucode_str = \ +"""\ + 3 \n\ + ∂ \n\ +──────(2⋅x⋅y)\n\ + 2 \n\ +∂y ∂x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # Greek letters + alpha = Symbol('alpha') + beta = Function('beta') + expr = beta(alpha).diff(alpha) + ascii_str = \ +"""\ + d \n\ +------(beta(alpha))\n\ +dalpha \ +""" + ucode_str = \ +"""\ +d \n\ +──(β(α))\n\ +dα \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Derivative(f(x), (x, n)) + + ascii_str = \ +"""\ + n \n\ +d \n\ +---(f(x))\n\ + n \n\ +dx \ +""" + ucode_str = \ +"""\ + n \n\ +d \n\ +───(f(x))\n\ + n \n\ +dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_integrals(): + expr = Integral(log(x), x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | log(x) dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ log(x) dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral((sin(x))**2 / (tan(x))**2) + ascii_str = \ +"""\ + / \n\ + | \n\ + | 2 \n\ + | sin (x) \n\ + | ------- dx\n\ + | 2 \n\ + | tan (x) \n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ 2 \n\ +⎮ sin (x) \n\ +⎮ ─────── dx\n\ +⎮ 2 \n\ +⎮ tan (x) \n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**(2**x), x) + ascii_str = \ +"""\ + / \n\ + | \n\ + | / x\\ \n\ + | \\2 / \n\ + | x dx\n\ + | \n\ +/ \ +""" + ucode_str = \ +"""\ +⌠ \n\ +⎮ ⎛ x⎞ \n\ +⎮ ⎝2 ⎠ \n\ +⎮ x dx\n\ +⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, (x, 1, 2)) + ascii_str = \ +"""\ + 2 \n\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \n\ +1 \ +""" + ucode_str = \ +"""\ +2 \n\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \n\ +1 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2, (x, Rational(1, 2), 10)) + ascii_str = \ +"""\ + 10 \n\ + / \n\ + | \n\ + | 2 \n\ + | x dx\n\ + | \n\ +/ \n\ +1/2 \ +""" + ucode_str = \ +"""\ +10 \n\ +⌠ \n\ +⎮ 2 \n\ +⎮ x dx\n\ +⌡ \n\ +1/2 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(x**2*y**2, x, y) + ascii_str = \ +"""\ + / / \n\ + | | \n\ + | | 2 2 \n\ + | | x *y dx dy\n\ + | | \n\ +/ / \ +""" + ucode_str = \ +"""\ +⌠ ⌠ \n\ +⎮ ⎮ 2 2 \n\ +⎮ ⎮ x ⋅y dx dy\n\ +⌡ ⌡ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(sin(th)/cos(ph), (th, 0, pi), (ph, 0, 2*pi)) + ascii_str = \ +"""\ + 2*pi pi \n\ + / / \n\ + | | \n\ + | | sin(theta) \n\ + | | ---------- d(theta) d(phi)\n\ + | | cos(phi) \n\ + | | \n\ + / / \n\ +0 0 \ +""" + ucode_str = \ +"""\ +2⋅π π \n\ + ⌠ ⌠ \n\ + ⎮ ⎮ sin(θ) \n\ + ⎮ ⎮ ────── dθ dφ\n\ + ⎮ ⎮ cos(φ) \n\ + ⌡ ⌡ \n\ + 0 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_matrix(): + # Empty Matrix + expr = Matrix() + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix(2, 0, lambda i, j: 0) + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix(0, 2, lambda i, j: 0) + ascii_str = "[]" + unicode_str = "[]" + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + expr = Matrix([[x**2 + 1, 1], [y, x + y]]) + ascii_str_1 = \ +"""\ +[ 2 ] +[1 + x 1 ] +[ ] +[ y x + y]\ +""" + ascii_str_2 = \ +"""\ +[ 2 ] +[x + 1 1 ] +[ ] +[ y x + y]\ +""" + ucode_str_1 = \ +"""\ +⎡ 2 ⎤ +⎢1 + x 1 ⎥ +⎢ ⎥ +⎣ y x + y⎦\ +""" + ucode_str_2 = \ +"""\ +⎡ 2 ⎤ +⎢x + 1 1 ⎥ +⎢ ⎥ +⎣ y x + y⎦\ +""" + assert pretty(expr) in [ascii_str_1, ascii_str_2] + assert upretty(expr) in [ucode_str_1, ucode_str_2] + + expr = Matrix([[x/y, y, th], [0, exp(I*k*ph), 1]]) + ascii_str = \ +"""\ +[x ] +[- y theta] +[y ] +[ ] +[ I*k*phi ] +[0 e 1 ]\ +""" + ucode_str = \ +"""\ +⎡x ⎤ +⎢─ y θ⎥ +⎢y ⎥ +⎢ ⎥ +⎢ ⅈ⋅k⋅φ ⎥ +⎣0 ℯ 1⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + unicode_str = \ +"""\ +⎡v̇_msc_00 0 0 ⎤ +⎢ ⎥ +⎢ 0 v̇_msc_01 0 ⎥ +⎢ ⎥ +⎣ 0 0 v̇_msc_02⎦\ +""" + + expr = diag(*MatrixSymbol('vdot_msc',1,3)) + assert upretty(expr) == unicode_str + + +def test_pretty_ndim_arrays(): + x, y, z, w = symbols("x y z w") + + for ArrayType in (ImmutableDenseNDimArray, ImmutableSparseNDimArray, MutableDenseNDimArray, MutableSparseNDimArray): + # Basic: scalar array + M = ArrayType(x) + + assert pretty(M) == "x" + assert upretty(M) == "x" + + M = ArrayType([[1/x, y], [z, w]]) + M1 = ArrayType([1/x, y, z]) + + M2 = tensorproduct(M1, M) + M3 = tensorproduct(M, M) + + ascii_str = \ +"""\ +[1 ]\n\ +[- y]\n\ +[x ]\n\ +[ ]\n\ +[z w]\ +""" + ucode_str = \ +"""\ +⎡1 ⎤\n\ +⎢─ y⎥\n\ +⎢x ⎥\n\ +⎢ ⎥\n\ +⎣z w⎦\ +""" + assert pretty(M) == ascii_str + assert upretty(M) == ucode_str + + ascii_str = \ +"""\ +[1 ]\n\ +[- y z]\n\ +[x ]\ +""" + ucode_str = \ +"""\ +⎡1 ⎤\n\ +⎢─ y z⎥\n\ +⎣x ⎦\ +""" + assert pretty(M1) == ascii_str + assert upretty(M1) == ucode_str + + ascii_str = \ +"""\ +[[1 y] ]\n\ +[[-- -] [z ]]\n\ +[[ 2 x] [ y 2 ] [- y*z]]\n\ +[[x ] [ - y ] [x ]]\n\ +[[ ] [ x ] [ ]]\n\ +[[z w] [ ] [ 2 ]]\n\ +[[- -] [y*z w*y] [z w*z]]\n\ +[[x x] ]\ +""" + ucode_str = \ +"""\ +⎡⎡1 y⎤ ⎤\n\ +⎢⎢── ─⎥ ⎡z ⎤⎥\n\ +⎢⎢ 2 x⎥ ⎡ y 2 ⎤ ⎢─ y⋅z⎥⎥\n\ +⎢⎢x ⎥ ⎢ ─ y ⎥ ⎢x ⎥⎥\n\ +⎢⎢ ⎥ ⎢ x ⎥ ⎢ ⎥⎥\n\ +⎢⎢z w⎥ ⎢ ⎥ ⎢ 2 ⎥⎥\n\ +⎢⎢─ ─⎥ ⎣y⋅z w⋅y⎦ ⎣z w⋅z⎦⎥\n\ +⎣⎣x x⎦ ⎦\ +""" + assert pretty(M2) == ascii_str + assert upretty(M2) == ucode_str + + ascii_str = \ +"""\ +[ [1 y] ]\n\ +[ [-- -] ]\n\ +[ [ 2 x] [ y 2 ]]\n\ +[ [x ] [ - y ]]\n\ +[ [ ] [ x ]]\n\ +[ [z w] [ ]]\n\ +[ [- -] [y*z w*y]]\n\ +[ [x x] ]\n\ +[ ]\n\ +[[z ] [ w ]]\n\ +[[- y*z] [ - w*y]]\n\ +[[x ] [ x ]]\n\ +[[ ] [ ]]\n\ +[[ 2 ] [ 2 ]]\n\ +[[z w*z] [w*z w ]]\ +""" + ucode_str = \ +"""\ +⎡ ⎡1 y⎤ ⎤\n\ +⎢ ⎢── ─⎥ ⎥\n\ +⎢ ⎢ 2 x⎥ ⎡ y 2 ⎤⎥\n\ +⎢ ⎢x ⎥ ⎢ ─ y ⎥⎥\n\ +⎢ ⎢ ⎥ ⎢ x ⎥⎥\n\ +⎢ ⎢z w⎥ ⎢ ⎥⎥\n\ +⎢ ⎢─ ─⎥ ⎣y⋅z w⋅y⎦⎥\n\ +⎢ ⎣x x⎦ ⎥\n\ +⎢ ⎥\n\ +⎢⎡z ⎤ ⎡ w ⎤⎥\n\ +⎢⎢─ y⋅z⎥ ⎢ ─ w⋅y⎥⎥\n\ +⎢⎢x ⎥ ⎢ x ⎥⎥\n\ +⎢⎢ ⎥ ⎢ ⎥⎥\n\ +⎢⎢ 2 ⎥ ⎢ 2 ⎥⎥\n\ +⎣⎣z w⋅z⎦ ⎣w⋅z w ⎦⎦\ +""" + assert pretty(M3) == ascii_str + assert upretty(M3) == ucode_str + + Mrow = ArrayType([[x, y, 1 / z]]) + Mcolumn = ArrayType([[x], [y], [1 / z]]) + Mcol2 = ArrayType([Mcolumn.tolist()]) + + ascii_str = \ +"""\ +[[ 1]]\n\ +[[x y -]]\n\ +[[ z]]\ +""" + ucode_str = \ +"""\ +⎡⎡ 1⎤⎤\n\ +⎢⎢x y ─⎥⎥\n\ +⎣⎣ z⎦⎦\ +""" + assert pretty(Mrow) == ascii_str + assert upretty(Mrow) == ucode_str + + ascii_str = \ +"""\ +[x]\n\ +[ ]\n\ +[y]\n\ +[ ]\n\ +[1]\n\ +[-]\n\ +[z]\ +""" + ucode_str = \ +"""\ +⎡x⎤\n\ +⎢ ⎥\n\ +⎢y⎥\n\ +⎢ ⎥\n\ +⎢1⎥\n\ +⎢─⎥\n\ +⎣z⎦\ +""" + assert pretty(Mcolumn) == ascii_str + assert upretty(Mcolumn) == ucode_str + + ascii_str = \ +"""\ +[[x]]\n\ +[[ ]]\n\ +[[y]]\n\ +[[ ]]\n\ +[[1]]\n\ +[[-]]\n\ +[[z]]\ +""" + ucode_str = \ +"""\ +⎡⎡x⎤⎤\n\ +⎢⎢ ⎥⎥\n\ +⎢⎢y⎥⎥\n\ +⎢⎢ ⎥⎥\n\ +⎢⎢1⎥⎥\n\ +⎢⎢─⎥⎥\n\ +⎣⎣z⎦⎦\ +""" + assert pretty(Mcol2) == ascii_str + assert upretty(Mcol2) == ucode_str + + +def test_tensor_TensorProduct(): + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + assert upretty(TensorProduct(A, B)) == "A\u2297B" + assert upretty(TensorProduct(A, B, A)) == "A\u2297B\u2297A" + + +def test_diffgeom_print_WedgeProduct(): + from sympy.diffgeom.rn import R2 + from sympy.diffgeom import WedgeProduct + wp = WedgeProduct(R2.dx, R2.dy) + assert upretty(wp) == "ⅆ x∧ⅆ y" + assert pretty(wp) == r"d x/\d y" + + +def test_Adjoint(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert pretty(Adjoint(X)) == " +\nX " + assert pretty(Adjoint(X + Y)) == " +\n(X + Y) " + assert pretty(Adjoint(X) + Adjoint(Y)) == " + +\nX + Y " + assert pretty(Adjoint(X*Y)) == " +\n(X*Y) " + assert pretty(Adjoint(Y)*Adjoint(X)) == " + +\nY *X " + assert pretty(Adjoint(X**2)) == " +\n/ 2\\ \n\\X / " + assert pretty(Adjoint(X)**2) == " 2\n/ +\\ \n\\X / " + assert pretty(Adjoint(Inverse(X))) == " +\n/ -1\\ \n\\X / " + assert pretty(Inverse(Adjoint(X))) == " -1\n/ +\\ \n\\X / " + assert pretty(Adjoint(Transpose(X))) == " +\n/ T\\ \n\\X / " + assert pretty(Transpose(Adjoint(X))) == " T\n/ +\\ \n\\X / " + assert upretty(Adjoint(X)) == " †\nX " + assert upretty(Adjoint(X + Y)) == " †\n(X + Y) " + assert upretty(Adjoint(X) + Adjoint(Y)) == " † †\nX + Y " + assert upretty(Adjoint(X*Y)) == " †\n(X⋅Y) " + assert upretty(Adjoint(Y)*Adjoint(X)) == " † †\nY ⋅X " + assert upretty(Adjoint(X**2)) == \ + " †\n⎛ 2⎞ \n⎝X ⎠ " + assert upretty(Adjoint(X)**2) == \ + " 2\n⎛ †⎞ \n⎝X ⎠ " + assert upretty(Adjoint(Inverse(X))) == \ + " †\n⎛ -1⎞ \n⎝X ⎠ " + assert upretty(Inverse(Adjoint(X))) == \ + " -1\n⎛ †⎞ \n⎝X ⎠ " + assert upretty(Adjoint(Transpose(X))) == \ + " †\n⎛ T⎞ \n⎝X ⎠ " + assert upretty(Transpose(Adjoint(X))) == \ + " T\n⎛ †⎞ \n⎝X ⎠ " + m = Matrix(((1, 2), (3, 4))) + assert upretty(Adjoint(m)) == \ + ' †\n'\ + '⎡1 2⎤ \n'\ + '⎢ ⎥ \n'\ + '⎣3 4⎦ ' + assert upretty(Adjoint(m+X)) == \ + ' †\n'\ + '⎛⎡1 2⎤ ⎞ \n'\ + '⎜⎢ ⎥ + X⎟ \n'\ + '⎝⎣3 4⎦ ⎠ ' + assert upretty(Adjoint(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + ' †\n'\ + '⎡ 𝟙 X⎤ \n'\ + '⎢ ⎥ \n'\ + '⎢⎡1 2⎤ ⎥ \n'\ + '⎢⎢ ⎥ 𝟘⎥ \n'\ + '⎣⎣3 4⎦ ⎦ ' + + +def test_Transpose(): + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert pretty(Transpose(X)) == " T\nX " + assert pretty(Transpose(X + Y)) == " T\n(X + Y) " + assert pretty(Transpose(X) + Transpose(Y)) == " T T\nX + Y " + assert pretty(Transpose(X*Y)) == " T\n(X*Y) " + assert pretty(Transpose(Y)*Transpose(X)) == " T T\nY *X " + assert pretty(Transpose(X**2)) == " T\n/ 2\\ \n\\X / " + assert pretty(Transpose(X)**2) == " 2\n/ T\\ \n\\X / " + assert pretty(Transpose(Inverse(X))) == " T\n/ -1\\ \n\\X / " + assert pretty(Inverse(Transpose(X))) == " -1\n/ T\\ \n\\X / " + assert upretty(Transpose(X)) == " T\nX " + assert upretty(Transpose(X + Y)) == " T\n(X + Y) " + assert upretty(Transpose(X) + Transpose(Y)) == " T T\nX + Y " + assert upretty(Transpose(X*Y)) == " T\n(X⋅Y) " + assert upretty(Transpose(Y)*Transpose(X)) == " T T\nY ⋅X " + assert upretty(Transpose(X**2)) == \ + " T\n⎛ 2⎞ \n⎝X ⎠ " + assert upretty(Transpose(X)**2) == \ + " 2\n⎛ T⎞ \n⎝X ⎠ " + assert upretty(Transpose(Inverse(X))) == \ + " T\n⎛ -1⎞ \n⎝X ⎠ " + assert upretty(Inverse(Transpose(X))) == \ + " -1\n⎛ T⎞ \n⎝X ⎠ " + m = Matrix(((1, 2), (3, 4))) + assert upretty(Transpose(m)) == \ + ' T\n'\ + '⎡1 2⎤ \n'\ + '⎢ ⎥ \n'\ + '⎣3 4⎦ ' + assert upretty(Transpose(m+X)) == \ + ' T\n'\ + '⎛⎡1 2⎤ ⎞ \n'\ + '⎜⎢ ⎥ + X⎟ \n'\ + '⎝⎣3 4⎦ ⎠ ' + assert upretty(Transpose(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + ' T\n'\ + '⎡ 𝟙 X⎤ \n'\ + '⎢ ⎥ \n'\ + '⎢⎡1 2⎤ ⎥ \n'\ + '⎢⎢ ⎥ 𝟘⎥ \n'\ + '⎣⎣3 4⎦ ⎦ ' + + +def test_pretty_Trace_issue_9044(): + X = Matrix([[1, 2], [3, 4]]) + Y = Matrix([[2, 4], [6, 8]]) + ascii_str_1 = \ +"""\ + /[1 2]\\ +tr|[ ]| + \\[3 4]/\ +""" + ucode_str_1 = \ +"""\ + ⎛⎡1 2⎤⎞ +tr⎜⎢ ⎥⎟ + ⎝⎣3 4⎦⎠\ +""" + ascii_str_2 = \ +"""\ + /[1 2]\\ /[2 4]\\ +tr|[ ]| + tr|[ ]| + \\[3 4]/ \\[6 8]/\ +""" + ucode_str_2 = \ +"""\ + ⎛⎡1 2⎤⎞ ⎛⎡2 4⎤⎞ +tr⎜⎢ ⎥⎟ + tr⎜⎢ ⎥⎟ + ⎝⎣3 4⎦⎠ ⎝⎣6 8⎦⎠\ +""" + assert pretty(Trace(X)) == ascii_str_1 + assert upretty(Trace(X)) == ucode_str_1 + + assert pretty(Trace(X) + Trace(Y)) == ascii_str_2 + assert upretty(Trace(X) + Trace(Y)) == ucode_str_2 + + +def test_MatrixSlice(): + n = Symbol('n', integer=True) + x, y, z, w, t, = symbols('x y z w t') + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + expr = MatrixSlice(X, (None, None, None), (None, None, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = X[x:x + 1, y:y + 1] + assert pretty(expr) == upretty(expr) == 'X[x:x + 1, y:y + 1]' + expr = X[x:x + 1:2, y:y + 1:2] + assert pretty(expr) == upretty(expr) == 'X[x:x + 1:2, y:y + 1:2]' + expr = X[:x, y:] + assert pretty(expr) == upretty(expr) == 'X[:x, y:]' + expr = X[:x, y:] + assert pretty(expr) == upretty(expr) == 'X[:x, y:]' + expr = X[x:, :y] + assert pretty(expr) == upretty(expr) == 'X[x:, :y]' + expr = X[x:y, z:w] + assert pretty(expr) == upretty(expr) == 'X[x:y, z:w]' + expr = X[x:y:t, w:t:x] + assert pretty(expr) == upretty(expr) == 'X[x:y:t, w:t:x]' + expr = X[x::y, t::w] + assert pretty(expr) == upretty(expr) == 'X[x::y, t::w]' + expr = X[:x:y, :t:w] + assert pretty(expr) == upretty(expr) == 'X[:x:y, :t:w]' + expr = X[::x, ::y] + assert pretty(expr) == upretty(expr) == 'X[::x, ::y]' + expr = MatrixSlice(X, (0, None, None), (0, None, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (None, n, None), (None, n, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (0, n, None), (0, n, None)) + assert pretty(expr) == upretty(expr) == 'X[:, :]' + expr = MatrixSlice(X, (0, n, 2), (0, n, 2)) + assert pretty(expr) == upretty(expr) == 'X[::2, ::2]' + expr = X[1:2:3, 4:5:6] + assert pretty(expr) == upretty(expr) == 'X[1:2:3, 4:5:6]' + expr = X[1:3:5, 4:6:8] + assert pretty(expr) == upretty(expr) == 'X[1:3:5, 4:6:8]' + expr = X[1:10:2] + assert pretty(expr) == upretty(expr) == 'X[1:10:2, :]' + expr = Y[:5, 1:9:2] + assert pretty(expr) == upretty(expr) == 'Y[:5, 1:9:2]' + expr = Y[:5, 1:10:2] + assert pretty(expr) == upretty(expr) == 'Y[:5, 1::2]' + expr = Y[5, :5:2] + assert pretty(expr) == upretty(expr) == 'Y[5:6, :5:2]' + expr = X[0:1, 0:1] + assert pretty(expr) == upretty(expr) == 'X[:1, :1]' + expr = X[0:1:2, 0:1:2] + assert pretty(expr) == upretty(expr) == 'X[:1:2, :1:2]' + expr = (Y + Z)[2:, 2:] + assert pretty(expr) == upretty(expr) == '(Y + Z)[2:, 2:]' + + +def test_MatrixExpressions(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + + assert pretty(X) == upretty(X) == "X" + + # Apply function elementwise (`ElementwiseApplyFunc`): + + expr = (X.T*X).applyfunc(sin) + + ascii_str = """\ + / T \\\n\ +(d -> sin(d)).\\X *X/\ +""" + ucode_str = """\ + ⎛ T ⎞\n\ +(d ↦ sin(d))˳⎝X ⋅X⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + lamda = Lambda(x, 1/x) + expr = (n*X).applyfunc(lamda) + ascii_str = """\ +/ 1\\ \n\ +|x -> -|.(n*X)\n\ +\\ x/ \ +""" + ucode_str = """\ +⎛ 1⎞ \n\ +⎜x ↦ ─⎟˳(n⋅X)\n\ +⎝ x⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_dotproduct(): + from sympy.matrices.expressions.dotproduct import DotProduct + n = symbols("n", integer=True) + A = MatrixSymbol('A', n, 1) + B = MatrixSymbol('B', n, 1) + C = Matrix(1, 3, [1, 2, 3]) + D = Matrix(1, 3, [1, 3, 4]) + + assert pretty(DotProduct(A, B)) == "A*B" + assert pretty(DotProduct(C, D)) == "[1 2 3]*[1 3 4]" + assert upretty(DotProduct(A, B)) == "A⋅B" + assert upretty(DotProduct(C, D)) == "[1 2 3]⋅[1 3 4]" + + +def test_pretty_Determinant(): + from sympy.matrices import Determinant, Inverse, BlockMatrix, OneMatrix, ZeroMatrix + m = Matrix(((1, 2), (3, 4))) + assert upretty(Determinant(m)) == '│1 2│\n│ │\n│3 4│' + assert upretty(Determinant(Inverse(m))) == \ + '│ -1│\n'\ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ │\n'\ + '│⎣3 4⎦ │' + X = MatrixSymbol('X', 2, 2) + assert upretty(Determinant(X)) == '│X│' + assert upretty(Determinant(X + m)) == \ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ + X│\n'\ + '│⎣3 4⎦ │' + assert upretty(Determinant(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '│ 𝟙 X│\n'\ + '│ │\n'\ + '│⎡1 2⎤ │\n'\ + '│⎢ ⎥ 𝟘│\n'\ + '│⎣3 4⎦ │' + + +def test_pretty_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + ascii_str = \ +"""\ +/x for x < 1\n\ +| \n\ +< 2 \n\ +|x otherwise\n\ +\\ \ +""" + ucode_str = \ +"""\ +⎧x for x < 1\n\ +⎪ \n\ +⎨ 2 \n\ +⎪x otherwise\n\ +⎩ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -Piecewise((x, x < 1), (x**2, True)) + ascii_str = \ +"""\ + //x for x < 1\\\n\ + || |\n\ +-|< 2 |\n\ + ||x otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x for x < 1⎞\n\ + ⎜⎪ ⎟\n\ +-⎜⎨ 2 ⎟\n\ + ⎜⎪x otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x + Piecewise((x, x > 0), (y, True)) + Piecewise((x/y, x < 2), + (y**2, x > 2), (1, True)) + 1 + ascii_str = \ +"""\ + //x \\ \n\ + ||- for x < 2| \n\ + ||y | \n\ + //x for x > 0\\ || | \n\ +x + |< | + |< 2 | + 1\n\ + \\\\y otherwise/ ||y for x > 2| \n\ + || | \n\ + ||1 otherwise| \n\ + \\\\ / \ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞ \n\ + ⎜⎪─ for x < 2⎟ \n\ + ⎜⎪y ⎟ \n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟ \n\ +x + ⎜⎨ ⎟ + ⎜⎨ 2 ⎟ + 1\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟ \n\ + ⎜⎪ ⎟ \n\ + ⎜⎪1 otherwise⎟ \n\ + ⎝⎩ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x - Piecewise((x, x > 0), (y, True)) + Piecewise((x/y, x < 2), + (y**2, x > 2), (1, True)) + 1 + ascii_str = \ +"""\ + //x \\ \n\ + ||- for x < 2| \n\ + ||y | \n\ + //x for x > 0\\ || | \n\ +x - |< | + |< 2 | + 1\n\ + \\\\y otherwise/ ||y for x > 2| \n\ + || | \n\ + ||1 otherwise| \n\ + \\\\ / \ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞ \n\ + ⎜⎪─ for x < 2⎟ \n\ + ⎜⎪y ⎟ \n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟ \n\ +x - ⎜⎨ ⎟ + ⎜⎨ 2 ⎟ + 1\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟ \n\ + ⎜⎪ ⎟ \n\ + ⎜⎪1 otherwise⎟ \n\ + ⎝⎩ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = x*Piecewise((x, x > 0), (y, True)) + ascii_str = \ +"""\ + //x for x > 0\\\n\ +x*|< |\n\ + \\\\y otherwise/\ +""" + ucode_str = \ +"""\ + ⎛⎧x for x > 0⎞\n\ +x⋅⎜⎨ ⎟\n\ + ⎝⎩y otherwise⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Piecewise((x, x > 0), (y, True))*Piecewise((x/y, x < 2), (y**2, x > + 2), (1, True)) + ascii_str = \ +"""\ + //x \\\n\ + ||- for x < 2|\n\ + ||y |\n\ +//x for x > 0\\ || |\n\ +|< |*|< 2 |\n\ +\\\\y otherwise/ ||y for x > 2|\n\ + || |\n\ + ||1 otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞\n\ + ⎜⎪─ for x < 2⎟\n\ + ⎜⎪y ⎟\n\ +⎛⎧x for x > 0⎞ ⎜⎪ ⎟\n\ +⎜⎨ ⎟⋅⎜⎨ 2 ⎟\n\ +⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟\n\ + ⎜⎪ ⎟\n\ + ⎜⎪1 otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -Piecewise((x, x > 0), (y, True))*Piecewise((x/y, x < 2), (y**2, x + > 2), (1, True)) + ascii_str = \ +"""\ + //x \\\n\ + ||- for x < 2|\n\ + ||y |\n\ + //x for x > 0\\ || |\n\ +-|< |*|< 2 |\n\ + \\\\y otherwise/ ||y for x > 2|\n\ + || |\n\ + ||1 otherwise|\n\ + \\\\ /\ +""" + ucode_str = \ +"""\ + ⎛⎧x ⎞\n\ + ⎜⎪─ for x < 2⎟\n\ + ⎜⎪y ⎟\n\ + ⎛⎧x for x > 0⎞ ⎜⎪ ⎟\n\ +-⎜⎨ ⎟⋅⎜⎨ 2 ⎟\n\ + ⎝⎩y otherwise⎠ ⎜⎪y for x > 2⎟\n\ + ⎜⎪ ⎟\n\ + ⎜⎪1 otherwise⎟\n\ + ⎝⎩ ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Piecewise((0, Abs(1/y) < 1), (1, Abs(y) < 1), (y*meijerg(((2, 1), + ()), ((), (1, 0)), 1/y), True)) + ascii_str = \ +"""\ +/ 1 \n\ +| 0 for --- < 1\n\ +| |y| \n\ +| \n\ +< 1 for |y| < 1\n\ +| \n\ +| __0, 2 /1, 2 | 1\\ \n\ +|y*/__ | | -| otherwise \n\ +\\ \\_|2, 2 \\ 0, 1 | y/ \ +""" + ucode_str = \ +"""\ +⎧ 1 \n\ +⎪ 0 for ─── < 1\n\ +⎪ │y│ \n\ +⎪ \n\ +⎨ 1 for │y│ < 1\n\ +⎪ \n\ +⎪ ╭─╮0, 2 ⎛1, 2 │ 1⎞ \n\ +⎪y⋅│╶┐ ⎜ │ ─⎟ otherwise \n\ +⎩ ╰─╯2, 2 ⎝ 0, 1 │ y⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + # XXX: We have to use evaluate=False here because Piecewise._eval_power + # denests the power. + expr = Pow(Piecewise((x, x > 0), (y, True)), 2, evaluate=False) + ascii_str = \ +"""\ + 2\n\ +//x for x > 0\\ \n\ +|< | \n\ +\\\\y otherwise/ \ +""" + ucode_str = \ +"""\ + 2\n\ +⎛⎧x for x > 0⎞ \n\ +⎜⎨ ⎟ \n\ +⎝⎩y otherwise⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_ITE(): + expr = ITE(x, y, z) + assert pretty(expr) == ( + '/y for x \n' + '< \n' + '\\z otherwise' + ) + assert upretty(expr) == """\ +⎧y for x \n\ +⎨ \n\ +⎩z otherwise\ +""" + + +def test_pretty_seq(): + expr = () + ascii_str = \ +"""\ +()\ +""" + ucode_str = \ +"""\ +()\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = [] + ascii_str = \ +"""\ +[]\ +""" + ucode_str = \ +"""\ +[]\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {} + expr_2 = {} + ascii_str = \ +"""\ +{}\ +""" + ucode_str = \ +"""\ +{}\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + expr = (1/x,) + ascii_str = \ +"""\ + 1 \n\ +(-,)\n\ + x \ +""" + ucode_str = \ +"""\ +⎛1 ⎞\n\ +⎜─,⎟\n\ +⎝x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = [x**2, 1/x, x, y, sin(th)**2/cos(ph)**2] + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +[x , -, x, y, -----------]\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎡ 2 ⎤\n\ +⎢ 2 1 sin (θ)⎥\n\ +⎢x , ─, x, y, ───────⎥\n\ +⎢ x 2 ⎥\n\ +⎣ cos (φ)⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +(x , -, x, y, -----------)\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎜ 2 1 sin (θ)⎟\n\ +⎜x , ─, x, y, ───────⎟\n\ +⎜ x 2 ⎟\n\ +⎝ cos (φ)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Tuple(x**2, 1/x, x, y, sin(th)**2/cos(ph)**2) + ascii_str = \ +"""\ + 2 \n\ + 2 1 sin (theta) \n\ +(x , -, x, y, -----------)\n\ + x 2 \n\ + cos (phi) \ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎜ 2 1 sin (θ)⎟\n\ +⎜x , ─, x, y, ───────⎟\n\ +⎜ x 2 ⎟\n\ +⎝ cos (φ)⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {x: sin(x)} + expr_2 = Dict({x: sin(x)}) + ascii_str = \ +"""\ +{x: sin(x)}\ +""" + ucode_str = \ +"""\ +{x: sin(x)}\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + expr = {1/x: 1/y, x: sin(x)**2} + expr_2 = Dict({1/x: 1/y, x: sin(x)**2}) + ascii_str = \ +"""\ + 1 1 2 \n\ +{-: -, x: sin (x)}\n\ + x y \ +""" + ucode_str = \ +"""\ +⎧1 1 2 ⎫\n\ +⎨─: ─, x: sin (x)⎬\n\ +⎩x y ⎭\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + # There used to be a bug with pretty-printing sequences of even height. + expr = [x**2] + ascii_str = \ +"""\ + 2 \n\ +[x ]\ +""" + ucode_str = \ +"""\ +⎡ 2⎤\n\ +⎣x ⎦\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (x**2,) + ascii_str = \ +"""\ + 2 \n\ +(x ,)\ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎝x ,⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Tuple(x**2) + ascii_str = \ +"""\ + 2 \n\ +(x ,)\ +""" + ucode_str = \ +"""\ +⎛ 2 ⎞\n\ +⎝x ,⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = {x**2: 1} + expr_2 = Dict({x**2: 1}) + ascii_str = \ +"""\ + 2 \n\ +{x : 1}\ +""" + ucode_str = \ +"""\ +⎧ 2 ⎫\n\ +⎨x : 1⎬\n\ +⎩ ⎭\ +""" + assert pretty(expr) == ascii_str + assert pretty(expr_2) == ascii_str + assert upretty(expr) == ucode_str + assert upretty(expr_2) == ucode_str + + +def test_any_object_in_sequence(): + # Cf. issue 5306 + b1 = Basic() + b2 = Basic(Basic()) + + expr = [b2, b1] + assert pretty(expr) == "[Basic(Basic()), Basic()]" + assert upretty(expr) == "[Basic(Basic()), Basic()]" + + expr = {b2, b1} + assert pretty(expr) == "{Basic(), Basic(Basic())}" + assert upretty(expr) == "{Basic(), Basic(Basic())}" + + expr = {b2: b1, b1: b2} + expr2 = Dict({b2: b1, b1: b2}) + assert pretty(expr) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert pretty( + expr2) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert upretty( + expr) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + assert upretty( + expr2) == "{Basic(): Basic(Basic()), Basic(Basic()): Basic()}" + + +def test_print_builtin_set(): + assert pretty(set()) == 'set()' + assert upretty(set()) == 'set()' + + assert pretty(frozenset()) == 'frozenset()' + assert upretty(frozenset()) == 'frozenset()' + + s1 = {1/x, x} + s2 = frozenset(s1) + + assert pretty(s1) == \ +"""\ + 1 \n\ +{-, x} + x \ +""" + assert upretty(s1) == \ +"""\ +⎧1 ⎫ +⎨─, x⎬ +⎩x ⎭\ +""" + + assert pretty(s2) == \ +"""\ + 1 \n\ +frozenset({-, x}) + x \ +""" + assert upretty(s2) == \ +"""\ + ⎛⎧1 ⎫⎞ +frozenset⎜⎨─, x⎬⎟ + ⎝⎩x ⎭⎠\ +""" + + +def test_pretty_sets(): + s = FiniteSet + assert pretty(s(*[x*y, x**2])) == \ +"""\ + 2 \n\ +{x , x*y}\ +""" + assert pretty(s(*range(1, 6))) == "{1, 2, 3, 4, 5}" + assert pretty(s(*range(1, 13))) == "{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}" + + assert pretty({x*y, x**2}) == \ +"""\ + 2 \n\ +{x , x*y}\ +""" + assert pretty(set(range(1, 6))) == "{1, 2, 3, 4, 5}" + assert pretty(set(range(1, 13))) == \ + "{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12}" + + assert pretty(frozenset([x*y, x**2])) == \ +"""\ + 2 \n\ +frozenset({x , x*y})\ +""" + assert pretty(frozenset(range(1, 6))) == "frozenset({1, 2, 3, 4, 5})" + assert pretty(frozenset(range(1, 13))) == \ + "frozenset({1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})" + + assert pretty(Range(0, 3, 1)) == '{0, 1, 2}' + + ascii_str = '{0, 1, ..., 29}' + ucode_str = '{0, 1, …, 29}' + assert pretty(Range(0, 30, 1)) == ascii_str + assert upretty(Range(0, 30, 1)) == ucode_str + + ascii_str = '{30, 29, ..., 2}' + ucode_str = '{30, 29, …, 2}' + assert pretty(Range(30, 1, -1)) == ascii_str + assert upretty(Range(30, 1, -1)) == ucode_str + + ascii_str = '{0, 2, ...}' + ucode_str = '{0, 2, …}' + assert pretty(Range(0, oo, 2)) == ascii_str + assert upretty(Range(0, oo, 2)) == ucode_str + + ascii_str = '{..., 2, 0}' + ucode_str = '{…, 2, 0}' + assert pretty(Range(oo, -2, -2)) == ascii_str + assert upretty(Range(oo, -2, -2)) == ucode_str + + ascii_str = '{-2, -3, ...}' + ucode_str = '{-2, -3, …}' + assert pretty(Range(-2, -oo, -1)) == ascii_str + assert upretty(Range(-2, -oo, -1)) == ucode_str + + +def test_pretty_SetExpr(): + iv = Interval(1, 3) + se = SetExpr(iv) + ascii_str = "SetExpr([1, 3])" + ucode_str = "SetExpr([1, 3])" + assert pretty(se) == ascii_str + assert upretty(se) == ucode_str + + +def test_pretty_ImageSet(): + imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4}) + ascii_str = '{x + y | x in {1, 2, 3}, y in {3, 4}}' + ucode_str = '{x + y │ x ∊ {1, 2, 3}, y ∊ {3, 4}}' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda(((x, y),), x + y), ProductSet({1, 2, 3}, {3, 4})) + ascii_str = '{x + y | (x, y) in {1, 2, 3} x {3, 4}}' + ucode_str = '{x + y │ (x, y) ∊ {1, 2, 3} × {3, 4}}' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda(x, x**2), S.Naturals) + ascii_str = '''\ + 2 \n\ +{x | x in Naturals}''' + ucode_str = '''\ +⎧ 2 │ ⎫\n\ +⎨x │ x ∊ ℕ⎬\n\ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + # TODO: The "x in N" parts below should be centered independently of the + # 1/x**2 fraction + imgset = ImageSet(Lambda(x, 1/x**2), S.Naturals) + ascii_str = '''\ + 1 \n\ +{-- | x in Naturals} + 2 \n\ + x ''' + ucode_str = '''\ +⎧1 │ ⎫\n\ +⎪── │ x ∊ ℕ⎪\n\ +⎨ 2 │ ⎬\n\ +⎪x │ ⎪\n\ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + imgset = ImageSet(Lambda((x, y), 1/(x + y)**2), S.Naturals, S.Naturals) + ascii_str = '''\ + 1 \n\ +{-------- | x in Naturals, y in Naturals} + 2 \n\ + (x + y) ''' + ucode_str = '''\ +⎧ 1 │ ⎫ +⎪──────── │ x ∊ ℕ, y ∊ ℕ⎪ +⎨ 2 │ ⎬ +⎪(x + y) │ ⎪ +⎩ │ ⎭''' + assert pretty(imgset) == ascii_str + assert upretty(imgset) == ucode_str + + # issue 23449 centering issue + assert upretty([Symbol("ihat") / (Symbol("i") + 1)]) == '''\ +⎡ î ⎤ +⎢─────⎥ +⎣i + 1⎦\ +''' + assert upretty(Matrix([Symbol("ihat"), Symbol("i") + 1])) == '''\ +⎡ î ⎤ +⎢ ⎥ +⎣i + 1⎦\ +''' + + +def test_pretty_ConditionSet(): + ascii_str = '{x | x in (-oo, oo) and sin(x) = 0}' + ucode_str = '{x │ x ∊ ℝ ∧ (sin(x) = 0)}' + assert pretty(ConditionSet(x, Eq(sin(x), 0), S.Reals)) == ascii_str + assert upretty(ConditionSet(x, Eq(sin(x), 0), S.Reals)) == ucode_str + + assert pretty(ConditionSet(x, Contains(x, S.Reals, evaluate=False), FiniteSet(1))) == '{1}' + assert upretty(ConditionSet(x, Contains(x, S.Reals, evaluate=False), FiniteSet(1))) == '{1}' + + assert pretty(ConditionSet(x, And(x > 1, x < -1), FiniteSet(1, 2, 3))) == "EmptySet" + assert upretty(ConditionSet(x, And(x > 1, x < -1), FiniteSet(1, 2, 3))) == "∅" + + assert pretty(ConditionSet(x, Or(x > 1, x < -1), FiniteSet(1, 2))) == '{2}' + assert upretty(ConditionSet(x, Or(x > 1, x < -1), FiniteSet(1, 2))) == '{2}' + + condset = ConditionSet(x, 1/x**2 > 0) + ascii_str = '''\ + 1 \n\ +{x | -- > 0} + 2 \n\ + x ''' + ucode_str = '''\ +⎧ │ ⎛1 ⎞⎫ +⎪x │ ⎜── > 0⎟⎪ +⎨ │ ⎜ 2 ⎟⎬ +⎪ │ ⎝x ⎠⎪ +⎩ │ ⎭''' + assert pretty(condset) == ascii_str + assert upretty(condset) == ucode_str + + condset = ConditionSet(x, 1/x**2 > 0, S.Reals) + ascii_str = '''\ + 1 \n\ +{x | x in (-oo, oo) and -- > 0} + 2 \n\ + x ''' + ucode_str = '''\ +⎧ │ ⎛1 ⎞⎫ +⎪x │ x ∊ ℝ ∧ ⎜── > 0⎟⎪ +⎨ │ ⎜ 2 ⎟⎬ +⎪ │ ⎝x ⎠⎪ +⎩ │ ⎭''' + assert pretty(condset) == ascii_str + assert upretty(condset) == ucode_str + + +def test_pretty_ComplexRegion(): + from sympy.sets.fancysets import ComplexRegion + cregion = ComplexRegion(Interval(3, 5)*Interval(4, 6)) + ascii_str = '{x + y*I | x, y in [3, 5] x [4, 6]}' + ucode_str = '{x + y⋅ⅈ │ x, y ∊ [3, 5] × [4, 6]}' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(0, 1)*Interval(0, 2*pi), polar=True) + ascii_str = '{r*(I*sin(theta) + cos(theta)) | r, theta in [0, 1] x [0, 2*pi)}' + ucode_str = '{r⋅(ⅈ⋅sin(θ) + cos(θ)) │ r, θ ∊ [0, 1] × [0, 2⋅π)}' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(3, 1/a**2)*Interval(4, 6)) + ascii_str = '''\ + 1 \n\ +{x + y*I | x, y in [3, --] x [4, 6]} + 2 \n\ + a ''' + ucode_str = '''\ +⎧ │ ⎡ 1 ⎤ ⎫ +⎪x + y⋅ⅈ │ x, y ∊ ⎢3, ──⎥ × [4, 6]⎪ +⎨ │ ⎢ 2⎥ ⎬ +⎪ │ ⎣ a ⎦ ⎪ +⎩ │ ⎭''' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + cregion = ComplexRegion(Interval(0, 1/a**2)*Interval(0, 2*pi), polar=True) + ascii_str = '''\ + 1 \n\ +{r*(I*sin(theta) + cos(theta)) | r, theta in [0, --] x [0, 2*pi)} + 2 \n\ + a ''' + ucode_str = '''\ +⎧ │ ⎡ 1 ⎤ ⎫ +⎪r⋅(ⅈ⋅sin(θ) + cos(θ)) │ r, θ ∊ ⎢0, ──⎥ × [0, 2⋅π)⎪ +⎨ │ ⎢ 2⎥ ⎬ +⎪ │ ⎣ a ⎦ ⎪ +⎩ │ ⎭''' + assert pretty(cregion) == ascii_str + assert upretty(cregion) == ucode_str + + +def test_pretty_Union_issue_10414(): + a, b = Interval(2, 3), Interval(4, 7) + ucode_str = '[2, 3] ∪ [4, 7]' + ascii_str = '[2, 3] U [4, 7]' + assert upretty(Union(a, b)) == ucode_str + assert pretty(Union(a, b)) == ascii_str + + +def test_pretty_Intersection_issue_10414(): + x, y, z, w = symbols('x, y, z, w') + a, b = Interval(x, y), Interval(z, w) + ucode_str = '[x, y] ∩ [z, w]' + ascii_str = '[x, y] n [z, w]' + assert upretty(Intersection(a, b)) == ucode_str + assert pretty(Intersection(a, b)) == ascii_str + + +def test_ProductSet_exponent(): + ucode_str = ' 1\n[0, 1] ' + assert upretty(Interval(0, 1)**1) == ucode_str + ucode_str = ' 2\n[0, 1] ' + assert upretty(Interval(0, 1)**2) == ucode_str + + +def test_ProductSet_parenthesis(): + ucode_str = '([4, 7] × {1, 2}) ∪ ([2, 3] × [4, 7])' + + a, b = Interval(2, 3), Interval(4, 7) + assert upretty(Union(a*b, b*FiniteSet(1, 2))) == ucode_str + + +def test_ProductSet_prod_char_issue_10413(): + ascii_str = '[2, 3] x [4, 7]' + ucode_str = '[2, 3] × [4, 7]' + + a, b = Interval(2, 3), Interval(4, 7) + assert pretty(a*b) == ascii_str + assert upretty(a*b) == ucode_str + + +def test_pretty_sequences(): + s1 = SeqFormula(a**2, (0, oo)) + s2 = SeqPer((1, 2)) + + ascii_str = '[0, 1, 4, 9, ...]' + ucode_str = '[0, 1, 4, 9, …]' + + assert pretty(s1) == ascii_str + assert upretty(s1) == ucode_str + + ascii_str = '[1, 2, 1, 2, ...]' + ucode_str = '[1, 2, 1, 2, …]' + assert pretty(s2) == ascii_str + assert upretty(s2) == ucode_str + + s3 = SeqFormula(a**2, (0, 2)) + s4 = SeqPer((1, 2), (0, 2)) + + ascii_str = '[0, 1, 4]' + ucode_str = '[0, 1, 4]' + + assert pretty(s3) == ascii_str + assert upretty(s3) == ucode_str + + ascii_str = '[1, 2, 1]' + ucode_str = '[1, 2, 1]' + assert pretty(s4) == ascii_str + assert upretty(s4) == ucode_str + + s5 = SeqFormula(a**2, (-oo, 0)) + s6 = SeqPer((1, 2), (-oo, 0)) + + ascii_str = '[..., 9, 4, 1, 0]' + ucode_str = '[…, 9, 4, 1, 0]' + + assert pretty(s5) == ascii_str + assert upretty(s5) == ucode_str + + ascii_str = '[..., 2, 1, 2, 1]' + ucode_str = '[…, 2, 1, 2, 1]' + assert pretty(s6) == ascii_str + assert upretty(s6) == ucode_str + + ascii_str = '[1, 3, 5, 11, ...]' + ucode_str = '[1, 3, 5, 11, …]' + + assert pretty(SeqAdd(s1, s2)) == ascii_str + assert upretty(SeqAdd(s1, s2)) == ucode_str + + ascii_str = '[1, 3, 5]' + ucode_str = '[1, 3, 5]' + + assert pretty(SeqAdd(s3, s4)) == ascii_str + assert upretty(SeqAdd(s3, s4)) == ucode_str + + ascii_str = '[..., 11, 5, 3, 1]' + ucode_str = '[…, 11, 5, 3, 1]' + + assert pretty(SeqAdd(s5, s6)) == ascii_str + assert upretty(SeqAdd(s5, s6)) == ucode_str + + ascii_str = '[0, 2, 4, 18, ...]' + ucode_str = '[0, 2, 4, 18, …]' + + assert pretty(SeqMul(s1, s2)) == ascii_str + assert upretty(SeqMul(s1, s2)) == ucode_str + + ascii_str = '[0, 2, 4]' + ucode_str = '[0, 2, 4]' + + assert pretty(SeqMul(s3, s4)) == ascii_str + assert upretty(SeqMul(s3, s4)) == ucode_str + + ascii_str = '[..., 18, 4, 2, 0]' + ucode_str = '[…, 18, 4, 2, 0]' + + assert pretty(SeqMul(s5, s6)) == ascii_str + assert upretty(SeqMul(s5, s6)) == ucode_str + + # Sequences with symbolic limits, issue 12629 + s7 = SeqFormula(a**2, (a, 0, x)) + raises(NotImplementedError, lambda: pretty(s7)) + raises(NotImplementedError, lambda: upretty(s7)) + + b = Symbol('b') + s8 = SeqFormula(b*a**2, (a, 0, 2)) + ascii_str = '[0, b, 4*b]' + ucode_str = '[0, b, 4⋅b]' + assert pretty(s8) == ascii_str + assert upretty(s8) == ucode_str + + +def test_pretty_FourierSeries(): + f = fourier_series(x, (x, -pi, pi)) + + ascii_str = \ +"""\ + 2*sin(3*x) \n\ +2*sin(x) - sin(2*x) + ---------- + ...\n\ + 3 \ +""" + + ucode_str = \ +"""\ + 2⋅sin(3⋅x) \n\ +2⋅sin(x) - sin(2⋅x) + ────────── + …\n\ + 3 \ +""" + + assert pretty(f) == ascii_str + assert upretty(f) == ucode_str + + +def test_pretty_FormalPowerSeries(): + f = fps(log(1 + x)) + + + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ -k k \n\ + \\ -(-1) *x \n\ + / -----------\n\ + / k \n\ +/___, \n\ +k = 1 \ +""" + + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ -k k \n\ + ╲ -(-1) ⋅x \n\ + ╱ ───────────\n\ + ╱ k \n\ +╱ \n\ +‾‾‾‾ \n\ +k = 1 \ +""" + + assert pretty(f) == ascii_str + assert upretty(f) == ucode_str + + +def test_pretty_limits(): + expr = Limit(x, x, oo) + ascii_str = \ +"""\ + lim x\n\ +x->oo \ +""" + ucode_str = \ +"""\ +lim x\n\ +x─→∞ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x**2, x, 0) + ascii_str = \ +"""\ + 2\n\ + lim x \n\ +x->0+ \ +""" + ucode_str = \ +"""\ + 2\n\ + lim x \n\ +x─→0⁺ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(1/x, x, 0) + ascii_str = \ +"""\ + 1\n\ + lim -\n\ +x->0+x\ +""" + ucode_str = \ +"""\ + 1\n\ + lim ─\n\ +x─→0⁺x\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x)/x, x, 0) + ascii_str = \ +"""\ + /sin(x)\\\n\ + lim |------|\n\ +x->0+\\ x /\ +""" + ucode_str = \ +"""\ + ⎛sin(x)⎞\n\ + lim ⎜──────⎟\n\ +x─→0⁺⎝ x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x)/x, x, 0, "-") + ascii_str = \ +"""\ + /sin(x)\\\n\ + lim |------|\n\ +x->0-\\ x /\ +""" + ucode_str = \ +"""\ + ⎛sin(x)⎞\n\ + lim ⎜──────⎟\n\ +x─→0⁻⎝ x ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x + sin(x), x, 0) + ascii_str = \ +"""\ + lim (x + sin(x))\n\ +x->0+ \ +""" + ucode_str = \ +"""\ + lim (x + sin(x))\n\ +x─→0⁺ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x, x, 0)**2 + ascii_str = \ +"""\ + 2\n\ +/ lim x\\ \n\ +\\x->0+ / \ +""" + ucode_str = \ +"""\ + 2\n\ +⎛ lim x⎞ \n\ +⎝x─→0⁺ ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(x*Limit(y/2,y,0), x, 0) + ascii_str = \ +"""\ + / /y\\\\\n\ + lim |x* lim |-||\n\ +x->0+\\ y->0+\\2//\ +""" + ucode_str = \ +"""\ + ⎛ ⎛y⎞⎞\n\ + lim ⎜x⋅ lim ⎜─⎟⎟\n\ +x─→0⁺⎝ y─→0⁺⎝2⎠⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = 2*Limit(x*Limit(y/2,y,0), x, 0) + ascii_str = \ +"""\ + / /y\\\\\n\ +2* lim |x* lim |-||\n\ + x->0+\\ y->0+\\2//\ +""" + ucode_str = \ +"""\ + ⎛ ⎛y⎞⎞\n\ +2⋅ lim ⎜x⋅ lim ⎜─⎟⎟\n\ + x─→0⁺⎝ y─→0⁺⎝2⎠⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Limit(sin(x), x, 0, dir='+-') + ascii_str = \ +"""\ +lim sin(x)\n\ +x->0 \ +""" + ucode_str = \ +"""\ +lim sin(x)\n\ +x─→0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_ComplexRootOf(): + expr = rootof(x**5 + 11*x - 2, 0) + ascii_str = \ +"""\ + / 5 \\\n\ +CRootOf\\x + 11*x - 2, 0/\ +""" + ucode_str = \ +"""\ + ⎛ 5 ⎞\n\ +CRootOf⎝x + 11⋅x - 2, 0⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_RootSum(): + expr = RootSum(x**5 + 11*x - 2, auto=False) + ascii_str = \ +"""\ + / 5 \\\n\ +RootSum\\x + 11*x - 2/\ +""" + ucode_str = \ +"""\ + ⎛ 5 ⎞\n\ +RootSum⎝x + 11⋅x - 2⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = RootSum(x**5 + 11*x - 2, Lambda(z, exp(z))) + ascii_str = \ +"""\ + / 5 z\\\n\ +RootSum\\x + 11*x - 2, z -> e /\ +""" + ucode_str = \ +"""\ + ⎛ 5 z⎞\n\ +RootSum⎝x + 11⋅x - 2, z ↦ ℯ ⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_GroebnerBasis(): + expr = groebner([], x, y) + + ascii_str = \ +"""\ +GroebnerBasis([], x, y, domain=ZZ, order=lex)\ +""" + ucode_str = \ +"""\ +GroebnerBasis([], x, y, domain=ℤ, order=lex)\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + expr = groebner(F, x, y, order='grlex') + + ascii_str = \ +"""\ + /[ 2 2 ] \\\n\ +GroebnerBasis\\[x - x - 3*y + 1, y - 2*x + y - 1], x, y, domain=ZZ, order=grlex/\ +""" + ucode_str = \ +"""\ + ⎛⎡ 2 2 ⎤ ⎞\n\ +GroebnerBasis⎝⎣x - x - 3⋅y + 1, y - 2⋅x + y - 1⎦, x, y, domain=ℤ, order=grlex⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = expr.fglm('lex') + + ascii_str = \ +"""\ + /[ 2 4 3 2 ] \\\n\ +GroebnerBasis\\[2*x - y - y + 1, y + 2*y - 3*y - 16*y + 7], x, y, domain=ZZ, order=lex/\ +""" + ucode_str = \ +"""\ + ⎛⎡ 2 4 3 2 ⎤ ⎞\n\ +GroebnerBasis⎝⎣2⋅x - y - y + 1, y + 2⋅y - 3⋅y - 16⋅y + 7⎦, x, y, domain=ℤ, order=lex⎠\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_UniversalSet(): + assert pretty(S.UniversalSet) == "UniversalSet" + assert upretty(S.UniversalSet) == '𝕌' + + +def test_pretty_Boolean(): + expr = Not(x, evaluate=False) + + assert pretty(expr) == "Not(x)" + assert upretty(expr) == "¬x" + + expr = And(x, y) + + assert pretty(expr) == "And(x, y)" + assert upretty(expr) == "x ∧ y" + + expr = Or(x, y) + + assert pretty(expr) == "Or(x, y)" + assert upretty(expr) == "x ∨ y" + + syms = symbols('a:f') + expr = And(*syms) + + assert pretty(expr) == "And(a, b, c, d, e, f)" + assert upretty(expr) == "a ∧ b ∧ c ∧ d ∧ e ∧ f" + + expr = Or(*syms) + + assert pretty(expr) == "Or(a, b, c, d, e, f)" + assert upretty(expr) == "a ∨ b ∨ c ∨ d ∨ e ∨ f" + + expr = Xor(x, y, evaluate=False) + + assert pretty(expr) == "Xor(x, y)" + assert upretty(expr) == "x ⊻ y" + + expr = Nand(x, y, evaluate=False) + + assert pretty(expr) == "Nand(x, y)" + assert upretty(expr) == "x ⊼ y" + + expr = Nor(x, y, evaluate=False) + + assert pretty(expr) == "Nor(x, y)" + assert upretty(expr) == "x ⊽ y" + + expr = Implies(x, y, evaluate=False) + + assert pretty(expr) == "Implies(x, y)" + assert upretty(expr) == "x → y" + + # don't sort args + expr = Implies(y, x, evaluate=False) + + assert pretty(expr) == "Implies(y, x)" + assert upretty(expr) == "y → x" + + expr = Equivalent(x, y, evaluate=False) + + assert pretty(expr) == "Equivalent(x, y)" + assert upretty(expr) == "x ⇔ y" + + expr = Equivalent(y, x, evaluate=False) + + assert pretty(expr) == "Equivalent(x, y)" + assert upretty(expr) == "x ⇔ y" + + +def test_pretty_Domain(): + expr = FF(23) + + assert pretty(expr) == "GF(23)" + assert upretty(expr) == "ℤ₂₃" + + expr = ZZ + + assert pretty(expr) == "ZZ" + assert upretty(expr) == "ℤ" + + expr = QQ + + assert pretty(expr) == "QQ" + assert upretty(expr) == "ℚ" + + expr = RR + + assert pretty(expr) == "RR" + assert upretty(expr) == "ℝ" + + expr = QQ[x] + + assert pretty(expr) == "QQ[x]" + assert upretty(expr) == "ℚ[x]" + + expr = QQ[x, y] + + assert pretty(expr) == "QQ[x, y]" + assert upretty(expr) == "ℚ[x, y]" + + expr = ZZ.frac_field(x) + + assert pretty(expr) == "ZZ(x)" + assert upretty(expr) == "ℤ(x)" + + expr = ZZ.frac_field(x, y) + + assert pretty(expr) == "ZZ(x, y)" + assert upretty(expr) == "ℤ(x, y)" + + expr = QQ.poly_ring(x, y, order=grlex) + + assert pretty(expr) == "QQ[x, y, order=grlex]" + assert upretty(expr) == "ℚ[x, y, order=grlex]" + + expr = QQ.poly_ring(x, y, order=ilex) + + assert pretty(expr) == "QQ[x, y, order=ilex]" + assert upretty(expr) == "ℚ[x, y, order=ilex]" + + +def test_pretty_prec(): + assert xpretty(S("0.3"), full_prec=True, wrap_line=False) == "0.300000000000000" + assert xpretty(S("0.3"), full_prec="auto", wrap_line=False) == "0.300000000000000" + assert xpretty(S("0.3"), full_prec=False, wrap_line=False) == "0.3" + assert xpretty(S("0.3")*x, full_prec=True, use_unicode=False, wrap_line=False) in [ + "0.300000000000000*x", + "x*0.300000000000000" + ] + assert xpretty(S("0.3")*x, full_prec="auto", use_unicode=False, wrap_line=False) in [ + "0.3*x", + "x*0.3" + ] + assert xpretty(S("0.3")*x, full_prec=False, use_unicode=False, wrap_line=False) in [ + "0.3*x", + "x*0.3" + ] + + +def test_pprint(): + import sys + from io import StringIO + fd = StringIO() + sso = sys.stdout + sys.stdout = fd + try: + pprint(pi, use_unicode=False, wrap_line=False) + finally: + sys.stdout = sso + assert fd.getvalue() == 'pi\n' + + +def test_pretty_class(): + """Test that the printer dispatcher correctly handles classes.""" + class C: + pass # C has no .__class__ and this was causing problems + + class D: + pass + + assert pretty( C ) == str( C ) + assert pretty( D ) == str( D ) + + +def test_pretty_no_wrap_line(): + huge_expr = 0 + for i in range(20): + huge_expr += i*sin(i + x) + assert xpretty(huge_expr ).find('\n') != -1 + assert xpretty(huge_expr, wrap_line=False).find('\n') == -1 + + +def test_settings(): + raises(TypeError, lambda: pretty(S(4), method="garbage")) + + +def test_pretty_sum(): + from sympy.abc import x, a, b, k, m, n + + expr = Sum(k**k, (k, 0, n)) + ascii_str = \ +"""\ + n \n\ +___ \n\ +\\ ` \n\ + \\ k\n\ + / k \n\ +/__, \n\ +k = 0 \ +""" + ucode_str = \ +"""\ + n \n\ + ___ \n\ + ╲ \n\ + ╲ k\n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾ \n\ +k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**k, (k, oo, n)) + ascii_str = \ +"""\ + n \n\ + ___ \n\ + \\ ` \n\ + \\ k\n\ + / k \n\ + /__, \n\ +k = oo \ +""" + ucode_str = \ +"""\ + n \n\ + ___ \n\ + ╲ \n\ + ╲ k\n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾ \n\ +k = ∞ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**(Integral(x**n, (x, -oo, oo))), (k, 0, n**n)) + ascii_str = \ +"""\ + n \n\ + n \n\ +______ \n\ +\\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ +/_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ + n \n\ + n \n\ +______ \n\ +╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ +╱ \n\ +‾‾‾‾‾‾ \n\ +k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**( + Integral(x**n, (x, -oo, oo))), (k, 0, Integral(x**x, (x, -oo, oo)))) + ascii_str = \ +"""\ + oo \n\ + / \n\ + | \n\ + | x \n\ + | x dx \n\ + | \n\ +/ \n\ +-oo \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ +∞ \n\ +⌠ \n\ +⎮ x \n\ +⎮ x dx \n\ +⌡ \n\ +-∞ \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**(Integral(x**n, (x, -oo, oo))), ( + k, x + n + x**2 + n**2 + (x/n) + (1/x), Integral(x**x, (x, -oo, oo)))) + ascii_str = \ +"""\ + oo \n\ + / \n\ + | \n\ + | x \n\ + | x dx \n\ + | \n\ + / \n\ + -oo \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + 2 2 1 x \n\ +k = n + n + x + x + - + - \n\ + x n \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ⌠ \n\ + ⎮ x \n\ + ⎮ x dx \n\ + ⌡ \n\ + -∞ \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + 2 2 1 x \n\ +k = n + n + x + x + ─ + ─ \n\ + x n \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(k**( + Integral(x**n, (x, -oo, oo))), (k, 0, x + n + x**2 + n**2 + (x/n) + (1/x))) + ascii_str = \ +"""\ + 2 2 1 x \n\ +n + n + x + x + - + - \n\ + x n \n\ + ______ \n\ + \\ ` \n\ + \\ oo \n\ + \\ / \n\ + \\ | \n\ + \\ | n \n\ + ) | x dx\n\ + / | \n\ + / / \n\ + / -oo \n\ + / k \n\ + /_____, \n\ + k = 0 \ +""" + ucode_str = \ +"""\ + 2 2 1 x \n\ +n + n + x + x + ─ + ─ \n\ + x n \n\ + ______ \n\ + ╲ \n\ + ╲ \n\ + ╲ ∞ \n\ + ╲ ⌠ \n\ + ╲ ⎮ n \n\ + ╱ ⎮ x dx\n\ + ╱ ⌡ \n\ + ╱ -∞ \n\ + ╱ k \n\ + ╱ \n\ + ‾‾‾‾‾‾ \n\ + k = 0 \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ + __ \n\ + \\ ` \n\ + ) x\n\ + /_, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ___ \n\ + ╲ \n\ + ╲ \n\ + ╱ x\n\ + ╱ \n\ + ‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x**2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +___ \n\ +\\ ` \n\ + \\ 2\n\ + / x \n\ +/__, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ + ___ \n\ + ╲ \n\ + ╲ 2\n\ + ╱ x \n\ + ╱ \n\ + ‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x/2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +___ \n\ +\\ ` \n\ + \\ x\n\ + ) -\n\ + / 2\n\ +/__, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ \n\ + ╲ x\n\ + ╱ ─\n\ + ╱ 2\n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(x**3/2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ 3\n\ + \\ x \n\ + / --\n\ + / 2 \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ 3\n\ + ╲ x \n\ + ╱ ──\n\ + ╱ 2 \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum((x**3*y**(x/2))**n, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ n\n\ + \\ / x\\ \n\ + ) | -| \n\ + / | 3 2| \n\ + / \\x *y / \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +_____ \n\ +╲ \n\ + ╲ \n\ + ╲ n\n\ + ╲ ⎛ x⎞ \n\ + ╱ ⎜ ─⎟ \n\ + ╱ ⎜ 3 2⎟ \n\ + ╱ ⎝x ⋅y ⎠ \n\ +╱ \n\ +‾‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/x**2, (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ 1 \n\ + \\ --\n\ + / 2\n\ + / x \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ 1 \n\ + ╲ ──\n\ + ╱ 2\n\ + ╱ x \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/y**(a/b), (x, 0, oo)) + ascii_str = \ +"""\ + oo \n\ +____ \n\ +\\ ` \n\ + \\ -a \n\ + \\ ---\n\ + / b \n\ + / y \n\ +/___, \n\ +x = 0 \ +""" + ucode_str = \ +"""\ + ∞ \n\ +____ \n\ +╲ \n\ + ╲ -a \n\ + ╲ ───\n\ + ╱ b \n\ + ╱ y \n\ +╱ \n\ +‾‾‾‾ \n\ +x = 0 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Sum(1/y**(a/b), (x, 0, oo), (y, 1, 2)) + ascii_str = \ +"""\ + 2 oo \n\ +____ ____ \n\ +\\ ` \\ ` \n\ + \\ \\ -a\n\ + \\ \\ --\n\ + / / b \n\ + / / y \n\ +/___, /___, \n\ +y = 1 x = 0 \ +""" + ucode_str = \ +"""\ + 2 ∞ \n\ +____ ____ \n\ +╲ ╲ \n\ + ╲ ╲ -a\n\ + ╲ ╲ ──\n\ + ╱ ╱ b \n\ + ╱ ╱ y \n\ +╱ ╱ \n\ +‾‾‾‾ ‾‾‾‾ \n\ +y = 1 x = 0 \ +""" + expr = Sum(1/(1 + 1/( + 1 + 1/k)) + 1, (k, 111, 1 + 1/n), (k, 1/(1 + m), oo)) + 1/(1 + 1/k) + ascii_str = \ +"""\ + 1 \n\ + 1 + - \n\ + oo n \n\ + _____ _____ \n\ + \\ ` \\ ` \n\ + \\ \\ / 1 \\ \n\ + \\ \\ |1 + ---------| \n\ + \\ \\ | 1 | 1 \n\ + ) ) | 1 + -----| + -----\n\ + / / | 1| 1\n\ + / / | 1 + -| 1 + -\n\ + / / \\ k/ k\n\ + /____, /____, \n\ + 1 k = 111 \n\ +k = ----- \n\ + m + 1 \ +""" + ucode_str = \ +"""\ + 1 \n\ + 1 + ─ \n\ + ∞ n \n\ + ______ ______ \n\ + ╲ ╲ \n\ + ╲ ╲ \n\ + ╲ ╲ ⎛ 1 ⎞ \n\ + ╲ ╲ ⎜1 + ─────────⎟ \n\ + ╲ ╲ ⎜ 1 ⎟ 1 \n\ + ╱ ╱ ⎜ 1 + ─────⎟ + ─────\n\ + ╱ ╱ ⎜ 1⎟ 1\n\ + ╱ ╱ ⎜ 1 + ─⎟ 1 + ─\n\ + ╱ ╱ ⎝ k⎠ k\n\ + ╱ ╱ \n\ + ‾‾‾‾‾‾ ‾‾‾‾‾‾ \n\ + 1 k = 111 \n\ +k = ───── \n\ + m + 1 \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_units(): + expr = joule + ascii_str1 = \ +"""\ + 2\n\ +kilogram*meter \n\ +---------------\n\ + 2 \n\ + second \ +""" + unicode_str1 = \ +"""\ + 2\n\ +kilogram⋅meter \n\ +───────────────\n\ + 2 \n\ + second \ +""" + + ascii_str2 = \ +"""\ + 2\n\ +3*x*y*kilogram*meter \n\ +---------------------\n\ + 2 \n\ + second \ +""" + unicode_str2 = \ +"""\ + 2\n\ +3⋅x⋅y⋅kilogram⋅meter \n\ +─────────────────────\n\ + 2 \n\ + second \ +""" + + from sympy.physics.units import kg, m, s + assert upretty(expr) == "joule" + assert pretty(expr) == "joule" + assert upretty(expr.convert_to(kg*m**2/s**2)) == unicode_str1 + assert pretty(expr.convert_to(kg*m**2/s**2)) == ascii_str1 + assert upretty(3*kg*x*m**2*y/s**2) == unicode_str2 + assert pretty(3*kg*x*m**2*y/s**2) == ascii_str2 + + +def test_pretty_Subs(): + f = Function('f') + expr = Subs(f(x), x, ph**2) + ascii_str = \ +"""\ +(f(x))| 2\n\ + |x=phi \ +""" + unicode_str = \ +"""\ +(f(x))│ 2\n\ + │x=φ \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + expr = Subs(f(x).diff(x), x, 0) + ascii_str = \ +"""\ +/d \\| \n\ +|--(f(x))|| \n\ +\\dx /|x=0\ +""" + unicode_str = \ +"""\ +⎛d ⎞│ \n\ +⎜──(f(x))⎟│ \n\ +⎝dx ⎠│x=0\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + expr = Subs(f(x).diff(x)/y, (x, y), (0, Rational(1, 2))) + ascii_str = \ +"""\ +/d \\| \n\ +|--(f(x))|| \n\ +|dx || \n\ +|--------|| \n\ +\\ y /|x=0, y=1/2\ +""" + unicode_str = \ +"""\ +⎛d ⎞│ \n\ +⎜──(f(x))⎟│ \n\ +⎜dx ⎟│ \n\ +⎜────────⎟│ \n\ +⎝ y ⎠│x=0, y=1/2\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == unicode_str + + +def test_gammas(): + assert upretty(lowergamma(x, y)) == "γ(x, y)" + assert upretty(uppergamma(x, y)) == "Γ(x, y)" + assert xpretty(gamma(x), use_unicode=True) == 'Γ(x)' + assert xpretty(gamma, use_unicode=True) == 'Γ' + assert xpretty(symbols('gamma', cls=Function)(x), use_unicode=True) == 'γ(x)' + assert xpretty(symbols('gamma', cls=Function), use_unicode=True) == 'γ' + + +def test_beta(): + assert xpretty(beta(x,y), use_unicode=True) == 'Β(x, y)' + assert xpretty(beta(x,y), use_unicode=False) == 'B(x, y)' + assert xpretty(beta, use_unicode=True) == 'Β' + assert xpretty(beta, use_unicode=False) == 'B' + mybeta = Function('beta') + assert xpretty(mybeta(x), use_unicode=True) == 'β(x)' + assert xpretty(mybeta(x, y, z), use_unicode=False) == 'beta(x, y, z)' + assert xpretty(mybeta, use_unicode=True) == 'β' + + +# test that notation passes to subclasses of the same name only +def test_function_subclass_different_name(): + class mygamma(gamma): + pass + assert xpretty(mygamma, use_unicode=True) == r"mygamma" + assert xpretty(mygamma(x), use_unicode=True) == r"mygamma(x)" + + +def test_SingularityFunction(): + assert xpretty(SingularityFunction(x, 0, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 1, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, -1, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, a, n), use_unicode=True) == ( +"""\ + n\n\ +<-a + x> \ +""") + assert xpretty(SingularityFunction(x, y, n), use_unicode=True) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 0, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, 1, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, -1, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + assert xpretty(SingularityFunction(x, a, n), use_unicode=False) == ( +"""\ + n\n\ +<-a + x> \ +""") + assert xpretty(SingularityFunction(x, y, n), use_unicode=False) == ( +"""\ + n\n\ + \ +""") + + +def test_deltas(): + assert xpretty(DiracDelta(x), use_unicode=True) == 'δ(x)' + assert xpretty(DiracDelta(x, 1), use_unicode=True) == \ +"""\ + (1) \n\ +δ (x)\ +""" + assert xpretty(x*DiracDelta(x, 1), use_unicode=True) == \ +"""\ + (1) \n\ +x⋅δ (x)\ +""" + + +def test_hyper(): + expr = hyper((), (), z) + ucode_str = \ +"""\ + ┌─ ⎛ │ ⎞\n\ + ├─ ⎜ │ z⎟\n\ +0╵ 0 ⎝ │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ / | \\\n\ + | | | z|\n\ +0 0 \\ | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((), (1,), x) + ucode_str = \ +"""\ + ┌─ ⎛ │ ⎞\n\ + ├─ ⎜ │ x⎟\n\ +0╵ 1 ⎝1 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ / | \\\n\ + | | | x|\n\ +0 1 \\1 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper([2], [1], x) + ucode_str = \ +"""\ + ┌─ ⎛2 │ ⎞\n\ + ├─ ⎜ │ x⎟\n\ +1╵ 1 ⎝1 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ /2 | \\\n\ + | | | x|\n\ +1 1 \\1 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((pi/3, -2*k), (3, 4, 5, -3), x) + ucode_str = \ +"""\ + ⎛ π │ ⎞\n\ + ┌─ ⎜ ─, -2⋅k │ ⎟\n\ + ├─ ⎜ 3 │ x⎟\n\ +2╵ 4 ⎜ │ ⎟\n\ + ⎝-3, 3, 4, 5 │ ⎠\ +""" + ascii_str = \ +"""\ + \n\ + _ / pi | \\\n\ + |_ | --, -2*k | |\n\ + | | 3 | x|\n\ +2 4 | | |\n\ + \\-3, 3, 4, 5 | /\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper((pi, S('2/3'), -2*k), (3, 4, 5, -3), x**2) + ucode_str = \ +"""\ + ┌─ ⎛2/3, π, -2⋅k │ 2⎞\n\ + ├─ ⎜ │ x ⎟\n\ +3╵ 4 ⎝-3, 3, 4, 5 │ ⎠\ +""" + ascii_str = \ +"""\ + _ \n\ + |_ /2/3, pi, -2*k | 2\\ + | | | x | +3 4 \\ -3, 3, 4, 5 | /""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hyper([1, 2], [3, 4], 1/(1/(1/(1/x + 1) + 1) + 1)) + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ + ⎜ │ ─────────────⎟\n\ + ⎜ │ 1 ⎟\n\ + ┌─ ⎜1, 2 │ 1 + ─────────⎟\n\ + ├─ ⎜ │ 1 ⎟\n\ +2╵ 2 ⎜3, 4 │ 1 + ─────⎟\n\ + ⎜ │ 1⎟\n\ + ⎜ │ 1 + ─⎟\n\ + ⎝ │ x⎠\ +""" + + ascii_str = \ +"""\ + \n\ + / | 1 \\\n\ + | | -------------|\n\ + _ | | 1 |\n\ + |_ |1, 2 | 1 + ---------|\n\ + | | | 1 |\n\ +2 2 |3, 4 | 1 + -----|\n\ + | | 1|\n\ + | | 1 + -|\n\ + \\ | x/\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_meijerg(): + expr = meijerg([pi, pi, x], [1], [0, 1], [1, 2, 3], z) + ucode_str = \ +"""\ +╭─╮2, 3 ⎛π, π, x 1 │ ⎞\n\ +│╶┐ ⎜ │ z⎟\n\ +╰─╯4, 5 ⎝ 0, 1 1, 2, 3 │ ⎠\ +""" + ascii_str = \ +"""\ + __2, 3 /pi, pi, x 1 | \\\n\ +/__ | | z|\n\ +\\_|4, 5 \\ 0, 1 1, 2, 3 | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = meijerg([1, pi/7], [2, pi, 5], [], [], z**2) + ucode_str = \ +"""\ + ⎛ π │ ⎞\n\ +╭─╮0, 2 ⎜1, ─ 2, 5, π │ 2⎟\n\ +│╶┐ ⎜ 7 │ z ⎟\n\ +╰─╯5, 0 ⎜ │ ⎟\n\ + ⎝ │ ⎠\ +""" + ascii_str = \ +"""\ + / pi | \\\n\ + __0, 2 |1, -- 2, 5, pi | 2|\n\ +/__ | 7 | z |\n\ +\\_|5, 0 | | |\n\ + \\ | /\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ucode_str = \ +"""\ +╭─╮ 1, 10 ⎛1, 1, 1, 1, 1, 1, 1, 1, 1, 1 1 │ ⎞\n\ +│╶┐ ⎜ │ z⎟\n\ +╰─╯11, 2 ⎝ 1 1 │ ⎠\ +""" + ascii_str = \ +"""\ + __ 1, 10 /1, 1, 1, 1, 1, 1, 1, 1, 1, 1 1 | \\\n\ +/__ | | z|\n\ +\\_|11, 2 \\ 1 1 | /\ +""" + + expr = meijerg([1]*10, [1], [1], [1], z) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = meijerg([1, 2, ], [4, 3], [3], [4, 5], 1/(1/(1/(1/x + 1) + 1) + 1)) + + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ + ⎜ │ ─────────────⎟\n\ + ⎜ │ 1 ⎟\n\ +╭─╮1, 2 ⎜1, 2 3, 4 │ 1 + ─────────⎟\n\ +│╶┐ ⎜ │ 1 ⎟\n\ +╰─╯4, 3 ⎜ 3 4, 5 │ 1 + ─────⎟\n\ + ⎜ │ 1⎟\n\ + ⎜ │ 1 + ─⎟\n\ + ⎝ │ x⎠\ +""" + + ascii_str = \ +"""\ + / | 1 \\\n\ + | | -------------|\n\ + | | 1 |\n\ + __1, 2 |1, 2 3, 4 | 1 + ---------|\n\ +/__ | | 1 |\n\ +\\_|4, 3 | 3 4, 5 | 1 + -----|\n\ + | | 1|\n\ + | | 1 + -|\n\ + \\ | x/\ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = Integral(expr, x) + + ucode_str = \ +"""\ +⌠ \n\ +⎮ ⎛ │ 1 ⎞ \n\ +⎮ ⎜ │ ─────────────⎟ \n\ +⎮ ⎜ │ 1 ⎟ \n\ +⎮ ╭─╮1, 2 ⎜1, 2 3, 4 │ 1 + ─────────⎟ \n\ +⎮ │╶┐ ⎜ │ 1 ⎟ dx\n\ +⎮ ╰─╯4, 3 ⎜ 3 4, 5 │ 1 + ─────⎟ \n\ +⎮ ⎜ │ 1⎟ \n\ +⎮ ⎜ │ 1 + ─⎟ \n\ +⎮ ⎝ │ x⎠ \n\ +⌡ \ +""" + + ascii_str = \ +"""\ + / \n\ + | \n\ + | / | 1 \\ \n\ + | | | -------------| \n\ + | | | 1 | \n\ + | __1, 2 |1, 2 3, 4 | 1 + ---------| \n\ + | /__ | | 1 | dx\n\ + | \\_|4, 3 | 3 4, 5 | 1 + -----| \n\ + | | | 1| \n\ + | | | 1 + -| \n\ + | \\ | x/ \n\ + | \n\ +/ \ +""" + + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + expr = A*B*C**-1 + ascii_str = \ +"""\ + -1\n\ +A*B*C \ +""" + ucode_str = \ +"""\ + -1\n\ +A⋅B⋅C \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = C**-1*A*B + ascii_str = \ +"""\ + -1 \n\ +C *A*B\ +""" + ucode_str = \ +"""\ + -1 \n\ +C ⋅A⋅B\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A*C**-1*B + ascii_str = \ +"""\ + -1 \n\ +A*C *B\ +""" + ucode_str = \ +"""\ + -1 \n\ +A⋅C ⋅B\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A*C**-1*B/x + ascii_str = \ +"""\ + -1 \n\ +A*C *B\n\ +-------\n\ + x \ +""" + ucode_str = \ +"""\ + -1 \n\ +A⋅C ⋅B\n\ +───────\n\ + x \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_special_functions(): + x, y = symbols("x y") + + # atan2 + expr = atan2(y/sqrt(200), sqrt(x)) + ascii_str = \ +"""\ + / ___ \\\n\ + |\\/ 2 *y ___|\n\ +atan2|-------, \\/ x |\n\ + \\ 20 /\ +""" + ucode_str = \ +"""\ + ⎛√2⋅y ⎞\n\ +atan2⎜────, √x⎟\n\ + ⎝ 20 ⎠\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_geometry(): + e = Segment((0, 1), (0, 2)) + assert pretty(e) == 'Segment2D(Point2D(0, 1), Point2D(0, 2))' + e = Ray((1, 1), angle=4.02*pi) + assert pretty(e) == 'Ray2D(Point2D(1, 1), Point2D(2, tan(pi/50) + 1))' + + +def test_expint(): + expr = Ei(x) + string = 'Ei(x)' + assert pretty(expr) == string + assert upretty(expr) == string + + expr = expint(1, z) + ucode_str = "E₁(z)" + ascii_str = "expint(1, z)" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + assert pretty(Shi(x)) == 'Shi(x)' + assert pretty(Si(x)) == 'Si(x)' + assert pretty(Ci(x)) == 'Ci(x)' + assert pretty(Chi(x)) == 'Chi(x)' + assert upretty(Shi(x)) == 'Shi(x)' + assert upretty(Si(x)) == 'Si(x)' + assert upretty(Ci(x)) == 'Ci(x)' + assert upretty(Chi(x)) == 'Chi(x)' + + +def test_elliptic_functions(): + ascii_str = \ +"""\ + / 1 \\\n\ +K|-----|\n\ + \\z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ 1 ⎞\n\ +K⎜─────⎟\n\ + ⎝z + 1⎠\ +""" + expr = elliptic_k(1/(z + 1)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / | 1 \\\n\ +F|1|-----|\n\ + \\ |z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ +F⎜1│─────⎟\n\ + ⎝ │z + 1⎠\ +""" + expr = elliptic_f(1, 1/(1 + z)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / 1 \\\n\ +E|-----|\n\ + \\z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ 1 ⎞\n\ +E⎜─────⎟\n\ + ⎝z + 1⎠\ +""" + expr = elliptic_e(1/(z + 1)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / | 1 \\\n\ +E|1|-----|\n\ + \\ |z + 1/\ +""" + ucode_str = \ +"""\ + ⎛ │ 1 ⎞\n\ +E⎜1│─────⎟\n\ + ⎝ │z + 1⎠\ +""" + expr = elliptic_e(1, 1/(1 + z)) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / |4\\\n\ +Pi|3|-|\n\ + \\ |x/\ +""" + ucode_str = \ +"""\ + ⎛ │4⎞\n\ +Π⎜3│─⎟\n\ + ⎝ │x⎠\ +""" + expr = elliptic_pi(3, 4/x) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + ascii_str = \ +"""\ + / 4| \\\n\ +Pi|3; -|6|\n\ + \\ x| /\ +""" + ucode_str = \ +"""\ + ⎛ 4│ ⎞\n\ +Π⎜3; ─│6⎟\n\ + ⎝ x│ ⎠\ +""" + expr = elliptic_pi(3, 4/x, 6) + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + X = Normal('x1', 0, 1) + assert upretty(where(X > 0)) == "Domain: 0 < x₁ ∧ x₁ < ∞" + + D = Die('d1', 6) + assert upretty(where(D > 4)) == 'Domain: d₁ = 5 ∨ d₁ = 6' + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert upretty(pspace(Tuple(A, B)).domain) == \ + 'Domain: 0 ≤ a ∧ 0 ≤ b ∧ a < ∞ ∧ b < ∞' + + +def test_PrettyPoly(): + F = QQ.frac_field(x, y) + R = QQ.poly_ring(x, y) + + expr = F.convert(x/(x + y)) + assert pretty(expr) == "x/(x + y)" + assert upretty(expr) == "x/(x + y)" + + expr = R.convert(x + y) + assert pretty(expr) == "x + y" + assert upretty(expr) == "x + y" + + +def test_issue_6285(): + assert pretty(Pow(2, -5, evaluate=False)) == '1 \n--\n 5\n2 ' + assert pretty(Pow(x, (1/pi))) == \ + ' 1 \n'\ + ' --\n'\ + ' pi\n'\ + 'x ' + + +def test_issue_6359(): + assert pretty(Integral(x**2, x)**2) == \ +"""\ + 2 +/ / \\ \n\ +| | | \n\ +| | 2 | \n\ +| | x dx| \n\ +| | | \n\ +\\/ / \ +""" + assert upretty(Integral(x**2, x)**2) == \ +"""\ + 2 +⎛⌠ ⎞ \n\ +⎜⎮ 2 ⎟ \n\ +⎜⎮ x dx⎟ \n\ +⎝⌡ ⎠ \ +""" + + assert pretty(Sum(x**2, (x, 0, 1))**2) == \ +"""\ + 2\n\ +/ 1 \\ \n\ +|___ | \n\ +|\\ ` | \n\ +| \\ 2| \n\ +| / x | \n\ +|/__, | \n\ +\\x = 0 / \ +""" + assert upretty(Sum(x**2, (x, 0, 1))**2) == \ +"""\ + 2 +⎛ 1 ⎞ \n\ +⎜ ___ ⎟ \n\ +⎜ ╲ ⎟ \n\ +⎜ ╲ 2⎟ \n\ +⎜ ╱ x ⎟ \n\ +⎜ ╱ ⎟ \n\ +⎜ ‾‾‾ ⎟ \n\ +⎝x = 0 ⎠ \ +""" + + assert pretty(Product(x**2, (x, 1, 2))**2) == \ +"""\ + 2 +/ 2 \\ \n\ +|______ | \n\ +| | | 2| \n\ +| | | x | \n\ +| | | | \n\ +\\x = 1 / \ +""" + assert upretty(Product(x**2, (x, 1, 2))**2) == \ +"""\ + 2 +⎛ 2 ⎞ \n\ +⎜─┬──┬─ ⎟ \n\ +⎜ │ │ 2⎟ \n\ +⎜ │ │ x ⎟ \n\ +⎜ │ │ ⎟ \n\ +⎝x = 1 ⎠ \ +""" + + f = Function('f') + assert pretty(Derivative(f(x), x)**2) == \ +"""\ + 2 +/d \\ \n\ +|--(f(x))| \n\ +\\dx / \ +""" + assert upretty(Derivative(f(x), x)**2) == \ +"""\ + 2 +⎛d ⎞ \n\ +⎜──(f(x))⎟ \n\ +⎝dx ⎠ \ +""" + + +def test_issue_6739(): + ascii_str = \ +"""\ + 1 \n\ +-----\n\ + ___\n\ +\\/ x \ +""" + ucode_str = \ +"""\ +1 \n\ +──\n\ +√x\ +""" + assert pretty(1/sqrt(x)) == ascii_str + assert upretty(1/sqrt(x)) == ucode_str + + +def test_complicated_symbol_unchanged(): + for symb_name in ["dexpr2_d1tau", "dexpr2^d1tau"]: + assert pretty(Symbol(symb_name)) == symb_name + + +def test_categories(): + from sympy.categories import (Object, IdentityMorphism, + NamedMorphism, Category, Diagram, DiagramGrid) + + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A2, A3, "f2") + id_A1 = IdentityMorphism(A1) + + K1 = Category("K1") + + assert pretty(A1) == "A1" + assert upretty(A1) == "A₁" + + assert pretty(f1) == "f1:A1-->A2" + assert upretty(f1) == "f₁:A₁——▶A₂" + assert pretty(id_A1) == "id:A1-->A1" + assert upretty(id_A1) == "id:A₁——▶A₁" + + assert pretty(f2*f1) == "f2*f1:A1-->A3" + assert upretty(f2*f1) == "f₂∘f₁:A₁——▶A₃" + + assert pretty(K1) == "K1" + assert upretty(K1) == "K₁" + + # Test how diagrams are printed. + d = Diagram() + assert pretty(d) == "EmptySet" + assert upretty(d) == "∅" + + d = Diagram({f1: "unique", f2: S.EmptySet}) + assert pretty(d) == "{f2*f1:A1-->A3: EmptySet, id:A1-->A1: " \ + "EmptySet, id:A2-->A2: EmptySet, id:A3-->A3: " \ + "EmptySet, f1:A1-->A2: {unique}, f2:A2-->A3: EmptySet}" + + assert upretty(d) == "{f₂∘f₁:A₁——▶A₃: ∅, id:A₁——▶A₁: ∅, " \ + "id:A₂——▶A₂: ∅, id:A₃——▶A₃: ∅, f₁:A₁——▶A₂: {unique}, f₂:A₂——▶A₃: ∅}" + + d = Diagram({f1: "unique", f2: S.EmptySet}, {f2 * f1: "unique"}) + assert pretty(d) == "{f2*f1:A1-->A3: EmptySet, id:A1-->A1: " \ + "EmptySet, id:A2-->A2: EmptySet, id:A3-->A3: " \ + "EmptySet, f1:A1-->A2: {unique}, f2:A2-->A3: EmptySet}" \ + " ==> {f2*f1:A1-->A3: {unique}}" + assert upretty(d) == "{f₂∘f₁:A₁——▶A₃: ∅, id:A₁——▶A₁: ∅, id:A₂——▶A₂: " \ + "∅, id:A₃——▶A₃: ∅, f₁:A₁——▶A₂: {unique}, f₂:A₂——▶A₃: ∅}" \ + " ══▶ {f₂∘f₁:A₁——▶A₃: {unique}}" + + grid = DiagramGrid(d) + assert pretty(grid) == "A1 A2\n \nA3 " + assert upretty(grid) == "A₁ A₂\n \nA₃ " + + +def test_PrettyModules(): + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + M = F.submodule([x, y], [1, x**2]) + + ucode_str = \ +"""\ + 2\n\ +ℚ[x, y] \ +""" + ascii_str = \ +"""\ + 2\n\ +QQ[x, y] \ +""" + + assert upretty(F) == ucode_str + assert pretty(F) == ascii_str + + ucode_str = \ +"""\ +╱ ⎡ 2⎤╲\n\ +╲[x, y], ⎣1, x ⎦╱\ +""" + ascii_str = \ +"""\ + 2 \n\ +<[x, y], [1, x ]>\ +""" + + assert upretty(M) == ucode_str + assert pretty(M) == ascii_str + + I = R.ideal(x**2, y) + + ucode_str = \ +"""\ +╱ 2 ╲\n\ +╲x , y╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ +\ +""" + + assert upretty(I) == ucode_str + assert pretty(I) == ascii_str + + Q = F / M + + ucode_str = \ +"""\ + 2 \n\ + ℚ[x, y] \n\ +─────────────────\n\ +╱ ⎡ 2⎤╲\n\ +╲[x, y], ⎣1, x ⎦╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ + QQ[x, y] \n\ +-----------------\n\ + 2 \n\ +<[x, y], [1, x ]>\ +""" + + assert upretty(Q) == ucode_str + assert pretty(Q) == ascii_str + + ucode_str = \ +"""\ +╱⎡ 3⎤ ╲\n\ +│⎢ x ⎥ ╱ ⎡ 2⎤╲ ╱ ⎡ 2⎤╲│\n\ +│⎢1, ──⎥ + ╲[x, y], ⎣1, x ⎦╱, [2, y] + ╲[x, y], ⎣1, x ⎦╱│\n\ +╲⎣ 2 ⎦ ╱\ +""" + + ascii_str = \ +"""\ + 3 \n\ + x 2 2 \n\ +<[1, --] + <[x, y], [1, x ]>, [2, y] + <[x, y], [1, x ]>>\n\ + 2 \ +""" + + +def test_QuotientRing(): + R = QQ.old_poly_ring(x)/[x**2 + 1] + + ucode_str = \ +"""\ + ℚ[x] \n\ +────────\n\ +╱ 2 ╲\n\ +╲x + 1╱\ +""" + + ascii_str = \ +"""\ + QQ[x] \n\ +--------\n\ + 2 \n\ +\ +""" + + assert upretty(R) == ucode_str + assert pretty(R) == ascii_str + + ucode_str = \ +"""\ + ╱ 2 ╲\n\ +1 + ╲x + 1╱\ +""" + + ascii_str = \ +"""\ + 2 \n\ +1 + \ +""" + + assert upretty(R.one) == ucode_str + assert pretty(R.one) == ascii_str + + +def test_Homomorphism(): + from sympy.polys.agca import homomorphism + + R = QQ.old_poly_ring(x) + + expr = homomorphism(R.free_module(1), R.free_module(1), [0]) + + ucode_str = \ +"""\ + 1 1\n\ +[0] : ℚ[x] ──> ℚ[x] \ +""" + + ascii_str = \ +"""\ + 1 1\n\ +[0] : QQ[x] --> QQ[x] \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + expr = homomorphism(R.free_module(2), R.free_module(2), [0, 0]) + + ucode_str = \ +"""\ +⎡0 0⎤ 2 2\n\ +⎢ ⎥ : ℚ[x] ──> ℚ[x] \n\ +⎣0 0⎦ \ +""" + + ascii_str = \ +"""\ +[0 0] 2 2\n\ +[ ] : QQ[x] --> QQ[x] \n\ +[0 0] \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + expr = homomorphism(R.free_module(1), R.free_module(1) / [[x]], [0]) + + ucode_str = \ +"""\ + 1\n\ + 1 ℚ[x] \n\ +[0] : ℚ[x] ──> ─────\n\ + <[x]>\ +""" + + ascii_str = \ +"""\ + 1\n\ + 1 QQ[x] \n\ +[0] : QQ[x] --> ------\n\ + <[x]> \ +""" + + assert upretty(expr) == ucode_str + assert pretty(expr) == ascii_str + + +def test_Tr(): + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert pretty(t) == r'Tr(A*B)' + assert upretty(t) == 'Tr(A⋅B)' + + +def test_pretty_Add(): + eq = Mul(-2, x - 2, evaluate=False) + 5 + assert pretty(eq) == '5 - 2*(x - 2)' + + +def test_issue_7179(): + assert upretty(Not(Equivalent(x, y))) == 'x ⇎ y' + assert upretty(Not(Implies(x, y))) == 'x ↛ y' + + +def test_issue_7180(): + assert upretty(Equivalent(x, y)) == 'x ⇔ y' + + +def test_pretty_Complement(): + assert pretty(S.Reals - S.Naturals) == '(-oo, oo) \\ Naturals' + assert upretty(S.Reals - S.Naturals) == 'ℝ \\ ℕ' + assert pretty(S.Reals - S.Naturals0) == '(-oo, oo) \\ Naturals0' + assert upretty(S.Reals - S.Naturals0) == 'ℝ \\ ℕ₀' + + +def test_pretty_SymmetricDifference(): + from sympy.sets.sets import SymmetricDifference + assert upretty(SymmetricDifference(Interval(2,3), Interval(3,5), \ + evaluate = False)) == '[2, 3] ∆ [3, 5]' + with raises(NotImplementedError): + pretty(SymmetricDifference(Interval(2,3), Interval(3,5), evaluate = False)) + + +def test_pretty_Contains(): + assert pretty(Contains(x, S.Integers)) == 'Contains(x, Integers)' + assert upretty(Contains(x, S.Integers)) == 'x ∈ ℤ' + + +def test_issue_8292(): + from sympy.core import sympify + e = sympify('((x+x**4)/(x-1))-(2*(x-1)**4/(x-1)**4)', evaluate=False) + ucode_str = \ +"""\ + 4 4 \n\ + 2⋅(x - 1) x + x\n\ +- ────────── + ──────\n\ + 4 x - 1 \n\ + (x - 1) \ +""" + ascii_str = \ +"""\ + 4 4 \n\ + 2*(x - 1) x + x\n\ +- ---------- + ------\n\ + 4 x - 1 \n\ + (x - 1) \ +""" + assert pretty(e) == ascii_str + assert upretty(e) == ucode_str + + +def test_issue_4335(): + y = Function('y') + expr = -y(x).diff(x) + ucode_str = \ +"""\ + d \n\ +-──(y(x))\n\ + dx \ +""" + ascii_str = \ +"""\ + d \n\ +- --(y(x))\n\ + dx \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_8344(): + from sympy.core import sympify + e = sympify('2*x*y**2/1**2 + 1', evaluate=False) + ucode_str = \ +"""\ + 2 \n\ +2⋅x⋅y \n\ +────── + 1\n\ + 2 \n\ + 1 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_6324(): + x = Pow(2, 3, evaluate=False) + y = Pow(10, -2, evaluate=False) + e = Mul(x, y, evaluate=False) + ucode_str = \ +"""\ + 3 \n\ +2 \n\ +───\n\ + 2\n\ +10 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_7927(): + e = sin(x/2)**cos(x/2) + ucode_str = \ +"""\ + ⎛x⎞\n\ + cos⎜─⎟\n\ + ⎝2⎠\n\ +⎛ ⎛x⎞⎞ \n\ +⎜sin⎜─⎟⎟ \n\ +⎝ ⎝2⎠⎠ \ +""" + assert upretty(e) == ucode_str + e = sin(x)**(S(11)/13) + ucode_str = \ +"""\ + 11\n\ + ──\n\ + 13\n\ +(sin(x)) \ +""" + assert upretty(e) == ucode_str + + +def test_issue_6134(): + from sympy.abc import lamda, t + phi = Function('phi') + + e = lamda*x*Integral(phi(t)*pi*sin(pi*t), (t, 0, 1)) + lamda*x**2*Integral(phi(t)*2*pi*sin(2*pi*t), (t, 0, 1)) + ucode_str = \ +"""\ + 1 1 \n\ + 2 ⌠ ⌠ \n\ +λ⋅x ⋅⎮ 2⋅π⋅φ(t)⋅sin(2⋅π⋅t) dt + λ⋅x⋅⎮ π⋅φ(t)⋅sin(π⋅t) dt\n\ + ⌡ ⌡ \n\ + 0 0 \ +""" + assert upretty(e) == ucode_str + + +def test_issue_9877(): + ucode_str1 = '(2, 3) ∪ ([1, 2] \\ {x})' + a, b, c = Interval(2, 3, True, True), Interval(1, 2), FiniteSet(x) + assert upretty(Union(a, Complement(b, c))) == ucode_str1 + + ucode_str2 = '{x} ∩ {y} ∩ ({z} \\ [1, 2])' + d, e, f, g = FiniteSet(x), FiniteSet(y), FiniteSet(z), Interval(1, 2) + assert upretty(Intersection(d, e, Complement(f, g))) == ucode_str2 + + +def test_issue_13651(): + expr1 = c + Mul(-1, a + b, evaluate=False) + assert pretty(expr1) == 'c - (a + b)' + expr2 = c + Mul(-1, a - b + d, evaluate=False) + assert pretty(expr2) == 'c - (a - b + d)' + + +def test_pretty_primenu(): + from sympy.functions.combinatorial.numbers import primenu + + ascii_str1 = "nu(n)" + ucode_str1 = "ν(n)" + + n = symbols('n', integer=True) + assert pretty(primenu(n)) == ascii_str1 + assert upretty(primenu(n)) == ucode_str1 + + +def test_pretty_primeomega(): + from sympy.functions.combinatorial.numbers import primeomega + + ascii_str1 = "Omega(n)" + ucode_str1 = "Ω(n)" + + n = symbols('n', integer=True) + assert pretty(primeomega(n)) == ascii_str1 + assert upretty(primeomega(n)) == ucode_str1 + + +def test_pretty_Mod(): + from sympy.core import Mod + + ascii_str1 = "x mod 7" + ucode_str1 = "x mod 7" + + ascii_str2 = "(x + 1) mod 7" + ucode_str2 = "(x + 1) mod 7" + + ascii_str3 = "2*x mod 7" + ucode_str3 = "2⋅x mod 7" + + ascii_str4 = "(x mod 7) + 1" + ucode_str4 = "(x mod 7) + 1" + + ascii_str5 = "2*(x mod 7)" + ucode_str5 = "2⋅(x mod 7)" + + x = symbols('x', integer=True) + assert pretty(Mod(x, 7)) == ascii_str1 + assert upretty(Mod(x, 7)) == ucode_str1 + assert pretty(Mod(x + 1, 7)) == ascii_str2 + assert upretty(Mod(x + 1, 7)) == ucode_str2 + assert pretty(Mod(2 * x, 7)) == ascii_str3 + assert upretty(Mod(2 * x, 7)) == ucode_str3 + assert pretty(Mod(x, 7) + 1) == ascii_str4 + assert upretty(Mod(x, 7) + 1) == ucode_str4 + assert pretty(2 * Mod(x, 7)) == ascii_str5 + assert upretty(2 * Mod(x, 7)) == ucode_str5 + + +def test_issue_11801(): + assert pretty(Symbol("")) == "" + assert upretty(Symbol("")) == "" + + +def test_pretty_UnevaluatedExpr(): + x = symbols('x') + he = UnevaluatedExpr(1/x) + + ucode_str = \ +"""\ +1\n\ +─\n\ +x\ +""" + + assert upretty(he) == ucode_str + + ucode_str = \ +"""\ + 2\n\ +⎛1⎞ \n\ +⎜─⎟ \n\ +⎝x⎠ \ +""" + + assert upretty(he**2) == ucode_str + + ucode_str = \ +"""\ + 1\n\ +1 + ─\n\ + x\ +""" + + assert upretty(he + 1) == ucode_str + + ucode_str = \ +('''\ + 1\n\ +x⋅─\n\ + x\ +''') + assert upretty(x*he) == ucode_str + + +def test_issue_10472(): + M = (Matrix([[0, 0], [0, 0]]), Matrix([0, 0])) + + ucode_str = \ +"""\ +⎛⎡0 0⎤ ⎡0⎤⎞ +⎜⎢ ⎥, ⎢ ⎥⎟ +⎝⎣0 0⎦ ⎣0⎦⎠\ +""" + assert upretty(M) == ucode_str + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + ascii_str1 = "A_00" + ucode_str1 = "A₀₀" + assert pretty(A[0, 0]) == ascii_str1 + assert upretty(A[0, 0]) == ucode_str1 + + ascii_str1 = "3*A_00" + ucode_str1 = "3⋅A₀₀" + assert pretty(3*A[0, 0]) == ascii_str1 + assert upretty(3*A[0, 0]) == ucode_str1 + + ascii_str1 = "(-B + A)[0, 0]" + ucode_str1 = "(-B + A)[0, 0]" + F = C[0, 0].subs(C, A - B) + assert pretty(F) == ascii_str1 + assert upretty(F) == ucode_str1 + + +def test_issue_12675(): + x, y, t, j = symbols('x y t j') + e = CoordSys3D('e') + + ucode_str = \ +"""\ +⎛ t⎞ \n\ +⎜⎛x⎞ ⎟ j_e\n\ +⎜⎜─⎟ ⎟ \n\ +⎝⎝y⎠ ⎠ \ +""" + assert upretty((x/y)**t*e.j) == ucode_str + ucode_str = \ +"""\ +⎛1⎞ \n\ +⎜─⎟ j_e\n\ +⎝y⎠ \ +""" + assert upretty((1/y)*e.j) == ucode_str + + +def test_MatrixSymbol_printing(): + # test cases for issue #14237 + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + assert pretty(-A*B*C) == "-A*B*C" + assert pretty(A - B) == "-B + A" + assert pretty(A*B*C - A*B - B*C) == "-A*B -B*C + A*B*C" + + # issue #14814 + x = MatrixSymbol('x', n, n) + y = MatrixSymbol('y*', n, n) + assert pretty(x + y) == "x + y*" + ascii_str = \ +"""\ + 2 \n\ +-2*y* -a*x\ +""" + assert pretty(-a*x + -2*y*y) == ascii_str + + +def test_degree_printing(): + expr1 = 90*degree + assert pretty(expr1) == '90°' + expr2 = x*degree + assert pretty(expr2) == 'x°' + expr3 = cos(x*degree + 90*degree) + assert pretty(expr3) == 'cos(x° + 90°)' + + +def test_vector_expr_pretty_printing(): + A = CoordSys3D('A') + + assert upretty(Cross(A.i, A.x*A.i+3*A.y*A.j)) == "(i_A)×((x_A) i_A + (3⋅y_A) j_A)" + assert upretty(x*Cross(A.i, A.j)) == 'x⋅(i_A)×(j_A)' + + assert upretty(Curl(A.x*A.i + 3*A.y*A.j)) == "∇×((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Divergence(A.x*A.i + 3*A.y*A.j)) == "∇⋅((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Dot(A.i, A.x*A.i+3*A.y*A.j)) == "(i_A)⋅((x_A) i_A + (3⋅y_A) j_A)" + + assert upretty(Gradient(A.x+3*A.y)) == "∇(x_A + 3⋅y_A)" + assert upretty(Laplacian(A.x+3*A.y)) == "∆(x_A + 3⋅y_A)" + # TODO: add support for ASCII pretty. + + +def test_pretty_print_tensor_expr(): + L = TensorIndexType("L") + i, j, k = tensor_indices("i j k", L) + i0 = tensor_indices("i_0", L) + A, B, C, D = tensor_heads("A B C D", [L]) + H = TensorHead("H", [L, L]) + + expr = -i + ascii_str = \ +"""\ +-i\ +""" + ucode_str = \ +"""\ +-i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i) + ascii_str = \ +"""\ + i\n\ +A \n\ + \ +""" + ucode_str = \ +"""\ + i\n\ +A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i0) + ascii_str = \ +"""\ + i_0\n\ +A \n\ + \ +""" + ucode_str = \ +"""\ + i₀\n\ +A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(-i) + ascii_str = \ +"""\ + \n\ +A \n\ + i\ +""" + ucode_str = \ +"""\ + \n\ +A \n\ + i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = -3*A(-i) + ascii_str = \ +"""\ + \n\ +-3*A \n\ + i\ +""" + ucode_str = \ +"""\ + \n\ +-3⋅A \n\ + i\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -j) + ascii_str = \ +"""\ + i \n\ +H \n\ + j\ +""" + ucode_str = \ +"""\ + i \n\ +H \n\ + j\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -i) + ascii_str = \ +"""\ + L_0 \n\ +H \n\ + L_0\ +""" + ucode_str = \ +"""\ + L₀ \n\ +H \n\ + L₀\ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = H(i, -j)*A(j)*B(k) + ascii_str = \ +"""\ + i L_0 k\n\ +H *A *B \n\ + L_0 \ +""" + ucode_str = \ +"""\ + i L₀ k\n\ +H ⋅A ⋅B \n\ + L₀ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (1+x)*A(i) + ascii_str = \ +"""\ + i\n\ +(x + 1)*A \n\ + \ +""" + ucode_str = \ +"""\ + i\n\ +(x + 1)⋅A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i) + 3*B(i) + ascii_str = \ +"""\ + i i\n\ +3*B + A \n\ + \ +""" + ucode_str = \ +"""\ + i i\n\ +3⋅B + A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_pretty_print_tensor_partial_deriv(): + from sympy.tensor.toperators import PartialDerivative + + L = TensorIndexType("L") + i, j, k = tensor_indices("i j k", L) + + A, B, C, D = tensor_heads("A B C D", [L]) + + H = TensorHead("H", [L, L]) + + expr = PartialDerivative(A(i), A(j)) + ascii_str = \ +"""\ + d / i\\\n\ +---|A |\n\ + j\\ /\n\ +dA \n\ + \ +""" + ucode_str = \ +"""\ + ∂ ⎛ i⎞\n\ +───⎜A ⎟\n\ + j⎝ ⎠\n\ +∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i)*PartialDerivative(H(k, -i), A(j)) + ascii_str = \ +"""\ + L_0 d / k \\\n\ +A *---|H |\n\ + j\\ L_0/\n\ + dA \n\ + \ +""" + ucode_str = \ +"""\ + L₀ ∂ ⎛ k ⎞\n\ +A ⋅───⎜H ⎟\n\ + j⎝ L₀⎠\n\ + ∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = A(i)*PartialDerivative(B(k)*C(-i) + 3*H(k, -i), A(j)) + ascii_str = \ +"""\ + L_0 d / k k \\\n\ +A *---|3*H + B *C |\n\ + j\\ L_0 L_0/\n\ + dA \n\ + \ +""" + ucode_str = \ +"""\ + L₀ ∂ ⎛ k k ⎞\n\ +A ⋅───⎜3⋅H + B ⋅C ⎟\n\ + j⎝ L₀ L₀⎠\n\ + ∂A \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (A(i) + B(i))*PartialDerivative(C(j), D(j)) + ascii_str = \ +"""\ +/ i i\\ d / L_0\\\n\ +|A + B |*-----|C |\n\ +\\ / L_0\\ /\n\ + dD \n\ + \ +""" + ucode_str = \ +"""\ +⎛ i i⎞ ∂ ⎛ L₀⎞\n\ +⎜A + B ⎟⋅────⎜C ⎟\n\ +⎝ ⎠ L₀⎝ ⎠\n\ + ∂D \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = (A(i) + B(i))*PartialDerivative(C(-i), D(j)) + ascii_str = \ +"""\ +/ L_0 L_0\\ d / \\\n\ +|A + B |*---|C |\n\ +\\ / j\\ L_0/\n\ + dD \n\ + \ +""" + ucode_str = \ +"""\ +⎛ L₀ L₀⎞ ∂ ⎛ ⎞\n\ +⎜A + B ⎟⋅───⎜C ⎟\n\ +⎝ ⎠ j⎝ L₀⎠\n\ + ∂D \n\ + \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = PartialDerivative(B(-i) + A(-i), A(-j), A(-n)) + ucode_str = """\ + 2 \n\ + ∂ ⎛ ⎞\n\ +───────⎜A + B ⎟\n\ + ⎝ i i⎠\n\ +∂A ∂A \n\ + n j \ +""" + assert upretty(expr) == ucode_str + + expr = PartialDerivative(3*A(-i), A(-j), A(-n)) + ucode_str = """\ + 2 \n\ + ∂ ⎛ ⎞\n\ +───────⎜3⋅A ⎟\n\ + ⎝ i⎠\n\ +∂A ∂A \n\ + n j \ +""" + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {i:1}) + ascii_str = \ +"""\ + i=1,j\n\ +H \n\ + \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {i: 1, j: 1}) + ascii_str = \ +"""\ + i=1,j=1\n\ +H \n\ + \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = TensorElement(H(i, j), {j: 1}) + ascii_str = \ +"""\ + i,j=1\n\ +H \n\ + \ +""" + ucode_str = ascii_str + + expr = TensorElement(H(-i, j), {-i: 1}) + ascii_str = \ +"""\ + j\n\ +H \n\ + i=1 \ +""" + ucode_str = ascii_str + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_15560(): + a = MatrixSymbol('a', 1, 1) + e = pretty(a*(KroneckerProduct(a, a))) + result = 'a*(a x a)' + assert e == result + + +def test_print_polylog(): + # Part of issue 6013 + uresult = 'Li₂(3)' + aresult = 'polylog(2, 3)' + assert pretty(polylog(2, 3)) == aresult + assert upretty(polylog(2, 3)) == uresult + + +# Issue #25312 +def test_print_expint_polylog_symbolic_order(): + s, z = symbols("s, z") + uresult = 'Liₛ(z)' + aresult = 'polylog(s, z)' + assert pretty(polylog(s, z)) == aresult + assert upretty(polylog(s, z)) == uresult + # TODO: TBD polylog(s - 1, z) + uresult = 'Eₛ(z)' + aresult = 'expint(s, z)' + assert pretty(expint(s, z)) == aresult + assert upretty(expint(s, z)) == uresult + + + +def test_print_polylog_long_order_issue_25309(): + s, z = symbols("s, z") + ucode_str = \ +"""\ + ⎛ 2 ⎞\n\ +polylog⎝s , z⎠\ +""" + assert upretty(polylog(s**2, z)) == ucode_str + + +def test_print_lerchphi(): + # Part of issue 6013 + a = Symbol('a') + pretty(lerchphi(a, 1, 2)) + uresult = 'Φ(a, 1, 2)' + aresult = 'lerchphi(a, 1, 2)' + assert pretty(lerchphi(a, 1, 2)) == aresult + assert upretty(lerchphi(a, 1, 2)) == uresult + + +def test_issue_15583(): + + N = mechanics.ReferenceFrame('N') + result = '(n_x, n_y, n_z)' + e = pretty((N.x, N.y, N.z)) + assert e == result + + +def test_matrixSymbolBold(): + # Issue 15871 + def boldpretty(expr): + return xpretty(expr, use_unicode=True, wrap_line=False, mat_symbol_style="bold") + + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert boldpretty(trace(A)) == 'tr(𝐀)' + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert boldpretty(-A) == '-𝐀' + assert boldpretty(A - A*B - B) == '-𝐁 -𝐀⋅𝐁 + 𝐀' + assert boldpretty(-A*B - A*B*C - B) == '-𝐁 -𝐀⋅𝐁 -𝐀⋅𝐁⋅𝐂' + + A = MatrixSymbol("Addot", 3, 3) + assert boldpretty(A) == '𝐀̈' + omega = MatrixSymbol("omega", 3, 3) + assert boldpretty(omega) == 'ω' + omega = MatrixSymbol("omeganorm", 3, 3) + assert boldpretty(omega) == '‖ω‖' + + a = Symbol('alpha') + b = Symbol('b') + c = MatrixSymbol("c", 3, 1) + d = MatrixSymbol("d", 3, 1) + + assert boldpretty(a*B*c+b*d) == 'b⋅𝐝 + α⋅𝐁⋅𝐜' + + d = MatrixSymbol("delta", 3, 1) + B = MatrixSymbol("Beta", 3, 3) + + assert boldpretty(a*B*c+b*d) == 'b⋅δ + α⋅Β⋅𝐜' + + A = MatrixSymbol("A_2", 3, 3) + assert boldpretty(A) == '𝐀₂' + + +def test_center_accent(): + assert center_accent('a', '\N{COMBINING TILDE}') == 'ã' + assert center_accent('aa', '\N{COMBINING TILDE}') == 'aã' + assert center_accent('aaa', '\N{COMBINING TILDE}') == 'aãa' + assert center_accent('aaaa', '\N{COMBINING TILDE}') == 'aaãa' + assert center_accent('aaaaa', '\N{COMBINING TILDE}') == 'aaãaa' + assert center_accent('abcdefg', '\N{COMBINING FOUR DOTS ABOVE}') == 'abcd⃜efg' + + +def test_imaginary_unit(): + from sympy.printing.pretty import pretty # b/c it was redefined above + assert pretty(1 + I, use_unicode=False) == '1 + I' + assert pretty(1 + I, use_unicode=True) == '1 + ⅈ' + assert pretty(1 + I, use_unicode=False, imaginary_unit='j') == '1 + I' + assert pretty(1 + I, use_unicode=True, imaginary_unit='j') == '1 + ⅉ' + + raises(TypeError, lambda: pretty(I, imaginary_unit=I)) + raises(ValueError, lambda: pretty(I, imaginary_unit="kkk")) + + +def test_str_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert pretty(Identity(4)) == 'I' + assert upretty(Identity(4)) == '𝕀' + assert pretty(ZeroMatrix(2, 2)) == '0' + assert upretty(ZeroMatrix(2, 2)) == '𝟘' + assert pretty(OneMatrix(2, 2)) == '1' + assert upretty(OneMatrix(2, 2)) == '𝟙' + + +def test_pretty_misc_functions(): + assert pretty(LambertW(x)) == 'W(x)' + assert upretty(LambertW(x)) == 'W(x)' + assert pretty(LambertW(x, y)) == 'W(x, y)' + assert upretty(LambertW(x, y)) == 'W(x, y)' + assert pretty(airyai(x)) == 'Ai(x)' + assert upretty(airyai(x)) == 'Ai(x)' + assert pretty(airybi(x)) == 'Bi(x)' + assert upretty(airybi(x)) == 'Bi(x)' + assert pretty(airyaiprime(x)) == "Ai'(x)" + assert upretty(airyaiprime(x)) == "Ai'(x)" + assert pretty(airybiprime(x)) == "Bi'(x)" + assert upretty(airybiprime(x)) == "Bi'(x)" + assert pretty(fresnelc(x)) == 'C(x)' + assert upretty(fresnelc(x)) == 'C(x)' + assert pretty(fresnels(x)) == 'S(x)' + assert upretty(fresnels(x)) == 'S(x)' + assert pretty(Heaviside(x)) == 'Heaviside(x)' + assert upretty(Heaviside(x)) == 'θ(x)' + assert pretty(Heaviside(x, y)) == 'Heaviside(x, y)' + assert upretty(Heaviside(x, y)) == 'θ(x, y)' + assert pretty(dirichlet_eta(x)) == 'dirichlet_eta(x)' + assert upretty(dirichlet_eta(x)) == 'η(x)' + + +def test_hadamard_power(): + m, n, p = symbols('m, n, p', integer=True) + A = MatrixSymbol('A', m, n) + B = MatrixSymbol('B', m, n) + + # Testing printer: + expr = hadamard_power(A, n) + ascii_str = \ +"""\ + .n\n\ +A \ +""" + ucode_str = \ +"""\ + ∘n\n\ +A \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hadamard_power(A, 1+n) + ascii_str = \ +"""\ + .(n + 1)\n\ +A \ +""" + ucode_str = \ +"""\ + ∘(n + 1)\n\ +A \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + expr = hadamard_power(A*B.T, 1+n) + ascii_str = \ +"""\ + .(n + 1)\n\ +/ T\\ \n\ +\\A*B / \ +""" + ucode_str = \ +"""\ + ∘(n + 1)\n\ +⎛ T⎞ \n\ +⎝A⋅B ⎠ \ +""" + assert pretty(expr) == ascii_str + assert upretty(expr) == ucode_str + + +def test_issue_17258(): + n = Symbol('n', integer=True) + assert pretty(Sum(n, (n, -oo, 1))) == \ + ' 1 \n'\ + ' __ \n'\ + ' \\ ` \n'\ + ' ) n\n'\ + ' /_, \n'\ + 'n = -oo ' + + assert upretty(Sum(n, (n, -oo, 1))) == \ +"""\ + 1 \n\ + ___ \n\ + ╲ \n\ + ╲ \n\ + ╱ n\n\ + ╱ \n\ + ‾‾‾ \n\ +n = -∞ \ +""" + + +def test_is_combining(): + line = "v̇_m" + assert [is_combining(sym) for sym in line] == \ + [False, True, False, False] + + +def test_issue_17616(): + assert pretty(pi**(1/exp(1))) == \ + ' / -1\\\n'\ + ' \\e /\n'\ + 'pi ' + + assert upretty(pi**(1/exp(1))) == \ + ' ⎛ -1⎞\n'\ + ' ⎝ℯ ⎠\n'\ + 'π ' + + assert pretty(pi**(1/pi)) == \ + ' 1 \n'\ + ' --\n'\ + ' pi\n'\ + 'pi ' + + assert upretty(pi**(1/pi)) == \ + ' 1\n'\ + ' ─\n'\ + ' π\n'\ + 'π ' + + assert pretty(pi**(1/EulerGamma)) == \ + ' 1 \n'\ + ' ----------\n'\ + ' EulerGamma\n'\ + 'pi ' + + assert upretty(pi**(1/EulerGamma)) == \ + ' 1\n'\ + ' ─\n'\ + ' γ\n'\ + 'π ' + + z = Symbol("x_17") + assert upretty(7**(1/z)) == \ + 'x₁₇___\n'\ + ' ╲╱ 7 ' + + assert pretty(7**(1/z)) == \ + 'x_17___\n'\ + ' \\/ 7 ' + + +def test_issue_17857(): + assert pretty(Range(-oo, oo)) == '{..., -1, 0, 1, ...}' + assert pretty(Range(oo, -oo, -1)) == '{..., 1, 0, -1, ...}' + + +def test_issue_18272(): + x = Symbol('x') + n = Symbol('n') + + assert upretty(ConditionSet(x, Eq(-x + exp(x), 0), S.Complexes)) == \ + '⎧ │ ⎛ x ⎞⎫\n'\ + '⎨x │ x ∊ ℂ ∧ ⎝-x + ℯ = 0⎠⎬\n'\ + '⎩ │ ⎭' + assert upretty(ConditionSet(x, Contains(n/2, Interval(0, oo)), FiniteSet(-n/2, n/2))) == \ + '⎧ │ ⎧-n n⎫ ⎛n ⎞⎫\n'\ + '⎨x │ x ∊ ⎨───, ─⎬ ∧ ⎜─ ∈ [0, ∞)⎟⎬\n'\ + '⎩ │ ⎩ 2 2⎭ ⎝2 ⎠⎭' + assert upretty(ConditionSet(x, Eq(Piecewise((1, x >= 3), (x/2 - 1/2, x >= 2), (1/2, x >= 1), + (x/2, True)) - 1/2, 0), Interval(0, 3))) == \ + '⎧ │ ⎛⎛⎧ 1 for x ≥ 3⎞ ⎞⎫\n'\ + '⎪ │ ⎜⎜⎪ ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪x ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪─ - 0.5 for x ≥ 2⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪2 ⎟ ⎟⎪\n'\ + '⎨x │ x ∊ [0, 3] ∧ ⎜⎜⎨ ⎟ - 0.5 = 0⎟⎬\n'\ + '⎪ │ ⎜⎜⎪ 0.5 for x ≥ 1⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ x ⎟ ⎟⎪\n'\ + '⎪ │ ⎜⎜⎪ ─ otherwise⎟ ⎟⎪\n'\ + '⎩ │ ⎝⎝⎩ 2 ⎠ ⎠⎭' + + +def test_Str(): + from sympy.core.symbol import Str + assert pretty(Str('x')) == 'x' + + +def test_symbolic_probability(): + mu = symbols("mu") + sigma = symbols("sigma", positive=True) + X = Normal("X", mu, sigma) + assert pretty(Expectation(X)) == r'E[X]' + assert pretty(Variance(X)) == r'Var(X)' + assert pretty(Probability(X > 0)) == r'P(X > 0)' + Y = Normal("Y", mu, sigma) + assert pretty(Covariance(X, Y)) == 'Cov(X, Y)' + + +def test_issue_21758(): + from sympy.functions.elementary.piecewise import piecewise_fold + from sympy.series.fourier import FourierSeries + x = Symbol('x') + k, n = symbols('k n') + fo = FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), SeqFormula( + Piecewise((-2*pi*cos(n*pi)/n + 2*sin(n*pi)/n**2, (n > -oo) & (n < oo) & Ne(n, 0)), + (0, True))*sin(n*x)/pi, (n, 1, oo)))) + assert upretty(piecewise_fold(fo)) == \ + '⎧ 2⋅sin(3⋅x) \n'\ + '⎪2⋅sin(x) - sin(2⋅x) + ────────── + … for n > -∞ ∧ n < ∞ ∧ n ≠ 0\n'\ + '⎨ 3 \n'\ + '⎪ \n'\ + '⎩ 0 otherwise ' + assert pretty(FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), + SeqFormula(0, (n, 1, oo))))) == '0' + + +def test_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert pretty(m) == 'M' + p = Patch('P', m) + assert pretty(p) == "P" + rect = CoordSystem('rect', p, [x, y]) + assert pretty(rect) == "rect" + b = BaseScalarField(rect, 0) + assert pretty(b) == "x" + + +def test_deprecated_prettyForm(): + with warns_deprecated_sympy(): + from sympy.printing.pretty.pretty_symbology import xstr + assert xstr(1) == '1' + + with warns_deprecated_sympy(): + from sympy.printing.pretty.stringpict import prettyForm + p = prettyForm('s', unicode='s') + + with warns_deprecated_sympy(): + assert p.unicode == p.s == 's' + + +def test_center(): + assert center('1', 2) == '1 ' + assert center('1', 3) == ' 1 ' + assert center('1', 3, '-') == '-1-' + assert center('1', 5, '-') == '--1--' diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__init__.py b/lib/python3.10/site-packages/sympy/printing/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6f95c6b3225af4464126d6e64d94a6e9320e7f7b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ce814fa8bd2ef29ffdfa903a1edcb526cc25f754 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_aesaracode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..635ef5ba4a2986b07e8716bcff6bb75407112fb5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_c.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1a7b25ca07ff4ea0ab9ffbc2444fe02c291cd980 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_codeprinter.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65439bd438cf48c0ce5f63662bb539b44f2c8fb7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_conventions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffa85f12ba28158a96e8dd1d2e9e3016a8f0df66 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_cupy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d6b364af129ef4ed4e1a059c221b5c34f6593a21 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_cxx.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d8537ff1f019f445fde536139c754bea32ec68b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_dot.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..15e4b008d84ca4d171412b71b6ae5de80ab1e201 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_fortran.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..44144357d1e2acb16aa5bf3235a3273d75b287bd Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_glsl.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5470b2da878b300d44b6fa11b2e286a781c751eb Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_gtk.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08c94dc81ebf9da33f0f634e3acb101a8dcb43ce Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_jax.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7918b39fa1b075515dc2cd2126e337036eb93f8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_jscode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de0b3e59cad735f9f951805171b45e551263d296 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_julia.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2cca681c4567749475ad6a556ad77fe9641d8f58 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_lambdarepr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a9a1fcdfead37db571a1a45d0451621166cdd1b6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_llvmjit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7277ac2cfa7133438984ad1c3df454ba6d2a0b16 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_maple.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74bf8bb726f663df5c3bdf82d3c4228d7aa63eec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_mathematica.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_mathml.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_mathml.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2ae808b9f9cac2dcad411a4fc7f4091aa7cb518f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_mathml.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97a1997165f9ec69749f61428a3f7efff5d22587 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_numpy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cb243e8672743aaadb6e1bdfa8c151e235c39582 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_octave.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1ed00fb281db5e3177548b5a549ddc318b4354b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_precedence.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..39ad9216b017c7773eddd3e363183691d71c2624 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_preview.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daec54e5cc628889c02ac1cdb7e749212dc1e7e2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_pycode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c38b5885529c3b221f5f81dc8db21403cda9db46 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_python.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cba7b04c1fb74621c412c23850942254dfecb9aa Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_rcode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de7150e29017b679dcd7aff702cfce9b18fed79a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_repr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..06c22b9c0a669f575e6784d73c0e4428df1aa2e6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_rust.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f38afc6110c4654eaa21fd6f3358d42a19738260 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_smtlib.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d663545d1f05c0bc7f0a0c0bd92e638aaf1be05 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_str.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa4a47e058c407c7dd0f2c192f0c2efe1e8c2050 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tableform.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11ffbca4fcd8b72031b207cac6606616db39bb89 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tensorflow.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b710bc2dd6f1a0dd34f8ed8e836673c093874d87 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_theanocode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-310.pyc b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2454119653f28231f61749aa5939d593395a29d0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/printing/tests/__pycache__/test_tree.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_aesaracode.py b/lib/python3.10/site-packages/sympy/printing/tests/test_aesaracode.py new file mode 100644 index 0000000000000000000000000000000000000000..28dee8fd8ed63aeea75ecf421085aed824aa4f17 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_aesaracode.py @@ -0,0 +1,626 @@ +""" +Important note on tests in this module - the Aesara printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the aesara_code_ and +aesara_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP + +from sympy.utilities.exceptions import ignore_warnings + + +aesaralogger = logging.getLogger('aesara.configdefaults') +aesaralogger.setLevel(logging.CRITICAL) +aesara = import_module('aesara') +aesaralogger.setLevel(logging.WARNING) + + +if aesara: + import numpy as np + aet = aesara.tensor + from aesara.scalar.basic import ScalarType + from aesara.graph.basic import Variable + from aesara.tensor.var import TensorVariable + from aesara.tensor.elemwise import Elemwise, DimShuffle + from aesara.tensor.math import Dot + + from sympy.printing.aesaracode import true_divide + + xt, yt, zt = [aet.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [aet.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.aesaracode import (aesara_code, dim_handling, + aesara_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def aesara_code_(expr, **kwargs): + """ Wrapper for aesara_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + return aesara_code(expr, **kwargs) + +def aesara_function_(inputs, outputs, **kwargs): + """ Wrapper for aesara_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + return aesara_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Aesara Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + outs = list(map(aesara_code_, exprs)) + ins = list(aesara.graph.basic.graph_inputs(outs)) + ins, outs = aesara.graph.basic.clone(ins, outs) + return aesara.graph.fg.FunctionGraph(ins, outs) + + +def aesara_simplify(fgraph): + """ Simplify a Aesara Computation. + + Parameters + ========== + fgraph : aesara.graph.fg.FunctionGraph + + Returns + ======= + aesara.graph.fg.FunctionGraph + """ + mode = aesara.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.rewrite(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Aesara objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = aesara.printing.debugprint(a, file='str') + bstr = aesara.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'aesara.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Aesara + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, aesara_code_(x)) + assert theq(yt, aesara_code_(y)) + assert theq(zt, aesara_code_(z)) + assert theq(Xt, aesara_code_(X)) + assert theq(Yt, aesara_code_(Y)) + assert theq(Zt, aesara_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a aesara variable. """ + xx = aesara_code_(x) + assert isinstance(xx, Variable) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = aesara_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a aesara variable. """ + XX = aesara_code_(X) + assert isinstance(XX, TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + aesara_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = aesara_code_(f_t) + assert isinstance(ftt, TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = aesara_code_(expr) + assert comp.owner.op == aesara.tensor.add + +def test_trig(): + assert theq(aesara_code_(sy.sin(x)), aet.sin(xt)) + assert theq(aesara_code_(sy.tan(x)), aet.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = aesara_code_(expr) + expected = aet.exp(xt**2 + aet.cos(yt)) * aet.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert aesara_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert aesara_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert aesara_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert aesara_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert aesara_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = aesara_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = aesara_code_(expr) + assert isinstance(expr_t.owner.op, Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(aesara_code_(X.T).owner.op, DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(aesara_code_(expr).owner.op, Elemwise) + + +def test_Rationals(): + assert theq(aesara_code_(sy.Integer(2) / 3), true_divide(2, 3)) + assert theq(aesara_code_(S.Half), true_divide(1, 2)) + +def test_Integers(): + assert aesara_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert aesara_code_(sy.factorial(n)) + +def test_Derivative(): + with ignore_warnings(UserWarning): + simp = lambda expr: aesara_simplify(fgraph_of(expr)) + assert theq(simp(aesara_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(aesara.grad(aet.sin(xt), xt))) + + +def test_aesara_function_simple(): + """ Test aesara_function() with single output. """ + f = aesara_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_aesara_function_multi(): + """ Test aesara_function() with multiple outputs. """ + f = aesara_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_aesara_function_numpy(): + """ Test aesara_function() vs Numpy implementation. """ + f = aesara_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_aesara_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = aesara_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = aesara_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_aesara_function_kwargs(): + """ + Test passing additional kwargs from aesara_function() to aesara.function(). + """ + import numpy as np + f = aesara_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = aesara_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_aesara_function_scalar(): + """ Test the "scalar" argument to aesara_function(). """ + from aesara.compile.function.types import Function + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = aesara_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the aesara_function attribute is set whether wrapped or not + assert isinstance(f.aesara_function, Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.aesara_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_aesara_function_bad_kwarg(): + """ + Passing an unknown keyword argument to aesara_function() should raise an + exception. + """ + raises(Exception, lambda : aesara_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert aesara_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(aesara_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(aesara_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = aesara_code_(Y, cache=cache) + + s = ScalarType('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == aesara_code_(X, cache=cache) + # == doesn't work in Aesara like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].data == i for i in range(1, 7)) + + k = sy.Symbol('k') + aesara_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = aesara_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(aesara_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = aesara_code_(Block) + solutions = [aet.join(0, aet.join(1, At, Bt), aet.join(1, Ct, Dt)), + aet.join(1, aet.join(0, At, Ct), aet.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = aesara_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = aesara_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + from aesara.tensor.basic import Join + + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = aesara_code_(X) + assert isinstance(tX, TensorVariable) + assert isinstance(tX.owner.op, Join) + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = aesara_code_(s1, cache=cache) + + # Test hit with same instance + assert aesara_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert aesara_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert aesara_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.aesaracode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + st = aesara_code(s) + assert aesara_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = aesara_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + assert aesara_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = aesara_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + aesara_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = aesara_code_(expr) + + # Iterate through variables in the Aesara computational graph that the + # printed expression depends on + seen = set() + for v in aesara.graph.basic.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, aesara.graph.basic.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = aesara_code_(expr) + assert result.owner.op == aet.switch + + expected = aet.switch(xt<0, 0, aet.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = aesara_code_(expr) + expected = aet.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = aesara_code_(expr) + expected = aet.switch(aet.and_(xt>0,xt<2), 0, \ + aet.switch(aet.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(aesara_code_(sy.Eq(x, y)), aet.eq(xt, yt)) + # assert theq(aesara_code_(sy.Ne(x, y)), aet.neq(xt, yt)) # TODO - implement + assert theq(aesara_code_(x > y), xt > yt) + assert theq(aesara_code_(x < y), xt < yt) + assert theq(aesara_code_(x >= y), xt >= yt) + assert theq(aesara_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + dtypes = {x:'complex128', y:'complex128'} + xt, yt = aesara_code(x, dtypes=dtypes), aesara_code(y, dtypes=dtypes) + from sympy.functions.elementary.complexes import conjugate + from aesara.tensor import as_tensor_variable as atv + from aesara.tensor import complex as cplx + assert theq(aesara_code(y*conjugate(x), dtypes=dtypes), yt*(xt.conj())) + assert theq(aesara_code((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + tf = aesara_function([],[1+1j]) + assert(tf()==1+1j) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_c.py b/lib/python3.10/site-packages/sympy/printing/tests/test_c.py new file mode 100644 index 0000000000000000000000000000000000000000..11836539f0b03a94cfa8f6ee52460ca6a9ffa1a1 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_c.py @@ -0,0 +1,883 @@ +from sympy.core import ( + S, pi, oo, Symbol, symbols, Rational, Integer, Float, Function, Mod, GoldenRatio, EulerGamma, Catalan, + Lambda, Dummy, nan, Mul, Pow, UnevaluatedExpr +) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.functions import ( + Abs, acos, acosh, asin, asinh, atan, atanh, atan2, ceiling, cos, cosh, erf, + erfc, exp, floor, gamma, log, loggamma, Max, Min, Piecewise, sign, sin, sinh, + sqrt, tan, tanh, fibonacci, lucas +) +from sympy.sets import Range +from sympy.logic import ITE, Implies, Equivalent +from sympy.codegen import For, aug_assign, Assignment +from sympy.testing.pytest import raises, XFAIL +from sympy.printing.codeprinter import PrintMethodNotImplementedError +from sympy.printing.c import C89CodePrinter, C99CodePrinter, get_math_macros +from sympy.codegen.ast import ( + AddAugmentedAssignment, Element, Type, FloatType, Declaration, Pointer, Variable, value_const, pointer_const, + While, Scope, Print, FunctionPrototype, FunctionDefinition, FunctionCall, Return, + real, float32, float64, float80, float128, intc, Comment, CodeBlock, stderr, QuotedString +) +from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, fma, log10, Cbrt, hypot, Sqrt +from sympy.codegen.cnodes import restrict +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix + +from sympy.printing.codeprinter import ccode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + class fabs(Abs): + def _ccode(self, printer): + return "fabs(%s)" % printer._print(self.args[0]) + + assert ccode(fabs(x)) == "fabs(x)" + + +def test_ccode_sqrt(): + assert ccode(sqrt(x)) == "sqrt(x)" + assert ccode(x**0.5) == "sqrt(x)" + assert ccode(sqrt(x)) == "sqrt(x)" + + +def test_ccode_Pow(): + assert ccode(x**3) == "pow(x, 3)" + assert ccode(x**(y**3)) == "pow(x, pow(y, 3))" + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2) + y)" + assert ccode(x**-1.0) == '1.0/x' + assert ccode(x**Rational(2, 3)) == 'pow(x, 2.0/3.0)' + assert ccode(x**Rational(2, 3), type_aliases={real: float80}) == 'powl(x, 2.0L/3.0L)' + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"), + (lambda base, exp: not exp.is_integer, "pow")] + assert ccode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)' + assert ccode(x**0.5, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 0.5)' + assert ccode(x**Rational(16, 5), user_functions={'Pow': _cond_cfunc}) == 'pow(x, 16.0/5.0)' + _cond_cfunc2 = [(lambda base, exp: base == 2, lambda base, exp: 'exp2(%s)' % exp), + (lambda base, exp: base != 2, 'pow')] + # Related to gh-11353 + assert ccode(2**x, user_functions={'Pow': _cond_cfunc2}) == 'exp2(x)' + assert ccode(x**2, user_functions={'Pow': _cond_cfunc2}) == 'pow(x, 2)' + # For issue 14160 + assert ccode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_ccode_Max(): + # Test for gh-11926 + assert ccode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))' + + +def test_ccode_Min_performance(): + #Shouldn't take more than a few seconds + big_min = Min(*symbols('a[0:50]')) + for curr_standard in ('c89', 'c99', 'c11'): + output = ccode(big_min, standard=curr_standard) + assert output.count('(') == output.count(')') + + +def test_ccode_constants_mathh(): + assert ccode(exp(1)) == "M_E" + assert ccode(pi) == "M_PI" + assert ccode(oo, standard='c89') == "HUGE_VAL" + assert ccode(-oo, standard='c89') == "-HUGE_VAL" + assert ccode(oo) == "INFINITY" + assert ccode(-oo, standard='c99') == "-INFINITY" + assert ccode(pi, type_aliases={real: float80}) == "M_PIl" + + +def test_ccode_constants_other(): + assert ccode(2*GoldenRatio) == "const double GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert ccode( + 2*Catalan) == "const double Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert ccode(2*EulerGamma) == "const double EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_ccode_Rational(): + assert ccode(Rational(3, 7)) == "3.0/7.0" + assert ccode(Rational(3, 7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(Rational(18, 9)) == "2" + assert ccode(Rational(3, -7)) == "-3.0/7.0" + assert ccode(Rational(3, -7), type_aliases={real: float80}) == "-3.0L/7.0L" + assert ccode(Rational(-3, -7)) == "3.0/7.0" + assert ccode(Rational(-3, -7), type_aliases={real: float80}) == "3.0L/7.0L" + assert ccode(x + Rational(3, 7)) == "x + 3.0/7.0" + assert ccode(x + Rational(3, 7), type_aliases={real: float80}) == "x + 3.0L/7.0L" + assert ccode(Rational(3, 7)*x) == "(3.0/7.0)*x" + assert ccode(Rational(3, 7)*x, type_aliases={real: float80}) == "(3.0L/7.0L)*x" + + +def test_ccode_Integer(): + assert ccode(Integer(67)) == "67" + assert ccode(Integer(-1)) == "-1" + + +def test_ccode_functions(): + assert ccode(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_ccode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert ccode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert ccode( + g(x)) == "const double Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert ccode(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i y" + assert ccode(Ge(x, y)) == "x >= y" + + +def test_ccode_Piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + "))") + assert ccode(expr, assign_to="c") == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + expr = Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True)) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + assert ccode(expr, assign_to='c') == ( + "if (x < 1) {\n" + " c = x;\n" + "}\n" + "else if (x < 2) {\n" + " c = x + 1;\n" + "}\n" + "else {\n" + " c = pow(x, 2);\n" + "}") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: ccode(expr)) + + +def test_ccode_sinc(): + from sympy.functions.elementary.trigonometric import sinc + expr = sinc(x) + assert ccode(expr) == ( + "(((x != 0) ? (\n" + " sin(x)/x\n" + ")\n" + ": (\n" + " 1\n" + ")))") + + +def test_ccode_Piecewise_deep(): + p = ccode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))) + assert p == ( + "2*((x < 1) ? (\n" + " x\n" + ")\n" + ": ((x < 2) ? (\n" + " x + 1\n" + ")\n" + ": (\n" + " pow(x, 2)\n" + ")))") + expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1 + assert ccode(expr) == ( + "pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1") + assert ccode(expr, assign_to='c') == ( + "c = pow(x, 2) + x*y*z + pow(y, 2) + ((x < 0.5) ? (\n" + " 0\n" + ")\n" + ": (\n" + " 1\n" + ")) + cos(z) - 1;") + + +def test_ccode_ITE(): + expr = ITE(x < 1, y, z) + assert ccode(expr) == ( + "((x < 1) ? (\n" + " y\n" + ")\n" + ": (\n" + " z\n" + "))") + + +def test_ccode_settings(): + raises(TypeError, lambda: ccode(sin(x), method="garbage")) + + +def test_ccode_Indexed(): + s, n, m, o = symbols('s n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + A = IndexedBase('A')[i, j] + B = IndexedBase('B')[i, j, k] + + p = C99CodePrinter() + + assert p._print_Indexed(x) == 'x[j]' + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + A = IndexedBase('A', shape=(5,3))[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (3*i + j) + + A = IndexedBase('A', shape=(5,3), strides='F')[i, j] + assert ccode(A) == 'A[%s]' % (i + 5*j) + + A = IndexedBase('A', shape=(29,29), strides=(1, s), offset=o)[i, j] + assert ccode(A) == 'A[o + s*j + i]' + + Abase = IndexedBase('A', strides=(s, m, n), offset=o) + assert ccode(Abase[i, j, k]) == 'A[m*j + n*k + o + s*i]' + assert ccode(Abase[2, 3, k]) == 'A[3*m + n*k + o + 2*s]' + + +def test_Element(): + assert ccode(Element('x', 'ij')) == 'x[i][j]' + assert ccode(Element('x', 'ij', strides='kl', offset='o')) == 'x[i*k + j*l + o]' + assert ccode(Element('x', (3,))) == 'x[3]' + assert ccode(Element('x', (3,4,5))) == 'x[3][4][5]' + + +def test_ccode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e = Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = ccode(e.rhs, assign_to=e.lhs, contract=False) + assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1) + + +def test_ccode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert ccode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert ccode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert ccode(m, M) == ( + "M[0] = sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_sparse_matrix(): + # gh-15791 + with raises(PrintMethodNotImplementedError): + ccode(SparseMatrix([[1, 2, 3]])) + + assert 'Not supported in C' in C89CodePrinter({'strict': False}).doprint(SparseMatrix([[1, 2, 3]])) + + + +def test_ccode_reserved_words(): + x, y = symbols('x, if') + with raises(ValueError): + ccode(y**2, error_on_reserved=True, standard='C99') + assert ccode(y**2) == 'pow(if_, 2)' + assert ccode(x * y**2, dereference=[y]) == 'pow((*if_), 2)*x' + assert ccode(y**2, reserved_word_suffix='_unreserved') == 'pow(if_unreserved, 2)' + + +def test_ccode_sign(): + expr1, ref1 = sign(x) * y, 'y*(((x) > 0) - ((x) < 0))' + expr2, ref2 = sign(cos(x)), '(((cos(x)) > 0) - ((cos(x)) < 0))' + expr3, ref3 = sign(2 * x + x**2) * x + x**2, 'pow(x, 2) + x*(((pow(x, 2) + 2*x) > 0) - ((pow(x, 2) + 2*x) < 0))' + assert ccode(expr1) == ref1 + assert ccode(expr1, 'z') == 'z = %s;' % ref1 + assert ccode(expr2) == ref2 + assert ccode(expr3) == ref3 + +def test_ccode_Assignment(): + assert ccode(Assignment(x, y + z)) == 'x = y + z;' + assert ccode(aug_assign(x, '+', y + z)) == 'x += y + z;' + + +def test_ccode_For(): + f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)]) + assert ccode(f) == ("for (x = 0; x < 10; x += 2) {\n" + " y *= x;\n" + "}") + +def test_ccode_Max_Min(): + assert ccode(Max(x, 0), standard='C89') == '((0 > x) ? 0 : x)' + assert ccode(Max(x, 0), standard='C99') == 'fmax(0, x)' + assert ccode(Min(x, 0, sqrt(x)), standard='c89') == ( + '((0 < ((x < sqrt(x)) ? x : sqrt(x))) ? 0 : ((x < sqrt(x)) ? x : sqrt(x)))' + ) + +def test_ccode_standard(): + assert ccode(expm1(x), standard='c99') == 'expm1(x)' + assert ccode(nan, standard='c99') == 'NAN' + assert ccode(float('nan'), standard='c99') == 'NAN' + + +def test_C89CodePrinter(): + c89printer = C89CodePrinter() + assert c89printer.language == 'C' + assert c89printer.standard == 'C89' + assert 'void' in c89printer.reserved_words + assert 'template' not in c89printer.reserved_words + + +def test_C99CodePrinter(): + assert C99CodePrinter().doprint(expm1(x)) == 'expm1(x)' + assert C99CodePrinter().doprint(log1p(x)) == 'log1p(x)' + assert C99CodePrinter().doprint(exp2(x)) == 'exp2(x)' + assert C99CodePrinter().doprint(log2(x)) == 'log2(x)' + assert C99CodePrinter().doprint(fma(x, y, -z)) == 'fma(x, y, -z)' + assert C99CodePrinter().doprint(log10(x)) == 'log10(x)' + assert C99CodePrinter().doprint(Cbrt(x)) == 'cbrt(x)' # note Cbrt due to cbrt already taken. + assert C99CodePrinter().doprint(hypot(x, y)) == 'hypot(x, y)' + assert C99CodePrinter().doprint(loggamma(x)) == 'lgamma(x)' + assert C99CodePrinter().doprint(Max(x, 3, x**2)) == 'fmax(3, fmax(x, pow(x, 2)))' + assert C99CodePrinter().doprint(Min(x, 3)) == 'fmin(3, x)' + c99printer = C99CodePrinter() + assert c99printer.language == 'C' + assert c99printer.standard == 'C99' + assert 'restrict' in c99printer.reserved_words + assert 'using' not in c99printer.reserved_words + + +@XFAIL +def test_C99CodePrinter__precision_f80(): + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f80_printer.doprint(sin(x + Float('2.1'))) == 'sinl(x + 2.1L)' + + +def test_C99CodePrinter__precision(): + n = symbols('n', integer=True) + p = symbols('p', integer=True, positive=True) + f32_printer = C99CodePrinter({"type_aliases": {real: float32}}) + f64_printer = C99CodePrinter({"type_aliases": {real: float64}}) + f80_printer = C99CodePrinter({"type_aliases": {real: float80}}) + assert f32_printer.doprint(sin(x+2.1)) == 'sinf(x + 2.1F)' + assert f64_printer.doprint(sin(x+2.1)) == 'sin(x + 2.1000000000000001)' + assert f80_printer.doprint(sin(x+Float('2.0'))) == 'sinl(x + 2.0L)' + + for printer, suffix in zip([f32_printer, f64_printer, f80_printer], ['f', '', 'l']): + def check(expr, ref): + assert printer.doprint(expr) == ref.format(s=suffix, S=suffix.upper()) + check(Abs(n), 'abs(n)') + check(Abs(x + 2.0), 'fabs{s}(x + 2.0{S})') + check(sin(x + 4.0)**cos(x - 2.0), 'pow{s}(sin{s}(x + 4.0{S}), cos{s}(x - 2.0{S}))') + check(exp(x*8.0), 'exp{s}(8.0{S}*x)') + check(exp2(x), 'exp2{s}(x)') + check(expm1(x*4.0), 'expm1{s}(4.0{S}*x)') + check(Mod(p, 2), 'p % 2') + check(Mod(2*p + 3, 3*p + 5, evaluate=False), '(2*p + 3) % (3*p + 5)') + check(Mod(x + 2.0, 3.0), 'fmod{s}(1.0{S}*x + 2.0{S}, 3.0{S})') + check(Mod(x, 2.0*x + 3.0), 'fmod{s}(1.0{S}*x, 2.0{S}*x + 3.0{S})') + check(log(x/2), 'log{s}((1.0{S}/2.0{S})*x)') + check(log10(3*x/2), 'log10{s}((3.0{S}/2.0{S})*x)') + check(log2(x*8.0), 'log2{s}(8.0{S}*x)') + check(log1p(x), 'log1p{s}(x)') + check(2**x, 'pow{s}(2, x)') + check(2.0**x, 'pow{s}(2.0{S}, x)') + check(x**3, 'pow{s}(x, 3)') + check(x**4.0, 'pow{s}(x, 4.0{S})') + check(sqrt(3+x), 'sqrt{s}(x + 3)') + check(Cbrt(x-2.0), 'cbrt{s}(x - 2.0{S})') + check(hypot(x, y), 'hypot{s}(x, y)') + check(sin(3.*x + 2.), 'sin{s}(3.0{S}*x + 2.0{S})') + check(cos(3.*x - 1.), 'cos{s}(3.0{S}*x - 1.0{S})') + check(tan(4.*y + 2.), 'tan{s}(4.0{S}*y + 2.0{S})') + check(asin(3.*x + 2.), 'asin{s}(3.0{S}*x + 2.0{S})') + check(acos(3.*x + 2.), 'acos{s}(3.0{S}*x + 2.0{S})') + check(atan(3.*x + 2.), 'atan{s}(3.0{S}*x + 2.0{S})') + check(atan2(3.*x, 2.*y), 'atan2{s}(3.0{S}*x, 2.0{S}*y)') + + check(sinh(3.*x + 2.), 'sinh{s}(3.0{S}*x + 2.0{S})') + check(cosh(3.*x - 1.), 'cosh{s}(3.0{S}*x - 1.0{S})') + check(tanh(4.0*y + 2.), 'tanh{s}(4.0{S}*y + 2.0{S})') + check(asinh(3.*x + 2.), 'asinh{s}(3.0{S}*x + 2.0{S})') + check(acosh(3.*x + 2.), 'acosh{s}(3.0{S}*x + 2.0{S})') + check(atanh(3.*x + 2.), 'atanh{s}(3.0{S}*x + 2.0{S})') + check(erf(42.*x), 'erf{s}(42.0{S}*x)') + check(erfc(42.*x), 'erfc{s}(42.0{S}*x)') + check(gamma(x), 'tgamma{s}(x)') + check(loggamma(x), 'lgamma{s}(x)') + + check(ceiling(x + 2.), "ceil{s}(x) + 2") + check(floor(x + 2.), "floor{s}(x) + 2") + check(fma(x, y, -z), 'fma{s}(x, y, -z)') + check(Max(x, 8.0, x**4.0), 'fmax{s}(8.0{S}, fmax{s}(x, pow{s}(x, 4.0{S})))') + check(Min(x, 2.0), 'fmin{s}(2.0{S}, x)') + + +def test_get_math_macros(): + macros = get_math_macros() + assert macros[exp(1)] == 'M_E' + assert macros[1/Sqrt(2)] == 'M_SQRT1_2' + + +def test_ccode_Declaration(): + i = symbols('i', integer=True) + var1 = Variable(i, type=Type.from_expr(i)) + dcl1 = Declaration(var1) + assert ccode(dcl1) == 'int i' + + var2 = Variable(x, type=float32, attrs={value_const}) + dcl2a = Declaration(var2) + assert ccode(dcl2a) == 'const float x' + dcl2b = var2.as_Declaration(value=pi) + assert ccode(dcl2b) == 'const float x = M_PI' + + var3 = Variable(y, type=Type('bool')) + dcl3 = Declaration(var3) + printer = C89CodePrinter() + assert 'stdbool.h' not in printer.headers + assert printer.doprint(dcl3) == 'bool y' + assert 'stdbool.h' in printer.headers + + u = symbols('u', real=True) + ptr4 = Pointer.deduced(u, attrs={pointer_const, restrict}) + dcl4 = Declaration(ptr4) + assert ccode(dcl4) == 'double * const restrict u' + + var5 = Variable(x, Type('__float128'), attrs={value_const}) + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const __float128 x' + var5b = Variable(var5.symbol, var5.type, pi, attrs=var5.attrs) + dcl5b = Declaration(var5b) + assert ccode(dcl5b) == 'const __float128 x = M_PI' + + +def test_C99CodePrinter_custom_type(): + # We will look at __float128 (new in glibc 2.26) + f128 = FloatType('_Float128', float128.nbits, float128.nmant, float128.nexp) + p128 = C99CodePrinter({ + "type_aliases": {real: f128}, + "type_literal_suffixes": {f128: 'Q'}, + "type_func_suffixes": {f128: 'f128'}, + "type_math_macro_suffixes": { + real: 'f128', + f128: 'f128' + }, + "type_macros": { + f128: ('__STDC_WANT_IEC_60559_TYPES_EXT__',) + } + }) + assert p128.doprint(x) == 'x' + assert not p128.headers + assert not p128.libraries + assert not p128.macros + assert p128.doprint(2.0) == '2.0Q' + assert not p128.headers + assert not p128.libraries + assert p128.macros == {'__STDC_WANT_IEC_60559_TYPES_EXT__'} + + assert p128.doprint(Rational(1, 2)) == '1.0Q/2.0Q' + assert p128.doprint(sin(x)) == 'sinf128(x)' + assert p128.doprint(cos(2., evaluate=False)) == 'cosf128(2.0Q)' + assert p128.doprint(x**-1.0) == '1.0Q/x' + + var5 = Variable(x, f128, attrs={value_const}) + + dcl5a = Declaration(var5) + assert ccode(dcl5a) == 'const _Float128 x' + var5b = Variable(x, f128, pi, attrs={value_const}) + dcl5b = Declaration(var5b) + assert p128.doprint(dcl5b) == 'const _Float128 x = M_PIf128' + var5b = Variable(x, f128, value=Catalan.evalf(38), attrs={value_const}) + dcl5c = Declaration(var5b) + assert p128.doprint(dcl5c) == 'const _Float128 x = %sQ' % Catalan.evalf(f128.decimal_dig) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(ccode(A[0, 0]) == "A[0]") + assert(ccode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(ccode(F) == "(A - B)[0]") + +def test_ccode_math_macros(): + assert ccode(z + exp(1)) == 'z + M_E' + assert ccode(z + log2(exp(1))) == 'z + M_LOG2E' + assert ccode(z + 1/log(2)) == 'z + M_LOG2E' + assert ccode(z + log(2)) == 'z + M_LN2' + assert ccode(z + log(10)) == 'z + M_LN10' + assert ccode(z + pi) == 'z + M_PI' + assert ccode(z + pi/2) == 'z + M_PI_2' + assert ccode(z + pi/4) == 'z + M_PI_4' + assert ccode(z + 1/pi) == 'z + M_1_PI' + assert ccode(z + 2/pi) == 'z + M_2_PI' + assert ccode(z + 2/sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + 2/Sqrt(pi)) == 'z + M_2_SQRTPI' + assert ccode(z + sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + Sqrt(2)) == 'z + M_SQRT2' + assert ccode(z + 1/sqrt(2)) == 'z + M_SQRT1_2' + assert ccode(z + 1/Sqrt(2)) == 'z + M_SQRT1_2' + + +def test_ccode_Type(): + assert ccode(Type('float')) == 'float' + assert ccode(intc) == 'int' + + +def test_ccode_codegen_ast(): + # Note that C only allows comments of the form /* ... */, double forward + # slash is not standard C, and some C compilers will grind to a halt upon + # encountering them. + assert ccode(Comment("this is a comment")) == "/* this is a comment */" # not // + assert ccode(While(abs(x) > 1, [aug_assign(x, '-', 1)])) == ( + 'while (fabs(x) > 1) {\n' + ' x -= 1;\n' + '}' + ) + assert ccode(Scope([AddAugmentedAssignment(x, 1)])) == ( + '{\n' + ' x += 1;\n' + '}' + ) + inp_x = Declaration(Variable(x, type=real)) + assert ccode(FunctionPrototype(real, 'pwer', [inp_x])) == 'double pwer(double x)' + assert ccode(FunctionDefinition(real, 'pwer', [inp_x], [Assignment(x, x**2)])) == ( + 'double pwer(double x){\n' + ' x = pow(x, 2);\n' + '}' + ) + + # Elements of CodeBlock are formatted as statements: + block = CodeBlock( + x, + Print([x, y], "%d %d"), + Print([QuotedString('hello'), y], "%s %d", file=stderr), + FunctionCall('pwer', [x]), + Return(x), + ) + assert ccode(block) == '\n'.join([ + 'x;', + 'printf("%d %d", x, y);', + 'fprintf(stderr, "%s %d", "hello", y);', + 'pwer(x);', + 'return x;', + ]) + +def test_ccode_UnevaluatedExpr(): + assert ccode(UnevaluatedExpr(y * x) + z) == "z + x*y" + assert ccode(UnevaluatedExpr(y + x) + z) == "z + (x + y)" # gh-21955 + w = symbols('w') + assert ccode(UnevaluatedExpr(y + x) + UnevaluatedExpr(z + w)) == "(w + z) + (x + y)" + + p, q, r = symbols("p q r", real=True) + q_r = UnevaluatedExpr(q + r) + expr = abs(exp(p+q_r)) + assert ccode(expr) == "exp(p + (q + r))" + + +def test_ccode_array_like_containers(): + assert ccode([2,3,4]) == "{2, 3, 4}" + assert ccode((2,3,4)) == "{2, 3, 4}" diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_codeprinter.py b/lib/python3.10/site-packages/sympy/printing/tests/test_codeprinter.py new file mode 100644 index 0000000000000000000000000000000000000000..2d89a27dab37f352aa6aa0a41e8ffa6155b65518 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_codeprinter.py @@ -0,0 +1,55 @@ +from sympy.printing.codeprinter import CodePrinter, PrintMethodNotImplementedError +from sympy.core import symbols +from sympy.core.symbol import Dummy +from sympy.testing.pytest import raises + + +def setup_test_printer(**kwargs): + p = CodePrinter(settings=kwargs) + p._not_supported = set() + p._number_symbols = set() + return p + + +def test_print_Dummy(): + d = Dummy('d') + p = setup_test_printer() + assert p._print_Dummy(d) == "d_%i" % d.dummy_index + +def test_print_Symbol(): + + x, y = symbols('x, if') + + p = setup_test_printer() + assert p._print(x) == 'x' + assert p._print(y) == 'if' + + p.reserved_words.update(['if']) + assert p._print(y) == 'if_' + + p = setup_test_printer(error_on_reserved=True) + p.reserved_words.update(['if']) + with raises(ValueError): + p._print(y) + + p = setup_test_printer(reserved_word_suffix='_He_Man') + p.reserved_words.update(['if']) + assert p._print(y) == 'if_He_Man' + +def test_issue_15791(): + class CrashingCodePrinter(CodePrinter): + def emptyPrinter(self, obj): + raise NotImplementedError + + from sympy.matrices import ( + MutableSparseMatrix, + ImmutableSparseMatrix, + ) + + c = CrashingCodePrinter() + + # these should not silently succeed + with raises(PrintMethodNotImplementedError): + c.doprint(ImmutableSparseMatrix(2, 2, {})) + with raises(PrintMethodNotImplementedError): + c.doprint(MutableSparseMatrix(2, 2, {})) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_conventions.py b/lib/python3.10/site-packages/sympy/printing/tests/test_conventions.py new file mode 100644 index 0000000000000000000000000000000000000000..e8f1fa8532f96130828b89d1ba5ba11fd5bed7a4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_conventions.py @@ -0,0 +1,116 @@ +# -*- coding: utf-8 -*- + +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.functions.special.bessel import besselj +from sympy.functions.special.polynomials import legendre +from sympy.functions.combinatorial.numbers import bell +from sympy.printing.conventions import split_super_sub, requires_partial +from sympy.testing.pytest import XFAIL + +def test_super_sub(): + assert split_super_sub("beta_13_2") == ("beta", [], ["13", "2"]) + assert split_super_sub("beta_132_20") == ("beta", [], ["132", "20"]) + assert split_super_sub("beta_13") == ("beta", [], ["13"]) + assert split_super_sub("x_a_b") == ("x", [], ["a", "b"]) + assert split_super_sub("x_1_2_3") == ("x", [], ["1", "2", "3"]) + assert split_super_sub("x_a_b1") == ("x", [], ["a", "b1"]) + assert split_super_sub("x_a_1") == ("x", [], ["a", "1"]) + assert split_super_sub("x_1_a") == ("x", [], ["1", "a"]) + assert split_super_sub("x_1^aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_1__aa") == ("x", ["aa"], ["1"]) + assert split_super_sub("x_11^a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_11__a") == ("x", ["a"], ["11"]) + assert split_super_sub("x_a_b_c_d") == ("x", [], ["a", "b", "c", "d"]) + assert split_super_sub("x_a_b^c^d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a_b__c__d") == ("x", ["c", "d"], ["a", "b"]) + assert split_super_sub("x_a^b_c^d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x_a__b_c__d") == ("x", ["b", "d"], ["a", "c"]) + assert split_super_sub("x^a^b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x__a__b_c_d") == ("x", ["a", "b"], ["c", "d"]) + assert split_super_sub("x^a^b^c^d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("x__a__b__c__d") == ("x", ["a", "b", "c", "d"], []) + assert split_super_sub("alpha_11") == ("alpha", [], ["11"]) + assert split_super_sub("alpha_11_11") == ("alpha", [], ["11", "11"]) + assert split_super_sub("w1") == ("w", [], ["1"]) + assert split_super_sub("w𝟙") == ("w", [], ["𝟙"]) + assert split_super_sub("w11") == ("w", [], ["11"]) + assert split_super_sub("w𝟙𝟙") == ("w", [], ["𝟙𝟙"]) + assert split_super_sub("w𝟙2𝟙") == ("w", [], ["𝟙2𝟙"]) + assert split_super_sub("w1^a") == ("w", ["a"], ["1"]) + assert split_super_sub("ω1") == ("ω", [], ["1"]) + assert split_super_sub("ω11") == ("ω", [], ["11"]) + assert split_super_sub("ω1^a") == ("ω", ["a"], ["1"]) + assert split_super_sub("ω𝟙^α") == ("ω", ["α"], ["𝟙"]) + assert split_super_sub("ω𝟙2^3α") == ("ω", ["3α"], ["𝟙2"]) + assert split_super_sub("") == ("", [], []) + + +def test_requires_partial(): + x, y, z, t, nu = symbols('x y z t nu') + n = symbols('n', integer=True) + + f = x * y + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, y)) is True + + ## integrating out one of the variables + assert requires_partial(Derivative(Integral(exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + ## bessel function with smooth parameter + f = besselj(nu, x) + assert requires_partial(Derivative(f, x)) is True + assert requires_partial(Derivative(f, nu)) is True + + ## bessel function with integer parameter + f = besselj(n, x) + assert requires_partial(Derivative(f, x)) is False + # this is not really valid (differentiating with respect to an integer) + # but there's no reason to use the partial derivative symbol there. make + # sure we don't throw an exception here, though + assert requires_partial(Derivative(f, n)) is False + + ## bell polynomial + f = bell(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + ## legendre polynomial + f = legendre(0, x) + assert requires_partial(Derivative(f, x)) is False + + f = legendre(n, x) + assert requires_partial(Derivative(f, x)) is False + # again, invalid + assert requires_partial(Derivative(f, n)) is False + + f = x ** n + assert requires_partial(Derivative(f, x)) is False + + assert requires_partial(Derivative(Integral((x*y) ** n * exp(-x * y), (x, 0, oo)), y, evaluate=False)) is False + + # parametric equation + f = (exp(t), cos(t)) + g = sum(f) + assert requires_partial(Derivative(g, t)) is False + + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f(x), x)) is False + assert requires_partial(Derivative(f(x), y)) is False + assert requires_partial(Derivative(f(x, y), x)) is True + assert requires_partial(Derivative(f(x, y), y)) is True + assert requires_partial(Derivative(f(x, y), z)) is True + assert requires_partial(Derivative(f(x, y), x, y)) is True + +@XFAIL +def test_requires_partial_unspecified_variables(): + x, y = symbols('x y') + # function of unspecified variables + f = symbols('f', cls=Function) + assert requires_partial(Derivative(f, x)) is False + assert requires_partial(Derivative(f, x, y)) is True diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_cupy.py b/lib/python3.10/site-packages/sympy/printing/tests/test_cupy.py new file mode 100644 index 0000000000000000000000000000000000000000..cf111ec1623390a3dbbf489235d2ed387624a36c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_cupy.py @@ -0,0 +1,56 @@ +from sympy.concrete.summations import Sum +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.utilities.lambdify import lambdify +from sympy.abc import x, i, a, b +from sympy.codegen.numpy_nodes import logaddexp +from sympy.printing.numpy import CuPyPrinter, _cupy_known_constants, _cupy_known_functions + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +cp = import_module('cupy') + +def test_cupy_print(): + prntr = CuPyPrinter() + assert prntr.doprint(logaddexp(a, b)) == 'cupy.logaddexp(a, b)' + assert prntr.doprint(sqrt(x)) == 'cupy.sqrt(x)' + assert prntr.doprint(log(x)) == 'cupy.log(x)' + assert prntr.doprint("acos(x)") == 'cupy.arccos(x)' + assert prntr.doprint("exp(x)") == 'cupy.exp(x)' + assert prntr.doprint("Abs(x)") == 'abs(x)' + +def test_not_cupy_print(): + prntr = CuPyPrinter() + with raises(NotImplementedError): + prntr.doprint("abcd(x)") + +def test_cupy_sum(): + if not cp: + skip("CuPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'cupy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = cp.linspace(-1, +1, 10) + assert cp.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + +def test_cupy_known_funcs_consts(): + assert _cupy_known_constants['NaN'] == 'cupy.nan' + assert _cupy_known_constants['EulerGamma'] == 'cupy.euler_gamma' + + assert _cupy_known_functions['acos'] == 'cupy.arccos' + assert _cupy_known_functions['log'] == 'cupy.log' + +def test_cupy_print_methods(): + prntr = CuPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_cxx.py b/lib/python3.10/site-packages/sympy/printing/tests/test_cxx.py new file mode 100644 index 0000000000000000000000000000000000000000..d84ec75cbf0eeb60a1176b9cb3b401a3384454e7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_cxx.py @@ -0,0 +1,86 @@ +from sympy.core.numbers import Float, Integer, Rational +from sympy.core.symbol import symbols +from sympy.functions import beta, Ei, zeta, Max, Min, sqrt, riemann_xi, frac +from sympy.printing.cxx import CXX98CodePrinter, CXX11CodePrinter, CXX17CodePrinter, cxxcode +from sympy.codegen.cfunctions import log1p + + +x, y, u, v = symbols('x y u v') + + +def test_CXX98CodePrinter(): + assert CXX98CodePrinter().doprint(Max(x, 3)) in ('std::max(x, 3)', 'std::max(3, x)') + assert CXX98CodePrinter().doprint(Min(x, 3, sqrt(x))) == 'std::min(3, std::min(x, std::sqrt(x)))' + cxx98printer = CXX98CodePrinter() + assert cxx98printer.language == 'C++' + assert cxx98printer.standard == 'C++98' + assert 'template' in cxx98printer.reserved_words + assert 'alignas' not in cxx98printer.reserved_words + + +def test_CXX11CodePrinter(): + assert CXX11CodePrinter().doprint(log1p(x)) == 'std::log1p(x)' + + cxx11printer = CXX11CodePrinter() + assert cxx11printer.language == 'C++' + assert cxx11printer.standard == 'C++11' + assert 'operator' in cxx11printer.reserved_words + assert 'noexcept' in cxx11printer.reserved_words + assert 'concept' not in cxx11printer.reserved_words + + +def test_subclass_print_method(): + class MyPrinter(CXX11CodePrinter): + def _print_log1p(self, expr): + return 'my_library::log1p(%s)' % ', '.join(map(self._print, expr.args)) + + assert MyPrinter().doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_subclass_print_method__ns(): + class MyPrinter(CXX11CodePrinter): + _ns = 'my_library::' + + p = CXX11CodePrinter() + myp = MyPrinter() + + assert p.doprint(log1p(x)) == 'std::log1p(x)' + assert myp.doprint(log1p(x)) == 'my_library::log1p(x)' + + +def test_CXX17CodePrinter(): + assert CXX17CodePrinter().doprint(beta(x, y)) == 'std::beta(x, y)' + assert CXX17CodePrinter().doprint(Ei(x)) == 'std::expint(x)' + assert CXX17CodePrinter().doprint(zeta(x)) == 'std::riemann_zeta(x)' + + # Automatic rewrite + assert CXX17CodePrinter().doprint(frac(x)) == '(x - std::floor(x))' + assert CXX17CodePrinter().doprint(riemann_xi(x)) == '((1.0/2.0)*std::pow(M_PI, -1.0/2.0*x)*x*(x - 1)*std::tgamma((1.0/2.0)*x)*std::riemann_zeta(x))' + + +def test_cxxcode(): + assert sorted(cxxcode(sqrt(x)*.5).split('*')) == sorted(['0.5', 'std::sqrt(x)']) + +def test_cxxcode_nested_minmax(): + assert cxxcode(Max(Min(x, y), Min(u, v))) \ + == 'std::max(std::min(u, v), std::min(x, y))' + assert cxxcode(Min(Max(x, y), Max(u, v))) \ + == 'std::min(std::max(u, v), std::max(x, y))' + +def test_subclass_Integer_Float(): + class MyPrinter(CXX17CodePrinter): + def _print_Integer(self, arg): + return 'bigInt("%s")' % super()._print_Integer(arg) + + def _print_Float(self, arg): + rat = Rational(arg) + return 'bigFloat(%s, %s)' % ( + self._print(Integer(rat.p)), + self._print(Integer(rat.q)) + ) + + p = MyPrinter() + for i in range(13): + assert p.doprint(i) == 'bigInt("%d")' % i + assert p.doprint(Float(0.5)) == 'bigFloat(bigInt("1"), bigInt("2"))' + assert p.doprint(x**-1.0) == 'bigFloat(bigInt("1"), bigInt("1"))/x' diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_dot.py b/lib/python3.10/site-packages/sympy/printing/tests/test_dot.py new file mode 100644 index 0000000000000000000000000000000000000000..6213e237fb7aac6460a956b4c9fc1f7c8710fec6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_dot.py @@ -0,0 +1,134 @@ +from sympy.printing.dot import (purestr, styleof, attrprint, dotnode, + dotedges, dotprint) +from sympy.core.basic import Basic +from sympy.core.expr import Expr +from sympy.core.numbers import (Float, Integer) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.printing.repr import srepr +from sympy.abc import x + + +def test_purestr(): + assert purestr(Symbol('x')) == "Symbol('x')" + assert purestr(Basic(S(1), S(2))) == "Basic(Integer(1), Integer(2))" + assert purestr(Float(2)) == "Float('2.0', precision=53)" + + assert purestr(Symbol('x'), with_args=True) == ("Symbol('x')", ()) + assert purestr(Basic(S(1), S(2)), with_args=True) == \ + ('Basic(Integer(1), Integer(2))', ('Integer(1)', 'Integer(2)')) + assert purestr(Float(2), with_args=True) == \ + ("Float('2.0', precision=53)", ()) + + +def test_styleof(): + styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}), + (Expr, {'color': 'black'})] + assert styleof(Basic(S(1)), styles) == {'color': 'blue', 'shape': 'ellipse'} + + assert styleof(x + 1, styles) == {'color': 'black', 'shape': 'ellipse'} + + +def test_attrprint(): + assert attrprint({'color': 'blue', 'shape': 'ellipse'}) == \ + '"color"="blue", "shape"="ellipse"' + +def test_dotnode(): + + assert dotnode(x, repeat=False) == \ + '"Symbol(\'x\')" ["color"="black", "label"="x", "shape"="ellipse"];' + assert dotnode(x+2, repeat=False) == \ + '"Add(Integer(2), Symbol(\'x\'))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];', \ + dotnode(x+2,repeat=0) + + assert dotnode(x + x**2, repeat=False) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + assert dotnode(x + x**2, repeat=True) == \ + '"Add(Symbol(\'x\'), Pow(Symbol(\'x\'), Integer(2)))_()" ' \ + '["color"="black", "label"="Add", "shape"="ellipse"];' + +def test_dotedges(): + assert sorted(dotedges(x+2, repeat=False)) == [ + '"Add(Integer(2), Symbol(\'x\'))" -> "Integer(2)";', + '"Add(Integer(2), Symbol(\'x\'))" -> "Symbol(\'x\')";' + ] + assert sorted(dotedges(x + 2, repeat=True)) == [ + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Integer(2)_(0,)";', + '"Add(Integer(2), Symbol(\'x\'))_()" -> "Symbol(\'x\')_(1,)";' + ] + +def test_dotprint(): + text = dotprint(x+2, repeat=False) + assert all(e in text for e in dotedges(x+2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x+2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=False) + assert all(e in text for e in dotedges(x+x**2, repeat=False)) + assert all( + n in text for n in [dotnode(expr, repeat=False) + for expr in (x, Integer(2), x**2)]) + assert 'digraph' in text + + text = dotprint(x+x**2, repeat=True) + assert all(e in text for e in dotedges(x+x**2, repeat=True)) + assert all( + n in text for n in [dotnode(expr, pos=()) + for expr in [x + x**2]]) + + text = dotprint(x**x, repeat=True) + assert all(e in text for e in dotedges(x**x, repeat=True)) + assert all( + n in text for n in [dotnode(x, pos=(0,)), dotnode(x, pos=(1,))]) + assert 'digraph' in text + +def test_dotprint_depth(): + text = dotprint(3*x+2, depth=1) + assert dotnode(3*x+2) in text + assert dotnode(x) not in text + text = dotprint(3*x+2) + assert "depth" not in text + +def test_Matrix_and_non_basics(): + from sympy.matrices.expressions.matexpr import MatrixSymbol + n = Symbol('n') + assert dotprint(MatrixSymbol('X', n, n)) == \ +"""digraph{ + +# Graph style +"ordering"="out" +"rankdir"="TD" + +######### +# Nodes # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" ["color"="black", "label"="MatrixSymbol", "shape"="ellipse"]; +"Str('X')_(0,)" ["color"="blue", "label"="X", "shape"="ellipse"]; +"Symbol('n')_(1,)" ["color"="black", "label"="n", "shape"="ellipse"]; +"Symbol('n')_(2,)" ["color"="black", "label"="n", "shape"="ellipse"]; + +######### +# Edges # +######### + +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Str('X')_(0,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(1,)"; +"MatrixSymbol(Str('X'), Symbol('n'), Symbol('n'))_()" -> "Symbol('n')_(2,)"; +}""" + + +def test_labelfunc(): + text = dotprint(x + 2, labelfunc=srepr) + assert "Symbol('x')" in text + assert "Integer(2)" in text + + +def test_commutative(): + x, y = symbols('x y', commutative=False) + assert dotprint(x + y) == dotprint(y + x) + assert dotprint(x*y) != dotprint(y*x) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_fortran.py b/lib/python3.10/site-packages/sympy/printing/tests/test_fortran.py new file mode 100644 index 0000000000000000000000000000000000000000..c28a1ea16dcf2157b58d763286428dccc1944b71 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_fortran.py @@ -0,0 +1,854 @@ +from sympy.core.add import Add +from sympy.core.expr import Expr +from sympy.core.function import (Function, Lambda, diff) +from sympy.core.mod import Mod +from sympy.core import (Catalan, EulerGamma, GoldenRatio) +from sympy.core.numbers import (E, Float, I, Integer, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.complexes import (conjugate, sign) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (atan2, cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral +from sympy.sets.fancysets import Range + +from sympy.codegen import For, Assignment, aug_assign +from sympy.codegen.ast import Declaration, Variable, float32, float64, \ + value_const, real, bool_, While, FunctionPrototype, FunctionDefinition, \ + integer, Return, Element +from sympy.core.expr import UnevaluatedExpr +from sympy.core.relational import Relational +from sympy.logic.boolalg import And, Or, Not, Equivalent, Xor +from sympy.matrices import Matrix, MatrixSymbol +from sympy.printing.fortran import fcode, FCodePrinter +from sympy.tensor import IndexedBase, Idx +from sympy.tensor.array.expressions import ArraySymbol, ArrayElement +from sympy.utilities.lambdify import implemented_function +from sympy.testing.pytest import raises + + +def test_UnevaluatedExpr(): + p, q, r = symbols("p q r", real=True) + q_r = UnevaluatedExpr(q + r) + expr = abs(exp(p+q_r)) + assert fcode(expr, source_format="free") == "exp(p + (q + r))" + x, y, z = symbols("x y z") + y_z = UnevaluatedExpr(y + z) + expr2 = abs(exp(x+y_z)) + assert fcode(expr2, human=False)[2].lstrip() == "exp(re(x) + re(y + z))" + assert fcode(expr2, user_functions={"re": "realpart"}).lstrip() == "exp(realpart(x) + realpart(y + z))" + + +def test_printmethod(): + x = symbols('x') + + class nint(Function): + def _fcode(self, printer): + return "nint(%s)" % printer._print(self.args[0]) + assert fcode(nint(x)) == " nint(x)" + + +def test_fcode_sign(): #issue 12267 + x=symbols('x') + y=symbols('y', integer=True) + z=symbols('z', complex=True) + assert fcode(sign(x), standard=95, source_format='free') == "merge(0d0, dsign(1d0, x), x == 0d0)" + assert fcode(sign(y), standard=95, source_format='free') == "merge(0, isign(1, y), y == 0)" + assert fcode(sign(z), standard=95, source_format='free') == "merge(cmplx(0d0, 0d0), z/abs(z), abs(z) == 0d0)" + raises(NotImplementedError, lambda: fcode(sign(x))) + + +def test_fcode_Pow(): + x, y = symbols('x,y') + n = symbols('n', integer=True) + + assert fcode(x**3) == " x**3" + assert fcode(x**(y**3)) == " x**(y**3)" + assert fcode(1/(sin(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + " (3.5d0*sin(x))**(-x + y**x)/(x**2 + y)" + assert fcode(sqrt(x)) == ' sqrt(x)' + assert fcode(sqrt(n)) == ' sqrt(dble(n))' + assert fcode(x**0.5) == ' sqrt(x)' + assert fcode(sqrt(x)) == ' sqrt(x)' + assert fcode(sqrt(10)) == ' sqrt(10.0d0)' + assert fcode(x**-1.0) == ' 1d0/x' + assert fcode(x**-2.0, 'y', source_format='free') == 'y = x**(-2.0d0)' # 2823 + assert fcode(x**Rational(3, 7)) == ' x**(3.0d0/7.0d0)' + + +def test_fcode_Rational(): + x = symbols('x') + assert fcode(Rational(3, 7)) == " 3.0d0/7.0d0" + assert fcode(Rational(18, 9)) == " 2" + assert fcode(Rational(3, -7)) == " -3.0d0/7.0d0" + assert fcode(Rational(-3, -7)) == " 3.0d0/7.0d0" + assert fcode(x + Rational(3, 7)) == " x + 3.0d0/7.0d0" + assert fcode(Rational(3, 7)*x) == " (3.0d0/7.0d0)*x" + + +def test_fcode_Integer(): + assert fcode(Integer(67)) == " 67" + assert fcode(Integer(-1)) == " -1" + + +def test_fcode_Float(): + assert fcode(Float(42.0)) == " 42.0000000000000d0" + assert fcode(Float(-1e20)) == " -1.00000000000000d+20" + + +def test_fcode_functions(): + x, y = symbols('x,y') + assert fcode(sin(x) ** cos(y)) == " sin(x)**cos(y)" + raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=66)) + raises(NotImplementedError, lambda: fcode(x % y, standard=66)) + raises(NotImplementedError, lambda: fcode(Mod(x, y), standard=77)) + raises(NotImplementedError, lambda: fcode(x % y, standard=77)) + for standard in [90, 95, 2003, 2008]: + assert fcode(Mod(x, y), standard=standard) == " modulo(x, y)" + assert fcode(x % y, standard=standard) == " modulo(x, y)" + + +def test_case(): + ob = FCodePrinter() + x,x_,x__,y,X,X_,Y = symbols('x,x_,x__,y,X,X_,Y') + assert fcode(exp(x_) + sin(x*y) + cos(X*Y)) == \ + ' exp(x_) + sin(x*y) + cos(X__*Y_)' + assert fcode(exp(x__) + 2*x*Y*X_**Rational(7, 2)) == \ + ' 2*X_**(7.0d0/2.0d0)*Y*x + exp(x__)' + assert fcode(exp(x_) + sin(x*y) + cos(X*Y), name_mangling=False) == \ + ' exp(x_) + sin(x*y) + cos(X*Y)' + assert fcode(x - cos(X), name_mangling=False) == ' x - cos(X)' + assert ob.doprint(X*sin(x) + x_, assign_to='me') == ' me = X*sin(x_) + x__' + assert ob.doprint(X*sin(x), assign_to='mu') == ' mu = X*sin(x_)' + assert ob.doprint(x_, assign_to='ad') == ' ad = x__' + n, m = symbols('n,m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + I = Idx('I', n) + assert fcode(A[i, I]*x[I], assign_to=y[i], source_format='free') == ( + "do i = 1, m\n" + " y(i) = 0\n" + "end do\n" + "do i = 1, m\n" + " do I_ = 1, n\n" + " y(i) = A(i, I_)*x(I_) + y(i)\n" + " end do\n" + "end do" ) + + +#issue 6814 +def test_fcode_functions_with_integers(): + x= symbols('x') + log10_17 = log(10).evalf(17) + loglog10_17 = '0.8340324452479558d0' + assert fcode(x * log(10)) == " x*%sd0" % log10_17 + assert fcode(x * log(10)) == " x*%sd0" % log10_17 + assert fcode(x * log(S(10))) == " x*%sd0" % log10_17 + assert fcode(log(S(10))) == " %sd0" % log10_17 + assert fcode(exp(10)) == " %sd0" % exp(10).evalf(17) + assert fcode(x * log(log(10))) == " x*%s" % loglog10_17 + assert fcode(x * log(log(S(10)))) == " x*%s" % loglog10_17 + + +def test_fcode_NumberSymbol(): + prec = 17 + p = FCodePrinter() + assert fcode(Catalan) == ' parameter (Catalan = %sd0)\n Catalan' % Catalan.evalf(prec) + assert fcode(EulerGamma) == ' parameter (EulerGamma = %sd0)\n EulerGamma' % EulerGamma.evalf(prec) + assert fcode(E) == ' parameter (E = %sd0)\n E' % E.evalf(prec) + assert fcode(GoldenRatio) == ' parameter (GoldenRatio = %sd0)\n GoldenRatio' % GoldenRatio.evalf(prec) + assert fcode(pi) == ' parameter (pi = %sd0)\n pi' % pi.evalf(prec) + assert fcode( + pi, precision=5) == ' parameter (pi = %sd0)\n pi' % pi.evalf(5) + assert fcode(Catalan, human=False) == ({ + (Catalan, p._print(Catalan.evalf(prec)))}, set(), ' Catalan') + assert fcode(EulerGamma, human=False) == ({(EulerGamma, p._print( + EulerGamma.evalf(prec)))}, set(), ' EulerGamma') + assert fcode(E, human=False) == ( + {(E, p._print(E.evalf(prec)))}, set(), ' E') + assert fcode(GoldenRatio, human=False) == ({(GoldenRatio, p._print( + GoldenRatio.evalf(prec)))}, set(), ' GoldenRatio') + assert fcode(pi, human=False) == ( + {(pi, p._print(pi.evalf(prec)))}, set(), ' pi') + assert fcode(pi, precision=5, human=False) == ( + {(pi, p._print(pi.evalf(5)))}, set(), ' pi') + + +def test_fcode_complex(): + assert fcode(I) == " cmplx(0,1)" + x = symbols('x') + assert fcode(4*I) == " cmplx(0,4)" + assert fcode(3 + 4*I) == " cmplx(3,4)" + assert fcode(3 + 4*I + x) == " cmplx(3,4) + x" + assert fcode(I*x) == " cmplx(0,1)*x" + assert fcode(3 + 4*I - x) == " cmplx(3,4) - x" + x = symbols('x', imaginary=True) + assert fcode(5*x) == " 5*x" + assert fcode(I*x) == " cmplx(0,1)*x" + assert fcode(3 + x) == " x + 3" + + +def test_implicit(): + x, y = symbols('x,y') + assert fcode(sin(x)) == " sin(x)" + assert fcode(atan2(x, y)) == " atan2(x, y)" + assert fcode(conjugate(x)) == " conjg(x)" + + +def test_not_fortran(): + x = symbols('x') + g = Function('g') + with raises(NotImplementedError): + fcode(gamma(x)) + assert fcode(Integral(sin(x)), strict=False) == "C Not supported in Fortran:\nC Integral\n Integral(sin(x), x)" + with raises(NotImplementedError): + fcode(g(x)) + + +def test_user_functions(): + x = symbols('x') + assert fcode(sin(x), user_functions={"sin": "zsin"}) == " zsin(x)" + x = symbols('x') + assert fcode( + gamma(x), user_functions={"gamma": "mygamma"}) == " mygamma(x)" + g = Function('g') + assert fcode(g(x), user_functions={"g": "great"}) == " great(x)" + n = symbols('n', integer=True) + assert fcode( + factorial(n), user_functions={"factorial": "fct"}) == " fct(n)" + + +def test_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert fcode(g(x)) == " 2*x" + g = implemented_function('g', Lambda(x, 2*pi/x)) + assert fcode(g(x)) == ( + " parameter (pi = %sd0)\n" + " 2*pi/x" + ) % pi.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert fcode(g(A[i]), assign_to=A[i]) == ( + " do i = 1, n\n" + " A(i) = (A(i) + 1)*(A(i) + 2)*A(i)\n" + " end do" + ) + + +def test_assign_to(): + x = symbols('x') + assert fcode(sin(x), assign_to="s") == " s = sin(x)" + + +def test_line_wrapping(): + x, y = symbols('x,y') + assert fcode(((x + y)**10).expand(), assign_to="var") == ( + " var = x**10 + 10*x**9*y + 45*x**8*y**2 + 120*x**7*y**3 + 210*x**6*\n" + " @ y**4 + 252*x**5*y**5 + 210*x**4*y**6 + 120*x**3*y**7 + 45*x**2*y\n" + " @ **8 + 10*x*y**9 + y**10" + ) + e = [x**i for i in range(11)] + assert fcode(Add(*e)) == ( + " x**10 + x**9 + x**8 + x**7 + x**6 + x**5 + x**4 + x**3 + x**2 + x\n" + " @ + 1" + ) + + +def test_fcode_precedence(): + x, y = symbols("x y") + assert fcode(And(x < y, y < x + 1), source_format="free") == \ + "x < y .and. y < x + 1" + assert fcode(Or(x < y, y < x + 1), source_format="free") == \ + "x < y .or. y < x + 1" + assert fcode(Xor(x < y, y < x + 1, evaluate=False), + source_format="free") == "x < y .neqv. y < x + 1" + assert fcode(Equivalent(x < y, y < x + 1), source_format="free") == \ + "x < y .eqv. y < x + 1" + + +def test_fcode_Logical(): + x, y, z = symbols("x y z") + # unary Not + assert fcode(Not(x), source_format="free") == ".not. x" + # binary And + assert fcode(And(x, y), source_format="free") == "x .and. y" + assert fcode(And(x, Not(y)), source_format="free") == "x .and. .not. y" + assert fcode(And(Not(x), y), source_format="free") == "y .and. .not. x" + assert fcode(And(Not(x), Not(y)), source_format="free") == \ + ".not. x .and. .not. y" + assert fcode(Not(And(x, y), evaluate=False), source_format="free") == \ + ".not. (x .and. y)" + # binary Or + assert fcode(Or(x, y), source_format="free") == "x .or. y" + assert fcode(Or(x, Not(y)), source_format="free") == "x .or. .not. y" + assert fcode(Or(Not(x), y), source_format="free") == "y .or. .not. x" + assert fcode(Or(Not(x), Not(y)), source_format="free") == \ + ".not. x .or. .not. y" + assert fcode(Not(Or(x, y), evaluate=False), source_format="free") == \ + ".not. (x .or. y)" + # mixed And/Or + assert fcode(And(Or(y, z), x), source_format="free") == "x .and. (y .or. z)" + assert fcode(And(Or(z, x), y), source_format="free") == "y .and. (x .or. z)" + assert fcode(And(Or(x, y), z), source_format="free") == "z .and. (x .or. y)" + assert fcode(Or(And(y, z), x), source_format="free") == "x .or. y .and. z" + assert fcode(Or(And(z, x), y), source_format="free") == "y .or. x .and. z" + assert fcode(Or(And(x, y), z), source_format="free") == "z .or. x .and. y" + # trinary And + assert fcode(And(x, y, z), source_format="free") == "x .and. y .and. z" + assert fcode(And(x, y, Not(z)), source_format="free") == \ + "x .and. y .and. .not. z" + assert fcode(And(x, Not(y), z), source_format="free") == \ + "x .and. z .and. .not. y" + assert fcode(And(Not(x), y, z), source_format="free") == \ + "y .and. z .and. .not. x" + assert fcode(Not(And(x, y, z), evaluate=False), source_format="free") == \ + ".not. (x .and. y .and. z)" + # trinary Or + assert fcode(Or(x, y, z), source_format="free") == "x .or. y .or. z" + assert fcode(Or(x, y, Not(z)), source_format="free") == \ + "x .or. y .or. .not. z" + assert fcode(Or(x, Not(y), z), source_format="free") == \ + "x .or. z .or. .not. y" + assert fcode(Or(Not(x), y, z), source_format="free") == \ + "y .or. z .or. .not. x" + assert fcode(Not(Or(x, y, z), evaluate=False), source_format="free") == \ + ".not. (x .or. y .or. z)" + + +def test_fcode_Xlogical(): + x, y, z = symbols("x y z") + # binary Xor + assert fcode(Xor(x, y, evaluate=False), source_format="free") == \ + "x .neqv. y" + assert fcode(Xor(x, Not(y), evaluate=False), source_format="free") == \ + "x .neqv. .not. y" + assert fcode(Xor(Not(x), y, evaluate=False), source_format="free") == \ + "y .neqv. .not. x" + assert fcode(Xor(Not(x), Not(y), evaluate=False), + source_format="free") == ".not. x .neqv. .not. y" + assert fcode(Not(Xor(x, y, evaluate=False), evaluate=False), + source_format="free") == ".not. (x .neqv. y)" + # binary Equivalent + assert fcode(Equivalent(x, y), source_format="free") == "x .eqv. y" + assert fcode(Equivalent(x, Not(y)), source_format="free") == \ + "x .eqv. .not. y" + assert fcode(Equivalent(Not(x), y), source_format="free") == \ + "y .eqv. .not. x" + assert fcode(Equivalent(Not(x), Not(y)), source_format="free") == \ + ".not. x .eqv. .not. y" + assert fcode(Not(Equivalent(x, y), evaluate=False), + source_format="free") == ".not. (x .eqv. y)" + # mixed And/Equivalent + assert fcode(Equivalent(And(y, z), x), source_format="free") == \ + "x .eqv. y .and. z" + assert fcode(Equivalent(And(z, x), y), source_format="free") == \ + "y .eqv. x .and. z" + assert fcode(Equivalent(And(x, y), z), source_format="free") == \ + "z .eqv. x .and. y" + assert fcode(And(Equivalent(y, z), x), source_format="free") == \ + "x .and. (y .eqv. z)" + assert fcode(And(Equivalent(z, x), y), source_format="free") == \ + "y .and. (x .eqv. z)" + assert fcode(And(Equivalent(x, y), z), source_format="free") == \ + "z .and. (x .eqv. y)" + # mixed Or/Equivalent + assert fcode(Equivalent(Or(y, z), x), source_format="free") == \ + "x .eqv. y .or. z" + assert fcode(Equivalent(Or(z, x), y), source_format="free") == \ + "y .eqv. x .or. z" + assert fcode(Equivalent(Or(x, y), z), source_format="free") == \ + "z .eqv. x .or. y" + assert fcode(Or(Equivalent(y, z), x), source_format="free") == \ + "x .or. (y .eqv. z)" + assert fcode(Or(Equivalent(z, x), y), source_format="free") == \ + "y .or. (x .eqv. z)" + assert fcode(Or(Equivalent(x, y), z), source_format="free") == \ + "z .or. (x .eqv. y)" + # mixed Xor/Equivalent + assert fcode(Equivalent(Xor(y, z, evaluate=False), x), + source_format="free") == "x .eqv. (y .neqv. z)" + assert fcode(Equivalent(Xor(z, x, evaluate=False), y), + source_format="free") == "y .eqv. (x .neqv. z)" + assert fcode(Equivalent(Xor(x, y, evaluate=False), z), + source_format="free") == "z .eqv. (x .neqv. y)" + assert fcode(Xor(Equivalent(y, z), x, evaluate=False), + source_format="free") == "x .neqv. (y .eqv. z)" + assert fcode(Xor(Equivalent(z, x), y, evaluate=False), + source_format="free") == "y .neqv. (x .eqv. z)" + assert fcode(Xor(Equivalent(x, y), z, evaluate=False), + source_format="free") == "z .neqv. (x .eqv. y)" + # mixed And/Xor + assert fcode(Xor(And(y, z), x, evaluate=False), source_format="free") == \ + "x .neqv. y .and. z" + assert fcode(Xor(And(z, x), y, evaluate=False), source_format="free") == \ + "y .neqv. x .and. z" + assert fcode(Xor(And(x, y), z, evaluate=False), source_format="free") == \ + "z .neqv. x .and. y" + assert fcode(And(Xor(y, z, evaluate=False), x), source_format="free") == \ + "x .and. (y .neqv. z)" + assert fcode(And(Xor(z, x, evaluate=False), y), source_format="free") == \ + "y .and. (x .neqv. z)" + assert fcode(And(Xor(x, y, evaluate=False), z), source_format="free") == \ + "z .and. (x .neqv. y)" + # mixed Or/Xor + assert fcode(Xor(Or(y, z), x, evaluate=False), source_format="free") == \ + "x .neqv. y .or. z" + assert fcode(Xor(Or(z, x), y, evaluate=False), source_format="free") == \ + "y .neqv. x .or. z" + assert fcode(Xor(Or(x, y), z, evaluate=False), source_format="free") == \ + "z .neqv. x .or. y" + assert fcode(Or(Xor(y, z, evaluate=False), x), source_format="free") == \ + "x .or. (y .neqv. z)" + assert fcode(Or(Xor(z, x, evaluate=False), y), source_format="free") == \ + "y .or. (x .neqv. z)" + assert fcode(Or(Xor(x, y, evaluate=False), z), source_format="free") == \ + "z .or. (x .neqv. y)" + # trinary Xor + assert fcode(Xor(x, y, z, evaluate=False), source_format="free") == \ + "x .neqv. y .neqv. z" + assert fcode(Xor(x, y, Not(z), evaluate=False), source_format="free") == \ + "x .neqv. y .neqv. .not. z" + assert fcode(Xor(x, Not(y), z, evaluate=False), source_format="free") == \ + "x .neqv. z .neqv. .not. y" + assert fcode(Xor(Not(x), y, z, evaluate=False), source_format="free") == \ + "y .neqv. z .neqv. .not. x" + + +def test_fcode_Relational(): + x, y = symbols("x y") + assert fcode(Relational(x, y, "=="), source_format="free") == "x == y" + assert fcode(Relational(x, y, "!="), source_format="free") == "x /= y" + assert fcode(Relational(x, y, ">="), source_format="free") == "x >= y" + assert fcode(Relational(x, y, "<="), source_format="free") == "x <= y" + assert fcode(Relational(x, y, ">"), source_format="free") == "x > y" + assert fcode(Relational(x, y, "<"), source_format="free") == "x < y" + + +def test_fcode_Piecewise(): + x = symbols('x') + expr = Piecewise((x, x < 1), (x**2, True)) + # Check that inline conditional (merge) fails if standard isn't 95+ + raises(NotImplementedError, lambda: fcode(expr)) + code = fcode(expr, standard=95) + expected = " merge(x, x**2, x < 1)" + assert code == expected + assert fcode(Piecewise((x, x < 1), (x**2, True)), assign_to="var") == ( + " if (x < 1) then\n" + " var = x\n" + " else\n" + " var = x**2\n" + " end if" + ) + a = cos(x)/x + b = sin(x)/x + for i in range(10): + a = diff(a, x) + b = diff(b, x) + expected = ( + " if (x < 0) then\n" + " weird_name = -cos(x)/x + 10*sin(x)/x**2 + 90*cos(x)/x**3 - 720*\n" + " @ sin(x)/x**4 - 5040*cos(x)/x**5 + 30240*sin(x)/x**6 + 151200*cos(x\n" + " @ )/x**7 - 604800*sin(x)/x**8 - 1814400*cos(x)/x**9 + 3628800*sin(x\n" + " @ )/x**10 + 3628800*cos(x)/x**11\n" + " else\n" + " weird_name = -sin(x)/x - 10*cos(x)/x**2 + 90*sin(x)/x**3 + 720*\n" + " @ cos(x)/x**4 - 5040*sin(x)/x**5 - 30240*cos(x)/x**6 + 151200*sin(x\n" + " @ )/x**7 + 604800*cos(x)/x**8 - 1814400*sin(x)/x**9 - 3628800*cos(x\n" + " @ )/x**10 + 3628800*sin(x)/x**11\n" + " end if" + ) + code = fcode(Piecewise((a, x < 0), (b, True)), assign_to="weird_name") + assert code == expected + code = fcode(Piecewise((x, x < 1), (x**2, x > 1), (sin(x), True)), standard=95) + expected = " merge(x, merge(x**2, sin(x), x > 1), x < 1)" + assert code == expected + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: fcode(expr)) + + +def test_wrap_fortran(): + # "########################################################################" + printer = FCodePrinter() + lines = [ + "C This is a long comment on a single line that must be wrapped properly to produce nice output", + " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)/must + be + wrapped + properly", + ] + wrapped_lines = printer._wrap_fortran(lines) + expected_lines = [ + "C This is a long comment on a single line that must be wrapped", + "C properly to produce nice output", + " this = is + a + long + and + nasty + fortran + statement + that *", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that *", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ * must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that*", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ *must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement +", + " @ that*must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that**", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ **must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement + that", + " @ **must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement +", + " @ that**must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)/", + " @ must + be + wrapped + properly", + " this = is + a + long + and + nasty + fortran + statement(that)", + " @ /must + be + wrapped + properly", + ] + for line in wrapped_lines: + assert len(line) <= 72 + for w, e in zip(wrapped_lines, expected_lines): + assert w == e + assert len(wrapped_lines) == len(expected_lines) + + +def test_wrap_fortran_keep_d0(): + printer = FCodePrinter() + lines = [ + ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break = 10.0d0' + ] + expected = [ + ' this_variable_is_very_long_because_we_try_to_test_line_break=1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 1.0d0', + ' this_variable_is_very_long_because_we_try_to_test_line_break =', + ' @ 10.0d0' + ] + assert printer._wrap_fortran(lines) == expected + + +def test_settings(): + raises(TypeError, lambda: fcode(S(4), method="garbage")) + + +def test_free_form_code_line(): + x, y = symbols('x,y') + assert fcode(cos(x) + sin(y), source_format='free') == "sin(y) + cos(x)" + + +def test_free_form_continuation_line(): + x, y = symbols('x,y') + result = fcode(((cos(x) + sin(y))**(7)).expand(), source_format='free') + expected = ( + 'sin(y)**7 + 7*sin(y)**6*cos(x) + 21*sin(y)**5*cos(x)**2 + 35*sin(y)**4* &\n' + ' cos(x)**3 + 35*sin(y)**3*cos(x)**4 + 21*sin(y)**2*cos(x)**5 + 7* &\n' + ' sin(y)*cos(x)**6 + cos(x)**7' + ) + assert result == expected + + +def test_free_form_comment_line(): + printer = FCodePrinter({'source_format': 'free'}) + lines = [ "! This is a long comment on a single line that must be wrapped properly to produce nice output"] + expected = [ + '! This is a long comment on a single line that must be wrapped properly', + '! to produce nice output'] + assert printer._wrap_fortran(lines) == expected + + +def test_loops(): + n, m = symbols('n,m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + expected = ( + 'do i = 1, m\n' + ' y(i) = 0\n' + 'end do\n' + 'do i = 1, m\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s\n' + ' end do\n' + 'end do' + ) + + code = fcode(A[i, j]*x[j], assign_to=y[i], source_format='free') + assert (code == expected % {'rhs': 'y(i) + A(i, j)*x(j)'} or + code == expected % {'rhs': 'y(i) + x(j)*A(i, j)'} or + code == expected % {'rhs': 'x(j)*A(i, j) + y(i)'} or + code == expected % {'rhs': 'A(i, j)*x(j) + y(i)'}) + + +def test_dummy_loops(): + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + expected = ( + 'do i_%(icount)i = 1, m_%(mcount)i\n' + ' y(i_%(icount)i) = x(i_%(icount)i)\n' + 'end do' + ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index} + code = fcode(x[i], assign_to=y[i], source_format='free') + assert code == expected + + +def test_fcode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = fcode(e.rhs, assign_to=e.lhs, contract=False) + assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))') + + +def test_element_like_objects(): + len_y = 5 + y = ArraySymbol('y', shape=(len_y,)) + x = ArraySymbol('x', shape=(len_y,)) + Dy = ArraySymbol('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = fcode(Assignment(e.lhs, e.rhs)) + assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))') + + class ElementExpr(Element, Expr): + pass + + e = e.subs((a, ElementExpr(a.name, a.indices)) for a in e.atoms(ArrayElement) ) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = fcode(Assignment(e.lhs, e.rhs)) + assert code0.endswith('Dy(i) = (y(i + 1) - y(i))/(x(i + 1) - x(i))') + + +def test_derived_classes(): + class MyFancyFCodePrinter(FCodePrinter): + _default_settings = FCodePrinter._default_settings.copy() + + printer = MyFancyFCodePrinter() + x = symbols('x') + assert printer.doprint(sin(x), "bork") == " bork = sin(x)" + + +def test_indent(): + codelines = ( + 'subroutine test(a)\n' + 'integer :: a, i, j\n' + '\n' + 'do\n' + 'do \n' + 'do j = 1, 5\n' + 'if (a>b) then\n' + 'if(b>0) then\n' + 'a = 3\n' + 'donot_indent_me = 2\n' + 'do_not_indent_me_either = 2\n' + 'ifIam_indented_something_went_wrong = 2\n' + 'if_I_am_indented_something_went_wrong = 2\n' + 'end should not be unindented here\n' + 'end if\n' + 'endif\n' + 'end do\n' + 'end do\n' + 'enddo\n' + 'end subroutine\n' + '\n' + 'subroutine test2(a)\n' + 'integer :: a\n' + 'do\n' + 'a = a + 1\n' + 'end do \n' + 'end subroutine\n' + ) + expected = ( + 'subroutine test(a)\n' + 'integer :: a, i, j\n' + '\n' + 'do\n' + ' do \n' + ' do j = 1, 5\n' + ' if (a>b) then\n' + ' if(b>0) then\n' + ' a = 3\n' + ' donot_indent_me = 2\n' + ' do_not_indent_me_either = 2\n' + ' ifIam_indented_something_went_wrong = 2\n' + ' if_I_am_indented_something_went_wrong = 2\n' + ' end should not be unindented here\n' + ' end if\n' + ' endif\n' + ' end do\n' + ' end do\n' + 'enddo\n' + 'end subroutine\n' + '\n' + 'subroutine test2(a)\n' + 'integer :: a\n' + 'do\n' + ' a = a + 1\n' + 'end do \n' + 'end subroutine\n' + ) + p = FCodePrinter({'source_format': 'free'}) + result = p.indent_code(codelines) + assert result == expected + +def test_Matrix_printing(): + x, y, z = symbols('x,y,z') + # Test returning a Matrix + mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert fcode(mat, A) == ( + " A(1, 1) = x*y\n" + " if (y > 0) then\n" + " A(2, 1) = x + 2\n" + " else\n" + " A(2, 1) = y\n" + " end if\n" + " A(3, 1) = sin(z)") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert fcode(expr, standard=95) == ( + " merge(2*A(3, 1), A(3, 1), x > 0) + sin(A(2, 1)) + A(1, 1)") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert fcode(m, M) == ( + " M(1, 1) = sin(q(2, 1))\n" + " M(2, 1) = q(2, 1) + q(3, 1)\n" + " M(3, 1) = 2*q(5, 1)/q(2, 1)\n" + " M(1, 2) = 0\n" + " M(2, 2) = q(4, 1)\n" + " M(3, 2) = sqrt(q(1, 1)) + 4\n" + " M(1, 3) = cos(q(3, 1))\n" + " M(2, 3) = 5\n" + " M(3, 3) = 0") + + +def test_fcode_For(): + x, y = symbols('x y') + + f = For(x, Range(0, 10, 2), [Assignment(y, x * y)]) + sol = fcode(f) + assert sol == (" do x = 0, 9, 2\n" + " y = x*y\n" + " end do") + + +def test_fcode_Declaration(): + def check(expr, ref, **kwargs): + assert fcode(expr, standard=95, source_format='free', **kwargs) == ref + + i = symbols('i', integer=True) + var1 = Variable.deduced(i) + dcl1 = Declaration(var1) + check(dcl1, "integer*4 :: i") + + + x, y = symbols('x y') + var2 = Variable(x, float32, value=42, attrs={value_const}) + dcl2b = Declaration(var2) + check(dcl2b, 'real*4, parameter :: x = 42') + + var3 = Variable(y, type=bool_) + dcl3 = Declaration(var3) + check(dcl3, 'logical :: y') + + check(float32, "real*4") + check(float64, "real*8") + check(real, "real*4", type_aliases={real: float32}) + check(real, "real*8", type_aliases={real: float64}) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(fcode(A[0, 0]) == " A(1, 1)") + assert(fcode(3 * A[0, 0]) == " 3*A(1, 1)") + + F = C[0, 0].subs(C, A - B) + assert(fcode(F) == " (A - B)(1, 1)") + + +def test_aug_assign(): + x = symbols('x') + assert fcode(aug_assign(x, '+', 1), source_format='free') == 'x = x + 1' + + +def test_While(): + x = symbols('x') + assert fcode(While(abs(x) > 1, [aug_assign(x, '-', 1)]), source_format='free') == ( + 'do while (abs(x) > 1)\n' + ' x = x - 1\n' + 'end do' + ) + + +def test_FunctionPrototype_print(): + x = symbols('x') + n = symbols('n', integer=True) + vx = Variable(x, type=real) + vn = Variable(n, type=integer) + fp1 = FunctionPrototype(real, 'power', [vx, vn]) + # Should be changed to proper test once multi-line generation is working + # see https://github.com/sympy/sympy/issues/15824 + raises(NotImplementedError, lambda: fcode(fp1)) + + +def test_FunctionDefinition_print(): + x = symbols('x') + n = symbols('n', integer=True) + vx = Variable(x, type=real) + vn = Variable(n, type=integer) + body = [Assignment(x, x**n), Return(x)] + fd1 = FunctionDefinition(real, 'power', [vx, vn], body) + # Should be changed to proper test once multi-line generation is working + # see https://github.com/sympy/sympy/issues/15824 + raises(NotImplementedError, lambda: fcode(fd1)) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_glsl.py b/lib/python3.10/site-packages/sympy/printing/tests/test_glsl.py new file mode 100644 index 0000000000000000000000000000000000000000..86ec1dfe4a37d141e8435c369cb692d3a9a3b7bc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_glsl.py @@ -0,0 +1,998 @@ +from sympy.core import (pi, symbols, Rational, Integer, GoldenRatio, EulerGamma, + Catalan, Lambda, Dummy, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, sin, cos, Abs, exp, ceiling, sqrt +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.printing.glsl import GLSLPrinter +from sympy.printing.str import StrPrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol +from sympy.core import Tuple +from sympy.printing.glsl import glsl_code +import textwrap + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert glsl_code(Abs(x)) == "abs(x)" + +def test_print_without_operators(): + assert glsl_code(x*y,use_operators = False) == 'mul(x, y)' + assert glsl_code(x**y+z,use_operators = False) == 'add(pow(x, y), z)' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z),use_operators = False) == 'mul(x, add(y, z))' + assert glsl_code(x*(y+z**y**0.5),use_operators = False) == 'mul(x, add(y, pow(z, sqrt(y))))' + assert glsl_code(-x-y, use_operators=False, zero='zero()') == 'sub(zero(), add(x, y))' + assert glsl_code(-x-y, use_operators=False) == 'sub(0.0, add(x, y))' + +def test_glsl_code_sqrt(): + assert glsl_code(sqrt(x)) == "sqrt(x)" + assert glsl_code(x**0.5) == "sqrt(x)" + assert glsl_code(sqrt(x)) == "sqrt(x)" + + +def test_glsl_code_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(x**3) == "pow(x, 3.0)" + assert glsl_code(x**(y**3)) == "pow(x, pow(y, 3.0))" + assert glsl_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "pow(3.5*2*x, -x + pow(y, x))/(pow(x, 2.0) + y)" + assert glsl_code(x**-1.0) == '1.0/x' + + +def test_glsl_code_Relational(): + assert glsl_code(Eq(x, y)) == "x == y" + assert glsl_code(Ne(x, y)) == "x != y" + assert glsl_code(Le(x, y)) == "x <= y" + assert glsl_code(Lt(x, y)) == "x < y" + assert glsl_code(Gt(x, y)) == "x > y" + assert glsl_code(Ge(x, y)) == "x >= y" + + +def test_glsl_code_constants_mathh(): + assert glsl_code(exp(1)) == "float E = 2.71828183;\nE" + assert glsl_code(pi) == "float pi = 3.14159265;\npi" + # assert glsl_code(oo) == "Number.POSITIVE_INFINITY" + # assert glsl_code(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_glsl_code_constants_other(): + assert glsl_code(2*GoldenRatio) == "float GoldenRatio = 1.61803399;\n2*GoldenRatio" + assert glsl_code(2*Catalan) == "float Catalan = 0.915965594;\n2*Catalan" + assert glsl_code(2*EulerGamma) == "float EulerGamma = 0.577215665;\n2*EulerGamma" + + +def test_glsl_code_Rational(): + assert glsl_code(Rational(3, 7)) == "3.0/7.0" + assert glsl_code(Rational(18, 9)) == "2" + assert glsl_code(Rational(3, -7)) == "-3.0/7.0" + assert glsl_code(Rational(-3, -7)) == "3.0/7.0" + + +def test_glsl_code_Integer(): + assert glsl_code(Integer(67)) == "67" + assert glsl_code(Integer(-1)) == "-1" + + +def test_glsl_code_functions(): + assert glsl_code(sin(x) ** cos(x)) == "pow(sin(x), cos(x))" + + +def test_glsl_code_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert glsl_code(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert glsl_code(g(x)) == "float Catalan = 0.915965594;\n2*x/Catalan" + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert glsl_code(g(A[i]), assign_to=A[i]) == ( + "for (int i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: glsl_code(expr)) + + +def test_glsl_code_Piecewise_deep(): + p = glsl_code(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + pow(x, 2.0) +))\ +""" + assert p == s + + +def test_glsl_code_settings(): + raises(TypeError, lambda: glsl_code(sin(x), method="garbage")) + + +def test_glsl_code_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = GLSLPrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + +def test_glsl_code_list_tuple_Tuple(): + assert glsl_code([1,2,3,4]) == 'vec4(1, 2, 3, 4)' + assert glsl_code([1,2,3],glsl_types=False) == 'float[3](1, 2, 3)' + assert glsl_code([1,2,3]) == glsl_code((1,2,3)) + assert glsl_code([1,2,3]) == glsl_code(Tuple(1,2,3)) + + m = MatrixSymbol('A',3,4) + assert glsl_code([m[0],m[1]]) + +def test_glsl_code_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (int i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert glsl_code(mat, assign_to=A) == ( +'''A[0][0] = x*y; +if (y > 0) { + A[1][0] = x + 2; +} +else { + A[1][0] = y; +} +A[2][0] = sin(z);''' ) + assert glsl_code(Matrix([A[0],A[1]])) + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert glsl_code(expr) == ( +'''((x > 0) ? ( + 2*A[2][0] +) +: ( + A[2][0] +)) + sin(A[1][0]) + A[0][0]''' ) + + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert glsl_code(m,M) == ( +'''M[0][0] = sin(q[1]); +M[0][1] = 0; +M[0][2] = cos(q[2]); +M[1][0] = q[1] + q[2]; +M[1][1] = q[3]; +M[1][2] = 5; +M[2][0] = 2*q[4]/q[1]; +M[2][1] = sqrt(q[0]) + 4; +M[2][2] = 0;''' + ) + +def test_Matrices_1x7(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + assert gl(A.transpose()) == 'float[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Matrices_1x7_array_type_int(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assert gl(A, array_type='int') == 'int[7](1, 2, 3, 4, 5, 6, 7)' + +def test_Tuple_array_type_custom(): + gl = glsl_code + A = symbols('a b c') + assert gl(A, array_type='AbcType', glsl_types=False) == 'AbcType[3](a, b, c)' + +def test_Matrices_1x7_spread_assign_to_symbols(): + gl = glsl_code + A = Matrix([1,2,3,4,5,6,7]) + assign_to = symbols('x.a x.b x.c x.d x.e x.f x.g') + assert gl(A, assign_to=assign_to) == textwrap.dedent('''\ + x.a = 1; + x.b = 2; + x.c = 3; + x.d = 4; + x.e = 5; + x.f = 6; + x.g = 7;''' + ) + +def test_spread_assign_to_nested_symbols(): + gl = glsl_code + expr = ((1,2,3), (1,2,3)) + assign_to = (symbols('a b c'), symbols('x y z')) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_spread_assign_to_deeply_nested_symbols(): + gl = glsl_code + a, b, c, x, y, z = symbols('a b c x y z') + expr = (((1,2),3), ((1,2),3)) + assign_to = (((a, b), c), ((x, y), z)) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + x = 1; + y = 2; + z = 3;''' + ) + +def test_matrix_of_tuples_spread_assign_to_symbols(): + gl = glsl_code + with warns_deprecated_sympy(): + expr = Matrix([[(1,2),(3,4)],[(5,6),(7,8)]]) + assign_to = (symbols('a b'), symbols('c d'), symbols('e f'), symbols('g h')) + assert gl(expr, assign_to) == textwrap.dedent('''\ + a = 1; + b = 2; + c = 3; + d = 4; + e = 5; + f = 6; + g = 7; + h = 8;''' + ) + +def test_cannot_assign_to_cause_mismatched_length(): + expr = (1, 2) + assign_to = symbols('x y z') + raises(ValueError, lambda: glsl_code(expr, assign_to)) + +def test_matrix_4x4_assign(): + gl = glsl_code + expr = MatrixSymbol('A',4,4) * MatrixSymbol('B',4,4) + MatrixSymbol('C',4,4) + assign_to = MatrixSymbol('X',4,4) + assert gl(expr, assign_to=assign_to) == textwrap.dedent('''\ + X[0][0] = A[0][0]*B[0][0] + A[0][1]*B[1][0] + A[0][2]*B[2][0] + A[0][3]*B[3][0] + C[0][0]; + X[0][1] = A[0][0]*B[0][1] + A[0][1]*B[1][1] + A[0][2]*B[2][1] + A[0][3]*B[3][1] + C[0][1]; + X[0][2] = A[0][0]*B[0][2] + A[0][1]*B[1][2] + A[0][2]*B[2][2] + A[0][3]*B[3][2] + C[0][2]; + X[0][3] = A[0][0]*B[0][3] + A[0][1]*B[1][3] + A[0][2]*B[2][3] + A[0][3]*B[3][3] + C[0][3]; + X[1][0] = A[1][0]*B[0][0] + A[1][1]*B[1][0] + A[1][2]*B[2][0] + A[1][3]*B[3][0] + C[1][0]; + X[1][1] = A[1][0]*B[0][1] + A[1][1]*B[1][1] + A[1][2]*B[2][1] + A[1][3]*B[3][1] + C[1][1]; + X[1][2] = A[1][0]*B[0][2] + A[1][1]*B[1][2] + A[1][2]*B[2][2] + A[1][3]*B[3][2] + C[1][2]; + X[1][3] = A[1][0]*B[0][3] + A[1][1]*B[1][3] + A[1][2]*B[2][3] + A[1][3]*B[3][3] + C[1][3]; + X[2][0] = A[2][0]*B[0][0] + A[2][1]*B[1][0] + A[2][2]*B[2][0] + A[2][3]*B[3][0] + C[2][0]; + X[2][1] = A[2][0]*B[0][1] + A[2][1]*B[1][1] + A[2][2]*B[2][1] + A[2][3]*B[3][1] + C[2][1]; + X[2][2] = A[2][0]*B[0][2] + A[2][1]*B[1][2] + A[2][2]*B[2][2] + A[2][3]*B[3][2] + C[2][2]; + X[2][3] = A[2][0]*B[0][3] + A[2][1]*B[1][3] + A[2][2]*B[2][3] + A[2][3]*B[3][3] + C[2][3]; + X[3][0] = A[3][0]*B[0][0] + A[3][1]*B[1][0] + A[3][2]*B[2][0] + A[3][3]*B[3][0] + C[3][0]; + X[3][1] = A[3][0]*B[0][1] + A[3][1]*B[1][1] + A[3][2]*B[2][1] + A[3][3]*B[3][1] + C[3][1]; + X[3][2] = A[3][0]*B[0][2] + A[3][1]*B[1][2] + A[3][2]*B[2][2] + A[3][3]*B[3][2] + C[3][2]; + X[3][3] = A[3][0]*B[0][3] + A[3][1]*B[1][3] + A[3][2]*B[2][3] + A[3][3]*B[3][3] + C[3][3];''' + ) + +def test_1xN_vecs(): + gl = glsl_code + for i in range(1,10): + A = Matrix(range(i)) + assert gl(A.transpose()) == gl(A) + assert gl(A,mat_transpose=True) == gl(A) + if i > 1: + if i <= 4: + assert gl(A) == 'vec%s(%s)' % (i,', '.join(str(s) for s in range(i))) + else: + assert gl(A) == 'float[%s](%s)' % (i,', '.join(str(s) for s in range(i))) + +def test_MxN_mats(): + generatedAssertions='def test_misc_mats():\n' + for i in range(1,6): + for j in range(1,6): + A = Matrix([[x + y*j for x in range(j)] for y in range(i)]) + gl = glsl_code(A) + glTransposed = glsl_code(A,mat_transpose=True) + generatedAssertions+=' mat = '+StrPrinter()._print(A)+'\n\n' + generatedAssertions+=' gl = \'\'\''+gl+'\'\'\'\n' + generatedAssertions+=' glTransposed = \'\'\''+glTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat) == gl\n' + generatedAssertions+=' assert glsl_code(mat,mat_transpose=True) == glTransposed\n' + if i == 1 and j == 1: + assert gl == '0' + elif i <= 4 and j <= 4 and i>1 and j>1: + assert gl.startswith('mat%s' % j) + assert glTransposed.startswith('mat%s' % i) + elif i == 1 and j <= 4: + assert gl.startswith('vec') + elif j == 1 and i <= 4: + assert gl.startswith('vec') + elif i == 1: + assert gl.startswith('float[%s]('% j*i) + assert glTransposed.startswith('float[%s]('% j*i) + elif j == 1: + assert gl.startswith('float[%s]('% i*j) + assert glTransposed.startswith('float[%s]('% i*j) + else: + assert gl.startswith('float[%s](' % (i*j)) + assert glTransposed.startswith('float[%s](' % (i*j)) + glNested = glsl_code(A,mat_nested=True) + glNestedTransposed = glsl_code(A,mat_transpose=True,mat_nested=True) + assert glNested.startswith('float[%s][%s]' % (i,j)) + assert glNestedTransposed.startswith('float[%s][%s]' % (j,i)) + generatedAssertions+=' glNested = \'\'\''+glNested+'\'\'\'\n' + generatedAssertions+=' glNestedTransposed = \'\'\''+glNestedTransposed+'\'\'\'\n\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True) == glNested\n' + generatedAssertions+=' assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed\n\n' + generateAssertions = False # set this to true to write bake these generated tests to a file + if generateAssertions: + gen = open('test_glsl_generated_matrices.py','w') + gen.write(generatedAssertions) + gen.close() + + +# these assertions were generated from the previous function +# glsl has complicated rules and this makes it easier to look over all the cases +def test_misc_mats(): + + mat = Matrix([[0]]) + + gl = '''0''' + glTransposed = '''0''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([[0, 1, 2, 3, 4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0], +[1]]) + + gl = '''vec2(0, 1)''' + glTransposed = '''vec2(0, 1)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3]]) + + gl = '''mat2(0, 1, 2, 3)''' + glTransposed = '''mat2(0, 2, 1, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5]]) + + gl = '''mat3x2(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat2x3(0, 3, 1, 4, 2, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7]]) + + gl = '''mat4x2(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat2x4(0, 4, 1, 5, 2, 6, 3, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3, 4], +[5, 6, 7, 8, 9]]) + + gl = '''float[10]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9 +) /* a 2x5 matrix */''' + glTransposed = '''float[10]( + 0, 5, + 1, 6, + 2, 7, + 3, 8, + 4, 9 +) /* a 5x2 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[2][5]( + float[](0, 1, 2, 3, 4), + float[](5, 6, 7, 8, 9) +)''' + glNestedTransposed = '''float[5][2]( + float[](0, 5), + float[](1, 6), + float[](2, 7), + float[](3, 8), + float[](4, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2]]) + + gl = '''vec3(0, 1, 2)''' + glTransposed = '''vec3(0, 1, 2)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5]]) + + gl = '''mat2x3(0, 1, 2, 3, 4, 5)''' + glTransposed = '''mat3x2(0, 2, 4, 1, 3, 5)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8]]) + + gl = '''mat3(0, 1, 2, 3, 4, 5, 6, 7, 8)''' + glTransposed = '''mat3(0, 3, 6, 1, 4, 7, 2, 5, 8)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2, 3], +[4, 5, 6, 7], +[8, 9, 10, 11]]) + + gl = '''mat4x3(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat3x4(0, 4, 8, 1, 5, 9, 2, 6, 10, 3, 7, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14 +) /* a 3x5 matrix */''' + glTransposed = '''float[15]( + 0, 5, 10, + 1, 6, 11, + 2, 7, 12, + 3, 8, 13, + 4, 9, 14 +) /* a 5x3 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[3][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14) +)''' + glNestedTransposed = '''float[5][3]( + float[](0, 5, 10), + float[](1, 6, 11), + float[](2, 7, 12), + float[](3, 8, 13), + float[](4, 9, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3]]) + + gl = '''vec4(0, 1, 2, 3)''' + glTransposed = '''vec4(0, 1, 2, 3)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7]]) + + gl = '''mat2x4(0, 1, 2, 3, 4, 5, 6, 7)''' + glTransposed = '''mat4x2(0, 2, 4, 6, 1, 3, 5, 7)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1, 2], +[3, 4, 5], +[6, 7, 8], +[9, 10, 11]]) + + gl = '''mat3x4(0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11)''' + glTransposed = '''mat4x3(0, 3, 6, 9, 1, 4, 7, 10, 2, 5, 8, 11)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15]]) + + gl = '''mat4( 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)''' + glTransposed = '''mat4(0, 4, 8, 12, 1, 5, 9, 13, 2, 6, 10, 14, 3, 7, 11, 15)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19 +) /* a 4x5 matrix */''' + glTransposed = '''float[20]( + 0, 5, 10, 15, + 1, 6, 11, 16, + 2, 7, 12, 17, + 3, 8, 13, 18, + 4, 9, 14, 19 +) /* a 5x4 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[4][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19) +)''' + glNestedTransposed = '''float[5][4]( + float[](0, 5, 10, 15), + float[](1, 6, 11, 16), + float[](2, 7, 12, 17), + float[](3, 8, 13, 18), + float[](4, 9, 14, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[0], +[1], +[2], +[3], +[4]]) + + gl = '''float[5](0, 1, 2, 3, 4)''' + glTransposed = '''float[5](0, 1, 2, 3, 4)''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + + mat = Matrix([ +[0, 1], +[2, 3], +[4, 5], +[6, 7], +[8, 9]]) + + gl = '''float[10]( + 0, 1, + 2, 3, + 4, 5, + 6, 7, + 8, 9 +) /* a 5x2 matrix */''' + glTransposed = '''float[10]( + 0, 2, 4, 6, 8, + 1, 3, 5, 7, 9 +) /* a 2x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][2]( + float[](0, 1), + float[](2, 3), + float[](4, 5), + float[](6, 7), + float[](8, 9) +)''' + glNestedTransposed = '''float[2][5]( + float[](0, 2, 4, 6, 8), + float[](1, 3, 5, 7, 9) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2], +[ 3, 4, 5], +[ 6, 7, 8], +[ 9, 10, 11], +[12, 13, 14]]) + + gl = '''float[15]( + 0, 1, 2, + 3, 4, 5, + 6, 7, 8, + 9, 10, 11, + 12, 13, 14 +) /* a 5x3 matrix */''' + glTransposed = '''float[15]( + 0, 3, 6, 9, 12, + 1, 4, 7, 10, 13, + 2, 5, 8, 11, 14 +) /* a 3x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][3]( + float[]( 0, 1, 2), + float[]( 3, 4, 5), + float[]( 6, 7, 8), + float[]( 9, 10, 11), + float[](12, 13, 14) +)''' + glNestedTransposed = '''float[3][5]( + float[](0, 3, 6, 9, 12), + float[](1, 4, 7, 10, 13), + float[](2, 5, 8, 11, 14) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3], +[ 4, 5, 6, 7], +[ 8, 9, 10, 11], +[12, 13, 14, 15], +[16, 17, 18, 19]]) + + gl = '''float[20]( + 0, 1, 2, 3, + 4, 5, 6, 7, + 8, 9, 10, 11, + 12, 13, 14, 15, + 16, 17, 18, 19 +) /* a 5x4 matrix */''' + glTransposed = '''float[20]( + 0, 4, 8, 12, 16, + 1, 5, 9, 13, 17, + 2, 6, 10, 14, 18, + 3, 7, 11, 15, 19 +) /* a 4x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][4]( + float[]( 0, 1, 2, 3), + float[]( 4, 5, 6, 7), + float[]( 8, 9, 10, 11), + float[](12, 13, 14, 15), + float[](16, 17, 18, 19) +)''' + glNestedTransposed = '''float[4][5]( + float[](0, 4, 8, 12, 16), + float[](1, 5, 9, 13, 17), + float[](2, 6, 10, 14, 18), + float[](3, 7, 11, 15, 19) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed + + mat = Matrix([ +[ 0, 1, 2, 3, 4], +[ 5, 6, 7, 8, 9], +[10, 11, 12, 13, 14], +[15, 16, 17, 18, 19], +[20, 21, 22, 23, 24]]) + + gl = '''float[25]( + 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, + 10, 11, 12, 13, 14, + 15, 16, 17, 18, 19, + 20, 21, 22, 23, 24 +) /* a 5x5 matrix */''' + glTransposed = '''float[25]( + 0, 5, 10, 15, 20, + 1, 6, 11, 16, 21, + 2, 7, 12, 17, 22, + 3, 8, 13, 18, 23, + 4, 9, 14, 19, 24 +) /* a 5x5 matrix */''' + + assert glsl_code(mat) == gl + assert glsl_code(mat,mat_transpose=True) == glTransposed + glNested = '''float[5][5]( + float[]( 0, 1, 2, 3, 4), + float[]( 5, 6, 7, 8, 9), + float[](10, 11, 12, 13, 14), + float[](15, 16, 17, 18, 19), + float[](20, 21, 22, 23, 24) +)''' + glNestedTransposed = '''float[5][5]( + float[](0, 5, 10, 15, 20), + float[](1, 6, 11, 16, 21), + float[](2, 7, 12, 17, 22), + float[](3, 8, 13, 18, 23), + float[](4, 9, 14, 19, 24) +)''' + + assert glsl_code(mat,mat_nested=True) == glNested + assert glsl_code(mat,mat_nested=True,mat_transpose=True) == glNestedTransposed diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_gtk.py b/lib/python3.10/site-packages/sympy/printing/tests/test_gtk.py new file mode 100644 index 0000000000000000000000000000000000000000..5a595ab04d3a29d23e06ec12207bf917392aebce --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_gtk.py @@ -0,0 +1,18 @@ +from sympy.functions.elementary.trigonometric import sin +from sympy.printing.gtk import print_gtk +from sympy.testing.pytest import XFAIL, raises + +# this test fails if python-lxml isn't installed. We don't want to depend on +# anything with SymPy + + +@XFAIL +def test_1(): + from sympy.abc import x + print_gtk(x**2, start_viewer=False) + print_gtk(x**2 + sin(x)/4, start_viewer=False) + + +def test_settings(): + from sympy.abc import x + raises(TypeError, lambda: print_gtk(x, method="garbage")) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_jax.py b/lib/python3.10/site-packages/sympy/printing/tests/test_jax.py new file mode 100644 index 0000000000000000000000000000000000000000..4a58b0bada1d93ce0ea573d81502448c322751c4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_jax.py @@ -0,0 +1,370 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Function, Pow, Symbol +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import JaxPrinter, _jax_known_constants, _jax_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +# Unlike NumPy which will aggressively promote operands to double precision, +# jax always uses single precision. Double precision in jax can be +# configured before the call to `import jax`, however this must be explicitly +# configured and is not fully supported. Thus, the tests here have been modified +# from the tests in test_numpy.py, only in the fact that they assert lambdify +# function accuracy to only single precision accuracy. +# https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision + +jax = import_module('jax') + +if jax: + deafult_float_info = jax.numpy.finfo(jax.numpy.array([]).dtype) + JAX_DEFAULT_EPSILON = deafult_float_info.eps + + +def test_jax_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = JaxPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'jax.numpy.select([jax.numpy.less(x, 0),True], [1,0], default=jax.numpy.nan)' + assert printer.module_imports == {'jax.numpy': {'select', 'less', 'nan'}} + + +def test_jax_logaddexp(): + lae = logaddexp(a, b) + assert JaxPrinter().doprint(lae) == 'jax.numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert JaxPrinter().doprint(lae2) == 'jax.numpy.logaddexp2(a, b)' + + +def test_jax_sum(): + if not jax: + skip("JAX not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'jax') + + a_, b_ = 0, 10 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_jax_multiple_sums(): + if not jax: + skip("JAX not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'jax') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = jax.numpy.linspace(-1, +1, 10) + assert jax.numpy.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_jax_codegen_einsum(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'jax') + + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == jax.numpy.matmul(ma, mb)).all() + + +def test_jax_codegen_extra(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = jax.numpy.array([[1, 2], [3, 4]]) + mb = jax.numpy.array([[1,-2], [-1, 3]]) + mc = jax.numpy.array([[2, 0], [1, 2]]) + md = jax.numpy.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'jax') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'jax') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'jax') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.transpose(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'jax') + assert (f(ma, mb) == jax.numpy.diagonal(jax.numpy.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_jax_relational(): + if not jax: + skip("JAX not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, True]) + + # Multi-condition expressions + e = (x >= 1) & (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [False, True, False]) + + e = (x >= 1) | (x < 2) + f = lambdify((x,), e, 'jax') + x_ = jax.numpy.array([0, 1, 2]) + assert jax.numpy.array_equal(f(x_), [True, True, True]) + +def test_jax_mod(): + if not jax: + skip("JAX not installed") + + e = Mod(a, b) + f = lambdify((a, b), e, 'jax') + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = 2 + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([0, 1, 2, 3]) + b_ = jax.numpy.array([2, 2, 2, 2]) + assert jax.numpy.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = jax.numpy.array([2, 3, 4, 5]) + b_ = jax.numpy.array([2, 3, 4, 5]) + assert jax.numpy.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_jax_pow(): + if not jax: + skip('JAX not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'jax') + assert f() == 0.5 + + +def test_jax_expm1(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), expm1(a), 'jax') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * JAX_DEFAULT_EPSILON + + +def test_jax_log1p(): + if not jax: + skip("JAX not installed") + + f = lambdify((a,), log1p(a), 'jax') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * JAX_DEFAULT_EPSILON + +def test_jax_hypot(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a, b), hypot(a, b), 'jax')(3, 4) - 5) <= JAX_DEFAULT_EPSILON + +def test_jax_log10(): + if not jax: + skip("JAX not installed") + + assert abs(lambdify((a,), log10(a), 'jax')(100) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_exp2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), exp2(a), 'jax')(5) - 32) <= JAX_DEFAULT_EPSILON + + +def test_jax_log2(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), log2(a), 'jax')(256) - 8) <= JAX_DEFAULT_EPSILON + + +def test_jax_Sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), Sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_sqrt(): + if not jax: + skip("JAX not installed") + assert abs(lambdify((a,), sqrt(a), 'jax')(4) - 2) <= JAX_DEFAULT_EPSILON + + +def test_jax_matsolve(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr, 'jax') + f_matsolve = lambdify((M, x), matsolve_expr, 'jax') + + m0 = jax.numpy.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert jax.numpy.linalg.matrix_rank(m0) == 3 + + x0 = jax.numpy.array([3, 4, 5]) + + assert jax.numpy.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not jax: + skip("JAX not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = JaxPrinter() + assert printer.doprint(A) == 'jax.numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not jax: + skip("JAX not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2), 'jax') + ma = jax.numpy.array([[1, 2], [3, 4]]) + mr = jax.numpy.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n), 'jax')) + + +def test_jax_array(): + assert JaxPrinter().doprint(Array(((1, 2), (3, 5)))) == 'jax.numpy.array([[1, 2], [3, 5]])' + assert JaxPrinter().doprint(Array((1, 2))) == 'jax.numpy.array((1, 2))' + + +def test_jax_known_funcs_consts(): + assert _jax_known_constants['NaN'] == 'jax.numpy.nan' + assert _jax_known_constants['EulerGamma'] == 'jax.numpy.euler_gamma' + + assert _jax_known_functions['acos'] == 'jax.numpy.arccos' + assert _jax_known_functions['log'] == 'jax.numpy.log' + + +def test_jax_print_methods(): + prntr = JaxPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + + +def test_jax_printmethod(): + printer = JaxPrinter() + assert hasattr(printer, 'printmethod') + assert printer.printmethod == '_jaxcode' + + +def test_jax_custom_print_method(): + + class expm1(Function): + + def _jaxcode(self, printer): + x, = self.args + function = f'expm1({printer._print(x)})' + return printer._module_format(printer._module + '.' + function) + + printer = JaxPrinter() + assert printer.doprint(expm1(Symbol('x'))) == 'jax.numpy.expm1(x)' diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_jscode.py b/lib/python3.10/site-packages/sympy/printing/tests/test_jscode.py new file mode 100644 index 0000000000000000000000000000000000000000..9199a8e0d62e87f2e964cb1712726a21c894fd20 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_jscode.py @@ -0,0 +1,396 @@ +from sympy.core import (pi, oo, symbols, Rational, Integer, GoldenRatio, + EulerGamma, Catalan, Lambda, Dummy, S, Eq, Ne, Le, + Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sinh, cosh, tanh, asin, acos, acosh, Max, Min) +from sympy.testing.pytest import raises +from sympy.printing.jscode import JavascriptCodePrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol + +from sympy.printing.jscode import jscode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + assert jscode(Abs(x)) == "Math.abs(x)" + + +def test_jscode_sqrt(): + assert jscode(sqrt(x)) == "Math.sqrt(x)" + assert jscode(x**0.5) == "Math.sqrt(x)" + assert jscode(x**(S.One/3)) == "Math.cbrt(x)" + + +def test_jscode_Pow(): + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(x**3) == "Math.pow(x, 3)" + assert jscode(x**(y**3)) == "Math.pow(x, Math.pow(y, 3))" + assert jscode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "Math.pow(3.5*2*x, -x + Math.pow(y, x))/(Math.pow(x, 2) + y)" + assert jscode(x**-1.0) == '1/x' + + +def test_jscode_constants_mathh(): + assert jscode(exp(1)) == "Math.E" + assert jscode(pi) == "Math.PI" + assert jscode(oo) == "Number.POSITIVE_INFINITY" + assert jscode(-oo) == "Number.NEGATIVE_INFINITY" + + +def test_jscode_constants_other(): + assert jscode( + 2*GoldenRatio) == "var GoldenRatio = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert jscode(2*Catalan) == "var Catalan = %s;\n2*Catalan" % Catalan.evalf(17) + assert jscode( + 2*EulerGamma) == "var EulerGamma = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_jscode_Rational(): + assert jscode(Rational(3, 7)) == "3/7" + assert jscode(Rational(18, 9)) == "2" + assert jscode(Rational(3, -7)) == "-3/7" + assert jscode(Rational(-3, -7)) == "3/7" + + +def test_Relational(): + assert jscode(Eq(x, y)) == "x == y" + assert jscode(Ne(x, y)) == "x != y" + assert jscode(Le(x, y)) == "x <= y" + assert jscode(Lt(x, y)) == "x < y" + assert jscode(Gt(x, y)) == "x > y" + assert jscode(Ge(x, y)) == "x >= y" + + +def test_Mod(): + assert jscode(Mod(x, y)) == '((x % y) + y) % y' + assert jscode(Mod(x, x + y)) == '((x % (x + y)) + (x + y)) % (x + y)' + p1, p2 = symbols('p1 p2', positive=True) + assert jscode(Mod(p1, p2)) == 'p1 % p2' + assert jscode(Mod(p1, p2 + 3)) == 'p1 % (p2 + 3)' + assert jscode(Mod(-3, -7, evaluate=False)) == '(-3) % (-7)' + assert jscode(-Mod(p1, p2)) == '-(p1 % p2)' + assert jscode(x*Mod(p1, p2)) == 'x*(p1 % p2)' + + +def test_jscode_Integer(): + assert jscode(Integer(67)) == "67" + assert jscode(Integer(-1)) == "-1" + + +def test_jscode_functions(): + assert jscode(sin(x) ** cos(x)) == "Math.pow(Math.sin(x), Math.cos(x))" + assert jscode(sinh(x) * cosh(x)) == "Math.sinh(x)*Math.cosh(x)" + assert jscode(Max(x, y) + Min(x, y)) == "Math.max(x, y) + Math.min(x, y)" + assert jscode(tanh(x)*acosh(y)) == "Math.tanh(x)*Math.acosh(y)" + assert jscode(asin(x)-acos(y)) == "-Math.acos(y) + Math.asin(x)" + + +def test_jscode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert jscode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert jscode(g(x)) == "var Catalan = %s;\n2*x/Catalan" % Catalan.evalf(17) + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert jscode(g(A[i]), assign_to=A[i]) == ( + "for (var i=0; i 1), (sin(x), x > 0)) + raises(ValueError, lambda: jscode(expr)) + + +def test_jscode_Piecewise_deep(): + p = jscode(2*Piecewise((x, x < 1), (x**2, True))) + s = \ +"""\ +2*((x < 1) ? ( + x +) +: ( + Math.pow(x, 2) +))\ +""" + assert p == s + + +def test_jscode_settings(): + raises(TypeError, lambda: jscode(sin(x), method="garbage")) + + +def test_jscode_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = JavascriptCodePrinter() + p._not_c = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[%s]' % (m*i+j) + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[%s]' % (i*o*m+j*o+k) + + assert p._not_c == set() + + +def test_jscode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (var i=0; i0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + assert jscode(mat, A) == ( + "A[0] = x*y;\n" + "if (y > 0) {\n" + " A[1] = x + 2;\n" + "}\n" + "else {\n" + " A[1] = y;\n" + "}\n" + "A[2] = Math.sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + assert jscode(expr) == ( + "((x > 0) ? (\n" + " 2*A[2]\n" + ")\n" + ": (\n" + " A[2]\n" + ")) + Math.sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert jscode(m, M) == ( + "M[0] = Math.sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = Math.cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = Math.sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(jscode(A[0, 0]) == "A[0]") + assert(jscode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(jscode(F) == "(A - B)[0]") diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_julia.py b/lib/python3.10/site-packages/sympy/printing/tests/test_julia.py new file mode 100644 index 0000000000000000000000000000000000000000..8bfea1035ed9909f55eb5b0c55d99a33689000bb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_julia.py @@ -0,0 +1,386 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.testing.pytest import XFAIL + +from sympy.printing.julia import julia_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert julia_code(Integer(67)) == "67" + assert julia_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert julia_code(Rational(3, 7)) == "3 // 7" + assert julia_code(Rational(18, 9)) == "2" + assert julia_code(Rational(3, -7)) == "-3 // 7" + assert julia_code(Rational(-3, -7)) == "3 // 7" + assert julia_code(x + Rational(3, 7)) == "x + 3 // 7" + assert julia_code(Rational(3, 7)*x) == "(3 // 7) * x" + + +def test_Relational(): + assert julia_code(Eq(x, y)) == "x == y" + assert julia_code(Ne(x, y)) == "x != y" + assert julia_code(Le(x, y)) == "x <= y" + assert julia_code(Lt(x, y)) == "x < y" + assert julia_code(Gt(x, y)) == "x > y" + assert julia_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert julia_code(sin(x) ** cos(x)) == "sin(x) .^ cos(x)" + assert julia_code(abs(x)) == "abs(x)" + assert julia_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert julia_code(x**3) == "x .^ 3" + assert julia_code(x**(y**3)) == "x .^ (y .^ 3)" + assert julia_code(x**Rational(2, 3)) == 'x .^ (2 // 3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert julia_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5 * 2 * x) .^ (-x + y .^ x) ./ (x .^ 2 + y)" + # For issue 14160 + assert julia_code(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2 * x ./ (y .* y)' + + +def test_basic_ops(): + assert julia_code(x*y) == "x .* y" + assert julia_code(x + y) == "x + y" + assert julia_code(x - y) == "x - y" + assert julia_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert julia_code(1/x) == '1 ./ x' + assert julia_code(x**-1) == julia_code(x**-1.0) == '1 ./ x' + assert julia_code(1/sqrt(x)) == '1 ./ sqrt(x)' + assert julia_code(x**-S.Half) == julia_code(x**-0.5) == '1 ./ sqrt(x)' + assert julia_code(sqrt(x)) == 'sqrt(x)' + assert julia_code(x**S.Half) == julia_code(x**0.5) == 'sqrt(x)' + assert julia_code(1/pi) == '1 / pi' + assert julia_code(pi**-1) == julia_code(pi**-1.0) == '1 / pi' + assert julia_code(pi**-0.5) == '1 / sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert julia_code(3*x) == "3 * x" + assert julia_code(pi*x) == "pi * x" + assert julia_code(3/x) == "3 ./ x" + assert julia_code(pi/x) == "pi ./ x" + assert julia_code(x/3) == "x / 3" + assert julia_code(x/pi) == "x / pi" + assert julia_code(x*y) == "x .* y" + assert julia_code(3*x*y) == "3 * x .* y" + assert julia_code(3*pi*x*y) == "3 * pi * x .* y" + assert julia_code(x/y) == "x ./ y" + assert julia_code(3*x/y) == "3 * x ./ y" + assert julia_code(x*y/z) == "x .* y ./ z" + assert julia_code(x/y*z) == "x .* z ./ y" + assert julia_code(1/x/y) == "1 ./ (x .* y)" + assert julia_code(2*pi*x/y/z) == "2 * pi * x ./ (y .* z)" + assert julia_code(3*pi/x) == "3 * pi ./ x" + assert julia_code(S(3)/5) == "3 // 5" + assert julia_code(S(3)/5*x) == "(3 // 5) * x" + assert julia_code(x/y/z) == "x ./ (y .* z)" + assert julia_code((x+y)/z) == "(x + y) ./ z" + assert julia_code((x+y)/(z+x)) == "(x + y) ./ (x + z)" + assert julia_code((x+y)/EulerGamma) == "(x + y) / eulergamma" + assert julia_code(x/3/pi) == "x / (3 * pi)" + assert julia_code(S(3)/5*x*y/pi) == "(3 // 5) * x .* y / pi" + + +def test_mix_number_pow_symbols(): + assert julia_code(pi**3) == 'pi ^ 3' + assert julia_code(x**2) == 'x .^ 2' + assert julia_code(x**(pi**3)) == 'x .^ (pi ^ 3)' + assert julia_code(x**y) == 'x .^ y' + assert julia_code(x**(y**z)) == 'x .^ (y .^ z)' + assert julia_code((x**y)**z) == '(x .^ y) .^ z' + + +def test_imag(): + I = S('I') + assert julia_code(I) == "im" + assert julia_code(5*I) == "5im" + assert julia_code((S(3)/2)*I) == "(3 // 2) * im" + assert julia_code(3+4*I) == "3 + 4im" + + +def test_constants(): + assert julia_code(pi) == "pi" + assert julia_code(oo) == "Inf" + assert julia_code(-oo) == "-Inf" + assert julia_code(S.NegativeInfinity) == "-Inf" + assert julia_code(S.NaN) == "NaN" + assert julia_code(S.Exp1) == "e" + assert julia_code(exp(1)) == "e" + + +def test_constants_other(): + assert julia_code(2*GoldenRatio) == "2 * golden" + assert julia_code(2*Catalan) == "2 * catalan" + assert julia_code(2*EulerGamma) == "2 * eulergamma" + + +def test_boolean(): + assert julia_code(x & y) == "x && y" + assert julia_code(x | y) == "x || y" + assert julia_code(~x) == "!x" + assert julia_code(x & y & z) == "x && y && z" + assert julia_code(x | y | z) == "x || y || z" + assert julia_code((x & y) | z) == "z || x && y" + assert julia_code((x | y) & z) == "z && (x || y)" + + +def test_Matrices(): + assert julia_code(Matrix(1, 1, [10])) == "[10]" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]); + expected = ("[1 sin(x / 2) abs(x);\n" + "0 1 pi;\n" + "0 e ceil(x)]") + assert julia_code(A) == expected + # row and columns + assert julia_code(A[:,0]) == "[1, 0, 0]" + assert julia_code(A[0,:]) == "[1 sin(x / 2) abs(x)]" + # empty matrices + assert julia_code(Matrix(0, 0, [])) == 'zeros(0, 0)' + assert julia_code(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert julia_code(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert julia_code(A) == "[1 sin(2 ./ x) (3 // 5) * pi ./ x]" + assert julia_code(A.T) == "[1, sin(2 ./ x), (3 // 5) * pi ./ x]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert julia_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert julia_code(A*B) == "A * B" + assert julia_code(B*A) == "B * A" + assert julia_code(2*A*B) == "2 * A * B" + assert julia_code(B*2*A) == "2 * B * A" + assert julia_code(A*(B + 3*Identity(n))) == "A * (3 * eye(n) + B)" + assert julia_code(A**(x**2)) == "A ^ (x .^ 2)" + assert julia_code(A**3) == "A ^ 3" + assert julia_code(A**S.Half) == "A ^ (1 // 2)" + + +def test_special_matrices(): + assert julia_code(6*Identity(3)) == "6 * eye(3)" + + +def test_containers(): + assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" + assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" + assert julia_code([1]) == "Any[1]" + assert julia_code((1,)) == "(1,)" + assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" + assert julia_code((1, x*y, (3, x**2))) == "(1, x .* y, (3, x .^ 2))" + # scalar, matrix, empty matrix and empty list + assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + + +def test_julia_noninline(): + source = julia_code((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "const Catalan = %s\n" + "me = (x + y) / Catalan" + ) % Catalan.evalf(17) + assert source == expected + + +def test_julia_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(expr) == "((x < 1) ? (x) : (x .^ 2))" + assert julia_code(expr, assign_to="r") == ( + "r = ((x < 1) ? (x) : (x .^ 2))") + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x\n" + "else\n" + " r = x .^ 2\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1) ? (x .^ 2) :\n" + "(x < 2) ? (x .^ 3) :\n" + "(x < 3) ? (x .^ 4) : (x .^ 5))") + assert julia_code(expr) == expected + assert julia_code(expr, assign_to="r") == "r = " + expected + assert julia_code(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x .^ 2\n" + "elseif (x < 2)\n" + " r = x .^ 3\n" + "elseif (x < 3)\n" + " r = x .^ 4\n" + "else\n" + " r = x .^ 5\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: julia_code(expr)) + + +def test_julia_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert julia_code(2*pw) == "2 * ((x < 1) ? (x) : (x .^ 2))" + assert julia_code(pw/x) == "((x < 1) ? (x) : (x .^ 2)) ./ x" + assert julia_code(pw/(x*y)) == "((x < 1) ? (x) : (x .^ 2)) ./ (x .* y)" + assert julia_code(pw/3) == "((x < 1) ? (x) : (x .^ 2)) / 3" + + +def test_julia_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert julia_code(A, assign_to='a') == "a = [1 2 3]" + A = Matrix([[1, 2], [3, 4]]) + assert julia_code(A, assign_to='A') == "A = [1 2;\n3 4]" + + +def test_julia_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert julia_code(A, assign_to=B) == "B = [1 2 3]" + raises(ValueError, lambda: julia_code(A, assign_to=x)) + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert julia_code(A, assign_to=B) == "B = [3]" + # FIXME? + #assert julia_code(A, assign_to=x) == "x = [3]" + raises(ValueError, lambda: julia_code(A, assign_to=C)) + + +def test_julia_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert julia_code(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" + A = MatrixSymbol('AA', 1, 3) + assert julia_code(A) == "AA" + assert julia_code(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" + assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + + +def test_julia_boolean(): + assert julia_code(True) == "true" + assert julia_code(S.true) == "true" + assert julia_code(False) == "false" + assert julia_code(S.false) == "false" + + +def test_julia_not_supported(): + with raises(NotImplementedError): + julia_code(S.ComplexInfinity) + + f = Function('f') + assert julia_code(f(x).diff(x), strict=False) == ( + "# Not supported in Julia:\n" + "# Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless'); + t2 = S('elsewhere'); + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert julia_code(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_haramard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert julia_code(C) == "A .* B" + assert julia_code(C*v) == "(A .* B) * v" + assert julia_code(h*C*v) == "h * (A .* B) * v" + assert julia_code(C*A) == "(A .* B) * A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert julia_code(C*x*y) == "(x .* y) * (A .* B)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10; + M[1, 2] = 20; + M[1, 3] = 22; + M[0, 3] = 30; + M[3, 0] = x*y; + assert julia_code(M) == ( + "sparse([4, 2, 3, 1, 2], [1, 3, 3, 4, 4], [x .* y, 20, 10, 30, 22], 5, 6)" + ) + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert julia_code(f(n, x)) == f.__name__ + '(n, x)' + for f in [airyai, airyaiprime, airybi, airybiprime]: + assert julia_code(f(x)) == f.__name__ + '(x)' + assert julia_code(hankel1(n, x)) == 'hankelh1(n, x)' + assert julia_code(hankel2(n, x)) == 'hankelh2(n, x)' + assert julia_code(jn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* besselj(n + 1 // 2, x) / 2' + assert julia_code(yn(n, x)) == 'sqrt(2) * sqrt(pi) * sqrt(1 ./ x) .* bessely(n + 1 // 2, x) / 2' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(julia_code(A[0, 0]) == "A[1,1]") + assert(julia_code(3 * A[0, 0]) == "3 * A[1,1]") + + F = C[0, 0].subs(C, A - B) + assert(julia_code(F) == "(A - B)[1,1]") diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_lambdarepr.py b/lib/python3.10/site-packages/sympy/printing/tests/test_lambdarepr.py new file mode 100644 index 0000000000000000000000000000000000000000..e027fc673d2ed69f36c73614d01a4d6f4ef331ad --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_lambdarepr.py @@ -0,0 +1,246 @@ +from sympy.concrete.summations import Sum +from sympy.core.expr import Expr +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.sets.sets import Interval +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises + +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.lambdarepr import lambdarepr, LambdaPrinter, NumExprPrinter + + +x, y, z = symbols("x,y,z") +i, a, b = symbols("i,a,b") +j, c, d = symbols("j,c,d") + + +def test_basic(): + assert lambdarepr(x*y) == "x*y" + assert lambdarepr(x + y) in ["y + x", "x + y"] + assert lambdarepr(x**y) == "x**y" + + +def test_matrix(): + # Test printing a Matrix that has an element that is printed differently + # with the LambdaPrinter than with the StrPrinter. + e = x % 2 + assert lambdarepr(e) != str(e) + assert lambdarepr(Matrix([e])) == 'ImmutableDenseMatrix([[x % 2]])' + + +def test_piecewise(): + # In each case, test eval() the lambdarepr() to make sure there are a + # correct number of parentheses. It will give a SyntaxError if there aren't. + + h = "lambda x: " + + p = Piecewise((x, x < 0)) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 0) else None)" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (0, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (0))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else None)" + + p = Piecewise( + (x, x < 1), + (x**2, Interval(3, 4, True, False).contains(x)), + (0, True), + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x) if (x < 1) else (x**2) if (((x <= 4)) and ((x > 3))) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), + (0, True), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else (0))" + + p = Piecewise( + (x**2, x < 0), + (x, x < 1), + (2 - x, x >= 1), evaluate=False + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((x**2) if (x < 0) else (x) if (x < 1)"\ + " else (2 - x) if (x >= 1) else None)" + + p = Piecewise( + (1, x >= 1), + (2, x >= 2), + (3, x >= 3), + (4, x >= 4), + (5, x >= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x >= 1) else (2) if (x >= 2) else (3) if (x >= 3)"\ + " else (4) if (x >= 4) else (5) if (x >= 5) else (6))" + + p = Piecewise( + (1, x <= 1), + (2, x <= 2), + (3, x <= 3), + (4, x <= 4), + (5, x <= 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x <= 1) else (2) if (x <= 2) else (3) if (x <= 3)"\ + " else (4) if (x <= 4) else (5) if (x <= 5) else (6))" + + p = Piecewise( + (1, x > 1), + (2, x > 2), + (3, x > 3), + (4, x > 4), + (5, x > 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l =="((1) if (x > 1) else (2) if (x > 2) else (3) if (x > 3)"\ + " else (4) if (x > 4) else (5) if (x > 5) else (6))" + + p = Piecewise( + (1, x < 1), + (2, x < 2), + (3, x < 3), + (4, x < 4), + (5, x < 5), + (6, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((1) if (x < 1) else (2) if (x < 2) else (3) if (x < 3)"\ + " else (4) if (x < 4) else (5) if (x < 5) else (6))" + + p = Piecewise( + (Piecewise( + (1, x > 0), + (2, True) + ), y > 0), + (3, True) + ) + l = lambdarepr(p) + eval(h + l) + assert l == "((((1) if (x > 0) else (2))) if (y > 0) else (3))" + + +def test_sum__1(): + # In each case, test eval() the lambdarepr() to make sure that + # it evaluates to the same results as the symbolic expression + s = Sum(x ** i, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(x**i for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + +def test_sum__2(): + s = Sum(i * x, (i, a, b)) + l = lambdarepr(s) + assert l == "(builtins.sum(i*x for i in range(a, b+1)))" + + args = x, a, b + f = lambdify(args, s) + v = 2, 3, 8 + assert f(*v) == s.subs(zip(args, v)).doit() + + +def test_multiple_sums(): + s = Sum(i * x + j, (i, a, b), (j, c, d)) + + l = lambdarepr(s) + assert l == "(builtins.sum(i*x + j for i in range(a, b+1) for j in range(c, d+1)))" + + args = x, a, b, c, d + f = lambdify(args, s) + vals = 2, 3, 4, 5, 6 + f_ref = s.subs(zip(args, vals)).doit() + f_res = f(*vals) + assert f_res == f_ref + + +def test_sqrt(): + prntr = LambdaPrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=False) == 'sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_settings(): + raises(TypeError, lambda: lambdarepr(sin(x), method="garbage")) + + +def test_numexpr(): + # test ITE rewrite as Piecewise + from sympy.logic.boolalg import ITE + expr = ITE(x > 0, True, False, evaluate=False) + assert NumExprPrinter().doprint(expr) == \ + "numexpr.evaluate('where((x > 0), True, False)', truediv=True)" + + from sympy.codegen.ast import Return, FunctionDefinition, Variable, Assignment + func_def = FunctionDefinition(None, 'foo', [Variable(x)], [Assignment(y,x), Return(y**2)]) + expected = "def foo(x):\n"\ + " y = numexpr.evaluate('x', truediv=True)\n"\ + " return numexpr.evaluate('y**2', truediv=True)" + assert NumExprPrinter().doprint(func_def) == expected + + +class CustomPrintedObject(Expr): + def _lambdacode(self, printer): + return 'lambda' + + def _tensorflowcode(self, printer): + return 'tensorflow' + + def _numpycode(self, printer): + return 'numpy' + + def _numexprcode(self, printer): + return 'numexpr' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + # In each case, printmethod is called to test + # its working + + obj = CustomPrintedObject() + assert LambdaPrinter().doprint(obj) == 'lambda' + assert TensorflowPrinter().doprint(obj) == 'tensorflow' + assert NumExprPrinter().doprint(obj) == "numexpr.evaluate('numexpr', truediv=True)" + + assert NumExprPrinter().doprint(Piecewise((y, x >= 0), (z, x < 0))) == \ + "numexpr.evaluate('where((x >= 0), y, z)', truediv=True)" diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_latex.py b/lib/python3.10/site-packages/sympy/printing/tests/test_latex.py new file mode 100644 index 0000000000000000000000000000000000000000..28872fb3a7d1836897fe3fd9146b3c2919c919ce --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_latex.py @@ -0,0 +1,3142 @@ +from sympy import MatAdd, MatMul, Array +from sympy.algebras.quaternion import Quaternion +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.permutations import Cycle, Permutation, AppliedPermutation +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple, Dict +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import (Derivative, Function, Lambda, Subs, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core.numbers import (AlgebraicNumber, Float, I, Integer, Rational, oo, pi) +from sympy.core.parameters import evaluate +from sympy.core.power import Pow +from sympy.core.relational import Eq, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial, factorial2, subfactorial) +from sympy.functions.combinatorial.numbers import (bernoulli, bell, catalan, euler, genocchi, + lucas, fibonacci, tribonacci, divisor_sigma, udivisor_sigma, + mobius, primenu, primeomega, + totient, reduced_totient) +from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, polar_lift, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (asinh, coth) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (Max, Min, root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acsc, asin, cos, cot, sin, tan) +from sympy.functions.special.beta_functions import beta +from sympy.functions.special.delta_functions import (DiracDelta, Heaviside) +from sympy.functions.special.elliptic_integrals import (elliptic_e, elliptic_f, elliptic_k, elliptic_pi) +from sympy.functions.special.error_functions import (Chi, Ci, Ei, Shi, Si, expint) +from sympy.functions.special.gamma_functions import (gamma, uppergamma) +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.functions.special.mathieu_functions import (mathieuc, mathieucprime, mathieus, mathieusprime) +from sympy.functions.special.polynomials import (assoc_laguerre, assoc_legendre, chebyshevt, chebyshevu, gegenbauer, hermite, jacobi, laguerre, legendre) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.spherical_harmonics import (Ynm, Znm) +from sympy.functions.special.tensor_functions import (KroneckerDelta, LeviCivita) +from sympy.functions.special.zeta_functions import (dirichlet_eta, lerchphi, polylog, stieltjes, zeta) +from sympy.integrals.integrals import Integral +from sympy.integrals.transforms import (CosineTransform, FourierTransform, InverseCosineTransform, InverseFourierTransform, InverseLaplaceTransform, InverseMellinTransform, InverseSineTransform, LaplaceTransform, MellinTransform, SineTransform) +from sympy.logic import Implies +from sympy.logic.boolalg import (And, Or, Xor, Equivalent, false, Not, true) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.kronecker import KroneckerProduct +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.permutation import PermutationMatrix +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.physics.control.lti import TransferFunction, Series, Parallel, Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.quantum import Commutator, Operator +from sympy.physics.quantum.trace import Tr +from sympy.physics.units import meter, gibibyte, gram, microgram, second, milli, micro +from sympy.polys.domains.integerring import ZZ +from sympy.polys.fields import field +from sympy.polys.polytools import Poly +from sympy.polys.rings import ring +from sympy.polys.rootoftools import (RootSum, rootof) +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.series.limits import Limit +from sympy.series.order import Order +from sympy.series.sequences import (SeqAdd, SeqFormula, SeqMul, SeqPer) +from sympy.sets.conditionset import ConditionSet +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ComplexRegion, ImageSet, Range) +from sympy.sets.ordinals import Ordinal, OrdinalOmega, OmegaPower +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import (FiniteSet, Interval, Union, Intersection, Complement, SymmetricDifference, ProductSet) +from sympy.sets.setexpr import SetExpr +from sympy.stats.crv_types import Normal +from sympy.stats.symbolic_probability import (Covariance, Expectation, + Probability, Variance) +from sympy.tensor.array import (ImmutableDenseNDimArray, + ImmutableSparseNDimArray, + MutableSparseNDimArray, + MutableDenseNDimArray, + tensorproduct) +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement +from sympy.tensor.indexed import (Idx, Indexed, IndexedBase) +from sympy.tensor.toperators import PartialDerivative +from sympy.vector import CoordSys3D, Cross, Curl, Dot, Divergence, Gradient, Laplacian + + +from sympy.testing.pytest import (XFAIL, raises, _both_exp_pow, + warns_deprecated_sympy) +from sympy.printing.latex import (latex, translate, greek_letters_set, + tex_greek_dictionary, multiline_latex, + latex_escape, LatexPrinter) + +import sympy as sym + +from sympy.abc import mu, tau + + +class lowergamma(sym.lowergamma): + pass # testing notation inheritance by a subclass with same name + + +x, y, z, t, w, a, b, c, s, p = symbols('x y z t w a b c s p') +k, m, n = symbols('k m n', integer=True) + + +def test_printmethod(): + class R(Abs): + def _latex(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert latex(R(x)) == r"foo(x)" + + class R(Abs): + def _latex(self, printer): + return "foo" + assert latex(R(x)) == r"foo" + + +def test_latex_basic(): + assert latex(1 + x) == r"x + 1" + assert latex(x**2) == r"x^{2}" + assert latex(x**(1 + x)) == r"x^{x + 1}" + assert latex(x**3 + x + 1 + x**2) == r"x^{3} + x^{2} + x + 1" + + assert latex(2*x*y) == r"2 x y" + assert latex(2*x*y, mul_symbol='dot') == r"2 \cdot x \cdot y" + assert latex(3*x**2*y, mul_symbol='\\,') == r"3\,x^{2}\,y" + assert latex(1.5*3**x, mul_symbol='\\,') == r"1.5 \cdot 3^{x}" + + assert latex(x**S.Half**5) == r"\sqrt[32]{x}" + assert latex(Mul(S.Half, x**2, -5, evaluate=False)) == r"\frac{1}{2} x^{2} \left(-5\right)" + assert latex(Mul(S.Half, x**2, 5, evaluate=False)) == r"\frac{1}{2} x^{2} \cdot 5" + assert latex(Mul(-5, -5, evaluate=False)) == r"\left(-5\right) \left(-5\right)" + assert latex(Mul(5, -5, evaluate=False)) == r"5 \left(-5\right)" + assert latex(Mul(S.Half, -5, S.Half, evaluate=False)) == r"\frac{1}{2} \left(-5\right) \frac{1}{2}" + assert latex(Mul(5, I, 5, evaluate=False)) == r"5 i 5" + assert latex(Mul(5, I, -5, evaluate=False)) == r"5 i \left(-5\right)" + assert latex(Mul(Pow(x, 2), S.Half*x + 1)) == r"x^{2} \left(\frac{x}{2} + 1\right)" + assert latex(Mul(Pow(x, 3), Rational(2, 3)*x + 1)) == r"x^{3} \left(\frac{2 x}{3} + 1\right)" + assert latex(Mul(Pow(x, 11), 2*x + 1)) == r"x^{11} \left(2 x + 1\right)" + + assert latex(Mul(0, 1, evaluate=False)) == r'0 \cdot 1' + assert latex(Mul(1, 0, evaluate=False)) == r'1 \cdot 0' + assert latex(Mul(1, 1, evaluate=False)) == r'1 \cdot 1' + assert latex(Mul(-1, 1, evaluate=False)) == r'\left(-1\right) 1' + assert latex(Mul(1, 1, 1, evaluate=False)) == r'1 \cdot 1 \cdot 1' + assert latex(Mul(1, 2, evaluate=False)) == r'1 \cdot 2' + assert latex(Mul(1, S.Half, evaluate=False)) == r'1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, S.Half, evaluate=False)) == \ + r'1 \cdot 1 \cdot \frac{1}{2}' + assert latex(Mul(1, 1, 2, 3, x, evaluate=False)) == \ + r'1 \cdot 1 \cdot 2 \cdot 3 x' + assert latex(Mul(1, -1, evaluate=False)) == r'1 \left(-1\right)' + assert latex(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \cdot 1 \cdot 0 y x' + assert latex(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == \ + r'4 \cdot 3 \cdot 2 \left(z + 1\right) 0 y x' + assert latex(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == \ + r'\frac{2}{3} \cdot \frac{5}{7}' + + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/x, fold_short_frac=True) == r"1 / x" + assert latex(-S(3)/2) == r"- \frac{3}{2}" + assert latex(-S(3)/2, fold_short_frac=True) == r"- 3 / 2" + assert latex(1/x**2) == r"\frac{1}{x^{2}}" + assert latex(1/(x + y)/2) == r"\frac{1}{2 \left(x + y\right)}" + assert latex(x/2) == r"\frac{x}{2}" + assert latex(x/2, fold_short_frac=True) == r"x / 2" + assert latex((x + y)/(2*x)) == r"\frac{x + y}{2 x}" + assert latex((x + y)/(2*x), fold_short_frac=True) == \ + r"\left(x + y\right) / 2 x" + assert latex((x + y)/(2*x), long_frac_ratio=0) == \ + r"\frac{1}{2 x} \left(x + y\right)" + assert latex((x + y)/x) == r"\frac{x + y}{x}" + assert latex((x + y)/x, long_frac_ratio=3) == r"\frac{x + y}{x}" + assert latex((2*sqrt(2)*x)/3) == r"\frac{2 \sqrt{2} x}{3}" + assert latex((2*sqrt(2)*x)/3, long_frac_ratio=2) == \ + r"\frac{2 x}{3} \sqrt{2}" + assert latex(binomial(x, y)) == r"{\binom{x}{y}}" + + x_star = Symbol('x^*') + f = Function('f') + assert latex(x_star**2) == r"\left(x^{*}\right)^{2}" + assert latex(x_star**2, parenthesize_super=False) == r"{x^{*}}^{2}" + assert latex(Derivative(f(x_star), x_star,2)) == r"\frac{d^{2}}{d \left(x^{*}\right)^{2}} f{\left(x^{*} \right)}" + assert latex(Derivative(f(x_star), x_star,2), parenthesize_super=False) == r"\frac{d^{2}}{d {x^{*}}^{2}} f{\left(x^{*} \right)}" + + assert latex(2*Integral(x, x)/3) == r"\frac{2 \int x\, dx}{3}" + assert latex(2*Integral(x, x)/3, fold_short_frac=True) == \ + r"\left(2 \int x\, dx\right) / 3" + + assert latex(sqrt(x)) == r"\sqrt{x}" + assert latex(x**Rational(1, 3)) == r"\sqrt[3]{x}" + assert latex(x**Rational(1, 3), root_notation=False) == r"x^{\frac{1}{3}}" + assert latex(sqrt(x)**3) == r"x^{\frac{3}{2}}" + assert latex(sqrt(x), itex=True) == r"\sqrt{x}" + assert latex(x**Rational(1, 3), itex=True) == r"\root{3}{x}" + assert latex(sqrt(x)**3, itex=True) == r"x^{\frac{3}{2}}" + assert latex(x**Rational(3, 4)) == r"x^{\frac{3}{4}}" + assert latex(x**Rational(3, 4), fold_frac_powers=True) == r"x^{3/4}" + assert latex((x + 1)**Rational(3, 4)) == \ + r"\left(x + 1\right)^{\frac{3}{4}}" + assert latex((x + 1)**Rational(3, 4), fold_frac_powers=True) == \ + r"\left(x + 1\right)^{3/4}" + assert latex(AlgebraicNumber(sqrt(2))) == r"\sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), [3, -7])) == r"-7 + 3 \sqrt{2}" + assert latex(AlgebraicNumber(sqrt(2), alias='alpha')) == r"\alpha" + assert latex(AlgebraicNumber(sqrt(2), [3, -7], alias='alpha')) == \ + r"3 \alpha - 7" + assert latex(AlgebraicNumber(2**(S(1)/3), [1, 3, -7], alias='beta')) == \ + r"\beta^{2} + 3 \beta - 7" + + k = ZZ.cyclotomic_field(5) + assert latex(k.ext.field_element([1, 2, 3, 4])) == \ + r"\zeta^{3} + 2 \zeta^{2} + 3 \zeta + 4" + assert latex(k.ext.field_element([1, 2, 3, 4]), order='old') == \ + r"4 + 3 \zeta + 2 \zeta^{2} + \zeta^{3}" + assert latex(k.primes_above(19)[0]) == \ + r"\left(19, \zeta^{2} + 5 \zeta + 1\right)" + assert latex(k.primes_above(19)[0], order='old') == \ + r"\left(19, 1 + 5 \zeta + \zeta^{2}\right)" + assert latex(k.primes_above(7)[0]) == r"\left(7\right)" + + assert latex(1.5e20*x) == r"1.5 \cdot 10^{20} x" + assert latex(1.5e20*x, mul_symbol='dot') == r"1.5 \cdot 10^{20} \cdot x" + assert latex(1.5e20*x, mul_symbol='times') == \ + r"1.5 \times 10^{20} \times x" + + assert latex(1/sin(x)) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**-1) == r"\frac{1}{\sin{\left(x \right)}}" + assert latex(sin(x)**Rational(3, 2)) == \ + r"\sin^{\frac{3}{2}}{\left(x \right)}" + assert latex(sin(x)**Rational(3, 2), fold_frac_powers=True) == \ + r"\sin^{3/2}{\left(x \right)}" + + assert latex(~x) == r"\neg x" + assert latex(x & y) == r"x \wedge y" + assert latex(x & y & z) == r"x \wedge y \wedge z" + assert latex(x | y) == r"x \vee y" + assert latex(x | y | z) == r"x \vee y \vee z" + assert latex((x & y) | z) == r"z \vee \left(x \wedge y\right)" + assert latex(Implies(x, y)) == r"x \Rightarrow y" + assert latex(~(x >> ~y)) == r"x \not\Rightarrow \neg y" + assert latex(Implies(Or(x,y), z)) == r"\left(x \vee y\right) \Rightarrow z" + assert latex(Implies(z, Or(x,y))) == r"z \Rightarrow \left(x \vee y\right)" + assert latex(~(x & y)) == r"\neg \left(x \wedge y\right)" + + assert latex(~x, symbol_names={x: "x_i"}) == r"\neg x_i" + assert latex(x & y, symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \wedge y_i" + assert latex(x & y & z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \wedge y_i \wedge z_i" + assert latex(x | y, symbol_names={x: "x_i", y: "y_i"}) == r"x_i \vee y_i" + assert latex(x | y | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"x_i \vee y_i \vee z_i" + assert latex((x & y) | z, symbol_names={x: "x_i", y: "y_i", z: "z_i"}) == \ + r"z_i \vee \left(x_i \wedge y_i\right)" + assert latex(Implies(x, y), symbol_names={x: "x_i", y: "y_i"}) == \ + r"x_i \Rightarrow y_i" + assert latex(Pow(Rational(1, 3), -1, evaluate=False)) == r"\frac{1}{\frac{1}{3}}" + assert latex(Pow(Rational(1, 3), -2, evaluate=False)) == r"\frac{1}{(\frac{1}{3})^{2}}" + assert latex(Pow(Integer(1)/100, -1, evaluate=False)) == r"\frac{1}{\frac{1}{100}}" + + + p = Symbol('p', positive=True) + assert latex(exp(-p)*log(p)) == r"e^{- p} \log{\left(p \right)}" + + +def test_latex_builtins(): + assert latex(True) == r"\text{True}" + assert latex(False) == r"\text{False}" + assert latex(None) == r"\text{None}" + assert latex(true) == r"\text{True}" + assert latex(false) == r'\text{False}' + + +def test_latex_SingularityFunction(): + assert latex(SingularityFunction(x, 4, 5)) == \ + r"{\left\langle x - 4 \right\rangle}^{5}" + assert latex(SingularityFunction(x, -3, 4)) == \ + r"{\left\langle x + 3 \right\rangle}^{4}" + assert latex(SingularityFunction(x, 0, 4)) == \ + r"{\left\langle x \right\rangle}^{4}" + assert latex(SingularityFunction(x, a, n)) == \ + r"{\left\langle - a + x \right\rangle}^{n}" + assert latex(SingularityFunction(x, 4, -2)) == \ + r"{\left\langle x - 4 \right\rangle}^{-2}" + assert latex(SingularityFunction(x, 4, -1)) == \ + r"{\left\langle x - 4 \right\rangle}^{-1}" + + assert latex(SingularityFunction(x, 4, 5)**3) == \ + r"{\left({\langle x - 4 \rangle}^{5}\right)}^{3}" + assert latex(SingularityFunction(x, -3, 4)**3) == \ + r"{\left({\langle x + 3 \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, 0, 4)**3) == \ + r"{\left({\langle x \rangle}^{4}\right)}^{3}" + assert latex(SingularityFunction(x, a, n)**3) == \ + r"{\left({\langle - a + x \rangle}^{n}\right)}^{3}" + assert latex(SingularityFunction(x, 4, -2)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-2}\right)}^{3}" + assert latex((SingularityFunction(x, 4, -1)**3)**3) == \ + r"{\left({\langle x - 4 \rangle}^{-1}\right)}^{9}" + + +def test_latex_cycle(): + assert latex(Cycle(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Cycle(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Cycle()) == r"\left( \right)" + + +def test_latex_permutation(): + assert latex(Permutation(1, 2, 4)) == r"\left( 1\; 2\; 4\right)" + assert latex(Permutation(1, 2)(4, 5, 6)) == \ + r"\left( 1\; 2\right)\left( 4\; 5\; 6\right)" + assert latex(Permutation()) == r"\left( \right)" + assert latex(Permutation(2, 4)*Permutation(5)) == \ + r"\left( 2\; 4\right)\left( 5\right)" + assert latex(Permutation(5)) == r"\left( 5\right)" + + assert latex(Permutation(0, 1), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 \\ 1 & 0 \end{pmatrix}" + assert latex(Permutation(0, 1)(2, 3), perm_cyclic=False) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + assert latex(Permutation(), perm_cyclic=False) == \ + r"\left( \right)" + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert latex(Permutation(0, 1)(2, 3)) == \ + r"\begin{pmatrix} 0 & 1 & 2 & 3 \\ 1 & 0 & 3 & 2 \end{pmatrix}" + Permutation.print_cyclic = old_print_cyclic + +def test_latex_Float(): + assert latex(Float(1.0e100)) == r"1.0 \cdot 10^{100}" + assert latex(Float(1.0e-100)) == r"1.0 \cdot 10^{-100}" + assert latex(Float(1.0e-100), mul_symbol="times") == \ + r"1.0 \times 10^{-100}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=2) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=4) == \ + r"1.0 \cdot 10^{4}" + assert latex(Float('10000.0'), full_prec=False, min=-2, max=5) == \ + r"10000.0" + assert latex(Float('0.099999'), full_prec=True, min=-2, max=5) == \ + r"9.99990000000000 \cdot 10^{-2}" + + +def test_latex_vector_expressions(): + A = CoordSys3D('A') + + assert latex(Cross(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Cross(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}" + assert latex(x*Cross(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \times \mathbf{\hat{j}_{A}}\right)" + assert latex(Cross(x*A.i, A.j)) == \ + r'- \mathbf{\hat{j}_{A}} \times \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)' + + assert latex(Curl(3*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*A.x*A.j+A.i)) == \ + r"\nabla\times \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Curl(3*x*A.x*A.j)) == \ + r"\nabla\times \left(\left(3 \mathbf{{x}_{A}} x\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Curl(3*A.x*A.j)) == \ + r"x \left(\nabla\times \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Divergence(3*A.x*A.j+A.i)) == \ + r"\nabla\cdot \left(\mathbf{\hat{i}_{A}} + \left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(Divergence(3*A.x*A.j)) == \ + r"\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)" + assert latex(x*Divergence(3*A.x*A.j)) == \ + r"x \left(\nabla\cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}}\right)\right)" + + assert latex(Dot(A.i, A.j*A.x*3+A.k)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \left(\left(3 \mathbf{{x}_{A}}\right)\mathbf{\hat{j}_{A}} + \mathbf{\hat{k}_{A}}\right)" + assert latex(Dot(A.i, A.j)) == \ + r"\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}" + assert latex(Dot(x*A.i, A.j)) == \ + r"\mathbf{\hat{j}_{A}} \cdot \left(\left(x\right)\mathbf{\hat{i}_{A}}\right)" + assert latex(x*Dot(A.i, A.j)) == \ + r"x \left(\mathbf{\hat{i}_{A}} \cdot \mathbf{\hat{j}_{A}}\right)" + + assert latex(Gradient(A.x)) == r"\nabla \mathbf{{x}_{A}}" + assert latex(Gradient(A.x + 3*A.y)) == \ + r"\nabla \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Gradient(A.x)) == r"x \left(\nabla \mathbf{{x}_{A}}\right)" + assert latex(Gradient(x*A.x)) == r"\nabla \left(\mathbf{{x}_{A}} x\right)" + + assert latex(Laplacian(A.x)) == r"\Delta \mathbf{{x}_{A}}" + assert latex(Laplacian(A.x + 3*A.y)) == \ + r"\Delta \left(\mathbf{{x}_{A}} + 3 \mathbf{{y}_{A}}\right)" + assert latex(x*Laplacian(A.x)) == r"x \left(\Delta \mathbf{{x}_{A}}\right)" + assert latex(Laplacian(x*A.x)) == r"\Delta \left(\mathbf{{x}_{A}} x\right)" + +def test_latex_symbols(): + Gamma, lmbda, rho = symbols('Gamma, lambda, rho') + tau, Tau, TAU, taU = symbols('tau, Tau, TAU, taU') + assert latex(tau) == r"\tau" + assert latex(Tau) == r"\mathrm{T}" + assert latex(TAU) == r"\tau" + assert latex(taU) == r"\tau" + # Check that all capitalized greek letters are handled explicitly + capitalized_letters = {l.capitalize() for l in greek_letters_set} + assert len(capitalized_letters - set(tex_greek_dictionary.keys())) == 0 + assert latex(Gamma + lmbda) == r"\Gamma + \lambda" + assert latex(Gamma * lmbda) == r"\Gamma \lambda" + assert latex(Symbol('q1')) == r"q_{1}" + assert latex(Symbol('q21')) == r"q_{21}" + assert latex(Symbol('epsilon0')) == r"\epsilon_{0}" + assert latex(Symbol('omega1')) == r"\omega_{1}" + assert latex(Symbol('91')) == r"91" + assert latex(Symbol('alpha_new')) == r"\alpha_{new}" + assert latex(Symbol('C^orig')) == r"C^{orig}" + assert latex(Symbol('x^alpha')) == r"x^{\alpha}" + assert latex(Symbol('beta^alpha')) == r"\beta^{\alpha}" + assert latex(Symbol('e^Alpha')) == r"e^{\mathrm{A}}" + assert latex(Symbol('omega_alpha^beta')) == r"\omega^{\beta}_{\alpha}" + assert latex(Symbol('omega') ** Symbol('beta')) == r"\omega^{\beta}" + + +@XFAIL +def test_latex_symbols_failing(): + rho, mass, volume = symbols('rho, mass, volume') + assert latex( + volume * rho == mass) == r"\rho \mathrm{volume} = \mathrm{mass}" + assert latex(volume / mass * rho == 1) == \ + r"\rho \mathrm{volume} {\mathrm{mass}}^{(-1)} = 1" + assert latex(mass**3 * volume**3) == \ + r"{\mathrm{mass}}^{3} \cdot {\mathrm{volume}}^{3}" + + +@_both_exp_pow +def test_latex_functions(): + assert latex(exp(x)) == r"e^{x}" + assert latex(exp(1) + exp(2)) == r"e + e^{2}" + + f = Function('f') + assert latex(f(x)) == r'f{\left(x \right)}' + assert latex(f) == r'f' + + g = Function('g') + assert latex(g(x, y)) == r'g{\left(x,y \right)}' + assert latex(g) == r'g' + + h = Function('h') + assert latex(h(x, y, z)) == r'h{\left(x,y,z \right)}' + assert latex(h) == r'h' + + Li = Function('Li') + assert latex(Li) == r'\operatorname{Li}' + assert latex(Li(x)) == r'\operatorname{Li}{\left(x \right)}' + + mybeta = Function('beta') + # not to be confused with the beta function + assert latex(mybeta(x, y, z)) == r"\beta{\left(x,y,z \right)}" + assert latex(beta(x, y)) == r'\operatorname{B}\left(x, y\right)' + assert latex(beta(x, evaluate=False)) == r'\operatorname{B}\left(x, x\right)' + assert latex(beta(x, y)**2) == r'\operatorname{B}^{2}\left(x, y\right)' + assert latex(mybeta(x)) == r"\beta{\left(x \right)}" + assert latex(mybeta) == r"\beta" + + g = Function('gamma') + # not to be confused with the gamma function + assert latex(g(x, y, z)) == r"\gamma{\left(x,y,z \right)}" + assert latex(g(x)) == r"\gamma{\left(x \right)}" + assert latex(g) == r"\gamma" + + a_1 = Function('a_1') + assert latex(a_1) == r"a_{1}" + assert latex(a_1(x)) == r"a_{1}{\left(x \right)}" + assert latex(Function('a_1')) == r"a_{1}" + + # Issue #16925 + # multi letter function names + # > simple + assert latex(Function('ab')) == r"\operatorname{ab}" + assert latex(Function('ab1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_1')) == r"\operatorname{ab}_{1}" + assert latex(Function('ab_12')) == r"\operatorname{ab}_{12}" + assert latex(Function('ab_c')) == r"\operatorname{ab}_{c}" + assert latex(Function('ab_cd')) == r"\operatorname{ab}_{cd}" + # > with argument + assert latex(Function('ab')(Symbol('x'))) == r"\operatorname{ab}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))) == r"\operatorname{ab}_{12}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))) == r"\operatorname{ab}_{1}{\left(x \right)}" + assert latex(Function('ab_c')(Symbol('x'))) == r"\operatorname{ab}_{c}{\left(x \right)}" + assert latex(Function('ab_cd')(Symbol('x'))) == r"\operatorname{ab}_{cd}{\left(x \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('ab')()**2) == r"\operatorname{ab}^{2}{\left( \right)}" + assert latex(Function('ab1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab_1')()**2) == r"\operatorname{ab}_{1}^{2}{\left( \right)}" + assert latex(Function('ab_12')()**2) == r"\operatorname{ab}_{12}^{2}{\left( \right)}" + assert latex(Function('ab')(Symbol('x'))**2) == r"\operatorname{ab}^{2}{\left(x \right)}" + assert latex(Function('ab1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab12')(Symbol('x'))**2) == r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + assert latex(Function('ab_1')(Symbol('x'))**2) == r"\operatorname{ab}_{1}^{2}{\left(x \right)}" + assert latex(Function('ab_12')(Symbol('x'))**2) == \ + r"\operatorname{ab}_{12}^{2}{\left(x \right)}" + + # single letter function names + # > simple + assert latex(Function('a')) == r"a" + assert latex(Function('a1')) == r"a_{1}" + assert latex(Function('a12')) == r"a_{12}" + assert latex(Function('a_1')) == r"a_{1}" + assert latex(Function('a_12')) == r"a_{12}" + + # > with argument + assert latex(Function('a')()) == r"a{\left( \right)}" + assert latex(Function('a1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a12')()) == r"a_{12}{\left( \right)}" + assert latex(Function('a_1')()) == r"a_{1}{\left( \right)}" + assert latex(Function('a_12')()) == r"a_{12}{\left( \right)}" + + # > with power + # does not work on functions without brackets + + # > with argument and power combined + assert latex(Function('a')()**2) == r"a^{2}{\left( \right)}" + assert latex(Function('a1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a_1')()**2) == r"a_{1}^{2}{\left( \right)}" + assert latex(Function('a_12')()**2) == r"a_{12}^{2}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**2) == r"a^{2}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**2) == r"a_{1}^{2}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**2) == r"a_{12}^{2}{\left(x \right)}" + + assert latex(Function('a')()**32) == r"a^{32}{\left( \right)}" + assert latex(Function('a1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a_1')()**32) == r"a_{1}^{32}{\left( \right)}" + assert latex(Function('a_12')()**32) == r"a_{12}^{32}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**32) == r"a^{32}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**32) == r"a_{1}^{32}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**32) == r"a_{12}^{32}{\left(x \right)}" + + assert latex(Function('a')()**a) == r"a^{a}{\left( \right)}" + assert latex(Function('a1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a_1')()**a) == r"a_{1}^{a}{\left( \right)}" + assert latex(Function('a_12')()**a) == r"a_{12}^{a}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**a) == r"a^{a}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**a) == r"a_{1}^{a}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**a) == r"a_{12}^{a}{\left(x \right)}" + + ab = Symbol('ab') + assert latex(Function('a')()**ab) == r"a^{ab}{\left( \right)}" + assert latex(Function('a1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a_1')()**ab) == r"a_{1}^{ab}{\left( \right)}" + assert latex(Function('a_12')()**ab) == r"a_{12}^{ab}{\left( \right)}" + assert latex(Function('a')(Symbol('x'))**ab) == r"a^{ab}{\left(x \right)}" + assert latex(Function('a1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + assert latex(Function('a_1')(Symbol('x'))**ab) == r"a_{1}^{ab}{\left(x \right)}" + assert latex(Function('a_12')(Symbol('x'))**ab) == r"a_{12}^{ab}{\left(x \right)}" + + assert latex(Function('a^12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a^12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a__12')(x)) == R"a^{12}{\left(x \right)}" + assert latex(Function('a__12')(x) ** ab) == R"\left(a^{12}\right)^{ab}{\left(x \right)}" + assert latex(Function('a_1__1_2')(x)) == R"a^{1}_{1 2}{\left(x \right)}" + + # issue 5868 + omega1 = Function('omega1') + assert latex(omega1) == r"\omega_{1}" + assert latex(omega1(x)) == r"\omega_{1}{\left(x \right)}" + + assert latex(sin(x)) == r"\sin{\left(x \right)}" + assert latex(sin(x), fold_func_brackets=True) == r"\sin {x}" + assert latex(sin(2*x**2), fold_func_brackets=True) == \ + r"\sin {2 x^{2}}" + assert latex(sin(x**2), fold_func_brackets=True) == \ + r"\sin {x^{2}}" + + assert latex(asin(x)**2) == r"\operatorname{asin}^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="full") == \ + r"\arcsin^{2}{\left(x \right)}" + assert latex(asin(x)**2, inv_trig_style="power") == \ + r"\sin^{-1}{\left(x \right)}^{2}" + assert latex(asin(x**2), inv_trig_style="power", + fold_func_brackets=True) == \ + r"\sin^{-1} {x^{2}}" + assert latex(acsc(x), inv_trig_style="full") == \ + r"\operatorname{arccsc}{\left(x \right)}" + assert latex(asinh(x), inv_trig_style="full") == \ + r"\operatorname{arsinh}{\left(x \right)}" + + assert latex(factorial(k)) == r"k!" + assert latex(factorial(-k)) == r"\left(- k\right)!" + assert latex(factorial(k)**2) == r"k!^{2}" + + assert latex(subfactorial(k)) == r"!k" + assert latex(subfactorial(-k)) == r"!\left(- k\right)" + assert latex(subfactorial(k)**2) == r"\left(!k\right)^{2}" + + assert latex(factorial2(k)) == r"k!!" + assert latex(factorial2(-k)) == r"\left(- k\right)!!" + assert latex(factorial2(k)**2) == r"k!!^{2}" + + assert latex(binomial(2, k)) == r"{\binom{2}{k}}" + assert latex(binomial(2, k)**2) == r"{\binom{2}{k}}^{2}" + + assert latex(FallingFactorial(3, k)) == r"{\left(3\right)}_{k}" + assert latex(RisingFactorial(3, k)) == r"{3}^{\left(k\right)}" + + assert latex(floor(x)) == r"\left\lfloor{x}\right\rfloor" + assert latex(ceiling(x)) == r"\left\lceil{x}\right\rceil" + assert latex(frac(x)) == r"\operatorname{frac}{\left(x\right)}" + assert latex(floor(x)**2) == r"\left\lfloor{x}\right\rfloor^{2}" + assert latex(ceiling(x)**2) == r"\left\lceil{x}\right\rceil^{2}" + assert latex(frac(x)**2) == r"\operatorname{frac}{\left(x\right)}^{2}" + + assert latex(Min(x, 2, x**3)) == r"\min\left(2, x, x^{3}\right)" + assert latex(Min(x, y)**2) == r"\min\left(x, y\right)^{2}" + assert latex(Max(x, 2, x**3)) == r"\max\left(2, x, x^{3}\right)" + assert latex(Max(x, y)**2) == r"\max\left(x, y\right)^{2}" + assert latex(Abs(x)) == r"\left|{x}\right|" + assert latex(Abs(x)**2) == r"\left|{x}\right|^{2}" + assert latex(re(x)) == r"\operatorname{re}{\left(x\right)}" + assert latex(re(x + y)) == \ + r"\operatorname{re}{\left(x\right)} + \operatorname{re}{\left(y\right)}" + assert latex(im(x)) == r"\operatorname{im}{\left(x\right)}" + assert latex(conjugate(x)) == r"\overline{x}" + assert latex(conjugate(x)**2) == r"\overline{x}^{2}" + assert latex(conjugate(x**2)) == r"\overline{x}^{2}" + assert latex(gamma(x)) == r"\Gamma\left(x\right)" + w = Wild('w') + assert latex(gamma(w)) == r"\Gamma\left(w\right)" + assert latex(Order(x)) == r"O\left(x\right)" + assert latex(Order(x, x)) == r"O\left(x\right)" + assert latex(Order(x, (x, 0))) == r"O\left(x\right)" + assert latex(Order(x, (x, oo))) == r"O\left(x; x\rightarrow \infty\right)" + assert latex(Order(x - y, (x, y))) == \ + r"O\left(x - y; x\rightarrow y\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, x, y)) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( 0, \ 0\right)\right)" + assert latex(Order(x, (x, oo), (y, oo))) == \ + r"O\left(x; \left( x, \ y\right)\rightarrow \left( \infty, \ \infty\right)\right)" + assert latex(lowergamma(x, y)) == r'\gamma\left(x, y\right)' + assert latex(lowergamma(x, y)**2) == r'\gamma^{2}\left(x, y\right)' + assert latex(uppergamma(x, y)) == r'\Gamma\left(x, y\right)' + assert latex(uppergamma(x, y)**2) == r'\Gamma^{2}\left(x, y\right)' + + assert latex(cot(x)) == r'\cot{\left(x \right)}' + assert latex(coth(x)) == r'\coth{\left(x \right)}' + assert latex(re(x)) == r'\operatorname{re}{\left(x\right)}' + assert latex(im(x)) == r'\operatorname{im}{\left(x\right)}' + assert latex(root(x, y)) == r'x^{\frac{1}{y}}' + assert latex(arg(x)) == r'\arg{\left(x \right)}' + + assert latex(zeta(x)) == r"\zeta\left(x\right)" + assert latex(zeta(x)**2) == r"\zeta^{2}\left(x\right)" + assert latex(zeta(x, y)) == r"\zeta\left(x, y\right)" + assert latex(zeta(x, y)**2) == r"\zeta^{2}\left(x, y\right)" + assert latex(dirichlet_eta(x)) == r"\eta\left(x\right)" + assert latex(dirichlet_eta(x)**2) == r"\eta^{2}\left(x\right)" + assert latex(polylog(x, y)) == r"\operatorname{Li}_{x}\left(y\right)" + assert latex( + polylog(x, y)**2) == r"\operatorname{Li}_{x}^{2}\left(y\right)" + assert latex(lerchphi(x, y, n)) == r"\Phi\left(x, y, n\right)" + assert latex(lerchphi(x, y, n)**2) == r"\Phi^{2}\left(x, y, n\right)" + assert latex(stieltjes(x)) == r"\gamma_{x}" + assert latex(stieltjes(x)**2) == r"\gamma_{x}^{2}" + assert latex(stieltjes(x, y)) == r"\gamma_{x}\left(y\right)" + assert latex(stieltjes(x, y)**2) == r"\gamma_{x}\left(y\right)^{2}" + + assert latex(elliptic_k(z)) == r"K\left(z\right)" + assert latex(elliptic_k(z)**2) == r"K^{2}\left(z\right)" + assert latex(elliptic_f(x, y)) == r"F\left(x\middle| y\right)" + assert latex(elliptic_f(x, y)**2) == r"F^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)) == r"E\left(x\middle| y\right)" + assert latex(elliptic_e(x, y)**2) == r"E^{2}\left(x\middle| y\right)" + assert latex(elliptic_e(z)) == r"E\left(z\right)" + assert latex(elliptic_e(z)**2) == r"E^{2}\left(z\right)" + assert latex(elliptic_pi(x, y, z)) == r"\Pi\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y, z)**2) == \ + r"\Pi^{2}\left(x; y\middle| z\right)" + assert latex(elliptic_pi(x, y)) == r"\Pi\left(x\middle| y\right)" + assert latex(elliptic_pi(x, y)**2) == r"\Pi^{2}\left(x\middle| y\right)" + + assert latex(Ei(x)) == r'\operatorname{Ei}{\left(x \right)}' + assert latex(Ei(x)**2) == r'\operatorname{Ei}^{2}{\left(x \right)}' + assert latex(expint(x, y)) == r'\operatorname{E}_{x}\left(y\right)' + assert latex(expint(x, y)**2) == r'\operatorname{E}_{x}^{2}\left(y\right)' + assert latex(Shi(x)**2) == r'\operatorname{Shi}^{2}{\left(x \right)}' + assert latex(Si(x)**2) == r'\operatorname{Si}^{2}{\left(x \right)}' + assert latex(Ci(x)**2) == r'\operatorname{Ci}^{2}{\left(x \right)}' + assert latex(Chi(x)**2) == r'\operatorname{Chi}^{2}\left(x\right)' + assert latex(Chi(x)) == r'\operatorname{Chi}\left(x\right)' + assert latex(jacobi(n, a, b, x)) == \ + r'P_{n}^{\left(a,b\right)}\left(x\right)' + assert latex(jacobi(n, a, b, x)**2) == \ + r'\left(P_{n}^{\left(a,b\right)}\left(x\right)\right)^{2}' + assert latex(gegenbauer(n, a, x)) == \ + r'C_{n}^{\left(a\right)}\left(x\right)' + assert latex(gegenbauer(n, a, x)**2) == \ + r'\left(C_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(chebyshevt(n, x)) == r'T_{n}\left(x\right)' + assert latex(chebyshevt(n, x)**2) == \ + r'\left(T_{n}\left(x\right)\right)^{2}' + assert latex(chebyshevu(n, x)) == r'U_{n}\left(x\right)' + assert latex(chebyshevu(n, x)**2) == \ + r'\left(U_{n}\left(x\right)\right)^{2}' + assert latex(legendre(n, x)) == r'P_{n}\left(x\right)' + assert latex(legendre(n, x)**2) == r'\left(P_{n}\left(x\right)\right)^{2}' + assert latex(assoc_legendre(n, a, x)) == \ + r'P_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_legendre(n, a, x)**2) == \ + r'\left(P_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(laguerre(n, x)) == r'L_{n}\left(x\right)' + assert latex(laguerre(n, x)**2) == r'\left(L_{n}\left(x\right)\right)^{2}' + assert latex(assoc_laguerre(n, a, x)) == \ + r'L_{n}^{\left(a\right)}\left(x\right)' + assert latex(assoc_laguerre(n, a, x)**2) == \ + r'\left(L_{n}^{\left(a\right)}\left(x\right)\right)^{2}' + assert latex(hermite(n, x)) == r'H_{n}\left(x\right)' + assert latex(hermite(n, x)**2) == r'\left(H_{n}\left(x\right)\right)^{2}' + + theta = Symbol("theta", real=True) + phi = Symbol("phi", real=True) + assert latex(Ynm(n, m, theta, phi)) == r'Y_{n}^{m}\left(\theta,\phi\right)' + assert latex(Ynm(n, m, theta, phi)**3) == \ + r'\left(Y_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + assert latex(Znm(n, m, theta, phi)) == r'Z_{n}^{m}\left(\theta,\phi\right)' + assert latex(Znm(n, m, theta, phi)**3) == \ + r'\left(Z_{n}^{m}\left(\theta,\phi\right)\right)^{3}' + + # Test latex printing of function names with "_" + assert latex(polar_lift(0)) == \ + r"\operatorname{polar\_lift}{\left(0 \right)}" + assert latex(polar_lift(0)**3) == \ + r"\operatorname{polar\_lift}^{3}{\left(0 \right)}" + + assert latex(totient(n)) == r'\phi\left(n\right)' + assert latex(totient(n) ** 2) == r'\left(\phi\left(n\right)\right)^{2}' + + assert latex(reduced_totient(n)) == r'\lambda\left(n\right)' + assert latex(reduced_totient(n) ** 2) == \ + r'\left(\lambda\left(n\right)\right)^{2}' + + assert latex(divisor_sigma(x)) == r"\sigma\left(x\right)" + assert latex(divisor_sigma(x)**2) == r"\sigma^{2}\left(x\right)" + assert latex(divisor_sigma(x, y)) == r"\sigma_y\left(x\right)" + assert latex(divisor_sigma(x, y)**2) == r"\sigma^{2}_y\left(x\right)" + + assert latex(udivisor_sigma(x)) == r"\sigma^*\left(x\right)" + assert latex(udivisor_sigma(x)**2) == r"\sigma^*^{2}\left(x\right)" + assert latex(udivisor_sigma(x, y)) == r"\sigma^*_y\left(x\right)" + assert latex(udivisor_sigma(x, y)**2) == r"\sigma^*^{2}_y\left(x\right)" + + assert latex(primenu(n)) == r'\nu\left(n\right)' + assert latex(primenu(n) ** 2) == r'\left(\nu\left(n\right)\right)^{2}' + + assert latex(primeomega(n)) == r'\Omega\left(n\right)' + assert latex(primeomega(n) ** 2) == \ + r'\left(\Omega\left(n\right)\right)^{2}' + + assert latex(LambertW(n)) == r'W\left(n\right)' + assert latex(LambertW(n, -1)) == r'W_{-1}\left(n\right)' + assert latex(LambertW(n, k)) == r'W_{k}\left(n\right)' + assert latex(LambertW(n) * LambertW(n)) == r"W^{2}\left(n\right)" + assert latex(Pow(LambertW(n), 2)) == r"W^{2}\left(n\right)" + assert latex(LambertW(n)**k) == r"W^{k}\left(n\right)" + assert latex(LambertW(n, k)**p) == r"W^{p}_{k}\left(n\right)" + + assert latex(Mod(x, 7)) == r'x \bmod 7' + assert latex(Mod(x + 1, 7)) == r'\left(x + 1\right) \bmod 7' + assert latex(Mod(7, x + 1)) == r'7 \bmod \left(x + 1\right)' + assert latex(Mod(2 * x, 7)) == r'2 x \bmod 7' + assert latex(Mod(7, 2 * x)) == r'7 \bmod 2 x' + assert latex(Mod(x, 7) + 1) == r'\left(x \bmod 7\right) + 1' + assert latex(2 * Mod(x, 7)) == r'2 \left(x \bmod 7\right)' + assert latex(Mod(7, 2 * x)**n) == r'\left(7 \bmod 2 x\right)^{n}' + + # some unknown function name should get rendered with \operatorname + fjlkd = Function('fjlkd') + assert latex(fjlkd(x)) == r'\operatorname{fjlkd}{\left(x \right)}' + # even when it is referred to without an argument + assert latex(fjlkd) == r'\operatorname{fjlkd}' + + +# test that notation passes to subclasses of the same name only +def test_function_subclass_different_name(): + class mygamma(gamma): + pass + assert latex(mygamma) == r"\operatorname{mygamma}" + assert latex(mygamma(x)) == r"\operatorname{mygamma}{\left(x \right)}" + + +def test_hyper_printing(): + from sympy.abc import x, z + + assert latex(meijerg(Tuple(pi, pi, x), Tuple(1), + (0, 1), Tuple(1, 2, 3/pi), z)) == \ + r'{G_{4, 5}^{2, 3}\left(\begin{matrix} \pi, \pi, x & 1 \\0, 1 & 1, 2, '\ + r'\frac{3}{\pi} \end{matrix} \middle| {z} \right)}' + assert latex(meijerg(Tuple(), Tuple(1), (0,), Tuple(), z)) == \ + r'{G_{1, 1}^{1, 0}\left(\begin{matrix} & 1 \\0 & \end{matrix} \middle| {z} \right)}' + assert latex(hyper((x, 2), (3,), z)) == \ + r'{{}_{2}F_{1}\left(\begin{matrix} 2, x ' \ + r'\\ 3 \end{matrix}\middle| {z} \right)}' + assert latex(hyper(Tuple(), Tuple(1), z)) == \ + r'{{}_{0}F_{1}\left(\begin{matrix} ' \ + r'\\ 1 \end{matrix}\middle| {z} \right)}' + + +def test_latex_bessel(): + from sympy.functions.special.bessel import (besselj, bessely, besseli, + besselk, hankel1, hankel2, + jn, yn, hn1, hn2) + from sympy.abc import z + assert latex(besselj(n, z**2)**k) == r'J^{k}_{n}\left(z^{2}\right)' + assert latex(bessely(n, z)) == r'Y_{n}\left(z\right)' + assert latex(besseli(n, z)) == r'I_{n}\left(z\right)' + assert latex(besselk(n, z)) == r'K_{n}\left(z\right)' + assert latex(hankel1(n, z**2)**2) == \ + r'\left(H^{(1)}_{n}\left(z^{2}\right)\right)^{2}' + assert latex(hankel2(n, z)) == r'H^{(2)}_{n}\left(z\right)' + assert latex(jn(n, z)) == r'j_{n}\left(z\right)' + assert latex(yn(n, z)) == r'y_{n}\left(z\right)' + assert latex(hn1(n, z)) == r'h^{(1)}_{n}\left(z\right)' + assert latex(hn2(n, z)) == r'h^{(2)}_{n}\left(z\right)' + + +def test_latex_fresnel(): + from sympy.functions.special.error_functions import (fresnels, fresnelc) + from sympy.abc import z + assert latex(fresnels(z)) == r'S\left(z\right)' + assert latex(fresnelc(z)) == r'C\left(z\right)' + assert latex(fresnels(z)**2) == r'S^{2}\left(z\right)' + assert latex(fresnelc(z)**2) == r'C^{2}\left(z\right)' + + +def test_latex_brackets(): + assert latex((-1)**x) == r"\left(-1\right)^{x}" + + +def test_latex_indexed(): + Psi_symbol = Symbol('Psi_0', complex=True, real=False) + Psi_indexed = IndexedBase(Symbol('Psi', complex=True, real=False)) + symbol_latex = latex(Psi_symbol * conjugate(Psi_symbol)) + indexed_latex = latex(Psi_indexed[0] * conjugate(Psi_indexed[0])) + # \\overline{{\\Psi}_{0}} {\\Psi}_{0} vs. \\Psi_{0} \\overline{\\Psi_{0}} + assert symbol_latex == r'\Psi_{0} \overline{\Psi_{0}}' + assert indexed_latex == r'\overline{{\Psi}_{0}} {\Psi}_{0}' + + # Symbol('gamma') gives r'\gamma' + interval = '\\mathrel{..}\\nobreak ' + assert latex(Indexed('x1', Symbol('i'))) == r'{x_{1}}_{i}' + assert latex(Indexed('x2', Idx('i'))) == r'{x_{2}}_{i}' + assert latex(Indexed('x3', Idx('i', Symbol('N')))) == r'{x_{3}}_{{i}_{0'+interval+'N - 1}}' + assert latex(Indexed('x3', Idx('i', Symbol('N')+1))) == r'{x_{3}}_{{i}_{0'+interval+'N}}' + assert latex(Indexed('x4', Idx('i', (Symbol('a'),Symbol('b'))))) == r'{x_{4}}_{{i}_{a'+interval+'b}}' + assert latex(IndexedBase('gamma')) == r'\gamma' + assert latex(IndexedBase('a b')) == r'a b' + assert latex(IndexedBase('a_b')) == r'a_{b}' + + +def test_latex_derivatives(): + # regular "d" for ordinary derivatives + assert latex(diff(x**3, x, evaluate=False)) == \ + r"\frac{d}{d x} x^{3}" + assert latex(diff(sin(x) + x**2, x, evaluate=False)) == \ + r"\frac{d}{d x} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False))\ + == \ + r"\frac{d^{2}}{d x^{2}} \left(x^{2} + \sin{\left(x \right)}\right)" + assert latex(diff(diff(diff(sin(x) + x**2, x, evaluate=False), evaluate=False), evaluate=False)) == \ + r"\frac{d^{3}}{d x^{3}} \left(x^{2} + \sin{\left(x \right)}\right)" + + # \partial for partial derivatives + assert latex(diff(sin(x * y), x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \sin{\left(x y \right)}" + assert latex(diff(sin(x * y) + x**2, x, evaluate=False)) == \ + r"\frac{\partial}{\partial x} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial x^{2}} \left(x^{2} + \sin{\left(x y \right)}\right)" + assert latex(diff(diff(diff(sin(x*y) + x**2, x, evaluate=False), x, evaluate=False), x, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial x^{3}} \left(x^{2} + \sin{\left(x y \right)}\right)" + + # mixed partial derivatives + f = Function("f") + assert latex(diff(diff(f(x, y), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{2}}{\partial y\partial x} " + latex(f(x, y)) + + assert latex(diff(diff(diff(f(x, y), x, evaluate=False), x, evaluate=False), y, evaluate=False)) == \ + r"\frac{\partial^{3}}{\partial y\partial x^{2}} " + latex(f(x, y)) + + # for negative nested Derivative + assert latex(diff(-diff(y**2,x,evaluate=False),x,evaluate=False)) == r'\frac{d}{d x} \left(- \frac{d}{d x} y^{2}\right)' + assert latex(diff(diff(-diff(diff(y,x,evaluate=False),x,evaluate=False),x,evaluate=False),x,evaluate=False)) == \ + r'\frac{d^{2}}{d x^{2}} \left(- \frac{d^{2}}{d x^{2}} y\right)' + + # use ordinary d when one of the variables has been integrated out + assert latex(diff(Integral(exp(-x*y), (x, 0, oo)), y, evaluate=False)) == \ + r"\frac{d}{d y} \int\limits_{0}^{\infty} e^{- x y}\, dx" + + # Derivative wrapped in power: + assert latex(diff(x, x, evaluate=False)**2) == \ + r"\left(\frac{d}{d x} x\right)^{2}" + + assert latex(diff(f(x), x)**2) == \ + r"\left(\frac{d}{d x} f{\left(x \right)}\right)^{2}" + + assert latex(diff(f(x), (x, n))) == \ + r"\frac{d^{n}}{d x^{n}} f{\left(x \right)}" + + x1 = Symbol('x1') + x2 = Symbol('x2') + assert latex(diff(f(x1, x2), x1)) == r'\frac{\partial}{\partial x_{1}} f{\left(x_{1},x_{2} \right)}' + + n1 = Symbol('n1') + assert latex(diff(f(x), (x, n1))) == r'\frac{d^{n_{1}}}{d x^{n_{1}}} f{\left(x \right)}' + + n2 = Symbol('n2') + assert latex(diff(f(x), (x, Max(n1, n2)))) == \ + r'\frac{d^{\max\left(n_{1}, n_{2}\right)}}{d x^{\max\left(n_{1}, n_{2}\right)}} f{\left(x \right)}' + + # set diff operator + assert latex(diff(f(x), x), diff_operator="rd") == r'\frac{\mathrm{d}}{\mathrm{d} x} f{\left(x \right)}' + + +def test_latex_subs(): + assert latex(Subs(x*y, (x, y), (1, 2))) == r'\left. x y \right|_{\substack{ x=1\\ y=2 }}' + + +def test_latex_integrals(): + assert latex(Integral(log(x), x)) == r"\int \log{\left(x \right)}\, dx" + assert latex(Integral(x**2, (x, 0, 1))) == \ + r"\int\limits_{0}^{1} x^{2}\, dx" + assert latex(Integral(x**2, (x, 10, 20))) == \ + r"\int\limits_{10}^{20} x^{2}\, dx" + assert latex(Integral(y*x**2, (x, 0, 1), y)) == \ + r"\int\int\limits_{0}^{1} x^{2} y\, dx\, dy" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*') == \ + r"\begin{equation*}\int\int\limits_{0}^{1} x^{2} y\, dx\, dy\end{equation*}" + assert latex(Integral(y*x**2, (x, 0, 1), y), mode='equation*', itex=True) \ + == r"$$\int\int_{0}^{1} x^{2} y\, dx\, dy$$" + assert latex(Integral(x, (x, 0))) == r"\int\limits^{0} x\, dx" + assert latex(Integral(x*y, x, y)) == r"\iint x y\, dx\, dy" + assert latex(Integral(x*y*z, x, y, z)) == r"\iiint x y z\, dx\, dy\, dz" + assert latex(Integral(x*y*z*t, x, y, z, t)) == \ + r"\iiiint t x y z\, dx\, dy\, dz\, dt" + assert latex(Integral(x, x, x, x, x, x, x)) == \ + r"\int\int\int\int\int\int x\, dx\, dx\, dx\, dx\, dx\, dx" + assert latex(Integral(x, x, y, (z, 0, 1))) == \ + r"\int\limits_{0}^{1}\int\int x\, dx\, dy\, dz" + + # for negative nested Integral + assert latex(Integral(-Integral(y**2,x),x)) == \ + r'\int \left(- \int y^{2}\, dx\right)\, dx' + assert latex(Integral(-Integral(-Integral(y,x),x),x)) == \ + r'\int \left(- \int \left(- \int y\, dx\right)\, dx\right)\, dx' + + # fix issue #10806 + assert latex(Integral(z, z)**2) == r"\left(\int z\, dz\right)^{2}" + assert latex(Integral(x + z, z)) == r"\int \left(x + z\right)\, dz" + assert latex(Integral(x+z/2, z)) == \ + r"\int \left(x + \frac{z}{2}\right)\, dz" + assert latex(Integral(x**y, z)) == r"\int x^{y}\, dz" + + # set diff operator + assert latex(Integral(x, x), diff_operator="rd") == r'\int x\, \mathrm{d}x' + assert latex(Integral(x, (x, 0, 1)), diff_operator="rd") == r'\int\limits_{0}^{1} x\, \mathrm{d}x' + + +def test_latex_sets(): + for s in (frozenset, set): + assert latex(s([x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + s = FiniteSet + assert latex(s(*[x*y, x**2])) == r"\left\{x^{2}, x y\right\}" + assert latex(s(*range(1, 6))) == r"\left\{1, 2, 3, 4, 5\right\}" + assert latex(s(*range(1, 13))) == \ + r"\left\{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12\right\}" + + +def test_latex_SetExpr(): + iv = Interval(1, 3) + se = SetExpr(iv) + assert latex(se) == r"SetExpr\left(\left[1, 3\right]\right)" + + +def test_latex_Range(): + assert latex(Range(1, 51)) == r'\left\{1, 2, \ldots, 50\right\}' + assert latex(Range(1, 4)) == r'\left\{1, 2, 3\right\}' + assert latex(Range(0, 3, 1)) == r'\left\{0, 1, 2\right\}' + assert latex(Range(0, 30, 1)) == r'\left\{0, 1, \ldots, 29\right\}' + assert latex(Range(30, 1, -1)) == r'\left\{30, 29, \ldots, 2\right\}' + assert latex(Range(0, oo, 2)) == r'\left\{0, 2, \ldots\right\}' + assert latex(Range(oo, -2, -2)) == r'\left\{\ldots, 2, 0\right\}' + assert latex(Range(-2, -oo, -1)) == r'\left\{-2, -3, \ldots\right\}' + assert latex(Range(-oo, oo)) == r'\left\{\ldots, -1, 0, 1, \ldots\right\}' + assert latex(Range(oo, -oo, -1)) == r'\left\{\ldots, 1, 0, -1, \ldots\right\}' + + a, b, c = symbols('a:c') + assert latex(Range(a, b, c)) == r'\text{Range}\left(a, b, c\right)' + assert latex(Range(a, 10, 1)) == r'\text{Range}\left(a, 10\right)' + assert latex(Range(0, b, 1)) == r'\text{Range}\left(b\right)' + assert latex(Range(0, 10, c)) == r'\text{Range}\left(0, 10, c\right)' + + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert latex(Range(i, i + 3)) == r'\left\{i, i + 1, i + 2\right\}' + assert latex(Range(-oo, n, 2)) == r'\left\{\ldots, n - 4, n - 2\right\}' + assert latex(Range(p, oo)) == r'\left\{p, p + 1, \ldots\right\}' + # The following will work if __iter__ is improved + # assert latex(Range(-3, p + 7)) == r'\left\{-3, -2, \ldots, p + 6\right\}' + # Must have integer assumptions + assert latex(Range(a, a + 3)) == r'\text{Range}\left(a, a + 3\right)' + + +def test_latex_sequences(): + s1 = SeqFormula(a**2, (0, oo)) + s2 = SeqPer((1, 2)) + + latex_str = r'\left[0, 1, 4, 9, \ldots\right]' + assert latex(s1) == latex_str + + latex_str = r'\left[1, 2, 1, 2, \ldots\right]' + assert latex(s2) == latex_str + + s3 = SeqFormula(a**2, (0, 2)) + s4 = SeqPer((1, 2), (0, 2)) + + latex_str = r'\left[0, 1, 4\right]' + assert latex(s3) == latex_str + + latex_str = r'\left[1, 2, 1\right]' + assert latex(s4) == latex_str + + s5 = SeqFormula(a**2, (-oo, 0)) + s6 = SeqPer((1, 2), (-oo, 0)) + + latex_str = r'\left[\ldots, 9, 4, 1, 0\right]' + assert latex(s5) == latex_str + + latex_str = r'\left[\ldots, 2, 1, 2, 1\right]' + assert latex(s6) == latex_str + + latex_str = r'\left[1, 3, 5, 11, \ldots\right]' + assert latex(SeqAdd(s1, s2)) == latex_str + + latex_str = r'\left[1, 3, 5\right]' + assert latex(SeqAdd(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 11, 5, 3, 1\right]' + assert latex(SeqAdd(s5, s6)) == latex_str + + latex_str = r'\left[0, 2, 4, 18, \ldots\right]' + assert latex(SeqMul(s1, s2)) == latex_str + + latex_str = r'\left[0, 2, 4\right]' + assert latex(SeqMul(s3, s4)) == latex_str + + latex_str = r'\left[\ldots, 18, 4, 2, 0\right]' + assert latex(SeqMul(s5, s6)) == latex_str + + # Sequences with symbolic limits, issue 12629 + s7 = SeqFormula(a**2, (a, 0, x)) + latex_str = r'\left\{a^{2}\right\}_{a=0}^{x}' + assert latex(s7) == latex_str + + b = Symbol('b') + s8 = SeqFormula(b*a**2, (a, 0, 2)) + latex_str = r'\left[0, b, 4 b\right]' + assert latex(s8) == latex_str + + +def test_latex_FourierSeries(): + latex_str = \ + r'2 \sin{\left(x \right)} - \sin{\left(2 x \right)} + \frac{2 \sin{\left(3 x \right)}}{3} + \ldots' + assert latex(fourier_series(x, (x, -pi, pi))) == latex_str + + +def test_latex_FormalPowerSeries(): + latex_str = r'\sum_{k=1}^{\infty} - \frac{\left(-1\right)^{- k} x^{k}}{k}' + assert latex(fps(log(1 + x))) == latex_str + + +def test_latex_intervals(): + a = Symbol('a', real=True) + assert latex(Interval(0, 0)) == r"\left\{0\right\}" + assert latex(Interval(0, a)) == r"\left[0, a\right]" + assert latex(Interval(0, a, False, False)) == r"\left[0, a\right]" + assert latex(Interval(0, a, True, False)) == r"\left(0, a\right]" + assert latex(Interval(0, a, False, True)) == r"\left[0, a\right)" + assert latex(Interval(0, a, True, True)) == r"\left(0, a\right)" + + +def test_latex_AccumuBounds(): + a = Symbol('a', real=True) + assert latex(AccumBounds(0, 1)) == r"\left\langle 0, 1\right\rangle" + assert latex(AccumBounds(0, a)) == r"\left\langle 0, a\right\rangle" + assert latex(AccumBounds(a + 1, a + 2)) == \ + r"\left\langle a + 1, a + 2\right\rangle" + + +def test_latex_emptyset(): + assert latex(S.EmptySet) == r"\emptyset" + + +def test_latex_universalset(): + assert latex(S.UniversalSet) == r"\mathbb{U}" + + +def test_latex_commutator(): + A = Operator('A') + B = Operator('B') + comm = Commutator(B, A) + assert latex(comm.doit()) == r"- (A B - B A)" + + +def test_latex_union(): + assert latex(Union(Interval(0, 1), Interval(2, 3))) == \ + r"\left[0, 1\right] \cup \left[2, 3\right]" + assert latex(Union(Interval(1, 1), Interval(2, 2), Interval(3, 4))) == \ + r"\left\{1, 2\right\} \cup \left[3, 4\right]" + + +def test_latex_intersection(): + assert latex(Intersection(Interval(0, 1), Interval(x, y))) == \ + r"\left[0, 1\right] \cap \left[x, y\right]" + + +def test_latex_symmetric_difference(): + assert latex(SymmetricDifference(Interval(2, 5), Interval(4, 7), + evaluate=False)) == \ + r'\left[2, 5\right] \triangle \left[4, 7\right]' + + +def test_latex_Complement(): + assert latex(Complement(S.Reals, S.Naturals)) == \ + r"\mathbb{R} \setminus \mathbb{N}" + + +def test_latex_productset(): + line = Interval(0, 1) + bigline = Interval(0, 10) + fset = FiniteSet(1, 2, 3) + assert latex(line**2) == r"%s^{2}" % latex(line) + assert latex(line**10) == r"%s^{10}" % latex(line) + assert latex((line * bigline * fset).flatten()) == r"%s \times %s \times %s" % ( + latex(line), latex(bigline), latex(fset)) + + +def test_latex_powerset(): + fset = FiniteSet(1, 2, 3) + assert latex(PowerSet(fset)) == r'\mathcal{P}\left(\left\{1, 2, 3\right\}\right)' + + +def test_latex_ordinals(): + w = OrdinalOmega() + assert latex(w) == r"\omega" + wp = OmegaPower(2, 3) + assert latex(wp) == r'3 \omega^{2}' + assert latex(Ordinal(wp, OmegaPower(1, 1))) == r'3 \omega^{2} + \omega' + assert latex(Ordinal(OmegaPower(2, 1), OmegaPower(1, 2))) == r'\omega^{2} + 2 \omega' + + +def test_set_operators_parenthesis(): + a, b, c, d = symbols('a:d') + A = FiniteSet(a) + B = FiniteSet(b) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(A, B, evaluate=False) + U2 = Union(C, D, evaluate=False) + I1 = Intersection(A, B, evaluate=False) + I2 = Intersection(C, D, evaluate=False) + C1 = Complement(A, B, evaluate=False) + C2 = Complement(C, D, evaluate=False) + D1 = SymmetricDifference(A, B, evaluate=False) + D2 = SymmetricDifference(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(A, B) + P2 = ProductSet(C, D) + + assert latex(Intersection(A, U2, evaluate=False)) == \ + r'\left\{a\right\} \cap ' \ + r'\left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \cup \left\{d\right\}\right)' + assert latex(Intersection(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Intersection(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cap \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Intersection(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cap \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Union(A, I2, evaluate=False)) == \ + r'\left\{a\right\} \cup ' \ + r'\left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \cap \left\{d\right\}\right)' + assert latex(Union(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Union(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \cup \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(Union(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\cup \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + assert latex(Complement(A, C2, evaluate=False)) == \ + r'\left\{a\right\} \setminus \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(Complement(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\setminus \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(Complement(D1, D2, evaluate=False)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \setminus ' \ + r'\left(\left\{c\right\} \triangle \left\{d\right\}\right)' + assert latex(Complement(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) '\ + r'\setminus \left(\left\{c\right\} \times '\ + r'\left\{d\right\}\right)' + + assert latex(SymmetricDifference(A, D2, evaluate=False)) == \ + r'\left\{a\right\} \triangle \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + assert latex(SymmetricDifference(U1, U2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(I1, I2, evaluate=False)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(SymmetricDifference(C1, C2, evaluate=False)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \triangle ' \ + r'\left(\left\{c\right\} \setminus \left\{d\right\}\right)' + assert latex(SymmetricDifference(P1, P2, evaluate=False)) == \ + r'\left(\left\{a\right\} \times \left\{b\right\}\right) ' \ + r'\triangle \left(\left\{c\right\} \times ' \ + r'\left\{d\right\}\right)' + + # XXX This can be incorrect since cartesian product is not associative + assert latex(ProductSet(A, P2).flatten()) == \ + r'\left\{a\right\} \times \left\{c\right\} \times ' \ + r'\left\{d\right\}' + assert latex(ProductSet(U1, U2)) == \ + r'\left(\left\{a\right\} \cup \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cup ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(I1, I2)) == \ + r'\left(\left\{a\right\} \cap \left\{b\right\}\right) ' \ + r'\times \left(\left\{c\right\} \cap ' \ + r'\left\{d\right\}\right)' + assert latex(ProductSet(C1, C2)) == \ + r'\left(\left\{a\right\} \setminus ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\setminus \left\{d\right\}\right)' + assert latex(ProductSet(D1, D2)) == \ + r'\left(\left\{a\right\} \triangle ' \ + r'\left\{b\right\}\right) \times \left(\left\{c\right\} ' \ + r'\triangle \left\{d\right\}\right)' + + +def test_latex_Complexes(): + assert latex(S.Complexes) == r"\mathbb{C}" + + +def test_latex_Naturals(): + assert latex(S.Naturals) == r"\mathbb{N}" + + +def test_latex_Naturals0(): + assert latex(S.Naturals0) == r"\mathbb{N}_0" + + +def test_latex_Integers(): + assert latex(S.Integers) == r"\mathbb{Z}" + + +def test_latex_ImageSet(): + x = Symbol('x') + assert latex(ImageSet(Lambda(x, x**2), S.Naturals)) == \ + r"\left\{x^{2}\; \middle|\; x \in \mathbb{N}\right\}" + + y = Symbol('y') + imgset = ImageSet(Lambda((x, y), x + y), {1, 2, 3}, {3, 4}) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; x \in \left\{1, 2, 3\right\}, y \in \left\{3, 4\right\}\right\}" + + imgset = ImageSet(Lambda(((x, y),), x + y), ProductSet({1, 2, 3}, {3, 4})) + assert latex(imgset) == \ + r"\left\{x + y\; \middle|\; \left( x, \ y\right) \in \left\{1, 2, 3\right\} \times \left\{3, 4\right\}\right\}" + + +def test_latex_ConditionSet(): + x = Symbol('x') + assert latex(ConditionSet(x, Eq(x**2, 1), S.Reals)) == \ + r"\left\{x\; \middle|\; x \in \mathbb{R} \wedge x^{2} = 1 \right\}" + assert latex(ConditionSet(x, Eq(x**2, 1), S.UniversalSet)) == \ + r"\left\{x\; \middle|\; x^{2} = 1 \right\}" + + +def test_latex_ComplexRegion(): + assert latex(ComplexRegion(Interval(3, 5)*Interval(4, 6))) == \ + r"\left\{x + y i\; \middle|\; x, y \in \left[3, 5\right] \times \left[4, 6\right] \right\}" + assert latex(ComplexRegion(Interval(0, 1)*Interval(0, 2*pi), polar=True)) == \ + r"\left\{r \left(i \sin{\left(\theta \right)} + \cos{\left(\theta "\ + r"\right)}\right)\; \middle|\; r, \theta \in \left[0, 1\right] \times \left[0, 2 \pi\right) \right\}" + + +def test_latex_Contains(): + x = Symbol('x') + assert latex(Contains(x, S.Naturals)) == r"x \in \mathbb{N}" + + +def test_latex_sum(): + assert latex(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\sum_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Sum(x**2, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} x^{2}" + assert latex(Sum(x**2 + y, (x, -2, 2))) == \ + r"\sum_{x=-2}^{2} \left(x^{2} + y\right)" + assert latex(Sum(x**2 + y, (x, -2, 2))**2) == \ + r"\left(\sum_{x=-2}^{2} \left(x^{2} + y\right)\right)^{2}" + + +def test_latex_product(): + assert latex(Product(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + r"\prod_{\substack{-2 \leq x \leq 2\\-5 \leq y \leq 5}} x y^{2}" + assert latex(Product(x**2, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} x^{2}" + assert latex(Product(x**2 + y, (x, -2, 2))) == \ + r"\prod_{x=-2}^{2} \left(x^{2} + y\right)" + + assert latex(Product(x, (x, -2, 2))**2) == \ + r"\left(\prod_{x=-2}^{2} x\right)^{2}" + + +def test_latex_limits(): + assert latex(Limit(x, x, oo)) == r"\lim_{x \to \infty} x" + + # issue 8175 + f = Function('f') + assert latex(Limit(f(x), x, 0)) == r"\lim_{x \to 0^+} f{\left(x \right)}" + assert latex(Limit(f(x), x, 0, "-")) == \ + r"\lim_{x \to 0^-} f{\left(x \right)}" + + # issue #10806 + assert latex(Limit(f(x), x, 0)**2) == \ + r"\left(\lim_{x \to 0^+} f{\left(x \right)}\right)^{2}" + # bi-directional limit + assert latex(Limit(f(x), x, 0, dir='+-')) == \ + r"\lim_{x \to 0} f{\left(x \right)}" + + +def test_latex_log(): + assert latex(log(x)) == r"\log{\left(x \right)}" + assert latex(log(x), ln_notation=True) == r"\ln{\left(x \right)}" + assert latex(log(x) + log(y)) == \ + r"\log{\left(x \right)} + \log{\left(y \right)}" + assert latex(log(x) + log(y), ln_notation=True) == \ + r"\ln{\left(x \right)} + \ln{\left(y \right)}" + assert latex(pow(log(x), x)) == r"\log{\left(x \right)}^{x}" + assert latex(pow(log(x), x), ln_notation=True) == \ + r"\ln{\left(x \right)}^{x}" + + +def test_issue_3568(): + beta = Symbol(r'\beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + beta = Symbol(r'beta') + y = beta + x + assert latex(y) in [r'\beta + x', r'x + \beta'] + + +def test_latex(): + assert latex((2*tau)**Rational(7, 2)) == r"8 \sqrt{2} \tau^{\frac{7}{2}}" + assert latex((2*mu)**Rational(7, 2), mode='equation*') == \ + r"\begin{equation*}8 \sqrt{2} \mu^{\frac{7}{2}}\end{equation*}" + assert latex((2*mu)**Rational(7, 2), mode='equation', itex=True) == \ + r"$$8 \sqrt{2} \mu^{\frac{7}{2}}$$" + assert latex([2/x, y]) == r"\left[ \frac{2}{x}, \ y\right]" + + +def test_latex_dict(): + d = {Rational(1): 1, x**2: 2, x: 3, x**3: 4} + assert latex(d) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + D = Dict(d) + assert latex(D) == \ + r'\left\{ 1 : 1, \ x : 3, \ x^{2} : 2, \ x^{3} : 4\right\}' + + +def test_latex_list(): + ll = [Symbol('omega1'), Symbol('a'), Symbol('alpha')] + assert latex(ll) == r'\left[ \omega_{1}, \ a, \ \alpha\right]' + + +def test_latex_NumberSymbols(): + assert latex(S.Catalan) == "G" + assert latex(S.EulerGamma) == r"\gamma" + assert latex(S.Exp1) == "e" + assert latex(S.GoldenRatio) == r"\phi" + assert latex(S.Pi) == r"\pi" + assert latex(S.TribonacciConstant) == r"\text{TribonacciConstant}" + + +def test_latex_rational(): + # tests issue 3973 + assert latex(-Rational(1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(-1, 2)) == r"- \frac{1}{2}" + assert latex(Rational(1, -2)) == r"- \frac{1}{2}" + assert latex(-Rational(-1, 2)) == r"\frac{1}{2}" + assert latex(-Rational(1, 2)*x) == r"- \frac{x}{2}" + assert latex(-Rational(1, 2)*x + Rational(-2, 3)*y) == \ + r"- \frac{x}{2} - \frac{2 y}{3}" + + +def test_latex_inverse(): + # tests issue 4129 + assert latex(1/x) == r"\frac{1}{x}" + assert latex(1/(x + y)) == r"\frac{1}{x + y}" + + +def test_latex_DiracDelta(): + assert latex(DiracDelta(x)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x)**2) == r"\left(\delta\left(x\right)\right)^{2}" + assert latex(DiracDelta(x, 0)) == r"\delta\left(x\right)" + assert latex(DiracDelta(x, 5)) == \ + r"\delta^{\left( 5 \right)}\left( x \right)" + assert latex(DiracDelta(x, 5)**2) == \ + r"\left(\delta^{\left( 5 \right)}\left( x \right)\right)^{2}" + + +def test_latex_Heaviside(): + assert latex(Heaviside(x)) == r"\theta\left(x\right)" + assert latex(Heaviside(x)**2) == r"\left(\theta\left(x\right)\right)^{2}" + + +def test_latex_KroneckerDelta(): + assert latex(KroneckerDelta(x, y)) == r"\delta_{x y}" + assert latex(KroneckerDelta(x, y + 1)) == r"\delta_{x, y + 1}" + # issue 6578 + assert latex(KroneckerDelta(x + 1, y)) == r"\delta_{y, x + 1}" + assert latex(Pow(KroneckerDelta(x, y), 2, evaluate=False)) == \ + r"\left(\delta_{x y}\right)^{2}" + + +def test_latex_LeviCivita(): + assert latex(LeviCivita(x, y, z)) == r"\varepsilon_{x y z}" + assert latex(LeviCivita(x, y, z)**2) == \ + r"\left(\varepsilon_{x y z}\right)^{2}" + assert latex(LeviCivita(x, y, z + 1)) == r"\varepsilon_{x, y, z + 1}" + assert latex(LeviCivita(x, y + 1, z)) == r"\varepsilon_{x, y + 1, z}" + assert latex(LeviCivita(x + 1, y, z)) == r"\varepsilon_{x + 1, y, z}" + + +def test_mode(): + expr = x + y + assert latex(expr) == r'x + y' + assert latex(expr, mode='plain') == r'x + y' + assert latex(expr, mode='inline') == r'$x + y$' + assert latex( + expr, mode='equation*') == r'\begin{equation*}x + y\end{equation*}' + assert latex( + expr, mode='equation') == r'\begin{equation}x + y\end{equation}' + raises(ValueError, lambda: latex(expr, mode='foo')) + + +def test_latex_mathieu(): + assert latex(mathieuc(x, y, z)) == r"C\left(x, y, z\right)" + assert latex(mathieus(x, y, z)) == r"S\left(x, y, z\right)" + assert latex(mathieuc(x, y, z)**2) == r"C\left(x, y, z\right)^{2}" + assert latex(mathieus(x, y, z)**2) == r"S\left(x, y, z\right)^{2}" + assert latex(mathieucprime(x, y, z)) == r"C^{\prime}\left(x, y, z\right)" + assert latex(mathieusprime(x, y, z)) == r"S^{\prime}\left(x, y, z\right)" + assert latex(mathieucprime(x, y, z)**2) == r"C^{\prime}\left(x, y, z\right)^{2}" + assert latex(mathieusprime(x, y, z)**2) == r"S^{\prime}\left(x, y, z\right)^{2}" + +def test_latex_Piecewise(): + p = Piecewise((x, x < 1), (x**2, True)) + assert latex(p) == r"\begin{cases} x & \text{for}\: x < 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + assert latex(p, itex=True) == \ + r"\begin{cases} x & \text{for}\: x \lt 1 \\x^{2} &" \ + r" \text{otherwise} \end{cases}" + p = Piecewise((x, x < 0), (0, x >= 0)) + assert latex(p) == r'\begin{cases} x & \text{for}\: x < 0 \\0 &' \ + r' \text{otherwise} \end{cases}' + A, B = symbols("A B", commutative=False) + p = Piecewise((A**2, Eq(A, B)), (A*B, True)) + s = r"\begin{cases} A^{2} & \text{for}\: A = B \\A B & \text{otherwise} \end{cases}" + assert latex(p) == s + assert latex(A*p) == r"A \left(%s\right)" % s + assert latex(p*A) == r"\left(%s\right) A" % s + assert latex(Piecewise((x, x < 1), (x**2, x < 2))) == \ + r'\begin{cases} x & ' \ + r'\text{for}\: x < 1 \\x^{2} & \text{for}\: x < 2 \end{cases}' + + +def test_latex_Matrix(): + M = Matrix([[1 + x, y], [y, x - 1]]) + assert latex(M) == \ + r'\left[\begin{matrix}x + 1 & y\\y & x - 1\end{matrix}\right]' + assert latex(M, mode='inline') == \ + r'$\left[\begin{smallmatrix}x + 1 & y\\' \ + r'y & x - 1\end{smallmatrix}\right]$' + assert latex(M, mat_str='array') == \ + r'\left[\begin{array}{cc}x + 1 & y\\y & x - 1\end{array}\right]' + assert latex(M, mat_str='bmatrix') == \ + r'\left[\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}\right]' + assert latex(M, mat_delim=None, mat_str='bmatrix') == \ + r'\begin{bmatrix}x + 1 & y\\y & x - 1\end{bmatrix}' + + M2 = Matrix(1, 11, range(11)) + assert latex(M2) == \ + r'\left[\begin{array}{ccccccccccc}' \ + r'0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + + +def test_latex_matrix_with_functions(): + t = symbols('t') + theta1 = symbols('theta1', cls=Function) + + M = Matrix([[sin(theta1(t)), cos(theta1(t))], + [cos(theta1(t).diff(t)), sin(theta1(t).diff(t))]]) + + expected = (r'\left[\begin{matrix}\sin{\left(' + r'\theta_{1}{\left(t \right)} \right)} & ' + r'\cos{\left(\theta_{1}{\left(t \right)} \right)' + r'}\\\cos{\left(\frac{d}{d t} \theta_{1}{\left(t ' + r'\right)} \right)} & \sin{\left(\frac{d}{d t} ' + r'\theta_{1}{\left(t \right)} \right' + r')}\end{matrix}\right]') + + assert latex(M) == expected + + +def test_latex_NDimArray(): + x, y, z, w = symbols("x y z w") + + for ArrayType in (ImmutableDenseNDimArray, ImmutableSparseNDimArray, + MutableDenseNDimArray, MutableSparseNDimArray): + # Basic: scalar array + M = ArrayType(x) + + assert latex(M) == r"x" + + M = ArrayType([[1 / x, y], [z, w]]) + M1 = ArrayType([1 / x, y, z]) + + M2 = tensorproduct(M1, M) + M3 = tensorproduct(M, M) + + assert latex(M) == \ + r'\left[\begin{matrix}\frac{1}{x} & y\\z & w\end{matrix}\right]' + assert latex(M1) == \ + r"\left[\begin{matrix}\frac{1}{x} & y & z\end{matrix}\right]" + assert latex(M2) == \ + r"\left[\begin{matrix}" \ + r"\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right] & " \ + r"\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right]" \ + r"\end{matrix}\right]" + assert latex(M3) == \ + r"""\left[\begin{matrix}"""\ + r"""\left[\begin{matrix}\frac{1}{x^{2}} & \frac{y}{x}\\\frac{z}{x} & \frac{w}{x}\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{y}{x} & y^{2}\\y z & w y\end{matrix}\right]\\"""\ + r"""\left[\begin{matrix}\frac{z}{x} & y z\\z^{2} & w z\end{matrix}\right] & """\ + r"""\left[\begin{matrix}\frac{w}{x} & w y\\w z & w^{2}\end{matrix}\right]"""\ + r"""\end{matrix}\right]""" + + Mrow = ArrayType([[x, y, 1/z]]) + Mcolumn = ArrayType([[x], [y], [1/z]]) + Mcol2 = ArrayType([Mcolumn.tolist()]) + + assert latex(Mrow) == \ + r"\left[\left[\begin{matrix}x & y & \frac{1}{z}\end{matrix}\right]\right]" + assert latex(Mcolumn) == \ + r"\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]" + assert latex(Mcol2) == \ + r'\left[\begin{matrix}\left[\begin{matrix}x\\y\\\frac{1}{z}\end{matrix}\right]\end{matrix}\right]' + + +def test_latex_mul_symbol(): + assert latex(4*4**x, mul_symbol='times') == r"4 \times 4^{x}" + assert latex(4*4**x, mul_symbol='dot') == r"4 \cdot 4^{x}" + assert latex(4*4**x, mul_symbol='ldot') == r"4 \,.\, 4^{x}" + + assert latex(4*x, mul_symbol='times') == r"4 \times x" + assert latex(4*x, mul_symbol='dot') == r"4 \cdot x" + assert latex(4*x, mul_symbol='ldot') == r"4 \,.\, x" + + +def test_latex_issue_4381(): + y = 4*4**log(2) + assert latex(y) == r'4 \cdot 4^{\log{\left(2 \right)}}' + assert latex(1/y) == r'\frac{1}{4 \cdot 4^{\log{\left(2 \right)}}}' + + +def test_latex_issue_4576(): + assert latex(Symbol("beta_13_2")) == r"\beta_{13 2}" + assert latex(Symbol("beta_132_20")) == r"\beta_{132 20}" + assert latex(Symbol("beta_13")) == r"\beta_{13}" + assert latex(Symbol("x_a_b")) == r"x_{a b}" + assert latex(Symbol("x_1_2_3")) == r"x_{1 2 3}" + assert latex(Symbol("x_a_b1")) == r"x_{a b1}" + assert latex(Symbol("x_a_1")) == r"x_{a 1}" + assert latex(Symbol("x_1_a")) == r"x_{1 a}" + assert latex(Symbol("x_1^aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_1__aa")) == r"x^{aa}_{1}" + assert latex(Symbol("x_11^a")) == r"x^{a}_{11}" + assert latex(Symbol("x_11__a")) == r"x^{a}_{11}" + assert latex(Symbol("x_a_a_a_a")) == r"x_{a a a a}" + assert latex(Symbol("x_a_a^a^a")) == r"x^{a a}_{a a}" + assert latex(Symbol("x_a_a__a__a")) == r"x^{a a}_{a a}" + assert latex(Symbol("alpha_11")) == r"\alpha_{11}" + assert latex(Symbol("alpha_11_11")) == r"\alpha_{11 11}" + assert latex(Symbol("alpha_alpha")) == r"\alpha_{\alpha}" + assert latex(Symbol("alpha^aleph")) == r"\alpha^{\aleph}" + assert latex(Symbol("alpha__aleph")) == r"\alpha^{\aleph}" + + +def test_latex_pow_fraction(): + x = Symbol('x') + # Testing exp + assert r'e^{-x}' in latex(exp(-x)/2).replace(' ', '') # Remove Whitespace + + # Testing e^{-x} in case future changes alter behavior of muls or fracs + # In particular current output is \frac{1}{2}e^{- x} but perhaps this will + # change to \frac{e^{-x}}{2} + + # Testing general, non-exp, power + assert r'3^{-x}' in latex(3**-x/2).replace(' ', '') + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert latex(A*B*C**-1) == r"A B C^{-1}" + assert latex(C**-1*A*B) == r"C^{-1} A B" + assert latex(A*C**-1*B) == r"A C^{-1} B" + + +def test_latex_order(): + expr = x**3 + x**2*y + y**4 + 3*x*y**3 + + assert latex(expr, order='lex') == r"x^{3} + x^{2} y + 3 x y^{3} + y^{4}" + assert latex( + expr, order='rev-lex') == r"y^{4} + 3 x y^{3} + x^{2} y + x^{3}" + assert latex(expr, order='none') == r"x^{3} + y^{4} + y x^{2} + 3 x y^{3}" + + +def test_latex_Lambda(): + assert latex(Lambda(x, x + 1)) == r"\left( x \mapsto x + 1 \right)" + assert latex(Lambda((x, y), x + 1)) == r"\left( \left( x, \ y\right) \mapsto x + 1 \right)" + assert latex(Lambda(x, x)) == r"\left( x \mapsto x \right)" + +def test_latex_PolyElement(): + Ruv, u, v = ring("u,v", ZZ) + Rxyz, x, y, z = ring("x,y,z", Ruv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex((u**2 + 3*u*v + 1)*x**2*y + u + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + u + 1" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x" + assert latex((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == \ + r"\left({u}^{2} + 3 u v + 1\right) {x}^{2} y + \left(u + 1\right) x + 1" + assert latex((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == \ + r"-\left({u}^{2} - 3 u v + 1\right) {x}^{2} y - \left(u + 1\right) x - 1" + + assert latex(-(v**2 + v + 1)*x + 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x + 3 u v + 1" + assert latex(-(v**2 + v + 1)*x - 3*u*v + 1) == \ + r"-\left({v}^{2} + v + 1\right) x - 3 u v + 1" + + +def test_latex_FracElement(): + Fuv, u, v = field("u,v", ZZ) + Fxyzt, x, y, z, t = field("x,y,z,t", Fuv) + + assert latex(x - x) == r"0" + assert latex(x - 1) == r"x - 1" + assert latex(x + 1) == r"x + 1" + + assert latex(x/3) == r"\frac{x}{3}" + assert latex(x/z) == r"\frac{x}{z}" + assert latex(x*y/z) == r"\frac{x y}{z}" + assert latex(x/(z*t)) == r"\frac{x}{z t}" + assert latex(x*y/(z*t)) == r"\frac{x y}{z t}" + + assert latex((x - 1)/y) == r"\frac{x - 1}{y}" + assert latex((x + 1)/y) == r"\frac{x + 1}{y}" + assert latex((-x - 1)/y) == r"\frac{-x - 1}{y}" + assert latex((x + 1)/(y*z)) == r"\frac{x + 1}{y z}" + assert latex(-y/(x + 1)) == r"\frac{-y}{x + 1}" + assert latex(y*z/(x + 1)) == r"\frac{y z}{x + 1}" + + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - 1}" + assert latex(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == \ + r"\frac{\left(u + 1\right) x y + 1}{\left(v - 1\right) z - u v t - 1}" + + +def test_latex_Poly(): + assert latex(Poly(x**2 + 2 * x, x)) == \ + r"\operatorname{Poly}{\left( x^{2} + 2 x, x, domain=\mathbb{Z} \right)}" + assert latex(Poly(x/y, x)) == \ + r"\operatorname{Poly}{\left( \frac{1}{y} x, x, domain=\mathbb{Z}\left(y\right) \right)}" + assert latex(Poly(2.0*x + y)) == \ + r"\operatorname{Poly}{\left( 2.0 x + 1.0 y, x, y, domain=\mathbb{R} \right)}" + + +def test_latex_Poly_order(): + assert latex(Poly([a, 1, b, 2, c, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{5} + x^{4} + b x^{3} + 2 x^{2} + c'\ + r' x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly([a, 1, b+c, 2, 3], x)) == \ + r'\operatorname{Poly}{\left( a x^{4} + x^{3} + \left(b + c\right) '\ + r'x^{2} + 2 x + 3, x, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + assert latex(Poly(a*x**3 + x**2*y - x*y - c*y**3 - b*x*y**2 + y - a*x + b, + (x, y))) == \ + r'\operatorname{Poly}{\left( a x^{3} + x^{2}y - b xy^{2} - xy - '\ + r'a x - c y^{3} + y + b, x, y, domain=\mathbb{Z}\left[a, b, c\right] \right)}' + + +def test_latex_ComplexRootOf(): + assert latex(rootof(x**5 + x + 3, 0)) == \ + r"\operatorname{CRootOf} {\left(x^{5} + x + 3, 0\right)}" + + +def test_latex_RootSum(): + assert latex(RootSum(x**5 + x + 3, sin)) == \ + r"\operatorname{RootSum} {\left(x^{5} + x + 3, \left( x \mapsto \sin{\left(x \right)} \right)\right)}" + + +def test_settings(): + raises(TypeError, lambda: latex(x*y, method="garbage")) + + +def test_latex_numbers(): + assert latex(catalan(n)) == r"C_{n}" + assert latex(catalan(n)**2) == r"C_{n}^{2}" + assert latex(bernoulli(n)) == r"B_{n}" + assert latex(bernoulli(n, x)) == r"B_{n}\left(x\right)" + assert latex(bernoulli(n)**2) == r"B_{n}^{2}" + assert latex(bernoulli(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(genocchi(n)) == r"G_{n}" + assert latex(genocchi(n, x)) == r"G_{n}\left(x\right)" + assert latex(genocchi(n)**2) == r"G_{n}^{2}" + assert latex(genocchi(n, x)**2) == r"G_{n}^{2}\left(x\right)" + assert latex(bell(n)) == r"B_{n}" + assert latex(bell(n, x)) == r"B_{n}\left(x\right)" + assert latex(bell(n, m, (x, y))) == r"B_{n, m}\left(x, y\right)" + assert latex(bell(n)**2) == r"B_{n}^{2}" + assert latex(bell(n, x)**2) == r"B_{n}^{2}\left(x\right)" + assert latex(bell(n, m, (x, y))**2) == r"B_{n, m}^{2}\left(x, y\right)" + assert latex(fibonacci(n)) == r"F_{n}" + assert latex(fibonacci(n, x)) == r"F_{n}\left(x\right)" + assert latex(fibonacci(n)**2) == r"F_{n}^{2}" + assert latex(fibonacci(n, x)**2) == r"F_{n}^{2}\left(x\right)" + assert latex(lucas(n)) == r"L_{n}" + assert latex(lucas(n)**2) == r"L_{n}^{2}" + assert latex(tribonacci(n)) == r"T_{n}" + assert latex(tribonacci(n, x)) == r"T_{n}\left(x\right)" + assert latex(tribonacci(n)**2) == r"T_{n}^{2}" + assert latex(tribonacci(n, x)**2) == r"T_{n}^{2}\left(x\right)" + assert latex(mobius(n)) == r"\mu\left(n\right)" + assert latex(mobius(n)**2) == r"\mu^{2}\left(n\right)" + + +def test_latex_euler(): + assert latex(euler(n)) == r"E_{n}" + assert latex(euler(n, x)) == r"E_{n}\left(x\right)" + assert latex(euler(n, x)**2) == r"E_{n}^{2}\left(x\right)" + + +def test_lamda(): + assert latex(Symbol('lamda')) == r"\lambda" + assert latex(Symbol('Lamda')) == r"\Lambda" + + +def test_custom_symbol_names(): + x = Symbol('x') + y = Symbol('y') + assert latex(x) == r"x" + assert latex(x, symbol_names={x: "x_i"}) == r"x_i" + assert latex(x + y, symbol_names={x: "x_i"}) == r"x_i + y" + assert latex(x**2, symbol_names={x: "x_i"}) == r"x_i^{2}" + assert latex(x + y, symbol_names={x: "x_i", y: "y_j"}) == r"x_i + y_j" + + +def test_matAdd(): + C = MatrixSymbol('C', 5, 5) + B = MatrixSymbol('B', 5, 5) + + n = symbols("n") + h = MatrixSymbol("h", 1, 1) + + assert latex(C - 2*B) in [r'- 2 B + C', r'C -2 B'] + assert latex(C + 2*B) in [r'2 B + C', r'C + 2 B'] + assert latex(B - 2*C) in [r'B - 2 C', r'- 2 C + B'] + assert latex(B + 2*C) in [r'B + 2 C', r'2 C + B'] + + assert latex(n * h - (-h + h.T) * (h + h.T)) == 'n h - \\left(- h + h^{T}\\right) \\left(h + h^{T}\\right)' + assert latex(MatAdd(MatAdd(h, h), MatAdd(h, h))) == '\\left(h + h\\right) + \\left(h + h\\right)' + assert latex(MatMul(MatMul(h, h), MatMul(h, h))) == '\\left(h h\\right) \\left(h h\\right)' + + +def test_matMul(): + A = MatrixSymbol('A', 5, 5) + B = MatrixSymbol('B', 5, 5) + x = Symbol('x') + assert latex(2*A) == r'2 A' + assert latex(2*x*A) == r'2 x A' + assert latex(-2*A) == r'- 2 A' + assert latex(1.5*A) == r'1.5 A' + assert latex(sqrt(2)*A) == r'\sqrt{2} A' + assert latex(-sqrt(2)*A) == r'- \sqrt{2} A' + assert latex(2*sqrt(2)*x*A) == r'2 \sqrt{2} x A' + assert latex(-2*A*(A + 2*B)) in [r'- 2 A \left(A + 2 B\right)', + r'- 2 A \left(2 B + A\right)'] + + +def test_latex_MatrixSlice(): + n = Symbol('n', integer=True) + x, y, z, w, t, = symbols('x y z w t') + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert latex(MatrixSlice(X, (None, None, None), (None, None, None))) == r'X\left[:, :\right]' + assert latex(X[x:x + 1, y:y + 1]) == r'X\left[x:x + 1, y:y + 1\right]' + assert latex(X[x:x + 1:2, y:y + 1:2]) == r'X\left[x:x + 1:2, y:y + 1:2\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[:x, y:]) == r'X\left[:x, y:\right]' + assert latex(X[x:, :y]) == r'X\left[x:, :y\right]' + assert latex(X[x:y, z:w]) == r'X\left[x:y, z:w\right]' + assert latex(X[x:y:t, w:t:x]) == r'X\left[x:y:t, w:t:x\right]' + assert latex(X[x::y, t::w]) == r'X\left[x::y, t::w\right]' + assert latex(X[:x:y, :t:w]) == r'X\left[:x:y, :t:w\right]' + assert latex(X[::x, ::y]) == r'X\left[::x, ::y\right]' + assert latex(MatrixSlice(X, (0, None, None), (0, None, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (None, n, None), (None, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, None), (0, n, None))) == r'X\left[:, :\right]' + assert latex(MatrixSlice(X, (0, n, 2), (0, n, 2))) == r'X\left[::2, ::2\right]' + assert latex(X[1:2:3, 4:5:6]) == r'X\left[1:2:3, 4:5:6\right]' + assert latex(X[1:3:5, 4:6:8]) == r'X\left[1:3:5, 4:6:8\right]' + assert latex(X[1:10:2]) == r'X\left[1:10:2, :\right]' + assert latex(Y[:5, 1:9:2]) == r'Y\left[:5, 1:9:2\right]' + assert latex(Y[:5, 1:10:2]) == r'Y\left[:5, 1::2\right]' + assert latex(Y[5, :5:2]) == r'Y\left[5:6, :5:2\right]' + assert latex(X[0:1, 0:1]) == r'X\left[:1, :1\right]' + assert latex(X[0:1:2, 0:1:2]) == r'X\left[:1:2, :1:2\right]' + assert latex((Y + Z)[2:, 2:]) == r'\left(Y + Z\right)\left[2:, 2:\right]' + + +def test_latex_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + from sympy.stats.rv import RandomDomain + + X = Normal('x1', 0, 1) + assert latex(where(X > 0)) == r"\text{Domain: }0 < x_{1} \wedge x_{1} < \infty" + + D = Die('d1', 6) + assert latex(where(D > 4)) == r"\text{Domain: }d_{1} = 5 \vee d_{1} = 6" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert latex( + pspace(Tuple(A, B)).domain) == \ + r"\text{Domain: }0 \leq a \wedge 0 \leq b \wedge a < \infty \wedge b < \infty" + + assert latex(RandomDomain(FiniteSet(x), FiniteSet(1, 2))) == \ + r'\text{Domain: }\left\{x\right\} \in \left\{1, 2\right\}' + +def test_PrettyPoly(): + from sympy.polys.domains import QQ + F = QQ.frac_field(x, y) + R = QQ[x, y] + + assert latex(F.convert(x/(x + y))) == latex(x/(x + y)) + assert latex(R.convert(x + y)) == latex(x + y) + + +def test_integral_transforms(): + x = Symbol("x") + k = Symbol("k") + f = Function("f") + a = Symbol("a") + b = Symbol("b") + + assert latex(MellinTransform(f(x), x, k)) == \ + r"\mathcal{M}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseMellinTransform(f(k), k, x, a, b)) == \ + r"\mathcal{M}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(LaplaceTransform(f(x), x, k)) == \ + r"\mathcal{L}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseLaplaceTransform(f(k), k, x, (a, b))) == \ + r"\mathcal{L}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(FourierTransform(f(x), x, k)) == \ + r"\mathcal{F}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseFourierTransform(f(k), k, x)) == \ + r"\mathcal{F}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(CosineTransform(f(x), x, k)) == \ + r"\mathcal{COS}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseCosineTransform(f(k), k, x)) == \ + r"\mathcal{COS}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + assert latex(SineTransform(f(x), x, k)) == \ + r"\mathcal{SIN}_{x}\left[f{\left(x \right)}\right]\left(k\right)" + assert latex(InverseSineTransform(f(k), k, x)) == \ + r"\mathcal{SIN}^{-1}_{k}\left[f{\left(k \right)}\right]\left(x\right)" + + +def test_PolynomialRingBase(): + from sympy.polys.domains import QQ + assert latex(QQ.old_poly_ring(x, y)) == r"\mathbb{Q}\left[x, y\right]" + assert latex(QQ.old_poly_ring(x, y, order="ilex")) == \ + r"S_<^{-1}\mathbb{Q}\left[x, y\right]" + + +def test_categories(): + from sympy.categories import (Object, IdentityMorphism, + NamedMorphism, Category, Diagram, + DiagramGrid) + + A1 = Object("A1") + A2 = Object("A2") + A3 = Object("A3") + + f1 = NamedMorphism(A1, A2, "f1") + f2 = NamedMorphism(A2, A3, "f2") + id_A1 = IdentityMorphism(A1) + + K1 = Category("K1") + + assert latex(A1) == r"A_{1}" + assert latex(f1) == r"f_{1}:A_{1}\rightarrow A_{2}" + assert latex(id_A1) == r"id:A_{1}\rightarrow A_{1}" + assert latex(f2*f1) == r"f_{2}\circ f_{1}:A_{1}\rightarrow A_{3}" + + assert latex(K1) == r"\mathbf{K_{1}}" + + d = Diagram() + assert latex(d) == r"\emptyset" + + d = Diagram({f1: "unique", f2: S.EmptySet}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}, " \ + r"\ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" + + d = Diagram({f1: "unique", f2: S.EmptySet}, {f2 * f1: "unique"}) + assert latex(d) == r"\left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \emptyset, \ id:A_{1}\rightarrow " \ + r"A_{1} : \emptyset, \ id:A_{2}\rightarrow A_{2} : " \ + r"\emptyset, \ id:A_{3}\rightarrow A_{3} : \emptyset, " \ + r"\ f_{1}:A_{1}\rightarrow A_{2} : \left\{unique\right\}," \ + r" \ f_{2}:A_{2}\rightarrow A_{3} : \emptyset\right\}" \ + r"\Longrightarrow \left\{ f_{2}\circ f_{1}:A_{1}" \ + r"\rightarrow A_{3} : \left\{unique\right\}\right\}" + + # A linear diagram. + A = Object("A") + B = Object("B") + C = Object("C") + f = NamedMorphism(A, B, "f") + g = NamedMorphism(B, C, "g") + d = Diagram([f, g]) + grid = DiagramGrid(d) + + assert latex(grid) == r"\begin{array}{cc}" + "\n" \ + r"A & B \\" + "\n" \ + r" & C " + "\n" \ + r"\end{array}" + "\n" + + +def test_Modules(): + from sympy.polys.domains import QQ + from sympy.polys.agca import homomorphism + + R = QQ.old_poly_ring(x, y) + F = R.free_module(2) + M = F.submodule([x, y], [1, x**2]) + + assert latex(F) == r"{\mathbb{Q}\left[x, y\right]}^{2}" + assert latex(M) == \ + r"\left\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle" + + I = R.ideal(x**2, y) + assert latex(I) == r"\left\langle {x^{2}},{y} \right\rangle" + + Q = F / M + assert latex(Q) == \ + r"\frac{{\mathbb{Q}\left[x, y\right]}^{2}}{\left\langle {\left[ {x},"\ + r"{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}" + assert latex(Q.submodule([1, x**3/2], [2, y])) == \ + r"\left\langle {{\left[ {1},{\frac{x^{3}}{2}} \right]} + {\left"\ + r"\langle {\left[ {x},{y} \right]},{\left[ {1},{x^{2}} \right]} "\ + r"\right\rangle}},{{\left[ {2},{y} \right]} + {\left\langle {\left[ "\ + r"{x},{y} \right]},{\left[ {1},{x^{2}} \right]} \right\rangle}} \right\rangle" + + h = homomorphism(QQ.old_poly_ring(x).free_module(2), + QQ.old_poly_ring(x).free_module(2), [0, 0]) + + assert latex(h) == \ + r"{\left[\begin{matrix}0 & 0\\0 & 0\end{matrix}\right]} : "\ + r"{{\mathbb{Q}\left[x\right]}^{2}} \to {{\mathbb{Q}\left[x\right]}^{2}}" + + +def test_QuotientRing(): + from sympy.polys.domains import QQ + R = QQ.old_poly_ring(x)/[x**2 + 1] + + assert latex(R) == \ + r"\frac{\mathbb{Q}\left[x\right]}{\left\langle {x^{2} + 1} \right\rangle}" + assert latex(R.one) == r"{1} + {\left\langle {x^{2} + 1} \right\rangle}" + + +def test_Tr(): + #TODO: Handle indices + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert latex(t) == r'\operatorname{tr}\left(A B\right)' + + +def test_Determinant(): + from sympy.matrices import Determinant, Inverse, BlockMatrix, OneMatrix, ZeroMatrix + m = Matrix(((1, 2), (3, 4))) + assert latex(Determinant(m)) == '\\left|{\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}}\\right|' + assert latex(Determinant(Inverse(m))) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{-1}}\\right|' + X = MatrixSymbol('X', 2, 2) + assert latex(Determinant(X)) == '\\left|{X}\\right|' + assert latex(Determinant(X + m)) == \ + '\\left|{\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X}\\right|' + assert latex(Determinant(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left|{\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}}\\right|' + + +def test_Adjoint(): + from sympy.matrices import Adjoint, Inverse, Transpose + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Adjoint(X)) == r'X^{\dagger}' + assert latex(Adjoint(X + Y)) == r'\left(X + Y\right)^{\dagger}' + assert latex(Adjoint(X) + Adjoint(Y)) == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(X*Y)) == r'\left(X Y\right)^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2)) == r'\left(X^{2}\right)^{\dagger}' + assert latex(Adjoint(X)**2) == r'\left(X^{\dagger}\right)^{2}' + assert latex(Adjoint(Inverse(X))) == r'\left(X^{-1}\right)^{\dagger}' + assert latex(Inverse(Adjoint(X))) == r'\left(X^{\dagger}\right)^{-1}' + assert latex(Adjoint(Transpose(X))) == r'\left(X^{T}\right)^{\dagger}' + assert latex(Transpose(Adjoint(X))) == r'\left(X^{\dagger}\right)^{T}' + assert latex(Transpose(Adjoint(X) + Y)) == r'\left(X^{\dagger} + Y\right)^{T}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Adjoint(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{\\dagger}' + assert latex(Adjoint(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{\\dagger}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Adjoint(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{\\dagger}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Adjoint(Mx)) == r'\left(M^{x}\right)^{\dagger}' + + # adjoint style + assert latex(Adjoint(X), adjoint_style="star") == r'X^{\ast}' + assert latex(Adjoint(X + Y), adjoint_style="hermitian") == r'\left(X + Y\right)^{\mathsf{H}}' + assert latex(Adjoint(X) + Adjoint(Y), adjoint_style="dagger") == r'X^{\dagger} + Y^{\dagger}' + assert latex(Adjoint(Y)*Adjoint(X)) == r'Y^{\dagger} X^{\dagger}' + assert latex(Adjoint(X**2), adjoint_style="star") == r'\left(X^{2}\right)^{\ast}' + assert latex(Adjoint(X)**2, adjoint_style="hermitian") == r'\left(X^{\mathsf{H}}\right)^{2}' + +def test_Transpose(): + from sympy.matrices import Transpose, MatPow, HadamardPower + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(Transpose(X)) == r'X^{T}' + assert latex(Transpose(X + Y)) == r'\left(X + Y\right)^{T}' + + assert latex(Transpose(HadamardPower(X, 2))) == r'\left(X^{\circ {2}}\right)^{T}' + assert latex(HadamardPower(Transpose(X), 2)) == r'\left(X^{T}\right)^{\circ {2}}' + assert latex(Transpose(MatPow(X, 2))) == r'\left(X^{2}\right)^{T}' + assert latex(MatPow(Transpose(X), 2)) == r'\left(X^{T}\right)^{2}' + m = Matrix(((1, 2), (3, 4))) + assert latex(Transpose(m)) == '\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right]^{T}' + assert latex(Transpose(m+X)) == \ + '\\left(\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] + X\\right)^{T}' + from sympy.matrices import BlockMatrix, OneMatrix, ZeroMatrix + assert latex(Transpose(BlockMatrix(((OneMatrix(2, 2), X), + (m, ZeroMatrix(2, 2)))))) == \ + '\\left[\\begin{matrix}1 & X\\\\\\left[\\begin{matrix}1 & 2\\\\3 & 4\\end{matrix}\\right] & 0\\end{matrix}\\right]^{T}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(Transpose(Mx)) == r'\left(M^{x}\right)^{T}' + + +def test_Hadamard(): + from sympy.matrices import HadamardProduct, HadamardPower + from sympy.matrices.expressions import MatAdd, MatMul, MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(HadamardProduct(X, Y*Y)) == r'X \circ Y^{2}' + assert latex(HadamardProduct(X, Y)*Y) == r'\left(X \circ Y\right) Y' + + assert latex(HadamardPower(X, 2)) == r'X^{\circ {2}}' + assert latex(HadamardPower(X, -1)) == r'X^{\circ \left({-1}\right)}' + assert latex(HadamardPower(MatAdd(X, Y), 2)) == \ + r'\left(X + Y\right)^{\circ {2}}' + assert latex(HadamardPower(MatMul(X, Y), 2)) == \ + r'\left(X Y\right)^{\circ {2}}' + + assert latex(HadamardPower(MatPow(X, -1), -1)) == \ + r'\left(X^{-1}\right)^{\circ \left({-1}\right)}' + assert latex(MatPow(HadamardPower(X, -1), -1)) == \ + r'\left(X^{\circ \left({-1}\right)}\right)^{-1}' + + assert latex(HadamardPower(X, n+1)) == \ + r'X^{\circ \left({n + 1}\right)}' + + +def test_MatPow(): + from sympy.matrices.expressions import MatPow + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert latex(MatPow(X, 2)) == 'X^{2}' + assert latex(MatPow(X*X, 2)) == '\\left(X^{2}\\right)^{2}' + assert latex(MatPow(X*Y, 2)) == '\\left(X Y\\right)^{2}' + assert latex(MatPow(X + Y, 2)) == '\\left(X + Y\\right)^{2}' + assert latex(MatPow(X + X, 2)) == '\\left(2 X\\right)^{2}' + # Issue 20959 + Mx = MatrixSymbol('M^x', 2, 2) + assert latex(MatPow(Mx, 2)) == r'\left(M^{x}\right)^{2}' + + +def test_ElementwiseApplyFunction(): + X = MatrixSymbol('X', 2, 2) + expr = (X.T*X).applyfunc(sin) + assert latex(expr) == r"{\left( d \mapsto \sin{\left(d \right)} \right)}_{\circ}\left({X^{T} X}\right)" + expr = X.applyfunc(Lambda(x, 1/x)) + assert latex(expr) == r'{\left( x \mapsto \frac{1}{x} \right)}_{\circ}\left({X}\right)' + + +def test_ZeroMatrix(): + from sympy.matrices.expressions.special import ZeroMatrix + assert latex(ZeroMatrix(1, 1), mat_symbol_style='plain') == r"0" + assert latex(ZeroMatrix(1, 1), mat_symbol_style='bold') == r"\mathbf{0}" + + +def test_OneMatrix(): + from sympy.matrices.expressions.special import OneMatrix + assert latex(OneMatrix(3, 4), mat_symbol_style='plain') == r"1" + assert latex(OneMatrix(3, 4), mat_symbol_style='bold') == r"\mathbf{1}" + + +def test_Identity(): + from sympy.matrices.expressions.special import Identity + assert latex(Identity(1), mat_symbol_style='plain') == r"\mathbb{I}" + assert latex(Identity(1), mat_symbol_style='bold') == r"\mathbf{I}" + + +def test_latex_DFT_IDFT(): + from sympy.matrices.expressions.fourier import DFT, IDFT + assert latex(DFT(13)) == r"\text{DFT}_{13}" + assert latex(IDFT(x)) == r"\text{IDFT}_{x}" + + +def test_boolean_args_order(): + syms = symbols('a:f') + + expr = And(*syms) + assert latex(expr) == r'a \wedge b \wedge c \wedge d \wedge e \wedge f' + + expr = Or(*syms) + assert latex(expr) == r'a \vee b \vee c \vee d \vee e \vee f' + + expr = Equivalent(*syms) + assert latex(expr) == \ + r'a \Leftrightarrow b \Leftrightarrow c \Leftrightarrow d \Leftrightarrow e \Leftrightarrow f' + + expr = Xor(*syms) + assert latex(expr) == \ + r'a \veebar b \veebar c \veebar d \veebar e \veebar f' + + +def test_imaginary(): + i = sqrt(-1) + assert latex(i) == r'i' + + +def test_builtins_without_args(): + assert latex(sin) == r'\sin' + assert latex(cos) == r'\cos' + assert latex(tan) == r'\tan' + assert latex(log) == r'\log' + assert latex(Ei) == r'\operatorname{Ei}' + assert latex(zeta) == r'\zeta' + + +def test_latex_greek_functions(): + # bug because capital greeks that have roman equivalents should not use + # \Alpha, \Beta, \Eta, etc. + s = Function('Alpha') + assert latex(s) == r'\mathrm{A}' + assert latex(s(x)) == r'\mathrm{A}{\left(x \right)}' + s = Function('Beta') + assert latex(s) == r'\mathrm{B}' + s = Function('Eta') + assert latex(s) == r'\mathrm{H}' + assert latex(s(x)) == r'\mathrm{H}{\left(x \right)}' + + # bug because sympy.core.numbers.Pi is special + p = Function('Pi') + # assert latex(p(x)) == r'\Pi{\left(x \right)}' + assert latex(p) == r'\Pi' + + # bug because not all greeks are included + c = Function('chi') + assert latex(c(x)) == r'\chi{\left(x \right)}' + assert latex(c) == r'\chi' + + +def test_translate(): + s = 'Alpha' + assert translate(s) == r'\mathrm{A}' + s = 'Beta' + assert translate(s) == r'\mathrm{B}' + s = 'Eta' + assert translate(s) == r'\mathrm{H}' + s = 'omicron' + assert translate(s) == r'o' + s = 'Pi' + assert translate(s) == r'\Pi' + s = 'pi' + assert translate(s) == r'\pi' + s = 'LamdaHatDOT' + assert translate(s) == r'\dot{\hat{\Lambda}}' + + +def test_other_symbols(): + from sympy.printing.latex import other_symbols + for s in other_symbols: + assert latex(symbols(s)) == r"" "\\" + s + + +def test_modifiers(): + # Test each modifier individually in the simplest case + # (with funny capitalizations) + assert latex(symbols("xMathring")) == r"\mathring{x}" + assert latex(symbols("xCheck")) == r"\check{x}" + assert latex(symbols("xBreve")) == r"\breve{x}" + assert latex(symbols("xAcute")) == r"\acute{x}" + assert latex(symbols("xGrave")) == r"\grave{x}" + assert latex(symbols("xTilde")) == r"\tilde{x}" + assert latex(symbols("xPrime")) == r"{x}'" + assert latex(symbols("xddDDot")) == r"\ddddot{x}" + assert latex(symbols("xDdDot")) == r"\dddot{x}" + assert latex(symbols("xDDot")) == r"\ddot{x}" + assert latex(symbols("xBold")) == r"\boldsymbol{x}" + assert latex(symbols("xnOrM")) == r"\left\|{x}\right\|" + assert latex(symbols("xAVG")) == r"\left\langle{x}\right\rangle" + assert latex(symbols("xHat")) == r"\hat{x}" + assert latex(symbols("xDot")) == r"\dot{x}" + assert latex(symbols("xBar")) == r"\bar{x}" + assert latex(symbols("xVec")) == r"\vec{x}" + assert latex(symbols("xAbs")) == r"\left|{x}\right|" + assert latex(symbols("xMag")) == r"\left|{x}\right|" + assert latex(symbols("xPrM")) == r"{x}'" + assert latex(symbols("xBM")) == r"\boldsymbol{x}" + # Test strings that are *only* the names of modifiers + assert latex(symbols("Mathring")) == r"Mathring" + assert latex(symbols("Check")) == r"Check" + assert latex(symbols("Breve")) == r"Breve" + assert latex(symbols("Acute")) == r"Acute" + assert latex(symbols("Grave")) == r"Grave" + assert latex(symbols("Tilde")) == r"Tilde" + assert latex(symbols("Prime")) == r"Prime" + assert latex(symbols("DDot")) == r"\dot{D}" + assert latex(symbols("Bold")) == r"Bold" + assert latex(symbols("NORm")) == r"NORm" + assert latex(symbols("AVG")) == r"AVG" + assert latex(symbols("Hat")) == r"Hat" + assert latex(symbols("Dot")) == r"Dot" + assert latex(symbols("Bar")) == r"Bar" + assert latex(symbols("Vec")) == r"Vec" + assert latex(symbols("Abs")) == r"Abs" + assert latex(symbols("Mag")) == r"Mag" + assert latex(symbols("PrM")) == r"PrM" + assert latex(symbols("BM")) == r"BM" + assert latex(symbols("hbar")) == r"\hbar" + # Check a few combinations + assert latex(symbols("xvecdot")) == r"\dot{\vec{x}}" + assert latex(symbols("xDotVec")) == r"\vec{\dot{x}}" + assert latex(symbols("xHATNorm")) == r"\left\|{\hat{x}}\right\|" + # Check a couple big, ugly combinations + assert latex(symbols('xMathringBm_yCheckPRM__zbreveAbs')) == \ + r"\boldsymbol{\mathring{x}}^{\left|{\breve{z}}\right|}_{{\check{y}}'}" + assert latex(symbols('alphadothat_nVECDOT__tTildePrime')) == \ + r"\hat{\dot{\alpha}}^{{\tilde{t}}'}_{\dot{\vec{n}}}" + + +def test_greek_symbols(): + assert latex(Symbol('alpha')) == r'\alpha' + assert latex(Symbol('beta')) == r'\beta' + assert latex(Symbol('gamma')) == r'\gamma' + assert latex(Symbol('delta')) == r'\delta' + assert latex(Symbol('epsilon')) == r'\epsilon' + assert latex(Symbol('zeta')) == r'\zeta' + assert latex(Symbol('eta')) == r'\eta' + assert latex(Symbol('theta')) == r'\theta' + assert latex(Symbol('iota')) == r'\iota' + assert latex(Symbol('kappa')) == r'\kappa' + assert latex(Symbol('lambda')) == r'\lambda' + assert latex(Symbol('mu')) == r'\mu' + assert latex(Symbol('nu')) == r'\nu' + assert latex(Symbol('xi')) == r'\xi' + assert latex(Symbol('omicron')) == r'o' + assert latex(Symbol('pi')) == r'\pi' + assert latex(Symbol('rho')) == r'\rho' + assert latex(Symbol('sigma')) == r'\sigma' + assert latex(Symbol('tau')) == r'\tau' + assert latex(Symbol('upsilon')) == r'\upsilon' + assert latex(Symbol('phi')) == r'\phi' + assert latex(Symbol('chi')) == r'\chi' + assert latex(Symbol('psi')) == r'\psi' + assert latex(Symbol('omega')) == r'\omega' + + assert latex(Symbol('Alpha')) == r'\mathrm{A}' + assert latex(Symbol('Beta')) == r'\mathrm{B}' + assert latex(Symbol('Gamma')) == r'\Gamma' + assert latex(Symbol('Delta')) == r'\Delta' + assert latex(Symbol('Epsilon')) == r'\mathrm{E}' + assert latex(Symbol('Zeta')) == r'\mathrm{Z}' + assert latex(Symbol('Eta')) == r'\mathrm{H}' + assert latex(Symbol('Theta')) == r'\Theta' + assert latex(Symbol('Iota')) == r'\mathrm{I}' + assert latex(Symbol('Kappa')) == r'\mathrm{K}' + assert latex(Symbol('Lambda')) == r'\Lambda' + assert latex(Symbol('Mu')) == r'\mathrm{M}' + assert latex(Symbol('Nu')) == r'\mathrm{N}' + assert latex(Symbol('Xi')) == r'\Xi' + assert latex(Symbol('Omicron')) == r'\mathrm{O}' + assert latex(Symbol('Pi')) == r'\Pi' + assert latex(Symbol('Rho')) == r'\mathrm{P}' + assert latex(Symbol('Sigma')) == r'\Sigma' + assert latex(Symbol('Tau')) == r'\mathrm{T}' + assert latex(Symbol('Upsilon')) == r'\Upsilon' + assert latex(Symbol('Phi')) == r'\Phi' + assert latex(Symbol('Chi')) == r'\mathrm{X}' + assert latex(Symbol('Psi')) == r'\Psi' + assert latex(Symbol('Omega')) == r'\Omega' + + assert latex(Symbol('varepsilon')) == r'\varepsilon' + assert latex(Symbol('varkappa')) == r'\varkappa' + assert latex(Symbol('varphi')) == r'\varphi' + assert latex(Symbol('varpi')) == r'\varpi' + assert latex(Symbol('varrho')) == r'\varrho' + assert latex(Symbol('varsigma')) == r'\varsigma' + assert latex(Symbol('vartheta')) == r'\vartheta' + + +def test_fancyset_symbols(): + assert latex(S.Rationals) == r'\mathbb{Q}' + assert latex(S.Naturals) == r'\mathbb{N}' + assert latex(S.Naturals0) == r'\mathbb{N}_0' + assert latex(S.Integers) == r'\mathbb{Z}' + assert latex(S.Reals) == r'\mathbb{R}' + assert latex(S.Complexes) == r'\mathbb{C}' + + +@XFAIL +def test_builtin_without_args_mismatched_names(): + assert latex(CosineTransform) == r'\mathcal{COS}' + + +def test_builtin_no_args(): + assert latex(Chi) == r'\operatorname{Chi}' + assert latex(beta) == r'\operatorname{B}' + assert latex(gamma) == r'\Gamma' + assert latex(KroneckerDelta) == r'\delta' + assert latex(DiracDelta) == r'\delta' + assert latex(lowergamma) == r'\gamma' + + +def test_issue_6853(): + p = Function('Pi') + assert latex(p(x)) == r"\Pi{\left(x \right)}" + + +def test_Mul(): + e = Mul(-2, x + 1, evaluate=False) + assert latex(e) == r'- 2 \left(x + 1\right)' + e = Mul(2, x + 1, evaluate=False) + assert latex(e) == r'2 \left(x + 1\right)' + e = Mul(S.Half, x + 1, evaluate=False) + assert latex(e) == r'\frac{x + 1}{2}' + e = Mul(y, x + 1, evaluate=False) + assert latex(e) == r'y \left(x + 1\right)' + e = Mul(-y, x + 1, evaluate=False) + assert latex(e) == r'- y \left(x + 1\right)' + e = Mul(-2, x + 1) + assert latex(e) == r'- 2 x - 2' + e = Mul(2, x + 1) + assert latex(e) == r'2 x + 2' + + +def test_Pow(): + e = Pow(2, 2, evaluate=False) + assert latex(e) == r'2^{2}' + assert latex(x**(Rational(-1, 3))) == r'\frac{1}{\sqrt[3]{x}}' + x2 = Symbol(r'x^2') + assert latex(x2**2) == r'\left(x^{2}\right)^{2}' + # Issue 11011 + assert latex(S('1.453e4500')**x) == r'{1.453 \cdot 10^{4500}}^{x}' + + +def test_issue_7180(): + assert latex(Equivalent(x, y)) == r"x \Leftrightarrow y" + assert latex(Not(Equivalent(x, y))) == r"x \not\Leftrightarrow y" + + +def test_issue_8409(): + assert latex(S.Half**n) == r"\left(\frac{1}{2}\right)^{n}" + + +def test_issue_8470(): + from sympy.parsing.sympy_parser import parse_expr + e = parse_expr("-B*A", evaluate=False) + assert latex(e) == r"A \left(- B\right)" + + +def test_issue_15439(): + x = MatrixSymbol('x', 2, 2) + y = MatrixSymbol('y', 2, 2) + assert latex((x * y).subs(y, -y)) == r"x \left(- y\right)" + assert latex((x * y).subs(y, -2*y)) == r"x \left(- 2 y\right)" + assert latex((x * y).subs(x, -x)) == r"\left(- x\right) y" + + +def test_issue_2934(): + assert latex(Symbol(r'\frac{a_1}{b_1}')) == r'\frac{a_1}{b_1}' + + +def test_issue_10489(): + latexSymbolWithBrace = r'C_{x_{0}}' + s = Symbol(latexSymbolWithBrace) + assert latex(s) == latexSymbolWithBrace + assert latex(cos(s)) == r'\cos{\left(C_{x_{0}} \right)}' + + +def test_issue_12886(): + m__1, l__1 = symbols('m__1, l__1') + assert latex(m__1**2 + l__1**2) == \ + r'\left(l^{1}\right)^{2} + \left(m^{1}\right)^{2}' + + +def test_issue_13559(): + from sympy.parsing.sympy_parser import parse_expr + expr = parse_expr('5/1', evaluate=False) + assert latex(expr) == r"\frac{5}{1}" + + +def test_issue_13651(): + expr = c + Mul(-1, a + b, evaluate=False) + assert latex(expr) == r"c - \left(a + b\right)" + + +def test_latex_UnevaluatedExpr(): + x = symbols("x") + he = UnevaluatedExpr(1/x) + assert latex(he) == latex(1/x) == r"\frac{1}{x}" + assert latex(he**2) == r"\left(\frac{1}{x}\right)^{2}" + assert latex(he + 1) == r"1 + \frac{1}{x}" + assert latex(x*he) == r"x \frac{1}{x}" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert latex(A[0, 0]) == r"{A}_{0,0}" + assert latex(3 * A[0, 0]) == r"3 {A}_{0,0}" + + F = C[0, 0].subs(C, A - B) + assert latex(F) == r"{\left(A - B\right)}_{0,0}" + + i, j, k = symbols("i j k") + M = MatrixSymbol("M", k, k) + N = MatrixSymbol("N", k, k) + assert latex((M*N)[i, j]) == \ + r'\sum_{i_{1}=0}^{k - 1} {M}_{i,i_{1}} {N}_{i_{1},j}' + + X_a = MatrixSymbol('X_a', 3, 3) + assert latex(X_a[0, 0]) == r"{X_{a}}_{0,0}" + + +def test_MatrixSymbol_printing(): + # test cases for issue #14237 + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A) == r"- A" + assert latex(A - A*B - B) == r"A - A B - B" + assert latex(-A*B - A*B*C - B) == r"- A B - A B C - B" + + +def test_KroneckerProduct_printing(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 2, 2) + assert latex(KroneckerProduct(A, B)) == r'A \otimes B' + + +def test_Series_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert latex(Series(tf1, tf2)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right)' + assert latex(Series(tf1, tf2, tf3)) == \ + r'\left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right) \left(\frac{x - y}{x + y}\right) \left(\frac{t x^{2} - t^{w} x + w}{t - y}\right)' + assert latex(Series(-tf2, tf1)) == \ + r'\left(\frac{- x + y}{x + y}\right) \left(\frac{x y^{2} - z}{- t^{3} + y^{3}}\right)' + + M_1 = Matrix([[5/s], [5/(2*s)]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5, 6*s**3]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + # Brackets + assert latex(T_1*(T_2 + T_2)) == \ + r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left(\left[\begin{matrix}\frac{5}{1} &' \ + r' \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau\right)' \ + == latex(MIMOSeries(MIMOParallel(T_2, T_2), T_1)) + # No Brackets + M_3 = Matrix([[5, 6], [6, 5/s]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1*T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{s}\\\frac{5}{2 s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}' \ + r'\frac{5}{1} & \frac{6 s^{3}}{1}\end{matrix}\right]_\tau + \left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & ' \ + r'\frac{5}{s}\end{matrix}\right]_\tau' == latex(MIMOParallel(MIMOSeries(T_2, T_1), T_3)) + + +def test_TransferFunction_printing(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert latex(tf1) == r"\frac{x - 1}{x + 1}" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert latex(tf2) == r"\frac{x + 1}{2 - y}" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert latex(tf3) == r"\frac{y}{y^{2} + 2 y + 3}" + + +def test_Parallel_printing(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + assert latex(Parallel(tf1, tf2)) == \ + r'\frac{x y^{2} - z}{- t^{3} + y^{3}} + \frac{x - y}{x + y}' + assert latex(Parallel(-tf2, tf1)) == \ + r'\frac{- x + y}{x + y} + \frac{x y^{2} - z}{- t^{3} + y^{3}}' + + M_1 = Matrix([[5, 6], [6, 5/s]]) + T_1 = TransferFunctionMatrix.from_Matrix(M_1, s) + M_2 = Matrix([[5/s, 6], [6, 5/(s - 1)]]) + T_2 = TransferFunctionMatrix.from_Matrix(M_2, s) + M_3 = Matrix([[6, 5/(s*(s - 1))], [5, 6]]) + T_3 = TransferFunctionMatrix.from_Matrix(M_3, s) + assert latex(T_1 + T_2 + T_3) == r'\left[\begin{matrix}\frac{5}{1} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s}\end{matrix}\right]' \ + r'_\tau + \left[\begin{matrix}\frac{5}{s} & \frac{6}{1}\\\frac{6}{1} & \frac{5}{s - 1}\end{matrix}\right]_\tau + \left[\begin{matrix}' \ + r'\frac{6}{1} & \frac{5}{s \left(s - 1\right)}\\\frac{5}{1} & \frac{6}{1}\end{matrix}\right]_\tau' \ + == latex(MIMOParallel(T_1, T_2, T_3)) == latex(MIMOParallel(T_1, MIMOParallel(T_2, T_3))) == latex(MIMOParallel(MIMOParallel(T_1, T_2), T_3)) + + +def test_TransferFunctionMatrix_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + tf3 = TransferFunction(p, y**2 + 2*y + 3, p) + assert latex(TransferFunctionMatrix([[tf1], [tf2]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x}\\\frac{p - s}{p + s}\end{matrix}\right]_\tau' + assert latex(TransferFunctionMatrix([[tf1, tf2], [tf3, -tf1]])) == \ + r'\left[\begin{matrix}\frac{p}{p + x} & \frac{p - s}{p + s}\\\frac{p}{y^{2} + 2 y + 3} & \frac{\left(-1\right) p}{p + x}\end{matrix}\right]_\tau' + + +def test_Feedback_printing(): + tf1 = TransferFunction(p, p + x, p) + tf2 = TransferFunction(-s + p, p + s, p) + # Negative Feedback (Default) + assert latex(Feedback(tf1, tf2)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, TransferFunction(1, 1, p))) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} + \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + # Positive Feedback + assert latex(Feedback(tf1, tf2, 1)) == \ + r'\frac{\frac{p}{p + x}}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + assert latex(Feedback(tf1*tf2, sign=1)) == \ + r'\frac{\left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}{\frac{1}{1} - \left(\frac{p}{p + x}\right) \left(\frac{p - s}{p + s}\right)}' + + +def test_MIMOFeedback_printing(): + tf1 = TransferFunction(1, s, s) + tf2 = TransferFunction(s, s**2 - 1, s) + tf3 = TransferFunction(s, s - 1, s) + tf4 = TransferFunction(s**2, s**2 - 1, s) + + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf3, tf4]]) + tfm_2 = TransferFunctionMatrix([[tf4, tf3], [tf2, tf1]]) + + # Negative Feedback (Default) + assert latex(MIMOFeedback(tfm_1, tfm_2)) == \ + r'\left(I_{\tau} + \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[' \ + r'\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}' \ + r'\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau' + + # Positive Feedback + assert latex(MIMOFeedback(tfm_1*tfm_2, tfm_1, 1)) == \ + r'\left(I_{\tau} - \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left' \ + r'[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1} & \frac{1}{s}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\right)^{-1} \cdot \left[\begin{matrix}\frac{1}{s} & \frac{s}{s^{2} - 1}' \ + r'\\\frac{s}{s - 1} & \frac{s^{2}}{s^{2} - 1}\end{matrix}\right]_\tau\cdot\left[\begin{matrix}\frac{s^{2}}{s^{2} - 1} & \frac{s}{s - 1}\\\frac{s}{s^{2} - 1}' \ + r' & \frac{1}{s}\end{matrix}\right]_\tau' + + +def test_Quaternion_latex_printing(): + q = Quaternion(x, y, z, t) + assert latex(q) == r"x + y i + z j + t k" + q = Quaternion(x, y, z, x*t) + assert latex(q) == r"x + y i + z j + t x k" + q = Quaternion(x, y, z, x + t) + assert latex(q) == r"x + y i + z j + \left(t + x\right) k" + + +def test_TensorProduct_printing(): + from sympy.tensor.functions import TensorProduct + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + assert latex(TensorProduct(A, B)) == r"A \otimes B" + + +def test_WedgeProduct_printing(): + from sympy.diffgeom.rn import R2 + from sympy.diffgeom import WedgeProduct + wp = WedgeProduct(R2.dx, R2.dy) + assert latex(wp) == r"\operatorname{d}x \wedge \operatorname{d}y" + + +def test_issue_9216(): + expr_1 = Pow(1, -1, evaluate=False) + assert latex(expr_1) == r"1^{-1}" + + expr_2 = Pow(1, Pow(1, -1, evaluate=False), evaluate=False) + assert latex(expr_2) == r"1^{1^{-1}}" + + expr_3 = Pow(3, -2, evaluate=False) + assert latex(expr_3) == r"\frac{1}{9}" + + expr_4 = Pow(1, -2, evaluate=False) + assert latex(expr_4) == r"1^{-2}" + + +def test_latex_printer_tensor(): + from sympy.tensor.tensor import TensorIndexType, tensor_indices, TensorHead, tensor_heads + L = TensorIndexType("L") + i, j, k, l = tensor_indices("i j k l", L) + i0 = tensor_indices("i_0", L) + A, B, C, D = tensor_heads("A B C D", [L]) + H = TensorHead("H", [L, L]) + K = TensorHead("K", [L, L, L, L]) + + assert latex(i) == r"{}^{i}" + assert latex(-i) == r"{}_{i}" + + expr = A(i) + assert latex(expr) == r"A{}^{i}" + + expr = A(i0) + assert latex(expr) == r"A{}^{i_{0}}" + + expr = A(-i) + assert latex(expr) == r"A{}_{i}" + + expr = -3*A(i) + assert latex(expr) == r"-3A{}^{i}" + + expr = K(i, j, -k, -i0) + assert latex(expr) == r"K{}^{ij}{}_{ki_{0}}" + + expr = K(i, -j, -k, i0) + assert latex(expr) == r"K{}^{i}{}_{jk}{}^{i_{0}}" + + expr = K(i, -j, k, -i0) + assert latex(expr) == r"K{}^{i}{}_{j}{}^{k}{}_{i_{0}}" + + expr = H(i, -j) + assert latex(expr) == r"H{}^{i}{}_{j}" + + expr = H(i, j) + assert latex(expr) == r"H{}^{ij}" + + expr = H(-i, -j) + assert latex(expr) == r"H{}_{ij}" + + expr = (1+x)*A(i) + assert latex(expr) == r"\left(x + 1\right)A{}^{i}" + + expr = H(i, -i) + assert latex(expr) == r"H{}^{L_{0}}{}_{L_{0}}" + + expr = H(i, -j)*A(j)*B(k) + assert latex(expr) == r"H{}^{i}{}_{L_{0}}A{}^{L_{0}}B{}^{k}" + + expr = A(i) + 3*B(i) + assert latex(expr) == r"3B{}^{i} + A{}^{i}" + + # Test ``TensorElement``: + from sympy.tensor.tensor import TensorElement + + expr = TensorElement(K(i, j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3,j,k=2,l}' + + expr = TensorElement(K(i, j, k, l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,jkl}' + + expr = TensorElement(K(i, -j, k, l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2,l}' + + expr = TensorElement(K(i, -j, k, -l), {i: 3, k: 2}) + assert latex(expr) == r'K{}^{i=3}{}_{j}{}^{k=2}{}_{l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3, -k: 2}) + assert latex(expr) == r'K{}^{i=3,j}{}_{k=2,l}' + + expr = TensorElement(K(i, j, -k, -l), {i: 3}) + assert latex(expr) == r'K{}^{i=3,j}{}_{kl}' + + expr = PartialDerivative(A(i), A(i)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}^{L_{0}}}}{A{}^{L_{0}}}" + + expr = PartialDerivative(A(-i), A(-j)) + assert latex(expr) == r"\frac{\partial}{\partial {A{}_{j}}}{A{}_{i}}" + + expr = PartialDerivative(K(i, j, -k, -l), A(m), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}^{m}} \partial {A{}_{n}}}{K{}^{ij}{}_{kl}}" + + expr = PartialDerivative(B(-i) + A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(A{}_{i} + B{}_{i}\right)}" + + expr = PartialDerivative(3*A(-i), A(-j), A(-n)) + assert latex(expr) == r"\frac{\partial^{2}}{\partial {A{}_{j}} \partial {A{}_{n}}}{\left(3A{}_{i}\right)}" + + +def test_multiline_latex(): + a, b, c, d, e, f = symbols('a b c d e f') + expr = -a + 2*b -3*c +4*d -5*e + expected = r"\begin{eqnarray}" + "\n"\ + r"f & = &- a \nonumber\\" + "\n"\ + r"& & + 2 b \nonumber\\" + "\n"\ + r"& & - 3 c \nonumber\\" + "\n"\ + r"& & + 4 d \nonumber\\" + "\n"\ + r"& & - 5 e " + "\n"\ + r"\end{eqnarray}" + assert multiline_latex(f, expr, environment="eqnarray") == expected + + expected2 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 2, environment="eqnarray") == expected2 + + expected3 = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray") == expected3 + + expected3dots = r'\begin{eqnarray}' + '\n'\ + r'f & = &- a + 2 b - 3 c \dots\nonumber\\'+ '\n'\ + r'& & + 4 d - 5 e ' + '\n'\ + r'\end{eqnarray}' + + assert multiline_latex(f, expr, 3, environment="eqnarray", use_dots=True) == expected3dots + + expected3align = r'\begin{align*}' + '\n'\ + r'f = &- a + 2 b - 3 c \\'+ '\n'\ + r'& + 4 d - 5 e ' + '\n'\ + r'\end{align*}' + + assert multiline_latex(f, expr, 3) == expected3align + assert multiline_latex(f, expr, 3, environment='align*') == expected3align + + expected2ieee = r'\begin{IEEEeqnarray}{rCl}' + '\n'\ + r'f & = &- a + 2 b \nonumber\\' + '\n'\ + r'& & - 3 c + 4 d \nonumber\\' + '\n'\ + r'& & - 5 e ' + '\n'\ + r'\end{IEEEeqnarray}' + + assert multiline_latex(f, expr, 2, environment="IEEEeqnarray") == expected2ieee + + raises(ValueError, lambda: multiline_latex(f, expr, environment="foo")) + +def test_issue_15353(): + a, x = symbols('a x') + # Obtained from nonlinsolve([(sin(a*x)),cos(a*x)],[x,a]) + sol = ConditionSet( + Tuple(x, a), Eq(sin(a*x), 0) & Eq(cos(a*x), 0), S.Complexes**2) + assert latex(sol) == \ + r'\left\{\left( x, \ a\right)\; \middle|\; \left( x, \ a\right) \in ' \ + r'\mathbb{C}^{2} \wedge \sin{\left(a x \right)} = 0 \wedge ' \ + r'\cos{\left(a x \right)} = 0 \right\}' + + +def test_latex_symbolic_probability(): + mu = symbols("mu") + sigma = symbols("sigma", positive=True) + X = Normal("X", mu, sigma) + assert latex(Expectation(X)) == r'\operatorname{E}\left[X\right]' + assert latex(Variance(X)) == r'\operatorname{Var}\left(X\right)' + assert latex(Probability(X > 0)) == r'\operatorname{P}\left(X > 0\right)' + Y = Normal("Y", mu, sigma) + assert latex(Covariance(X, Y)) == r'\operatorname{Cov}\left(X, Y\right)' + + +def test_trace(): + # Issue 15303 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A)) == r"\operatorname{tr}\left(A \right)" + assert latex(trace(A**2)) == r"\operatorname{tr}\left(A^{2} \right)" + + +def test_print_basic(): + # Issue 15303 + from sympy.core.basic import Basic + from sympy.core.expr import Expr + + # dummy class for testing printing where the function is not + # implemented in latex.py + class UnimplementedExpr(Expr): + def __new__(cls, e): + return Basic.__new__(cls, e) + + # dummy function for testing + def unimplemented_expr(expr): + return UnimplementedExpr(expr).doit() + + # override class name to use superscript / subscript + def unimplemented_expr_sup_sub(expr): + result = UnimplementedExpr(expr) + result.__class__.__name__ = 'UnimplementedExpr_x^1' + return result + + assert latex(unimplemented_expr(x)) == r'\operatorname{UnimplementedExpr}\left(x\right)' + assert latex(unimplemented_expr(x**2)) == \ + r'\operatorname{UnimplementedExpr}\left(x^{2}\right)' + assert latex(unimplemented_expr_sup_sub(x)) == \ + r'\operatorname{UnimplementedExpr^{1}_{x}}\left(x\right)' + + +def test_MatrixSymbol_bold(): + # Issue #15871 + from sympy.matrices.expressions.trace import trace + A = MatrixSymbol("A", 2, 2) + assert latex(trace(A), mat_symbol_style='bold') == \ + r"\operatorname{tr}\left(\mathbf{A} \right)" + assert latex(trace(A), mat_symbol_style='plain') == \ + r"\operatorname{tr}\left(A \right)" + + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + C = MatrixSymbol("C", 3, 3) + + assert latex(-A, mat_symbol_style='bold') == r"- \mathbf{A}" + assert latex(A - A*B - B, mat_symbol_style='bold') == \ + r"\mathbf{A} - \mathbf{A} \mathbf{B} - \mathbf{B}" + assert latex(-A*B - A*B*C - B, mat_symbol_style='bold') == \ + r"- \mathbf{A} \mathbf{B} - \mathbf{A} \mathbf{B} \mathbf{C} - \mathbf{B}" + + A_k = MatrixSymbol("A_k", 3, 3) + assert latex(A_k, mat_symbol_style='bold') == r"\mathbf{A}_{k}" + + A = MatrixSymbol(r"\nabla_k", 3, 3) + assert latex(A, mat_symbol_style='bold') == r"\mathbf{\nabla}_{k}" + +def test_AppliedPermutation(): + p = Permutation(0, 1, 2) + x = Symbol('x') + assert latex(AppliedPermutation(p, x)) == \ + r'\sigma_{\left( 0\; 1\; 2\right)}(x)' + + +def test_PermutationMatrix(): + p = Permutation(0, 1, 2) + assert latex(PermutationMatrix(p)) == r'P_{\left( 0\; 1\; 2\right)}' + p = Permutation(0, 3)(1, 2) + assert latex(PermutationMatrix(p)) == \ + r'P_{\left( 0\; 3\right)\left( 1\; 2\right)}' + + +def test_issue_21758(): + from sympy.functions.elementary.piecewise import piecewise_fold + from sympy.series.fourier import FourierSeries + x = Symbol('x') + k, n = symbols('k n') + fo = FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), SeqFormula( + Piecewise((-2*pi*cos(n*pi)/n + 2*sin(n*pi)/n**2, (n > -oo) & (n < oo) & Ne(n, 0)), + (0, True))*sin(n*x)/pi, (n, 1, oo)))) + assert latex(piecewise_fold(fo)) == '\\begin{cases} 2 \\sin{\\left(x \\right)}' \ + ' - \\sin{\\left(2 x \\right)} + \\frac{2 \\sin{\\left(3 x \\right)}}{3} +' \ + ' \\ldots & \\text{for}\\: n > -\\infty \\wedge n < \\infty \\wedge ' \ + 'n \\neq 0 \\\\0 & \\text{otherwise} \\end{cases}' + assert latex(FourierSeries(x, (x, -pi, pi), (0, SeqFormula(0, (k, 1, oo)), + SeqFormula(0, (n, 1, oo))))) == '0' + + +def test_imaginary_unit(): + assert latex(1 + I) == r'1 + i' + assert latex(1 + I, imaginary_unit='i') == r'1 + i' + assert latex(1 + I, imaginary_unit='j') == r'1 + j' + assert latex(1 + I, imaginary_unit='foo') == r'1 + foo' + assert latex(I, imaginary_unit="ti") == r'\text{i}' + assert latex(I, imaginary_unit="tj") == r'\text{j}' + + +def test_text_re_im(): + assert latex(im(x), gothic_re_im=True) == r'\Im{\left(x\right)}' + assert latex(im(x), gothic_re_im=False) == r'\operatorname{im}{\left(x\right)}' + assert latex(re(x), gothic_re_im=True) == r'\Re{\left(x\right)}' + assert latex(re(x), gothic_re_im=False) == r'\operatorname{re}{\left(x\right)}' + + +def test_latex_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField, Differential + from sympy.diffgeom.rn import R2 + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert latex(m) == r'\text{M}' + p = Patch('P', m) + assert latex(p) == r'\text{P}_{\text{M}}' + rect = CoordSystem('rect', p, [x, y]) + assert latex(rect) == r'\text{rect}^{\text{P}}_{\text{M}}' + b = BaseScalarField(rect, 0) + assert latex(b) == r'\mathbf{x}' + + g = Function('g') + s_field = g(R2.x, R2.y) + assert latex(Differential(s_field)) == \ + r'\operatorname{d}\left(g{\left(\mathbf{x},\mathbf{y} \right)}\right)' + + +def test_unit_printing(): + assert latex(5*meter) == r'5 \text{m}' + assert latex(3*gibibyte) == r'3 \text{gibibyte}' + assert latex(4*microgram/second) == r'\frac{4 \mu\text{g}}{\text{s}}' + assert latex(4*micro*gram/second) == r'\frac{4 \mu \text{g}}{\text{s}}' + assert latex(5*milli*meter) == r'5 \text{m} \text{m}' + assert latex(milli) == r'\text{m}' + + +def test_issue_17092(): + x_star = Symbol('x^*') + assert latex(Derivative(x_star, x_star,2)) == r'\frac{d^{2}}{d \left(x^{*}\right)^{2}} x^{*}' + + +def test_latex_decimal_separator(): + + x, y, z, t = symbols('x y z t') + k, m, n = symbols('k m n', integer=True) + f, g, h = symbols('f g h', cls=Function) + + # comma decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='comma') == r'\left[ 1; \ 2{,}3; \ 4{,}5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='comma') == r'\left\{1; 2{,}3; 4{,}5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'comma') == r'\left( 1; \ 2{,}3; \ 4{,}6\right)') + assert(latex((1,), decimal_separator='comma') == r'\left( 1;\right)') + + # period decimal_separator + assert(latex([1, 2.3, 4.5], decimal_separator='period') == r'\left[ 1, \ 2.3, \ 4.5\right]' ) + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6), decimal_separator = 'period') == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,), decimal_separator='period') == r'\left( 1,\right)') + + # default decimal_separator + assert(latex([1, 2.3, 4.5]) == r'\left[ 1, \ 2.3, \ 4.5\right]') + assert(latex(FiniteSet(1, 2.3, 4.5)) == r'\left\{1, 2.3, 4.5\right\}') + assert(latex((1, 2.3, 4.6)) == r'\left( 1, \ 2.3, \ 4.6\right)') + assert(latex((1,)) == r'\left( 1,\right)') + + assert(latex(Mul(3.4,5.3), decimal_separator = 'comma') == r'18{,}02') + assert(latex(3.4*5.3, decimal_separator = 'comma') == r'18{,}02') + x = symbols('x') + y = symbols('y') + z = symbols('z') + assert(latex(x*5.3 + 2**y**3.4 + 4.5 + z, decimal_separator = 'comma') == r'2^{y^{3{,}4}} + 5{,}3 x + z + 4{,}5') + + assert(latex(0.987, decimal_separator='comma') == r'0{,}987') + assert(latex(S(0.987), decimal_separator='comma') == r'0{,}987') + assert(latex(.3, decimal_separator='comma') == r'0{,}3') + assert(latex(S(.3), decimal_separator='comma') == r'0{,}3') + + + assert(latex(5.8*10**(-7), decimal_separator='comma') == r'5{,}8 \cdot 10^{-7}') + assert(latex(S(5.7)*10**(-7), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + assert(latex(S(5.7*10**(-7)), decimal_separator='comma') == r'5{,}7 \cdot 10^{-7}') + + x = symbols('x') + assert(latex(1.2*x+3.4, decimal_separator='comma') == r'1{,}2 x + 3{,}4') + assert(latex(FiniteSet(1, 2.3, 4.5), decimal_separator='period') == r'\left\{1, 2.3, 4.5\right\}') + + # Error Handling tests + raises(ValueError, lambda: latex([1,2.3,4.5], decimal_separator='non_existing_decimal_separator_in_list')) + raises(ValueError, lambda: latex(FiniteSet(1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_set')) + raises(ValueError, lambda: latex((1,2.3,4.5), decimal_separator='non_existing_decimal_separator_in_tuple')) + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == r'x' + +def test_latex_escape(): + assert latex_escape(r"~^\&%$#_{}") == "".join([ + r'\textasciitilde', + r'\textasciicircum', + r'\textbackslash', + r'\&', + r'\%', + r'\$', + r'\#', + r'\_', + r'\{', + r'\}', + ]) + +def test_emptyPrinter(): + class MyObject: + def __repr__(self): + return "" + + # unknown objects are monospaced + assert latex(MyObject()) == r"\mathtt{\text{}}" + + # even if they are nested within other objects + assert latex((MyObject(),)) == r"\left( \mathtt{\text{}},\right)" + +def test_global_settings(): + import inspect + + # settings should be visible in the signature of `latex` + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + try: + # but changing the defaults... + LatexPrinter.set_global_settings(imaginary_unit='j') + # ... should change the signature + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'j' + assert latex(I) == r'j' + finally: + # there's no public API to undo this, but we need to make sure we do + # so as not to impact other tests + del LatexPrinter._global_settings['imaginary_unit'] + + # check we really did undo it + assert inspect.signature(latex).parameters['imaginary_unit'].default == r'i' + assert latex(I) == r'i' + +def test_pickleable(): + # this tests that the _PrintFunction instance is pickleable + import pickle + assert pickle.loads(pickle.dumps(latex)) is latex + +def test_printing_latex_array_expressions(): + assert latex(ArraySymbol("A", (2, 3, 4))) == "A" + assert latex(ArrayElement("A", (2, 1/(1-x), 0))) == "{{A}_{2, \\frac{1}{1 - x}, 0}}" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert latex(ArrayElement(M*N, [x, 0])) == "{{\\left(M N\\right)}_{x, 0}}" + +def test_Array(): + arr = Array(range(10)) + assert latex(arr) == r'\left[\begin{matrix}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9\end{matrix}\right]' + + arr = Array(range(11)) + # fill the empty argument with a bunch of 'c' to avoid latex errors + assert latex(arr) == r'\left[\begin{array}{ccccccccccc}0 & 1 & 2 & 3 & 4 & 5 & 6 & 7 & 8 & 9 & 10\end{array}\right]' + +def test_latex_with_unevaluated(): + with evaluate(False): + assert latex(a * a) == r"a a" diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_llvmjit.py b/lib/python3.10/site-packages/sympy/printing/tests/test_llvmjit.py new file mode 100644 index 0000000000000000000000000000000000000000..709476f1d7517dc629210341594a70dc6f41808f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_llvmjit.py @@ -0,0 +1,224 @@ +from sympy.external import import_module +from sympy.testing.pytest import raises +import ctypes + + +if import_module('llvmlite'): + import sympy.printing.llvmjitcode as g +else: + disabled = True + +import sympy +from sympy.abc import a, b, n + + +# copied from numpy.isclose documentation +def isclose(a, b): + rtol = 1e-5 + atol = 1e-8 + return abs(a-b) <= atol + rtol*abs(b) + + +def test_simple_expr(): + e = a + 1.0 + f = g.llvm_callable([a], e) + res = float(e.subs({a: 4.0}).evalf()) + jit_res = f(4.0) + + assert isclose(jit_res, res) + + +def test_two_arg(): + e = 4.0*a + b + 3.0 + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 4.0, b: 3.0}).evalf()) + jit_res = f(4.0, 3.0) + + assert isclose(jit_res, res) + + +def test_func(): + e = 4.0*sympy.exp(-a) + f = g.llvm_callable([a], e) + res = float(e.subs({a: 1.5}).evalf()) + jit_res = f(1.5) + + assert isclose(jit_res, res) + + +def test_two_func(): + e = 4.0*sympy.exp(-a) + sympy.exp(b) + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_two_sqrt(): + e = 4.0*sympy.sqrt(a) + sympy.sqrt(b) + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_two_pow(): + e = a**1.5 + b**7 + f = g.llvm_callable([a, b], e) + res = float(e.subs({a: 1.5, b: 2.0}).evalf()) + jit_res = f(1.5, 2.0) + + assert isclose(jit_res, res) + + +def test_callback(): + e = a + 1.2 + f = g.llvm_callable([a], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + inp = {a: 2.2} + array = array_type(inp[a]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_callback_cubature(): + e = a + 1.2 + f = g.llvm_callable([a], e, callback_type='cubature') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + inp = {a: 2.2} + array = array_type(inp[a]) + out_array = array_type(0.0) + jit_ret = f(m, array, None, m, out_array) + + assert jit_ret == 0 + + res = float(e.subs(inp).evalf()) + + assert isclose(out_array[0], res) + + +def test_callback_two(): + e = 3*a*b + f = g.llvm_callable([a, b], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(2) + array_type = ctypes.c_double * 2 + inp = {a: 0.2, b: 1.7} + array = array_type(inp[a], inp[b]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_callback_alt_two(): + d = sympy.IndexedBase('d') + e = 3*d[0]*d[1] + f = g.llvm_callable([n, d], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(2) + array_type = ctypes.c_double * 2 + inp = {d[0]: 0.2, d[1]: 1.7} + array = array_type(inp[d[0]], inp[d[1]]) + jit_res = f(m, array) + + res = float(e.subs(inp).evalf()) + + assert isclose(jit_res, res) + + +def test_multiple_statements(): + # Match return from CSE + e = [[(b, 4.0*a)], [b + 5]] + f = g.llvm_callable([a], e) + b_val = e[0][0][1].subs({a: 1.5}) + res = float(e[1][0].subs({b: b_val}).evalf()) + jit_res = f(1.5) + assert isclose(jit_res, res) + + f_callback = g.llvm_callable([a], e, callback_type='scipy.integrate.test') + m = ctypes.c_int(1) + array_type = ctypes.c_double * 1 + array = array_type(1.5) + jit_callback_res = f_callback(m, array) + assert isclose(jit_callback_res, res) + + +def test_cse(): + e = a*a + b*b + sympy.exp(-a*a - b*b) + e2 = sympy.cse(e) + f = g.llvm_callable([a, b], e2) + res = float(e.subs({a: 2.3, b: 0.1}).evalf()) + jit_res = f(2.3, 0.1) + + assert isclose(jit_res, res) + + +def eval_cse(e, sub_dict): + tmp_dict = {} + for tmp_name, tmp_expr in e[0]: + e2 = tmp_expr.subs(sub_dict) + e3 = e2.subs(tmp_dict) + tmp_dict[tmp_name] = e3 + return [e.subs(sub_dict).subs(tmp_dict) for e in e[1]] + + +def test_cse_multiple(): + e1 = a*a + e2 = a*a + b*b + e3 = sympy.cse([e1, e2]) + + raises(NotImplementedError, + lambda: g.llvm_callable([a, b], e3, callback_type='scipy.integrate')) + + f = g.llvm_callable([a, b], e3) + jit_res = f(0.1, 1.5) + assert len(jit_res) == 2 + res = eval_cse(e3, {a: 0.1, b: 1.5}) + assert isclose(res[0], jit_res[0]) + assert isclose(res[1], jit_res[1]) + + +def test_callback_cubature_multiple(): + e1 = a*a + e2 = a*a + b*b + e3 = sympy.cse([e1, e2, 4*e2]) + f = g.llvm_callable([a, b], e3, callback_type='cubature') + + # Number of input variables + ndim = 2 + # Number of output expression values + outdim = 3 + + m = ctypes.c_int(ndim) + fdim = ctypes.c_int(outdim) + array_type = ctypes.c_double * ndim + out_array_type = ctypes.c_double * outdim + inp = {a: 0.2, b: 1.5} + array = array_type(inp[a], inp[b]) + out_array = out_array_type() + jit_ret = f(m, array, None, fdim, out_array) + + assert jit_ret == 0 + + res = eval_cse(e3, inp) + + assert isclose(out_array[0], res[0]) + assert isclose(out_array[1], res[1]) + assert isclose(out_array[2], res[2]) + + +def test_symbol_not_found(): + e = a*a + b + raises(LookupError, lambda: g.llvm_callable([a], e)) + + +def test_bad_callback(): + e = a + raises(ValueError, lambda: g.llvm_callable([a], e, callback_type='bad_callback')) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_maple.py b/lib/python3.10/site-packages/sympy/printing/tests/test_maple.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb4c512ad3203bd64ae56b350e15734b3a6afb0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_maple.py @@ -0,0 +1,381 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.core import EulerGamma, GoldenRatio, Catalan, Lambda, Mul, Pow +from sympy.functions import Piecewise, sqrt, ceiling, exp, sin, cos, sinc, lucas +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix) +from sympy.functions.special.bessel import besseli + +from sympy.printing.maple import maple_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert maple_code(Integer(67)) == "67" + assert maple_code(Integer(-1)) == "-1" + + +def test_Rational(): + assert maple_code(Rational(3, 7)) == "3/7" + assert maple_code(Rational(18, 9)) == "2" + assert maple_code(Rational(3, -7)) == "-3/7" + assert maple_code(Rational(-3, -7)) == "3/7" + assert maple_code(x + Rational(3, 7)) == "x + 3/7" + assert maple_code(Rational(3, 7) * x) == '(3/7)*x' + + +def test_Relational(): + assert maple_code(Eq(x, y)) == "x = y" + assert maple_code(Ne(x, y)) == "x <> y" + assert maple_code(Le(x, y)) == "x <= y" + assert maple_code(Lt(x, y)) == "x < y" + assert maple_code(Gt(x, y)) == "x > y" + assert maple_code(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert maple_code(sin(x) ** cos(x)) == "sin(x)^cos(x)" + assert maple_code(abs(x)) == "abs(x)" + assert maple_code(ceiling(x)) == "ceil(x)" + + +def test_Pow(): + assert maple_code(x ** 3) == "x^3" + assert maple_code(x ** (y ** 3)) == "x^(y^3)" + + assert maple_code((x ** 3) ** y) == "(x^3)^y" + assert maple_code(x ** Rational(2, 3)) == 'x^(2/3)' + + g = implemented_function('g', Lambda(x, 2 * x)) + assert maple_code(1 / (g(x) * 3.5) ** (x - y ** x) / (x ** 2 + y)) == \ + "(3.5*2*x)^(-x + y^x)/(x^2 + y)" + # For issue 14160 + assert maple_code(Mul(-2, x, Pow(Mul(y, y, evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + + +def test_basic_ops(): + assert maple_code(x * y) == "x*y" + assert maple_code(x + y) == "x + y" + assert maple_code(x - y) == "x - y" + assert maple_code(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert maple_code(1 / x) == '1/x' + assert maple_code(x ** -1) == maple_code(x ** -1.0) == '1/x' + assert maple_code(1 / sqrt(x)) == '1/sqrt(x)' + assert maple_code(x ** -S.Half) == maple_code(x ** -0.5) == '1/sqrt(x)' + assert maple_code(sqrt(x)) == 'sqrt(x)' + assert maple_code(x ** S.Half) == maple_code(x ** 0.5) == 'sqrt(x)' + assert maple_code(1 / pi) == '1/Pi' + assert maple_code(pi ** -1) == maple_code(pi ** -1.0) == '1/Pi' + assert maple_code(pi ** -0.5) == '1/sqrt(Pi)' + + +def test_mix_number_mult_symbols(): + assert maple_code(3 * x) == "3*x" + assert maple_code(pi * x) == "Pi*x" + assert maple_code(3 / x) == "3/x" + assert maple_code(pi / x) == "Pi/x" + assert maple_code(x / 3) == '(1/3)*x' + assert maple_code(x / pi) == "x/Pi" + assert maple_code(x * y) == "x*y" + assert maple_code(3 * x * y) == "3*x*y" + assert maple_code(3 * pi * x * y) == "3*Pi*x*y" + assert maple_code(x / y) == "x/y" + assert maple_code(3 * x / y) == "3*x/y" + assert maple_code(x * y / z) == "x*y/z" + assert maple_code(x / y * z) == "x*z/y" + assert maple_code(1 / x / y) == "1/(x*y)" + assert maple_code(2 * pi * x / y / z) == "2*Pi*x/(y*z)" + assert maple_code(3 * pi / x) == "3*Pi/x" + assert maple_code(S(3) / 5) == "3/5" + assert maple_code(S(3) / 5 * x) == '(3/5)*x' + assert maple_code(x / y / z) == "x/(y*z)" + assert maple_code((x + y) / z) == "(x + y)/z" + assert maple_code((x + y) / (z + x)) == "(x + y)/(x + z)" + assert maple_code((x + y) / EulerGamma) == '(x + y)/gamma' + assert maple_code(x / 3 / pi) == '(1/3)*x/Pi' + assert maple_code(S(3) / 5 * x * y / pi) == '(3/5)*x*y/Pi' + + +def test_mix_number_pow_symbols(): + assert maple_code(pi ** 3) == 'Pi^3' + assert maple_code(x ** 2) == 'x^2' + + assert maple_code(x ** (pi ** 3)) == 'x^(Pi^3)' + assert maple_code(x ** y) == 'x^y' + + assert maple_code(x ** (y ** z)) == 'x^(y^z)' + assert maple_code((x ** y) ** z) == '(x^y)^z' + + +def test_imag(): + I = S('I') + assert maple_code(I) == "I" + assert maple_code(5 * I) == "5*I" + + assert maple_code((S(3) / 2) * I) == "(3/2)*I" + assert maple_code(3 + 4 * I) == "3 + 4*I" + + +def test_constants(): + assert maple_code(pi) == "Pi" + assert maple_code(oo) == "infinity" + assert maple_code(-oo) == "-infinity" + assert maple_code(S.NegativeInfinity) == "-infinity" + assert maple_code(S.NaN) == "undefined" + assert maple_code(S.Exp1) == "exp(1)" + assert maple_code(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert maple_code(2 * GoldenRatio) == '2*(1/2 + (1/2)*sqrt(5))' + assert maple_code(2 * Catalan) == '2*Catalan' + assert maple_code(2 * EulerGamma) == "2*gamma" + + +def test_boolean(): + assert maple_code(x & y) == "x and y" + assert maple_code(x | y) == "x or y" + assert maple_code(~x) == "not x" + assert maple_code(x & y & z) == "x and y and z" + assert maple_code(x | y | z) == "x or y or z" + assert maple_code((x & y) | z) == "z or x and y" + assert maple_code((x | y) & z) == "z and (x or y)" + + +def test_Matrices(): + assert maple_code(Matrix(1, 1, [10])) == \ + 'Matrix([[10]], storage = rectangular)' + + A = Matrix([[1, sin(x / 2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]) + expected = \ + 'Matrix(' \ + '[[1, sin((1/2)*x), abs(x)],' \ + ' [0, 1, Pi],' \ + ' [0, exp(1), ceil(x)]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + # row and columns + assert maple_code(A[:, 0]) == \ + 'Matrix([[1], [0], [0]], storage = rectangular)' + assert maple_code(A[0, :]) == \ + 'Matrix([[1, sin((1/2)*x), abs(x)]], storage = rectangular)' + assert maple_code(Matrix([[x, x - y, -y]])) == \ + 'Matrix([[x, x - y, -y]], storage = rectangular)' + + # empty matrices + assert maple_code(Matrix(0, 0, [])) == \ + 'Matrix([], storage = rectangular)' + assert maple_code(Matrix(0, 3, [])) == \ + 'Matrix([], storage = rectangular)' + +def test_SparseMatrices(): + assert maple_code(SparseMatrix(Identity(2))) == 'Matrix([[1, 0], [0, 1]], storage = sparse)' + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5]]) + assert maple_code(A) == \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x]], storage = rectangular)' + assert maple_code(A.T) == \ + 'Matrix([[1], [sin(2/x)], [(3/5)*Pi/x]], storage = rectangular)' + + +def test_Matrices_entries_not_hadamard(): + A = Matrix([[1, sin(2 / x), 3 * pi / x / 5], [1, 2, x * y]]) + expected = \ + 'Matrix([[1, sin(2/x), (3/5)*Pi/x], [1, 2, x*y]], ' \ + 'storage = rectangular)' + assert maple_code(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert maple_code(A * B) == "A.B" + assert maple_code(B * A) == "B.A" + assert maple_code(2 * A * B) == "2*A.B" + assert maple_code(B * 2 * A) == "2*B.A" + + assert maple_code( + A * (B + 3 * Identity(n))) == "A.(3*Matrix(n, shape = identity) + B)" + + assert maple_code(A ** (x ** 2)) == "MatrixPower(A, x^2)" + assert maple_code(A ** 3) == "MatrixPower(A, 3)" + assert maple_code(A ** (S.Half)) == "MatrixPower(A, 1/2)" + + +def test_special_matrices(): + assert maple_code(6 * Identity(3)) == "6*Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = sparse)" + assert maple_code(Identity(x)) == 'Matrix(x, shape = identity)' + + +def test_containers(): + assert maple_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "[1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]" + + assert maple_code((1, 2, (3, 4))) == "[1, 2, [3, 4]]" + assert maple_code([1]) == "[1]" + assert maple_code((1,)) == "[1]" + assert maple_code(Tuple(*[1, 2, 3])) == "[1, 2, 3]" + assert maple_code((1, x * y, (3, x ** 2))) == "[1, x*y, [3, x^2]]" + # scalar, matrix, empty matrix and empty list + + assert maple_code((1, eye(3), Matrix(0, 0, []), [])) == \ + "[1, Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]], storage = rectangular), Matrix([], storage = rectangular), []]" + + +def test_maple_noninline(): + source = maple_code((x + y)/Catalan, assign_to='me', inline=False) + expected = "me := (x + y)/Catalan" + + assert source == expected + + +def test_maple_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert maple_code(A, assign_to='a') == "a := Matrix([[1, 2, 3]], storage = rectangular)" + A = Matrix([[1, 2], [3, 4]]) + assert maple_code(A, assign_to='A') == "A := Matrix([[1, 2], [3, 4]], storage = rectangular)" + + +def test_maple_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert maple_code(A, assign_to=B) == "B := Matrix([[1, 2, 3]], storage = rectangular)" + raises(ValueError, lambda: maple_code(A, assign_to=x)) + raises(ValueError, lambda: maple_code(A, assign_to=C)) + + +def test_maple_matrix_1x1(): + A = Matrix([[3]]) + assert maple_code(A, assign_to='B') == "B := Matrix([[3]], storage = rectangular)" + + +def test_maple_matrix_elements(): + A = Matrix([[x, 2, x * y]]) + + assert maple_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x^2 + x*y + 2" + AA = MatrixSymbol('AA', 1, 3) + assert maple_code(AA) == "AA" + + assert maple_code(AA[0, 0] ** 2 + sin(AA[0, 1]) + AA[0, 2]) == \ + "sin(AA[1, 2]) + AA[1, 1]^2 + AA[1, 3]" + assert maple_code(sum(AA)) == "AA[1, 1] + AA[1, 2] + AA[1, 3]" + + +def test_maple_boolean(): + assert maple_code(True) == "true" + assert maple_code(S.true) == "true" + assert maple_code(False) == "false" + assert maple_code(S.false) == "false" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10 + M[1, 2] = 20 + M[1, 3] = 22 + M[0, 3] = 30 + M[3, 0] = x * y + assert maple_code(M) == \ + 'Matrix([[0, 0, 0, 30, 0, 0],' \ + ' [0, 0, 20, 22, 0, 0],' \ + ' [0, 0, 10, 0, 0, 0],' \ + ' [x*y, 0, 0, 0, 0, 0],' \ + ' [0, 0, 0, 0, 0, 0]], ' \ + 'storage = sparse)' + +# Not an important point. +def test_maple_not_supported(): + with raises(NotImplementedError): + maple_code(S.ComplexInfinity) + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + + assert (maple_code(A[0, 0]) == "A[1, 1]") + assert (maple_code(3 * A[0, 0]) == "3*A[1, 1]") + + F = A-B + + assert (maple_code(F[0,0]) == "A[1, 1] - B[1, 1]") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + assert maple_code(C) == "A*B" + + assert maple_code(C * v) == "(A*B).v" + # HadamardProduct is higher than dot product. + + assert maple_code(h * C * v) == "h.(A*B).v" + + assert maple_code(C * A) == "(A*B).A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + + assert maple_code(C * x * y) == "x*y*(A*B)" + + +def test_maple_piecewise(): + expr = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(expr) == "piecewise(x < 1, x, x^2)" + assert maple_code(expr, assign_to="r") == ( + "r := piecewise(x < 1, x, x^2)") + + expr = Piecewise((x ** 2, x < 1), (x ** 3, x < 2), (x ** 4, x < 3), (x ** 5, True)) + expected = "piecewise(x < 1, x^2, x < 2, x^3, x < 3, x^4, x^5)" + assert maple_code(expr) == expected + assert maple_code(expr, assign_to="r") == "r := " + expected + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: maple_code(expr)) + + +def test_maple_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + + assert maple_code(2 * pw) == "2*piecewise(x < 1, x, x^2)" + assert maple_code(pw / x) == "piecewise(x < 1, x, x^2)/x" + assert maple_code(pw / (x * y)) == "piecewise(x < 1, x, x^2)/(x*y)" + assert maple_code(pw / 3) == "(1/3)*piecewise(x < 1, x, x^2)" + + +def test_maple_derivatives(): + f = Function('f') + assert maple_code(f(x).diff(x)) == 'diff(f(x), x)' + assert maple_code(f(x).diff(x, 2)) == 'diff(f(x), x$2)' + + +def test_automatic_rewrites(): + assert maple_code(lucas(x)) == '(2^(-x)*((1 - sqrt(5))^x + (1 + sqrt(5))^x))' + assert maple_code(sinc(x)) == '(piecewise(x <> 0, sin(x)/x, 1))' + + +def test_specfun(): + assert maple_code('asin(x)') == 'arcsin(x)' + assert maple_code(besseli(x, y)) == 'BesselI(x, y)' diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_mathematica.py b/lib/python3.10/site-packages/sympy/printing/tests/test_mathematica.py new file mode 100644 index 0000000000000000000000000000000000000000..5780ceb900ab5ad34ba0dccdb10281a3f1dc5d18 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_mathematica.py @@ -0,0 +1,287 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, Tuple, + Derivative, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.integrals import Integral +from sympy.concrete import Sum +from sympy.functions import (exp, sin, cos, fresnelc, fresnels, conjugate, Max, + Min, gamma, polygamma, loggamma, erf, erfi, erfc, + erf2, expint, erfinv, erfcinv, Ei, Si, Ci, li, + Shi, Chi, uppergamma, beta, subfactorial, erf2inv, + factorial, factorial2, catalan, RisingFactorial, + FallingFactorial, harmonic, atan2, sec, acsc, + hermite, laguerre, assoc_laguerre, jacobi, + gegenbauer, chebyshevt, chebyshevu, legendre, + assoc_legendre, Li, LambertW) + +from sympy.printing.mathematica import mathematica_code as mcode + +x, y, z, w = symbols('x,y,z,w') +f = Function('f') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "(3/7)*x" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(f(x, y, z)) == "f[x, y, z]" + assert mcode(sin(x) ** cos(x)) == "Sin[x]^Cos[x]" + assert mcode(sec(x) * acsc(x)) == "ArcCsc[x]*Sec[x]" + assert mcode(atan2(x, y)) == "ArcTan[x, y]" + assert mcode(conjugate(x)) == "Conjugate[x]" + assert mcode(Max(x, y, z)*Min(y, z)) == "Max[x, y, z]*Min[y, z]" + assert mcode(fresnelc(x)) == "FresnelC[x]" + assert mcode(fresnels(x)) == "FresnelS[x]" + assert mcode(gamma(x)) == "Gamma[x]" + assert mcode(uppergamma(x, y)) == "Gamma[x, y]" + assert mcode(polygamma(x, y)) == "PolyGamma[x, y]" + assert mcode(loggamma(x)) == "LogGamma[x]" + assert mcode(erf(x)) == "Erf[x]" + assert mcode(erfc(x)) == "Erfc[x]" + assert mcode(erfi(x)) == "Erfi[x]" + assert mcode(erf2(x, y)) == "Erf[x, y]" + assert mcode(expint(x, y)) == "ExpIntegralE[x, y]" + assert mcode(erfcinv(x)) == "InverseErfc[x]" + assert mcode(erfinv(x)) == "InverseErf[x]" + assert mcode(erf2inv(x, y)) == "InverseErf[x, y]" + assert mcode(Ei(x)) == "ExpIntegralEi[x]" + assert mcode(Ci(x)) == "CosIntegral[x]" + assert mcode(li(x)) == "LogIntegral[x]" + assert mcode(Si(x)) == "SinIntegral[x]" + assert mcode(Shi(x)) == "SinhIntegral[x]" + assert mcode(Chi(x)) == "CoshIntegral[x]" + assert mcode(beta(x, y)) == "Beta[x, y]" + assert mcode(factorial(x)) == "Factorial[x]" + assert mcode(factorial2(x)) == "Factorial2[x]" + assert mcode(subfactorial(x)) == "Subfactorial[x]" + assert mcode(FallingFactorial(x, y)) == "FactorialPower[x, y]" + assert mcode(RisingFactorial(x, y)) == "Pochhammer[x, y]" + assert mcode(catalan(x)) == "CatalanNumber[x]" + assert mcode(harmonic(x)) == "HarmonicNumber[x]" + assert mcode(harmonic(x, y)) == "HarmonicNumber[x, y]" + assert mcode(Li(x)) == "LogIntegral[x] - LogIntegral[2]" + assert mcode(LambertW(x)) == "ProductLog[x]" + assert mcode(LambertW(x, -1)) == "ProductLog[-1, x]" + assert mcode(LambertW(x, y)) == "ProductLog[y, x]" + + +def test_special_polynomials(): + assert mcode(hermite(x, y)) == "HermiteH[x, y]" + assert mcode(laguerre(x, y)) == "LaguerreL[x, y]" + assert mcode(assoc_laguerre(x, y, z)) == "LaguerreL[x, y, z]" + assert mcode(jacobi(x, y, z, w)) == "JacobiP[x, y, z, w]" + assert mcode(gegenbauer(x, y, z)) == "GegenbauerC[x, y, z]" + assert mcode(chebyshevt(x, y)) == "ChebyshevT[x, y]" + assert mcode(chebyshevu(x, y)) == "ChebyshevU[x, y]" + assert mcode(legendre(x, y)) == "LegendreP[x, y]" + assert mcode(assoc_legendre(x, y, z)) == "LegendreP[x, y, z]" + + +def test_Pow(): + assert mcode(x**3) == "x^3" + assert mcode(x**(y**3)) == "x^(y^3)" + assert mcode(1/(f(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*f[x])^(-x + y^x)/(x^2 + y)" + assert mcode(x**-1.0) == 'x^(-1.0)' + assert mcode(x**Rational(2, 3)) == 'x^(2/3)' + + +def test_Mul(): + A, B, C, D = symbols('A B C D', commutative=False) + assert mcode(x*y*z) == "x*y*z" + assert mcode(x*y*A) == "x*y*A" + assert mcode(x*y*A*B) == "x*y*A**B" + assert mcode(x*y*A*B*C) == "x*y*A**B**C" + assert mcode(x*A*B*(C + D)*A*y) == "x*y*A**B**(C + D)**A" + + +def test_constants(): + assert mcode(S.Zero) == "0" + assert mcode(S.One) == "1" + assert mcode(S.NegativeOne) == "-1" + assert mcode(S.Half) == "1/2" + assert mcode(S.ImaginaryUnit) == "I" + + assert mcode(oo) == "Infinity" + assert mcode(S.NegativeInfinity) == "-Infinity" + assert mcode(S.ComplexInfinity) == "ComplexInfinity" + assert mcode(S.NaN) == "Indeterminate" + + assert mcode(S.Exp1) == "E" + assert mcode(pi) == "Pi" + assert mcode(S.GoldenRatio) == "GoldenRatio" + assert mcode(S.TribonacciConstant) == \ + "(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(2*S.TribonacciConstant) == \ + "2*(1/3 + (1/3)*(19 - 3*33^(1/2))^(1/3) + " \ + "(1/3)*(3*33^(1/2) + 19)^(1/3))" + assert mcode(S.EulerGamma) == "EulerGamma" + assert mcode(S.Catalan) == "Catalan" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + + +def test_matrices(): + from sympy.matrices import MutableDenseMatrix, MutableSparseMatrix, \ + ImmutableDenseMatrix, ImmutableSparseMatrix + A = MutableDenseMatrix( + [[1, -1, 0, 0], + [0, 1, -1, 0], + [0, 0, 1, -1], + [0, 0, 0, 1]] + ) + B = MutableSparseMatrix(A) + C = ImmutableDenseMatrix(A) + D = ImmutableSparseMatrix(A) + + assert mcode(C) == mcode(A) == \ + "{{1, -1, 0, 0}, " \ + "{0, 1, -1, 0}, " \ + "{0, 0, 1, -1}, " \ + "{0, 0, 0, 1}}" + + assert mcode(D) == mcode(B) == \ + "SparseArray[{" \ + "{1, 1} -> 1, {1, 2} -> -1, {2, 2} -> 1, {2, 3} -> -1, " \ + "{3, 3} -> 1, {3, 4} -> -1, {4, 4} -> 1" \ + "}, {4, 4}]" + + # Trivial cases of matrices + assert mcode(MutableDenseMatrix(0, 0, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 0, [])) == 'SparseArray[{}, {0, 0}]' + assert mcode(MutableDenseMatrix(0, 3, [])) == '{}' + assert mcode(MutableSparseMatrix(0, 3, [])) == 'SparseArray[{}, {0, 3}]' + assert mcode(MutableDenseMatrix(3, 0, [])) == '{{}, {}, {}}' + assert mcode(MutableSparseMatrix(3, 0, [])) == 'SparseArray[{}, {3, 0}]' + +def test_NDArray(): + from sympy.tensor.array import ( + MutableDenseNDimArray, ImmutableDenseNDimArray, + MutableSparseNDimArray, ImmutableSparseNDimArray) + + example = MutableDenseNDimArray( + [[[1, 2, 3, 4], + [5, 6, 7, 8], + [9, 10, 11, 12]], + [[13, 14, 15, 16], + [17, 18, 19, 20], + [21, 22, 23, 24]]] + ) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = ImmutableDenseNDimArray(example) + + assert mcode(example) == \ + "{{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}}, " \ + "{{13, 14, 15, 16}, {17, 18, 19, 20}, {21, 22, 23, 24}}}" + + example = MutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + example = ImmutableSparseNDimArray(example) + + assert mcode(example) == \ + "SparseArray[{" \ + "{1, 1, 1} -> 1, {1, 1, 2} -> 2, {1, 1, 3} -> 3, " \ + "{1, 1, 4} -> 4, {1, 2, 1} -> 5, {1, 2, 2} -> 6, " \ + "{1, 2, 3} -> 7, {1, 2, 4} -> 8, {1, 3, 1} -> 9, " \ + "{1, 3, 2} -> 10, {1, 3, 3} -> 11, {1, 3, 4} -> 12, " \ + "{2, 1, 1} -> 13, {2, 1, 2} -> 14, {2, 1, 3} -> 15, " \ + "{2, 1, 4} -> 16, {2, 2, 1} -> 17, {2, 2, 2} -> 18, " \ + "{2, 2, 3} -> 19, {2, 2, 4} -> 20, {2, 3, 1} -> 21, " \ + "{2, 3, 2} -> 22, {2, 3, 3} -> 23, {2, 3, 4} -> 24" \ + "}, {2, 3, 4}]" + + +def test_Integral(): + assert mcode(Integral(sin(sin(x)), x)) == "Hold[Integrate[Sin[Sin[x]], x]]" + assert mcode(Integral(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Integrate[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_Derivative(): + assert mcode(Derivative(sin(x), x)) == "Hold[D[Sin[x], x]]" + assert mcode(Derivative(x, x)) == "Hold[D[x, x]]" + assert mcode(Derivative(sin(x)*y**4, x, 2)) == "Hold[D[y^4*Sin[x], {x, 2}]]" + assert mcode(Derivative(sin(x)*y**4, x, y, x)) == "Hold[D[y^4*Sin[x], x, y, x]]" + assert mcode(Derivative(sin(x)*y**4, x, y, 3, x)) == "Hold[D[y^4*Sin[x], x, {y, 3}, x]]" + + +def test_Sum(): + assert mcode(Sum(sin(x), (x, 0, 10))) == "Hold[Sum[Sin[x], {x, 0, 10}]]" + assert mcode(Sum(exp(-x**2 - y**2), + (x, -oo, oo), + (y, -oo, oo))) == \ + "Hold[Sum[Exp[-x^2 - y^2], {x, -Infinity, Infinity}, " \ + "{y, -Infinity, Infinity}]]" + + +def test_comment(): + from sympy.printing.mathematica import MCodePrinter + assert MCodePrinter()._get_comment("Hello World") == \ + "(* Hello World *)" + + +def test_userfuncs(): + # Dictionary mutation test + some_function = symbols("some_function", cls=Function) + my_user_functions = {"some_function": "SomeFunction"} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeFunction[z]' + + # List argument test + my_user_functions = \ + {"some_function": [(lambda x: True, "SomeOtherFunction")]} + assert mcode( + some_function(z), + user_functions=my_user_functions) == \ + 'SomeOtherFunction[z]' diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_mathml.py b/lib/python3.10/site-packages/sympy/printing/tests/test_mathml.py new file mode 100644 index 0000000000000000000000000000000000000000..433115d402fa4566f452409f3e47a608e900ee8c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_mathml.py @@ -0,0 +1,2037 @@ +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.concrete.summations import Sum +from sympy.core.basic import Basic +from sympy.core.containers import Tuple +from sympy.core.function import Derivative, Lambda, diff, Function +from sympy.core.numbers import (zoo, Float, Integer, I, oo, pi, E, + Rational) +from sympy.core.relational import Lt, Ge, Ne, Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols, Symbol +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (factorial2, + binomial, factorial) +from sympy.functions.combinatorial.numbers import (lucas, bell, + catalan, euler, tribonacci, fibonacci, bernoulli, primenu, primeomega, + totient, reduced_totient) +from sympy.functions.elementary.complexes import re, im, conjugate, Abs +from sympy.functions.elementary.exponential import exp, LambertW, log +from sympy.functions.elementary.hyperbolic import (tanh, acoth, atanh, + coth, asinh, acsch, asech, acosh, csch, sinh, cosh, sech) +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import Max, Min +from sympy.functions.elementary.trigonometric import (csc, sec, tan, + atan, sin, asec, cot, cos, acot, acsc, asin, acos) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.elliptic_integrals import (elliptic_pi, + elliptic_f, elliptic_k, elliptic_e) +from sympy.functions.special.error_functions import (fresnelc, + fresnels, Ei, expint) +from sympy.functions.special.gamma_functions import (gamma, uppergamma, + lowergamma) +from sympy.functions.special.mathieu_functions import (mathieusprime, + mathieus, mathieucprime, mathieuc) +from sympy.functions.special.polynomials import (jacobi, chebyshevu, + chebyshevt, hermite, assoc_legendre, gegenbauer, assoc_laguerre, + legendre, laguerre) +from sympy.functions.special.singularity_functions import SingularityFunction +from sympy.functions.special.zeta_functions import (polylog, stieltjes, + lerchphi, dirichlet_eta, zeta) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (Xor, Or, false, true, And, Equivalent, + Implies, Not) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.determinant import Determinant +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.physics.quantum import (ComplexSpace, FockSpace, hbar, + HilbertSpace, Dagger) +from sympy.printing.mathml import (MathMLPresentationPrinter, + MathMLPrinter, MathMLContentPrinter, mathml) +from sympy.series.limits import Limit +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Range +from sympy.sets.sets import (Interval, Union, SymmetricDifference, + Complement, FiniteSet, Intersection, ProductSet) +from sympy.stats.rv import RandomSymbol +from sympy.tensor.indexed import IndexedBase +from sympy.vector import (Divergence, CoordSys3D, Cross, Curl, Dot, + Laplacian, Gradient) +from sympy.testing.pytest import raises + +x, y, z, a, b, c, d, e, n = symbols('x:z a:e n') +mp = MathMLContentPrinter() +mpp = MathMLPresentationPrinter() + + +def test_mathml_printer(): + m = MathMLPrinter() + assert m.doprint(1+x) == mp.doprint(1+x) + + +def test_content_printmethod(): + assert mp.doprint(1 + x) == 'x1' + + +def test_content_mathml_core(): + mml_1 = mp._print(1 + x) + assert mml_1.nodeName == 'apply' + nodes = mml_1.childNodes + assert len(nodes) == 3 + assert nodes[0].nodeName == 'plus' + assert nodes[0].hasChildNodes() is False + assert nodes[0].nodeValue is None + assert nodes[1].nodeName in ['cn', 'ci'] + if nodes[1].nodeName == 'cn': + assert nodes[1].childNodes[0].nodeValue == '1' + assert nodes[2].childNodes[0].nodeValue == 'x' + else: + assert nodes[1].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mp._print(x**2) + assert mml_2.nodeName == 'apply' + nodes = mml_2.childNodes + assert nodes[1].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '2' + + mml_3 = mp._print(2*x) + assert mml_3.nodeName == 'apply' + nodes = mml_3.childNodes + assert nodes[0].nodeName == 'times' + assert nodes[1].childNodes[0].nodeValue == '2' + assert nodes[2].childNodes[0].nodeValue == 'x' + + mml = mp._print(Float(1.0, 2)*x) + assert mml.nodeName == 'apply' + nodes = mml.childNodes + assert nodes[0].nodeName == 'times' + assert nodes[1].childNodes[0].nodeValue == '1.0' + assert nodes[2].childNodes[0].nodeValue == 'x' + + +def test_content_mathml_functions(): + mml_1 = mp._print(sin(x)) + assert mml_1.nodeName == 'apply' + assert mml_1.childNodes[0].nodeName == 'sin' + assert mml_1.childNodes[1].nodeName == 'ci' + + mml_2 = mp._print(diff(sin(x), x, evaluate=False)) + assert mml_2.nodeName == 'apply' + assert mml_2.childNodes[0].nodeName == 'diff' + assert mml_2.childNodes[1].nodeName == 'bvar' + assert mml_2.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + + mml_3 = mp._print(diff(cos(x*y), x, evaluate=False)) + assert mml_3.nodeName == 'apply' + assert mml_3.childNodes[0].nodeName == 'partialdiff' + assert mml_3.childNodes[1].nodeName == 'bvar' + assert mml_3.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + + mml_4 = mp._print(Lambda((x, y), x * y)) + assert mml_4.nodeName == 'lambda' + assert mml_4.childNodes[0].nodeName == 'bvar' + assert mml_4.childNodes[0].childNodes[ + 0].nodeName == 'ci' # below bvar there's x/ci> + assert mml_4.childNodes[1].nodeName == 'bvar' + assert mml_4.childNodes[1].childNodes[ + 0].nodeName == 'ci' # below bvar there's y/ci> + assert mml_4.childNodes[2].nodeName == 'apply' + + +def test_content_mathml_limits(): + # XXX No unevaluated limits + lim_fun = sin(x)/x + mml_1 = mp._print(Limit(lim_fun, x, 0)) + assert mml_1.childNodes[0].nodeName == 'limit' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].toxml() == mp._print(lim_fun).toxml() + + +def test_content_mathml_integrals(): + integrand = x + mml_1 = mp._print(Integral(integrand, (x, 0, 1))) + assert mml_1.childNodes[0].nodeName == 'int' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].nodeName == 'uplimit' + assert mml_1.childNodes[4].toxml() == mp._print(integrand).toxml() + + +def test_content_mathml_matrices(): + A = Matrix([1, 2, 3]) + B = Matrix([[0, 5, 4], [2, 3, 1], [9, 7, 9]]) + mll_1 = mp._print(A) + assert mll_1.childNodes[0].nodeName == 'matrixrow' + assert mll_1.childNodes[0].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[0].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_1.childNodes[1].nodeName == 'matrixrow' + assert mll_1.childNodes[1].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_1.childNodes[2].nodeName == 'matrixrow' + assert mll_1.childNodes[2].childNodes[0].nodeName == 'cn' + assert mll_1.childNodes[2].childNodes[0].childNodes[0].nodeValue == '3' + mll_2 = mp._print(B) + assert mll_2.childNodes[0].nodeName == 'matrixrow' + assert mll_2.childNodes[0].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[0].childNodes[0].nodeValue == '0' + assert mll_2.childNodes[0].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[1].childNodes[0].nodeValue == '5' + assert mll_2.childNodes[0].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[0].childNodes[2].childNodes[0].nodeValue == '4' + assert mll_2.childNodes[1].nodeName == 'matrixrow' + assert mll_2.childNodes[1].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_2.childNodes[1].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[1].childNodes[0].nodeValue == '3' + assert mll_2.childNodes[1].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[1].childNodes[2].childNodes[0].nodeValue == '1' + assert mll_2.childNodes[2].nodeName == 'matrixrow' + assert mll_2.childNodes[2].childNodes[0].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[0].childNodes[0].nodeValue == '9' + assert mll_2.childNodes[2].childNodes[1].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[1].childNodes[0].nodeValue == '7' + assert mll_2.childNodes[2].childNodes[2].nodeName == 'cn' + assert mll_2.childNodes[2].childNodes[2].childNodes[0].nodeValue == '9' + + +def test_content_mathml_sums(): + summand = x + mml_1 = mp._print(Sum(summand, (x, 1, 10))) + assert mml_1.childNodes[0].nodeName == 'sum' + assert mml_1.childNodes[1].nodeName == 'bvar' + assert mml_1.childNodes[2].nodeName == 'lowlimit' + assert mml_1.childNodes[3].nodeName == 'uplimit' + assert mml_1.childNodes[4].toxml() == mp._print(summand).toxml() + + +def test_content_mathml_tuples(): + mml_1 = mp._print([2]) + assert mml_1.nodeName == 'list' + assert mml_1.childNodes[0].nodeName == 'cn' + assert len(mml_1.childNodes) == 1 + + mml_2 = mp._print([2, Integer(1)]) + assert mml_2.nodeName == 'list' + assert mml_2.childNodes[0].nodeName == 'cn' + assert mml_2.childNodes[1].nodeName == 'cn' + assert len(mml_2.childNodes) == 2 + + +def test_content_mathml_add(): + mml = mp._print(x**5 - x**4 + x) + assert mml.childNodes[0].nodeName == 'plus' + assert mml.childNodes[1].childNodes[0].nodeName == 'minus' + assert mml.childNodes[1].childNodes[1].nodeName == 'apply' + + +def test_content_mathml_Rational(): + mml_1 = mp._print(Rational(1, 1)) + """should just return a number""" + assert mml_1.nodeName == 'cn' + + mml_2 = mp._print(Rational(2, 5)) + assert mml_2.childNodes[0].nodeName == 'divide' + + +def test_content_mathml_constants(): + mml = mp._print(I) + assert mml.nodeName == 'imaginaryi' + + mml = mp._print(E) + assert mml.nodeName == 'exponentiale' + + mml = mp._print(oo) + assert mml.nodeName == 'infinity' + + mml = mp._print(pi) + assert mml.nodeName == 'pi' + + assert mathml(hbar) == '' + assert mathml(S.TribonacciConstant) == '' + assert mathml(S.GoldenRatio) == 'φ' + mml = mathml(S.EulerGamma) + assert mml == '' + + mml = mathml(S.EmptySet) + assert mml == '' + + mml = mathml(S.true) + assert mml == '' + + mml = mathml(S.false) + assert mml == '' + + mml = mathml(S.NaN) + assert mml == '' + + +def test_content_mathml_trig(): + mml = mp._print(sin(x)) + assert mml.childNodes[0].nodeName == 'sin' + + mml = mp._print(cos(x)) + assert mml.childNodes[0].nodeName == 'cos' + + mml = mp._print(tan(x)) + assert mml.childNodes[0].nodeName == 'tan' + + mml = mp._print(cot(x)) + assert mml.childNodes[0].nodeName == 'cot' + + mml = mp._print(csc(x)) + assert mml.childNodes[0].nodeName == 'csc' + + mml = mp._print(sec(x)) + assert mml.childNodes[0].nodeName == 'sec' + + mml = mp._print(asin(x)) + assert mml.childNodes[0].nodeName == 'arcsin' + + mml = mp._print(acos(x)) + assert mml.childNodes[0].nodeName == 'arccos' + + mml = mp._print(atan(x)) + assert mml.childNodes[0].nodeName == 'arctan' + + mml = mp._print(acot(x)) + assert mml.childNodes[0].nodeName == 'arccot' + + mml = mp._print(acsc(x)) + assert mml.childNodes[0].nodeName == 'arccsc' + + mml = mp._print(asec(x)) + assert mml.childNodes[0].nodeName == 'arcsec' + + mml = mp._print(sinh(x)) + assert mml.childNodes[0].nodeName == 'sinh' + + mml = mp._print(cosh(x)) + assert mml.childNodes[0].nodeName == 'cosh' + + mml = mp._print(tanh(x)) + assert mml.childNodes[0].nodeName == 'tanh' + + mml = mp._print(coth(x)) + assert mml.childNodes[0].nodeName == 'coth' + + mml = mp._print(csch(x)) + assert mml.childNodes[0].nodeName == 'csch' + + mml = mp._print(sech(x)) + assert mml.childNodes[0].nodeName == 'sech' + + mml = mp._print(asinh(x)) + assert mml.childNodes[0].nodeName == 'arcsinh' + + mml = mp._print(atanh(x)) + assert mml.childNodes[0].nodeName == 'arctanh' + + mml = mp._print(acosh(x)) + assert mml.childNodes[0].nodeName == 'arccosh' + + mml = mp._print(acoth(x)) + assert mml.childNodes[0].nodeName == 'arccoth' + + mml = mp._print(acsch(x)) + assert mml.childNodes[0].nodeName == 'arccsch' + + mml = mp._print(asech(x)) + assert mml.childNodes[0].nodeName == 'arcsech' + + +def test_content_mathml_relational(): + mml_1 = mp._print(Eq(x, 1)) + assert mml_1.nodeName == 'apply' + assert mml_1.childNodes[0].nodeName == 'eq' + assert mml_1.childNodes[1].nodeName == 'ci' + assert mml_1.childNodes[1].childNodes[0].nodeValue == 'x' + assert mml_1.childNodes[2].nodeName == 'cn' + assert mml_1.childNodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mp._print(Ne(1, x)) + assert mml_2.nodeName == 'apply' + assert mml_2.childNodes[0].nodeName == 'neq' + assert mml_2.childNodes[1].nodeName == 'cn' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_2.childNodes[2].nodeName == 'ci' + assert mml_2.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_3 = mp._print(Ge(1, x)) + assert mml_3.nodeName == 'apply' + assert mml_3.childNodes[0].nodeName == 'geq' + assert mml_3.childNodes[1].nodeName == 'cn' + assert mml_3.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_3.childNodes[2].nodeName == 'ci' + assert mml_3.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_4 = mp._print(Lt(1, x)) + assert mml_4.nodeName == 'apply' + assert mml_4.childNodes[0].nodeName == 'lt' + assert mml_4.childNodes[1].nodeName == 'cn' + assert mml_4.childNodes[1].childNodes[0].nodeValue == '1' + assert mml_4.childNodes[2].nodeName == 'ci' + assert mml_4.childNodes[2].childNodes[0].nodeValue == 'x' + + +def test_content_symbol(): + mml = mp._print(x) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeValue == 'x' + del mml + + mml = mp._print(Symbol("x^2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x__2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msub' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mp._print(Symbol("x^3_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msubsup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[0].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mp._print(Symbol("x__3_2")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msubsup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[0].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mp._print(Symbol("x_2_a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msub' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + mml = mp._print(Symbol("x^2^a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + mml = mp._print(Symbol("x__2__a")) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeName == 'mml:msup' + assert mml.childNodes[0].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].nodeName == 'mml:mrow' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[0].childNodes[ + 0].nodeValue == '2' + assert mml.childNodes[0].childNodes[1].childNodes[1].nodeName == 'mml:mo' + assert mml.childNodes[0].childNodes[1].childNodes[1].childNodes[ + 0].nodeValue == ' ' + assert mml.childNodes[0].childNodes[1].childNodes[2].nodeName == 'mml:mi' + assert mml.childNodes[0].childNodes[1].childNodes[2].childNodes[ + 0].nodeValue == 'a' + del mml + + +def test_content_mathml_greek(): + mml = mp._print(Symbol('alpha')) + assert mml.nodeName == 'ci' + assert mml.childNodes[0].nodeValue == '\N{GREEK SMALL LETTER ALPHA}' + + assert mp.doprint(Symbol('alpha')) == 'α' + assert mp.doprint(Symbol('beta')) == 'β' + assert mp.doprint(Symbol('gamma')) == 'γ' + assert mp.doprint(Symbol('delta')) == 'δ' + assert mp.doprint(Symbol('epsilon')) == 'ε' + assert mp.doprint(Symbol('zeta')) == 'ζ' + assert mp.doprint(Symbol('eta')) == 'η' + assert mp.doprint(Symbol('theta')) == 'θ' + assert mp.doprint(Symbol('iota')) == 'ι' + assert mp.doprint(Symbol('kappa')) == 'κ' + assert mp.doprint(Symbol('lambda')) == 'λ' + assert mp.doprint(Symbol('mu')) == 'μ' + assert mp.doprint(Symbol('nu')) == 'ν' + assert mp.doprint(Symbol('xi')) == 'ξ' + assert mp.doprint(Symbol('omicron')) == 'ο' + assert mp.doprint(Symbol('pi')) == 'π' + assert mp.doprint(Symbol('rho')) == 'ρ' + assert mp.doprint(Symbol('varsigma')) == 'ς' + assert mp.doprint(Symbol('sigma')) == 'σ' + assert mp.doprint(Symbol('tau')) == 'τ' + assert mp.doprint(Symbol('upsilon')) == 'υ' + assert mp.doprint(Symbol('phi')) == 'φ' + assert mp.doprint(Symbol('chi')) == 'χ' + assert mp.doprint(Symbol('psi')) == 'ψ' + assert mp.doprint(Symbol('omega')) == 'ω' + + assert mp.doprint(Symbol('Alpha')) == 'Α' + assert mp.doprint(Symbol('Beta')) == 'Β' + assert mp.doprint(Symbol('Gamma')) == 'Γ' + assert mp.doprint(Symbol('Delta')) == 'Δ' + assert mp.doprint(Symbol('Epsilon')) == 'Ε' + assert mp.doprint(Symbol('Zeta')) == 'Ζ' + assert mp.doprint(Symbol('Eta')) == 'Η' + assert mp.doprint(Symbol('Theta')) == 'Θ' + assert mp.doprint(Symbol('Iota')) == 'Ι' + assert mp.doprint(Symbol('Kappa')) == 'Κ' + assert mp.doprint(Symbol('Lambda')) == 'Λ' + assert mp.doprint(Symbol('Mu')) == 'Μ' + assert mp.doprint(Symbol('Nu')) == 'Ν' + assert mp.doprint(Symbol('Xi')) == 'Ξ' + assert mp.doprint(Symbol('Omicron')) == 'Ο' + assert mp.doprint(Symbol('Pi')) == 'Π' + assert mp.doprint(Symbol('Rho')) == 'Ρ' + assert mp.doprint(Symbol('Sigma')) == 'Σ' + assert mp.doprint(Symbol('Tau')) == 'Τ' + assert mp.doprint(Symbol('Upsilon')) == 'Υ' + assert mp.doprint(Symbol('Phi')) == 'Φ' + assert mp.doprint(Symbol('Chi')) == 'Χ' + assert mp.doprint(Symbol('Psi')) == 'Ψ' + assert mp.doprint(Symbol('Omega')) == 'Ω' + + +def test_content_mathml_order(): + expr = x**3 + x**2*y + 3*x*y**3 + y**4 + + mp = MathMLContentPrinter({'order': 'lex'}) + mml = mp._print(expr) + + assert mml.childNodes[1].childNodes[0].nodeName == 'power' + assert mml.childNodes[1].childNodes[1].childNodes[0].data == 'x' + assert mml.childNodes[1].childNodes[2].childNodes[0].data == '3' + + assert mml.childNodes[4].childNodes[0].nodeName == 'power' + assert mml.childNodes[4].childNodes[1].childNodes[0].data == 'y' + assert mml.childNodes[4].childNodes[2].childNodes[0].data == '4' + + mp = MathMLContentPrinter({'order': 'rev-lex'}) + mml = mp._print(expr) + + assert mml.childNodes[1].childNodes[0].nodeName == 'power' + assert mml.childNodes[1].childNodes[1].childNodes[0].data == 'y' + assert mml.childNodes[1].childNodes[2].childNodes[0].data == '4' + + assert mml.childNodes[4].childNodes[0].nodeName == 'power' + assert mml.childNodes[4].childNodes[1].childNodes[0].data == 'x' + assert mml.childNodes[4].childNodes[2].childNodes[0].data == '3' + + +def test_content_settings(): + raises(TypeError, lambda: mathml(x, method="garbage")) + + +def test_content_mathml_logic(): + assert mathml(And(x, y)) == 'xy' + assert mathml(Or(x, y)) == 'xy' + assert mathml(Xor(x, y)) == 'xy' + assert mathml(Implies(x, y)) == 'xy' + assert mathml(Not(x)) == 'x' + + +def test_content_finite_sets(): + assert mathml(FiniteSet(a)) == 'a' + assert mathml(FiniteSet(a, b)) == 'ab' + assert mathml(FiniteSet(FiniteSet(a, b), c)) == \ + 'cab' + + A = FiniteSet(a) + B = FiniteSet(b) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(A, B, evaluate=False) + U2 = Union(C, D, evaluate=False) + I1 = Intersection(A, B, evaluate=False) + I2 = Intersection(C, D, evaluate=False) + C1 = Complement(A, B, evaluate=False) + C2 = Complement(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(A, B) + P2 = ProductSet(C, D) + + assert mathml(U1) == \ + 'ab' + assert mathml(I1) == \ + 'ab' \ + '' + assert mathml(C1) == \ + 'ab' + assert mathml(P1) == \ + 'ab' \ + '' + + assert mathml(Intersection(A, U2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Intersection(U1, U2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + + # XXX Does the parenthesis appear correctly for these examples in mathjax? + assert mathml(Intersection(C1, C2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Intersection(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(Union(A, I2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Union(I1, I2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Union(C1, C2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Union(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(Complement(A, C2, evaluate=False)) == \ + 'a' \ + 'cd' + assert mathml(Complement(U1, U2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Complement(I1, I2, evaluate=False)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(Complement(P1, P2, evaluate=False)) == \ + 'a' \ + 'b' \ + 'cd' + + assert mathml(ProductSet(A, P2)) == \ + 'a' \ + 'c' \ + 'd' + assert mathml(ProductSet(U1, U2)) == \ + 'a' \ + 'bc' \ + 'd' + assert mathml(ProductSet(I1, I2)) == \ + 'a' \ + 'b' \ + 'cd' + assert mathml(ProductSet(C1, C2)) == \ + 'a' \ + 'b' \ + 'cd' + + +def test_presentation_printmethod(): + assert mpp.doprint(1 + x) == 'x+1' + assert mpp.doprint(x**2) == 'x2' + assert mpp.doprint(x**-1) == '1x' + assert mpp.doprint(x**-2) == \ + '1x2' + assert mpp.doprint(2*x) == \ + '2x' + + +def test_presentation_mathml_core(): + mml_1 = mpp._print(1 + x) + assert mml_1.nodeName == 'mrow' + nodes = mml_1.childNodes + assert len(nodes) == 3 + assert nodes[0].nodeName in ['mi', 'mn'] + assert nodes[1].nodeName == 'mo' + if nodes[0].nodeName == 'mn': + assert nodes[0].childNodes[0].nodeValue == '1' + assert nodes[2].childNodes[0].nodeValue == 'x' + else: + assert nodes[0].childNodes[0].nodeValue == 'x' + assert nodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mpp._print(x**2) + assert mml_2.nodeName == 'msup' + nodes = mml_2.childNodes + assert nodes[0].childNodes[0].nodeValue == 'x' + assert nodes[1].childNodes[0].nodeValue == '2' + + mml_3 = mpp._print(2*x) + assert mml_3.nodeName == 'mrow' + nodes = mml_3.childNodes + assert nodes[0].childNodes[0].nodeValue == '2' + assert nodes[1].childNodes[0].nodeValue == '⁢' + assert nodes[2].childNodes[0].nodeValue == 'x' + + mml = mpp._print(Float(1.0, 2)*x) + assert mml.nodeName == 'mrow' + nodes = mml.childNodes + assert nodes[0].childNodes[0].nodeValue == '1.0' + assert nodes[1].childNodes[0].nodeValue == '⁢' + assert nodes[2].childNodes[0].nodeValue == 'x' + + +def test_presentation_mathml_functions(): + mml_1 = mpp._print(sin(x)) + assert mml_1.childNodes[0].childNodes[0 + ].nodeValue == 'sin' + assert mml_1.childNodes[1].childNodes[0 + ].childNodes[0].nodeValue == 'x' + + mml_2 = mpp._print(diff(sin(x), x, evaluate=False)) + assert mml_2.nodeName == 'mrow' + assert mml_2.childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == 'ⅆ' + assert mml_2.childNodes[1].childNodes[1 + ].nodeName == 'mfenced' + assert mml_2.childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == 'ⅆ' + + mml_3 = mpp._print(diff(cos(x*y), x, evaluate=False)) + assert mml_3.childNodes[0].nodeName == 'mfrac' + assert mml_3.childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '∂' + assert mml_3.childNodes[1].childNodes[0 + ].childNodes[0].nodeValue == 'cos' + + +def test_print_derivative(): + f = Function('f') + d = Derivative(f(x, y, z), x, z, x, z, z, y) + assert mathml(d) == \ + 'yz2xzxxyz' + assert mathml(d, printer='presentation') == \ + '6y2zxzxfxyz' + + +def test_presentation_mathml_limits(): + lim_fun = sin(x)/x + mml_1 = mpp._print(Limit(lim_fun, x, 0)) + assert mml_1.childNodes[0].nodeName == 'munder' + assert mml_1.childNodes[0].childNodes[0 + ].childNodes[0].nodeValue == 'lim' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[1].childNodes[0 + ].nodeValue == '→' + assert mml_1.childNodes[0].childNodes[1 + ].childNodes[2].childNodes[0 + ].nodeValue == '0' + + +def test_presentation_mathml_integrals(): + assert mpp.doprint(Integral(x, (x, 0, 1))) == \ + '01'\ + 'xx' + assert mpp.doprint(Integral(log(x), x)) == \ + 'logx'\ + 'x' + assert mpp.doprint(Integral(x*y, x, y)) == \ + 'x'\ + 'yyx' + z, w = symbols('z w') + assert mpp.doprint(Integral(x*y*z, x, y, z)) == \ + 'x'\ + 'yz'\ + 'zyx' + assert mpp.doprint(Integral(x*y*z*w, x, y, z, w)) == \ + ''\ + 'w'\ + 'xy'\ + 'zw'\ + 'zyx' + assert mpp.doprint(Integral(x, x, y, (z, 0, 1))) == \ + '01'\ + 'xz'\ + 'yx' + assert mpp.doprint(Integral(x, (x, 0))) == \ + '0x'\ + 'x' + + +def test_presentation_mathml_matrices(): + A = Matrix([1, 2, 3]) + B = Matrix([[0, 5, 4], [2, 3, 1], [9, 7, 9]]) + mll_1 = mpp._print(A) + assert mll_1.childNodes[0].nodeName == 'mtable' + assert mll_1.childNodes[0].childNodes[0].nodeName == 'mtr' + assert len(mll_1.childNodes[0].childNodes) == 3 + assert mll_1.childNodes[0].childNodes[0].childNodes[0].nodeName == 'mtd' + assert len(mll_1.childNodes[0].childNodes[0].childNodes) == 1 + assert mll_1.childNodes[0].childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_1.childNodes[0].childNodes[1].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_1.childNodes[0].childNodes[2].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '3' + mll_2 = mpp._print(B) + assert mll_2.childNodes[0].nodeName == 'mtable' + assert mll_2.childNodes[0].childNodes[0].nodeName == 'mtr' + assert len(mll_2.childNodes[0].childNodes) == 3 + assert mll_2.childNodes[0].childNodes[0].childNodes[0].nodeName == 'mtd' + assert len(mll_2.childNodes[0].childNodes[0].childNodes) == 3 + assert mll_2.childNodes[0].childNodes[0].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '0' + assert mll_2.childNodes[0].childNodes[0].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '5' + assert mll_2.childNodes[0].childNodes[0].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '4' + assert mll_2.childNodes[0].childNodes[1].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '2' + assert mll_2.childNodes[0].childNodes[1].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '3' + assert mll_2.childNodes[0].childNodes[1].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '1' + assert mll_2.childNodes[0].childNodes[2].childNodes[0 + ].childNodes[0].childNodes[0].nodeValue == '9' + assert mll_2.childNodes[0].childNodes[2].childNodes[1 + ].childNodes[0].childNodes[0].nodeValue == '7' + assert mll_2.childNodes[0].childNodes[2].childNodes[2 + ].childNodes[0].childNodes[0].nodeValue == '9' + + +def test_presentation_mathml_sums(): + summand = x + mml_1 = mpp._print(Sum(summand, (x, 1, 10))) + assert mml_1.childNodes[0].nodeName == 'munderover' + assert len(mml_1.childNodes[0].childNodes) == 3 + assert mml_1.childNodes[0].childNodes[0].childNodes[0 + ].nodeValue == '∑' + assert len(mml_1.childNodes[0].childNodes[1].childNodes) == 3 + assert mml_1.childNodes[0].childNodes[2].childNodes[0 + ].nodeValue == '10' + assert mml_1.childNodes[1].childNodes[0].nodeValue == 'x' + + +def test_presentation_mathml_add(): + mml = mpp._print(x**5 - x**4 + x) + assert len(mml.childNodes) == 5 + assert mml.childNodes[0].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].childNodes[0 + ].nodeValue == '5' + assert mml.childNodes[1].childNodes[0].nodeValue == '-' + assert mml.childNodes[2].childNodes[0].childNodes[0 + ].nodeValue == 'x' + assert mml.childNodes[2].childNodes[1].childNodes[0 + ].nodeValue == '4' + assert mml.childNodes[3].childNodes[0].nodeValue == '+' + assert mml.childNodes[4].childNodes[0].nodeValue == 'x' + + +def test_presentation_mathml_Rational(): + mml_1 = mpp._print(Rational(1, 1)) + assert mml_1.nodeName == 'mn' + + mml_2 = mpp._print(Rational(2, 5)) + assert mml_2.nodeName == 'mfrac' + assert mml_2.childNodes[0].childNodes[0].nodeValue == '2' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '5' + + +def test_presentation_mathml_constants(): + mml = mpp._print(I) + assert mml.childNodes[0].nodeValue == 'ⅈ' + + mml = mpp._print(E) + assert mml.childNodes[0].nodeValue == 'ⅇ' + + mml = mpp._print(oo) + assert mml.childNodes[0].nodeValue == '∞' + + mml = mpp._print(pi) + assert mml.childNodes[0].nodeValue == 'π' + + assert mathml(hbar, printer='presentation') == '' + assert mathml(S.TribonacciConstant, printer='presentation' + ) == 'TribonacciConstant' + assert mathml(S.EulerGamma, printer='presentation' + ) == 'γ' + assert mathml(S.GoldenRatio, printer='presentation' + ) == 'Φ' + + assert mathml(zoo, printer='presentation') == \ + '~' + + assert mathml(S.NaN, printer='presentation') == 'NaN' + +def test_presentation_mathml_trig(): + mml = mpp._print(sin(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'sin' + + mml = mpp._print(cos(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'cos' + + mml = mpp._print(tan(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'tan' + + mml = mpp._print(asin(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arcsin' + + mml = mpp._print(acos(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arccos' + + mml = mpp._print(atan(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arctan' + + mml = mpp._print(sinh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'sinh' + + mml = mpp._print(cosh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'cosh' + + mml = mpp._print(tanh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'tanh' + + mml = mpp._print(asinh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arcsinh' + + mml = mpp._print(atanh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arctanh' + + mml = mpp._print(acosh(x)) + assert mml.childNodes[0].childNodes[0].nodeValue == 'arccosh' + + +def test_presentation_mathml_relational(): + mml_1 = mpp._print(Eq(x, 1)) + assert len(mml_1.childNodes) == 3 + assert mml_1.childNodes[0].nodeName == 'mi' + assert mml_1.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml_1.childNodes[1].nodeName == 'mo' + assert mml_1.childNodes[1].childNodes[0].nodeValue == '=' + assert mml_1.childNodes[2].nodeName == 'mn' + assert mml_1.childNodes[2].childNodes[0].nodeValue == '1' + + mml_2 = mpp._print(Ne(1, x)) + assert len(mml_2.childNodes) == 3 + assert mml_2.childNodes[0].nodeName == 'mn' + assert mml_2.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_2.childNodes[1].nodeName == 'mo' + assert mml_2.childNodes[1].childNodes[0].nodeValue == '≠' + assert mml_2.childNodes[2].nodeName == 'mi' + assert mml_2.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_3 = mpp._print(Ge(1, x)) + assert len(mml_3.childNodes) == 3 + assert mml_3.childNodes[0].nodeName == 'mn' + assert mml_3.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_3.childNodes[1].nodeName == 'mo' + assert mml_3.childNodes[1].childNodes[0].nodeValue == '≥' + assert mml_3.childNodes[2].nodeName == 'mi' + assert mml_3.childNodes[2].childNodes[0].nodeValue == 'x' + + mml_4 = mpp._print(Lt(1, x)) + assert len(mml_4.childNodes) == 3 + assert mml_4.childNodes[0].nodeName == 'mn' + assert mml_4.childNodes[0].childNodes[0].nodeValue == '1' + assert mml_4.childNodes[1].nodeName == 'mo' + assert mml_4.childNodes[1].childNodes[0].nodeValue == '<' + assert mml_4.childNodes[2].nodeName == 'mi' + assert mml_4.childNodes[2].childNodes[0].nodeValue == 'x' + + +def test_presentation_symbol(): + mml = mpp._print(x) + assert mml.nodeName == 'mi' + assert mml.childNodes[0].nodeValue == 'x' + del mml + + mml = mpp._print(Symbol("x^2")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x__2")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x_2")) + assert mml.nodeName == 'msub' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + del mml + + mml = mpp._print(Symbol("x^3_2")) + assert mml.nodeName == 'msubsup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[2].nodeName == 'mi' + assert mml.childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mpp._print(Symbol("x__3_2")) + assert mml.nodeName == 'msubsup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].nodeValue == '2' + assert mml.childNodes[2].nodeName == 'mi' + assert mml.childNodes[2].childNodes[0].nodeValue == '3' + del mml + + mml = mpp._print(Symbol("x_2_a")) + assert mml.nodeName == 'msub' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + mml = mpp._print(Symbol("x^2^a")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + mml = mpp._print(Symbol("x__2__a")) + assert mml.nodeName == 'msup' + assert mml.childNodes[0].nodeName == 'mi' + assert mml.childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[1].nodeName == 'mrow' + assert mml.childNodes[1].childNodes[0].nodeName == 'mi' + assert mml.childNodes[1].childNodes[0].childNodes[0].nodeValue == '2' + assert mml.childNodes[1].childNodes[1].nodeName == 'mo' + assert mml.childNodes[1].childNodes[1].childNodes[0].nodeValue == ' ' + assert mml.childNodes[1].childNodes[2].nodeName == 'mi' + assert mml.childNodes[1].childNodes[2].childNodes[0].nodeValue == 'a' + del mml + + +def test_presentation_mathml_greek(): + mml = mpp._print(Symbol('alpha')) + assert mml.nodeName == 'mi' + assert mml.childNodes[0].nodeValue == '\N{GREEK SMALL LETTER ALPHA}' + + assert mpp.doprint(Symbol('alpha')) == 'α' + assert mpp.doprint(Symbol('beta')) == 'β' + assert mpp.doprint(Symbol('gamma')) == 'γ' + assert mpp.doprint(Symbol('delta')) == 'δ' + assert mpp.doprint(Symbol('epsilon')) == 'ε' + assert mpp.doprint(Symbol('zeta')) == 'ζ' + assert mpp.doprint(Symbol('eta')) == 'η' + assert mpp.doprint(Symbol('theta')) == 'θ' + assert mpp.doprint(Symbol('iota')) == 'ι' + assert mpp.doprint(Symbol('kappa')) == 'κ' + assert mpp.doprint(Symbol('lambda')) == 'λ' + assert mpp.doprint(Symbol('mu')) == 'μ' + assert mpp.doprint(Symbol('nu')) == 'ν' + assert mpp.doprint(Symbol('xi')) == 'ξ' + assert mpp.doprint(Symbol('omicron')) == 'ο' + assert mpp.doprint(Symbol('pi')) == 'π' + assert mpp.doprint(Symbol('rho')) == 'ρ' + assert mpp.doprint(Symbol('varsigma')) == 'ς' + assert mpp.doprint(Symbol('sigma')) == 'σ' + assert mpp.doprint(Symbol('tau')) == 'τ' + assert mpp.doprint(Symbol('upsilon')) == 'υ' + assert mpp.doprint(Symbol('phi')) == 'φ' + assert mpp.doprint(Symbol('chi')) == 'χ' + assert mpp.doprint(Symbol('psi')) == 'ψ' + assert mpp.doprint(Symbol('omega')) == 'ω' + + assert mpp.doprint(Symbol('Alpha')) == 'Α' + assert mpp.doprint(Symbol('Beta')) == 'Β' + assert mpp.doprint(Symbol('Gamma')) == 'Γ' + assert mpp.doprint(Symbol('Delta')) == 'Δ' + assert mpp.doprint(Symbol('Epsilon')) == 'Ε' + assert mpp.doprint(Symbol('Zeta')) == 'Ζ' + assert mpp.doprint(Symbol('Eta')) == 'Η' + assert mpp.doprint(Symbol('Theta')) == 'Θ' + assert mpp.doprint(Symbol('Iota')) == 'Ι' + assert mpp.doprint(Symbol('Kappa')) == 'Κ' + assert mpp.doprint(Symbol('Lambda')) == 'Λ' + assert mpp.doprint(Symbol('Mu')) == 'Μ' + assert mpp.doprint(Symbol('Nu')) == 'Ν' + assert mpp.doprint(Symbol('Xi')) == 'Ξ' + assert mpp.doprint(Symbol('Omicron')) == 'Ο' + assert mpp.doprint(Symbol('Pi')) == 'Π' + assert mpp.doprint(Symbol('Rho')) == 'Ρ' + assert mpp.doprint(Symbol('Sigma')) == 'Σ' + assert mpp.doprint(Symbol('Tau')) == 'Τ' + assert mpp.doprint(Symbol('Upsilon')) == 'Υ' + assert mpp.doprint(Symbol('Phi')) == 'Φ' + assert mpp.doprint(Symbol('Chi')) == 'Χ' + assert mpp.doprint(Symbol('Psi')) == 'Ψ' + assert mpp.doprint(Symbol('Omega')) == 'Ω' + + +def test_presentation_mathml_order(): + expr = x**3 + x**2*y + 3*x*y**3 + y**4 + + mp = MathMLPresentationPrinter({'order': 'lex'}) + mml = mp._print(expr) + assert mml.childNodes[0].nodeName == 'msup' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '3' + + assert mml.childNodes[6].nodeName == 'msup' + assert mml.childNodes[6].childNodes[0].childNodes[0].nodeValue == 'y' + assert mml.childNodes[6].childNodes[1].childNodes[0].nodeValue == '4' + + mp = MathMLPresentationPrinter({'order': 'rev-lex'}) + mml = mp._print(expr) + + assert mml.childNodes[0].nodeName == 'msup' + assert mml.childNodes[0].childNodes[0].childNodes[0].nodeValue == 'y' + assert mml.childNodes[0].childNodes[1].childNodes[0].nodeValue == '4' + + assert mml.childNodes[6].nodeName == 'msup' + assert mml.childNodes[6].childNodes[0].childNodes[0].nodeValue == 'x' + assert mml.childNodes[6].childNodes[1].childNodes[0].nodeValue == '3' + + +def test_print_intervals(): + a = Symbol('a', real=True) + assert mpp.doprint(Interval(0, a)) == \ + '0a' + assert mpp.doprint(Interval(0, a, False, False)) == \ + '0a' + assert mpp.doprint(Interval(0, a, True, False)) == \ + '0a' + assert mpp.doprint(Interval(0, a, False, True)) == \ + '0a' + assert mpp.doprint(Interval(0, a, True, True)) == \ + '0a' + + +def test_print_tuples(): + assert mpp.doprint(Tuple(0,)) == \ + '0' + assert mpp.doprint(Tuple(0, a)) == \ + '0a' + assert mpp.doprint(Tuple(0, a, a)) == \ + '0aa' + assert mpp.doprint(Tuple(0, 1, 2, 3, 4)) == \ + '01234' + assert mpp.doprint(Tuple(0, 1, Tuple(2, 3, 4))) == \ + '0123'\ + '4' + + +def test_print_re_im(): + assert mpp.doprint(re(x)) == \ + 'Rx' + assert mpp.doprint(im(x)) == \ + 'Ix' + assert mpp.doprint(re(x + 1)) == \ + 'Rx'\ + '+1' + assert mpp.doprint(im(x + 1)) == \ + 'Ix' + + +def test_print_Abs(): + assert mpp.doprint(Abs(x)) == \ + 'x' + assert mpp.doprint(Abs(x + 1)) == \ + 'x+1' + + +def test_print_Determinant(): + assert mpp.doprint(Determinant(Matrix([[1, 2], [3, 4]]))) == \ + '1234' + + +def test_presentation_settings(): + raises(TypeError, lambda: mathml(x, printer='presentation', + method="garbage")) + + +def test_print_domains(): + from sympy.sets import Integers, Naturals, Naturals0, Reals, Complexes + + assert mpp.doprint(Complexes) == '' + assert mpp.doprint(Integers) == '' + assert mpp.doprint(Naturals) == '' + assert mpp.doprint(Naturals0) == \ + '0' + assert mpp.doprint(Reals) == '' + + +def test_print_expression_with_minus(): + assert mpp.doprint(-x) == '-x' + assert mpp.doprint(-x/y) == \ + '-xy' + assert mpp.doprint(-Rational(1, 2)) == \ + '-12' + + +def test_print_AssocOp(): + from sympy.core.operations import AssocOp + + class TestAssocOp(AssocOp): + identity = 0 + + expr = TestAssocOp(1, 2) + assert mpp.doprint(expr) == \ + 'testassocop12' + + +def test_print_basic(): + expr = Basic(S(1), S(2)) + assert mpp.doprint(expr) == \ + 'basic12' + assert mp.doprint(expr) == '12' + + +def test_mat_delim_print(): + expr = Matrix([[1, 2], [3, 4]]) + assert mathml(expr, printer='presentation', mat_delim='[') == \ + '1'\ + '234'\ + '' + assert mathml(expr, printer='presentation', mat_delim='(') == \ + '12'\ + '34' + assert mathml(expr, printer='presentation', mat_delim='') == \ + '12'\ + '34' + + +def test_ln_notation_print(): + expr = log(x) + assert mathml(expr, printer='presentation') == \ + 'logx' + assert mathml(expr, printer='presentation', ln_notation=False) == \ + 'logx' + assert mathml(expr, printer='presentation', ln_notation=True) == \ + 'lnx' + + +def test_mul_symbol_print(): + expr = x * y + assert mathml(expr, printer='presentation') == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol=None) == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol='dot') == \ + 'x·y' + assert mathml(expr, printer='presentation', mul_symbol='ldot') == \ + 'xy' + assert mathml(expr, printer='presentation', mul_symbol='times') == \ + 'x×y' + + +def test_print_lerchphi(): + assert mpp.doprint(lerchphi(1, 2, 3)) == \ + 'Φ123' + + +def test_print_polylog(): + assert mp.doprint(polylog(x, y)) == \ + 'xy' + assert mpp.doprint(polylog(x, y)) == \ + 'Lixy' + + +def test_print_set_frozenset(): + f = frozenset({1, 5, 3}) + assert mpp.doprint(f) == \ + '135' + s = set({1, 2, 3}) + assert mpp.doprint(s) == \ + '123' + + +def test_print_FiniteSet(): + f1 = FiniteSet(x, 1, 3) + assert mpp.doprint(f1) == \ + '13x' + + +def test_print_LambertW(): + assert mpp.doprint(LambertW(x)) == 'Wx' + assert mpp.doprint(LambertW(x, y)) == 'Wxy' + + +def test_print_EmptySet(): + assert mpp.doprint(S.EmptySet) == '' + + +def test_print_UniversalSet(): + assert mpp.doprint(S.UniversalSet) == '𝕌' + + +def test_print_spaces(): + assert mpp.doprint(HilbertSpace()) == '' + assert mpp.doprint(ComplexSpace(2)) == '𝒞2' + assert mpp.doprint(FockSpace()) == '' + + +def test_print_constants(): + assert mpp.doprint(hbar) == '' + assert mpp.doprint(S.TribonacciConstant) == 'TribonacciConstant' + assert mpp.doprint(S.GoldenRatio) == 'Φ' + assert mpp.doprint(S.EulerGamma) == 'γ' + + +def test_print_Contains(): + assert mpp.doprint(Contains(x, S.Naturals)) == \ + 'x' + + +def test_print_Dagger(): + assert mpp.doprint(Dagger(x)) == 'x' + + +def test_print_SetOp(): + f1 = FiniteSet(x, 1, 3) + f2 = FiniteSet(y, 2, 4) + + prntr = lambda x: mathml(x, printer='presentation') + + assert prntr(Union(f1, f2, evaluate=False)) == \ + '13x'\ + '2'\ + '4y' + assert prntr(Intersection(f1, f2, evaluate=False)) == \ + '13x'\ + '2'\ + '4y' + assert prntr(Complement(f1, f2, evaluate=False)) == \ + '13x'\ + '2'\ + '4y' + assert prntr(SymmetricDifference(f1, f2, evaluate=False)) == \ + '13x'\ + '2'\ + '4y' + + A = FiniteSet(a) + C = FiniteSet(c) + D = FiniteSet(d) + + U1 = Union(C, D, evaluate=False) + I1 = Intersection(C, D, evaluate=False) + C1 = Complement(C, D, evaluate=False) + D1 = SymmetricDifference(C, D, evaluate=False) + # XXX ProductSet does not support evaluate keyword + P1 = ProductSet(C, D) + + assert prntr(Union(A, I1, evaluate=False)) == \ + 'a' \ + '' \ + 'c' \ + 'd' + assert prntr(Intersection(A, C1, evaluate=False)) == \ + 'a' \ + '' \ + 'c' \ + 'd' + assert prntr(Complement(A, D1, evaluate=False)) == \ + 'a' \ + '' \ + 'c' \ + 'd' + assert prntr(SymmetricDifference(A, P1, evaluate=False)) == \ + 'a' \ + '' \ + 'c×' \ + 'd' + assert prntr(ProductSet(A, U1)) == \ + 'a' \ + '×' \ + 'c' \ + 'd' + + +def test_print_logic(): + assert mpp.doprint(And(x, y)) == \ + 'xy' + assert mpp.doprint(Or(x, y)) == \ + 'xy' + assert mpp.doprint(Xor(x, y)) == \ + 'xy' + assert mpp.doprint(Implies(x, y)) == \ + 'xy' + assert mpp.doprint(Equivalent(x, y)) == \ + 'xy' + + assert mpp.doprint(And(Eq(x, y), x > 4)) == \ + 'x=y'\ + 'x>4' + assert mpp.doprint(And(Eq(x, 3), y < 3, x > y + 1)) == \ + 'x=3'\ + 'x>y+1'\ + 'y<3' + assert mpp.doprint(Or(Eq(x, y), x > 4)) == \ + 'x=y'\ + 'x>4' + assert mpp.doprint(And(Eq(x, 3), Or(y < 3, x > y + 1))) == \ + 'x=3'\ + 'x>y+'\ + '1y<'\ + '3' + + assert mpp.doprint(Not(x)) == '¬x' + assert mpp.doprint(Not(And(x, y))) == \ + '¬x'\ + 'y' + + +def test_root_notation_print(): + assert mathml(x**(S.One/3), printer='presentation') == \ + 'x3' + assert mathml(x**(S.One/3), printer='presentation', root_notation=False) ==\ + 'x13' + assert mathml(x**(S.One/3), printer='content') == \ + '3x' + assert mathml(x**(S.One/3), printer='content', root_notation=False) == \ + 'x13' + assert mathml(x**(Rational(-1, 3)), printer='presentation') == \ + '1x3' + assert mathml(x**(Rational(-1, 3)), printer='presentation', root_notation=False) \ + == '1x13' + + +def test_fold_frac_powers_print(): + expr = x ** Rational(5, 2) + assert mathml(expr, printer='presentation') == \ + 'x52' + assert mathml(expr, printer='presentation', fold_frac_powers=True) == \ + 'x52' + assert mathml(expr, printer='presentation', fold_frac_powers=False) == \ + 'x52' + + +def test_fold_short_frac_print(): + expr = Rational(2, 5) + assert mathml(expr, printer='presentation') == \ + '25' + assert mathml(expr, printer='presentation', fold_short_frac=True) == \ + '25' + assert mathml(expr, printer='presentation', fold_short_frac=False) == \ + '25' + + +def test_print_factorials(): + assert mpp.doprint(factorial(x)) == 'x!' + assert mpp.doprint(factorial(x + 1)) == \ + 'x+1!' + assert mpp.doprint(factorial2(x)) == 'x!!' + assert mpp.doprint(factorial2(x + 1)) == \ + 'x+1!!' + assert mpp.doprint(binomial(x, y)) == \ + 'xy' + assert mpp.doprint(binomial(4, x + y)) == \ + '4x'\ + '+y' + + +def test_print_floor(): + expr = floor(x) + assert mathml(expr, printer='presentation') == \ + 'x' + + +def test_print_ceiling(): + expr = ceiling(x) + assert mathml(expr, printer='presentation') == \ + 'x' + + +def test_print_Lambda(): + expr = Lambda(x, x+1) + assert mathml(expr, printer='presentation') == \ + 'xx+'\ + '1' + expr = Lambda((x, y), x + y) + assert mathml(expr, printer='presentation') == \ + 'xy'\ + 'x+y' + + +def test_print_conjugate(): + assert mpp.doprint(conjugate(x)) == \ + 'x' + assert mpp.doprint(conjugate(x + 1)) == \ + 'x+1' + + +def test_print_AccumBounds(): + a = Symbol('a', real=True) + assert mpp.doprint(AccumBounds(0, 1)) == '01' + assert mpp.doprint(AccumBounds(0, a)) == '0a' + assert mpp.doprint(AccumBounds(a + 1, a + 2)) == 'a+1a+2' + + +def test_print_Float(): + assert mpp.doprint(Float(1e100)) == '1.0·10100' + assert mpp.doprint(Float(1e-100)) == '1.0·10-100' + assert mpp.doprint(Float(-1e100)) == '-1.0·10100' + assert mpp.doprint(Float(1.0*oo)) == '' + assert mpp.doprint(Float(-1.0*oo)) == '-' + + +def test_print_different_functions(): + assert mpp.doprint(gamma(x)) == 'Γx' + assert mpp.doprint(lowergamma(x, y)) == 'γxy' + assert mpp.doprint(uppergamma(x, y)) == 'Γxy' + assert mpp.doprint(zeta(x)) == 'ζx' + assert mpp.doprint(zeta(x, y)) == 'ζxy' + assert mpp.doprint(dirichlet_eta(x)) == 'ηx' + assert mpp.doprint(elliptic_k(x)) == 'Κx' + assert mpp.doprint(totient(x)) == 'ϕx' + assert mpp.doprint(reduced_totient(x)) == 'λx' + assert mpp.doprint(primenu(x)) == 'νx' + assert mpp.doprint(primeomega(x)) == 'Ωx' + assert mpp.doprint(fresnels(x)) == 'Sx' + assert mpp.doprint(fresnelc(x)) == 'Cx' + assert mpp.doprint(Heaviside(x)) == 'Θx12' + + +def test_mathml_builtins(): + assert mpp.doprint(None) == 'None' + assert mpp.doprint(true) == 'True' + assert mpp.doprint(false) == 'False' + + +def test_mathml_Range(): + assert mpp.doprint(Range(1, 51)) == \ + '1250' + assert mpp.doprint(Range(1, 4)) == \ + '123' + assert mpp.doprint(Range(0, 3, 1)) == \ + '012' + assert mpp.doprint(Range(0, 30, 1)) == \ + '0129' + assert mpp.doprint(Range(30, 1, -1)) == \ + '3029'\ + '2' + assert mpp.doprint(Range(0, oo, 2)) == \ + '02' + assert mpp.doprint(Range(oo, -2, -2)) == \ + '20' + assert mpp.doprint(Range(-2, -oo, -1)) == \ + '-2-3' + + +def test_print_exp(): + assert mpp.doprint(exp(x)) == \ + 'x' + assert mpp.doprint(exp(1) + exp(2)) == \ + '+2' + + +def test_print_MinMax(): + assert mpp.doprint(Min(x, y)) == \ + 'minxy' + assert mpp.doprint(Min(x, 2, x**3)) == \ + 'min2xx'\ + '3' + assert mpp.doprint(Max(x, y)) == \ + 'maxxy' + assert mpp.doprint(Max(x, 2, x**3)) == \ + 'max2xx'\ + '3' + + +def test_mathml_presentation_numbers(): + n = Symbol('n') + assert mathml(catalan(n), printer='presentation') == \ + 'Cn' + assert mathml(bernoulli(n), printer='presentation') == \ + 'Bn' + assert mathml(bell(n), printer='presentation') == \ + 'Bn' + assert mathml(euler(n), printer='presentation') == \ + 'En' + assert mathml(fibonacci(n), printer='presentation') == \ + 'Fn' + assert mathml(lucas(n), printer='presentation') == \ + 'Ln' + assert mathml(tribonacci(n), printer='presentation') == \ + 'Tn' + assert mathml(bernoulli(n, x), printer='presentation') == \ + 'Bnx' + assert mathml(bell(n, x), printer='presentation') == \ + 'Bnx' + assert mathml(euler(n, x), printer='presentation') == \ + 'Enx' + assert mathml(fibonacci(n, x), printer='presentation') == \ + 'Fnx' + assert mathml(tribonacci(n, x), printer='presentation') == \ + 'Tnx' + + +def test_mathml_presentation_mathieu(): + assert mathml(mathieuc(x, y, z), printer='presentation') == \ + 'Cxyz' + assert mathml(mathieus(x, y, z), printer='presentation') == \ + 'Sxyz' + assert mathml(mathieucprime(x, y, z), printer='presentation') == \ + 'C′xyz' + assert mathml(mathieusprime(x, y, z), printer='presentation') == \ + 'S′xyz' + + +def test_mathml_presentation_stieltjes(): + assert mathml(stieltjes(n), printer='presentation') == \ + 'γn' + assert mathml(stieltjes(n, x), printer='presentation') == \ + 'γnx' + + +def test_print_matrix_symbol(): + A = MatrixSymbol('A', 1, 2) + assert mpp.doprint(A) == 'A' + assert mp.doprint(A) == 'A' + assert mathml(A, printer='presentation', mat_symbol_style="bold") == \ + 'A' + # No effect in content printer + assert mathml(A, mat_symbol_style="bold") == 'A' + + +def test_print_hadamard(): + from sympy.matrices.expressions import HadamardProduct + from sympy.matrices.expressions import Transpose + + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + + assert mathml(HadamardProduct(X, Y*Y), printer="presentation") == \ + '' \ + 'X' \ + '' \ + 'Y2' \ + '' + + assert mathml(HadamardProduct(X, Y)*Y, printer="presentation") == \ + '' \ + '' \ + 'XY' \ + '' \ + 'Y' \ + '' + + assert mathml(HadamardProduct(X, Y, Y), printer="presentation") == \ + '' \ + 'X' \ + 'Y' \ + 'Y' \ + '' + + assert mathml( + Transpose(HadamardProduct(X, Y)), printer="presentation") == \ + '' \ + '' \ + 'XY' \ + '' \ + 'T' \ + '' + + +def test_print_random_symbol(): + R = RandomSymbol(Symbol('R')) + assert mpp.doprint(R) == 'R' + assert mp.doprint(R) == 'R' + + +def test_print_IndexedBase(): + assert mathml(IndexedBase(a)[b], printer='presentation') == \ + 'ab' + assert mathml(IndexedBase(a)[b, c, d], printer='presentation') == \ + 'abcd' + assert mathml(IndexedBase(a)[b]*IndexedBase(c)[d]*IndexedBase(e), + printer='presentation') == \ + 'ab⁢'\ + 'cde' + + +def test_print_Indexed(): + assert mathml(IndexedBase(a), printer='presentation') == 'a' + assert mathml(IndexedBase(a/b), printer='presentation') == \ + 'ab' + assert mathml(IndexedBase((a, b)), printer='presentation') == \ + 'ab' + +def test_print_MatrixElement(): + i, j = symbols('i j') + A = MatrixSymbol('A', i, j) + assert mathml(A[0,0],printer = 'presentation') == \ + 'A00' + assert mathml(A[i,j], printer = 'presentation') == \ + 'Aij' + assert mathml(A[i*j,0], printer = 'presentation') == \ + 'Aij0' + + +def test_print_Vector(): + ACS = CoordSys3D('A') + assert mathml(Cross(ACS.i, ACS.j*ACS.x*3 + ACS.k), printer='presentation') == \ + 'i^'\ + 'A×'\ + '3'\ + 'xA'\ + ''\ + 'j^'\ + 'A+'\ + 'k^'\ + 'A' + assert mathml(Cross(ACS.i, ACS.j), printer='presentation') == \ + 'i^'\ + 'A×'\ + 'j^'\ + 'A' + assert mathml(x*Cross(ACS.i, ACS.j), printer='presentation') == \ + 'x'\ + 'i^'\ + 'A×'\ + 'j^'\ + 'A' + assert mathml(Cross(x*ACS.i, ACS.j), printer='presentation') == \ + '-j'\ + '^A'\ + '×x'\ + 'i'\ + '^A'\ + '' + assert mathml(Curl(3*ACS.x*ACS.j), printer='presentation') == \ + '×'\ + '3'\ + 'xA'\ + ''\ + 'j^'\ + 'A' + assert mathml(Curl(3*x*ACS.x*ACS.j), printer='presentation') == \ + '×'\ + '3x'\ + 'A'\ + 'x'\ + 'j^'\ + 'A' + assert mathml(x*Curl(3*ACS.x*ACS.j), printer='presentation') == \ + 'x'\ + '×3'\ + 'x'\ + 'A'\ + 'j'\ + '^A'\ + '' + assert mathml(Curl(3*x*ACS.x*ACS.j + ACS.i), printer='presentation') == \ + '×'\ + 'i^'\ + 'A+'\ + '3x'\ + 'A'\ + 'x'\ + 'j^'\ + 'A' + assert mathml(Divergence(3*ACS.x*ACS.j), printer='presentation') == \ + '·'\ + '3x'\ + 'A'\ + 'j'\ + '^A' + assert mathml(x*Divergence(3*ACS.x*ACS.j), printer='presentation') == \ + 'x'\ + '·3'\ + 'x'\ + 'A'\ + 'j'\ + '^A'\ + '' + assert mathml(Divergence(3*x*ACS.x*ACS.j + ACS.i), printer='presentation') == \ + '·'\ + 'i^'\ + 'A+'\ + '3'\ + 'xA'\ + 'x'\ + 'j'\ + '^A' + assert mathml(Dot(ACS.i, ACS.j*ACS.x*3+ACS.k), printer='presentation') == \ + 'i^'\ + 'A·'\ + '3'\ + 'xA'\ + ''\ + 'j^'\ + 'A+'\ + 'k^'\ + 'A' + assert mathml(Dot(ACS.i, ACS.j), printer='presentation') == \ + 'i^'\ + 'A·'\ + 'j^'\ + 'A' + assert mathml(Dot(x*ACS.i, ACS.j), printer='presentation') == \ + 'j^'\ + 'A·'\ + 'x'\ + 'i^'\ + 'A' + assert mathml(x*Dot(ACS.i, ACS.j), printer='presentation') == \ + 'x'\ + 'i^'\ + 'A·'\ + 'j^'\ + 'A' + assert mathml(Gradient(ACS.x), printer='presentation') == \ + 'x'\ + 'A' + assert mathml(Gradient(ACS.x + 3*ACS.y), printer='presentation') == \ + ''\ + 'xA+3'\ + 'y'\ + 'A' + assert mathml(x*Gradient(ACS.x), printer='presentation') == \ + 'x'\ + 'xA'\ + '' + assert mathml(Gradient(x*ACS.x), printer='presentation') == \ + ''\ + 'xA'\ + 'x' + assert mathml(Cross(ACS.x, ACS.z) + Cross(ACS.z, ACS.x), printer='presentation') == \ + '0^' + assert mathml(Cross(ACS.z, ACS.x), printer='presentation') == \ + '-x'\ + 'A×'\ + 'zA' + assert mathml(Laplacian(ACS.x), printer='presentation') == \ + 'x'\ + 'A' + assert mathml(Laplacian(ACS.x + 3*ACS.y), printer='presentation') == \ + ''\ + 'xA+3'\ + 'y'\ + 'A' + assert mathml(x*Laplacian(ACS.x), printer='presentation') == \ + 'x'\ + 'xA'\ + '' + assert mathml(Laplacian(x*ACS.x), printer='presentation') == \ + ''\ + 'xA'\ + 'x' + +def test_print_elliptic_f(): + assert mathml(elliptic_f(x, y), printer = 'presentation') == \ + '𝖥xy' + assert mathml(elliptic_f(x/y, y), printer = 'presentation') == \ + '𝖥xyy' + +def test_print_elliptic_e(): + assert mathml(elliptic_e(x), printer = 'presentation') == \ + '𝖤x' + assert mathml(elliptic_e(x, y), printer = 'presentation') == \ + '𝖤xy' + +def test_print_elliptic_pi(): + assert mathml(elliptic_pi(x, y), printer = 'presentation') == \ + '𝛱xy' + assert mathml(elliptic_pi(x, y, z), printer = 'presentation') == \ + '𝛱xyz' + +def test_print_Ei(): + assert mathml(Ei(x), printer = 'presentation') == \ + 'Eix' + assert mathml(Ei(x**y), printer = 'presentation') == \ + 'Eixy' + +def test_print_expint(): + assert mathml(expint(x, y), printer = 'presentation') == \ + 'Exy' + assert mathml(expint(IndexedBase(x)[1], IndexedBase(x)[2]), printer = 'presentation') == \ + 'Ex1x2' + +def test_print_jacobi(): + assert mathml(jacobi(n, a, b, x), printer = 'presentation') == \ + 'Pnabx' + +def test_print_gegenbauer(): + assert mathml(gegenbauer(n, a, x), printer = 'presentation') == \ + 'Cnax' + +def test_print_chebyshevt(): + assert mathml(chebyshevt(n, x), printer = 'presentation') == \ + 'Tnx' + +def test_print_chebyshevu(): + assert mathml(chebyshevu(n, x), printer = 'presentation') == \ + 'Unx' + +def test_print_legendre(): + assert mathml(legendre(n, x), printer = 'presentation') == \ + 'Pnx' + +def test_print_assoc_legendre(): + assert mathml(assoc_legendre(n, a, x), printer = 'presentation') == \ + 'Pnax' + +def test_print_laguerre(): + assert mathml(laguerre(n, x), printer = 'presentation') == \ + 'Lnx' + +def test_print_assoc_laguerre(): + assert mathml(assoc_laguerre(n, a, x), printer = 'presentation') == \ + 'Lnax' + +def test_print_hermite(): + assert mathml(hermite(n, x), printer = 'presentation') == \ + 'Hnx' + +def test_mathml_SingularityFunction(): + assert mathml(SingularityFunction(x, 4, 5), printer='presentation') == \ + 'x' \ + '-45' + assert mathml(SingularityFunction(x, -3, 4), printer='presentation') == \ + 'x' \ + '+34' + assert mathml(SingularityFunction(x, 0, 4), printer='presentation') == \ + 'x' \ + '4' + assert mathml(SingularityFunction(x, a, n), printer='presentation') == \ + '' \ + '-a+x' \ + 'n' + assert mathml(SingularityFunction(x, 4, -2), printer='presentation') == \ + 'x' \ + '-4-2' + assert mathml(SingularityFunction(x, 4, -1), printer='presentation') == \ + 'x' \ + '-4-1' + + +def test_mathml_matrix_functions(): + from sympy.matrices import Adjoint, Inverse, Transpose + X = MatrixSymbol('X', 2, 2) + Y = MatrixSymbol('Y', 2, 2) + assert mathml(Adjoint(X), printer='presentation') == \ + 'X' + assert mathml(Adjoint(X + Y), printer='presentation') == \ + 'X+Y' + assert mathml(Adjoint(X) + Adjoint(Y), printer='presentation') == \ + 'X+' \ + 'Y' + assert mathml(Adjoint(X*Y), printer='presentation') == \ + 'X' \ + 'Y' + assert mathml(Adjoint(Y)*Adjoint(X), printer='presentation') == \ + 'Y⁢' \ + 'X' + assert mathml(Adjoint(X**2), printer='presentation') == \ + 'X2' + assert mathml(Adjoint(X)**2, printer='presentation') == \ + 'X2' + assert mathml(Adjoint(Inverse(X)), printer='presentation') == \ + 'X-1' + assert mathml(Inverse(Adjoint(X)), printer='presentation') == \ + 'X-1' + assert mathml(Adjoint(Transpose(X)), printer='presentation') == \ + 'XT' + assert mathml(Transpose(Adjoint(X)), printer='presentation') == \ + 'XT' + assert mathml(Transpose(Adjoint(X) + Y), printer='presentation') == \ + 'X' \ + '+YT' + assert mathml(Transpose(X), printer='presentation') == \ + 'XT' + assert mathml(Transpose(X + Y), printer='presentation') == \ + 'X+YT' + + +def test_mathml_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert mathml(Identity(4), printer='presentation') == '𝕀' + assert mathml(ZeroMatrix(2, 2), printer='presentation') == '𝟘' + assert mathml(OneMatrix(2, 2), printer='presentation') == '𝟙' + +def test_mathml_piecewise(): + from sympy.functions.elementary.piecewise import Piecewise + # Content MathML + assert mathml(Piecewise((x, x <= 1), (x**2, True))) == \ + 'xx1x2' + + raises(ValueError, lambda: mathml(Piecewise((x, x <= 1)))) + + +def test_issue_17857(): + assert mathml(Range(-oo, oo), printer='presentation') == \ + '-101' + assert mathml(Range(oo, -oo, -1), printer='presentation') == \ + '10-1' + + +def test_float_roundtrip(): + x = sympify(0.8975979010256552) + y = float(mp.doprint(x).strip('')) + assert x == y diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_numpy.py b/lib/python3.10/site-packages/sympy/printing/tests/test_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..a64a7368b0ef0ab2b425166b5a1aba57121b47fd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_numpy.py @@ -0,0 +1,365 @@ +from sympy.concrete.summations import Sum +from sympy.core.mod import Mod +from sympy.core.relational import (Equality, Unequality) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.gamma_functions import polygamma +from sympy.functions.special.error_functions import (Si, Ci) +from sympy.matrices.expressions.blockmatrix import BlockMatrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions.special import Identity +from sympy.utilities.lambdify import lambdify +from sympy import symbols, Min, Max + +from sympy.abc import x, i, j, a, b, c, d +from sympy.core import Pow +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.cfunctions import log1p, expm1, hypot, log10, exp2, log2, Sqrt +from sympy.tensor.array import Array +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter, _numpy_known_constants, \ + _numpy_known_functions, _scipy_known_constants, _scipy_known_functions +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array + +from sympy.testing.pytest import skip, raises +from sympy.external import import_module + +np = import_module('numpy') +jax = import_module('jax') + +if np: + deafult_float_info = np.finfo(np.array([]).dtype) + NUMPY_DEFAULT_EPSILON = deafult_float_info.eps + +def test_numpy_piecewise_regression(): + """ + NumPyPrinter needs to print Piecewise()'s choicelist as a list to avoid + breaking compatibility with numpy 1.8. This is not necessary in numpy 1.9+. + See gh-9747 and gh-9749 for details. + """ + printer = NumPyPrinter() + p = Piecewise((1, x < 0), (0, True)) + assert printer.doprint(p) == \ + 'numpy.select([numpy.less(x, 0),True], [1,0], default=numpy.nan)' + assert printer.module_imports == {'numpy': {'select', 'less', 'nan'}} + +def test_numpy_logaddexp(): + lae = logaddexp(a, b) + assert NumPyPrinter().doprint(lae) == 'numpy.logaddexp(a, b)' + lae2 = logaddexp2(a, b) + assert NumPyPrinter().doprint(lae2) == 'numpy.logaddexp2(a, b)' + + +def test_sum(): + if not np: + skip("NumPy not installed") + + s = Sum(x ** i, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(x_ ** i_ for i_ in range(a_, b_ + 1))) + + s = Sum(i * x, (i, a, b)) + f = lambdify((a, b, x), s, 'numpy') + + a_, b_ = 0, 10 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, x_), sum(i_ * x_ for i_ in range(a_, b_ + 1))) + + +def test_multiple_sums(): + if not np: + skip("NumPy not installed") + + s = Sum((x + j) * i, (i, a, b), (j, c, d)) + f = lambdify((a, b, c, d, x), s, 'numpy') + + a_, b_ = 0, 10 + c_, d_ = 11, 21 + x_ = np.linspace(-1, +1, 10) + assert np.allclose(f(a_, b_, c_, d_, x_), + sum((x_ + j_) * i_ for i_ in range(a_, b_ + 1) for j_ in range(c_, d_ + 1))) + + +def test_codegen_einsum(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'numpy') + + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + assert (f(ma, mb) == np.matmul(ma, mb)).all() + + +def test_codegen_extra(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = np.array([[1, 2], [3, 4]]) + mb = np.array([[1,-2], [-1, 3]]) + mc = np.array([[2, 0], [1, 2]]) + md = np.array([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.einsum(ma, [0, 1], mb, [2, 3])).all() + + cg = ArrayAdd(M, N) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == ma+mb).all() + + cg = ArrayAdd(M, N, P) + f = lambdify((M, N, P), cg, 'numpy') + assert (f(ma, mb, mc) == ma+mb+mc).all() + + cg = ArrayAdd(M, N, P, Q) + f = lambdify((M, N, P, Q), cg, 'numpy') + assert (f(ma, mb, mc, md) == ma+mb+mc+md).all() + + cg = PermuteDims(M, [1, 0]) + f = lambdify((M,), cg, 'numpy') + assert (f(ma) == ma.T).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.transpose(np.einsum(ma, [0, 1], mb, [2, 3]), (1, 2, 3, 0))).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + f = lambdify((M, N), cg, 'numpy') + assert (f(ma, mb) == np.diagonal(np.einsum(ma, [0, 1], mb, [2, 3]), axis1=1, axis2=2)).all() + + +def test_relational(): + if not np: + skip("NumPy not installed") + + e = Equality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, False]) + + e = Unequality(x, 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, True]) + + e = (x < 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, False, False]) + + e = (x <= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [True, True, False]) + + e = (x > 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, False, True]) + + e = (x >= 1) + + f = lambdify((x,), e) + x_ = np.array([0, 1, 2]) + assert np.array_equal(f(x_), [False, True, True]) + + +def test_mod(): + if not np: + skip("NumPy not installed") + + e = Mod(a, b) + f = lambdify((a, b), e) + + a_ = np.array([0, 1, 2, 3]) + b_ = 2 + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([0, 1, 2, 3]) + b_ = np.array([2, 2, 2, 2]) + assert np.array_equal(f(a_, b_), [0, 1, 0, 1]) + + a_ = np.array([2, 3, 4, 5]) + b_ = np.array([2, 3, 4, 5]) + assert np.array_equal(f(a_, b_), [0, 0, 0, 0]) + + +def test_pow(): + if not np: + skip('NumPy not installed') + + expr = Pow(2, -1, evaluate=False) + f = lambdify([], expr, 'numpy') + assert f() == 0.5 + + +def test_expm1(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), expm1(a), 'numpy') + assert abs(f(1e-10) - 1e-10 - 5e-21) <= 1e-10 * NUMPY_DEFAULT_EPSILON + + +def test_log1p(): + if not np: + skip("NumPy not installed") + + f = lambdify((a,), log1p(a), 'numpy') + assert abs(f(1e-99) - 1e-99) <= 1e-99 * NUMPY_DEFAULT_EPSILON + +def test_hypot(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a, b), hypot(a, b), 'numpy')(3, 4) - 5) <= NUMPY_DEFAULT_EPSILON + +def test_log10(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log10(a), 'numpy')(100) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_exp2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), exp2(a), 'numpy')(5) - 32) <= NUMPY_DEFAULT_EPSILON + + +def test_log2(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), log2(a), 'numpy')(256) - 8) <= NUMPY_DEFAULT_EPSILON + + +def test_Sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), Sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_sqrt(): + if not np: + skip("NumPy not installed") + assert abs(lambdify((a,), sqrt(a), 'numpy')(4) - 2) <= NUMPY_DEFAULT_EPSILON + + +def test_matsolve(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 3, 3) + x = MatrixSymbol("x", 3, 1) + + expr = M**(-1) * x + x + matsolve_expr = MatrixSolve(M, x) + x + + f = lambdify((M, x), expr) + f_matsolve = lambdify((M, x), matsolve_expr) + + m0 = np.array([[1, 2, 3], [3, 2, 5], [5, 6, 7]]) + assert np.linalg.matrix_rank(m0) == 3 + + x0 = np.array([3, 4, 5]) + + assert np.allclose(f_matsolve(m0, x0), f(m0, x0)) + + +def test_16857(): + if not np: + skip("NumPy not installed") + + a_1 = MatrixSymbol('a_1', 10, 3) + a_2 = MatrixSymbol('a_2', 10, 3) + a_3 = MatrixSymbol('a_3', 10, 3) + a_4 = MatrixSymbol('a_4', 10, 3) + A = BlockMatrix([[a_1, a_2], [a_3, a_4]]) + assert A.shape == (20, 6) + + printer = NumPyPrinter() + assert printer.doprint(A) == 'numpy.block([[a_1, a_2], [a_3, a_4]])' + + +def test_issue_17006(): + if not np: + skip("NumPy not installed") + + M = MatrixSymbol("M", 2, 2) + + f = lambdify(M, M + Identity(2)) + ma = np.array([[1, 2], [3, 4]]) + mr = np.array([[2, 2], [3, 5]]) + + assert (f(ma) == mr).all() + + from sympy.core.symbol import symbols + n = symbols('n', integer=True) + N = MatrixSymbol("M", n, n) + raises(NotImplementedError, lambda: lambdify(N, N + Identity(n))) + +def test_jax_tuple_compatibility(): + if not jax: + skip("Jax not installed") + + x, y, z = symbols('x y z') + expr = Max(x, y, z) + Min(x, y, z) + func = lambdify((x, y, z), expr, 'jax') + input_tuple1, input_tuple2 = (1, 2, 3), (4, 5, 6) + input_array1, input_array2 = jax.numpy.asarray(input_tuple1), jax.numpy.asarray(input_tuple2) + assert np.allclose(func(*input_tuple1), func(*input_array1)) + assert np.allclose(func(*input_tuple2), func(*input_array2)) + +def test_numpy_array(): + assert NumPyPrinter().doprint(Array(((1, 2), (3, 5)))) == 'numpy.array([[1, 2], [3, 5]])' + assert NumPyPrinter().doprint(Array((1, 2))) == 'numpy.array((1, 2))' + +def test_numpy_known_funcs_consts(): + assert _numpy_known_constants['NaN'] == 'numpy.nan' + assert _numpy_known_constants['EulerGamma'] == 'numpy.euler_gamma' + + assert _numpy_known_functions['acos'] == 'numpy.arccos' + assert _numpy_known_functions['log'] == 'numpy.log' + +def test_scipy_known_funcs_consts(): + assert _scipy_known_constants['GoldenRatio'] == 'scipy.constants.golden_ratio' + assert _scipy_known_constants['Pi'] == 'scipy.constants.pi' + + assert _scipy_known_functions['erf'] == 'scipy.special.erf' + assert _scipy_known_functions['factorial'] == 'scipy.special.factorial' + +def test_numpy_print_methods(): + prntr = NumPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + +def test_scipy_print_methods(): + prntr = SciPyPrinter() + assert hasattr(prntr, '_print_acos') + assert hasattr(prntr, '_print_log') + assert hasattr(prntr, '_print_erf') + assert hasattr(prntr, '_print_factorial') + assert hasattr(prntr, '_print_chebyshevt') + k = Symbol('k', integer=True, nonnegative=True) + x = Symbol('x', real=True) + assert prntr.doprint(polygamma(k, x)) == "scipy.special.polygamma(k, x)" + assert prntr.doprint(Si(x)) == "scipy.special.sici(x)[0]" + assert prntr.doprint(Ci(x)) == "scipy.special.sici(x)[1]" diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_octave.py b/lib/python3.10/site-packages/sympy/printing/tests/test_octave.py new file mode 100644 index 0000000000000000000000000000000000000000..279b300b950e4b3251347e30258b1ddd09d7d598 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_octave.py @@ -0,0 +1,515 @@ +from sympy.core import (S, pi, oo, symbols, Function, Rational, Integer, + Tuple, Symbol, EulerGamma, GoldenRatio, Catalan, + Lambda, Mul, Pow, Mod, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.functions import (arg, atan2, bernoulli, beta, ceiling, chebyshevu, + chebyshevt, conjugate, DiracDelta, exp, expint, + factorial, floor, harmonic, Heaviside, im, + laguerre, LambertW, log, Max, Min, Piecewise, + polylog, re, RisingFactorial, sign, sinc, sqrt, + zeta, binomial, legendre, dirichlet_eta, + riemann_xi) +from sympy.functions import (sin, cos, tan, cot, sec, csc, asin, acos, acot, + atan, asec, acsc, sinh, cosh, tanh, coth, csch, + sech, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.testing.pytest import raises, XFAIL +from sympy.utilities.lambdify import implemented_function +from sympy.matrices import (eye, Matrix, MatrixSymbol, Identity, + HadamardProduct, SparseMatrix, HadamardPower) +from sympy.functions.special.bessel import (jn, yn, besselj, bessely, besseli, + besselk, hankel1, hankel2, airyai, + airybi, airyaiprime, airybiprime) +from sympy.functions.special.gamma_functions import (gamma, lowergamma, + uppergamma, loggamma, + polygamma) +from sympy.functions.special.error_functions import (Chi, Ci, erf, erfc, erfi, + erfcinv, erfinv, fresnelc, + fresnels, li, Shi, Si, Li, + erf2, Ei) +from sympy.printing.octave import octave_code, octave_code as mcode + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert mcode(Integer(67)) == "67" + assert mcode(Integer(-1)) == "-1" + + +def test_Rational(): + assert mcode(Rational(3, 7)) == "3/7" + assert mcode(Rational(18, 9)) == "2" + assert mcode(Rational(3, -7)) == "-3/7" + assert mcode(Rational(-3, -7)) == "3/7" + assert mcode(x + Rational(3, 7)) == "x + 3/7" + assert mcode(Rational(3, 7)*x) == "3*x/7" + + +def test_Relational(): + assert mcode(Eq(x, y)) == "x == y" + assert mcode(Ne(x, y)) == "x != y" + assert mcode(Le(x, y)) == "x <= y" + assert mcode(Lt(x, y)) == "x < y" + assert mcode(Gt(x, y)) == "x > y" + assert mcode(Ge(x, y)) == "x >= y" + + +def test_Function(): + assert mcode(sin(x) ** cos(x)) == "sin(x).^cos(x)" + assert mcode(sign(x)) == "sign(x)" + assert mcode(exp(x)) == "exp(x)" + assert mcode(log(x)) == "log(x)" + assert mcode(factorial(x)) == "factorial(x)" + assert mcode(floor(x)) == "floor(x)" + assert mcode(atan2(y, x)) == "atan2(y, x)" + assert mcode(beta(x, y)) == 'beta(x, y)' + assert mcode(polylog(x, y)) == 'polylog(x, y)' + assert mcode(harmonic(x)) == 'harmonic(x)' + assert mcode(bernoulli(x)) == "bernoulli(x)" + assert mcode(bernoulli(x, y)) == "bernoulli(x, y)" + assert mcode(legendre(x, y)) == "legendre(x, y)" + + +def test_Function_change_name(): + assert mcode(abs(x)) == "abs(x)" + assert mcode(ceiling(x)) == "ceil(x)" + assert mcode(arg(x)) == "angle(x)" + assert mcode(im(x)) == "imag(x)" + assert mcode(re(x)) == "real(x)" + assert mcode(conjugate(x)) == "conj(x)" + assert mcode(chebyshevt(y, x)) == "chebyshevT(y, x)" + assert mcode(chebyshevu(y, x)) == "chebyshevU(y, x)" + assert mcode(laguerre(x, y)) == "laguerreL(x, y)" + assert mcode(Chi(x)) == "coshint(x)" + assert mcode(Shi(x)) == "sinhint(x)" + assert mcode(Ci(x)) == "cosint(x)" + assert mcode(Si(x)) == "sinint(x)" + assert mcode(li(x)) == "logint(x)" + assert mcode(loggamma(x)) == "gammaln(x)" + assert mcode(polygamma(x, y)) == "psi(x, y)" + assert mcode(RisingFactorial(x, y)) == "pochhammer(x, y)" + assert mcode(DiracDelta(x)) == "dirac(x)" + assert mcode(DiracDelta(x, 3)) == "dirac(3, x)" + assert mcode(Heaviside(x)) == "heaviside(x, 1/2)" + assert mcode(Heaviside(x, y)) == "heaviside(x, y)" + assert mcode(binomial(x, y)) == "bincoeff(x, y)" + assert mcode(Mod(x, y)) == "mod(x, y)" + + +def test_minmax(): + assert mcode(Max(x, y) + Min(x, y)) == "max(x, y) + min(x, y)" + assert mcode(Max(x, y, z)) == "max(x, max(y, z))" + assert mcode(Min(x, y, z)) == "min(x, min(y, z))" + + +def test_Pow(): + assert mcode(x**3) == "x.^3" + assert mcode(x**(y**3)) == "x.^(y.^3)" + assert mcode(x**Rational(2, 3)) == 'x.^(2/3)' + g = implemented_function('g', Lambda(x, 2*x)) + assert mcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x).^(-x + y.^x)./(x.^2 + y)" + # For issue 14160 + assert mcode(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x./(y.*y)' + + +def test_basic_ops(): + assert mcode(x*y) == "x.*y" + assert mcode(x + y) == "x + y" + assert mcode(x - y) == "x - y" + assert mcode(-x) == "-x" + + +def test_1_over_x_and_sqrt(): + # 1.0 and 0.5 would do something different in regular StrPrinter, + # but these are exact in IEEE floating point so no different here. + assert mcode(1/x) == '1./x' + assert mcode(x**-1) == mcode(x**-1.0) == '1./x' + assert mcode(1/sqrt(x)) == '1./sqrt(x)' + assert mcode(x**-S.Half) == mcode(x**-0.5) == '1./sqrt(x)' + assert mcode(sqrt(x)) == 'sqrt(x)' + assert mcode(x**S.Half) == mcode(x**0.5) == 'sqrt(x)' + assert mcode(1/pi) == '1/pi' + assert mcode(pi**-1) == mcode(pi**-1.0) == '1/pi' + assert mcode(pi**-0.5) == '1/sqrt(pi)' + + +def test_mix_number_mult_symbols(): + assert mcode(3*x) == "3*x" + assert mcode(pi*x) == "pi*x" + assert mcode(3/x) == "3./x" + assert mcode(pi/x) == "pi./x" + assert mcode(x/3) == "x/3" + assert mcode(x/pi) == "x/pi" + assert mcode(x*y) == "x.*y" + assert mcode(3*x*y) == "3*x.*y" + assert mcode(3*pi*x*y) == "3*pi*x.*y" + assert mcode(x/y) == "x./y" + assert mcode(3*x/y) == "3*x./y" + assert mcode(x*y/z) == "x.*y./z" + assert mcode(x/y*z) == "x.*z./y" + assert mcode(1/x/y) == "1./(x.*y)" + assert mcode(2*pi*x/y/z) == "2*pi*x./(y.*z)" + assert mcode(3*pi/x) == "3*pi./x" + assert mcode(S(3)/5) == "3/5" + assert mcode(S(3)/5*x) == "3*x/5" + assert mcode(x/y/z) == "x./(y.*z)" + assert mcode((x+y)/z) == "(x + y)./z" + assert mcode((x+y)/(z+x)) == "(x + y)./(x + z)" + assert mcode((x+y)/EulerGamma) == "(x + y)/%s" % EulerGamma.evalf(17) + assert mcode(x/3/pi) == "x/(3*pi)" + assert mcode(S(3)/5*x*y/pi) == "3*x.*y/(5*pi)" + + +def test_mix_number_pow_symbols(): + assert mcode(pi**3) == 'pi^3' + assert mcode(x**2) == 'x.^2' + assert mcode(x**(pi**3)) == 'x.^(pi^3)' + assert mcode(x**y) == 'x.^y' + assert mcode(x**(y**z)) == 'x.^(y.^z)' + assert mcode((x**y)**z) == '(x.^y).^z' + + +def test_imag(): + I = S('I') + assert mcode(I) == "1i" + assert mcode(5*I) == "5i" + assert mcode((S(3)/2)*I) == "3*1i/2" + assert mcode(3+4*I) == "3 + 4i" + assert mcode(sqrt(3)*I) == "sqrt(3)*1i" + + +def test_constants(): + assert mcode(pi) == "pi" + assert mcode(oo) == "inf" + assert mcode(-oo) == "-inf" + assert mcode(S.NegativeInfinity) == "-inf" + assert mcode(S.NaN) == "NaN" + assert mcode(S.Exp1) == "exp(1)" + assert mcode(exp(1)) == "exp(1)" + + +def test_constants_other(): + assert mcode(2*GoldenRatio) == "2*(1+sqrt(5))/2" + assert mcode(2*Catalan) == "2*%s" % Catalan.evalf(17) + assert mcode(2*EulerGamma) == "2*%s" % EulerGamma.evalf(17) + + +def test_boolean(): + assert mcode(x & y) == "x & y" + assert mcode(x | y) == "x | y" + assert mcode(~x) == "~x" + assert mcode(x & y & z) == "x & y & z" + assert mcode(x | y | z) == "x | y | z" + assert mcode((x & y) | z) == "z | x & y" + assert mcode((x | y) & z) == "z & (x | y)" + + +def test_KroneckerDelta(): + from sympy.functions import KroneckerDelta + assert mcode(KroneckerDelta(x, y)) == "double(x == y)" + assert mcode(KroneckerDelta(x, y + 1)) == "double(x == (y + 1))" + assert mcode(KroneckerDelta(2**x, y)) == "double((2.^x) == y)" + + +def test_Matrices(): + assert mcode(Matrix(1, 1, [10])) == "10" + A = Matrix([[1, sin(x/2), abs(x)], + [0, 1, pi], + [0, exp(1), ceiling(x)]]); + expected = "[1 sin(x/2) abs(x); 0 1 pi; 0 exp(1) ceil(x)]" + assert mcode(A) == expected + # row and columns + assert mcode(A[:,0]) == "[1; 0; 0]" + assert mcode(A[0,:]) == "[1 sin(x/2) abs(x)]" + # empty matrices + assert mcode(Matrix(0, 0, [])) == '[]' + assert mcode(Matrix(0, 3, [])) == 'zeros(0, 3)' + # annoying to read but correct + assert mcode(Matrix([[x, x - y, -y]])) == "[x x - y -y]" + + +def test_vector_entries_hadamard(): + # For a row or column, user might to use the other dimension + A = Matrix([[1, sin(2/x), 3*pi/x/5]]) + assert mcode(A) == "[1 sin(2./x) 3*pi./(5*x)]" + assert mcode(A.T) == "[1; sin(2./x); 3*pi./(5*x)]" + + +@XFAIL +def test_Matrices_entries_not_hadamard(): + # For Matrix with col >= 2, row >= 2, they need to be scalars + # FIXME: is it worth worrying about this? Its not wrong, just + # leave it user's responsibility to put scalar data for x. + A = Matrix([[1, sin(2/x), 3*pi/x/5], [1, 2, x*y]]) + expected = ("[1 sin(2/x) 3*pi/(5*x);\n" + "1 2 x*y]") # <- we give x.*y + assert mcode(A) == expected + + +def test_MatrixSymbol(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + B = MatrixSymbol('B', n, n) + assert mcode(A*B) == "A*B" + assert mcode(B*A) == "B*A" + assert mcode(2*A*B) == "2*A*B" + assert mcode(B*2*A) == "2*B*A" + assert mcode(A*(B + 3*Identity(n))) == "A*(3*eye(n) + B)" + assert mcode(A**(x**2)) == "A^(x.^2)" + assert mcode(A**3) == "A^3" + assert mcode(A**S.Half) == "A^(1/2)" + + +def test_MatrixSolve(): + n = Symbol('n', integer=True) + A = MatrixSymbol('A', n, n) + x = MatrixSymbol('x', n, 1) + assert mcode(MatrixSolve(A, x)) == "A \\ x" + +def test_special_matrices(): + assert mcode(6*Identity(3)) == "6*eye(3)" + + +def test_containers(): + assert mcode([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ + "{1, 2, 3, {4, 5, {6, 7}}, 8, {9, 10}, 11}" + assert mcode((1, 2, (3, 4))) == "{1, 2, {3, 4}}" + assert mcode([1]) == "{1}" + assert mcode((1,)) == "{1}" + assert mcode(Tuple(*[1, 2, 3])) == "{1, 2, 3}" + assert mcode((1, x*y, (3, x**2))) == "{1, x.*y, {3, x.^2}}" + # scalar, matrix, empty matrix and empty list + assert mcode((1, eye(3), Matrix(0, 0, []), [])) == "{1, [1 0 0; 0 1 0; 0 0 1], [], {}}" + + +def test_octave_noninline(): + source = mcode((x+y)/Catalan, assign_to='me', inline=False) + expected = ( + "Catalan = %s;\n" + "me = (x + y)/Catalan;" + ) % Catalan.evalf(17) + assert source == expected + + +def test_octave_piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + assert mcode(expr) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(expr, assign_to="r") == ( + "r = ((x < 1).*(x) + (~(x < 1)).*(x.^2));") + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x;\n" + "else\n" + " r = x.^2;\n" + "end") + expr = Piecewise((x**2, x < 1), (x**3, x < 2), (x**4, x < 3), (x**5, True)) + expected = ("((x < 1).*(x.^2) + (~(x < 1)).*( ...\n" + "(x < 2).*(x.^3) + (~(x < 2)).*( ...\n" + "(x < 3).*(x.^4) + (~(x < 3)).*(x.^5))))") + assert mcode(expr) == expected + assert mcode(expr, assign_to="r") == "r = " + expected + ";" + assert mcode(expr, assign_to="r", inline=False) == ( + "if (x < 1)\n" + " r = x.^2;\n" + "elseif (x < 2)\n" + " r = x.^3;\n" + "elseif (x < 3)\n" + " r = x.^4;\n" + "else\n" + " r = x.^5;\n" + "end") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: mcode(expr)) + + +def test_octave_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x**2, True)) + assert mcode(2*pw) == "2*((x < 1).*(x) + (~(x < 1)).*(x.^2))" + assert mcode(pw/x) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./x" + assert mcode(pw/(x*y)) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))./(x.*y)" + assert mcode(pw/3) == "((x < 1).*(x) + (~(x < 1)).*(x.^2))/3" + + +def test_octave_matrix_assign_to(): + A = Matrix([[1, 2, 3]]) + assert mcode(A, assign_to='a') == "a = [1 2 3];" + A = Matrix([[1, 2], [3, 4]]) + assert mcode(A, assign_to='A') == "A = [1 2; 3 4];" + + +def test_octave_matrix_assign_to_more(): + # assigning to Symbol or MatrixSymbol requires lhs/rhs match + A = Matrix([[1, 2, 3]]) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 2, 3) + assert mcode(A, assign_to=B) == "B = [1 2 3];" + raises(ValueError, lambda: mcode(A, assign_to=x)) + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_1x1(): + A = Matrix([[3]]) + B = MatrixSymbol('B', 1, 1) + C = MatrixSymbol('C', 1, 2) + assert mcode(A, assign_to=B) == "B = 3;" + # FIXME? + #assert mcode(A, assign_to=x) == "x = 3;" + raises(ValueError, lambda: mcode(A, assign_to=C)) + + +def test_octave_matrix_elements(): + A = Matrix([[x, 2, x*y]]) + assert mcode(A[0, 0]**2 + A[0, 1] + A[0, 2]) == "x.^2 + x.*y + 2" + A = MatrixSymbol('AA', 1, 3) + assert mcode(A) == "AA" + assert mcode(A[0, 0]**2 + sin(A[0,1]) + A[0,2]) == \ + "sin(AA(1, 2)) + AA(1, 1).^2 + AA(1, 3)" + assert mcode(sum(A)) == "AA(1, 1) + AA(1, 2) + AA(1, 3)" + + +def test_octave_boolean(): + assert mcode(True) == "true" + assert mcode(S.true) == "true" + assert mcode(False) == "false" + assert mcode(S.false) == "false" + + +def test_octave_not_supported(): + with raises(NotImplementedError): + mcode(S.ComplexInfinity) + f = Function('f') + assert mcode(f(x).diff(x), strict=False) == ( + "% Not supported in Octave:\n" + "% Derivative\n" + "Derivative(f(x), x)" + ) + + +def test_octave_not_supported_not_on_whitelist(): + from sympy.functions.special.polynomials import assoc_laguerre + with raises(NotImplementedError): + mcode(assoc_laguerre(x, y, z)) + + +def test_octave_expint(): + assert mcode(expint(1, x)) == "expint(x)" + with raises(NotImplementedError): + mcode(expint(2, x)) + assert mcode(expint(y, x), strict=False) == ( + "% Not supported in Octave:\n" + "% expint\n" + "expint(y, x)" + ) + + +def test_trick_indent_with_end_else_words(): + # words starting with "end" or "else" do not confuse the indenter + t1 = S('endless'); + t2 = S('elsewhere'); + pw = Piecewise((t1, x < 0), (t2, x <= 1), (1, True)) + assert mcode(pw, inline=False) == ( + "if (x < 0)\n" + " endless\n" + "elseif (x <= 1)\n" + " elsewhere\n" + "else\n" + " 1\n" + "end") + + +def test_hadamard(): + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + v = MatrixSymbol('v', 3, 1) + h = MatrixSymbol('h', 1, 3) + C = HadamardProduct(A, B) + n = Symbol('n') + assert mcode(C) == "A.*B" + assert mcode(C*v) == "(A.*B)*v" + assert mcode(h*C*v) == "h*(A.*B)*v" + assert mcode(C*A) == "(A.*B)*A" + # mixing Hadamard and scalar strange b/c we vectorize scalars + assert mcode(C*x*y) == "(x.*y)*(A.*B)" + + # Testing HadamardPower: + assert mcode(HadamardPower(A, n)) == "A.**n" + assert mcode(HadamardPower(A, 1+n)) == "A.**(n + 1)" + assert mcode(HadamardPower(A*B.T, 1+n)) == "(A*B.T).**(n + 1)" + + +def test_sparse(): + M = SparseMatrix(5, 6, {}) + M[2, 2] = 10; + M[1, 2] = 20; + M[1, 3] = 22; + M[0, 3] = 30; + M[3, 0] = x*y; + assert mcode(M) == ( + "sparse([4 2 3 1 2], [1 3 3 4 4], [x.*y 20 10 30 22], 5, 6)" + ) + + +def test_sinc(): + assert mcode(sinc(x)) == 'sinc(x/pi)' + assert mcode(sinc(x + 3)) == 'sinc((x + 3)/pi)' + assert mcode(sinc(pi*(x + 3))) == 'sinc(x + 3)' + + +def test_trigfun(): + for f in (sin, cos, tan, cot, sec, csc, asin, acos, acot, atan, asec, acsc, + sinh, cosh, tanh, coth, csch, sech, asinh, acosh, atanh, acoth, + asech, acsch): + assert octave_code(f(x) == f.__name__ + '(x)') + + +def test_specfun(): + n = Symbol('n') + for f in [besselj, bessely, besseli, besselk]: + assert octave_code(f(n, x)) == f.__name__ + '(n, x)' + for f in (erfc, erfi, erf, erfinv, erfcinv, fresnelc, fresnels, gamma): + assert octave_code(f(x)) == f.__name__ + '(x)' + assert octave_code(hankel1(n, x)) == 'besselh(n, 1, x)' + assert octave_code(hankel2(n, x)) == 'besselh(n, 2, x)' + assert octave_code(airyai(x)) == 'airy(0, x)' + assert octave_code(airyaiprime(x)) == 'airy(1, x)' + assert octave_code(airybi(x)) == 'airy(2, x)' + assert octave_code(airybiprime(x)) == 'airy(3, x)' + assert octave_code(uppergamma(n, x)) == '(gammainc(x, n, \'upper\').*gamma(n))' + assert octave_code(lowergamma(n, x)) == '(gammainc(x, n).*gamma(n))' + assert octave_code(z**lowergamma(n, x)) == 'z.^(gammainc(x, n).*gamma(n))' + assert octave_code(jn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*besselj(n + 1/2, x)/2' + assert octave_code(yn(n, x)) == 'sqrt(2)*sqrt(pi)*sqrt(1./x).*bessely(n + 1/2, x)/2' + assert octave_code(LambertW(x)) == 'lambertw(x)' + assert octave_code(LambertW(x, n)) == 'lambertw(n, x)' + + # Automatic rewrite + assert octave_code(Ei(x)) == '(logint(exp(x)))' + assert octave_code(dirichlet_eta(x)) == '(((x == 1).*(log(2)) + (~(x == 1)).*((1 - 2.^(1 - x)).*zeta(x))))' + assert octave_code(riemann_xi(x)) == '(pi.^(-x/2).*x.*(x - 1).*gamma(x/2).*zeta(x)/2)' + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert mcode(A[0, 0]) == "A(1, 1)" + assert mcode(3 * A[0, 0]) == "3*A(1, 1)" + + F = C[0, 0].subs(C, A - B) + assert mcode(F) == "(A - B)(1, 1)" + + +def test_zeta_printing_issue_14820(): + assert octave_code(zeta(x)) == 'zeta(x)' + with raises(NotImplementedError): + octave_code(zeta(x, y)) + + +def test_automatic_rewrite(): + assert octave_code(Li(x)) == '(logint(x) - logint(2))' + assert octave_code(erf2(x, y)) == '(-erf(x) + erf(y))' diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_precedence.py b/lib/python3.10/site-packages/sympy/printing/tests/test_precedence.py new file mode 100644 index 0000000000000000000000000000000000000000..372a5b0356b7a7473ecf595df45ae31c3bfaff71 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_precedence.py @@ -0,0 +1,89 @@ +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum +from sympy.core.function import Derivative +from sympy.core.numbers import Integer, Rational, Float, oo +from sympy.core.relational import Rel +from sympy.core.symbol import symbols +from sympy.functions import sin +from sympy.integrals.integrals import Integral +from sympy.series.order import Order + +from sympy.printing.precedence import precedence, PRECEDENCE + +x, y = symbols("x,y") + + +def test_Add(): + assert precedence(x + y) == PRECEDENCE["Add"] + assert precedence(x*y + 1) == PRECEDENCE["Add"] + + +def test_Function(): + assert precedence(sin(x)) == PRECEDENCE["Func"] + +def test_Derivative(): + assert precedence(Derivative(x, y)) == PRECEDENCE["Atom"] + +def test_Integral(): + assert precedence(Integral(x, y)) == PRECEDENCE["Atom"] + + +def test_Mul(): + assert precedence(x*y) == PRECEDENCE["Mul"] + assert precedence(-x*y) == PRECEDENCE["Add"] + + +def test_Number(): + assert precedence(Integer(0)) == PRECEDENCE["Atom"] + assert precedence(Integer(1)) == PRECEDENCE["Atom"] + assert precedence(Integer(-1)) == PRECEDENCE["Add"] + assert precedence(Integer(10)) == PRECEDENCE["Atom"] + assert precedence(Rational(5, 2)) == PRECEDENCE["Mul"] + assert precedence(Rational(-5, 2)) == PRECEDENCE["Add"] + assert precedence(Float(5)) == PRECEDENCE["Atom"] + assert precedence(Float(-5)) == PRECEDENCE["Add"] + assert precedence(oo) == PRECEDENCE["Atom"] + assert precedence(-oo) == PRECEDENCE["Add"] + + +def test_Order(): + assert precedence(Order(x)) == PRECEDENCE["Atom"] + + +def test_Pow(): + assert precedence(x**y) == PRECEDENCE["Pow"] + assert precedence(-x**y) == PRECEDENCE["Add"] + assert precedence(x**-y) == PRECEDENCE["Pow"] + + +def test_Product(): + assert precedence(Product(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Relational(): + assert precedence(Rel(x + y, y, "<")) == PRECEDENCE["Relational"] + + +def test_Sum(): + assert precedence(Sum(x, (x, y, y + 1))) == PRECEDENCE["Atom"] + + +def test_Symbol(): + assert precedence(x) == PRECEDENCE["Atom"] + + +def test_And_Or(): + # precedence relations between logical operators, ... + assert precedence(x & y) > precedence(x | y) + assert precedence(~y) > precedence(x & y) + # ... and with other operators (cfr. other programming languages) + assert precedence(x + y) > precedence(x | y) + assert precedence(x + y) > precedence(x & y) + assert precedence(x*y) > precedence(x | y) + assert precedence(x*y) > precedence(x & y) + assert precedence(~y) > precedence(x*y) + assert precedence(~y) > precedence(x - y) + # double checks + assert precedence(x & y) == PRECEDENCE["And"] + assert precedence(x | y) == PRECEDENCE["Or"] + assert precedence(~y) == PRECEDENCE["Not"] diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_preview.py b/lib/python3.10/site-packages/sympy/printing/tests/test_preview.py new file mode 100644 index 0000000000000000000000000000000000000000..91771ceb0466d6b0fee00570426713d02da14872 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_preview.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- + +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol +from sympy.functions.elementary.piecewise import Piecewise +from sympy.printing.preview import preview + +from io import BytesIO + + +def test_preview(): + x = Symbol('x') + obj = BytesIO() + try: + preview(x, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_unicode_symbol(): + # issue 9107 + a = Symbol('α') + obj = BytesIO() + try: + preview(a, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server + + +def test_preview_latex_construct_in_expr(): + # see PR 9801 + x = Symbol('x') + pw = Piecewise((1, Eq(x, 0)), (0, True)) + obj = BytesIO() + try: + preview(pw, output='png', viewer='BytesIO', outputbuffer=obj) + except RuntimeError: + pass # latex not installed on CI server diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_pycode.py b/lib/python3.10/site-packages/sympy/printing/tests/test_pycode.py new file mode 100644 index 0000000000000000000000000000000000000000..e15648ebe2d771152cd0323f42be9cb20c45d467 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_pycode.py @@ -0,0 +1,429 @@ +from sympy.codegen import Assignment +from sympy.codegen.ast import none +from sympy.codegen.cfunctions import expm1, log1p +from sympy.codegen.scipy_nodes import cosm1 +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core import Expr, Mod, symbols, Eq, Le, Gt, zoo, oo, Rational, Pow +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions import acos, KroneckerDelta, Piecewise, sign, sqrt, Min, Max, cot, acsch, asec, coth, sec +from sympy.logic import And, Or +from sympy.matrices import SparseMatrix, MatrixSymbol, Identity +from sympy.printing.pycode import ( + MpmathPrinter, PythonCodePrinter, pycode, SymPyPrinter +) +from sympy.printing.tensorflow import TensorflowPrinter +from sympy.printing.numpy import NumPyPrinter, SciPyPrinter +from sympy.testing.pytest import raises, skip +from sympy.tensor import IndexedBase, Idx +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayDiagonal, ArrayContraction, ZeroArray, OneArray +from sympy.external import import_module +from sympy.functions.special.gamma_functions import loggamma + + +x, y, z = symbols('x y z') +p = IndexedBase("p") + + +def test_PythonCodePrinter(): + prntr = PythonCodePrinter() + + assert not prntr.module_imports + + assert prntr.doprint(x**y) == 'x**y' + assert prntr.doprint(Mod(x, 2)) == 'x % 2' + assert prntr.doprint(-Mod(x, y)) == '-(x % y)' + assert prntr.doprint(Mod(-x, y)) == '(-x) % y' + assert prntr.doprint(And(x, y)) == 'x and y' + assert prntr.doprint(Or(x, y)) == 'x or y' + assert prntr.doprint(1/(x+y)) == '1/(x + y)' + assert not prntr.module_imports + + assert prntr.doprint(pi) == 'math.pi' + assert prntr.module_imports == {'math': {'pi'}} + + assert prntr.doprint(x**Rational(1, 2)) == 'math.sqrt(x)' + assert prntr.doprint(sqrt(x)) == 'math.sqrt(x)' + assert prntr.module_imports == {'math': {'pi', 'sqrt'}} + + assert prntr.doprint(acos(x)) == 'math.acos(x)' + assert prntr.doprint(cot(x)) == '(1/math.tan(x))' + assert prntr.doprint(coth(x)) == '((math.exp(x) + math.exp(-x))/(math.exp(x) - math.exp(-x)))' + assert prntr.doprint(asec(x)) == '(math.acos(1/x))' + assert prntr.doprint(acsch(x)) == '(math.log(math.sqrt(1 + x**(-2)) + 1/x))' + + assert prntr.doprint(Assignment(x, 2)) == 'x = 2' + assert prntr.doprint(Piecewise((1, Eq(x, 0)), + (2, x>6))) == '((1) if (x == 0) else (2) if (x > 6) else None)' + assert prntr.doprint(Piecewise((2, Le(x, 0)), + (3, Gt(x, 0)), evaluate=False)) == '((2) if (x <= 0) else'\ + ' (3) if (x > 0) else None)' + assert prntr.doprint(sign(x)) == '(0.0 if x == 0 else math.copysign(1, x))' + assert prntr.doprint(p[0, 1]) == 'p[0, 1]' + assert prntr.doprint(KroneckerDelta(x,y)) == '(1 if x == y else 0)' + + assert prntr.doprint((2,3)) == "(2, 3)" + assert prntr.doprint([2,3]) == "[2, 3]" + + assert prntr.doprint(Min(x, y)) == "min(x, y)" + assert prntr.doprint(Max(x, y)) == "max(x, y)" + + +def test_PythonCodePrinter_standard(): + prntr = PythonCodePrinter() + + assert prntr.standard == 'python3' + + raises(ValueError, lambda: PythonCodePrinter({'standard':'python4'})) + + +def test_MpmathPrinter(): + p = MpmathPrinter() + assert p.doprint(sign(x)) == 'mpmath.sign(x)' + assert p.doprint(Rational(1, 2)) == 'mpmath.mpf(1)/mpmath.mpf(2)' + + assert p.doprint(S.Exp1) == 'mpmath.e' + assert p.doprint(S.Pi) == 'mpmath.pi' + assert p.doprint(S.GoldenRatio) == 'mpmath.phi' + assert p.doprint(S.EulerGamma) == 'mpmath.euler' + assert p.doprint(S.NaN) == 'mpmath.nan' + assert p.doprint(S.Infinity) == 'mpmath.inf' + assert p.doprint(S.NegativeInfinity) == 'mpmath.ninf' + assert p.doprint(loggamma(x)) == 'mpmath.loggamma(x)' + + +def test_NumPyPrinter(): + from sympy.core.function import Lambda + from sympy.matrices.expressions.adjoint import Adjoint + from sympy.matrices.expressions.diagonal import (DiagMatrix, DiagonalMatrix, DiagonalOf) + from sympy.matrices.expressions.funcmatrix import FunctionMatrix + from sympy.matrices.expressions.hadamard import HadamardProduct + from sympy.matrices.expressions.kronecker import KroneckerProduct + from sympy.matrices.expressions.special import (OneMatrix, ZeroMatrix) + from sympy.abc import a, b + p = NumPyPrinter() + assert p.doprint(sign(x)) == 'numpy.sign(x)' + A = MatrixSymbol("A", 2, 2) + B = MatrixSymbol("B", 2, 2) + C = MatrixSymbol("C", 1, 5) + D = MatrixSymbol("D", 3, 4) + assert p.doprint(A**(-1)) == "numpy.linalg.inv(A)" + assert p.doprint(A**5) == "numpy.linalg.matrix_power(A, 5)" + assert p.doprint(Identity(3)) == "numpy.eye(3)" + + u = MatrixSymbol('x', 2, 1) + v = MatrixSymbol('y', 2, 1) + assert p.doprint(MatrixSolve(A, u)) == 'numpy.linalg.solve(A, x)' + assert p.doprint(MatrixSolve(A, u) + v) == 'numpy.linalg.solve(A, x) + y' + + assert p.doprint(ZeroMatrix(2, 3)) == "numpy.zeros((2, 3))" + assert p.doprint(OneMatrix(2, 3)) == "numpy.ones((2, 3))" + assert p.doprint(FunctionMatrix(4, 5, Lambda((a, b), a + b))) == \ + "numpy.fromfunction(lambda a, b: a + b, (4, 5))" + assert p.doprint(HadamardProduct(A, B)) == "numpy.multiply(A, B)" + assert p.doprint(KroneckerProduct(A, B)) == "numpy.kron(A, B)" + assert p.doprint(Adjoint(A)) == "numpy.conjugate(numpy.transpose(A))" + assert p.doprint(DiagonalOf(A)) == "numpy.reshape(numpy.diag(A), (-1, 1))" + assert p.doprint(DiagMatrix(C)) == "numpy.diagflat(C)" + assert p.doprint(DiagonalMatrix(D)) == "numpy.multiply(D, numpy.eye(3, 4))" + + # Workaround for numpy negative integer power errors + assert p.doprint(x**-1) == 'x**(-1.0)' + assert p.doprint(x**-2) == 'x**(-2.0)' + + expr = Pow(2, -1, evaluate=False) + assert p.doprint(expr) == "2**(-1.0)" + + assert p.doprint(S.Exp1) == 'numpy.e' + assert p.doprint(S.Pi) == 'numpy.pi' + assert p.doprint(S.EulerGamma) == 'numpy.euler_gamma' + assert p.doprint(S.NaN) == 'numpy.nan' + assert p.doprint(S.Infinity) == 'numpy.inf' + assert p.doprint(S.NegativeInfinity) == '-numpy.inf' + + # Function rewriting operator precedence fix + assert p.doprint(sec(x)**2) == '(numpy.cos(x)**(-1.0))**2' + + +def test_issue_18770(): + numpy = import_module('numpy') + if not numpy: + skip("numpy not installed.") + + from sympy.functions.elementary.miscellaneous import (Max, Min) + from sympy.utilities.lambdify import lambdify + + expr1 = Min(0.1*x + 3, x + 1, 0.5*x + 1) + func = lambdify(x, expr1, "numpy") + assert (func(numpy.linspace(0, 3, 3)) == [1.0, 1.75, 2.5 ]).all() + assert func(4) == 3 + + expr1 = Max(x**2, x**3) + func = lambdify(x,expr1, "numpy") + assert (func(numpy.linspace(-1, 2, 4)) == [1, 0, 1, 8] ).all() + assert func(4) == 64 + + +def test_SciPyPrinter(): + p = SciPyPrinter() + expr = acos(x) + assert 'numpy' not in p.module_imports + assert p.doprint(expr) == 'numpy.arccos(x)' + assert 'numpy' in p.module_imports + assert not any(m.startswith('scipy') for m in p.module_imports) + smat = SparseMatrix(2, 5, {(0, 1): 3}) + assert p.doprint(smat) == \ + 'scipy.sparse.coo_matrix(([3], ([0], [1])), shape=(2, 5))' + assert 'scipy.sparse' in p.module_imports + + assert p.doprint(S.GoldenRatio) == 'scipy.constants.golden_ratio' + assert p.doprint(S.Pi) == 'scipy.constants.pi' + assert p.doprint(S.Exp1) == 'numpy.e' + + +def test_pycode_reserved_words(): + s1, s2 = symbols('if else') + raises(ValueError, lambda: pycode(s1 + s2, error_on_reserved=True)) + py_str = pycode(s1 + s2) + assert py_str in ('else_ + if_', 'if_ + else_') + + +def test_issue_20762(): + # Make sure pycode removes curly braces from subscripted variables + a_b, b, a_11 = symbols('a_{b} b a_{11}') + expr = a_b*b + assert pycode(expr) == 'a_b*b' + expr = a_11*b + assert pycode(expr) == 'a_11*b' + + +def test_sqrt(): + prntr = PythonCodePrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'math.sqrt(x)' + assert prntr._print_Pow(1/sqrt(x), rational=False) == '1/math.sqrt(x)' + + prntr = PythonCodePrinter({'standard' : 'python3'}) + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + assert prntr._print_Pow(1/sqrt(x), rational=True) == 'x**(-1/2)' + + prntr = MpmathPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'mpmath.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == \ + "x**(mpmath.mpf(1)/mpmath.mpf(2))" + + prntr = NumPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SciPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'numpy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + prntr = SymPyPrinter() + assert prntr._print_Pow(sqrt(x), rational=False) == 'sympy.sqrt(x)' + assert prntr._print_Pow(sqrt(x), rational=True) == 'x**(1/2)' + + +def test_frac(): + from sympy.functions.elementary.integers import frac + + expr = frac(x) + prntr = NumPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'numpy.mod(x, 1)' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == 'x % 1' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.frac(x)' + + prntr = SymPyPrinter() + assert prntr.doprint(expr) == 'sympy.functions.elementary.integers.frac(x)' + + +class CustomPrintedObject(Expr): + def _numpycode(self, printer): + return 'numpy' + + def _mpmathcode(self, printer): + return 'mpmath' + + +def test_printmethod(): + obj = CustomPrintedObject() + assert NumPyPrinter().doprint(obj) == 'numpy' + assert MpmathPrinter().doprint(obj) == 'mpmath' + + +def test_codegen_ast_nodes(): + assert pycode(none) == 'None' + + +def test_issue_14283(): + prntr = PythonCodePrinter() + + assert prntr.doprint(zoo) == "math.nan" + assert prntr.doprint(-oo) == "float('-inf')" + + +def test_NumPyPrinter_print_seq(): + n = NumPyPrinter() + + assert n._print_seq(range(2)) == '(0, 1,)' + + +def test_issue_16535_16536(): + from sympy.functions.special.gamma_functions import (lowergamma, uppergamma) + + a = symbols('a') + expr1 = lowergamma(a, x) + expr2 = uppergamma(a, x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.gamma(a)*scipy.special.gammainc(a, x)' + assert prntr.doprint(expr2) == 'scipy.special.gamma(a)*scipy.special.gammaincc(a, x)' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter({'strict': False}) + + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr1) + assert "Not supported" in p_pycode.doprint(expr) + + +def test_Integral(): + from sympy.functions.elementary.exponential import exp + from sympy.integrals.integrals import Integral + + single = Integral(exp(-x), (x, 0, oo)) + double = Integral(x**2*exp(x*y), (x, -z, z), (y, 0, z)) + indefinite = Integral(x**2, x) + evaluateat = Integral(x**2, (x, 1)) + + prntr = SciPyPrinter() + assert prntr.doprint(single) == 'scipy.integrate.quad(lambda x: numpy.exp(-x), 0, numpy.inf)[0]' + assert prntr.doprint(double) == 'scipy.integrate.nquad(lambda x, y: x**2*numpy.exp(x*y), ((-z, z), (0, z)))[0]' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + prntr = MpmathPrinter() + assert prntr.doprint(single) == 'mpmath.quad(lambda x: mpmath.exp(-x), (0, mpmath.inf))' + assert prntr.doprint(double) == 'mpmath.quad(lambda x, y: x**2*mpmath.exp(x*y), (-z, z), (0, z))' + raises(NotImplementedError, lambda: prntr.doprint(indefinite)) + raises(NotImplementedError, lambda: prntr.doprint(evaluateat)) + + +def test_fresnel_integrals(): + from sympy.functions.special.error_functions import (fresnelc, fresnels) + + expr1 = fresnelc(x) + expr2 = fresnels(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.fresnel(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.fresnel(x)[0]' + + p_numpy = NumPyPrinter() + p_pycode = PythonCodePrinter() + p_mpmath = MpmathPrinter() + for expr in [expr1, expr2]: + with raises(NotImplementedError): + p_numpy.doprint(expr) + with raises(NotImplementedError): + p_pycode.doprint(expr) + + assert p_mpmath.doprint(expr1) == 'mpmath.fresnelc(x)' + assert p_mpmath.doprint(expr2) == 'mpmath.fresnels(x)' + + +def test_beta(): + from sympy.functions.special.beta_functions import beta + + expr = beta(x, y) + + prntr = SciPyPrinter() + assert prntr.doprint(expr) == 'scipy.special.beta(x, y)' + + prntr = NumPyPrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter() + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = PythonCodePrinter({'allow_unknown_functions': True}) + assert prntr.doprint(expr) == '(math.gamma(x)*math.gamma(y)/math.gamma(x + y))' + + prntr = MpmathPrinter() + assert prntr.doprint(expr) == 'mpmath.beta(x, y)' + +def test_airy(): + from sympy.functions.special.bessel import (airyai, airybi) + + expr1 = airyai(x) + expr2 = airybi(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[0]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[2]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + +def test_airy_prime(): + from sympy.functions.special.bessel import (airyaiprime, airybiprime) + + expr1 = airyaiprime(x) + expr2 = airybiprime(x) + + prntr = SciPyPrinter() + assert prntr.doprint(expr1) == 'scipy.special.airy(x)[1]' + assert prntr.doprint(expr2) == 'scipy.special.airy(x)[3]' + + prntr = NumPyPrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + prntr = PythonCodePrinter({'strict': False}) + assert "Not supported" in prntr.doprint(expr1) + assert "Not supported" in prntr.doprint(expr2) + + +def test_numerical_accuracy_functions(): + prntr = SciPyPrinter() + assert prntr.doprint(expm1(x)) == 'numpy.expm1(x)' + assert prntr.doprint(log1p(x)) == 'numpy.log1p(x)' + assert prntr.doprint(cosm1(x)) == 'scipy.special.cosm1(x)' + +def test_array_printer(): + A = ArraySymbol('A', (4,4,6,6,6)) + I = IndexedBase('I') + i,j,k = Idx('i', (0,1)), Idx('j', (2,3)), Idx('k', (4,5)) + + prntr = NumPyPrinter() + assert prntr.doprint(ZeroArray(5)) == 'numpy.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'numpy.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'numpy.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'numpy.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'numpy.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'numpy.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' + + prntr = TensorflowPrinter() + assert prntr.doprint(ZeroArray(5)) == 'tensorflow.zeros((5,))' + assert prntr.doprint(OneArray(5)) == 'tensorflow.ones((5,))' + assert prntr.doprint(ArrayContraction(A, [2,3])) == 'tensorflow.linalg.einsum("abccd->abd", A)' + assert prntr.doprint(I) == 'I' + assert prntr.doprint(ArrayDiagonal(A, [2,3,4])) == 'tensorflow.linalg.einsum("abccc->abc", A)' + assert prntr.doprint(ArrayDiagonal(A, [0,1], [2,3])) == 'tensorflow.linalg.einsum("aabbc->cab", A)' + assert prntr.doprint(ArrayContraction(A, [2], [3])) == 'tensorflow.linalg.einsum("abcde->abe", A)' + assert prntr.doprint(Assignment(I[i,j,k], I[i,j,k])) == 'I = I' diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_python.py b/lib/python3.10/site-packages/sympy/printing/tests/test_python.py new file mode 100644 index 0000000000000000000000000000000000000000..fb94a662be90934a672d08b3de44a22e2580d8b6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_python.py @@ -0,0 +1,203 @@ +from sympy.core.function import (Derivative, Function) +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (Abs, conjugate) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from sympy.series.limits import limit + +from sympy.printing.python import python + +from sympy.testing.pytest import raises, XFAIL + +x, y = symbols('x,y') +th = Symbol('theta') +ph = Symbol('phi') + + +def test_python_basic(): + # Simple numbers/symbols + assert python(-Rational(1)/2) == "e = Rational(-1, 2)" + assert python(-Rational(13)/22) == "e = Rational(-13, 22)" + assert python(oo) == "e = oo" + + # Powers + assert python(x**2) == "x = Symbol(\'x\')\ne = x**2" + assert python(1/x) == "x = Symbol('x')\ne = 1/x" + assert python(y*x**-2) == "y = Symbol('y')\nx = Symbol('x')\ne = y/x**2" + assert python( + x**Rational(-5, 2)) == "x = Symbol('x')\ne = x**Rational(-5, 2)" + + # Sums of terms + assert python(x**2 + x + 1) in [ + "x = Symbol('x')\ne = 1 + x + x**2", + "x = Symbol('x')\ne = x + x**2 + 1", + "x = Symbol('x')\ne = x**2 + x + 1", ] + assert python(1 - x) in [ + "x = Symbol('x')\ne = 1 - x", + "x = Symbol('x')\ne = -x + 1"] + assert python(1 - 2*x) in [ + "x = Symbol('x')\ne = 1 - 2*x", + "x = Symbol('x')\ne = -2*x + 1"] + assert python(1 - Rational(3, 2)*y/x) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3/2*y/x", + "y = Symbol('y')\nx = Symbol('x')\ne = -3/2*y/x + 1", + "y = Symbol('y')\nx = Symbol('x')\ne = 1 - 3*y/(2*x)"] + + # Multiplication + assert python(x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = x/y" + assert python(-x/y) == "x = Symbol('x')\ny = Symbol('y')\ne = -x/y" + assert python((x + 2)/y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(2 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = 1/y*(x + 2)", + "x = Symbol('x')\ny = Symbol('y')\ne = 1/y*(2 + x)", + "x = Symbol('x')\ny = Symbol('y')\ne = (2 + x)/y", + "x = Symbol('x')\ny = Symbol('y')\ne = (x + 2)/y"] + assert python((1 + x)*y) in [ + "y = Symbol('y')\nx = Symbol('x')\ne = y*(1 + x)", + "y = Symbol('y')\nx = Symbol('x')\ne = y*(x + 1)", ] + + # Check for proper placement of negative sign + assert python(-5*x/(x + 10)) == "x = Symbol('x')\ne = -5*x/(x + 10)" + assert python(1 - Rational(3, 2)*(x + 1)) in [ + "x = Symbol('x')\ne = Rational(-3, 2)*x + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)", + "x = Symbol('x')\ne = -3*x/2 + Rational(-1, 2)" + ] + + +def test_python_keyword_symbol_name_escaping(): + # Check for escaping of keywords + assert python( + 5*Symbol("lambda")) == "lambda_ = Symbol('lambda')\ne = 5*lambda_" + assert (python(5*Symbol("lambda") + 7*Symbol("lambda_")) == + "lambda__ = Symbol('lambda')\nlambda_ = Symbol('lambda_')\ne = 7*lambda_ + 5*lambda__") + assert (python(5*Symbol("for") + Function("for_")(8)) == + "for__ = Symbol('for')\nfor_ = Function('for_')\ne = 5*for__ + for_(8)") + + +def test_python_keyword_function_name_escaping(): + assert python( + 5*Function("for")(8)) == "for_ = Function('for')\ne = 5*for_(8)" + + +def test_python_relational(): + assert python(Eq(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = Eq(x, y)" + assert python(Ge(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x >= y" + assert python(Le(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x <= y" + assert python(Gt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x > y" + assert python(Lt(x, y)) == "x = Symbol('x')\ny = Symbol('y')\ne = x < y" + assert python(Ne(x/(y + 1), y**2)) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(1 + y), y**2)", + "x = Symbol('x')\ny = Symbol('y')\ne = Ne(x/(y + 1), y**2)"] + + +def test_python_functions(): + # Simple + assert python(2*x + exp(x)) in "x = Symbol('x')\ne = 2*x + exp(x)" + assert python(sqrt(2)) == 'e = sqrt(2)' + assert python(2**Rational(1, 3)) == 'e = 2**Rational(1, 3)' + assert python(sqrt(2 + pi)) == 'e = sqrt(2 + pi)' + assert python((2 + pi)**Rational(1, 3)) == 'e = (2 + pi)**Rational(1, 3)' + assert python(2**Rational(1, 4)) == 'e = 2**Rational(1, 4)' + assert python(Abs(x)) == "x = Symbol('x')\ne = Abs(x)" + assert python( + Abs(x/(x**2 + 1))) in ["x = Symbol('x')\ne = Abs(x/(1 + x**2))", + "x = Symbol('x')\ne = Abs(x/(x**2 + 1))"] + + # Univariate/Multivariate functions + f = Function('f') + assert python(f(x)) == "x = Symbol('x')\nf = Function('f')\ne = f(x)" + assert python(f(x, y)) == "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x, y)" + assert python(f(x/(y + 1), y)) in [ + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(1 + y), y)", + "x = Symbol('x')\ny = Symbol('y')\nf = Function('f')\ne = f(x/(y + 1), y)"] + + # Nesting of square roots + assert python(sqrt((sqrt(x + 1)) + 1)) in [ + "x = Symbol('x')\ne = sqrt(1 + sqrt(1 + x))", + "x = Symbol('x')\ne = sqrt(sqrt(x + 1) + 1)"] + + # Nesting of powers + assert python((((x + 1)**Rational(1, 3)) + 1)**Rational(1, 3)) in [ + "x = Symbol('x')\ne = (1 + (1 + x)**Rational(1, 3))**Rational(1, 3)", + "x = Symbol('x')\ne = ((x + 1)**Rational(1, 3) + 1)**Rational(1, 3)"] + + # Function powers + assert python(sin(x)**2) == "x = Symbol('x')\ne = sin(x)**2" + + +@XFAIL +def test_python_functions_conjugates(): + a, b = map(Symbol, 'ab') + assert python( conjugate(a + b*I) ) == '_ _\na - I*b' + assert python( conjugate(exp(a + b*I)) ) == ' _ _\n a - I*b\ne ' + + +def test_python_derivatives(): + # Simple + f_1 = Derivative(log(x), x, evaluate=False) + assert python(f_1) == "x = Symbol('x')\ne = Derivative(log(x), x)" + + f_2 = Derivative(log(x), x, evaluate=False) + x + assert python(f_2) == "x = Symbol('x')\ne = x + Derivative(log(x), x)" + + # Multiple symbols + f_3 = Derivative(log(x) + x**2, x, y, evaluate=False) + assert python(f_3) == \ + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(x**2 + log(x), x, y)" + + f_4 = Derivative(2*x*y, y, x, evaluate=False) + x**2 + assert python(f_4) in [ + "x = Symbol('x')\ny = Symbol('y')\ne = x**2 + Derivative(2*x*y, y, x)", + "x = Symbol('x')\ny = Symbol('y')\ne = Derivative(2*x*y, y, x) + x**2"] + + +def test_python_integrals(): + # Simple + f_1 = Integral(log(x), x) + assert python(f_1) == "x = Symbol('x')\ne = Integral(log(x), x)" + + f_2 = Integral(x**2, x) + assert python(f_2) == "x = Symbol('x')\ne = Integral(x**2, x)" + + # Double nesting of pow + f_3 = Integral(x**(2**x), x) + assert python(f_3) == "x = Symbol('x')\ne = Integral(x**(2**x), x)" + + # Definite integrals + f_4 = Integral(x**2, (x, 1, 2)) + assert python(f_4) == "x = Symbol('x')\ne = Integral(x**2, (x, 1, 2))" + + f_5 = Integral(x**2, (x, Rational(1, 2), 10)) + assert python( + f_5) == "x = Symbol('x')\ne = Integral(x**2, (x, Rational(1, 2), 10))" + + # Nested integrals + f_6 = Integral(x**2*y**2, x, y) + assert python(f_6) == "x = Symbol('x')\ny = Symbol('y')\ne = Integral(x**2*y**2, x, y)" + + +def test_python_matrix(): + p = python(Matrix([[x**2+1, 1], [y, x+y]])) + s = "x = Symbol('x')\ny = Symbol('y')\ne = MutableDenseMatrix([[x**2 + 1, 1], [y, x + y]])" + assert p == s + +def test_python_limits(): + assert python(limit(x, x, oo)) == 'e = oo' + assert python(limit(x**2, x, 0)) == 'e = 0' + +def test_issue_20762(): + # Make sure Python removes curly braces from subscripted variables + a_b = Symbol('a_{b}') + b = Symbol('b') + expr = a_b*b + assert python(expr) == "a_b = Symbol('a_{b}')\nb = Symbol('b')\ne = a_b*b" + + +def test_settings(): + raises(TypeError, lambda: python(x, method="garbage")) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_rcode.py b/lib/python3.10/site-packages/sympy/printing/tests/test_rcode.py new file mode 100644 index 0000000000000000000000000000000000000000..a83235b0654c6bf24c30846dbf68678d29cd3c80 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_rcode.py @@ -0,0 +1,476 @@ +from sympy.core import (S, pi, oo, Symbol, symbols, Rational, Integer, + GoldenRatio, EulerGamma, Catalan, Lambda, Dummy) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + gamma, sign, Max, Min, factorial, beta) +from sympy.core.relational import (Eq, Ge, Gt, Le, Lt, Ne) +from sympy.sets import Range +from sympy.logic import ITE +from sympy.codegen import For, aug_assign, Assignment +from sympy.testing.pytest import raises +from sympy.printing.rcode import RCodePrinter +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import Matrix, MatrixSymbol + +from sympy.printing.rcode import rcode + +x, y, z = symbols('x,y,z') + + +def test_printmethod(): + class fabs(Abs): + def _rcode(self, printer): + return "abs(%s)" % printer._print(self.args[0]) + + assert rcode(fabs(x)) == "abs(x)" + + +def test_rcode_sqrt(): + assert rcode(sqrt(x)) == "sqrt(x)" + assert rcode(x**0.5) == "sqrt(x)" + assert rcode(sqrt(x)) == "sqrt(x)" + + +def test_rcode_Pow(): + assert rcode(x**3) == "x^3" + assert rcode(x**(y**3)) == "x^(y^3)" + g = implemented_function('g', Lambda(x, 2*x)) + assert rcode(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x)^(-x + y^x)/(x^2 + y)" + assert rcode(x**-1.0) == '1.0/x' + assert rcode(x**Rational(2, 3)) == 'x^(2.0/3.0)' + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi"), + (lambda base, exp: not exp.is_integer, "pow")] + assert rcode(x**3, user_functions={'Pow': _cond_cfunc}) == 'dpowi(x, 3)' + assert rcode(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'pow(x, 3.2)' + + +def test_rcode_Max(): + # Test for gh-11926 + assert rcode(Max(x,x*x),user_functions={"Max":"my_max", "Pow":"my_pow"}) == 'my_max(x, my_pow(x, 2))' + + +def test_rcode_constants_mathh(): + assert rcode(exp(1)) == "exp(1)" + assert rcode(pi) == "pi" + assert rcode(oo) == "Inf" + assert rcode(-oo) == "-Inf" + + +def test_rcode_constants_other(): + assert rcode(2*GoldenRatio) == "GoldenRatio = 1.61803398874989;\n2*GoldenRatio" + assert rcode( + 2*Catalan) == "Catalan = 0.915965594177219;\n2*Catalan" + assert rcode(2*EulerGamma) == "EulerGamma = 0.577215664901533;\n2*EulerGamma" + + +def test_rcode_Rational(): + assert rcode(Rational(3, 7)) == "3.0/7.0" + assert rcode(Rational(18, 9)) == "2" + assert rcode(Rational(3, -7)) == "-3.0/7.0" + assert rcode(Rational(-3, -7)) == "3.0/7.0" + assert rcode(x + Rational(3, 7)) == "x + 3.0/7.0" + assert rcode(Rational(3, 7)*x) == "(3.0/7.0)*x" + + +def test_rcode_Integer(): + assert rcode(Integer(67)) == "67" + assert rcode(Integer(-1)) == "-1" + + +def test_rcode_functions(): + assert rcode(sin(x) ** cos(x)) == "sin(x)^cos(x)" + assert rcode(factorial(x) + gamma(y)) == "factorial(x) + gamma(y)" + assert rcode(beta(Min(x, y), Max(x, y))) == "beta(min(x, y), max(x, y))" + + +def test_rcode_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert rcode(g(x)) == "2*x" + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert rcode( + g(x)) == "Catalan = %s;\n2*x/Catalan" % Catalan.n() + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + res=rcode(g(A[i]), assign_to=A[i]) + ref=( + "for (i in 1:n){\n" + " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" + "}" + ) + assert res == ref + + +def test_rcode_exceptions(): + assert rcode(ceiling(x)) == "ceiling(x)" + assert rcode(Abs(x)) == "abs(x)" + assert rcode(gamma(x)) == "gamma(x)" + + +def test_rcode_user_functions(): + x = symbols('x', integer=False) + n = symbols('n', integer=True) + custom_functions = { + "ceiling": "myceil", + "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")], + } + assert rcode(ceiling(x), user_functions=custom_functions) == "myceil(x)" + assert rcode(Abs(x), user_functions=custom_functions) == "fabs(x)" + assert rcode(Abs(n), user_functions=custom_functions) == "abs(n)" + + +def test_rcode_boolean(): + assert rcode(True) == "True" + assert rcode(S.true) == "True" + assert rcode(False) == "False" + assert rcode(S.false) == "False" + assert rcode(x & y) == "x & y" + assert rcode(x | y) == "x | y" + assert rcode(~x) == "!x" + assert rcode(x & y & z) == "x & y & z" + assert rcode(x | y | z) == "x | y | z" + assert rcode((x & y) | z) == "z | x & y" + assert rcode((x | y) & z) == "z & (x | y)" + +def test_rcode_Relational(): + assert rcode(Eq(x, y)) == "x == y" + assert rcode(Ne(x, y)) == "x != y" + assert rcode(Le(x, y)) == "x <= y" + assert rcode(Lt(x, y)) == "x < y" + assert rcode(Gt(x, y)) == "x > y" + assert rcode(Ge(x, y)) == "x >= y" + + +def test_rcode_Piecewise(): + expr = Piecewise((x, x < 1), (x**2, True)) + res=rcode(expr) + ref="ifelse(x < 1,x,x^2)" + assert res == ref + tau=Symbol("tau") + res=rcode(expr,tau) + ref="tau = ifelse(x < 1,x,x^2);" + assert res == ref + + expr = 2*Piecewise((x, x < 1), (x**2, x<2), (x**3,True)) + assert rcode(expr) == "2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3))" + res = rcode(expr, assign_to='c') + assert res == "c = 2*ifelse(x < 1,x,ifelse(x < 2,x^2,x^3));" + + # Check that Piecewise without a True (default) condition error + #expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + #raises(ValueError, lambda: rcode(expr)) + expr = 2*Piecewise((x, x < 1), (x**2, x<2)) + assert(rcode(expr))== "2*ifelse(x < 1,x,ifelse(x < 2,x^2,NA))" + + +def test_rcode_sinc(): + from sympy.functions.elementary.trigonometric import sinc + expr = sinc(x) + res = rcode(expr) + ref = "(ifelse(x != 0,sin(x)/x,1))" + assert res == ref + + +def test_rcode_Piecewise_deep(): + p = rcode(2*Piecewise((x, x < 1), (x + 1, x < 2), (x**2, True))) + assert p == "2*ifelse(x < 1,x,ifelse(x < 2,x + 1,x^2))" + expr = x*y*z + x**2 + y**2 + Piecewise((0, x < 0.5), (1, True)) + cos(z) - 1 + p = rcode(expr) + ref="x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1" + assert p == ref + + ref="c = x^2 + x*y*z + y^2 + ifelse(x < 0.5,0,1) + cos(z) - 1;" + p = rcode(expr, assign_to='c') + assert p == ref + + +def test_rcode_ITE(): + expr = ITE(x < 1, y, z) + p = rcode(expr) + ref="ifelse(x < 1,y,z)" + assert p == ref + + +def test_rcode_settings(): + raises(TypeError, lambda: rcode(sin(x), method="garbage")) + + +def test_rcode_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + p = RCodePrinter() + p._not_r = set() + + x = IndexedBase('x')[j] + assert p._print_Indexed(x) == 'x[j]' + A = IndexedBase('A')[i, j] + assert p._print_Indexed(A) == 'A[i, j]' + B = IndexedBase('B')[i, j, k] + assert p._print_Indexed(B) == 'B[i, j, k]' + + assert p._not_r == set() + +def test_rcode_Indexed_without_looking_for_contraction(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + Dy = IndexedBase('Dy', shape=(len_y-1,)) + i = Idx('i', len_y-1) + e=Eq(Dy[i], (y[i+1]-y[i])/(x[i+1]-x[i])) + code0 = rcode(e.rhs, assign_to=e.lhs, contract=False) + assert code0 == 'Dy[i] = (y[%s] - y[i])/(x[%s] - x[i]);' % (i + 1, i + 1) + + +def test_rcode_loops_matrix_vector(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = A[i, j]*x[j] + y[i];\n' + ' }\n' + '}' + ) + c = rcode(A[i, j]*x[j], assign_to=y[i]) + assert c == s + + +def test_dummy_loops(): + # the following line could also be + # [Dummy(s, integer=True) for s in 'im'] + # or [Dummy(integer=True) for s in 'im'] + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + expected = ( + 'for (i_%(icount)i in 1:m_%(mcount)i){\n' + ' y[i_%(icount)i] = x[i_%(icount)i];\n' + '}' + ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index} + code = rcode(x[i], assign_to=y[i]) + assert code == expected + + +def test_rcode_loops_add(): + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i = Idx('i', m) + j = Idx('j', n) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = x[i] + z[i];\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = A[i, j]*x[j] + y[i];\n' + ' }\n' + '}' + ) + c = rcode(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_multiple_contractions(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' for (l in 1:p){\n' + ' y[i] = a[i, j, k, l]*b[j, k, l] + y[i];\n' + ' }\n' + ' }\n' + ' }\n' + '}' + ) + c = rcode(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_addfactor(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + s = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' for (l in 1:p){\n' + ' y[i] = (a[i, j, k, l] + b[i, j, k, l])*c[j, k, l] + y[i];\n' + ' }\n' + ' }\n' + ' }\n' + '}' + ) + c = rcode((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) + assert c == s + + +def test_rcode_loops_multiple_terms(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + + s0 = ( + 'for (i in 1:m){\n' + ' y[i] = 0;\n' + '}\n' + ) + s1 = ( + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' for (k in 1:o){\n' + ' y[i] = b[j]*b[k]*c[i, j, k] + y[i];\n' + ' }\n' + ' }\n' + '}\n' + ) + s2 = ( + 'for (i in 1:m){\n' + ' for (k in 1:o){\n' + ' y[i] = a[i, k]*b[k] + y[i];\n' + ' }\n' + '}\n' + ) + s3 = ( + 'for (i in 1:m){\n' + ' for (j in 1:n){\n' + ' y[i] = a[i, j]*b[j] + y[i];\n' + ' }\n' + '}\n' + ) + c = rcode( + b[j]*a[i, j] + b[k]*a[i, k] + b[j]*b[k]*c[i, j, k], assign_to=y[i]) + + ref={} + ref[0] = s0 + s1 + s2 + s3[:-1] + ref[1] = s0 + s1 + s3 + s2[:-1] + ref[2] = s0 + s2 + s1 + s3[:-1] + ref[3] = s0 + s2 + s3 + s1[:-1] + ref[4] = s0 + s3 + s1 + s2[:-1] + ref[5] = s0 + s3 + s2 + s1[:-1] + + assert (c == ref[0] or + c == ref[1] or + c == ref[2] or + c == ref[3] or + c == ref[4] or + c == ref[5]) + + +def test_dereference_printing(): + expr = x + y + sin(z) + z + assert rcode(expr, dereference=[z]) == "x + y + (*z) + sin((*z))" + + +def test_Matrix_printing(): + # Test returning a Matrix + mat = Matrix([x*y, Piecewise((2 + x, y>0), (y, True)), sin(z)]) + A = MatrixSymbol('A', 3, 1) + p = rcode(mat, A) + assert p == ( + "A[0] = x*y;\n" + "A[1] = ifelse(y > 0,x + 2,y);\n" + "A[2] = sin(z);") + # Test using MatrixElements in expressions + expr = Piecewise((2*A[2, 0], x > 0), (A[2, 0], True)) + sin(A[1, 0]) + A[0, 0] + p = rcode(expr) + assert p == ("ifelse(x > 0,2*A[2],A[2]) + sin(A[1]) + A[0]") + # Test using MatrixElements in a Matrix + q = MatrixSymbol('q', 5, 1) + M = MatrixSymbol('M', 3, 3) + m = Matrix([[sin(q[1,0]), 0, cos(q[2,0])], + [q[1,0] + q[2,0], q[3, 0], 5], + [2*q[4, 0]/q[1,0], sqrt(q[0,0]) + 4, 0]]) + assert rcode(m, M) == ( + "M[0] = sin(q[1]);\n" + "M[1] = 0;\n" + "M[2] = cos(q[2]);\n" + "M[3] = q[1] + q[2];\n" + "M[4] = q[3];\n" + "M[5] = 5;\n" + "M[6] = 2*q[4]/q[1];\n" + "M[7] = sqrt(q[0]) + 4;\n" + "M[8] = 0;") + + +def test_rcode_sgn(): + + expr = sign(x) * y + assert rcode(expr) == 'y*sign(x)' + p = rcode(expr, 'z') + assert p == 'z = y*sign(x);' + + p = rcode(sign(2 * x + x**2) * x + x**2) + assert p == "x^2 + x*sign(x^2 + 2*x)" + + expr = sign(cos(x)) + p = rcode(expr) + assert p == 'sign(cos(x))' + +def test_rcode_Assignment(): + assert rcode(Assignment(x, y + z)) == 'x = y + z;' + assert rcode(aug_assign(x, '+', y + z)) == 'x += y + z;' + + +def test_rcode_For(): + f = For(x, Range(0, 10, 2), [aug_assign(y, '*', x)]) + sol = rcode(f) + assert sol == ("for(x in seq(from=0, to=9, by=2){\n" + " y *= x;\n" + "}") + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(rcode(A[0, 0]) == "A[0]") + assert(rcode(3 * A[0, 0]) == "3*A[0]") + + F = C[0, 0].subs(C, A - B) + assert(rcode(F) == "(A - B)[0]") diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_repr.py b/lib/python3.10/site-packages/sympy/printing/tests/test_repr.py new file mode 100644 index 0000000000000000000000000000000000000000..da58883b4fb027ed82db842a0a1ce5f76a49a8bb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_repr.py @@ -0,0 +1,382 @@ +from __future__ import annotations +from typing import Any + +from sympy.external.gmpy import GROUND_TYPES +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.assumptions.ask import Q +from sympy.core.function import (Function, WildFunction) +from sympy.core.numbers import (AlgebraicNumber, Float, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.delta_functions import Heaviside +from sympy.logic.boolalg import (false, true) +from sympy.matrices.dense import (Matrix, ones) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.combinatorics import Cycle, Permutation +from sympy.core.symbol import Str +from sympy.geometry import Point, Ellipse +from sympy.printing import srepr +from sympy.polys import ring, field, ZZ, QQ, lex, grlex, Poly +from sympy.polys.polyclasses import DMP +from sympy.polys.agca.extensions import FiniteExtension + +x, y = symbols('x,y') + +# eval(srepr(expr)) == expr has to succeed in the right environment. The right +# environment is the scope of "from sympy import *" for most cases. +ENV: dict[str, Any] = {"Str": Str} +exec("from sympy import *", ENV) + + +def sT(expr, string, import_stmt=None, **kwargs): + """ + sT := sreprTest + + Tests that srepr delivers the expected string and that + the condition eval(srepr(expr))==expr holds. + """ + if import_stmt is None: + ENV2 = ENV + else: + ENV2 = ENV.copy() + exec(import_stmt, ENV2) + + assert srepr(expr, **kwargs) == string + assert eval(string, ENV2) == expr + + +def test_printmethod(): + class R(Abs): + def _sympyrepr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert srepr(R(x)) == "foo(Symbol('x'))" + + +def test_Add(): + sT(x + y, "Add(Symbol('x'), Symbol('y'))") + assert srepr(x**2 + 1, order='lex') == "Add(Pow(Symbol('x'), Integer(2)), Integer(1))" + assert srepr(x**2 + 1, order='old') == "Add(Integer(1), Pow(Symbol('x'), Integer(2)))" + assert srepr(sympify('x + 3 - 2', evaluate=False), order='none') == "Add(Symbol('x'), Integer(3), Mul(Integer(-1), Integer(2)))" + + +def test_more_than_255_args_issue_10259(): + from sympy.core.add import Add + from sympy.core.mul import Mul + for op in (Add, Mul): + expr = op(*symbols('x:256')) + assert eval(srepr(expr)) == expr + + +def test_Function(): + sT(Function("f")(x), "Function('f')(Symbol('x'))") + # test unapplied Function + sT(Function('f'), "Function('f')") + + sT(sin(x), "sin(Symbol('x'))") + sT(sin, "sin") + + +def test_Heaviside(): + sT(Heaviside(x), "Heaviside(Symbol('x'))") + sT(Heaviside(x, 1), "Heaviside(Symbol('x'), Integer(1))") + + +def test_Geometry(): + sT(Point(0, 0), "Point2D(Integer(0), Integer(0))") + sT(Ellipse(Point(0, 0), 5, 1), + "Ellipse(Point2D(Integer(0), Integer(0)), Integer(5), Integer(1))") + # TODO more tests + + +def test_Singletons(): + sT(S.Catalan, 'Catalan') + sT(S.ComplexInfinity, 'zoo') + sT(S.EulerGamma, 'EulerGamma') + sT(S.Exp1, 'E') + sT(S.GoldenRatio, 'GoldenRatio') + sT(S.TribonacciConstant, 'TribonacciConstant') + sT(S.Half, 'Rational(1, 2)') + sT(S.ImaginaryUnit, 'I') + sT(S.Infinity, 'oo') + sT(S.NaN, 'nan') + sT(S.NegativeInfinity, '-oo') + sT(S.NegativeOne, 'Integer(-1)') + sT(S.One, 'Integer(1)') + sT(S.Pi, 'pi') + sT(S.Zero, 'Integer(0)') + sT(S.Complexes, 'Complexes') + sT(S.EmptySequence, 'EmptySequence') + sT(S.EmptySet, 'EmptySet') + # sT(S.IdentityFunction, 'Lambda(_x, _x)') + sT(S.Naturals, 'Naturals') + sT(S.Naturals0, 'Naturals0') + sT(S.Rationals, 'Rationals') + sT(S.Reals, 'Reals') + sT(S.UniversalSet, 'UniversalSet') + + +def test_Integer(): + sT(Integer(4), "Integer(4)") + + +def test_list(): + sT([x, Integer(4)], "[Symbol('x'), Integer(4)]") + + +def test_Matrix(): + for cls, name in [(Matrix, "MutableDenseMatrix"), (ImmutableDenseMatrix, "ImmutableDenseMatrix")]: + sT(cls([[x**+1, 1], [y, x + y]]), + "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + sT(cls(), "%s([])" % name) + + sT(cls([[x**+1, 1], [y, x + y]]), "%s([[Symbol('x'), Integer(1)], [Symbol('y'), Add(Symbol('x'), Symbol('y'))]])" % name) + + +def test_empty_Matrix(): + sT(ones(0, 3), "MutableDenseMatrix(0, 3, [])") + sT(ones(4, 0), "MutableDenseMatrix(4, 0, [])") + sT(ones(0, 0), "MutableDenseMatrix([])") + + +def test_Rational(): + sT(Rational(1, 3), "Rational(1, 3)") + sT(Rational(-1, 3), "Rational(-1, 3)") + + +def test_Float(): + sT(Float('1.23', dps=3), "Float('1.22998', precision=13)") + sT(Float('1.23456789', dps=9), "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', dps=19), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', dps=15), + "Float('0.60038617995049726', precision=53)") + + sT(Float('1.23', precision=13), "Float('1.22998', precision=13)") + sT(Float('1.23456789', precision=33), + "Float('1.23456788994', precision=33)") + sT(Float('1.234567890123456789', precision=66), + "Float('1.234567890123456789013', precision=66)") + sT(Float('0.60038617995049726', precision=53), + "Float('0.60038617995049726', precision=53)") + + sT(Float('0.60038617995049726', 15), + "Float('0.60038617995049726', precision=53)") + + +def test_Symbol(): + sT(x, "Symbol('x')") + sT(y, "Symbol('y')") + sT(Symbol('x', negative=True), "Symbol('x', negative=True)") + + +def test_Symbol_two_assumptions(): + x = Symbol('x', negative=0, integer=1) + # order could vary + s1 = "Symbol('x', integer=True, negative=False)" + s2 = "Symbol('x', negative=False, integer=True)" + assert srepr(x) in (s1, s2) + assert eval(srepr(x), ENV) == x + + +def test_Symbol_no_special_commutative_treatment(): + sT(Symbol('x'), "Symbol('x')") + sT(Symbol('x', commutative=False), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=0), "Symbol('x', commutative=False)") + sT(Symbol('x', commutative=True), "Symbol('x', commutative=True)") + sT(Symbol('x', commutative=1), "Symbol('x', commutative=True)") + + +def test_Wild(): + sT(Wild('x', even=True), "Wild('x', even=True)") + + +def test_Dummy(): + d = Dummy('d') + sT(d, "Dummy('d', dummy_index=%s)" % str(d.dummy_index)) + + +def test_Dummy_assumption(): + d = Dummy('d', nonzero=True) + assert d == eval(srepr(d)) + s1 = "Dummy('d', dummy_index=%s, nonzero=True)" % str(d.dummy_index) + s2 = "Dummy('d', nonzero=True, dummy_index=%s)" % str(d.dummy_index) + assert srepr(d) in (s1, s2) + + +def test_Dummy_from_Symbol(): + # should not get the full dictionary of assumptions + n = Symbol('n', integer=True) + d = n.as_dummy() + assert srepr(d + ) == "Dummy('n', dummy_index=%s)" % str(d.dummy_index) + + +def test_tuple(): + sT((x,), "(Symbol('x'),)") + sT((x, y), "(Symbol('x'), Symbol('y'))") + + +def test_WildFunction(): + sT(WildFunction('w'), "WildFunction('w')") + + +def test_settins(): + raises(TypeError, lambda: srepr(x, method="garbage")) + + +def test_Mul(): + sT(3*x**3*y, "Mul(Integer(3), Pow(Symbol('x'), Integer(3)), Symbol('y'))") + assert srepr(3*x**3*y, order='old') == "Mul(Integer(3), Symbol('y'), Pow(Symbol('x'), Integer(3)))" + assert srepr(sympify('(x+4)*2*x*7', evaluate=False), order='none') == "Mul(Add(Symbol('x'), Integer(4)), Integer(2), Symbol('x'), Integer(7))" + + +def test_AlgebraicNumber(): + a = AlgebraicNumber(sqrt(2)) + sT(a, "AlgebraicNumber(Pow(Integer(2), Rational(1, 2)), [Integer(1), Integer(0)])") + a = AlgebraicNumber(root(-2, 3)) + sT(a, "AlgebraicNumber(Pow(Integer(-2), Rational(1, 3)), [Integer(1), Integer(0)])") + + +def test_PolyRing(): + assert srepr(ring("x", ZZ, lex)[0]) == "PolyRing((Symbol('x'),), ZZ, lex)" + assert srepr(ring("x,y", QQ, grlex)[0]) == "PolyRing((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(ring("x,y,z", ZZ["t"], lex)[0]) == "PolyRing((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_FracField(): + assert srepr(field("x", ZZ, lex)[0]) == "FracField((Symbol('x'),), ZZ, lex)" + assert srepr(field("x,y", QQ, grlex)[0]) == "FracField((Symbol('x'), Symbol('y')), QQ, grlex)" + assert srepr(field("x,y,z", ZZ["t"], lex)[0]) == "FracField((Symbol('x'), Symbol('y'), Symbol('z')), ZZ[t], lex)" + + +def test_PolyElement(): + R, x, y = ring("x,y", ZZ) + assert srepr(3*x**2*y + 1) == "PolyElement(PolyRing((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)])" + + +def test_FracElement(): + F, x, y = field("x,y", ZZ) + assert srepr((3*x**2*y + 1)/(x - y**2)) == "FracElement(FracField((Symbol('x'), Symbol('y')), ZZ, lex), [((2, 1), 3), ((0, 0), 1)], [((1, 0), 1), ((0, 2), -1)])" + + +def test_FractionField(): + assert srepr(QQ.frac_field(x)) == \ + "FractionField(FracField((Symbol('x'),), QQ, lex))" + assert srepr(QQ.frac_field(x, y, order=grlex)) == \ + "FractionField(FracField((Symbol('x'), Symbol('y')), QQ, grlex))" + + +def test_PolynomialRingBase(): + assert srepr(ZZ.old_poly_ring(x)) == \ + "GlobalPolynomialRing(ZZ, Symbol('x'))" + assert srepr(ZZ[x].old_poly_ring(y)) == \ + "GlobalPolynomialRing(ZZ[x], Symbol('y'))" + assert srepr(QQ.frac_field(x).old_poly_ring(y)) == \ + "GlobalPolynomialRing(FractionField(FracField((Symbol('x'),), QQ, lex)), Symbol('y'))" + + +def test_DMP(): + p1 = DMP([1, 2], ZZ) + p2 = ZZ.old_poly_ring(x)([1, 2]) + if GROUND_TYPES != 'flint': + assert srepr(p1) == "DMP_Python([1, 2], ZZ)" + assert srepr(p2) == "DMP_Python([1, 2], ZZ)" + else: + assert srepr(p1) == "DUP_Flint([1, 2], ZZ)" + assert srepr(p2) == "DUP_Flint([1, 2], ZZ)" + + +def test_FiniteExtension(): + assert srepr(FiniteExtension(Poly(x**2 + 1, x))) == \ + "FiniteExtension(Poly(x**2 + 1, x, domain='ZZ'))" + + +def test_ExtensionElement(): + A = FiniteExtension(Poly(x**2 + 1, x)) + if GROUND_TYPES != 'flint': + ans = "ExtElem(DMP_Python([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + else: + ans = "ExtElem(DUP_Flint([1, 0], ZZ), FiniteExtension(Poly(x**2 + 1, x, domain='ZZ')))" + assert srepr(A.generator) == ans + +def test_BooleanAtom(): + assert srepr(true) == "true" + assert srepr(false) == "false" + + +def test_Integers(): + sT(S.Integers, "Integers") + + +def test_Naturals(): + sT(S.Naturals, "Naturals") + + +def test_Naturals0(): + sT(S.Naturals0, "Naturals0") + + +def test_Reals(): + sT(S.Reals, "Reals") + + +def test_matrix_expressions(): + n = symbols('n', integer=True) + A = MatrixSymbol("A", n, n) + B = MatrixSymbol("B", n, n) + sT(A, "MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True))") + sT(A*B, "MatMul(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + sT(A + B, "MatAdd(MatrixSymbol(Str('A'), Symbol('n', integer=True), Symbol('n', integer=True)), MatrixSymbol(Str('B'), Symbol('n', integer=True), Symbol('n', integer=True)))") + + +def test_Cycle(): + # FIXME: sT fails because Cycle is not immutable and calling srepr(Cycle(1, 2)) + # adds keys to the Cycle dict (GH-17661) + #import_stmt = "from sympy.combinatorics import Cycle" + #sT(Cycle(1, 2), "Cycle(1, 2)", import_stmt) + assert srepr(Cycle(1, 2)) == "Cycle(1, 2)" + + +def test_Permutation(): + import_stmt = "from sympy.combinatorics import Permutation" + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt, perm_cyclic=False) + sT(Permutation(1, 2)(3, 4), "Permutation(1, 2)(3, 4)", import_stmt, perm_cyclic=True) + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + sT(Permutation(1, 2)(3, 4), "Permutation([0, 2, 1, 4, 3])", import_stmt) + Permutation.print_cyclic = old_print_cyclic + +def test_dict(): + from sympy.abc import x, y, z + d = {} + assert srepr(d) == "{}" + d = {x: y} + assert srepr(d) == "{Symbol('x'): Symbol('y')}" + d = {x: y, y: z} + assert srepr(d) in ( + "{Symbol('x'): Symbol('y'), Symbol('y'): Symbol('z')}", + "{Symbol('y'): Symbol('z'), Symbol('x'): Symbol('y')}", + ) + d = {x: {y: z}} + assert srepr(d) == "{Symbol('x'): {Symbol('y'): Symbol('z')}}" + +def test_set(): + from sympy.abc import x, y + s = set() + assert srepr(s) == "set()" + s = {x, y} + assert srepr(s) in ("{Symbol('x'), Symbol('y')}", "{Symbol('y'), Symbol('x')}") + +def test_Predicate(): + sT(Q.even, "Q.even") + +def test_AppliedPredicate(): + sT(Q.even(Symbol('z')), "AppliedPredicate(Q.even, Symbol('z'))") diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_rust.py b/lib/python3.10/site-packages/sympy/printing/tests/test_rust.py new file mode 100644 index 0000000000000000000000000000000000000000..1c2a443422bb08562523eb7fdcf98f6cda287b43 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_rust.py @@ -0,0 +1,360 @@ +from sympy.core import (S, pi, oo, symbols, Rational, Integer, + GoldenRatio, EulerGamma, Catalan, Lambda, Dummy, + Eq, Ne, Le, Lt, Gt, Ge, Mod) +from sympy.functions import (Piecewise, sin, cos, Abs, exp, ceiling, sqrt, + sign, floor) +from sympy.logic import ITE +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function +from sympy.tensor import IndexedBase, Idx +from sympy.matrices import MatrixSymbol, SparseMatrix, Matrix + +from sympy.printing.rust import rust_code + +x, y, z = symbols('x,y,z') + + +def test_Integer(): + assert rust_code(Integer(42)) == "42" + assert rust_code(Integer(-56)) == "-56" + + +def test_Relational(): + assert rust_code(Eq(x, y)) == "x == y" + assert rust_code(Ne(x, y)) == "x != y" + assert rust_code(Le(x, y)) == "x <= y" + assert rust_code(Lt(x, y)) == "x < y" + assert rust_code(Gt(x, y)) == "x > y" + assert rust_code(Ge(x, y)) == "x >= y" + + +def test_Rational(): + assert rust_code(Rational(3, 7)) == "3_f64/7.0" + assert rust_code(Rational(18, 9)) == "2" + assert rust_code(Rational(3, -7)) == "-3_f64/7.0" + assert rust_code(Rational(-3, -7)) == "3_f64/7.0" + assert rust_code(x + Rational(3, 7)) == "x + 3_f64/7.0" + assert rust_code(Rational(3, 7)*x) == "(3_f64/7.0)*x" + + +def test_basic_ops(): + assert rust_code(x + y) == "x + y" + assert rust_code(x - y) == "x - y" + assert rust_code(x * y) == "x*y" + assert rust_code(x / y) == "x/y" + assert rust_code(-x) == "-x" + + +def test_printmethod(): + class fabs(Abs): + def _rust_code(self, printer): + return "%s.fabs()" % printer._print(self.args[0]) + assert rust_code(fabs(x)) == "x.fabs()" + a = MatrixSymbol("a", 1, 3) + assert rust_code(a[0,0]) == 'a[0]' + + +def test_Functions(): + assert rust_code(sin(x) ** cos(x)) == "x.sin().powf(x.cos())" + assert rust_code(abs(x)) == "x.abs()" + assert rust_code(ceiling(x)) == "x.ceil()" + assert rust_code(floor(x)) == "x.floor()" + + # Automatic rewrite + assert rust_code(Mod(x, 3)) == 'x - 3*((1_f64/3.0)*x).floor()' + + +def test_Pow(): + assert rust_code(1/x) == "x.recip()" + assert rust_code(x**-1) == rust_code(x**-1.0) == "x.recip()" + assert rust_code(sqrt(x)) == "x.sqrt()" + assert rust_code(x**S.Half) == rust_code(x**0.5) == "x.sqrt()" + + assert rust_code(1/sqrt(x)) == "x.sqrt().recip()" + assert rust_code(x**-S.Half) == rust_code(x**-0.5) == "x.sqrt().recip()" + + assert rust_code(1/pi) == "PI.recip()" + assert rust_code(pi**-1) == rust_code(pi**-1.0) == "PI.recip()" + assert rust_code(pi**-0.5) == "PI.sqrt().recip()" + + assert rust_code(x**Rational(1, 3)) == "x.cbrt()" + assert rust_code(2**x) == "x.exp2()" + assert rust_code(exp(x)) == "x.exp()" + assert rust_code(x**3) == "x.powi(3)" + assert rust_code(x**(y**3)) == "x.powf(y.powi(3))" + assert rust_code(x**Rational(2, 3)) == "x.powf(2_f64/3.0)" + + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(1/(g(x)*3.5)**(x - y**x)/(x**2 + y)) == \ + "(3.5*2*x).powf(-x + y.powf(x))/(x.powi(2) + y)" + _cond_cfunc = [(lambda base, exp: exp.is_integer, "dpowi", 1), + (lambda base, exp: not exp.is_integer, "pow", 1)] + assert rust_code(x**3, user_functions={'Pow': _cond_cfunc}) == 'x.dpowi(3)' + assert rust_code(x**3.2, user_functions={'Pow': _cond_cfunc}) == 'x.pow(3.2)' + + +def test_constants(): + assert rust_code(pi) == "PI" + assert rust_code(oo) == "INFINITY" + assert rust_code(S.Infinity) == "INFINITY" + assert rust_code(-oo) == "NEG_INFINITY" + assert rust_code(S.NegativeInfinity) == "NEG_INFINITY" + assert rust_code(S.NaN) == "NAN" + assert rust_code(exp(1)) == "E" + assert rust_code(S.Exp1) == "E" + + +def test_constants_other(): + assert rust_code(2*GoldenRatio) == "const GoldenRatio: f64 = %s;\n2*GoldenRatio" % GoldenRatio.evalf(17) + assert rust_code( + 2*Catalan) == "const Catalan: f64 = %s;\n2*Catalan" % Catalan.evalf(17) + assert rust_code(2*EulerGamma) == "const EulerGamma: f64 = %s;\n2*EulerGamma" % EulerGamma.evalf(17) + + +def test_boolean(): + assert rust_code(True) == "true" + assert rust_code(S.true) == "true" + assert rust_code(False) == "false" + assert rust_code(S.false) == "false" + assert rust_code(x & y) == "x && y" + assert rust_code(x | y) == "x || y" + assert rust_code(~x) == "!x" + assert rust_code(x & y & z) == "x && y && z" + assert rust_code(x | y | z) == "x || y || z" + assert rust_code((x & y) | z) == "z || x && y" + assert rust_code((x | y) & z) == "z && (x || y)" + + +def test_Piecewise(): + expr = Piecewise((x, x < 1), (x + 2, True)) + assert rust_code(expr) == ( + "if (x < 1) {\n" + " x\n" + "} else {\n" + " x + 2\n" + "}") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1) {\n" + " x\n" + "} else {\n" + " x + 2\n" + "};") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1) { x } else { x + 2 };") + expr = Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }") + assert rust_code(expr, assign_to="r", inline=True) == ( + "r = if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 };") + assert rust_code(expr, assign_to="r") == ( + "r = if (x < 1) {\n" + " x\n" + "} else if (x < 5) {\n" + " x + 1\n" + "} else {\n" + " x + 2\n" + "};") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) + assert rust_code(expr, inline=True) == ( + "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 }") + expr = 2*Piecewise((x, x < 1), (x + 1, x < 5), (x + 2, True)) - 42 + assert rust_code(expr, inline=True) == ( + "2*if (x < 1) { x } else if (x < 5) { x + 1 } else { x + 2 } - 42") + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x**2, x > 1), (sin(x), x > 0)) + raises(ValueError, lambda: rust_code(expr)) + + +def test_dereference_printing(): + expr = x + y + sin(z) + z + assert rust_code(expr, dereference=[z]) == "x + y + (*z) + (*z).sin()" + + +def test_sign(): + expr = sign(x) * y + assert rust_code(expr) == "y*x.signum()" + assert rust_code(expr, assign_to='r') == "r = y*x.signum();" + + expr = sign(x + y) + 42 + assert rust_code(expr) == "(x + y).signum() + 42" + assert rust_code(expr, assign_to='r') == "r = (x + y).signum() + 42;" + + expr = sign(cos(x)) + assert rust_code(expr) == "x.cos().signum()" + + +def test_reserved_words(): + + x, y = symbols("x if") + + expr = sin(y) + assert rust_code(expr) == "if_.sin()" + assert rust_code(expr, dereference=[y]) == "(*if_).sin()" + assert rust_code(expr, reserved_word_suffix='_unreserved') == "if_unreserved.sin()" + + with raises(ValueError): + rust_code(expr, error_on_reserved=True) + + +def test_ITE(): + expr = ITE(x < 1, y, z) + assert rust_code(expr) == ( + "if (x < 1) {\n" + " y\n" + "} else {\n" + " z\n" + "}") + + +def test_Indexed(): + n, m, o = symbols('n m o', integer=True) + i, j, k = Idx('i', n), Idx('j', m), Idx('k', o) + + x = IndexedBase('x')[j] + assert rust_code(x) == "x[j]" + + A = IndexedBase('A')[i, j] + assert rust_code(A) == "A[m*i + j]" + + B = IndexedBase('B')[i, j, k] + assert rust_code(B) == "B[m*o*i + o*j + k]" + + +def test_dummy_loops(): + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + + assert rust_code(x[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i];\n" + "}") + + +def test_loops(): + m, n = symbols('m n', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + z = IndexedBase('z') + i = Idx('i', m) + j = Idx('j', n) + + assert rust_code(A[i, j]*x[j], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + assert rust_code(A[i, j]*x[j] + x[i] + z[i], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = x[i] + z[i];\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " y[i] = A[n*i + j]*x[j] + y[i];\n" + " }\n" + "}") + + +def test_loops_multiple_contractions(): + n, m, o, p = symbols('n m o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + assert rust_code(b[j, k, l]*a[i, j, k, l], assign_to=y[i]) == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = a[%s]*b[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_loops_addfactor(): + m, n, o, p = symbols('m n o p', integer=True) + a = IndexedBase('a') + b = IndexedBase('b') + c = IndexedBase('c') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + + code = rust_code((a[i, j, k, l] + b[i, j, k, l])*c[j, k, l], assign_to=y[i]) + assert code == ( + "for i in 0..m {\n" + " y[i] = 0;\n" + "}\n" + "for i in 0..m {\n" + " for j in 0..n {\n" + " for k in 0..o {\n" + " for l in 0..p {\n" + " y[i] = (a[%s] + b[%s])*c[%s] + y[i];\n" % (i*n*o*p + j*o*p + k*p + l, i*n*o*p + j*o*p + k*p + l, j*o*p + k*p + l) +\ + " }\n" + " }\n" + " }\n" + "}") + + +def test_settings(): + raises(TypeError, lambda: rust_code(sin(x), method="garbage")) + + +def test_inline_function(): + x = symbols('x') + g = implemented_function('g', Lambda(x, 2*x)) + assert rust_code(g(x)) == "2*x" + + g = implemented_function('g', Lambda(x, 2*x/Catalan)) + assert rust_code(g(x)) == ( + "const Catalan: f64 = %s;\n2*x/Catalan" % Catalan.evalf(17)) + + A = IndexedBase('A') + i = Idx('i', symbols('n', integer=True)) + g = implemented_function('g', Lambda(x, x*(1 + x)*(2 + x))) + assert rust_code(g(A[i]), assign_to=A[i]) == ( + "for i in 0..n {\n" + " A[i] = (A[i] + 1)*(A[i] + 2)*A[i];\n" + "}") + + +def test_user_functions(): + x = symbols('x', integer=False) + n = symbols('n', integer=True) + custom_functions = { + "ceiling": "ceil", + "Abs": [(lambda x: not x.is_integer, "fabs", 4), (lambda x: x.is_integer, "abs", 4)], + } + assert rust_code(ceiling(x), user_functions=custom_functions) == "x.ceil()" + assert rust_code(Abs(x), user_functions=custom_functions) == "fabs(x)" + assert rust_code(Abs(n), user_functions=custom_functions) == "abs(n)" + + +def test_matrix(): + assert rust_code(Matrix([1, 2, 3])) == '[1, 2, 3]' + with raises(ValueError): + rust_code(Matrix([[1, 2, 3]])) + + +def test_sparse_matrix(): + # gh-15791 + with raises(NotImplementedError): + rust_code(SparseMatrix([[1, 2, 3]])) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_smtlib.py b/lib/python3.10/site-packages/sympy/printing/tests/test_smtlib.py new file mode 100644 index 0000000000000000000000000000000000000000..23566f707beaf04e8196a0b221b28f4fbff6fe23 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_smtlib.py @@ -0,0 +1,553 @@ +import contextlib +import itertools +import re +import typing +from enum import Enum +from typing import Callable + +import sympy +from sympy import Add, Implies, sqrt +from sympy.core import Mul, Pow +from sympy.core import (S, pi, symbols, Function, Rational, Integer, + Symbol, Eq, Ne, Le, Lt, Gt, Ge) +from sympy.functions import Piecewise, exp, sin, cos +from sympy.assumptions.ask import Q +from sympy.printing.smtlib import smtlib_code +from sympy.testing.pytest import raises, Failed + +x, y, z = symbols('x,y,z') + + +class _W(Enum): + DEFAULTING_TO_FLOAT = re.compile("Could not infer type of `.+`. Defaulting to float.", re.I) + WILL_NOT_DECLARE = re.compile("Non-Symbol/Function `.+` will not be declared.", re.I) + WILL_NOT_ASSERT = re.compile("Non-Boolean expression `.+` will not be asserted. Converting to SMTLib verbatim.", re.I) + + +@contextlib.contextmanager +def _check_warns(expected: typing.Iterable[_W]): + warns: typing.List[str] = [] + log_warn = warns.append + yield log_warn + + errors = [] + for i, (w, e) in enumerate(itertools.zip_longest(warns, expected)): + if not e: + errors += [f"[{i}] Received unexpected warning `{w}`."] + elif not w: + errors += [f"[{i}] Did not receive expected warning `{e.name}`."] + elif not e.value.match(w): + errors += [f"[{i}] Warning `{w}` does not match expected {e.name}."] + + if errors: raise Failed('\n'.join(errors)) + + +def test_Integer(): + with _check_warns([_W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(Integer(67), log_warn=w) == "67" + assert smtlib_code(Integer(-1), log_warn=w) == "-1" + with _check_warns([]) as w: + assert smtlib_code(Integer(67)) == "67" + assert smtlib_code(Integer(-1)) == "-1" + + +def test_Rational(): + with _check_warns([_W.WILL_NOT_ASSERT] * 4) as w: + assert smtlib_code(Rational(3, 7), log_warn=w) == "(/ 3 7)" + assert smtlib_code(Rational(18, 9), log_warn=w) == "2" + assert smtlib_code(Rational(3, -7), log_warn=w) == "(/ -3 7)" + assert smtlib_code(Rational(-3, -7), log_warn=w) == "(/ 3 7)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT] * 2) as w: + assert smtlib_code(x + Rational(3, 7), auto_declare=False, log_warn=w) == "(+ (/ 3 7) x)" + assert smtlib_code(Rational(3, 7) * x, log_warn=w) == "(declare-const x Real)\n" \ + "(* (/ 3 7) x)" + + +def test_Relational(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + +def test_AppliedBinaryRelation(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 12) as w: + assert smtlib_code(Q.eq(x, y), auto_declare=False, log_warn=w) == "(assert (= x y))" + assert smtlib_code(Q.ne(x, y), auto_declare=False, log_warn=w) == "(assert (not (= x y)))" + assert smtlib_code(Q.lt(x, y), auto_declare=False, log_warn=w) == "(assert (< x y))" + assert smtlib_code(Q.le(x, y), auto_declare=False, log_warn=w) == "(assert (<= x y))" + assert smtlib_code(Q.gt(x, y), auto_declare=False, log_warn=w) == "(assert (> x y))" + assert smtlib_code(Q.ge(x, y), auto_declare=False, log_warn=w) == "(assert (>= x y))" + + raises(ValueError, lambda: smtlib_code(Q.complex(x), log_warn=w)) + + +def test_AppliedPredicate(): + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 6) as w: + assert smtlib_code(Q.positive(x), auto_declare=False, log_warn=w) == "(assert (> x 0))" + assert smtlib_code(Q.negative(x), auto_declare=False, log_warn=w) == "(assert (< x 0))" + assert smtlib_code(Q.zero(x), auto_declare=False, log_warn=w) == "(assert (= x 0))" + assert smtlib_code(Q.nonpositive(x), auto_declare=False, log_warn=w) == "(assert (<= x 0))" + assert smtlib_code(Q.nonnegative(x), auto_declare=False, log_warn=w) == "(assert (>= x 0))" + assert smtlib_code(Q.nonzero(x), auto_declare=False, log_warn=w) == "(assert (not (= x 0)))" + +def test_Function(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(sin(x) ** cos(x), auto_declare=False, log_warn=w) == "(pow (sin x) (cos x))" + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + abs(x), + symbol_table={x: int, y: bool}, + known_types={int: "INTEGER_TYPE"}, + known_functions={sympy.Abs: "ABSOLUTE_VALUE_OF"}, + log_warn=w + ) == "(declare-const x INTEGER_TYPE)\n" \ + "(ABSOLUTE_VALUE_OF x)" + + my_fun1 = Function('f1') + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], float]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Real)\n" \ + "(f1 x)" + + with _check_warns([]) as w: + assert smtlib_code( + my_fun1(x), + symbol_table={my_fun1: Callable[[bool], bool]}, + log_warn=w + ) == "(declare-const x Bool)\n" \ + "(declare-fun f1 (Bool) Bool)\n" \ + "(assert (f1 x))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(declare-fun f1 (Int Bool) Bool)\n" \ + "(assert (= (f1 x z) y))" + + assert smtlib_code( + Eq(my_fun1(x, z), y), + symbol_table={my_fun1: Callable[[int, bool], bool]}, + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Int)\n" \ + "(declare-const y Bool)\n" \ + "(declare-const z Bool)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 3) as w: + assert smtlib_code( + Eq(my_fun1(x, z), y), + known_functions={my_fun1: "MY_KNOWN_FUN", Eq: '=='}, + log_warn=w + ) == "(declare-const x Real)\n" \ + "(declare-const y Real)\n" \ + "(declare-const z Real)\n" \ + "(assert (== (MY_KNOWN_FUN x z) y))" + + +def test_Pow(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** 3, auto_declare=False, log_warn=w) == "(pow x 3)" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** (y ** 3), auto_declare=False, log_warn=w) == "(pow x (pow y 3))" + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x ** Rational(2, 3), auto_declare=False, log_warn=w) == '(pow x (/ 2 3))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + def g(x): return 2 * x + + # if x=1, y=2, then expr=2.333... + expr = 1 / (g(a) * 3.5) ** (a - b ** a) / (a ** 2 + b) + + with _check_warns([]) as w: + assert smtlib_code( + [ + Eq(a < 2, c), + Eq(b > a, c), + c & True, + Eq(expr, 2 + Rational(1, 3)) + ], + log_warn=w + ) == '(declare-const a Int)\n' \ + '(declare-const b Real)\n' \ + '(declare-const c Bool)\n' \ + '(assert (= (< a 2) c))\n' \ + '(assert (= (> b a) c))\n' \ + '(assert c)\n' \ + '(assert (= ' \ + '(* (pow (* 7.0 a) (+ (pow b a) (* -1 a))) (pow (+ b (pow a 2)) -1)) ' \ + '(/ 7 3)' \ + '))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Mul(-2, c, Pow(Mul(b, b, evaluate=False), -1, evaluate=False), evaluate=False), + log_warn=w + ) == '(declare-const b Real)\n' \ + '(declare-const c Real)\n' \ + '(* -2 c (pow (* b b) -1))' + + +def test_basic_ops(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x * y, auto_declare=False, log_warn=w) == "(* x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(x + y, auto_declare=False, log_warn=w) == "(+ x y)" + + # with _check_warns([_SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.DEFAULTING_TO_FLOAT, _SmtlibWarnings.WILL_NOT_ASSERT]) as w: + # todo: implement re-write, currently does '(+ x (* -1 y))' instead + # assert smtlib_code(x - y, auto_declare=False, log_warn=w) == "(- x y)" + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(-x, auto_declare=False, log_warn=w) == "(* -1 x)" + + +def test_quantifier_extensions(): + from sympy.logic.boolalg import Boolean + from sympy import Interval, Tuple, sympify + + # start For-all quantifier class example + class ForAll(Boolean): + def _smtlib(self, printer): + bound_symbol_declarations = [ + printer._s_expr(sym.name, [ + printer._known_types[printer.symbol_table[sym]], + Interval(start, end) + ]) for sym, start, end in self.limits + ] + return printer._s_expr('forall', [ + printer._s_expr('', bound_symbol_declarations), + self.function + ]) + + @property + def bound_symbols(self): + return {s for s, _, _ in self.limits} + + @property + def free_symbols(self): + bound_symbol_names = {s.name for s in self.bound_symbols} + return { + s for s in self.function.free_symbols + if s.name not in bound_symbol_names + } + + def __new__(cls, *args): + limits = [sympify(a) for a in args if isinstance(a, (tuple, Tuple))] + function = [sympify(a) for a in args if isinstance(a, Boolean)] + assert len(limits) + len(function) == len(args) + assert len(function) == 1 + function = function[0] + + if isinstance(function, ForAll): return ForAll.__new__( + ForAll, *(limits + function.limits), function.function + ) + inst = Boolean.__new__(cls) + inst._args = tuple(limits + [function]) + inst.limits = limits + inst.function = function + return inst + + # end For-All Quantifier class example + + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + ForAll((x, -42, +21), Eq(f(x), f(x))), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(assert (forall ( (x Real [-42, 21])) true))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT] * 2) as w: + assert smtlib_code( + ForAll( + (x, -42, +21), (y, -100, 3), + Implies(Eq(x, y), Eq(f(x), f(y))) + ), + symbol_table={f: Callable[[float], float]}, + log_warn=w + ) == '(declare-fun f (Real) Real)\n' \ + '(assert (' \ + 'forall ( (x Real [-42, 21]) (y Real [-100, 3])) ' \ + '(=> (= x y) (= (f x) (f y)))' \ + '))' + + a = Symbol('a', integer=True) + b = Symbol('b', real=True) + c = Symbol('c') + + with _check_warns([]) as w: + assert smtlib_code( + ForAll( + (a, 2, 100), ForAll( + (b, 2, 100), + Implies(a < b, sqrt(a) < b) | c + )), + log_warn=w + ) == '(declare-const c Bool)\n' \ + '(assert (forall ( (a Int [2, 100]) (b Real [2, 100])) ' \ + '(or c (=> (< a b) (< (pow a (/ 1 2)) b)))' \ + '))' + + +def test_mix_number_mult_symbols(): + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + 1 / pi, + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + [ + Eq(pi, 3.14, evaluate=False), + 1 / pi, + ], + known_constants={pi: "MY_PI"}, + log_warn=w + ) == '(assert (= MY_PI 3.14))\n' \ + '(pow MY_PI -1)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p', S.GoldenRatio: 'g', + S.Exp1: 'e' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p g)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={ + S.Pi: 'p' + }, + known_functions={ + Add: 'plus', + exp: 'exp' + }, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) (exp 1) p 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) 2.72 3.14 1.62)' + + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Add(S.Zero, S.One, S.NegativeOne, S.Half, + S.Exp1, S.Pi, S.GoldenRatio, evaluate=False), + known_constants={S.Exp1: 'e'}, + known_functions={Add: 'plus'}, + precision=3, + log_warn=w + ) == '(plus 0 1 -1 (/ 1 2) e 3.14 1.62)' + + +def test_boolean(): + with _check_warns([]) as w: + assert smtlib_code(x & y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (and x y))' + assert smtlib_code(x | y, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(assert (or x y))' + assert smtlib_code(~x, log_warn=w) == '(declare-const x Bool)\n' \ + '(assert (not x))' + assert smtlib_code(x & y & z, log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Bool)\n' \ + '(assert (and x y z))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code((x & ~y) | (z > 3), log_warn=w) == '(declare-const x Bool)\n' \ + '(declare-const y Bool)\n' \ + '(declare-const z Real)\n' \ + '(assert (or (> z 3) (and x (not y))))' + + f = Function('f') + g = Function('g') + h = Function('h') + with _check_warns([_W.DEFAULTING_TO_FLOAT]) as w: + assert smtlib_code( + [Gt(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Real)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (> (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Lt(y, g(z))], + symbol_table={ + f: Callable[[bool], int], g: Callable[[bool], int], + }, log_warn=w + ) == '(declare-const x Bool)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Bool) Int)\n' \ + '(declare-fun g (Bool) Int)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (< y (g z)))' + + with _check_warns([]) as w: + assert smtlib_code( + [Eq(f(x), y), + Eq(g(f(x)), z), + Eq(h(g(f(x))), x)], + symbol_table={ + f: Callable[[float], int], + g: Callable[[int], bool], + h: Callable[[bool], float] + }, + log_warn=w + ) == '(declare-const x Real)\n' \ + '(declare-const y Int)\n' \ + '(declare-const z Bool)\n' \ + '(declare-fun f (Real) Int)\n' \ + '(declare-fun g (Int) Bool)\n' \ + '(declare-fun h (Bool) Real)\n' \ + '(assert (= (f x) y))\n' \ + '(assert (= (g (f x)) z))\n' \ + '(assert (= (h (g (f x))) x))' + + +# todo: make smtlib_code support arrays +# def test_containers(): +# assert julia_code([1, 2, 3, [4, 5, [6, 7]], 8, [9, 10], 11]) == \ +# "Any[1, 2, 3, Any[4, 5, Any[6, 7]], 8, Any[9, 10], 11]" +# assert julia_code((1, 2, (3, 4))) == "(1, 2, (3, 4))" +# assert julia_code([1]) == "Any[1]" +# assert julia_code((1,)) == "(1,)" +# assert julia_code(Tuple(*[1, 2, 3])) == "(1, 2, 3)" +# assert julia_code((1, x * y, (3, x ** 2))) == "(1, x .* y, (3, x .^ 2))" +# # scalar, matrix, empty matrix and empty list +# assert julia_code((1, eye(3), Matrix(0, 0, []), [])) == "(1, [1 0 0;\n0 1 0;\n0 0 1], zeros(0, 0), Any[])" + +def test_smtlib_piecewise(): + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x, x < 1), + (x ** 2, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) x (pow x 2))' + + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code( + Piecewise((x ** 2, x < 1), + (x ** 3, x < 2), + (x ** 4, x < 3), + (x ** 5, True)), + auto_declare=False, + log_warn=w + ) == '(ite (< x 1) (pow x 2) ' \ + '(ite (< x 2) (pow x 3) ' \ + '(ite (< x 3) (pow x 4) ' \ + '(pow x 5))))' + + # Check that Piecewise without a True (default) condition error + expr = Piecewise((x, x < 1), (x ** 2, x > 1), (sin(x), x > 0)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(AssertionError, lambda: smtlib_code(expr, log_warn=w)) + + +def test_smtlib_piecewise_times_const(): + pw = Piecewise((x, x < 1), (x ** 2, True)) + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(2 * pw, log_warn=w) == '(declare-const x Real)\n(* 2 (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / x, log_warn=w) == '(declare-const x Real)\n(* (pow x -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / (x * y), log_warn=w) == '(declare-const x Real)\n(declare-const y Real)\n(* (pow x -1) (pow y -1) (ite (< x 1) x (pow x 2)))' + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + assert smtlib_code(pw / 3, log_warn=w) == '(declare-const x Real)\n(* (/ 1 3) (ite (< x 1) x (pow x 2)))' + + +# todo: make smtlib_code support arrays / matrices ? +# def test_smtlib_matrix_assign_to(): +# A = Matrix([[1, 2, 3]]) +# assert smtlib_code(A, assign_to='a') == "a = [1 2 3]" +# A = Matrix([[1, 2], [3, 4]]) +# assert smtlib_code(A, assign_to='A') == "A = [1 2;\n3 4]" + +# def test_julia_matrix_1x1(): +# A = Matrix([[3]]) +# B = MatrixSymbol('B', 1, 1) +# C = MatrixSymbol('C', 1, 2) +# assert julia_code(A, assign_to=B) == "B = [3]" +# raises(ValueError, lambda: julia_code(A, assign_to=C)) + +# def test_julia_matrix_elements(): +# A = Matrix([[x, 2, x * y]]) +# assert julia_code(A[0, 0] ** 2 + A[0, 1] + A[0, 2]) == "x .^ 2 + x .* y + 2" +# A = MatrixSymbol('AA', 1, 3) +# assert julia_code(A) == "AA" +# assert julia_code(A[0, 0] ** 2 + sin(A[0, 1]) + A[0, 2]) == \ +# "sin(AA[1,2]) + AA[1,1] .^ 2 + AA[1,3]" +# assert julia_code(sum(A)) == "AA[1,1] + AA[1,2] + AA[1,3]" + +def test_smtlib_boolean(): + with _check_warns([]) as w: + assert smtlib_code(True, auto_assert=False, log_warn=w) == 'true' + assert smtlib_code(True, log_warn=w) == '(assert true)' + assert smtlib_code(S.true, log_warn=w) == '(assert true)' + assert smtlib_code(S.false, log_warn=w) == '(assert false)' + assert smtlib_code(False, log_warn=w) == '(assert false)' + assert smtlib_code(False, auto_assert=False, log_warn=w) == 'false' + + +def test_not_supported(): + f = Function('f') + with _check_warns([_W.DEFAULTING_TO_FLOAT, _W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(f(x).diff(x), symbol_table={f: Callable[[float], float]}, log_warn=w)) + with _check_warns([_W.WILL_NOT_ASSERT]) as w: + raises(KeyError, lambda: smtlib_code(S.ComplexInfinity, log_warn=w)) + + +def test_Float(): + assert smtlib_code(0.0) == "0.0" + assert smtlib_code(0.000000000000000003) == '(* 3.0 (pow 10 -18))' + assert smtlib_code(5.3) == "5.3" diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_str.py b/lib/python3.10/site-packages/sympy/printing/tests/test_str.py new file mode 100644 index 0000000000000000000000000000000000000000..96ecb66a0794fc8ed48f6ccfcfe234e9f851e1ad --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_str.py @@ -0,0 +1,1199 @@ +from sympy import MatAdd +from sympy.algebras.quaternion import Quaternion +from sympy.assumptions.ask import Q +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.combinatorics.partitions import Partition +from sympy.concrete.summations import (Sum, summation) +from sympy.core.add import Add +from sympy.core.containers import (Dict, Tuple) +from sympy.core.expr import UnevaluatedExpr, Expr +from sympy.core.function import (Derivative, Function, Lambda, Subs, WildFunction) +from sympy.core.mul import Mul +from sympy.core import (Catalan, EulerGamma, GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (E, Float, I, Integer, Rational, nan, oo, pi, zoo) +from sympy.core.parameters import _exp_is_pow +from sympy.core.power import Pow +from sympy.core.relational import (Eq, Rel, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.functions.combinatorial.factorials import (factorial, factorial2, subfactorial) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.zeta_functions import zeta +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (Equivalent, false, true, Xor) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.matrices.expressions import Identity +from sympy.matrices.expressions.slice import MatrixSlice +from sympy.matrices import SparseMatrix +from sympy.polys.polytools import factor +from sympy.series.limits import Limit +from sympy.series.order import O +from sympy.sets.sets import (Complement, FiniteSet, Interval, SymmetricDifference) +from sympy.stats import (Covariance, Expectation, Probability, Variance) +from sympy.stats.rv import RandomSymbol +from sympy.external import import_module +from sympy.physics.control.lti import TransferFunction, Series, Parallel, \ + Feedback, TransferFunctionMatrix, MIMOSeries, MIMOParallel, MIMOFeedback +from sympy.physics.units import second, joule +from sympy.polys import (Poly, rootof, RootSum, groebner, ring, field, ZZ, QQ, + ZZ_I, QQ_I, lex, grlex) +from sympy.geometry import Point, Circle, Polygon, Ellipse, Triangle +from sympy.tensor import NDimArray +from sympy.tensor.array.expressions.array_expressions import ArraySymbol, ArrayElement + +from sympy.testing.pytest import raises, warns_deprecated_sympy + +from sympy.printing import sstr, sstrrepr, StrPrinter +from sympy.physics.quantum.trace import Tr + +x, y, z, w, t = symbols('x,y,z,w,t') +d = Dummy('d') + + +def test_printmethod(): + class R(Abs): + def _sympystr(self, printer): + return "foo(%s)" % printer._print(self.args[0]) + assert sstr(R(x)) == "foo(x)" + + class R(Abs): + def _sympystr(self, printer): + return "foo" + assert sstr(R(x)) == "foo" + + +def test_Abs(): + assert str(Abs(x)) == "Abs(x)" + assert str(Abs(Rational(1, 6))) == "1/6" + assert str(Abs(Rational(-1, 6))) == "1/6" + + +def test_Add(): + assert str(x + y) == "x + y" + assert str(x + 1) == "x + 1" + assert str(x + x**2) == "x**2 + x" + assert str(Add(0, 1, evaluate=False)) == "0 + 1" + assert str(Add(0, 0, 1, evaluate=False)) == "0 + 0 + 1" + assert str(1.0*x) == "1.0*x" + assert str(5 + x + y + x*y + x**2 + y**2) == "x**2 + x*y + x + y**2 + y + 5" + assert str(1 + x + x**2/2 + x**3/3) == "x**3/3 + x**2/2 + x + 1" + assert str(2*x - 7*x**2 + 2 + 3*y) == "-7*x**2 + 2*x + 3*y + 2" + assert str(x - y) == "x - y" + assert str(2 - x) == "2 - x" + assert str(x - 2) == "x - 2" + assert str(x - y - z - w) == "-w + x - y - z" + assert str(x - z*y**2*z*w) == "-w*y**2*z**2 + x" + assert str(x - 1*y*x*y) == "-x*y**2 + x" + assert str(sin(x).series(x, 0, 15)) == "x - x**3/6 + x**5/120 - x**7/5040 + x**9/362880 - x**11/39916800 + x**13/6227020800 + O(x**15)" + assert str(Add(Add(-w, x, evaluate=False), Add(-y, z, evaluate=False), evaluate=False)) == "(-w + x) + (-y + z)" + assert str(Add(Add(-x, -y, evaluate=False), -z, evaluate=False)) == "-z + (-x - y)" + assert str(Add(Add(Add(-x, -y, evaluate=False), -z, evaluate=False), -t, evaluate=False)) == "-t + (-z + (-x - y))" + + +def test_Catalan(): + assert str(Catalan) == "Catalan" + + +def test_ComplexInfinity(): + assert str(zoo) == "zoo" + + +def test_Derivative(): + assert str(Derivative(x, y)) == "Derivative(x, y)" + assert str(Derivative(x**2, x, evaluate=False)) == "Derivative(x**2, x)" + assert str(Derivative( + x**2/y, x, y, evaluate=False)) == "Derivative(x**2/y, x, y)" + + +def test_dict(): + assert str({1: 1 + x}) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str({1: x**2, 2: y*x}) in ("{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr({1: x**2, 2: y*x}) == "{1: x**2, 2: x*y}" + + +def test_Dict(): + assert str(Dict({1: 1 + x})) == sstr({1: 1 + x}) == "{1: x + 1}" + assert str(Dict({1: x**2, 2: y*x})) in ( + "{1: x**2, 2: x*y}", "{2: x*y, 1: x**2}") + assert sstr(Dict({1: x**2, 2: y*x})) == "{1: x**2, 2: x*y}" + + +def test_Dummy(): + assert str(d) == "_d" + assert str(d + x) == "_d + x" + + +def test_EulerGamma(): + assert str(EulerGamma) == "EulerGamma" + + +def test_Exp(): + assert str(E) == "E" + with _exp_is_pow(True): + assert str(exp(x)) == "E**x" + + +def test_factorial(): + n = Symbol('n', integer=True) + assert str(factorial(-2)) == "zoo" + assert str(factorial(0)) == "1" + assert str(factorial(7)) == "5040" + assert str(factorial(n)) == "factorial(n)" + assert str(factorial(2*n)) == "factorial(2*n)" + assert str(factorial(factorial(n))) == 'factorial(factorial(n))' + assert str(factorial(factorial2(n))) == 'factorial(factorial2(n))' + assert str(factorial2(factorial(n))) == 'factorial2(factorial(n))' + assert str(factorial2(factorial2(n))) == 'factorial2(factorial2(n))' + assert str(subfactorial(3)) == "2" + assert str(subfactorial(n)) == "subfactorial(n)" + assert str(subfactorial(2*n)) == "subfactorial(2*n)" + + +def test_Function(): + f = Function('f') + fx = f(x) + w = WildFunction('w') + assert str(f) == "f" + assert str(fx) == "f(x)" + assert str(w) == "w_" + + +def test_Geometry(): + assert sstr(Point(0, 0)) == 'Point2D(0, 0)' + assert sstr(Circle(Point(0, 0), 3)) == 'Circle(Point2D(0, 0), 3)' + assert sstr(Ellipse(Point(1, 2), 3, 4)) == 'Ellipse(Point2D(1, 2), 3, 4)' + assert sstr(Triangle(Point(1, 1), Point(7, 8), Point(0, -1))) == \ + 'Triangle(Point2D(1, 1), Point2D(7, 8), Point2D(0, -1))' + assert sstr(Polygon(Point(5, 6), Point(-2, -3), Point(0, 0), Point(4, 7))) == \ + 'Polygon(Point2D(5, 6), Point2D(-2, -3), Point2D(0, 0), Point2D(4, 7))' + assert sstr(Triangle(Point(0, 0), Point(1, 0), Point(0, 1)), sympy_integers=True) == \ + 'Triangle(Point2D(S(0), S(0)), Point2D(S(1), S(0)), Point2D(S(0), S(1)))' + assert sstr(Ellipse(Point(1, 2), 3, 4), sympy_integers=True) == \ + 'Ellipse(Point2D(S(1), S(2)), S(3), S(4))' + + +def test_GoldenRatio(): + assert str(GoldenRatio) == "GoldenRatio" + + +def test_Heaviside(): + assert str(Heaviside(x)) == str(Heaviside(x, S.Half)) == "Heaviside(x)" + assert str(Heaviside(x, 1)) == "Heaviside(x, 1)" + + +def test_TribonacciConstant(): + assert str(TribonacciConstant) == "TribonacciConstant" + + +def test_ImaginaryUnit(): + assert str(I) == "I" + + +def test_Infinity(): + assert str(oo) == "oo" + assert str(oo*I) == "oo*I" + + +def test_Integer(): + assert str(Integer(-1)) == "-1" + assert str(Integer(1)) == "1" + assert str(Integer(-3)) == "-3" + assert str(Integer(0)) == "0" + assert str(Integer(25)) == "25" + + +def test_Integral(): + assert str(Integral(sin(x), y)) == "Integral(sin(x), y)" + assert str(Integral(sin(x), (y, 0, 1))) == "Integral(sin(x), (y, 0, 1))" + + +def test_Interval(): + n = (S.NegativeInfinity, 1, 2, S.Infinity) + for i in range(len(n)): + for j in range(i + 1, len(n)): + for l in (True, False): + for r in (True, False): + ival = Interval(n[i], n[j], l, r) + assert S(str(ival)) == ival + + +def test_AccumBounds(): + a = Symbol('a', real=True) + assert str(AccumBounds(0, a)) == "AccumBounds(0, a)" + assert str(AccumBounds(0, 1)) == "AccumBounds(0, 1)" + + +def test_Lambda(): + assert str(Lambda(d, d**2)) == "Lambda(_d, _d**2)" + # issue 2908 + assert str(Lambda((), 1)) == "Lambda((), 1)" + assert str(Lambda((), x)) == "Lambda((), x)" + assert str(Lambda((x, y), x+y)) == "Lambda((x, y), x + y)" + assert str(Lambda(((x, y),), x+y)) == "Lambda(((x, y),), x + y)" + + +def test_Limit(): + assert str(Limit(sin(x)/x, x, y)) == "Limit(sin(x)/x, x, y, dir='+')" + assert str(Limit(1/x, x, 0)) == "Limit(1/x, x, 0, dir='+')" + assert str( + Limit(sin(x)/x, x, y, dir="-")) == "Limit(sin(x)/x, x, y, dir='-')" + + +def test_list(): + assert str([x]) == sstr([x]) == "[x]" + assert str([x**2, x*y + 1]) == sstr([x**2, x*y + 1]) == "[x**2, x*y + 1]" + assert str([x**2, [y + x]]) == sstr([x**2, [y + x]]) == "[x**2, [x + y]]" + + +def test_Matrix_str(): + M = Matrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + M = Matrix([[1]]) + assert str(M) == sstr(M) == "Matrix([[1]])" + M = Matrix([[1, 2]]) + assert str(M) == sstr(M) == "Matrix([[1, 2]])" + M = Matrix() + assert str(M) == sstr(M) == "Matrix(0, 0, [])" + M = Matrix(0, 1, lambda i, j: 0) + assert str(M) == sstr(M) == "Matrix(0, 1, [])" + + +def test_Mul(): + assert str(x/y) == "x/y" + assert str(y/x) == "y/x" + assert str(x/y/z) == "x/(y*z)" + assert str((x + 1)/(y + 2)) == "(x + 1)/(y + 2)" + assert str(2*x/3) == '2*x/3' + assert str(-2*x/3) == '-2*x/3' + assert str(-1.0*x) == '-1.0*x' + assert str(1.0*x) == '1.0*x' + assert str(Mul(0, 1, evaluate=False)) == '0*1' + assert str(Mul(1, 0, evaluate=False)) == '1*0' + assert str(Mul(1, 1, evaluate=False)) == '1*1' + assert str(Mul(1, 1, 1, evaluate=False)) == '1*1*1' + assert str(Mul(1, 2, evaluate=False)) == '1*2' + assert str(Mul(1, S.Half, evaluate=False)) == '1*(1/2)' + assert str(Mul(1, 1, S.Half, evaluate=False)) == '1*1*(1/2)' + assert str(Mul(1, 1, 2, 3, x, evaluate=False)) == '1*1*2*3*x' + assert str(Mul(1, -1, evaluate=False)) == '1*(-1)' + assert str(Mul(-1, 1, evaluate=False)) == '-1*1' + assert str(Mul(4, 3, 2, 1, 0, y, x, evaluate=False)) == '4*3*2*1*0*y*x' + assert str(Mul(4, 3, 2, 1+z, 0, y, x, evaluate=False)) == '4*3*2*(z + 1)*0*y*x' + assert str(Mul(Rational(2, 3), Rational(5, 7), evaluate=False)) == '(2/3)*(5/7)' + # For issue 14160 + assert str(Mul(-2, x, Pow(Mul(y,y,evaluate=False), -1, evaluate=False), + evaluate=False)) == '-2*x/(y*y)' + # issue 21537 + assert str(Mul(x, Pow(1/y, -1, evaluate=False), evaluate=False)) == 'x/(1/y)' + + # Issue 24108 + from sympy.core.parameters import evaluate + with evaluate(False): + assert str(Mul(Pow(Integer(2), Integer(-1)), Add(Integer(-1), Mul(Integer(-1), Integer(1))))) == "(-1 - 1*1)/2" + + class CustomClass1(Expr): + is_commutative = True + + class CustomClass2(Expr): + is_commutative = True + cc1 = CustomClass1() + cc2 = CustomClass2() + assert str(Rational(2)*cc1) == '2*CustomClass1()' + assert str(cc1*Rational(2)) == '2*CustomClass1()' + assert str(cc1*Float("1.5")) == '1.5*CustomClass1()' + assert str(cc2*Rational(2)) == '2*CustomClass2()' + assert str(cc2*Rational(2)*cc1) == '2*CustomClass1()*CustomClass2()' + assert str(cc1*Rational(2)*cc2) == '2*CustomClass1()*CustomClass2()' + + +def test_NaN(): + assert str(nan) == "nan" + + +def test_NegativeInfinity(): + assert str(-oo) == "-oo" + +def test_Order(): + assert str(O(x)) == "O(x)" + assert str(O(x**2)) == "O(x**2)" + assert str(O(x*y)) == "O(x*y, x, y)" + assert str(O(x, x)) == "O(x)" + assert str(O(x, (x, 0))) == "O(x)" + assert str(O(x, (x, oo))) == "O(x, (x, oo))" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, x, y)) == "O(x, x, y)" + assert str(O(x, (x, oo), (y, oo))) == "O(x, (x, oo), (y, oo))" + + +def test_Permutation_Cycle(): + from sympy.combinatorics import Permutation, Cycle + + # general principle: economically, canonically show all moved elements + # and the size of the permutation. + + for p, s in [ + (Cycle(), + '()'), + (Cycle(2), + '(2)'), + (Cycle(2, 1), + '(1 2)'), + (Cycle(1, 2)(5)(6, 7)(10), + '(1 2)(6 7)(10)'), + (Cycle(3, 4)(1, 2)(3, 4), + '(1 2)(4)'), + ]: + assert sstr(p) == s + + for p, s in [ + (Permutation([]), + 'Permutation([])'), + (Permutation([], size=1), + 'Permutation([0])'), + (Permutation([], size=2), + 'Permutation([0, 1])'), + (Permutation([], size=10), + 'Permutation([], size=10)'), + (Permutation([1, 0, 2]), + 'Permutation([1, 0, 2])'), + (Permutation([1, 0, 2, 3, 4, 5]), + 'Permutation([1, 0], size=6)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + 'Permutation([1, 0], size=10)'), + ]: + assert sstr(p, perm_cyclic=False) == s + + for p, s in [ + (Permutation([]), + '()'), + (Permutation([], size=1), + '(0)'), + (Permutation([], size=2), + '(1)'), + (Permutation([], size=10), + '(9)'), + (Permutation([1, 0, 2]), + '(2)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5]), + '(5)(0 1)'), + (Permutation([1, 0, 2, 3, 4, 5], size=10), + '(9)(0 1)'), + (Permutation([0, 1, 3, 2, 4, 5], size=10), + '(9)(2 3)'), + ]: + assert sstr(p) == s + + + with warns_deprecated_sympy(): + old_print_cyclic = Permutation.print_cyclic + Permutation.print_cyclic = False + assert sstr(Permutation([1, 0, 2])) == 'Permutation([1, 0, 2])' + Permutation.print_cyclic = old_print_cyclic + +def test_Pi(): + assert str(pi) == "pi" + + +def test_Poly(): + assert str(Poly(0, x)) == "Poly(0, x, domain='ZZ')" + assert str(Poly(1, x)) == "Poly(1, x, domain='ZZ')" + assert str(Poly(x, x)) == "Poly(x, x, domain='ZZ')" + + assert str(Poly(2*x + 1, x)) == "Poly(2*x + 1, x, domain='ZZ')" + assert str(Poly(2*x - 1, x)) == "Poly(2*x - 1, x, domain='ZZ')" + + assert str(Poly(-1, x)) == "Poly(-1, x, domain='ZZ')" + assert str(Poly(-x, x)) == "Poly(-x, x, domain='ZZ')" + + assert str(Poly(-2*x + 1, x)) == "Poly(-2*x + 1, x, domain='ZZ')" + assert str(Poly(-2*x - 1, x)) == "Poly(-2*x - 1, x, domain='ZZ')" + + assert str(Poly(x - 1, x)) == "Poly(x - 1, x, domain='ZZ')" + assert str(Poly(2*x + x**5, x)) == "Poly(x**5 + 2*x, x, domain='ZZ')" + + assert str(Poly(3**(2*x), 3**x)) == "Poly((3**x)**2, 3**x, domain='ZZ')" + assert str(Poly((x**2)**x)) == "Poly(((x**2)**x), (x**2)**x, domain='ZZ')" + + assert str(Poly((x + y)**3, (x + y), expand=False) + ) == "Poly((x + y)**3, x + y, domain='ZZ')" + assert str(Poly((x - 1)**2, (x - 1), expand=False) + ) == "Poly((x - 1)**2, x - 1, domain='ZZ')" + + assert str( + Poly(x**2 + 1 + y, x)) == "Poly(x**2 + y + 1, x, domain='ZZ[y]')" + assert str( + Poly(x**2 - 1 + y, x)) == "Poly(x**2 + y - 1, x, domain='ZZ[y]')" + + assert str(Poly(x**2 + I*x, x)) == "Poly(x**2 + I*x, x, domain='ZZ_I')" + assert str(Poly(x**2 - I*x, x)) == "Poly(x**2 - I*x, x, domain='ZZ_I')" + + assert str(Poly(-x*y*z + x*y - 1, x, y, z) + ) == "Poly(-x*y*z + x*y - 1, x, y, z, domain='ZZ')" + assert str(Poly(-w*x**21*y**7*z + (1 + w)*z**3 - 2*x*z + 1, x, y, z)) == \ + "Poly(-w*x**21*y**7*z - 2*x*z + (w + 1)*z**3 + 1, x, y, z, domain='ZZ[w]')" + + assert str(Poly(x**2 + 1, x, modulus=2)) == "Poly(x**2 + 1, x, modulus=2)" + assert str(Poly(2*x**2 + 3*x + 4, x, modulus=17)) == "Poly(2*x**2 + 3*x + 4, x, modulus=17)" + + +def test_PolyRing(): + assert str(ring("x", ZZ, lex)[0]) == "Polynomial ring in x over ZZ with lex order" + assert str(ring("x,y", QQ, grlex)[0]) == "Polynomial ring in x, y over QQ with grlex order" + assert str(ring("x,y,z", ZZ["t"], lex)[0]) == "Polynomial ring in x, y, z over ZZ[t] with lex order" + + +def test_FracField(): + assert str(field("x", ZZ, lex)[0]) == "Rational function field in x over ZZ with lex order" + assert str(field("x,y", QQ, grlex)[0]) == "Rational function field in x, y over QQ with grlex order" + assert str(field("x,y,z", ZZ["t"], lex)[0]) == "Rational function field in x, y, z over ZZ[t] with lex order" + + +def test_PolyElement(): + Ruv, u,v = ring("u,v", ZZ) + Rxyz, x,y,z = ring("x,y,z", Ruv) + Rx_zzi, xz = ring("x", ZZ_I) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + assert str(x**2) == "x**2" + assert str(x**(-2)) == "x**(-2)" + assert str(x**QQ(1, 2)) == "x**(1/2)" + + assert str((u**2 + 3*u*v + 1)*x**2*y + u + 1) == "(u**2 + 3*u*v + 1)*x**2*y + u + 1" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x" + assert str((u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1) == "(u**2 + 3*u*v + 1)*x**2*y + (u + 1)*x + 1" + assert str((-u**2 + 3*u*v - 1)*x**2*y - (u + 1)*x - 1) == "-(u**2 - 3*u*v + 1)*x**2*y - (u + 1)*x - 1" + + assert str(-(v**2 + v + 1)*x + 3*u*v + 1) == "-(v**2 + v + 1)*x + 3*u*v + 1" + assert str(-(v**2 + v + 1)*x - 3*u*v + 1) == "-(v**2 + v + 1)*x - 3*u*v + 1" + + assert str((1+I)*xz + 2) == "(1 + 1*I)*x + (2 + 0*I)" + + +def test_FracElement(): + Fuv, u,v = field("u,v", ZZ) + Fxyzt, x,y,z,t = field("x,y,z,t", Fuv) + Rx_zzi, xz = field("x", QQ_I) + i = QQ_I(0, 1) + + assert str(x - x) == "0" + assert str(x - 1) == "x - 1" + assert str(x + 1) == "x + 1" + + assert str(x/3) == "x/3" + assert str(x/z) == "x/z" + assert str(x*y/z) == "x*y/z" + assert str(x/(z*t)) == "x/(z*t)" + assert str(x*y/(z*t)) == "x*y/(z*t)" + + assert str((x - 1)/y) == "(x - 1)/y" + assert str((x + 1)/y) == "(x + 1)/y" + assert str((-x - 1)/y) == "(-x - 1)/y" + assert str((x + 1)/(y*z)) == "(x + 1)/(y*z)" + assert str(-y/(x + 1)) == "-y/(x + 1)" + assert str(y*z/(x + 1)) == "y*z/(x + 1)" + + assert str(((u + 1)*x*y + 1)/((v - 1)*z - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - 1)" + assert str(((u + 1)*x*y + 1)/((v - 1)*z - t*u*v - 1)) == "((u + 1)*x*y + 1)/((v - 1)*z - u*v*t - 1)" + + assert str((1+i)/xz) == "(1 + 1*I)/x" + assert str(((1+i)*xz - i)/xz) == "((1 + 1*I)*x + (0 + -1*I))/x" + + +def test_GaussianInteger(): + assert str(ZZ_I(1, 0)) == "1" + assert str(ZZ_I(-1, 0)) == "-1" + assert str(ZZ_I(0, 1)) == "I" + assert str(ZZ_I(0, -1)) == "-I" + assert str(ZZ_I(0, 2)) == "2*I" + assert str(ZZ_I(0, -2)) == "-2*I" + assert str(ZZ_I(1, 1)) == "1 + I" + assert str(ZZ_I(-1, -1)) == "-1 - I" + assert str(ZZ_I(-1, -2)) == "-1 - 2*I" + + +def test_GaussianRational(): + assert str(QQ_I(1, 0)) == "1" + assert str(QQ_I(QQ(2, 3), 0)) == "2/3" + assert str(QQ_I(0, QQ(2, 3))) == "2*I/3" + assert str(QQ_I(QQ(1, 2), QQ(-2, 3))) == "1/2 - 2*I/3" + + +def test_Pow(): + assert str(x**-1) == "1/x" + assert str(x**-2) == "x**(-2)" + assert str(x**2) == "x**2" + assert str((x + y)**-1) == "1/(x + y)" + assert str((x + y)**-2) == "(x + y)**(-2)" + assert str((x + y)**2) == "(x + y)**2" + assert str((x + y)**(1 + x)) == "(x + y)**(x + 1)" + assert str(x**Rational(1, 3)) == "x**(1/3)" + assert str(1/x**Rational(1, 3)) == "x**(-1/3)" + assert str(sqrt(sqrt(x))) == "x**(1/4)" + # not the same as x**-1 + assert str(x**-1.0) == 'x**(-1.0)' + # see issue #2860 + assert str(Pow(S(2), -1.0, evaluate=False)) == '2**(-1.0)' + + +def test_sqrt(): + assert str(sqrt(x)) == "sqrt(x)" + assert str(sqrt(x**2)) == "sqrt(x**2)" + assert str(1/sqrt(x)) == "1/sqrt(x)" + assert str(1/sqrt(x**2)) == "1/sqrt(x**2)" + assert str(y/sqrt(x)) == "y/sqrt(x)" + assert str(x**0.5) == "x**0.5" + assert str(1/x**0.5) == "x**(-0.5)" + + +def test_Rational(): + n1 = Rational(1, 4) + n2 = Rational(1, 3) + n3 = Rational(2, 4) + n4 = Rational(2, -4) + n5 = Rational(0) + n7 = Rational(3) + n8 = Rational(-3) + assert str(n1*n2) == "1/12" + assert str(n1*n2) == "1/12" + assert str(n3) == "1/2" + assert str(n1*n3) == "1/8" + assert str(n1 + n3) == "3/4" + assert str(n1 + n2) == "7/12" + assert str(n1 + n4) == "-1/4" + assert str(n4*n4) == "1/4" + assert str(n4 + n2) == "-1/6" + assert str(n4 + n5) == "-1/2" + assert str(n4*n5) == "0" + assert str(n3 + n4) == "0" + assert str(n1**n7) == "1/64" + assert str(n2**n7) == "1/27" + assert str(n2**n8) == "27" + assert str(n7**n8) == "1/27" + assert str(Rational("-25")) == "-25" + assert str(Rational("1.25")) == "5/4" + assert str(Rational("-2.6e-2")) == "-13/500" + assert str(S("25/7")) == "25/7" + assert str(S("-123/569")) == "-123/569" + assert str(S("0.1[23]", rational=1)) == "61/495" + assert str(S("5.1[666]", rational=1)) == "31/6" + assert str(S("-5.1[666]", rational=1)) == "-31/6" + assert str(S("0.[9]", rational=1)) == "1" + assert str(S("-0.[9]", rational=1)) == "-1" + + assert str(sqrt(Rational(1, 4))) == "1/2" + assert str(sqrt(Rational(1, 36))) == "1/6" + + assert str((123**25) ** Rational(1, 25)) == "123" + assert str((123**25 + 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "123" + assert str((123**25 - 1)**Rational(1, 25)) != "122" + + assert str(sqrt(Rational(81, 36))**3) == "27/8" + assert str(1/sqrt(Rational(81, 36))**3) == "8/27" + + assert str(sqrt(-4)) == str(2*I) + assert str(2**Rational(1, 10**10)) == "2**(1/10000000000)" + + assert sstr(Rational(2, 3), sympy_integers=True) == "S(2)/3" + x = Symbol("x") + assert sstr(x**Rational(2, 3), sympy_integers=True) == "x**(S(2)/3)" + assert sstr(Eq(x, Rational(2, 3)), sympy_integers=True) == "Eq(x, S(2)/3)" + assert sstr(Limit(x, x, Rational(7, 2)), sympy_integers=True) == \ + "Limit(x, x, S(7)/2, dir='+')" + + +def test_Float(): + # NOTE dps is the whole number of decimal digits + assert str(Float('1.23', dps=1 + 2)) == '1.23' + assert str(Float('1.23456789', dps=1 + 8)) == '1.23456789' + assert str( + Float('1.234567890123456789', dps=1 + 18)) == '1.234567890123456789' + assert str(pi.evalf(1 + 2)) == '3.14' + assert str(pi.evalf(1 + 14)) == '3.14159265358979' + assert str(pi.evalf(1 + 64)) == ('3.141592653589793238462643383279' + '5028841971693993751058209749445923') + assert str(pi.round(-1)) == '0.0' + assert str((pi**400 - (pi**400).round(1)).n(2)) == '-0.e+88' + assert sstr(Float("100"), full_prec=False, min=-2, max=2) == '1.0e+2' + assert sstr(Float("100"), full_prec=False, min=-2, max=3) == '100.0' + assert sstr(Float("0.1"), full_prec=False, min=-2, max=3) == '0.1' + assert sstr(Float("0.099"), min=-2, max=3) == '9.90000000000000e-2' + + +def test_Relational(): + assert str(Rel(x, y, "<")) == "x < y" + assert str(Rel(x + y, y, "==")) == "Eq(x + y, y)" + assert str(Rel(x, y, "!=")) == "Ne(x, y)" + assert str(Eq(x, 1) | Eq(x, 2)) == "Eq(x, 1) | Eq(x, 2)" + assert str(Ne(x, 1) & Ne(x, 2)) == "Ne(x, 1) & Ne(x, 2)" + + +def test_AppliedBinaryRelation(): + assert str(Q.eq(x, y)) == "Q.eq(x, y)" + assert str(Q.ne(x, y)) == "Q.ne(x, y)" + + +def test_CRootOf(): + assert str(rootof(x**5 + 2*x - 1, 0)) == "CRootOf(x**5 + 2*x - 1, 0)" + + +def test_RootSum(): + f = x**5 + 2*x - 1 + + assert str( + RootSum(f, Lambda(z, z), auto=False)) == "RootSum(x**5 + 2*x - 1)" + assert str(RootSum(f, Lambda( + z, z**2), auto=False)) == "RootSum(x**5 + 2*x - 1, Lambda(z, z**2))" + + +def test_GroebnerBasis(): + assert str(groebner( + [], x, y)) == "GroebnerBasis([], x, y, domain='ZZ', order='lex')" + + F = [x**2 - 3*y - x + 1, y**2 - 2*x + y - 1] + + assert str(groebner(F, order='grlex')) == \ + "GroebnerBasis([x**2 - x - 3*y + 1, y**2 - 2*x + y - 1], x, y, domain='ZZ', order='grlex')" + assert str(groebner(F, order='lex')) == \ + "GroebnerBasis([2*x - y**2 - y + 1, y**4 + 2*y**3 - 3*y**2 - 16*y + 7], x, y, domain='ZZ', order='lex')" + +def test_set(): + assert sstr(set()) == 'set()' + assert sstr(frozenset()) == 'frozenset()' + + assert sstr({1}) == '{1}' + assert sstr(frozenset([1])) == 'frozenset({1})' + assert sstr({1, 2, 3}) == '{1, 2, 3}' + assert sstr(frozenset([1, 2, 3])) == 'frozenset({1, 2, 3})' + + assert sstr( + {1, x, x**2, x**3, x**4}) == '{1, x, x**2, x**3, x**4}' + assert sstr( + frozenset([1, x, x**2, x**3, x**4])) == 'frozenset({1, x, x**2, x**3, x**4})' + + +def test_SparseMatrix(): + M = SparseMatrix([[x**+1, 1], [y, x + y]]) + assert str(M) == "Matrix([[x, 1], [y, x + y]])" + assert sstr(M) == "Matrix([\n[x, 1],\n[y, x + y]])" + + +def test_Sum(): + assert str(summation(cos(3*z), (z, x, y))) == "Sum(cos(3*z), (z, x, y))" + assert str(Sum(x*y**2, (x, -2, 2), (y, -5, 5))) == \ + "Sum(x*y**2, (x, -2, 2), (y, -5, 5))" + + +def test_Symbol(): + assert str(y) == "y" + assert str(x) == "x" + e = x + assert str(e) == "x" + + +def test_tuple(): + assert str((x,)) == sstr((x,)) == "(x,)" + assert str((x + y, 1 + x)) == sstr((x + y, 1 + x)) == "(x + y, x + 1)" + assert str((x + y, ( + 1 + x, x**2))) == sstr((x + y, (1 + x, x**2))) == "(x + y, (x + 1, x**2))" + + +def test_Series_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Series(tf1, tf2)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Series(tf1, tf2, tf3)) == \ + "Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Series(-tf2, tf1)) == \ + "Series(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOSeries_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOSeries(tfm_1, tfm_2)) == \ + "MIMOSeries(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_TransferFunction_str(): + tf1 = TransferFunction(x - 1, x + 1, x) + assert str(tf1) == "TransferFunction(x - 1, x + 1, x)" + tf2 = TransferFunction(x + 1, 2 - y, x) + assert str(tf2) == "TransferFunction(x + 1, 2 - y, x)" + tf3 = TransferFunction(y, y**2 + 2*y + 3, y) + assert str(tf3) == "TransferFunction(y, y**2 + 2*y + 3, y)" + + +def test_Parallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Parallel(tf1, tf2)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y))" + assert str(Parallel(tf1, tf2, tf3)) == \ + "Parallel(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y), TransferFunction(t*x**2 - t**w*x + w, t - y, y))" + assert str(Parallel(-tf2, tf1)) == \ + "Parallel(TransferFunction(-x + y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y))" + + +def test_MIMOParallel_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tfm_1 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + tfm_2 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + assert str(MIMOParallel(tfm_1, tfm_2)) == \ + "MIMOParallel(TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), "\ + "(TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)))), "\ + "TransferFunctionMatrix(((TransferFunction(x - y, x + y, y), TransferFunction(x*y**2 - z, -t**3 + y**3, y)), "\ + "(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)))))" + + +def test_Feedback_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(Feedback(tf1*tf2, tf3)) == \ + "Feedback(Series(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), " \ + "TransferFunction(t*x**2 - t**w*x + w, t - y, y), -1)" + assert str(Feedback(tf1, TransferFunction(1, 1, y), 1)) == \ + "Feedback(TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(1, 1, y), 1)" + + +def test_MIMOFeedback_str(): + tf1 = TransferFunction(x**2 - y**3, y - z, x) + tf2 = TransferFunction(y - x, z + y, x) + tfm_1 = TransferFunctionMatrix([[tf2, tf1], [tf1, tf2]]) + tfm_2 = TransferFunctionMatrix([[tf1, tf2], [tf2, tf1]]) + assert (str(MIMOFeedback(tfm_1, tfm_2)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x))," \ + " (TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), " \ + "TransferFunction(-x + y, y + z, x)), (TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), -1)") + assert (str(MIMOFeedback(tfm_1, tfm_2, 1)) \ + == "MIMOFeedback(TransferFunctionMatrix(((TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)), " \ + "(TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)))), " \ + "TransferFunctionMatrix(((TransferFunction(x**2 - y**3, y - z, x), TransferFunction(-x + y, y + z, x)), "\ + "(TransferFunction(-x + y, y + z, x), TransferFunction(x**2 - y**3, y - z, x)))), 1)") + + +def test_TransferFunctionMatrix_str(): + tf1 = TransferFunction(x*y**2 - z, y**3 - t**3, y) + tf2 = TransferFunction(x - y, x + y, y) + tf3 = TransferFunction(t*x**2 - t**w*x + w, t - y, y) + assert str(TransferFunctionMatrix([[tf1], [tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y),), (TransferFunction(x - y, x + y, y),)))" + assert str(TransferFunctionMatrix([[tf1, tf2], [tf3, tf2]])) == \ + "TransferFunctionMatrix(((TransferFunction(x*y**2 - z, -t**3 + y**3, y), TransferFunction(x - y, x + y, y)), (TransferFunction(t*x**2 - t**w*x + w, t - y, y), TransferFunction(x - y, x + y, y))))" + + +def test_Quaternion_str_printer(): + q = Quaternion(x, y, z, t) + assert str(q) == "x + y*i + z*j + t*k" + q = Quaternion(x,y,z,x*t) + assert str(q) == "x + y*i + z*j + t*x*k" + q = Quaternion(x,y,z,x+t) + assert str(q) == "x + y*i + z*j + (t + x)*k" + + +def test_Quantity_str(): + assert sstr(second, abbrev=True) == "s" + assert sstr(joule, abbrev=True) == "J" + assert str(second) == "second" + assert str(joule) == "joule" + + +def test_wild_str(): + # Check expressions containing Wild not causing infinite recursion + w = Wild('x') + assert str(w + 1) == 'x_ + 1' + assert str(exp(2**w) + 5) == 'exp(2**x_) + 5' + assert str(3*w + 1) == '3*x_ + 1' + assert str(1/w + 1) == '1 + 1/x_' + assert str(w**2 + 1) == 'x_**2 + 1' + assert str(1/(1 - w)) == '1/(1 - x_)' + + +def test_wild_matchpy(): + from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar + + matchpy = import_module("matchpy") + + if matchpy is None: + return + + wd = WildDot('w_') + wp = WildPlus('w__') + ws = WildStar('w___') + + assert str(wd) == 'w_' + assert str(wp) == 'w__' + assert str(ws) == 'w___' + + assert str(wp/ws + 2**wd) == '2**w_ + w__/w___' + assert str(sin(wd)*cos(wp)*sqrt(ws)) == 'sqrt(w___)*sin(w_)*cos(w__)' + + +def test_zeta(): + assert str(zeta(3)) == "zeta(3)" + + +def test_issue_3101(): + e = x - y + a = str(e) + b = str(e) + assert a == b + + +def test_issue_3103(): + e = -2*sqrt(x) - y/sqrt(x)/2 + assert str(e) not in ["(-2)*x**1/2(-1/2)*x**(-1/2)*y", + "-2*x**1/2(-1/2)*x**(-1/2)*y", "-2*x**1/2-1/2*x**-1/2*w"] + assert str(e) == "-2*sqrt(x) - y/(2*sqrt(x))" + + +def test_issue_4021(): + e = Integral(x, x) + 1 + assert str(e) == 'Integral(x, x) + 1' + + +def test_sstrrepr(): + assert sstr('abc') == 'abc' + assert sstrrepr('abc') == "'abc'" + + e = ['a', 'b', 'c', x] + assert sstr(e) == "[a, b, c, x]" + assert sstrrepr(e) == "['a', 'b', 'c', x]" + + +def test_infinity(): + assert sstr(oo*I) == "oo*I" + + +def test_full_prec(): + assert sstr(S("0.3"), full_prec=True) == "0.300000000000000" + assert sstr(S("0.3"), full_prec="auto") == "0.300000000000000" + assert sstr(S("0.3"), full_prec=False) == "0.3" + assert sstr(S("0.3")*x, full_prec=True) in [ + "0.300000000000000*x", + "x*0.300000000000000" + ] + assert sstr(S("0.3")*x, full_prec="auto") in [ + "0.3*x", + "x*0.3" + ] + assert sstr(S("0.3")*x, full_prec=False) in [ + "0.3*x", + "x*0.3" + ] + + +def test_noncommutative(): + A, B, C = symbols('A,B,C', commutative=False) + + assert sstr(A*B*C**-1) == "A*B*C**(-1)" + assert sstr(C**-1*A*B) == "C**(-1)*A*B" + assert sstr(A*C**-1*B) == "A*C**(-1)*B" + assert sstr(sqrt(A)) == "sqrt(A)" + assert sstr(1/sqrt(A)) == "A**(-1/2)" + + +def test_empty_printer(): + str_printer = StrPrinter() + assert str_printer.emptyPrinter("foo") == "foo" + assert str_printer.emptyPrinter(x*y) == "x*y" + assert str_printer.emptyPrinter(32) == "32" + + +def test_settings(): + raises(TypeError, lambda: sstr(S(4), method="garbage")) + + +def test_RandomDomain(): + from sympy.stats import Normal, Die, Exponential, pspace, where + X = Normal('x1', 0, 1) + assert str(where(X > 0)) == "Domain: (0 < x1) & (x1 < oo)" + + D = Die('d1', 6) + assert str(where(D > 4)) == "Domain: Eq(d1, 5) | Eq(d1, 6)" + + A = Exponential('a', 1) + B = Exponential('b', 1) + assert str(pspace(Tuple(A, B)).domain) == "Domain: (0 <= a) & (0 <= b) & (a < oo) & (b < oo)" + + +def test_FiniteSet(): + assert str(FiniteSet(*range(1, 51))) == ( + '{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17,' + ' 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34,' + ' 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50}' + ) + assert str(FiniteSet(*range(1, 6))) == '{1, 2, 3, 4, 5}' + assert str(FiniteSet(*[x*y, x**2])) == '{x**2, x*y}' + assert str(FiniteSet(FiniteSet(FiniteSet(x, y), 5), FiniteSet(x,y), 5) + ) == 'FiniteSet(5, FiniteSet(5, {x, y}), {x, y})' + + +def test_Partition(): + assert str(Partition(FiniteSet(x, y), {z})) == 'Partition({z}, {x, y})' + +def test_UniversalSet(): + assert str(S.UniversalSet) == 'UniversalSet' + + +def test_PrettyPoly(): + F = QQ.frac_field(x, y) + R = QQ[x, y] + assert sstr(F.convert(x/(x + y))) == sstr(x/(x + y)) + assert sstr(R.convert(x + y)) == sstr(x + y) + + +def test_categories(): + from sympy.categories import (Object, NamedMorphism, + IdentityMorphism, Category) + + A = Object("A") + B = Object("B") + + f = NamedMorphism(A, B, "f") + id_A = IdentityMorphism(A) + + K = Category("K") + + assert str(A) == 'Object("A")' + assert str(f) == 'NamedMorphism(Object("A"), Object("B"), "f")' + assert str(id_A) == 'IdentityMorphism(Object("A"))' + + assert str(K) == 'Category("K")' + + +def test_Tr(): + A, B = symbols('A B', commutative=False) + t = Tr(A*B) + assert str(t) == 'Tr(A*B)' + + +def test_issue_6387(): + assert str(factor(-3.0*z + 3)) == '-3.0*(1.0*z - 1.0)' + + +def test_MatMul_MatAdd(): + X, Y = MatrixSymbol("X", 2, 2), MatrixSymbol("Y", 2, 2) + assert str(2*(X + Y)) == "2*X + 2*Y" + + assert str(I*X) == "I*X" + assert str(-I*X) == "-I*X" + assert str((1 + I)*X) == '(1 + I)*X' + assert str(-(1 + I)*X) == '(-1 - I)*X' + assert str(MatAdd(MatAdd(X, Y), MatAdd(X, Y))) == '(X + Y) + (X + Y)' + + +def test_MatrixSlice(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + Y = MatrixSymbol('Y', 10, 10) + Z = MatrixSymbol('Z', 10, 10) + + assert str(MatrixSlice(X, (None, None, None), (None, None, None))) == 'X[:, :]' + assert str(X[x:x + 1, y:y + 1]) == 'X[x:x + 1, y:y + 1]' + assert str(X[x:x + 1:2, y:y + 1:2]) == 'X[x:x + 1:2, y:y + 1:2]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[:x, y:]) == 'X[:x, y:]' + assert str(X[x:, :y]) == 'X[x:, :y]' + assert str(X[x:y, z:w]) == 'X[x:y, z:w]' + assert str(X[x:y:t, w:t:x]) == 'X[x:y:t, w:t:x]' + assert str(X[x::y, t::w]) == 'X[x::y, t::w]' + assert str(X[:x:y, :t:w]) == 'X[:x:y, :t:w]' + assert str(X[::x, ::y]) == 'X[::x, ::y]' + assert str(MatrixSlice(X, (0, None, None), (0, None, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (None, n, None), (None, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, None), (0, n, None))) == 'X[:, :]' + assert str(MatrixSlice(X, (0, n, 2), (0, n, 2))) == 'X[::2, ::2]' + assert str(X[1:2:3, 4:5:6]) == 'X[1:2:3, 4:5:6]' + assert str(X[1:3:5, 4:6:8]) == 'X[1:3:5, 4:6:8]' + assert str(X[1:10:2]) == 'X[1:10:2, :]' + assert str(Y[:5, 1:9:2]) == 'Y[:5, 1:9:2]' + assert str(Y[:5, 1:10:2]) == 'Y[:5, 1::2]' + assert str(Y[5, :5:2]) == 'Y[5:6, :5:2]' + assert str(X[0:1, 0:1]) == 'X[:1, :1]' + assert str(X[0:1:2, 0:1:2]) == 'X[:1:2, :1:2]' + assert str((Y + Z)[2:, 2:]) == '(Y + Z)[2:, 2:]' + +def test_true_false(): + assert str(true) == repr(true) == sstr(true) == "True" + assert str(false) == repr(false) == sstr(false) == "False" + +def test_Equivalent(): + assert str(Equivalent(y, x)) == "Equivalent(x, y)" + +def test_Xor(): + assert str(Xor(y, x, evaluate=False)) == "x ^ y" + +def test_Complement(): + assert str(Complement(S.Reals, S.Naturals)) == 'Complement(Reals, Naturals)' + +def test_SymmetricDifference(): + assert str(SymmetricDifference(Interval(2, 3), Interval(3, 4),evaluate=False)) == \ + 'SymmetricDifference(Interval(2, 3), Interval(3, 4))' + + +def test_UnevaluatedExpr(): + a, b = symbols("a b") + expr1 = 2*UnevaluatedExpr(a+b) + assert str(expr1) == "2*(a + b)" + + +def test_MatrixElement_printing(): + # test cases for issue #11821 + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert(str(A[0, 0]) == "A[0, 0]") + assert(str(3 * A[0, 0]) == "3*A[0, 0]") + + F = C[0, 0].subs(C, A - B) + assert str(F) == "(A - B)[0, 0]" + + +def test_MatrixSymbol_printing(): + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 3, 3) + + assert str(A - A*B - B) == "A - A*B - B" + assert str(A*B - (A+B)) == "-A + A*B - B" + assert str(A**(-1)) == "A**(-1)" + assert str(A**3) == "A**3" + + +def test_MatrixExpressions(): + n = Symbol('n', integer=True) + X = MatrixSymbol('X', n, n) + + assert str(X) == "X" + + # Apply function elementwise (`ElementwiseApplyFunc`): + + expr = (X.T*X).applyfunc(sin) + assert str(expr) == 'Lambda(_d, sin(_d)).(X.T*X)' + + lamda = Lambda(x, 1/x) + expr = (n*X).applyfunc(lamda) + assert str(expr) == 'Lambda(x, 1/x).(n*X)' + + +def test_Subs_printing(): + assert str(Subs(x, (x,), (1,))) == 'Subs(x, x, 1)' + assert str(Subs(x + y, (x, y), (1, 2))) == 'Subs(x + y, (x, y), (1, 2))' + + +def test_issue_15716(): + e = Integral(factorial(x), (x, -oo, oo)) + assert e.as_terms() == ([(e, ((1.0, 0.0), (1,), ()))], [e]) + + +def test_str_special_matrices(): + from sympy.matrices import Identity, ZeroMatrix, OneMatrix + assert str(Identity(4)) == 'I' + assert str(ZeroMatrix(2, 2)) == '0' + assert str(OneMatrix(2, 2)) == '1' + + +def test_issue_14567(): + assert factorial(Sum(-1, (x, 0, 0))) + y # doesn't raise an error + + +def test_issue_21823(): + assert str(Partition([1, 2])) == 'Partition({1, 2})' + assert str(Partition({1, 2})) == 'Partition({1, 2})' + + +def test_issue_22689(): + assert str(Mul(Pow(x,-2, evaluate=False), Pow(3,-1,evaluate=False), evaluate=False)) == "1/(x**2*3)" + + +def test_issue_21119_21460(): + ss = lambda x: str(S(x, evaluate=False)) + assert ss('4/2') == '4/2' + assert ss('4/-2') == '4/(-2)' + assert ss('-4/2') == '-4/2' + assert ss('-4/-2') == '-4/(-2)' + assert ss('-2*3/-1') == '-2*3/(-1)' + assert ss('-2*3/-1/2') == '-2*3/(-1*2)' + assert ss('4/2/1') == '4/(2*1)' + assert ss('-2/-1/2') == '-2/(-1*2)' + assert ss('2*3*4**(-2*3)') == '2*3/4**(2*3)' + assert ss('2*3*1*4**(-2*3)') == '2*3*1/4**(2*3)' + + +def test_Str(): + from sympy.core.symbol import Str + assert str(Str('x')) == 'x' + assert sstrrepr(Str('x')) == "Str('x')" + + +def test_diffgeom(): + from sympy.diffgeom import Manifold, Patch, CoordSystem, BaseScalarField + x,y = symbols('x y', real=True) + m = Manifold('M', 2) + assert str(m) == "M" + p = Patch('P', m) + assert str(p) == "P" + rect = CoordSystem('rect', p, [x, y]) + assert str(rect) == "rect" + b = BaseScalarField(rect, 0) + assert str(b) == "x" + +def test_NDimArray(): + assert sstr(NDimArray(1.0), full_prec=True) == '1.00000000000000' + assert sstr(NDimArray(1.0), full_prec=False) == '1.0' + assert sstr(NDimArray([1.0, 2.0]), full_prec=True) == '[1.00000000000000, 2.00000000000000]' + assert sstr(NDimArray([1.0, 2.0]), full_prec=False) == '[1.0, 2.0]' + +def test_Predicate(): + assert sstr(Q.even) == 'Q.even' + +def test_AppliedPredicate(): + assert sstr(Q.even(x)) == 'Q.even(x)' + +def test_printing_str_array_expressions(): + assert sstr(ArraySymbol("A", (2, 3, 4))) == "A" + assert sstr(ArrayElement("A", (2, 1/(1-x), 0))) == "A[2, 1/(1 - x), 0]" + M = MatrixSymbol("M", 3, 3) + N = MatrixSymbol("N", 3, 3) + assert sstr(ArrayElement(M*N, [x, 0])) == "(M*N)[x, 0]" + +def test_printing_stats(): + # issue 24132 + x = RandomSymbol("x") + y = RandomSymbol("y") + z1 = Probability(x > 0)*Identity(2) + z2 = Expectation(x)*Identity(2) + z3 = Variance(x)*Identity(2) + z4 = Covariance(x, y) * Identity(2) + + assert str(z1) == "Probability(x > 0)*I" + assert str(z2) == "Expectation(x)*I" + assert str(z3) == "Variance(x)*I" + assert str(z4) == "Covariance(x, y)*I" + assert z1.is_commutative == False + assert z2.is_commutative == False + assert z3.is_commutative == False + assert z4.is_commutative == False + assert z2._eval_is_commutative() == False + assert z3._eval_is_commutative() == False + assert z4._eval_is_commutative() == False diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_tableform.py b/lib/python3.10/site-packages/sympy/printing/tests/test_tableform.py new file mode 100644 index 0000000000000000000000000000000000000000..05802dd104a12f2f53d137167ecf31d201ff8dfc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_tableform.py @@ -0,0 +1,182 @@ +from sympy.core.singleton import S +from sympy.printing.tableform import TableForm +from sympy.printing.latex import latex +from sympy.abc import x +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.testing.pytest import raises + +from textwrap import dedent + + +def test_TableForm(): + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic")) + assert s == ( + ' | 1 2\n' + '-------\n' + '1 | a b\n' + '2 | c d\n' + '3 | e ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", 0]], + headings="automatic", wipe_zeros=False)) + assert s == dedent('''\ + | 1 2 + ------- + 1 | a b + 2 | c d + 3 | e 0''') + s = str(TableForm([[x**2, "b"], ["c", x**2], ["e", "f"]], + headings=("automatic", None))) + assert s == ( + '1 | x**2 b \n' + '2 | c x**2\n' + '3 | e f ' + ) + s = str(TableForm([["a", "b"], ["c", "d"], ["e", "f"]], + headings=(None, "automatic"))) + assert s == dedent('''\ + 1 2 + --- + a b + c d + e f''') + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]])) + assert s == ( + ' | y1 y2\n' + '---------------\n' + 'Group A | 5 7 \n' + 'Group B | 4 2 \n' + 'Group C | 10 3 ' + ) + raises( + ValueError, + lambda: + TableForm( + [[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="middle") + ) + s = str(TableForm([[5, 7], [4, 2], [10, 3]], + headings=[["Group A", "Group B", "Group C"], ["y1", "y2"]], + alignments="right")) + assert s == dedent('''\ + | y1 y2 + --------------- + Group A | 5 7 + Group B | 4 2 + Group C | 10 3''') + + # other alignment permutations + d = [[1, 100], [100, 1]] + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='l') + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='lr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + s = TableForm(d, headings=(('xxx', 'x'), None), alignments='clr') + assert str(s) == dedent('''\ + xxx | 1 100 + x | 100 1''') + + s = TableForm(d, headings=(('xxx', 'x'), None)) + assert str(s) == ( + 'xxx | 1 100\n' + ' x | 100 1 ' + ) + + raises(ValueError, lambda: TableForm(d, alignments='clr')) + + #pad + s = str(TableForm([[None, "-", 2], [1]], pad='?')) + assert s == dedent('''\ + ? - 2 + 1 ? ?''') + + +def test_TableForm_latex(): + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l')) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([[0, x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + wipe_zeros=True, headings=("automatic", "automatic"), alignments='l'*3)) + assert s == ( + '\\begin{tabular}{l l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & $a$ & $x^{3}$ \\\\\n' + '2 & $c$ & $\\frac{1}{4}$ \\\\\n' + '3 & $\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]], + formats=['(%s)', None], headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & (a) & $x^{3}$ \\\\\n' + '2 & (c) & $\\frac{1}{4}$ \\\\\n' + '3 & (sqrt(x)) & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) + + def neg_in_paren(x, i, j): + if i % 2: + return ('(%s)' if x < 0 else '%s') % x + else: + pass # use default print + s = latex(TableForm([[-1, 2], [-3, 4]], + formats=[neg_in_paren]*2, headings=("automatic", "automatic"))) + assert s == ( + '\\begin{tabular}{r l l}\n' + ' & 1 & 2 \\\\\n' + '\\hline\n' + '1 & -1 & 2 \\\\\n' + '2 & (-3) & 4 \\\\\n' + '\\end{tabular}' + ) + s = latex(TableForm([["a", x**3], ["c", S.One/4], [sqrt(x), sin(x**2)]])) + assert s == ( + '\\begin{tabular}{l l}\n' + '$a$ & $x^{3}$ \\\\\n' + '$c$ & $\\frac{1}{4}$ \\\\\n' + '$\\sqrt{x}$ & $\\sin{\\left(x^{2} \\right)}$ \\\\\n' + '\\end{tabular}' + ) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_tensorflow.py b/lib/python3.10/site-packages/sympy/printing/tests/test_tensorflow.py new file mode 100644 index 0000000000000000000000000000000000000000..d511e3b6c7fc0840331f2dcfacd5bbb11002758c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_tensorflow.py @@ -0,0 +1,465 @@ +import random +from sympy.core.function import Derivative +from sympy.core.symbol import symbols +from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct, ArrayAdd, \ + PermuteDims, ArrayDiagonal +from sympy.core.relational import Eq, Ne, Ge, Gt, Le, Lt +from sympy.external import import_module +from sympy.functions import \ + Abs, ceiling, exp, floor, sign, sin, asin, sqrt, cos, \ + acos, tan, atan, atan2, cosh, acosh, sinh, asinh, tanh, atanh, \ + re, im, arg, erf, loggamma, log +from sympy.matrices import Matrix, MatrixBase, eye, randMatrix +from sympy.matrices.expressions import \ + Determinant, HadamardProduct, Inverse, MatrixSymbol, Trace +from sympy.printing.tensorflow import tensorflow_code +from sympy.tensor.array.expressions.from_matrix_to_array import convert_matrix_to_array +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import skip +from sympy.testing.pytest import XFAIL + + +tf = tensorflow = import_module("tensorflow") + +if tensorflow: + # Hide Tensorflow warnings + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + + +M = MatrixSymbol("M", 3, 3) +N = MatrixSymbol("N", 3, 3) +P = MatrixSymbol("P", 3, 3) +Q = MatrixSymbol("Q", 3, 3) + +x, y, z, t = symbols("x y z t") + +if tf is not None: + llo = [list(range(i, i+3)) for i in range(0, 9, 3)] + m3x3 = tf.constant(llo) + m3x3sympy = Matrix(llo) + + +def _compare_tensorflow_matrix(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [randMatrix(v.rows, v.cols) for v in variables] + else: + random_matrices = [randMatrix(v.rows, v.cols)/100. for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +# Creating a custom inverse test. +# See https://github.com/sympy/sympy/issues/18469 +def _compare_tensorflow_matrix_inverse(variables, expr, use_float=False): + f = lambdify(variables, expr, 'tensorflow') + if not use_float: + random_matrices = [eye(v.rows, v.cols)*4 for v in variables] + else: + random_matrices = [eye(v.rows, v.cols)*3.14 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + if e.is_Matrix: + if not isinstance(e, MatrixBase): + e = e.as_explicit() + e = e.tolist() + + if not use_float: + assert (r == e).all() + else: + r = [i for row in r for i in row] + e = [i for row in e for i in row] + assert all( + abs(a-b) < 10**-(4-int(log(abs(a), 10))) for a, b in zip(r, e)) + + +def _compare_tensorflow_matrix_scalar(variables, expr): + f = lambdify(variables, expr, 'tensorflow') + random_matrices = [ + randMatrix(v.rows, v.cols).evalf() / 100 for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + random_variables = [eval(tensorflow_code(i)) for i in random_matrices] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*random_variables)) + + e = expr.subs(dict(zip(variables, random_matrices))) + e = e.doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_scalar( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).evalf().doit() + assert abs(r-e) < 10**-6 + + +def _compare_tensorflow_relational( + variables, expr, rng=lambda: random.randint(0, 10)): + f = lambdify(variables, expr, 'tensorflow') + rvs = [rng() for v in variables] + + graph = tf.Graph() + r = None + with graph.as_default(): + tf_rvs = [eval(tensorflow_code(i)) for i in rvs] + session = tf.compat.v1.Session(graph=graph) + r = session.run(f(*tf_rvs)) + + e = expr.subs(dict(zip(variables, rvs))).doit() + assert r == e + + +def test_tensorflow_printing(): + assert tensorflow_code(eye(3)) == \ + "tensorflow.constant([[1, 0, 0], [0, 1, 0], [0, 0, 1]])" + + expr = Matrix([[x, sin(y)], [exp(z), -t]]) + assert tensorflow_code(expr) == \ + "tensorflow.Variable(" \ + "[[x, tensorflow.math.sin(y)]," \ + " [tensorflow.math.exp(z), -t]])" + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_math(): + if not tf: + skip("TensorFlow not installed") + + expr = Abs(x) + assert tensorflow_code(expr) == "tensorflow.math.abs(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = sign(x) + assert tensorflow_code(expr) == "tensorflow.math.sign(x)" + _compare_tensorflow_scalar((x,), expr) + + expr = ceiling(x) + assert tensorflow_code(expr) == "tensorflow.math.ceil(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = floor(x) + assert tensorflow_code(expr) == "tensorflow.math.floor(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = exp(x) + assert tensorflow_code(expr) == "tensorflow.math.exp(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = sqrt(x) + assert tensorflow_code(expr) == "tensorflow.math.sqrt(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = x ** 4 + assert tensorflow_code(expr) == "tensorflow.math.pow(x, 4)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = cos(x) + assert tensorflow_code(expr) == "tensorflow.math.cos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acos(x) + assert tensorflow_code(expr) == "tensorflow.math.acos(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(0, 0.95)) + + expr = sin(x) + assert tensorflow_code(expr) == "tensorflow.math.sin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = asin(x) + assert tensorflow_code(expr) == "tensorflow.math.asin(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = tan(x) + assert tensorflow_code(expr) == "tensorflow.math.tan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan(x) + assert tensorflow_code(expr) == "tensorflow.math.atan(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = atan2(y, x) + assert tensorflow_code(expr) == "tensorflow.math.atan2(y, x)" + _compare_tensorflow_scalar((y, x), expr, rng=lambda: random.random()) + + expr = cosh(x) + assert tensorflow_code(expr) == "tensorflow.math.cosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.random()) + + expr = acosh(x) + assert tensorflow_code(expr) == "tensorflow.math.acosh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = sinh(x) + assert tensorflow_code(expr) == "tensorflow.math.sinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = asinh(x) + assert tensorflow_code(expr) == "tensorflow.math.asinh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = tanh(x) + assert tensorflow_code(expr) == "tensorflow.math.tanh(x)" + _compare_tensorflow_scalar((x,), expr, rng=lambda: random.uniform(1, 2)) + + expr = atanh(x) + assert tensorflow_code(expr) == "tensorflow.math.atanh(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.uniform(-.5, .5)) + + expr = erf(x) + assert tensorflow_code(expr) == "tensorflow.math.erf(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + expr = loggamma(x) + assert tensorflow_code(expr) == "tensorflow.math.lgamma(x)" + _compare_tensorflow_scalar( + (x,), expr, rng=lambda: random.random()) + + +def test_tensorflow_complexes(): + assert tensorflow_code(re(x)) == "tensorflow.math.real(x)" + assert tensorflow_code(im(x)) == "tensorflow.math.imag(x)" + assert tensorflow_code(arg(x)) == "tensorflow.math.angle(x)" + + +def test_tensorflow_relational(): + if not tf: + skip("TensorFlow not installed") + + expr = Eq(x, y) + assert tensorflow_code(expr) == "tensorflow.math.equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ne(x, y) + assert tensorflow_code(expr) == "tensorflow.math.not_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Ge(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Gt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.greater(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Le(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less_equal(x, y)" + _compare_tensorflow_relational((x, y), expr) + + expr = Lt(x, y) + assert tensorflow_code(expr) == "tensorflow.math.less(x, y)" + _compare_tensorflow_relational((x, y), expr) + + +# This (random) test is XFAIL because it fails occasionally +# See https://github.com/sympy/sympy/issues/18469 +@XFAIL +def test_tensorflow_matrices(): + if not tf: + skip("TensorFlow not installed") + + expr = M + assert tensorflow_code(expr) == "M" + _compare_tensorflow_matrix((M,), expr) + + expr = M + N + assert tensorflow_code(expr) == "tensorflow.math.add(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M * N + assert tensorflow_code(expr) == "tensorflow.linalg.matmul(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = HadamardProduct(M, N) + assert tensorflow_code(expr) == "tensorflow.math.multiply(M, N)" + _compare_tensorflow_matrix((M, N), expr) + + expr = M*N*P*Q + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(" \ + "tensorflow.linalg.matmul(M, N), P), Q)" + _compare_tensorflow_matrix((M, N, P, Q), expr) + + expr = M**3 + assert tensorflow_code(expr) == \ + "tensorflow.linalg.matmul(tensorflow.linalg.matmul(M, M), M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Trace(M) + assert tensorflow_code(expr) == "tensorflow.linalg.trace(M)" + _compare_tensorflow_matrix((M,), expr) + + expr = Determinant(M) + assert tensorflow_code(expr) == "tensorflow.linalg.det(M)" + _compare_tensorflow_matrix_scalar((M,), expr) + + expr = Inverse(M) + assert tensorflow_code(expr) == "tensorflow.linalg.inv(M)" + _compare_tensorflow_matrix_inverse((M,), expr, use_float=True) + + expr = M.T + assert tensorflow_code(expr, tensorflow_version='1.14') == \ + "tensorflow.linalg.matrix_transpose(M)" + assert tensorflow_code(expr, tensorflow_version='1.13') == \ + "tensorflow.matrix_transpose(M)" + + _compare_tensorflow_matrix((M,), expr) + + +def test_codegen_einsum(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session(graph=graph) + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + + cg = convert_matrix_to_array(M * N) + f = lambdify((M, N), cg, 'tensorflow') + + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + y = session.run(f(ma, mb)) + c = session.run(tf.matmul(ma, mb)) + assert (y == c).all() + + +def test_codegen_extra(): + if not tf: + skip("TensorFlow not installed") + + graph = tf.Graph() + with graph.as_default(): + session = tf.compat.v1.Session() + + M = MatrixSymbol("M", 2, 2) + N = MatrixSymbol("N", 2, 2) + P = MatrixSymbol("P", 2, 2) + Q = MatrixSymbol("Q", 2, 2) + ma = tf.constant([[1, 2], [3, 4]]) + mb = tf.constant([[1,-2], [-1, 3]]) + mc = tf.constant([[2, 0], [1, 2]]) + md = tf.constant([[1,-1], [4, 7]]) + + cg = ArrayTensorProduct(M, N) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,cd", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ij,kl", ma, mb)) + assert (y == c).all() + + cg = ArrayAdd(M, N) + assert tensorflow_code(cg) == 'tensorflow.math.add(M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(ma + mb) + assert (y == c).all() + + cg = ArrayAdd(M, N, P) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P)' + f = lambdify((M, N, P), cg, 'tensorflow') + y = session.run(f(ma, mb, mc)) + c = session.run(ma + mb + mc) + assert (y == c).all() + + cg = ArrayAdd(M, N, P, Q) + assert tensorflow_code(cg) == \ + 'tensorflow.math.add(' \ + 'tensorflow.math.add(tensorflow.math.add(M, N), P), Q)' + f = lambdify((M, N, P, Q), cg, 'tensorflow') + y = session.run(f(ma, mb, mc, md)) + c = session.run(ma + mb + mc + md) + assert (y == c).all() + + cg = PermuteDims(M, [1, 0]) + assert tensorflow_code(cg) == 'tensorflow.transpose(M, [1, 0])' + f = lambdify((M,), cg, 'tensorflow') + y = session.run(f(ma)) + c = session.run(tf.transpose(ma)) + assert (y == c).all() + + cg = PermuteDims(ArrayTensorProduct(M, N), [1, 2, 3, 0]) + assert tensorflow_code(cg) == \ + 'tensorflow.transpose(' \ + 'tensorflow.linalg.einsum("ab,cd", M, N), [1, 2, 3, 0])' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.transpose(tf.einsum("ab,cd", ma, mb), [1, 2, 3, 0])) + assert (y == c).all() + + cg = ArrayDiagonal(ArrayTensorProduct(M, N), (1, 2)) + assert tensorflow_code(cg) == \ + 'tensorflow.linalg.einsum("ab,bc->acb", M, N)' + f = lambdify((M, N), cg, 'tensorflow') + y = session.run(f(ma, mb)) + c = session.run(tf.einsum("ab,bc->acb", ma, mb)) + assert (y == c).all() + + +def test_MatrixElement_printing(): + A = MatrixSymbol("A", 1, 3) + B = MatrixSymbol("B", 1, 3) + C = MatrixSymbol("C", 1, 3) + + assert tensorflow_code(A[0, 0]) == "A[0, 0]" + assert tensorflow_code(3 * A[0, 0]) == "3*A[0, 0]" + + F = C[0, 0].subs(C, A - B) + assert tensorflow_code(F) == "(tensorflow.math.add((-1)*B, A))[0, 0]" + + +def test_tensorflow_Derivative(): + expr = Derivative(sin(x), x) + assert tensorflow_code(expr) == \ + "tensorflow.gradients(tensorflow.math.sin(x), x)[0]" diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_theanocode.py b/lib/python3.10/site-packages/sympy/printing/tests/test_theanocode.py new file mode 100644 index 0000000000000000000000000000000000000000..6ff40f78cb4de16149cb5e780756b7e32b574b71 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_theanocode.py @@ -0,0 +1,639 @@ +""" +Important note on tests in this module - the Theano printing functions use a +global cache by default, which means that tests using it will modify global +state and thus not be independent from each other. Instead of using the "cache" +keyword argument each time, this module uses the theano_code_ and +theano_function_ functions defined below which default to using a new, empty +cache instead. +""" + +import logging + +from sympy.external import import_module +from sympy.testing.pytest import raises, SKIP, warns_deprecated_sympy + +theanologger = logging.getLogger('theano.configdefaults') +theanologger.setLevel(logging.CRITICAL) +theano = import_module('theano') +theanologger.setLevel(logging.WARNING) + + +if theano: + import numpy as np + ts = theano.scalar + tt = theano.tensor + xt, yt, zt = [tt.scalar(name, 'floatX') for name in 'xyz'] + Xt, Yt, Zt = [tt.tensor('floatX', (False, False), name=n) for n in 'XYZ'] +else: + #bin/test will not execute any tests now + disabled = True + +import sympy as sy +from sympy.core.singleton import S +from sympy.abc import x, y, z, t +from sympy.printing.theanocode import (theano_code, dim_handling, + theano_function) + + +# Default set of matrix symbols for testing - make square so we can both +# multiply and perform elementwise operations between them. +X, Y, Z = [sy.MatrixSymbol(n, 4, 4) for n in 'XYZ'] + +# For testing AppliedUndef +f_t = sy.Function('f')(t) + + +def theano_code_(expr, **kwargs): + """ Wrapper for theano_code that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_code(expr, **kwargs) + +def theano_function_(inputs, outputs, **kwargs): + """ Wrapper for theano_function that uses a new, empty cache by default. """ + kwargs.setdefault('cache', {}) + with warns_deprecated_sympy(): + return theano_function(inputs, outputs, **kwargs) + + +def fgraph_of(*exprs): + """ Transform SymPy expressions into Theano Computation. + + Parameters + ========== + exprs + SymPy expressions + + Returns + ======= + theano.gof.FunctionGraph + """ + outs = list(map(theano_code_, exprs)) + ins = theano.gof.graph.inputs(outs) + ins, outs = theano.gof.graph.clone(ins, outs) + return theano.gof.FunctionGraph(ins, outs) + + +def theano_simplify(fgraph): + """ Simplify a Theano Computation. + + Parameters + ========== + fgraph : theano.gof.FunctionGraph + + Returns + ======= + theano.gof.FunctionGraph + """ + mode = theano.compile.get_default_mode().excluding("fusion") + fgraph = fgraph.clone() + mode.optimizer.optimize(fgraph) + return fgraph + + +def theq(a, b): + """ Test two Theano objects for equality. + + Also accepts numeric types and lists/tuples of supported types. + + Note - debugprint() has a bug where it will accept numeric types but does + not respect the "file" argument and in this case and instead prints the number + to stdout and returns an empty string. This can lead to tests passing where + they should fail because any two numbers will always compare as equal. To + prevent this we treat numbers as a separate case. + """ + numeric_types = (int, float, np.number) + a_is_num = isinstance(a, numeric_types) + b_is_num = isinstance(b, numeric_types) + + # Compare numeric types using regular equality + if a_is_num or b_is_num: + if not (a_is_num and b_is_num): + return False + + return a == b + + # Compare sequences element-wise + a_is_seq = isinstance(a, (tuple, list)) + b_is_seq = isinstance(b, (tuple, list)) + + if a_is_seq or b_is_seq: + if not (a_is_seq and b_is_seq) or type(a) != type(b): + return False + + return list(map(theq, a)) == list(map(theq, b)) + + # Otherwise, assume debugprint() can handle it + astr = theano.printing.debugprint(a, file='str') + bstr = theano.printing.debugprint(b, file='str') + + # Check for bug mentioned above + for argname, argval, argstr in [('a', a, astr), ('b', b, bstr)]: + if argstr == '': + raise TypeError( + 'theano.printing.debugprint(%s) returned empty string ' + '(%s is instance of %r)' + % (argname, argname, type(argval)) + ) + + return astr == bstr + + +def test_example_symbols(): + """ + Check that the example symbols in this module print to their Theano + equivalents, as many of the other tests depend on this. + """ + assert theq(xt, theano_code_(x)) + assert theq(yt, theano_code_(y)) + assert theq(zt, theano_code_(z)) + assert theq(Xt, theano_code_(X)) + assert theq(Yt, theano_code_(Y)) + assert theq(Zt, theano_code_(Z)) + + +def test_Symbol(): + """ Test printing a Symbol to a theano variable. """ + xx = theano_code_(x) + assert isinstance(xx, (tt.TensorVariable, ts.ScalarVariable)) + assert xx.broadcastable == () + assert xx.name == x.name + + xx2 = theano_code_(x, broadcastables={x: (False,)}) + assert xx2.broadcastable == (False,) + assert xx2.name == x.name + +def test_MatrixSymbol(): + """ Test printing a MatrixSymbol to a theano variable. """ + XX = theano_code_(X) + assert isinstance(XX, tt.TensorVariable) + assert XX.broadcastable == (False, False) + +@SKIP # TODO - this is currently not checked but should be implemented +def test_MatrixSymbol_wrong_dims(): + """ Test MatrixSymbol with invalid broadcastable. """ + bcs = [(), (False,), (True,), (True, False), (False, True,), (True, True)] + for bc in bcs: + with raises(ValueError): + theano_code_(X, broadcastables={X: bc}) + +def test_AppliedUndef(): + """ Test printing AppliedUndef instance, which works similarly to Symbol. """ + ftt = theano_code_(f_t) + assert isinstance(ftt, tt.TensorVariable) + assert ftt.broadcastable == () + assert ftt.name == 'f_t' + + +def test_add(): + expr = x + y + comp = theano_code_(expr) + assert comp.owner.op == theano.tensor.add + +def test_trig(): + assert theq(theano_code_(sy.sin(x)), tt.sin(xt)) + assert theq(theano_code_(sy.tan(x)), tt.tan(xt)) + +def test_many(): + """ Test printing a complex expression with multiple symbols. """ + expr = sy.exp(x**2 + sy.cos(y)) * sy.log(2*z) + comp = theano_code_(expr) + expected = tt.exp(xt**2 + tt.cos(yt)) * tt.log(2*zt) + assert theq(comp, expected) + + +def test_dtype(): + """ Test specifying specific data types through the dtype argument. """ + for dtype in ['float32', 'float64', 'int8', 'int16', 'int32', 'int64']: + assert theano_code_(x, dtypes={x: dtype}).type.dtype == dtype + + # "floatX" type + assert theano_code_(x, dtypes={x: 'floatX'}).type.dtype in ('float32', 'float64') + + # Type promotion + assert theano_code_(x + 1, dtypes={x: 'float32'}).type.dtype == 'float32' + assert theano_code_(x + y, dtypes={x: 'float64', y: 'float32'}).type.dtype == 'float64' + + +def test_broadcastables(): + """ Test the "broadcastables" argument when printing symbol-like objects. """ + + # No restrictions on shape + for s in [x, f_t]: + for bc in [(), (False,), (True,), (False, False), (True, False)]: + assert theano_code_(s, broadcastables={s: bc}).broadcastable == bc + + # TODO - matrix broadcasting? + +def test_broadcasting(): + """ Test "broadcastable" attribute after applying element-wise binary op. """ + + expr = x + y + + cases = [ + [(), (), ()], + [(False,), (False,), (False,)], + [(True,), (False,), (False,)], + [(False, True), (False, False), (False, False)], + [(True, False), (False, False), (False, False)], + ] + + for bc1, bc2, bc3 in cases: + comp = theano_code_(expr, broadcastables={x: bc1, y: bc2}) + assert comp.broadcastable == bc3 + + +def test_MatMul(): + expr = X*Y*Z + expr_t = theano_code_(expr) + assert isinstance(expr_t.owner.op, tt.Dot) + assert theq(expr_t, Xt.dot(Yt).dot(Zt)) + +def test_Transpose(): + assert isinstance(theano_code_(X.T).owner.op, tt.DimShuffle) + +def test_MatAdd(): + expr = X+Y+Z + assert isinstance(theano_code_(expr).owner.op, tt.Elemwise) + + +def test_Rationals(): + assert theq(theano_code_(sy.Integer(2) / 3), tt.true_div(2, 3)) + assert theq(theano_code_(S.Half), tt.true_div(1, 2)) + +def test_Integers(): + assert theano_code_(sy.Integer(3)) == 3 + +def test_factorial(): + n = sy.Symbol('n') + assert theano_code_(sy.factorial(n)) + +def test_Derivative(): + simp = lambda expr: theano_simplify(fgraph_of(expr)) + assert theq(simp(theano_code_(sy.Derivative(sy.sin(x), x, evaluate=False))), + simp(theano.grad(tt.sin(xt), xt))) + + +def test_theano_function_simple(): + """ Test theano_function() with single output. """ + f = theano_function_([x, y], [x+y]) + assert f(2, 3) == 5 + +def test_theano_function_multi(): + """ Test theano_function() with multiple outputs. """ + f = theano_function_([x, y], [x+y, x-y]) + o1, o2 = f(2, 3) + assert o1 == 5 + assert o2 == -1 + +def test_theano_function_numpy(): + """ Test theano_function() vs Numpy implementation. """ + f = theano_function_([x, y], [x+y], dim=1, + dtypes={x: 'float64', y: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y], [x+y], dtypes={x: 'float64', y: 'float64'}, + dim=1) + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy) - 3*np.arange(3)) < 1e-9 + + +def test_theano_function_matrix(): + m = sy.Matrix([[x, y], [z, x + y + z]]) + expected = np.array([[1.0, 2.0], [3.0, 1.0 + 2.0 + 3.0]]) + f = theano_function_([x, y, z], [m]) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m], scalar=True) + np.testing.assert_allclose(f(1.0, 2.0, 3.0), expected) + f = theano_function_([x, y, z], [m, m]) + assert isinstance(f(1.0, 2.0, 3.0), type([])) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[0], expected) + np.testing.assert_allclose(f(1.0, 2.0, 3.0)[1], expected) + +def test_dim_handling(): + assert dim_handling([x], dim=2) == {x: (False, False)} + assert dim_handling([x, y], dims={x: 1, y: 2}) == {x: (False, True), + y: (False, False)} + assert dim_handling([x], broadcastables={x: (False,)}) == {x: (False,)} + +def test_theano_function_kwargs(): + """ + Test passing additional kwargs from theano_function() to theano.function(). + """ + import numpy as np + f = theano_function_([x, y, z], [x+y], dim=1, on_unused_input='ignore', + dtypes={x: 'float64', y: 'float64', z: 'float64'}) + assert np.linalg.norm(f([1, 2], [3, 4], [0, 0]) - np.asarray([4, 6])) < 1e-9 + + f = theano_function_([x, y, z], [x+y], + dtypes={x: 'float64', y: 'float64', z: 'float64'}, + dim=1, on_unused_input='ignore') + xx = np.arange(3).astype('float64') + yy = 2*np.arange(3).astype('float64') + zz = 2*np.arange(3).astype('float64') + assert np.linalg.norm(f(xx, yy, zz) - 3*np.arange(3)) < 1e-9 + +def test_theano_function_scalar(): + """ Test the "scalar" argument to theano_function(). """ + + args = [ + ([x, y], [x + y], None, [0]), # Single 0d output + ([X, Y], [X + Y], None, [2]), # Single 2d output + ([x, y], [x + y], {x: 0, y: 1}, [1]), # Single 1d output + ([x, y], [x + y, x - y], None, [0, 0]), # Two 0d outputs + ([x, y, X, Y], [x + y, X + Y], None, [0, 2]), # One 0d output, one 2d + ] + + # Create and test functions with and without the scalar setting + for inputs, outputs, in_dims, out_dims in args: + for scalar in [False, True]: + + f = theano_function_(inputs, outputs, dims=in_dims, scalar=scalar) + + # Check the theano_function attribute is set whether wrapped or not + assert isinstance(f.theano_function, theano.compile.function_module.Function) + + # Feed in inputs of the appropriate size and get outputs + in_values = [ + np.ones([1 if bc else 5 for bc in i.type.broadcastable]) + for i in f.theano_function.input_storage + ] + out_values = f(*in_values) + if not isinstance(out_values, list): + out_values = [out_values] + + # Check output types and shapes + assert len(out_dims) == len(out_values) + for d, value in zip(out_dims, out_values): + + if scalar and d == 0: + # Should have been converted to a scalar value + assert isinstance(value, np.number) + + else: + # Otherwise should be an array + assert isinstance(value, np.ndarray) + assert value.ndim == d + +def test_theano_function_bad_kwarg(): + """ + Passing an unknown keyword argument to theano_function() should raise an + exception. + """ + raises(Exception, lambda : theano_function_([x], [x+1], foobar=3)) + + +def test_slice(): + assert theano_code_(slice(1, 2, 3)) == slice(1, 2, 3) + + def theq_slice(s1, s2): + for attr in ['start', 'stop', 'step']: + a1 = getattr(s1, attr) + a2 = getattr(s2, attr) + if a1 is None or a2 is None: + if not (a1 is None or a2 is None): + return False + elif not theq(a1, a2): + return False + return True + + dtypes = {x: 'int32', y: 'int32'} + assert theq_slice(theano_code_(slice(x, y), dtypes=dtypes), slice(xt, yt)) + assert theq_slice(theano_code_(slice(1, x, 3), dtypes=dtypes), slice(1, xt, 3)) + +def test_MatrixSlice(): + from theano import Constant + + cache = {} + + n = sy.Symbol('n', integer=True) + X = sy.MatrixSymbol('X', n, n) + + Y = X[1:2:3, 4:5:6] + Yt = theano_code_(Y, cache=cache) + + s = ts.Scalar('int64') + assert tuple(Yt.owner.op.idx_list) == (slice(s, s, s), slice(s, s, s)) + assert Yt.owner.inputs[0] == theano_code_(X, cache=cache) + # == doesn't work in theano like it does in SymPy. You have to use + # equals. + assert all(Yt.owner.inputs[i].equals(Constant(s, i)) for i in range(1, 7)) + + k = sy.Symbol('k') + theano_code_(k, dtypes={k: 'int32'}) + start, stop, step = 4, k, 2 + Y = X[start:stop:step] + Yt = theano_code_(Y, dtypes={n: 'int32', k: 'int32'}) + # assert Yt.owner.op.idx_list[0].stop == kt + +def test_BlockMatrix(): + n = sy.Symbol('n', integer=True) + A, B, C, D = [sy.MatrixSymbol(name, n, n) for name in 'ABCD'] + At, Bt, Ct, Dt = map(theano_code_, (A, B, C, D)) + Block = sy.BlockMatrix([[A, B], [C, D]]) + Blockt = theano_code_(Block) + solutions = [tt.join(0, tt.join(1, At, Bt), tt.join(1, Ct, Dt)), + tt.join(1, tt.join(0, At, Ct), tt.join(0, Bt, Dt))] + assert any(theq(Blockt, solution) for solution in solutions) + +@SKIP +def test_BlockMatrix_Inverse_execution(): + k, n = 2, 4 + dtype = 'float32' + A = sy.MatrixSymbol('A', n, k) + B = sy.MatrixSymbol('B', n, n) + inputs = A, B + output = B.I*A + + cutsizes = {A: [(n//2, n//2), (k//2, k//2)], + B: [(n//2, n//2), (n//2, n//2)]} + cutinputs = [sy.blockcut(i, *cutsizes[i]) for i in inputs] + cutoutput = output.subs(dict(zip(inputs, cutinputs))) + + dtypes = dict(zip(inputs, [dtype]*len(inputs))) + f = theano_function_(inputs, [output], dtypes=dtypes, cache={}) + fblocked = theano_function_(inputs, [sy.block_collapse(cutoutput)], + dtypes=dtypes, cache={}) + + ninputs = [np.random.rand(*x.shape).astype(dtype) for x in inputs] + ninputs = [np.arange(n*k).reshape(A.shape).astype(dtype), + np.eye(n).astype(dtype)] + ninputs[1] += np.ones(B.shape)*1e-5 + + assert np.allclose(f(*ninputs), fblocked(*ninputs), rtol=1e-5) + +def test_DenseMatrix(): + t = sy.Symbol('theta') + for MatrixType in [sy.Matrix, sy.ImmutableMatrix]: + X = MatrixType([[sy.cos(t), -sy.sin(t)], [sy.sin(t), sy.cos(t)]]) + tX = theano_code_(X) + assert isinstance(tX, tt.TensorVariable) + assert tX.owner.op == tt.join_ + + +def test_cache_basic(): + """ Test single symbol-like objects are cached when printed by themselves. """ + + # Pairs of objects which should be considered equivalent with respect to caching + pairs = [ + (x, sy.Symbol('x')), + (X, sy.MatrixSymbol('X', *X.shape)), + (f_t, sy.Function('f')(sy.Symbol('t'))), + ] + + for s1, s2 in pairs: + cache = {} + st = theano_code_(s1, cache=cache) + + # Test hit with same instance + assert theano_code_(s1, cache=cache) is st + + # Test miss with same instance but new cache + assert theano_code_(s1, cache={}) is not st + + # Test hit with different but equivalent instance + assert theano_code_(s2, cache=cache) is st + +def test_global_cache(): + """ Test use of the global cache. """ + from sympy.printing.theanocode import global_cache + + backup = dict(global_cache) + try: + # Temporarily empty global cache + global_cache.clear() + + for s in [x, X, f_t]: + with warns_deprecated_sympy(): + st = theano_code(s) + assert theano_code(s) is st + + finally: + # Restore global cache + global_cache.update(backup) + +def test_cache_types_distinct(): + """ + Test that symbol-like objects of different types (Symbol, MatrixSymbol, + AppliedUndef) are distinguished by the cache even if they have the same + name. + """ + symbols = [sy.Symbol('f_t'), sy.MatrixSymbol('f_t', 4, 4), f_t] + + cache = {} # Single shared cache + printed = {} + + for s in symbols: + st = theano_code_(s, cache=cache) + assert st not in printed.values() + printed[s] = st + + # Check all printed objects are distinct + assert len(set(map(id, printed.values()))) == len(symbols) + + # Check retrieving + for s, st in printed.items(): + with warns_deprecated_sympy(): + assert theano_code(s, cache=cache) is st + +def test_symbols_are_created_once(): + """ + Test that a symbol is cached and reused when it appears in an expression + more than once. + """ + expr = sy.Add(x, x, evaluate=False) + comp = theano_code_(expr) + + assert theq(comp, xt + xt) + assert not theq(comp, xt + theano_code_(x)) + +def test_cache_complex(): + """ + Test caching on a complicated expression with multiple symbols appearing + multiple times. + """ + expr = x ** 2 + (y - sy.exp(x)) * sy.sin(z - x * y) + symbol_names = {s.name for s in expr.free_symbols} + expr_t = theano_code_(expr) + + # Iterate through variables in the Theano computational graph that the + # printed expression depends on + seen = set() + for v in theano.gof.graph.ancestors([expr_t]): + # Owner-less, non-constant variables should be our symbols + if v.owner is None and not isinstance(v, theano.gof.graph.Constant): + # Check it corresponds to a symbol and appears only once + assert v.name in symbol_names + assert v.name not in seen + seen.add(v.name) + + # Check all were present + assert seen == symbol_names + + +def test_Piecewise(): + # A piecewise linear + expr = sy.Piecewise((0, x<0), (x, x<2), (1, True)) # ___/III + result = theano_code_(expr) + assert result.owner.op == tt.switch + + expected = tt.switch(xt<0, 0, tt.switch(xt<2, xt, 1)) + assert theq(result, expected) + + expr = sy.Piecewise((x, x < 0)) + result = theano_code_(expr) + expected = tt.switch(xt < 0, xt, np.nan) + assert theq(result, expected) + + expr = sy.Piecewise((0, sy.And(x>0, x<2)), \ + (x, sy.Or(x>2, x<0))) + result = theano_code_(expr) + expected = tt.switch(tt.and_(xt>0,xt<2), 0, \ + tt.switch(tt.or_(xt>2, xt<0), xt, np.nan)) + assert theq(result, expected) + + +def test_Relationals(): + assert theq(theano_code_(sy.Eq(x, y)), tt.eq(xt, yt)) + # assert theq(theano_code_(sy.Ne(x, y)), tt.neq(xt, yt)) # TODO - implement + assert theq(theano_code_(x > y), xt > yt) + assert theq(theano_code_(x < y), xt < yt) + assert theq(theano_code_(x >= y), xt >= yt) + assert theq(theano_code_(x <= y), xt <= yt) + + +def test_complexfunctions(): + with warns_deprecated_sympy(): + xt, yt = theano_code_(x, dtypes={x:'complex128'}), theano_code_(y, dtypes={y: 'complex128'}) + from sympy.functions.elementary.complexes import conjugate + from theano.tensor import as_tensor_variable as atv + from theano.tensor import complex as cplx + with warns_deprecated_sympy(): + assert theq(theano_code_(y*conjugate(x)), yt*(xt.conj())) + assert theq(theano_code_((1+2j)*x), xt*(atv(1.0)+atv(2.0)*cplx(0,1))) + + +def test_constantfunctions(): + with warns_deprecated_sympy(): + tf = theano_function_([],[1+1j]) + assert(tf()==1+1j) + + +def test_Exp1(): + """ + Test that exp(1) prints without error and evaluates close to SymPy's E + """ + # sy.exp(1) should yield same instance of E as sy.E (singleton), but extra + # check added for sanity + e_a = sy.exp(1) + e_b = sy.E + + np.testing.assert_allclose(float(e_a), np.e) + np.testing.assert_allclose(float(e_b), np.e) + + e = theano_code_(e_a) + np.testing.assert_allclose(float(e_a), e.eval()) + + e = theano_code_(e_b) + np.testing.assert_allclose(float(e_b), e.eval()) diff --git a/lib/python3.10/site-packages/sympy/printing/tests/test_tree.py b/lib/python3.10/site-packages/sympy/printing/tests/test_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..cf116d0cac5d38f225815fcd2d4ac90cd0dd96d7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/printing/tests/test_tree.py @@ -0,0 +1,196 @@ +from sympy.printing.tree import tree +from sympy.testing.pytest import XFAIL + + +# Remove this flag after making _assumptions cache deterministic. +@XFAIL +def test_print_tree_MatAdd(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = [ + 'MatAdd: A + B\n', + 'algebraic: False\n', + 'commutative: False\n', + 'complex: False\n', + 'composite: False\n', + 'even: False\n', + 'extended_negative: False\n', + 'extended_nonnegative: False\n', + 'extended_nonpositive: False\n', + 'extended_nonzero: False\n', + 'extended_positive: False\n', + 'extended_real: False\n', + 'imaginary: False\n', + 'integer: False\n', + 'irrational: False\n', + 'negative: False\n', + 'noninteger: False\n', + 'nonnegative: False\n', + 'nonpositive: False\n', + 'nonzero: False\n', + 'odd: False\n', + 'positive: False\n', + 'prime: False\n', + 'rational: False\n', + 'real: False\n', + 'transcendental: False\n', + 'zero: False\n', + '+-MatrixSymbol: A\n', + '| algebraic: False\n', + '| commutative: False\n', + '| complex: False\n', + '| composite: False\n', + '| even: False\n', + '| extended_negative: False\n', + '| extended_nonnegative: False\n', + '| extended_nonpositive: False\n', + '| extended_nonzero: False\n', + '| extended_positive: False\n', + '| extended_real: False\n', + '| imaginary: False\n', + '| integer: False\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: False\n', + '| nonpositive: False\n', + '| nonzero: False\n', + '| odd: False\n', + '| positive: False\n', + '| prime: False\n', + '| rational: False\n', + '| real: False\n', + '| transcendental: False\n', + '| zero: False\n', + '| +-Symbol: A\n', + '| | commutative: True\n', + '| +-Integer: 3\n', + '| | algebraic: True\n', + '| | commutative: True\n', + '| | complex: True\n', + '| | extended_negative: False\n', + '| | extended_nonnegative: True\n', + '| | extended_real: True\n', + '| | finite: True\n', + '| | hermitian: True\n', + '| | imaginary: False\n', + '| | infinite: False\n', + '| | integer: True\n', + '| | irrational: False\n', + '| | negative: False\n', + '| | noninteger: False\n', + '| | nonnegative: True\n', + '| | rational: True\n', + '| | real: True\n', + '| | transcendental: False\n', + '| +-Integer: 3\n', + '| algebraic: True\n', + '| commutative: True\n', + '| complex: True\n', + '| extended_negative: False\n', + '| extended_nonnegative: True\n', + '| extended_real: True\n', + '| finite: True\n', + '| hermitian: True\n', + '| imaginary: False\n', + '| infinite: False\n', + '| integer: True\n', + '| irrational: False\n', + '| negative: False\n', + '| noninteger: False\n', + '| nonnegative: True\n', + '| rational: True\n', + '| real: True\n', + '| transcendental: False\n', + '+-MatrixSymbol: B\n', + ' algebraic: False\n', + ' commutative: False\n', + ' complex: False\n', + ' composite: False\n', + ' even: False\n', + ' extended_negative: False\n', + ' extended_nonnegative: False\n', + ' extended_nonpositive: False\n', + ' extended_nonzero: False\n', + ' extended_positive: False\n', + ' extended_real: False\n', + ' imaginary: False\n', + ' integer: False\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: False\n', + ' nonpositive: False\n', + ' nonzero: False\n', + ' odd: False\n', + ' positive: False\n', + ' prime: False\n', + ' rational: False\n', + ' real: False\n', + ' transcendental: False\n', + ' zero: False\n', + ' +-Symbol: B\n', + ' | commutative: True\n', + ' +-Integer: 3\n', + ' | algebraic: True\n', + ' | commutative: True\n', + ' | complex: True\n', + ' | extended_negative: False\n', + ' | extended_nonnegative: True\n', + ' | extended_real: True\n', + ' | finite: True\n', + ' | hermitian: True\n', + ' | imaginary: False\n', + ' | infinite: False\n', + ' | integer: True\n', + ' | irrational: False\n', + ' | negative: False\n', + ' | noninteger: False\n', + ' | nonnegative: True\n', + ' | rational: True\n', + ' | real: True\n', + ' | transcendental: False\n', + ' +-Integer: 3\n', + ' algebraic: True\n', + ' commutative: True\n', + ' complex: True\n', + ' extended_negative: False\n', + ' extended_nonnegative: True\n', + ' extended_real: True\n', + ' finite: True\n', + ' hermitian: True\n', + ' imaginary: False\n', + ' infinite: False\n', + ' integer: True\n', + ' irrational: False\n', + ' negative: False\n', + ' noninteger: False\n', + ' nonnegative: True\n', + ' rational: True\n', + ' real: True\n', + ' transcendental: False\n' + ] + + assert tree(A + B) == "".join(test_str) + + +def test_print_tree_MatAdd_noassumptions(): + from sympy.matrices.expressions import MatrixSymbol + A = MatrixSymbol('A', 3, 3) + B = MatrixSymbol('B', 3, 3) + + test_str = \ +"""MatAdd: A + B ++-MatrixSymbol: A +| +-Str: A +| +-Integer: 3 +| +-Integer: 3 ++-MatrixSymbol: B + +-Str: B + +-Integer: 3 + +-Integer: 3 +""" + + assert tree(A + B, assumptions=False) == test_str diff --git a/lib/python3.10/site-packages/sympy/sandbox/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sandbox/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d2cb5b78221449919934e3f9b5f604d4cc0637a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sandbox/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sandbox/__pycache__/indexed_integrals.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sandbox/__pycache__/indexed_integrals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6edd7d606c7097d308cfab2d7ef990330ccd030 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sandbox/__pycache__/indexed_integrals.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sandbox/tests/__init__.py b/lib/python3.10/site-packages/sympy/sandbox/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/sandbox/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sandbox/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..986bb1b07257c868edcc9842bc18f330f2b45d11 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sandbox/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sandbox/tests/__pycache__/test_indexed_integrals.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sandbox/tests/__pycache__/test_indexed_integrals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d53fdb4659bada49d2220c35411d9b720e961d6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sandbox/tests/__pycache__/test_indexed_integrals.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sandbox/tests/test_indexed_integrals.py b/lib/python3.10/site-packages/sympy/sandbox/tests/test_indexed_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..61b98f0ffec29e026f6dfe8e16fde8b5818b0b09 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sandbox/tests/test_indexed_integrals.py @@ -0,0 +1,25 @@ +from sympy.sandbox.indexed_integrals import IndexedIntegral +from sympy.core.symbol import symbols +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.tensor.indexed import (Idx, IndexedBase) + + +def test_indexed_integrals(): + A = IndexedBase('A') + i, j = symbols('i j', integer=True) + a1, a2 = symbols('a1:3', cls=Idx) + assert isinstance(a1, Idx) + + assert IndexedIntegral(1, A[i]).doit() == A[i] + assert IndexedIntegral(A[i], A[i]).doit() == A[i] ** 2 / 2 + assert IndexedIntegral(A[j], A[i]).doit() == A[i] * A[j] + assert IndexedIntegral(A[i] * A[j], A[i]).doit() == A[i] ** 2 * A[j] / 2 + assert IndexedIntegral(sin(A[i]), A[i]).doit() == -cos(A[i]) + assert IndexedIntegral(sin(A[j]), A[i]).doit() == sin(A[j]) * A[i] + + assert IndexedIntegral(1, A[a1]).doit() == A[a1] + assert IndexedIntegral(A[a1], A[a1]).doit() == A[a1] ** 2 / 2 + assert IndexedIntegral(A[a2], A[a1]).doit() == A[a1] * A[a2] + assert IndexedIntegral(A[a1] * A[a2], A[a1]).doit() == A[a1] ** 2 * A[a2] / 2 + assert IndexedIntegral(sin(A[a1]), A[a1]).doit() == -cos(A[a1]) + assert IndexedIntegral(sin(A[a2]), A[a1]).doit() == sin(A[a2]) * A[a1] diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..203c7c72d0de94aa97c2bc468349d555bb971238 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/acceleration.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/acceleration.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1ac9d1d9e7f565e9e28ef892125ad838825b4497 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/acceleration.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/approximants.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/approximants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e671a8d37b9deb8ce50d76372ccec603046335c6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/approximants.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/aseries.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/aseries.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8c5c54cafd5d528544ae01da45d632a7da17cc9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/aseries.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/formal.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/formal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1e2b5017d5f99a7d980ad8ced5c983a1134327a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/formal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/fourier.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/fourier.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..17f9c1382256c6dc0d70166ff0623f0dc27bba79 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/fourier.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/gruntz.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/gruntz.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..daeadb6dbd1c30a44ed6c89df5e900f0f6a9db50 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/gruntz.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/kauers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/kauers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a91e198734685445623665a49bb4491ce0c61dd4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/kauers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/limits.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/limits.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d84e03f6aa43cc5dfdb94e618b6102543063752 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/limits.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/limitseq.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/limitseq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc21618e4f1e2c2abe27116b600a59776ebe19e8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/limitseq.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/order.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/order.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e3db5d08ef70ff4530cf02049ef51703c44e07d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/order.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/residues.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/residues.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..941368234f3691479f5cfe7c77a8d63555e687a6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/residues.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/sequences.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/sequences.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b58b291210183690809f012a87a3afea1671fb0b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/sequences.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/series.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/series.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..45fed8a034da531715a5bf850905be4d7d1995e4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/series.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/__pycache__/series_class.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/__pycache__/series_class.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1833c081c15057fbee722b32de049c2590129f7d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/__pycache__/series_class.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/benchmarks/__init__.py b/lib/python3.10/site-packages/sympy/series/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f48f4aefb2a82bc50451ec9e5226fff8e364baf Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9d2eb982fbde38a006fa623dd9b52816f0d76596 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/bench_limit.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/bench_order.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/bench_order.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9de2595da5847afe0d9657ceab69dc93a5bc55a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/benchmarks/__pycache__/bench_order.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/benchmarks/bench_limit.py b/lib/python3.10/site-packages/sympy/series/benchmarks/bench_limit.py new file mode 100644 index 0000000000000000000000000000000000000000..eafc28328848dad4b3ea433537971f5785253afe --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/benchmarks/bench_limit.py @@ -0,0 +1,9 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.series.limits import limit + +x = Symbol('x') + + +def timeit_limit_1x(): + limit(1/x, x, oo) diff --git a/lib/python3.10/site-packages/sympy/series/benchmarks/bench_order.py b/lib/python3.10/site-packages/sympy/series/benchmarks/bench_order.py new file mode 100644 index 0000000000000000000000000000000000000000..1c85fa173dfc2a478792de8ab816c23ba9d408ef --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/benchmarks/bench_order.py @@ -0,0 +1,10 @@ +from sympy.core.add import Add +from sympy.core.symbol import Symbol +from sympy.series.order import O + +x = Symbol('x') +l = [x**i for i in range(1000)] +l.append(O(x**1001)) + +def timeit_order_1x(): + Add(*l) diff --git a/lib/python3.10/site-packages/sympy/series/tests/__init__.py b/lib/python3.10/site-packages/sympy/series/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16b56bcff4149b0f36b1ffdc0a6b7ebea4dca4f0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be70900ee1e5f7c5647c117394640605cd55a41a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_approximants.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac02066ce8e8d706f4b46bb28876b792a6ca4d9e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_aseries.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b7ded9c36fc1d2e563de0c433e9c6f9730529c3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_demidovich.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e8fcc6fb4c6d5c6d379a55f580104c23a2275e53 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_formal.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3024cdb14ad59ce44714ff99c1ce95a2eb99b1d9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_fourier.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f3b50bf7f88d022728f1727a8f71ddf532e179dc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_gruntz.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4e1b0c172e9409ebb7145ec5ab480146a4dcb834 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_kauers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_limits.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_limits.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1870bd302614d18eea93347d321b46c4cb91d8f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_limits.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b11886bfd524984fa10ea65ee728160d7147c29 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_limitseq.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d01b389e2914ffc3d1bcace9150941acf01f91c9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_lseries.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aeb19577a208712f5260c5d83eb35144b473480f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_nseries.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_order.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_order.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..56b446222b8811f23850be33918a1bbfc8d2e81f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_order.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6fe8426ecae2d4beab727d0226572e98e59e18b6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_residues.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5801d3e2ae139d3a43437418f6c5af6277772436 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_sequences.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_series.cpython-310.pyc b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_series.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de4a67620dabb0811539623a11f946779cfd7f0d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/series/tests/__pycache__/test_series.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_approximants.py b/lib/python3.10/site-packages/sympy/series/tests/test_approximants.py new file mode 100644 index 0000000000000000000000000000000000000000..9c03d2ce38add99b0dce8725b6c8d8844b31f76b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_approximants.py @@ -0,0 +1,23 @@ +from sympy.series import approximants +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.combinatorial.numbers import (fibonacci, lucas) + + +def test_approximants(): + x, t = symbols("x,t") + g = [lucas(k) for k in range(16)] + assert list(approximants(g)) == ( + [2, -4/(x - 2), (5*x - 2)/(3*x - 1), (x - 2)/(x**2 + x - 1)] ) + g = [lucas(k)+fibonacci(k+2) for k in range(16)] + assert list(approximants(g)) == ( + [3, -3/(x - 1), (3*x - 3)/(2*x - 1), -3/(x**2 + x - 1)] ) + g = [lucas(k)**2 for k in range(16)] + assert list(approximants(g)) == ( + [4, -16/(x - 4), (35*x - 4)/(9*x - 1), (37*x - 28)/(13*x**2 + 11*x - 7), + (50*x**2 + 63*x - 52)/(37*x**2 + 19*x - 13), + (-x**2 - 7*x + 4)/(x**3 - 2*x**2 - 2*x + 1)] ) + p = [sum(binomial(k,i)*x**i for i in range(k+1)) for k in range(16)] + y = approximants(p, t, simplify=True) + assert next(y) == 1 + assert next(y) == -1/(t*(x + 1) - 1) diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_aseries.py b/lib/python3.10/site-packages/sympy/series/tests/test_aseries.py new file mode 100644 index 0000000000000000000000000000000000000000..055d6b8aef23212a8c8f19475f537a5a2b9e2b1b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_aseries.py @@ -0,0 +1,56 @@ +from sympy.core.function import PoleError +from sympy.core.numbers import oo +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.series.order import O +from sympy.abc import x + +from sympy.testing.pytest import raises + +def test_simple(): + # Gruntz' theses pp. 91 to 96 + # 6.6 + e = sin(1/x + exp(-x)) - sin(1/x) + assert e.aseries(x) == (1/(24*x**4) - 1/(2*x**2) + 1 + O(x**(-6), (x, oo)))*exp(-x) + + e = exp(x) * (exp(1/x + exp(-x)) - exp(1/x)) + assert e.aseries(x, n=4) == 1/(6*x**3) + 1/(2*x**2) + 1/x + 1 + O(x**(-4), (x, oo)) + + e = exp(exp(x) / (1 - 1/x)) + assert e.aseries(x) == exp(exp(x) / (1 - 1/x)) + + # The implementation of bound in aseries is incorrect currently. This test + # should be commented out when that is fixed. + # assert e.aseries(x, bound=3) == exp(exp(x) / x**2)*exp(exp(x) / x)*exp(-exp(x) + exp(x)/(1 - 1/x) - \ + # exp(x) / x - exp(x) / x**2) * exp(exp(x)) + + e = exp(sin(1/x + exp(-exp(x)))) - exp(sin(1/x)) + assert e.aseries(x, n=4) == (-1/(2*x**3) + 1/x + 1 + O(x**(-4), (x, oo)))*exp(-exp(x)) + + e3 = lambda x:exp(exp(exp(x))) + e = e3(x)/e3(x - 1/e3(x)) + assert e.aseries(x, n=3) == 1 + exp(x + exp(x))*exp(-exp(exp(x)))\ + + ((-exp(x)/2 - S.Half)*exp(x + exp(x))\ + + exp(2*x + 2*exp(x))/2)*exp(-2*exp(exp(x))) + O(exp(-3*exp(exp(x))), (x, oo)) + + e = exp(exp(x)) * (exp(sin(1/x + 1/exp(exp(x)))) - exp(sin(1/x))) + assert e.aseries(x, n=4) == -1/(2*x**3) + 1/x + 1 + O(x**(-4), (x, oo)) + + n = Symbol('n', integer=True) + e = (sqrt(n)*log(n)**2*exp(sqrt(log(n))*log(log(n))**2*exp(sqrt(log(log(n)))*log(log(log(n)))**3)))/n + assert e.aseries(n) == \ + exp(exp(sqrt(log(log(n)))*log(log(log(n)))**3)*sqrt(log(n))*log(log(n))**2)*log(n)**2/sqrt(n) + + +def test_hierarchical(): + e = sin(1/x + exp(-x)) + assert e.aseries(x, n=3, hir=True) == -exp(-2*x)*sin(1/x)/2 + \ + exp(-x)*cos(1/x) + sin(1/x) + O(exp(-3*x), (x, oo)) + + e = sin(x) * cos(exp(-x)) + assert e.aseries(x, hir=True) == exp(-4*x)*sin(x)/24 - \ + exp(-2*x)*sin(x)/2 + sin(x) + O(exp(-6*x), (x, oo)) + raises(PoleError, lambda: e.aseries(x)) diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_demidovich.py b/lib/python3.10/site-packages/sympy/series/tests/test_demidovich.py new file mode 100644 index 0000000000000000000000000000000000000000..98cafbae6f019dd3d97d306099d5780ed2f37f04 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_demidovich.py @@ -0,0 +1,143 @@ +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, sin, tan) +from sympy.polys.rationaltools import together +from sympy.series.limits import limit + +# Numbers listed with the tests refer to problem numbers in the book +# "Anti-demidovich, problemas resueltos, Ed. URSS" + +x = Symbol("x") + + +def test_leadterm(): + assert (3 + 2*x**(log(3)/log(2) - 1)).leadterm(x) == (3, 0) + + +def root3(x): + return root(x, 3) + + +def root4(x): + return root(x, 4) + + +def test_Limits_simple_0(): + assert limit((2**(x + 1) + 3**(x + 1))/(2**x + 3**x), x, oo) == 3 # 175 + + +def test_Limits_simple_1(): + assert limit((x + 1)*(x + 2)*(x + 3)/x**3, x, oo) == 1 # 172 + assert limit(sqrt(x + 1) - sqrt(x), x, oo) == 0 # 179 + assert limit((2*x - 3)*(3*x + 5)*(4*x - 6)/(3*x**3 + x - 1), x, oo) == 8 # Primjer 1 + assert limit(x/root3(x**3 + 10), x, oo) == 1 # Primjer 2 + assert limit((x + 1)**2/(x**2 + 1), x, oo) == 1 # 181 + + +def test_Limits_simple_2(): + assert limit(1000*x/(x**2 - 1), x, oo) == 0 # 182 + assert limit((x**2 - 5*x + 1)/(3*x + 7), x, oo) is oo # 183 + assert limit((2*x**2 - x + 3)/(x**3 - 8*x + 5), x, oo) == 0 # 184 + assert limit((2*x**2 - 3*x - 4)/sqrt(x**4 + 1), x, oo) == 2 # 186 + assert limit((2*x + 3)/(x + root3(x)), x, oo) == 2 # 187 + assert limit(x**2/(10 + x*sqrt(x)), x, oo) is oo # 188 + assert limit(root3(x**2 + 1)/(x + 1), x, oo) == 0 # 189 + assert limit(sqrt(x)/sqrt(x + sqrt(x + sqrt(x))), x, oo) == 1 # 190 + + +def test_Limits_simple_3a(): + a = Symbol('a') + #issue 3513 + assert together(limit((x**2 - (a + 1)*x + a)/(x**3 - a**3), x, a)) == \ + (a - 1)/(3*a**2) # 196 + + +def test_Limits_simple_3b(): + h = Symbol("h") + assert limit(((x + h)**3 - x**3)/h, h, 0) == 3*x**2 # 197 + assert limit((1/(1 - x) - 3/(1 - x**3)), x, 1) == -1 # 198 + assert limit((sqrt(1 + x) - 1)/(root3(1 + x) - 1), x, 0) == Rational(3)/2 # Primer 4 + assert limit((sqrt(x) - 1)/(x - 1), x, 1) == Rational(1)/2 # 199 + assert limit((sqrt(x) - 8)/(root3(x) - 4), x, 64) == 3 # 200 + assert limit((root3(x) - 1)/(root4(x) - 1), x, 1) == Rational(4)/3 # 201 + assert limit( + (root3(x**2) - 2*root3(x) + 1)/(x - 1)**2, x, 1) == Rational(1)/9 # 202 + + +def test_Limits_simple_4a(): + a = Symbol('a') + assert limit((sqrt(x) - sqrt(a))/(x - a), x, a) == 1/(2*sqrt(a)) # Primer 5 + assert limit((sqrt(x) - 1)/(root3(x) - 1), x, 1) == Rational(3, 2) # 205 + assert limit((sqrt(1 + x) - sqrt(1 - x))/x, x, 0) == 1 # 207 + assert limit(sqrt(x**2 - 5*x + 6) - x, x, oo) == Rational(-5, 2) # 213 + + +def test_limits_simple_4aa(): + assert limit(x*(sqrt(x**2 + 1) - x), x, oo) == Rational(1)/2 # 214 + + +def test_Limits_simple_4b(): + #issue 3511 + assert limit(x - root3(x**3 - 1), x, oo) == 0 # 215 + + +def test_Limits_simple_4c(): + assert limit(log(1 + exp(x))/x, x, -oo) == 0 # 267a + assert limit(log(1 + exp(x))/x, x, oo) == 1 # 267b + + +def test_bounded(): + assert limit(sin(x)/x, x, oo) == 0 # 216b + assert limit(x*sin(1/x), x, 0) == 0 # 227a + + +def test_f1a(): + #issue 3508: + assert limit((sin(2*x)/x)**(1 + x), x, 0) == 2 # Primer 7 + + +def test_f1a2(): + #issue 3509: + assert limit(((x - 1)/(x + 1))**x, x, oo) == exp(-2) # Primer 9 + + +def test_f1b(): + m = Symbol("m") + n = Symbol("n") + h = Symbol("h") + a = Symbol("a") + assert limit(sin(x)/x, x, 2) == sin(2)/2 # 216a + assert limit(sin(3*x)/x, x, 0) == 3 # 217 + assert limit(sin(5*x)/sin(2*x), x, 0) == Rational(5, 2) # 218 + assert limit(sin(pi*x)/sin(3*pi*x), x, 0) == Rational(1, 3) # 219 + assert limit(x*sin(pi/x), x, oo) == pi # 220 + assert limit((1 - cos(x))/x**2, x, 0) == S.Half # 221 + assert limit(x*sin(1/x), x, oo) == 1 # 227b + assert limit((cos(m*x) - cos(n*x))/x**2, x, 0) == -m**2/2 + n**2/2 # 232 + assert limit((tan(x) - sin(x))/x**3, x, 0) == S.Half # 233 + assert limit((x - sin(2*x))/(x + sin(3*x)), x, 0) == -Rational(1, 4) # 237 + assert limit((1 - sqrt(cos(x)))/x**2, x, 0) == Rational(1, 4) # 239 + assert limit((sqrt(1 + sin(x)) - sqrt(1 - sin(x)))/x, x, 0) == 1 # 240 + + assert limit((1 + h/x)**x, x, oo) == exp(h) # Primer 9 + assert limit((sin(x) - sin(a))/(x - a), x, a) == cos(a) # 222, *176 + assert limit((cos(x) - cos(a))/(x - a), x, a) == -sin(a) # 223 + assert limit((sin(x + h) - sin(x))/h, h, 0) == cos(x) # 225 + + +def test_f2a(): + assert limit(((x + 1)/(2*x + 1))**(x**2), x, oo) == 0 # Primer 8 + + +def test_f2(): + assert limit((sqrt( + cos(x)) - root3(cos(x)))/(sin(x)**2), x, 0) == -Rational(1, 12) # *184 + + +def test_f3(): + a = Symbol('a') + #issue 3504 + assert limit(asin(a*x)/x, x, 0) == a diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_formal.py b/lib/python3.10/site-packages/sympy/series/tests/test_formal.py new file mode 100644 index 0000000000000000000000000000000000000000..0a28418fbb326ce3aa973eecaf8c3b1231f6c767 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_formal.py @@ -0,0 +1,618 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acosh, asech) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin) +from sympy.functions.special.bessel import airyai +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import integrate +from sympy.series.formal import fps +from sympy.series.order import O +from sympy.series.formal import (rational_algorithm, FormalPowerSeries, + FormalPowerSeriesProduct, FormalPowerSeriesCompose, + FormalPowerSeriesInverse, simpleDE, + rational_independent, exp_re, hyper_re) +from sympy.testing.pytest import raises, XFAIL, slow + +x, y, z = symbols('x y z') +n, m, k = symbols('n m k', integer=True) +f, r = Function('f'), Function('r') + + +def test_rational_algorithm(): + f = 1 / ((x - 1)**2 * (x - 2)) + assert rational_algorithm(f, x, k) == \ + (-2**(-k - 1) + 1 - (factorial(k + 1) / factorial(k)), 0, 0) + + f = (1 + x + x**2 + x**3) / ((x - 1) * (x - 2)) + assert rational_algorithm(f, x, k) == \ + (-15*2**(-k - 1) + 4, x + 4, 0) + + f = z / (y*m - m*x - y*x + x**2) + assert rational_algorithm(f, x, k) == \ + (((-y**(-k - 1)*z) / (y - m)) + ((m**(-k - 1)*z) / (y - m)), 0, 0) + + f = x / (1 - x - x**2) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((Rational(-1, 2) + sqrt(5)/2)**(-k - 1) * + (-sqrt(5)/10 + S.Half)) + + ((-sqrt(5)/2 - S.Half)**(-k - 1) * + (sqrt(5)/10 + S.Half)), 0, 0) + + f = 1 / (x**2 + 2*x + 2) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + ((I*(-1 + I)**(-k - 1)) / 2 - (I*(-1 - I)**(-k - 1)) / 2, 0, 0) + + f = log(1 + x) + assert rational_algorithm(f, x, k) == \ + (-(-1)**(-k) / k, 0, 1) + + f = atan(x) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((I*I**(-k)) / 2 - (I*(-I)**(-k)) / 2) / k, 0, 1) + + f = x*atan(x) - log(1 + x**2) / 2 + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + (((I*I**(-k + 1)) / 2 - (I*(-I)**(-k + 1)) / 2) / + (k*(k - 1)), 0, 2) + + f = log((1 + x) / (1 - x)) / 2 - atan(x) + assert rational_algorithm(f, x, k) is None + assert rational_algorithm(f, x, k, full=True) == \ + ((-(-1)**(-k) / 2 - (I*I**(-k)) / 2 + (I*(-I)**(-k)) / 2 + + S.Half) / k, 0, 1) + + assert rational_algorithm(cos(x), x, k) is None + + +def test_rational_independent(): + ri = rational_independent + assert ri([], x) == [] + assert ri([cos(x), sin(x)], x) == [cos(x), sin(x)] + assert ri([x**2, sin(x), x*sin(x), x**3], x) == \ + [x**3 + x**2, x*sin(x) + sin(x)] + assert ri([S.One, x*log(x), log(x), sin(x)/x, cos(x), sin(x), x], x) == \ + [x + 1, x*log(x) + log(x), sin(x)/x + sin(x), cos(x)] + + +def test_simpleDE(): + # Tests just the first valid DE + for DE in simpleDE(exp(x), x, f): + assert DE == (-f(x) + Derivative(f(x), x), 1) + break + for DE in simpleDE(sin(x), x, f): + assert DE == (f(x) + Derivative(f(x), x, x), 2) + break + for DE in simpleDE(log(1 + x), x, f): + assert DE == ((x + 1)*Derivative(f(x), x, 2) + Derivative(f(x), x), 2) + break + for DE in simpleDE(asin(x), x, f): + assert DE == (x*Derivative(f(x), x) + (x**2 - 1)*Derivative(f(x), x, x), + 2) + break + for DE in simpleDE(exp(x)*sin(x), x, f): + assert DE == (2*f(x) - 2*Derivative(f(x)) + Derivative(f(x), x, x), 2) + break + for DE in simpleDE(((1 + x)/(1 - x))**n, x, f): + assert DE == (2*n*f(x) + (x**2 - 1)*Derivative(f(x), x), 1) + break + for DE in simpleDE(airyai(x), x, f): + assert DE == (-x*f(x) + Derivative(f(x), x, x), 2) + break + + +def test_exp_re(): + d = -f(x) + Derivative(f(x), x) + assert exp_re(d, r, k) == -r(k) + r(k + 1) + + d = f(x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 2) + + d = f(x) + Derivative(f(x), x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 1) + r(k + 2) + + d = Derivative(f(x), x) + Derivative(f(x), x, x) + assert exp_re(d, r, k) == r(k) + r(k + 1) + + d = Derivative(f(x), x, 3) + Derivative(f(x), x, 4) + Derivative(f(x)) + assert exp_re(d, r, k) == r(k) + r(k + 2) + r(k + 3) + + +def test_hyper_re(): + d = f(x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == r(k) + (k+1)*(k+2)*r(k + 2) + + d = -x*f(x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == (k + 2)*(k + 3)*r(k + 3) - r(k) + + d = 2*f(x) - 2*Derivative(f(x), x) + Derivative(f(x), x, x) + assert hyper_re(d, r, k) == \ + (-2*k - 2)*r(k + 1) + (k + 1)*(k + 2)*r(k + 2) + 2*r(k) + + d = 2*n*f(x) + (x**2 - 1)*Derivative(f(x), x) + assert hyper_re(d, r, k) == \ + k*r(k) + 2*n*r(k + 1) + (-k - 2)*r(k + 2) + + d = (x**10 + 4)*Derivative(f(x), x) + x*(x**10 - 1)*Derivative(f(x), x, x) + assert hyper_re(d, r, k) == \ + (k*(k - 1) + k)*r(k) + (4*k - (k + 9)*(k + 10) + 40)*r(k + 10) + + d = ((x**2 - 1)*Derivative(f(x), x, 3) + 3*x*Derivative(f(x), x, x) + + Derivative(f(x), x)) + assert hyper_re(d, r, k) == \ + ((k*(k - 2)*(k - 1) + 3*k*(k - 1) + k)*r(k) + + (-k*(k + 1)*(k + 2))*r(k + 2)) + + +def test_fps(): + assert fps(1) == 1 + assert fps(2, x) == 2 + assert fps(2, x, dir='+') == 2 + assert fps(2, x, dir='-') == 2 + assert fps(1/x + 1/x**2) == 1/x + 1/x**2 + assert fps(log(1 + x), hyper=False, rational=False) == log(1 + x) + + f = fps(x**2 + x + 1) + assert isinstance(f, FormalPowerSeries) + assert f.function == x**2 + x + 1 + assert f[0] == 1 + assert f[2] == x**2 + assert f.truncate(4) == x**2 + x + 1 + O(x**4) + assert f.polynomial() == x**2 + x + 1 + + f = fps(log(1 + x)) + assert isinstance(f, FormalPowerSeries) + assert f.function == log(1 + x) + assert f.subs(x, y) == f + assert f[:5] == [0, x, -x**2/2, x**3/3, -x**4/4] + assert f.as_leading_term(x) == x + assert f.polynomial(6) == x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + + k = f.ak.variables[0] + assert f.infinite == Sum((-(-1)**(-k)*x**k)/k, (k, 1, oo)) + + ft, s = f.truncate(n=None), f[:5] + for i, t in enumerate(ft): + if i == 5: + break + assert s[i] == t + + f = sin(x).fps(x) + assert isinstance(f, FormalPowerSeries) + assert f.truncate() == x - x**3/6 + x**5/120 + O(x**6) + + raises(NotImplementedError, lambda: fps(y*x)) + raises(ValueError, lambda: fps(x, dir=0)) + + +@slow +def test_fps__rational(): + assert fps(1/x) == (1/x) + assert fps((x**2 + x + 1) / x**3, dir=-1) == (x**2 + x + 1) / x**3 + + f = 1 / ((x - 1)**2 * (x - 2)) + assert fps(f, x).truncate() == \ + (Rational(-1, 2) - x*Rational(5, 4) - 17*x**2/8 - 49*x**3/16 - 129*x**4/32 - + 321*x**5/64 + O(x**6)) + + f = (1 + x + x**2 + x**3) / ((x - 1) * (x - 2)) + assert fps(f, x).truncate() == \ + (S.Half + x*Rational(5, 4) + 17*x**2/8 + 49*x**3/16 + 113*x**4/32 + + 241*x**5/64 + O(x**6)) + + f = x / (1 - x - x**2) + assert fps(f, x, full=True).truncate() == \ + x + x**2 + 2*x**3 + 3*x**4 + 5*x**5 + O(x**6) + + f = 1 / (x**2 + 2*x + 2) + assert fps(f, x, full=True).truncate() == \ + S.Half - x/2 + x**2/4 - x**4/8 + x**5/8 + O(x**6) + + f = log(1 + x) + assert fps(f, x).truncate() == \ + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + assert fps(f, x, dir=1).truncate() == fps(f, x, dir=-1).truncate() + assert fps(f, x, 2).truncate() == \ + (log(3) - Rational(2, 3) - (x - 2)**2/18 + (x - 2)**3/81 - + (x - 2)**4/324 + (x - 2)**5/1215 + x/3 + O((x - 2)**6, (x, 2))) + assert fps(f, x, 2, dir=-1).truncate() == \ + (log(3) - Rational(2, 3) - (-x + 2)**2/18 - (-x + 2)**3/81 - + (-x + 2)**4/324 - (-x + 2)**5/1215 + x/3 + O((x - 2)**6, (x, 2))) + + f = atan(x) + assert fps(f, x, full=True).truncate() == x - x**3/3 + x**5/5 + O(x**6) + assert fps(f, x, full=True, dir=1).truncate() == \ + fps(f, x, full=True, dir=-1).truncate() + assert fps(f, x, 2, full=True).truncate() == \ + (atan(2) - Rational(2, 5) - 2*(x - 2)**2/25 + 11*(x - 2)**3/375 - + 6*(x - 2)**4/625 + 41*(x - 2)**5/15625 + x/5 + O((x - 2)**6, (x, 2))) + assert fps(f, x, 2, full=True, dir=-1).truncate() == \ + (atan(2) - Rational(2, 5) - 2*(-x + 2)**2/25 - 11*(-x + 2)**3/375 - + 6*(-x + 2)**4/625 - 41*(-x + 2)**5/15625 + x/5 + O((x - 2)**6, (x, 2))) + + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x, full=True).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = log((1 + x) / (1 - x)) / 2 - atan(x) + assert fps(f, x, full=True).truncate(n=10) == 2*x**3/3 + 2*x**7/7 + O(x**10) + + +@slow +def test_fps__hyper(): + f = sin(x) + assert fps(f, x).truncate() == x - x**3/6 + x**5/120 + O(x**6) + + f = cos(x) + assert fps(f, x).truncate() == 1 - x**2/2 + x**4/24 + O(x**6) + + f = exp(x) + assert fps(f, x).truncate() == \ + 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + O(x**6) + + f = atan(x) + assert fps(f, x).truncate() == x - x**3/3 + x**5/5 + O(x**6) + + f = exp(acos(x)) + assert fps(f, x).truncate() == \ + (exp(pi/2) - x*exp(pi/2) + x**2*exp(pi/2)/2 - x**3*exp(pi/2)/3 + + 5*x**4*exp(pi/2)/24 - x**5*exp(pi/2)/6 + O(x**6)) + + f = exp(acosh(x)) + assert fps(f, x).truncate() == I + x - I*x**2/2 - I*x**4/8 + O(x**6) + + f = atan(1/x) + assert fps(f, x).truncate() == pi/2 - x + x**3/3 - x**5/5 + O(x**6) + + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x, rational=False).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = log(1 + x) + assert fps(f, x, rational=False).truncate() == \ + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + + f = airyai(x**2) + assert fps(f, x).truncate() == \ + (3**Rational(5, 6)*gamma(Rational(1, 3))/(6*pi) - + 3**Rational(2, 3)*x**2/(3*gamma(Rational(1, 3))) + O(x**6)) + + f = exp(x)*sin(x) + assert fps(f, x).truncate() == x + x**2 + x**3/3 - x**5/30 + O(x**6) + + f = exp(x)*sin(x)/x + assert fps(f, x).truncate() == 1 + x + x**2/3 - x**4/30 - x**5/90 + O(x**6) + + f = sin(x) * cos(x) + assert fps(f, x).truncate() == x - 2*x**3/3 + 2*x**5/15 + O(x**6) + + +def test_fps_shift(): + f = x**-5*sin(x) + assert fps(f, x).truncate() == \ + 1/x**4 - 1/(6*x**2) + Rational(1, 120) - x**2/5040 + x**4/362880 + O(x**6) + + f = x**2*atan(x) + assert fps(f, x, rational=False).truncate() == \ + x**3 - x**5/3 + O(x**6) + + f = cos(sqrt(x))*x + assert fps(f, x).truncate() == \ + x - x**2/2 + x**3/24 - x**4/720 + x**5/40320 + O(x**6) + + f = x**2*cos(sqrt(x)) + assert fps(f, x).truncate() == \ + x**2 - x**3/2 + x**4/24 - x**5/720 + O(x**6) + + +def test_fps__Add_expr(): + f = x*atan(x) - log(1 + x**2) / 2 + assert fps(f, x).truncate() == x**2/2 - x**4/12 + O(x**6) + + f = sin(x) + cos(x) - exp(x) + log(1 + x) + assert fps(f, x).truncate() == x - 3*x**2/2 - x**4/4 + x**5/5 + O(x**6) + + f = 1/x + sin(x) + assert fps(f, x).truncate() == 1/x + x - x**3/6 + x**5/120 + O(x**6) + + f = sin(x) - cos(x) + 1/(x - 1) + assert fps(f, x).truncate() == \ + -2 - x**2/2 - 7*x**3/6 - 25*x**4/24 - 119*x**5/120 + O(x**6) + + +def test_fps__asymptotic(): + f = exp(x) + assert fps(f, x, oo) == f + assert fps(f, x, -oo).truncate() == O(1/x**6, (x, oo)) + + f = erf(x) + assert fps(f, x, oo).truncate() == 1 + O(1/x**6, (x, oo)) + assert fps(f, x, -oo).truncate() == -1 + O(1/x**6, (x, oo)) + + f = atan(x) + assert fps(f, x, oo, full=True).truncate() == \ + -1/(5*x**5) + 1/(3*x**3) - 1/x + pi/2 + O(1/x**6, (x, oo)) + assert fps(f, x, -oo, full=True).truncate() == \ + -1/(5*x**5) + 1/(3*x**3) - 1/x - pi/2 + O(1/x**6, (x, oo)) + + f = log(1 + x) + assert fps(f, x, oo) != \ + (-1/(5*x**5) - 1/(4*x**4) + 1/(3*x**3) - 1/(2*x**2) + 1/x - log(1/x) + + O(1/x**6, (x, oo))) + assert fps(f, x, -oo) != \ + (-1/(5*x**5) - 1/(4*x**4) + 1/(3*x**3) - 1/(2*x**2) + 1/x + I*pi - + log(-1/x) + O(1/x**6, (x, oo))) + + +def test_fps__fractional(): + f = sin(sqrt(x)) / x + assert fps(f, x).truncate() == \ + (1/sqrt(x) - sqrt(x)/6 + x**Rational(3, 2)/120 - + x**Rational(5, 2)/5040 + x**Rational(7, 2)/362880 - + x**Rational(9, 2)/39916800 + x**Rational(11, 2)/6227020800 + O(x**6)) + + f = sin(sqrt(x)) * x + assert fps(f, x).truncate() == \ + (x**Rational(3, 2) - x**Rational(5, 2)/6 + x**Rational(7, 2)/120 - + x**Rational(9, 2)/5040 + x**Rational(11, 2)/362880 + O(x**6)) + + f = atan(sqrt(x)) / x**2 + assert fps(f, x).truncate() == \ + (x**Rational(-3, 2) - x**Rational(-1, 2)/3 + x**S.Half/5 - + x**Rational(3, 2)/7 + x**Rational(5, 2)/9 - x**Rational(7, 2)/11 + + x**Rational(9, 2)/13 - x**Rational(11, 2)/15 + O(x**6)) + + f = exp(sqrt(x)) + assert fps(f, x).truncate().expand() == \ + (1 + x/2 + x**2/24 + x**3/720 + x**4/40320 + x**5/3628800 + sqrt(x) + + x**Rational(3, 2)/6 + x**Rational(5, 2)/120 + x**Rational(7, 2)/5040 + + x**Rational(9, 2)/362880 + x**Rational(11, 2)/39916800 + O(x**6)) + + f = exp(sqrt(x))*x + assert fps(f, x).truncate().expand() == \ + (x + x**2/2 + x**3/24 + x**4/720 + x**5/40320 + x**Rational(3, 2) + + x**Rational(5, 2)/6 + x**Rational(7, 2)/120 + x**Rational(9, 2)/5040 + + x**Rational(11, 2)/362880 + O(x**6)) + + +def test_fps__logarithmic_singularity(): + f = log(1 + 1/x) + assert fps(f, x) != \ + -log(x) + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + assert fps(f, x, rational=False) != \ + -log(x) + x - x**2/2 + x**3/3 - x**4/4 + x**5/5 + O(x**6) + + +@XFAIL +def test_fps__logarithmic_singularity_fail(): + f = asech(x) # Algorithms for computing limits probably needs improvemnts + assert fps(f, x) == log(2) - log(x) - x**2/4 - 3*x**4/64 + O(x**6) + + +def test_fps_symbolic(): + f = x**n*sin(x**2) + assert fps(f, x).truncate(8) == x**(n + 2) - x**(n + 6)/6 + O(x**(n + 8), x) + + f = x**n*log(1 + x) + fp = fps(f, x) + k = fp.ak.variables[0] + assert fp.infinite == \ + Sum((-(-1)**(-k)*x**(k + n))/k, (k, 1, oo)) + + f = (x - 2)**n*log(1 + x) + assert fps(f, x, 2).truncate() == \ + ((x - 2)**n*log(3) + (x - 2)**(n + 1)/3 - (x - 2)**(n + 2)/18 + (x - 2)**(n + 3)/81 - + (x - 2)**(n + 4)/324 + (x - 2)**(n + 5)/1215 + O((x - 2)**(n + 6), (x, 2))) + + f = x**(n - 2)*cos(x) + assert fps(f, x).truncate() == \ + (x**(n - 2) - x**n/2 + x**(n + 2)/24 + O(x**(n + 4), x)) + + f = x**(n - 2)*sin(x) + x**n*exp(x) + assert fps(f, x).truncate() == \ + (x**(n - 1) + x**(n + 1) + x**(n + 2)/2 + x**n + + x**(n + 4)/24 + x**(n + 5)/60 + O(x**(n + 6), x)) + + f = x**n*atan(x) + assert fps(f, x, oo).truncate() == \ + (-x**(n - 5)/5 + x**(n - 3)/3 + x**n*(pi/2 - 1/x) + + O((1/x)**(-n)/x**6, (x, oo))) + + f = x**(n/2)*cos(x) + assert fps(f, x).truncate() == \ + x**(n/2) - x**(n/2 + 2)/2 + x**(n/2 + 4)/24 + O(x**(n/2 + 6), x) + + f = x**(n + m)*sin(x) + assert fps(f, x).truncate() == \ + x**(m + n + 1) - x**(m + n + 3)/6 + x**(m + n + 5)/120 + O(x**(m + n + 6), x) + + +def test_fps__slow(): + f = x*exp(x)*sin(2*x) # TODO: rsolve needs improvement + assert fps(f, x).truncate() == 2*x**2 + 2*x**3 - x**4/3 - x**5 + O(x**6) + + +def test_fps__operations(): + f1, f2 = fps(sin(x)), fps(cos(x)) + + fsum = f1 + f2 + assert fsum.function == sin(x) + cos(x) + assert fsum.truncate() == \ + 1 + x - x**2/2 - x**3/6 + x**4/24 + x**5/120 + O(x**6) + + fsum = f1 + 1 + assert fsum.function == sin(x) + 1 + assert fsum.truncate() == 1 + x - x**3/6 + x**5/120 + O(x**6) + + fsum = 1 + f2 + assert fsum.function == cos(x) + 1 + assert fsum.truncate() == 2 - x**2/2 + x**4/24 + O(x**6) + + assert (f1 + x) == Add(f1, x) + + assert -f2.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + assert (f1 - f1) is S.Zero + + fsub = f1 - f2 + assert fsub.function == sin(x) - cos(x) + assert fsub.truncate() == \ + -1 + x + x**2/2 - x**3/6 - x**4/24 + x**5/120 + O(x**6) + + fsub = f1 - 1 + assert fsub.function == sin(x) - 1 + assert fsub.truncate() == -1 + x - x**3/6 + x**5/120 + O(x**6) + + fsub = 1 - f2 + assert fsub.function == -cos(x) + 1 + assert fsub.truncate() == x**2/2 - x**4/24 + O(x**6) + + raises(ValueError, lambda: f1 + fps(exp(x), dir=-1)) + raises(ValueError, lambda: f1 + fps(exp(x), x0=1)) + + fm = f1 * 3 + + assert fm.function == 3*sin(x) + assert fm.truncate() == 3*x - x**3/2 + x**5/40 + O(x**6) + + fm = 3 * f2 + + assert fm.function == 3*cos(x) + assert fm.truncate() == 3 - 3*x**2/2 + x**4/8 + O(x**6) + + assert (f1 * f2) == Mul(f1, f2) + assert (f1 * x) == Mul(f1, x) + + fd = f1.diff() + assert fd.function == cos(x) + assert fd.truncate() == 1 - x**2/2 + x**4/24 + O(x**6) + + fd = f2.diff() + assert fd.function == -sin(x) + assert fd.truncate() == -x + x**3/6 - x**5/120 + O(x**6) + + fd = f2.diff().diff() + assert fd.function == -cos(x) + assert fd.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + + f3 = fps(exp(sqrt(x))) + fd = f3.diff() + assert fd.truncate().expand() == \ + (1/(2*sqrt(x)) + S.Half + x/12 + x**2/240 + x**3/10080 + x**4/725760 + + x**5/79833600 + sqrt(x)/4 + x**Rational(3, 2)/48 + x**Rational(5, 2)/1440 + + x**Rational(7, 2)/80640 + x**Rational(9, 2)/7257600 + x**Rational(11, 2)/958003200 + + O(x**6)) + + assert f1.integrate((x, 0, 1)) == -cos(1) + 1 + assert integrate(f1, (x, 0, 1)) == -cos(1) + 1 + + fi = integrate(f1, x) + assert fi.function == -cos(x) + assert fi.truncate() == -1 + x**2/2 - x**4/24 + O(x**6) + + fi = f2.integrate(x) + assert fi.function == sin(x) + assert fi.truncate() == x - x**3/6 + x**5/120 + O(x**6) + +def test_fps__product(): + f1, f2, f3 = fps(sin(x)), fps(exp(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.product(exp(x), x)) + raises(ValueError, lambda: f1.product(fps(exp(x), dir=-1), x, 4)) + raises(ValueError, lambda: f1.product(fps(exp(x), x0=1), x, 4)) + raises(ValueError, lambda: f1.product(fps(exp(y)), x, 4)) + + fprod = f1.product(f2, x) + assert isinstance(fprod, FormalPowerSeriesProduct) + assert isinstance(fprod.ffps, FormalPowerSeries) + assert isinstance(fprod.gfps, FormalPowerSeries) + assert fprod.f == sin(x) + assert fprod.g == exp(x) + assert fprod.function == sin(x) * exp(x) + assert fprod._eval_terms(4) == x + x**2 + x**3/3 + assert fprod.truncate(4) == x + x**2 + x**3/3 + O(x**4) + assert fprod.polynomial(4) == x + x**2 + x**3/3 + + raises(NotImplementedError, lambda: fprod._eval_term(5)) + raises(NotImplementedError, lambda: fprod.infinite) + raises(NotImplementedError, lambda: fprod._eval_derivative(x)) + raises(NotImplementedError, lambda: fprod.integrate(x)) + + assert f1.product(f3, x)._eval_terms(4) == x - 2*x**3/3 + assert f1.product(f3, x).truncate(4) == x - 2*x**3/3 + O(x**4) + + +def test_fps__compose(): + f1, f2, f3 = fps(exp(x)), fps(sin(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.compose(sin(x), x)) + raises(ValueError, lambda: f1.compose(fps(sin(x), dir=-1), x, 4)) + raises(ValueError, lambda: f1.compose(fps(sin(x), x0=1), x, 4)) + raises(ValueError, lambda: f1.compose(fps(sin(y)), x, 4)) + + raises(ValueError, lambda: f1.compose(f3, x)) + raises(ValueError, lambda: f2.compose(f3, x)) + + fcomp = f1.compose(f2, x) + assert isinstance(fcomp, FormalPowerSeriesCompose) + assert isinstance(fcomp.ffps, FormalPowerSeries) + assert isinstance(fcomp.gfps, FormalPowerSeries) + assert fcomp.f == exp(x) + assert fcomp.g == sin(x) + assert fcomp.function == exp(sin(x)) + assert fcomp._eval_terms(6) == 1 + x + x**2/2 - x**4/8 - x**5/15 + assert fcomp.truncate() == 1 + x + x**2/2 - x**4/8 - x**5/15 + O(x**6) + assert fcomp.truncate(5) == 1 + x + x**2/2 - x**4/8 + O(x**5) + + raises(NotImplementedError, lambda: fcomp._eval_term(5)) + raises(NotImplementedError, lambda: fcomp.infinite) + raises(NotImplementedError, lambda: fcomp._eval_derivative(x)) + raises(NotImplementedError, lambda: fcomp.integrate(x)) + + assert f1.compose(f2, x).truncate(4) == 1 + x + x**2/2 + O(x**4) + assert f1.compose(f2, x).truncate(8) == \ + 1 + x + x**2/2 - x**4/8 - x**5/15 - x**6/240 + x**7/90 + O(x**8) + assert f1.compose(f2, x).truncate(6) == \ + 1 + x + x**2/2 - x**4/8 - x**5/15 + O(x**6) + + assert f2.compose(f2, x).truncate(4) == x - x**3/3 + O(x**4) + assert f2.compose(f2, x).truncate(8) == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + assert f2.compose(f2, x).truncate(6) == x - x**3/3 + x**5/10 + O(x**6) + + +def test_fps__inverse(): + f1, f2, f3 = fps(sin(x)), fps(exp(x)), fps(cos(x)) + + raises(ValueError, lambda: f1.inverse(x)) + + finv = f2.inverse(x) + assert isinstance(finv, FormalPowerSeriesInverse) + assert isinstance(finv.ffps, FormalPowerSeries) + raises(ValueError, lambda: finv.gfps) + + assert finv.f == exp(x) + assert finv.function == exp(-x) + assert finv._eval_terms(5) == 1 - x + x**2/2 - x**3/6 + x**4/24 + assert finv.truncate() == 1 - x + x**2/2 - x**3/6 + x**4/24 - x**5/120 + O(x**6) + assert finv.truncate(5) == 1 - x + x**2/2 - x**3/6 + x**4/24 + O(x**5) + + raises(NotImplementedError, lambda: finv._eval_term(5)) + raises(ValueError, lambda: finv.g) + raises(NotImplementedError, lambda: finv.infinite) + raises(NotImplementedError, lambda: finv._eval_derivative(x)) + raises(NotImplementedError, lambda: finv.integrate(x)) + + assert f2.inverse(x).truncate(8) == \ + 1 - x + x**2/2 - x**3/6 + x**4/24 - x**5/120 + x**6/720 - x**7/5040 + O(x**8) + + assert f3.inverse(x).truncate() == 1 + x**2/2 + 5*x**4/24 + O(x**6) + assert f3.inverse(x).truncate(8) == 1 + x**2/2 + 5*x**4/24 + 61*x**6/720 + O(x**8) diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_fourier.py b/lib/python3.10/site-packages/sympy/series/tests/test_fourier.py new file mode 100644 index 0000000000000000000000000000000000000000..e3f206af3cc0c43e78065d8a1b788bf5138131bd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_fourier.py @@ -0,0 +1,165 @@ +from sympy.core.add import Add +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin, sinc, tan) +from sympy.series.fourier import fourier_series +from sympy.series.fourier import FourierSeries +from sympy.testing.pytest import raises +from functools import lru_cache + +x, y, z = symbols('x y z') + +# Don't declare these during import because they are slow +@lru_cache() +def _get_examples(): + fo = fourier_series(x, (x, -pi, pi)) + fe = fourier_series(x**2, (-pi, pi)) + fp = fourier_series(Piecewise((0, x < 0), (pi, True)), (x, -pi, pi)) + return fo, fe, fp + + +def test_FourierSeries(): + fo, fe, fp = _get_examples() + + assert fourier_series(1, (-pi, pi)) == 1 + assert (Piecewise((0, x < 0), (pi, True)). + fourier_series((x, -pi, pi)).truncate()) == fp.truncate() + assert isinstance(fo, FourierSeries) + assert fo.function == x + assert fo.x == x + assert fo.period == (-pi, pi) + + assert fo.term(3) == 2*sin(3*x) / 3 + assert fe.term(3) == -4*cos(3*x) / 9 + assert fp.term(3) == 2*sin(3*x) / 3 + + assert fo.as_leading_term(x) == 2*sin(x) + assert fe.as_leading_term(x) == pi**2 / 3 + assert fp.as_leading_term(x) == pi / 2 + + assert fo.truncate() == 2*sin(x) - sin(2*x) + (2*sin(3*x) / 3) + assert fe.truncate() == -4*cos(x) + cos(2*x) + pi**2 / 3 + assert fp.truncate() == 2*sin(x) + (2*sin(3*x) / 3) + pi / 2 + + fot = fo.truncate(n=None) + s = [0, 2*sin(x), -sin(2*x)] + for i, t in enumerate(fot): + if i == 3: + break + assert s[i] == t + + def _check_iter(f, i): + for ind, t in enumerate(f): + assert t == f[ind] + if ind == i: + break + + _check_iter(fo, 3) + _check_iter(fe, 3) + _check_iter(fp, 3) + + assert fo.subs(x, x**2) == fo + + raises(ValueError, lambda: fourier_series(x, (0, 1, 2))) + raises(ValueError, lambda: fourier_series(x, (x, 0, oo))) + raises(ValueError, lambda: fourier_series(x*y, (0, oo))) + + +def test_FourierSeries_2(): + p = Piecewise((0, x < 0), (x, True)) + f = fourier_series(p, (x, -2, 2)) + + assert f.term(3) == (2*sin(3*pi*x / 2) / (3*pi) - + 4*cos(3*pi*x / 2) / (9*pi**2)) + assert f.truncate() == (2*sin(pi*x / 2) / pi - sin(pi*x) / pi - + 4*cos(pi*x / 2) / pi**2 + S.Half) + + +def test_square_wave(): + """Test if fourier_series approximates discontinuous function correctly.""" + square_wave = Piecewise((1, x < pi), (-1, True)) + s = fourier_series(square_wave, (x, 0, 2*pi)) + + assert s.truncate(3) == 4 / pi * sin(x) + 4 / (3 * pi) * sin(3 * x) + \ + 4 / (5 * pi) * sin(5 * x) + assert s.sigma_approximation(4) == 4 / pi * sin(x) * sinc(pi / 4) + \ + 4 / (3 * pi) * sin(3 * x) * sinc(3 * pi / 4) + + +def test_sawtooth_wave(): + s = fourier_series(x, (x, 0, pi)) + assert s.truncate(4) == \ + pi/2 - sin(2*x) - sin(4*x)/2 - sin(6*x)/3 + s = fourier_series(x, (x, 0, 1)) + assert s.truncate(4) == \ + S.Half - sin(2*pi*x)/pi - sin(4*pi*x)/(2*pi) - sin(6*pi*x)/(3*pi) + + +def test_FourierSeries__operations(): + fo, fe, fp = _get_examples() + + fes = fe.scale(-1).shift(pi**2) + assert fes.truncate() == 4*cos(x) - cos(2*x) + 2*pi**2 / 3 + + assert fp.shift(-pi/2).truncate() == (2*sin(x) + (2*sin(3*x) / 3) + + (2*sin(5*x) / 5)) + + fos = fo.scale(3) + assert fos.truncate() == 6*sin(x) - 3*sin(2*x) + 2*sin(3*x) + + fx = fe.scalex(2).shiftx(1) + assert fx.truncate() == -4*cos(2*x + 2) + cos(4*x + 4) + pi**2 / 3 + + fl = fe.scalex(3).shift(-pi).scalex(2).shiftx(1).scale(4) + assert fl.truncate() == (-16*cos(6*x + 6) + 4*cos(12*x + 12) - + 4*pi + 4*pi**2 / 3) + + raises(ValueError, lambda: fo.shift(x)) + raises(ValueError, lambda: fo.shiftx(sin(x))) + raises(ValueError, lambda: fo.scale(x*y)) + raises(ValueError, lambda: fo.scalex(x**2)) + + +def test_FourierSeries__neg(): + fo, fe, fp = _get_examples() + + assert (-fo).truncate() == -2*sin(x) + sin(2*x) - (2*sin(3*x) / 3) + assert (-fe).truncate() == +4*cos(x) - cos(2*x) - pi**2 / 3 + + +def test_FourierSeries__add__sub(): + fo, fe, fp = _get_examples() + + assert fo + fo == fo.scale(2) + assert fo - fo == 0 + assert -fe - fe == fe.scale(-2) + + assert (fo + fe).truncate() == 2*sin(x) - sin(2*x) - 4*cos(x) + cos(2*x) \ + + pi**2 / 3 + assert (fo - fe).truncate() == 2*sin(x) - sin(2*x) + 4*cos(x) - cos(2*x) \ + - pi**2 / 3 + + assert isinstance(fo + 1, Add) + + raises(ValueError, lambda: fo + fourier_series(x, (x, 0, 2))) + + +def test_FourierSeries_finite(): + + assert fourier_series(sin(x)).truncate(1) == sin(x) + # assert type(fourier_series(sin(x)*log(x))).truncate() == FourierSeries + # assert type(fourier_series(sin(x**2+6))).truncate() == FourierSeries + assert fourier_series(sin(x)*log(y)*exp(z),(x,pi,-pi)).truncate() == sin(x)*log(y)*exp(z) + assert fourier_series(sin(x)**6).truncate(oo) == -15*cos(2*x)/32 + 3*cos(4*x)/16 - cos(6*x)/32 \ + + Rational(5, 16) + assert fourier_series(sin(x) ** 6).truncate() == -15 * cos(2 * x) / 32 + 3 * cos(4 * x) / 16 \ + + Rational(5, 16) + assert fourier_series(sin(4*x+3) + cos(3*x+4)).truncate(oo) == -sin(4)*sin(3*x) + sin(4*x)*cos(3) \ + + cos(4)*cos(3*x) + sin(3)*cos(4*x) + assert fourier_series(sin(x)+cos(x)*tan(x)).truncate(oo) == 2*sin(x) + assert fourier_series(cos(pi*x), (x, -1, 1)).truncate(oo) == cos(pi*x) + assert fourier_series(cos(3*pi*x + 4) - sin(4*pi*x)*log(pi*y), (x, -1, 1)).truncate(oo) == -log(pi*y)*sin(4*pi*x)\ + - sin(4)*sin(3*pi*x) + cos(4)*cos(3*pi*x) diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_gruntz.py b/lib/python3.10/site-packages/sympy/series/tests/test_gruntz.py new file mode 100644 index 0000000000000000000000000000000000000000..4565c876085b04e7f521bcf79571daf4a93ad653 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_gruntz.py @@ -0,0 +1,482 @@ +from sympy.core import EulerGamma +from sympy.core.numbers import (E, I, Integer, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acot, atan, cos, sin) +from sympy.functions.elementary.complexes import sign as _sign +from sympy.functions.special.error_functions import (Ei, erf) +from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma) +from sympy.functions.special.zeta_functions import zeta +from sympy.polys.polytools import cancel +from sympy.functions.elementary.hyperbolic import cosh, coth, sinh, tanh +from sympy.series.gruntz import compare, mrv, rewrite, mrv_leadterm, gruntz, \ + sign +from sympy.testing.pytest import XFAIL, skip, slow + +""" +This test suite is testing the limit algorithm using the bottom up approach. +See the documentation in limits2.py. The algorithm itself is highly recursive +by nature, so "compare" is logically the lowest part of the algorithm, yet in +some sense it's the most complex part, because it needs to calculate a limit +to return the result. + +Nevertheless, the rest of the algorithm depends on compare working correctly. +""" + +x = Symbol('x', real=True) +m = Symbol('m', real=True) + + +runslow = False + + +def _sskip(): + if not runslow: + skip("slow") + + +@slow +def test_gruntz_evaluation(): + # Gruntz' thesis pp. 122 to 123 + # 8.1 + assert gruntz(exp(x)*(exp(1/x - exp(-x)) - exp(1/x)), x, oo) == -1 + # 8.2 + assert gruntz(exp(x)*(exp(1/x + exp(-x) + exp(-x**2)) + - exp(1/x - exp(-exp(x)))), x, oo) == 1 + # 8.3 + assert gruntz(exp(exp(x - exp(-x))/(1 - 1/x)) - exp(exp(x)), x, oo) is oo + # 8.5 + assert gruntz(exp(exp(exp(x + exp(-x)))) / exp(exp(exp(x))), x, oo) is oo + # 8.6 + assert gruntz(exp(exp(exp(x))) / exp(exp(exp(x - exp(-exp(x))))), + x, oo) is oo + # 8.7 + assert gruntz(exp(exp(exp(x))) / exp(exp(exp(x - exp(-exp(exp(x)))))), + x, oo) == 1 + # 8.8 + assert gruntz(exp(exp(x)) / exp(exp(x - exp(-exp(exp(x))))), x, oo) == 1 + # 8.9 + assert gruntz(log(x)**2 * exp(sqrt(log(x))*(log(log(x)))**2 + * exp(sqrt(log(log(x))) * (log(log(log(x))))**3)) / sqrt(x), + x, oo) == 0 + # 8.10 + assert gruntz((x*log(x)*(log(x*exp(x) - x**2))**2) + / (log(log(x**2 + 2*exp(exp(3*x**3*log(x)))))), x, oo) == Rational(1, 3) + # 8.11 + assert gruntz((exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1)))) - exp(x))/x, + x, oo) == -exp(2) + # 8.12 + assert gruntz((3**x + 5**x)**(1/x), x, oo) == 5 + # 8.13 + assert gruntz(x/log(x**(log(x**(log(2)/log(x))))), x, oo) is oo + # 8.14 + assert gruntz(exp(exp(2*log(x**5 + x)*log(log(x)))) + / exp(exp(10*log(x)*log(log(x)))), x, oo) is oo + # 8.15 + assert gruntz(exp(exp(Rational(5, 2)*x**Rational(-5, 7) + Rational(21, 8)*x**Rational(6, 11) + + 2*x**(-8) + Rational(54, 17)*x**Rational(49, 45)))**8 + / log(log(-log(Rational(4, 3)*x**Rational(-5, 14))))**Rational(7, 6), x, oo) is oo + # 8.16 + assert gruntz((exp(4*x*exp(-x)/(1/exp(x) + 1/exp(2*x**2/(x + 1)))) - exp(x)) + / exp(x)**4, x, oo) == 1 + # 8.17 + assert gruntz(exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1))))/exp(x), x, oo) \ + == 1 + # 8.19 + assert gruntz(log(x)*(log(log(x) + log(log(x))) - log(log(x))) + / (log(log(x) + log(log(log(x))))), x, oo) == 1 + # 8.20 + assert gruntz(exp((log(log(x + exp(log(x)*log(log(x)))))) + / (log(log(log(exp(x) + x + log(x)))))), x, oo) == E + # Another + assert gruntz(exp(exp(exp(x + exp(-x)))) / exp(exp(x)), x, oo) is oo + + +def test_gruntz_evaluation_slow(): + _sskip() + # 8.4 + assert gruntz(exp(exp(exp(x)/(1 - 1/x))) + - exp(exp(exp(x)/(1 - 1/x - log(x)**(-log(x))))), x, oo) is -oo + # 8.18 + assert gruntz((exp(exp(-x/(1 + exp(-x))))*exp(-x/(1 + exp(-x/(1 + exp(-x))))) + *exp(exp(-x + exp(-x/(1 + exp(-x)))))) + / (exp(-x/(1 + exp(-x))))**2 - exp(x) + x, x, oo) == 2 + + +@slow +def test_gruntz_eval_special(): + # Gruntz, p. 126 + assert gruntz(exp(x)*(sin(1/x + exp(-x)) - sin(1/x + exp(-x**2))), x, oo) == 1 + assert gruntz((erf(x - exp(-exp(x))) - erf(x)) * exp(exp(x)) * exp(x**2), + x, oo) == -2/sqrt(pi) + assert gruntz(exp(exp(x)) * (exp(sin(1/x + exp(-exp(x)))) - exp(sin(1/x))), + x, oo) == 1 + assert gruntz(exp(x)*(gamma(x + exp(-x)) - gamma(x)), x, oo) is oo + assert gruntz(exp(exp(digamma(digamma(x))))/x, x, oo) == exp(Rational(-1, 2)) + assert gruntz(exp(exp(digamma(log(x))))/x, x, oo) == exp(Rational(-1, 2)) + assert gruntz(digamma(digamma(digamma(x))), x, oo) is oo + assert gruntz(loggamma(loggamma(x)), x, oo) is oo + assert gruntz(((gamma(x + 1/gamma(x)) - gamma(x))/log(x) - cos(1/x)) + * x*log(x), x, oo) == Rational(-1, 2) + assert gruntz(x * (gamma(x - 1/gamma(x)) - gamma(x) + log(x)), x, oo) \ + == S.Half + assert gruntz((gamma(x + 1/gamma(x)) - gamma(x)) / log(x), x, oo) == 1 + + +def test_gruntz_eval_special_slow(): + _sskip() + assert gruntz(gamma(x + 1)/sqrt(2*pi) + - exp(-x)*(x**(x + S.Half) + x**(x - S.Half)/12), x, oo) is oo + assert gruntz(exp(exp(exp(digamma(digamma(digamma(x))))))/x, x, oo) == 0 + + +@XFAIL +def test_grunts_eval_special_slow_sometimes_fail(): + _sskip() + # XXX This sometimes fails!!! + assert gruntz(exp(gamma(x - exp(-x))*exp(1/x)) - exp(gamma(x)), x, oo) is oo + + +def test_gruntz_Ei(): + assert gruntz((Ei(x - exp(-exp(x))) - Ei(x)) *exp(-x)*exp(exp(x))*x, x, oo) == -1 + + +@XFAIL +def test_gruntz_eval_special_fail(): + # TODO zeta function series + assert gruntz( + exp((log(2) + 1)*x) * (zeta(x + exp(-x)) - zeta(x)), x, oo) == -log(2) + + # TODO 8.35 - 8.37 (bessel, max-min) + + +def test_gruntz_hyperbolic(): + assert gruntz(cosh(x), x, oo) is oo + assert gruntz(cosh(x), x, -oo) is oo + assert gruntz(sinh(x), x, oo) is oo + assert gruntz(sinh(x), x, -oo) is -oo + assert gruntz(2*cosh(x)*exp(x), x, oo) is oo + assert gruntz(2*cosh(x)*exp(x), x, -oo) == 1 + assert gruntz(2*sinh(x)*exp(x), x, oo) is oo + assert gruntz(2*sinh(x)*exp(x), x, -oo) == -1 + assert gruntz(tanh(x), x, oo) == 1 + assert gruntz(tanh(x), x, -oo) == -1 + assert gruntz(coth(x), x, oo) == 1 + assert gruntz(coth(x), x, -oo) == -1 + + +def test_compare1(): + assert compare(2, x, x) == "<" + assert compare(x, exp(x), x) == "<" + assert compare(exp(x), exp(x**2), x) == "<" + assert compare(exp(x**2), exp(exp(x)), x) == "<" + assert compare(1, exp(exp(x)), x) == "<" + + assert compare(x, 2, x) == ">" + assert compare(exp(x), x, x) == ">" + assert compare(exp(x**2), exp(x), x) == ">" + assert compare(exp(exp(x)), exp(x**2), x) == ">" + assert compare(exp(exp(x)), 1, x) == ">" + + assert compare(2, 3, x) == "=" + assert compare(3, -5, x) == "=" + assert compare(2, -5, x) == "=" + + assert compare(x, x**2, x) == "=" + assert compare(x**2, x**3, x) == "=" + assert compare(x**3, 1/x, x) == "=" + assert compare(1/x, x**m, x) == "=" + assert compare(x**m, -x, x) == "=" + + assert compare(exp(x), exp(-x), x) == "=" + assert compare(exp(-x), exp(2*x), x) == "=" + assert compare(exp(2*x), exp(x)**2, x) == "=" + assert compare(exp(x)**2, exp(x + exp(-x)), x) == "=" + assert compare(exp(x), exp(x + exp(-x)), x) == "=" + + assert compare(exp(x**2), 1/exp(x**2), x) == "=" + + +def test_compare2(): + assert compare(exp(x), x**5, x) == ">" + assert compare(exp(x**2), exp(x)**2, x) == ">" + assert compare(exp(x), exp(x + exp(-x)), x) == "=" + assert compare(exp(x + exp(-x)), exp(x), x) == "=" + assert compare(exp(x + exp(-x)), exp(-x), x) == "=" + assert compare(exp(-x), x, x) == ">" + assert compare(x, exp(-x), x) == "<" + assert compare(exp(x + 1/x), x, x) == ">" + assert compare(exp(-exp(x)), exp(x), x) == ">" + assert compare(exp(exp(-exp(x)) + x), exp(-exp(x)), x) == "<" + + +def test_compare3(): + assert compare(exp(exp(x)), exp(x + exp(-exp(x))), x) == ">" + + +def test_sign1(): + assert sign(Rational(0), x) == 0 + assert sign(Rational(3), x) == 1 + assert sign(Rational(-5), x) == -1 + assert sign(log(x), x) == 1 + assert sign(exp(-x), x) == 1 + assert sign(exp(x), x) == 1 + assert sign(-exp(x), x) == -1 + assert sign(3 - 1/x, x) == 1 + assert sign(-3 - 1/x, x) == -1 + assert sign(sin(1/x), x) == 1 + assert sign((x**Integer(2)), x) == 1 + assert sign(x**2, x) == 1 + assert sign(x**5, x) == 1 + + +def test_sign2(): + assert sign(x, x) == 1 + assert sign(-x, x) == -1 + y = Symbol("y", positive=True) + assert sign(y, x) == 1 + assert sign(-y, x) == -1 + assert sign(y*x, x) == 1 + assert sign(-y*x, x) == -1 + + +def mmrv(a, b): + return set(mrv(a, b)[0].keys()) + + +def test_mrv1(): + assert mmrv(x, x) == {x} + assert mmrv(x + 1/x, x) == {x} + assert mmrv(x**2, x) == {x} + assert mmrv(log(x), x) == {x} + assert mmrv(exp(x), x) == {exp(x)} + assert mmrv(exp(-x), x) == {exp(-x)} + assert mmrv(exp(x**2), x) == {exp(x**2)} + assert mmrv(-exp(1/x), x) == {x} + assert mmrv(exp(x + 1/x), x) == {exp(x + 1/x)} + + +def test_mrv2a(): + assert mmrv(exp(x + exp(-exp(x))), x) == {exp(-exp(x))} + assert mmrv(exp(x + exp(-x)), x) == {exp(x + exp(-x)), exp(-x)} + assert mmrv(exp(1/x + exp(-x)), x) == {exp(-x)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv2b(): + assert mmrv(exp(x + exp(-x**2)), x) == {exp(-x**2)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv2c(): + assert mmrv( + exp(-x + 1/x**2) - exp(x + 1/x), x) == {exp(x + 1/x), exp(1/x**2 - x)} + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_mrv3(): + assert mmrv(exp(x**2) + x*exp(x) + log(x)**x/x, x) == {exp(x**2)} + assert mmrv( + exp(x)*(exp(1/x + exp(-x)) - exp(1/x)), x) == {exp(x), exp(-x)} + assert mmrv(log( + x**2 + 2*exp(exp(3*x**3*log(x)))), x) == {exp(exp(3*x**3*log(x)))} + assert mmrv(log(x - log(x))/log(x), x) == {x} + assert mmrv( + (exp(1/x - exp(-x)) - exp(1/x))*exp(x), x) == {exp(x), exp(-x)} + assert mmrv( + 1/exp(-x + exp(-x)) - exp(x), x) == {exp(x), exp(-x), exp(x - exp(-x))} + assert mmrv(log(log(x*exp(x*exp(x)) + 1)), x) == {exp(x*exp(x))} + assert mmrv(exp(exp(log(log(x) + 1/x))), x) == {x} + + +def test_mrv4(): + ln = log + assert mmrv((ln(ln(x) + ln(ln(x))) - ln(ln(x)))/ln(ln(x) + ln(ln(ln(x))))*ln(x), + x) == {x} + assert mmrv(log(log(x*exp(x*exp(x)) + 1)) - exp(exp(log(log(x) + 1/x))), x) == \ + {exp(x*exp(x))} + + +def mrewrite(a, b, c): + return rewrite(a[1], a[0], b, c) + + +def test_rewrite1(): + e = exp(x) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x) + e = exp(x**2) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x**2) + e = exp(x + 1/x) + assert mrewrite(mrv(e, x), x, m) == (1/m, -x - 1/x) + e = 1/exp(-x + exp(-x)) - exp(x) + assert mrewrite(mrv(e, x), x, m) == (1/(m*exp(m)) - 1/m, -x) + + +def test_rewrite2(): + e = exp(x)*log(log(exp(x))) + assert mmrv(e, x) == {exp(x)} + assert mrewrite(mrv(e, x), x, m) == (1/m*log(x), -x) + +#sometimes infinite recursion due to log(exp(x**2)) not simplifying + + +def test_rewrite3(): + e = exp(-x + 1/x**2) - exp(x + 1/x) + #both of these are correct and should be equivalent: + assert mrewrite(mrv(e, x), x, m) in [(-1/m + m*exp( + 1/x + 1/x**2), -x - 1/x), (m - 1/m*exp(1/x + x**(-2)), x**(-2) - x)] + + +def test_mrv_leadterm1(): + assert mrv_leadterm(-exp(1/x), x) == (-1, 0) + assert mrv_leadterm(1/exp(-x + exp(-x)) - exp(x), x) == (-1, 0) + assert mrv_leadterm( + (exp(1/x - exp(-x)) - exp(1/x))*exp(x), x) == (-exp(1/x), 0) + + +def test_mrv_leadterm2(): + #Gruntz: p51, 3.25 + assert mrv_leadterm((log(exp(x) + x) - x)/log(exp(x) + log(x))*exp(x), x) == \ + (1, 0) + + +def test_mrv_leadterm3(): + #Gruntz: p56, 3.27 + assert mmrv(exp(-x + exp(-x)*exp(-x*log(x))), x) == {exp(-x - x*log(x))} + assert mrv_leadterm(exp(-x + exp(-x)*exp(-x*log(x))), x) == (exp(-x), 0) + + +def test_limit1(): + assert gruntz(x, x, oo) is oo + assert gruntz(x, x, -oo) is -oo + assert gruntz(-x, x, oo) is -oo + assert gruntz(x**2, x, -oo) is oo + assert gruntz(-x**2, x, oo) is -oo + assert gruntz(x*log(x), x, 0, dir="+") == 0 + assert gruntz(1/x, x, oo) == 0 + assert gruntz(exp(x), x, oo) is oo + assert gruntz(-exp(x), x, oo) is -oo + assert gruntz(exp(x)/x, x, oo) is oo + assert gruntz(1/x - exp(-x), x, oo) == 0 + assert gruntz(x + 1/x, x, oo) is oo + + +def test_limit2(): + assert gruntz(x**x, x, 0, dir="+") == 1 + assert gruntz((exp(x) - 1)/x, x, 0) == 1 + assert gruntz(1 + 1/x, x, oo) == 1 + assert gruntz(-exp(1/x), x, oo) == -1 + assert gruntz(x + exp(-x), x, oo) is oo + assert gruntz(x + exp(-x**2), x, oo) is oo + assert gruntz(x + exp(-exp(x)), x, oo) is oo + assert gruntz(13 + 1/x - exp(-x), x, oo) == 13 + + +def test_limit3(): + a = Symbol('a') + assert gruntz(x - log(1 + exp(x)), x, oo) == 0 + assert gruntz(x - log(a + exp(x)), x, oo) == 0 + assert gruntz(exp(x)/(1 + exp(x)), x, oo) == 1 + assert gruntz(exp(x)/(a + exp(x)), x, oo) == 1 + + +def test_limit4(): + #issue 3463 + assert gruntz((3**x + 5**x)**(1/x), x, oo) == 5 + #issue 3463 + assert gruntz((3**(1/x) + 5**(1/x))**x, x, 0) == 5 + + +@XFAIL +def test_MrvTestCase_page47_ex3_21(): + h = exp(-x/(1 + exp(-x))) + expr = exp(h)*exp(-x/(1 + h))*exp(exp(-x + h))/h**2 - exp(x) + x + assert mmrv(expr, x) == {1/h, exp(-x), exp(x), exp(x - h), exp(x/(1 + h))} + + +def test_gruntz_I(): + y = Symbol("y") + assert gruntz(I*x, x, oo) == I*oo + assert gruntz(y*I*x, x, oo) == y*I*oo + assert gruntz(y*3*I*x, x, oo) == y*I*oo + assert gruntz(y*3*sin(I)*x, x, oo).simplify().rewrite(_sign) == _sign(y)*I*oo + + +def test_issue_4814(): + assert gruntz((x + 1)**(1/log(x + 1)), x, oo) == E + + +def test_intractable(): + assert gruntz(1/gamma(x), x, oo) == 0 + assert gruntz(1/loggamma(x), x, oo) == 0 + assert gruntz(gamma(x)/loggamma(x), x, oo) is oo + assert gruntz(exp(gamma(x))/gamma(x), x, oo) is oo + assert gruntz(gamma(x), x, 3) == 2 + assert gruntz(gamma(Rational(1, 7) + 1/x), x, oo) == gamma(Rational(1, 7)) + assert gruntz(log(x**x)/log(gamma(x)), x, oo) == 1 + assert gruntz(log(gamma(gamma(x)))/exp(x), x, oo) is oo + + +def test_aseries_trig(): + assert cancel(gruntz(1/log(atan(x)), x, oo) + - 1/(log(pi) + log(S.Half))) == 0 + assert gruntz(1/acot(x), x, -oo) is -oo + + +def test_exp_log_series(): + assert gruntz(x/log(log(x*exp(x))), x, oo) is oo + + +def test_issue_3644(): + assert gruntz(((x**7 + x + 1)/(2**x + x**2))**(-1/x), x, oo) == 2 + + +def test_issue_6843(): + n = Symbol('n', integer=True, positive=True) + r = (n + 1)*x**(n + 1)/(x**(n + 1) - 1) - x/(x - 1) + assert gruntz(r, x, 1).simplify() == n/2 + + +def test_issue_4190(): + assert gruntz(x - gamma(1/x), x, oo) == S.EulerGamma + + +@XFAIL +def test_issue_5172(): + n = Symbol('n') + r = Symbol('r', positive=True) + c = Symbol('c') + p = Symbol('p', positive=True) + m = Symbol('m', negative=True) + expr = ((2*n*(n - r + 1)/(n + r*(n - r + 1)))**c + \ + (r - 1)*(n*(n - r + 2)/(n + r*(n - r + 1)))**c - n)/(n**c - n) + expr = expr.subs(c, c + 1) + assert gruntz(expr.subs(c, m), n, oo) == 1 + # fail: + assert gruntz(expr.subs(c, p), n, oo).simplify() == \ + (2**(p + 1) + r - 1)/(r + 1)**(p + 1) + + +def test_issue_4109(): + assert gruntz(1/gamma(x), x, 0) == 0 + assert gruntz(x*gamma(x), x, 0) == 1 + + +def test_issue_6682(): + assert gruntz(exp(2*Ei(-x))/x**2, x, 0) == exp(2*EulerGamma) + + +def test_issue_7096(): + from sympy.functions import sign + assert gruntz(x**-pi, x, 0, dir='-') == oo*sign((-1)**(-pi)) + +def test_issue_24210_25885(): + eq = exp(x)/(1+1/x)**x**2 + ans = sqrt(E) + assert gruntz(eq, x, oo) == ans + assert gruntz(1/eq, x, oo) == 1/ans diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_kauers.py b/lib/python3.10/site-packages/sympy/series/tests/test_kauers.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb9044b33416bc38879649b258150ba2906250c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_kauers.py @@ -0,0 +1,23 @@ +from sympy.series.kauers import finite_diff +from sympy.series.kauers import finite_diff_kauers +from sympy.abc import x, y, z, m, n, w +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.concrete.summations import Sum + + +def test_finite_diff(): + assert finite_diff(x**2 + 2*x + 1, x) == 2*x + 3 + assert finite_diff(y**3 + 2*y**2 + 3*y + 5, y) == 3*y**2 + 7*y + 6 + assert finite_diff(z**2 - 2*z + 3, z) == 2*z - 1 + assert finite_diff(w**2 + 3*w - 2, w) == 2*w + 4 + assert finite_diff(sin(x), x, pi/6) == -sin(x) + sin(x + pi/6) + assert finite_diff(cos(y), y, pi/3) == -cos(y) + cos(y + pi/3) + assert finite_diff(x**2 - 2*x + 3, x, 2) == 4*x + assert finite_diff(n**2 - 2*n + 3, n, 3) == 6*n + 3 + +def test_finite_diff_kauers(): + assert finite_diff_kauers(Sum(x**2, (x, 1, n))) == (n + 1)**2 + assert finite_diff_kauers(Sum(y, (y, 1, m))) == (m + 1) + assert finite_diff_kauers(Sum((x*y), (x, 1, m), (y, 1, n))) == (m + 1)*(n + 1) + assert finite_diff_kauers(Sum((x*y**2), (x, 1, m), (y, 1, n))) == (n + 1)**2*(m + 1) diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_limits.py b/lib/python3.10/site-packages/sympy/series/tests/test_limits.py new file mode 100644 index 0000000000000000000000000000000000000000..21777c15e65cddf54ab53062cc3d2b58fadef114 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_limits.py @@ -0,0 +1,1414 @@ +from itertools import product + +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, diff) +from sympy.core import EulerGamma, GoldenRatio +from sympy.core.mod import Mod +from sympy.core.numbers import (E, I, Rational, oo, pi, zoo) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.numbers import fibonacci +from sympy.functions.combinatorial.factorials import (binomial, factorial, subfactorial) +from sympy.functions.elementary.complexes import (Abs, re, sign) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (atanh, asinh, acosh, acoth, acsch, asech, tanh, sinh) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (cbrt, real_root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, acot, acsc, asec, asin, + atan, cos, cot, csc, sec, sin, tan) +from sympy.functions.special.bessel import (besseli, bessely, besselj, besselk) +from sympy.functions.special.error_functions import (Ei, erf, erfc, erfi, fresnelc, fresnels) +from sympy.functions.special.gamma_functions import (digamma, gamma, uppergamma) +from sympy.functions.special.hyper import meijerg +from sympy.integrals.integrals import (Integral, integrate) +from sympy.series.limits import (Limit, limit) +from sympy.simplify.simplify import (logcombine, simplify) +from sympy.simplify.hyperexpand import hyperexpand + +from sympy.calculus.accumulationbounds import AccumBounds +from sympy.core.mul import Mul +from sympy.series.limits import heuristics +from sympy.series.order import Order +from sympy.testing.pytest import XFAIL, raises + +from sympy import elliptic_e, elliptic_k + +from sympy.abc import x, y, z, k +n = Symbol('n', integer=True, positive=True) + + +def test_basic1(): + assert limit(x, x, oo) is oo + assert limit(x, x, -oo) is -oo + assert limit(-x, x, oo) is -oo + assert limit(x**2, x, -oo) is oo + assert limit(-x**2, x, oo) is -oo + assert limit(x*log(x), x, 0, dir="+") == 0 + assert limit(1/x, x, oo) == 0 + assert limit(exp(x), x, oo) is oo + assert limit(-exp(x), x, oo) is -oo + assert limit(exp(x)/x, x, oo) is oo + assert limit(1/x - exp(-x), x, oo) == 0 + assert limit(x + 1/x, x, oo) is oo + assert limit(x - x**2, x, oo) is -oo + assert limit((1 + x)**(1 + sqrt(2)), x, 0) == 1 + assert limit((1 + x)**oo, x, 0) == Limit((x + 1)**oo, x, 0) + assert limit((1 + x)**oo, x, 0, dir='-') == Limit((x + 1)**oo, x, 0, dir='-') + assert limit((1 + x + y)**oo, x, 0, dir='-') == Limit((1 + x + y)**oo, x, 0, dir='-') + assert limit(y/x/log(x), x, 0) == -oo*sign(y) + assert limit(cos(x + y)/x, x, 0) == sign(cos(y))*oo + assert limit(gamma(1/x + 3), x, oo) == 2 + assert limit(S.NaN, x, -oo) is S.NaN + assert limit(Order(2)*x, x, S.NaN) is S.NaN + assert limit(1/(x - 1), x, 1, dir="+") is oo + assert limit(1/(x - 1), x, 1, dir="-") is -oo + assert limit(1/(5 - x)**3, x, 5, dir="+") is -oo + assert limit(1/(5 - x)**3, x, 5, dir="-") is oo + assert limit(1/sin(x), x, pi, dir="+") is -oo + assert limit(1/sin(x), x, pi, dir="-") is oo + assert limit(1/cos(x), x, pi/2, dir="+") is -oo + assert limit(1/cos(x), x, pi/2, dir="-") is oo + assert limit(1/tan(x**3), x, (2*pi)**Rational(1, 3), dir="+") is oo + assert limit(1/tan(x**3), x, (2*pi)**Rational(1, 3), dir="-") is -oo + assert limit(1/cot(x)**3, x, (pi*Rational(3, 2)), dir="+") is -oo + assert limit(1/cot(x)**3, x, (pi*Rational(3, 2)), dir="-") is oo + assert limit(tan(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cot(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(sec(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(csc(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + + # test bi-directional limits + assert limit(sin(x)/x, x, 0, dir="+-") == 1 + assert limit(x**2, x, 0, dir="+-") == 0 + assert limit(1/x**2, x, 0, dir="+-") is oo + + # test failing bi-directional limits + assert limit(1/x, x, 0, dir="+-") is zoo + # approaching 0 + # from dir="+" + assert limit(1 + 1/x, x, 0) is oo + # from dir='-' + # Add + assert limit(1 + 1/x, x, 0, dir='-') is -oo + # Pow + assert limit(x**(-2), x, 0, dir='-') is oo + assert limit(x**(-3), x, 0, dir='-') is -oo + assert limit(1/sqrt(x), x, 0, dir='-') == (-oo)*I + assert limit(x**2, x, 0, dir='-') == 0 + assert limit(sqrt(x), x, 0, dir='-') == 0 + assert limit(x**-pi, x, 0, dir='-') == -oo*(-1)**(1 - pi) + assert limit((1 + cos(x))**oo, x, 0) == Limit((cos(x) + 1)**oo, x, 0) + + # test pull request 22491 + assert limit(1/asin(x), x, 0, dir = '+') == oo + assert limit(1/asin(x), x, 0, dir = '-') == -oo + assert limit(1/sinh(x), x, 0, dir = '+') == oo + assert limit(1/sinh(x), x, 0, dir = '-') == -oo + assert limit(log(1/x) + 1/sin(x), x, 0, dir = '+') == oo + assert limit(log(1/x) + 1/x, x, 0, dir = '+') == oo + + +def test_basic2(): + assert limit(x**x, x, 0, dir="+") == 1 + assert limit((exp(x) - 1)/x, x, 0) == 1 + assert limit(1 + 1/x, x, oo) == 1 + assert limit(-exp(1/x), x, oo) == -1 + assert limit(x + exp(-x), x, oo) is oo + assert limit(x + exp(-x**2), x, oo) is oo + assert limit(x + exp(-exp(x)), x, oo) is oo + assert limit(13 + 1/x - exp(-x), x, oo) == 13 + + +def test_basic3(): + assert limit(1/x, x, 0, dir="+") is oo + assert limit(1/x, x, 0, dir="-") is -oo + + +def test_basic4(): + assert limit(2*x + y*x, x, 0) == 0 + assert limit(2*x + y*x, x, 1) == 2 + y + assert limit(2*x**8 + y*x**(-3), x, -2) == 512 - y/8 + assert limit(sqrt(x + 1) - sqrt(x), x, oo) == 0 + assert integrate(1/(x**3 + 1), (x, 0, oo)) == 2*pi*sqrt(3)/9 + + +def test_log(): + # https://github.com/sympy/sympy/issues/21598 + a, b, c = symbols('a b c', positive=True) + A = log(a/b) - (log(a) - log(b)) + assert A.limit(a, oo) == 0 + assert (A * c).limit(a, oo) == 0 + + tau, x = symbols('tau x', positive=True) + # The value of manualintegrate in the issue + expr = tau**2*((tau - 1)*(tau + 1)*log(x + 1)/(tau**2 + 1)**2 + 1/((tau**2\ + + 1)*(x + 1)) - (-2*tau*atan(x/tau) + (tau**2/2 - 1/2)*log(tau**2\ + + x**2))/(tau**2 + 1)**2) + assert limit(expr, x, oo) == pi*tau**3/(tau**2 + 1)**2 + + +def test_piecewise(): + # https://github.com/sympy/sympy/issues/18363 + assert limit((real_root(x - 6, 3) + 2)/(x + 2), x, -2, '+') == Rational(1, 12) + + +def test_piecewise2(): + func1 = 2*sqrt(x)*Piecewise(((4*x - 2)/Abs(sqrt(4 - 4*(2*x - 1)**2)), 4*x - 2\ + >= 0), ((2 - 4*x)/Abs(sqrt(4 - 4*(2*x - 1)**2)), True)) + func2 = Piecewise((x**2/2, x <= 0.5), (x/2 - 0.125, True)) + func3 = Piecewise(((x - 9) / 5, x < -1), ((x - 9) / 5, x > 4), (sqrt(Abs(x - 3)), True)) + assert limit(func1, x, 0) == 1 + assert limit(func2, x, 0) == 0 + assert limit(func3, x, -1) == 2 + + +def test_basic5(): + class my(Function): + @classmethod + def eval(cls, arg): + if arg is S.Infinity: + return S.NaN + assert limit(my(x), x, oo) == Limit(my(x), x, oo) + + +def test_issue_3885(): + assert limit(x*y + x*z, z, 2) == x*y + 2*x + + +def test_Limit(): + assert Limit(sin(x)/x, x, 0) != 1 + assert Limit(sin(x)/x, x, 0).doit() == 1 + assert Limit(x, x, 0, dir='+-').args == (x, x, 0, Symbol('+-')) + + +def test_floor(): + assert limit(floor(x), x, -2, "+") == -2 + assert limit(floor(x), x, -2, "-") == -3 + assert limit(floor(x), x, -1, "+") == -1 + assert limit(floor(x), x, -1, "-") == -2 + assert limit(floor(x), x, 0, "+") == 0 + assert limit(floor(x), x, 0, "-") == -1 + assert limit(floor(x), x, 1, "+") == 1 + assert limit(floor(x), x, 1, "-") == 0 + assert limit(floor(x), x, 2, "+") == 2 + assert limit(floor(x), x, 2, "-") == 1 + assert limit(floor(x), x, 248, "+") == 248 + assert limit(floor(x), x, 248, "-") == 247 + + # https://github.com/sympy/sympy/issues/14478 + assert limit(x*floor(3/x)/2, x, 0, '+') == Rational(3, 2) + assert limit(floor(x + 1/2) - floor(x), x, oo) == AccumBounds(-S.Half, S(3)/2) + + # test issue 9158 + assert limit(floor(atan(x)), x, oo) == 1 + assert limit(floor(atan(x)), x, -oo) == -2 + assert limit(ceiling(atan(x)), x, oo) == 2 + assert limit(ceiling(atan(x)), x, -oo) == -1 + + +def test_floor_requires_robust_assumptions(): + assert limit(floor(sin(x)), x, 0, "+") == 0 + assert limit(floor(sin(x)), x, 0, "-") == -1 + assert limit(floor(cos(x)), x, 0, "+") == 0 + assert limit(floor(cos(x)), x, 0, "-") == 0 + assert limit(floor(5 + sin(x)), x, 0, "+") == 5 + assert limit(floor(5 + sin(x)), x, 0, "-") == 4 + assert limit(floor(5 + cos(x)), x, 0, "+") == 5 + assert limit(floor(5 + cos(x)), x, 0, "-") == 5 + + +def test_ceiling(): + assert limit(ceiling(x), x, -2, "+") == -1 + assert limit(ceiling(x), x, -2, "-") == -2 + assert limit(ceiling(x), x, -1, "+") == 0 + assert limit(ceiling(x), x, -1, "-") == -1 + assert limit(ceiling(x), x, 0, "+") == 1 + assert limit(ceiling(x), x, 0, "-") == 0 + assert limit(ceiling(x), x, 1, "+") == 2 + assert limit(ceiling(x), x, 1, "-") == 1 + assert limit(ceiling(x), x, 2, "+") == 3 + assert limit(ceiling(x), x, 2, "-") == 2 + assert limit(ceiling(x), x, 248, "+") == 249 + assert limit(ceiling(x), x, 248, "-") == 248 + + # https://github.com/sympy/sympy/issues/14478 + assert limit(x*ceiling(3/x)/2, x, 0, '+') == Rational(3, 2) + assert limit(ceiling(x + 1/2) - ceiling(x), x, oo) == AccumBounds(-S.Half, S(3)/2) + + +def test_ceiling_requires_robust_assumptions(): + assert limit(ceiling(sin(x)), x, 0, "+") == 1 + assert limit(ceiling(sin(x)), x, 0, "-") == 0 + assert limit(ceiling(cos(x)), x, 0, "+") == 1 + assert limit(ceiling(cos(x)), x, 0, "-") == 1 + assert limit(ceiling(5 + sin(x)), x, 0, "+") == 6 + assert limit(ceiling(5 + sin(x)), x, 0, "-") == 5 + assert limit(ceiling(5 + cos(x)), x, 0, "+") == 6 + assert limit(ceiling(5 + cos(x)), x, 0, "-") == 6 + + +def test_frac(): + assert limit(frac(x), x, oo) == AccumBounds(0, 1) + assert limit(frac(x)**(1/x), x, oo) == AccumBounds(0, 1) + assert limit(frac(x)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit(frac(x)**x, x, oo) == AccumBounds(0, oo) # wolfram gives (0, 1) + assert limit(frac(sin(x)), x, 0, "+") == 0 + assert limit(frac(sin(x)), x, 0, "-") == 1 + assert limit(frac(cos(x)), x, 0, "+-") == 1 + assert limit(frac(x**2), x, 0, "+-") == 0 + raises(ValueError, lambda: limit(frac(x), x, 0, '+-')) + assert limit(frac(-2*x + 1), x, 0, "+") == 1 + assert limit(frac(-2*x + 1), x, 0, "-") == 0 + assert limit(frac(x + S.Half), x, 0, "+-") == S(1)/2 + assert limit(frac(1/x), x, 0) == AccumBounds(0, 1) + + +def test_issue_14355(): + assert limit(floor(sin(x)/x), x, 0, '+') == 0 + assert limit(floor(sin(x)/x), x, 0, '-') == 0 + # test comment https://github.com/sympy/sympy/issues/14355#issuecomment-372121314 + assert limit(floor(-tan(x)/x), x, 0, '+') == -2 + assert limit(floor(-tan(x)/x), x, 0, '-') == -2 + + +def test_atan(): + x = Symbol("x", real=True) + assert limit(atan(x)*sin(1/x), x, 0) == 0 + assert limit(atan(x) + sqrt(x + 1) - sqrt(x), x, oo) == pi/2 + + +def test_set_signs(): + assert limit(abs(x), x, 0) == 0 + assert limit(abs(sin(x)), x, 0) == 0 + assert limit(abs(cos(x)), x, 0) == 1 + assert limit(abs(sin(x + 1)), x, 0) == sin(1) + + # https://github.com/sympy/sympy/issues/9449 + assert limit((Abs(x + y) - Abs(x - y))/(2*x), x, 0) == sign(y) + + # https://github.com/sympy/sympy/issues/12398 + assert limit(Abs(log(x)/x**3), x, oo) == 0 + assert limit(x*(Abs(log(x)/x**3)/Abs(log(x + 1)/(x + 1)**3) - 1), x, oo) == 3 + + # https://github.com/sympy/sympy/issues/18501 + assert limit(Abs(log(x - 1)**3 - 1), x, 1, '+') == oo + + # https://github.com/sympy/sympy/issues/18997 + assert limit(Abs(log(x)), x, 0) == oo + assert limit(Abs(log(Abs(x))), x, 0) == oo + + # https://github.com/sympy/sympy/issues/19026 + z = Symbol('z', positive=True) + assert limit(Abs(log(z) + 1)/log(z), z, oo) == 1 + + # https://github.com/sympy/sympy/issues/20704 + assert limit(z*(Abs(1/z + y) - Abs(y - 1/z))/2, z, 0) == 0 + + # https://github.com/sympy/sympy/issues/21606 + assert limit(cos(z)/sign(z), z, pi, '-') == -1 + + +def test_heuristic(): + x = Symbol("x", real=True) + assert heuristics(sin(1/x) + atan(x), x, 0, '+') == AccumBounds(-1, 1) + assert limit(log(2 + sqrt(atan(x))*sqrt(sin(1/x))), x, 0) == log(2) + + +def test_issue_3871(): + z = Symbol("z", positive=True) + f = -1/z*exp(-z*x) + assert limit(f, x, oo) == 0 + assert f.limit(x, oo) == 0 + + +def test_exponential(): + n = Symbol('n') + x = Symbol('x', real=True) + assert limit((1 + x/n)**n, n, oo) == exp(x) + assert limit((1 + x/(2*n))**n, n, oo) == exp(x/2) + assert limit((1 + x/(2*n + 1))**n, n, oo) == exp(x/2) + assert limit(((x - 1)/(x + 1))**x, x, oo) == exp(-2) + assert limit(1 + (1 + 1/x)**x, x, oo) == 1 + S.Exp1 + assert limit((2 + 6*x)**x/(6*x)**x, x, oo) == exp(S('1/3')) + + +def test_exponential2(): + n = Symbol('n') + assert limit((1 + x/(n + sin(n)))**n, n, oo) == exp(x) + + +def test_doit(): + f = Integral(2 * x, x) + l = Limit(f, x, oo) + assert l.doit() is oo + + +def test_series_AccumBounds(): + assert limit(sin(k) - sin(k + 1), k, oo) == AccumBounds(-2, 2) + assert limit(cos(k) - cos(k + 1) + 1, k, oo) == AccumBounds(-1, 3) + + # not the exact bound + assert limit(sin(k) - sin(k)*cos(k), k, oo) == AccumBounds(-2, 2) + + # test for issue #9934 + lo = (-3 + cos(1))/2 + hi = (1 + cos(1))/2 + t1 = Mul(AccumBounds(lo, hi), 1/(-1 + cos(1)), evaluate=False) + assert limit(simplify(Sum(cos(n).rewrite(exp), (n, 0, k)).doit().rewrite(sin)), k, oo) == t1 + + t2 = Mul(AccumBounds(-1 + sin(1)/2, sin(1)/2 + 1), 1/(1 - cos(1))) + assert limit(simplify(Sum(sin(n).rewrite(exp), (n, 0, k)).doit().rewrite(sin)), k, oo) == t2 + + assert limit(((sin(x) + 1)/2)**x, x, oo) == AccumBounds(0, oo) # wolfram says 0 + + # https://github.com/sympy/sympy/issues/12312 + e = 2**(-x)*(sin(x) + 1)**x + assert limit(e, x, oo) == AccumBounds(0, oo) + + +def test_bessel_functions_at_infinity(): + # Pull Request 23844 implements limits for all bessel and modified bessel + # functions approaching infinity along any direction i.e. abs(z0) tends to oo + + assert limit(besselj(1, x), x, oo) == 0 + assert limit(besselj(1, x), x, -oo) == 0 + assert limit(besselj(1, x), x, I*oo) == oo*I + assert limit(besselj(1, x), x, -I*oo) == -oo*I + assert limit(bessely(1, x), x, oo) == 0 + assert limit(bessely(1, x), x, -oo) == 0 + assert limit(bessely(1, x), x, I*oo) == -oo + assert limit(bessely(1, x), x, -I*oo) == -oo + assert limit(besseli(1, x), x, oo) == oo + assert limit(besseli(1, x), x, -oo) == -oo + assert limit(besseli(1, x), x, I*oo) == 0 + assert limit(besseli(1, x), x, -I*oo) == 0 + assert limit(besselk(1, x), x, oo) == 0 + assert limit(besselk(1, x), x, -oo) == -oo*I + assert limit(besselk(1, x), x, I*oo) == 0 + assert limit(besselk(1, x), x, -I*oo) == 0 + + # test issue 14874 + assert limit(besselk(0, x), x, oo) == 0 + + +@XFAIL +def test_doit2(): + f = Integral(2 * x, x) + l = Limit(f, x, oo) + # limit() breaks on the contained Integral. + assert l.doit(deep=False) == l + + +def test_issue_2929(): + assert limit((x * exp(x))/(exp(x) - 1), x, -oo) == 0 + + +def test_issue_3792(): + assert limit((1 - cos(x))/x**2, x, S.Half) == 4 - 4*cos(S.Half) + assert limit(sin(sin(x + 1) + 1), x, 0) == sin(1 + sin(1)) + assert limit(abs(sin(x + 1) + 1), x, 0) == 1 + sin(1) + + +def test_issue_4090(): + assert limit(1/(x + 3), x, 2) == Rational(1, 5) + assert limit(1/(x + pi), x, 2) == S.One/(2 + pi) + assert limit(log(x)/(x**2 + 3), x, 2) == log(2)/7 + assert limit(log(x)/(x**2 + pi), x, 2) == log(2)/(4 + pi) + + +def test_issue_4547(): + assert limit(cot(x), x, 0, dir='+') is oo + assert limit(cot(x), x, pi/2, dir='+') == 0 + + +def test_issue_5164(): + assert limit(x**0.5, x, oo) == oo**0.5 is oo + assert limit(x**0.5, x, 16) == 4 # Should this be a float? + assert limit(x**0.5, x, 0) == 0 + assert limit(x**(-0.5), x, oo) == 0 + assert limit(x**(-0.5), x, 4) == S.Half # Should this be a float? + + +def test_issue_5383(): + func = (1.0 * 1 + 1.0 * x)**(1.0 * 1 / x) + assert limit(func, x, 0) == E + + +def test_issue_14793(): + expr = ((x + S(1)/2) * log(x) - x + log(2*pi)/2 - \ + log(factorial(x)) + S(1)/(12*x))*x**3 + assert limit(expr, x, oo) == S(1)/360 + + +def test_issue_5183(): + # using list(...) so py.test can recalculate values + tests = list(product([x, -x], + [-1, 1], + [2, 3, S.Half, Rational(2, 3)], + ['-', '+'])) + results = (oo, oo, -oo, oo, -oo*I, oo, -oo*(-1)**Rational(1, 3), oo, + 0, 0, 0, 0, 0, 0, 0, 0, + oo, oo, oo, -oo, oo, -oo*I, oo, -oo*(-1)**Rational(1, 3), + 0, 0, 0, 0, 0, 0, 0, 0) + assert len(tests) == len(results) + for i, (args, res) in enumerate(zip(tests, results)): + y, s, e, d = args + eq = y**(s*e) + try: + assert limit(eq, x, 0, dir=d) == res + except AssertionError: + if 0: # change to 1 if you want to see the failing tests + print() + print(i, res, eq, d, limit(eq, x, 0, dir=d)) + else: + assert None + + +def test_issue_5184(): + assert limit(sin(x)/x, x, oo) == 0 + assert limit(atan(x), x, oo) == pi/2 + assert limit(gamma(x), x, oo) is oo + assert limit(cos(x)/x, x, oo) == 0 + assert limit(gamma(x), x, S.Half) == sqrt(pi) + + r = Symbol('r', real=True) + assert limit(r*sin(1/r), r, 0) == 0 + + +def test_issue_5229(): + assert limit((1 + y)**(1/y) - S.Exp1, y, 0) == 0 + + +def test_issue_4546(): + # using list(...) so py.test can recalculate values + tests = list(product([cot, tan], + [-pi/2, 0, pi/2, pi, pi*Rational(3, 2)], + ['-', '+'])) + results = (0, 0, -oo, oo, 0, 0, -oo, oo, 0, 0, + oo, -oo, 0, 0, oo, -oo, 0, 0, oo, -oo) + assert len(tests) == len(results) + for i, (args, res) in enumerate(zip(tests, results)): + f, l, d = args + eq = f(x) + try: + assert limit(eq, x, l, dir=d) == res + except AssertionError: + if 0: # change to 1 if you want to see the failing tests + print() + print(i, res, eq, l, d, limit(eq, x, l, dir=d)) + else: + assert None + + +def test_issue_3934(): + assert limit((1 + x**log(3))**(1/x), x, 0) == 1 + assert limit((5**(1/x) + 3**(1/x))**x, x, 0) == 5 + + +def test_calculate_series(): + # NOTE + # The calculate_series method is being deprecated and is no longer responsible + # for result being returned. The mrv_leadterm function now uses simple leadterm + # calls rather than calculate_series. + + # needs gruntz calculate_series to go to n = 32 + assert limit(x**Rational(77, 3)/(1 + x**Rational(77, 3)), x, oo) == 1 + # needs gruntz calculate_series to go to n = 128 + assert limit(x**101.1/(1 + x**101.1), x, oo) == 1 + + +def test_issue_5955(): + assert limit((x**16)/(1 + x**16), x, oo) == 1 + assert limit((x**100)/(1 + x**100), x, oo) == 1 + assert limit((x**1885)/(1 + x**1885), x, oo) == 1 + assert limit((x**1000/((x + 1)**1000 + exp(-x))), x, oo) == 1 + + +def test_newissue(): + assert limit(exp(1/sin(x))/exp(cot(x)), x, 0) == 1 + + +def test_extended_real_line(): + assert limit(x - oo, x, oo) == Limit(x - oo, x, oo) + assert limit(1/(x + sin(x)) - oo, x, 0) == Limit(1/(x + sin(x)) - oo, x, 0) + assert limit(oo/x, x, oo) == Limit(oo/x, x, oo) + assert limit(x - oo + 1/x, x, oo) == Limit(x - oo + 1/x, x, oo) + + +@XFAIL +def test_order_oo(): + x = Symbol('x', positive=True) + assert Order(x)*oo != Order(1, x) + assert limit(oo/(x**2 - 4), x, oo) is oo + + +def test_issue_5436(): + raises(NotImplementedError, lambda: limit(exp(x*y), x, oo)) + raises(NotImplementedError, lambda: limit(exp(-x*y), x, oo)) + + +def test_Limit_dir(): + raises(TypeError, lambda: Limit(x, x, 0, dir=0)) + raises(ValueError, lambda: Limit(x, x, 0, dir='0')) + + +def test_polynomial(): + assert limit((x + 1)**1000/((x + 1)**1000 + 1), x, oo) == 1 + assert limit((x + 1)**1000/((x + 1)**1000 + 1), x, -oo) == 1 + + +def test_rational(): + assert limit(1/y - (1/(y + x) + x/(y + x)/y)/z, x, oo) == (z - 1)/(y*z) + assert limit(1/y - (1/(y + x) + x/(y + x)/y)/z, x, -oo) == (z - 1)/(y*z) + + +def test_issue_5740(): + assert limit(log(x)*z - log(2*x)*y, x, 0) == oo*sign(y - z) + + +def test_issue_6366(): + n = Symbol('n', integer=True, positive=True) + r = (n + 1)*x**(n + 1)/(x**(n + 1) - 1) - x/(x - 1) + assert limit(r, x, 1).cancel() == n/2 + + +def test_factorial(): + f = factorial(x) + assert limit(f, x, oo) is oo + assert limit(x/f, x, oo) == 0 + # see Stirling's approximation: + # https://en.wikipedia.org/wiki/Stirling's_approximation + assert limit(f/(sqrt(2*pi*x)*(x/E)**x), x, oo) == 1 + assert limit(f, x, -oo) == gamma(-oo) + + +def test_issue_6560(): + e = (5*x**3/4 - x*Rational(3, 4) + (y*(3*x**2/2 - S.Half) + + 35*x**4/8 - 15*x**2/4 + Rational(3, 8))/(2*(y + 1))) + assert limit(e, y, oo) == 5*x**3/4 + 3*x**2/4 - 3*x/4 - Rational(1, 4) + +@XFAIL +def test_issue_5172(): + n = Symbol('n') + r = Symbol('r', positive=True) + c = Symbol('c') + p = Symbol('p', positive=True) + m = Symbol('m', negative=True) + expr = ((2*n*(n - r + 1)/(n + r*(n - r + 1)))**c + + (r - 1)*(n*(n - r + 2)/(n + r*(n - r + 1)))**c - n)/(n**c - n) + expr = expr.subs(c, c + 1) + raises(NotImplementedError, lambda: limit(expr, n, oo)) + assert limit(expr.subs(c, m), n, oo) == 1 + assert limit(expr.subs(c, p), n, oo).simplify() == \ + (2**(p + 1) + r - 1)/(r + 1)**(p + 1) + + +def test_issue_7088(): + a = Symbol('a') + assert limit(sqrt(x/(x + a)), x, oo) == 1 + + +def test_branch_cuts(): + assert limit(asin(I*x + 2), x, 0) == pi - asin(2) + assert limit(asin(I*x + 2), x, 0, '-') == asin(2) + assert limit(asin(I*x - 2), x, 0) == -asin(2) + assert limit(asin(I*x - 2), x, 0, '-') == -pi + asin(2) + assert limit(acos(I*x + 2), x, 0) == -acos(2) + assert limit(acos(I*x + 2), x, 0, '-') == acos(2) + assert limit(acos(I*x - 2), x, 0) == acos(-2) + assert limit(acos(I*x - 2), x, 0, '-') == 2*pi - acos(-2) + assert limit(atan(x + 2*I), x, 0) == I*atanh(2) + assert limit(atan(x + 2*I), x, 0, '-') == -pi + I*atanh(2) + assert limit(atan(x - 2*I), x, 0) == pi - I*atanh(2) + assert limit(atan(x - 2*I), x, 0, '-') == -I*atanh(2) + assert limit(atan(1/x), x, 0) == pi/2 + assert limit(atan(1/x), x, 0, '-') == -pi/2 + assert limit(atan(x), x, oo) == pi/2 + assert limit(atan(x), x, -oo) == -pi/2 + assert limit(acot(x + S(1)/2*I), x, 0) == pi - I*acoth(S(1)/2) + assert limit(acot(x + S(1)/2*I), x, 0, '-') == -I*acoth(S(1)/2) + assert limit(acot(x - S(1)/2*I), x, 0) == I*acoth(S(1)/2) + assert limit(acot(x - S(1)/2*I), x, 0, '-') == -pi + I*acoth(S(1)/2) + assert limit(acot(x), x, 0) == pi/2 + assert limit(acot(x), x, 0, '-') == -pi/2 + assert limit(asec(I*x + S(1)/2), x, 0) == asec(S(1)/2) + assert limit(asec(I*x + S(1)/2), x, 0, '-') == -asec(S(1)/2) + assert limit(asec(I*x - S(1)/2), x, 0) == 2*pi - asec(-S(1)/2) + assert limit(asec(I*x - S(1)/2), x, 0, '-') == asec(-S(1)/2) + assert limit(acsc(I*x + S(1)/2), x, 0) == acsc(S(1)/2) + assert limit(acsc(I*x + S(1)/2), x, 0, '-') == pi - acsc(S(1)/2) + assert limit(acsc(I*x - S(1)/2), x, 0) == -pi + acsc(S(1)/2) + assert limit(acsc(I*x - S(1)/2), x, 0, '-') == -acsc(S(1)/2) + + assert limit(log(I*x - 1), x, 0) == I*pi + assert limit(log(I*x - 1), x, 0, '-') == -I*pi + assert limit(log(-I*x - 1), x, 0) == -I*pi + assert limit(log(-I*x - 1), x, 0, '-') == I*pi + + assert limit(sqrt(I*x - 1), x, 0) == I + assert limit(sqrt(I*x - 1), x, 0, '-') == -I + assert limit(sqrt(-I*x - 1), x, 0) == -I + assert limit(sqrt(-I*x - 1), x, 0, '-') == I + + assert limit(cbrt(I*x - 1), x, 0) == (-1)**(S(1)/3) + assert limit(cbrt(I*x - 1), x, 0, '-') == -(-1)**(S(2)/3) + assert limit(cbrt(-I*x - 1), x, 0) == -(-1)**(S(2)/3) + assert limit(cbrt(-I*x - 1), x, 0, '-') == (-1)**(S(1)/3) + + +def test_issue_6364(): + a = Symbol('a') + e = z/(1 - sqrt(1 + z)*sin(a)**2 - sqrt(1 - z)*cos(a)**2) + assert limit(e, z, 0) == 1/(cos(a)**2 - S.Half) + + +def test_issue_6682(): + assert limit(exp(2*Ei(-x))/x**2, x, 0) == exp(2*EulerGamma) + + +def test_issue_4099(): + a = Symbol('a') + assert limit(a/x, x, 0) == oo*sign(a) + assert limit(-a/x, x, 0) == -oo*sign(a) + assert limit(-a*x, x, oo) == -oo*sign(a) + assert limit(a*x, x, oo) == oo*sign(a) + + +def test_issue_4503(): + dx = Symbol('dx') + assert limit((sqrt(1 + exp(x + dx)) - sqrt(1 + exp(x)))/dx, dx, 0) == \ + exp(x)/(2*sqrt(exp(x) + 1)) + + +def test_issue_6052(): + G = meijerg((), (), (1,), (0,), -x) + g = hyperexpand(G) + assert limit(g, x, 0, '+-') == 0 + assert limit(g, x, oo) == -oo + + +def test_issue_7224(): + expr = sqrt(x)*besseli(1,sqrt(8*x)) + assert limit(x*diff(expr, x, x)/expr, x, 0) == 2 + assert limit(x*diff(expr, x, x)/expr, x, 1).evalf() == 2.0 + + +def test_issue_8208(): + assert limit(n**(Rational(1, 1e9) - 1), n, oo) == 0 + + +def test_issue_8229(): + assert limit((x**Rational(1, 4) - 2)/(sqrt(x) - 4)**Rational(2, 3), x, 16) == 0 + + +def test_issue_8433(): + d, t = symbols('d t', positive=True) + assert limit(erf(1 - t/d), t, oo) == -1 + + +def test_issue_8481(): + k = Symbol('k', integer=True, nonnegative=True) + lamda = Symbol('lamda', positive=True) + assert limit(lamda**k * exp(-lamda) / factorial(k), k, oo) == 0 + + +def test_issue_8462(): + assert limit(binomial(n, n/2), n, oo) == oo + assert limit(binomial(n, n/2) * 3 ** (-n), n, oo) == 0 + + +def test_issue_8634(): + n = Symbol('n', integer=True, positive=True) + x = Symbol('x') + assert limit(x**n, x, -oo) == oo*sign((-1)**n) + + +def test_issue_8635_18176(): + x = Symbol('x', real=True) + k = Symbol('k', positive=True) + assert limit(x**n - x**(n - 0), x, oo) == 0 + assert limit(x**n - x**(n - 5), x, oo) == oo + assert limit(x**n - x**(n - 2.5), x, oo) == oo + assert limit(x**n - x**(n - k - 1), x, oo) == oo + x = Symbol('x', positive=True) + assert limit(x**n - x**(n - 1), x, oo) == oo + assert limit(x**n - x**(n + 2), x, oo) == -oo + + +def test_issue_8730(): + assert limit(subfactorial(x), x, oo) is oo + + +def test_issue_9252(): + n = Symbol('n', integer=True) + c = Symbol('c', positive=True) + assert limit((log(n))**(n/log(n)) / (1 + c)**n, n, oo) == 0 + # limit should depend on the value of c + raises(NotImplementedError, lambda: limit((log(n))**(n/log(n)) / c**n, n, oo)) + + +def test_issue_9558(): + assert limit(sin(x)**15, x, 0, '-') == 0 + + +def test_issue_10801(): + # make sure limits work with binomial + assert limit(16**k / (k * binomial(2*k, k)**2), k, oo) == pi + + +def test_issue_10976(): + s, x = symbols('s x', real=True) + assert limit(erf(s*x)/erf(s), s, 0) == x + + +def test_issue_9041(): + assert limit(factorial(n) / ((n/exp(1))**n * sqrt(2*pi*n)), n, oo) == 1 + + +def test_issue_9205(): + x, y, a = symbols('x, y, a') + assert Limit(x, x, a).free_symbols == {a} + assert Limit(x, x, a, '-').free_symbols == {a} + assert Limit(x + y, x + y, a).free_symbols == {a} + assert Limit(-x**2 + y, x**2, a).free_symbols == {y, a} + + +def test_issue_9471(): + assert limit(((27**(log(n,3)))/n**3),n,oo) == 1 + assert limit(((27**(log(n,3)+1))/n**3),n,oo) == 27 + + +def test_issue_10382(): + assert limit(fibonacci(n + 1)/fibonacci(n), n, oo) == GoldenRatio + + +def test_issue_11496(): + assert limit(erfc(log(1/x)), x, oo) == 2 + + +def test_issue_11879(): + assert simplify(limit(((x+y)**n-x**n)/y, y, 0)) == n*x**(n-1) + + +def test_limit_with_Float(): + k = symbols("k") + assert limit(1.0 ** k, k, oo) == 1 + assert limit(0.3*1.0**k, k, oo) == Rational(3, 10) + + +def test_issue_10610(): + assert limit(3**x*3**(-x - 1)*(x + 1)**2/x**2, x, oo) == Rational(1, 3) + + +def test_issue_10868(): + assert limit(log(x) + asech(x), x, 0, '+') == log(2) + assert limit(log(x) + asech(x), x, 0, '-') == log(2) + 2*I*pi + raises(ValueError, lambda: limit(log(x) + asech(x), x, 0, '+-')) + assert limit(log(x) + asech(x), x, oo) == oo + assert limit(log(x) + acsch(x), x, 0, '+') == log(2) + assert limit(log(x) + acsch(x), x, 0, '-') == -oo + raises(ValueError, lambda: limit(log(x) + acsch(x), x, 0, '+-')) + assert limit(log(x) + acsch(x), x, oo) == oo + + +def test_issue_6599(): + assert limit((n + cos(n))/n, n, oo) == 1 + + +def test_issue_12555(): + assert limit((3**x + 2* x**10) / (x**10 + exp(x)), x, -oo) == 2 + assert limit((3**x + 2* x**10) / (x**10 + exp(x)), x, oo) is oo + + +def test_issue_12769(): + r, z, x = symbols('r z x', real=True) + a, b, s0, K, F0, s, T = symbols('a b s0 K F0 s T', positive=True, real=True) + fx = (F0**b*K**b*r*s0 - sqrt((F0**2*K**(2*b)*a**2*(b - 1) + \ + F0**(2*b)*K**2*a**2*(b - 1) + F0**(2*b)*K**(2*b)*s0**2*(b - 1)*(b**2 - 2*b + 1) - \ + 2*F0**(2*b)*K**(b + 1)*a*r*s0*(b**2 - 2*b + 1) + \ + 2*F0**(b + 1)*K**(2*b)*a*r*s0*(b**2 - 2*b + 1) - \ + 2*F0**(b + 1)*K**(b + 1)*a**2*(b - 1))/((b - 1)*(b**2 - 2*b + 1))))*(b*r - b - r + 1) + + assert fx.subs(K, F0).factor(deep=True) == limit(fx, K, F0).factor(deep=True) + + +def test_issue_13332(): + assert limit(sqrt(30)*5**(-5*x - 1)*(46656*x)**x*(5*x + 2)**(5*x + 5*S.Half) * + (6*x + 2)**(-6*x - 5*S.Half), x, oo) == Rational(25, 36) + + +def test_issue_12564(): + assert limit(x**2 + x*sin(x) + cos(x), x, -oo) is oo + assert limit(x**2 + x*sin(x) + cos(x), x, oo) is oo + assert limit(((x + cos(x))**2).expand(), x, oo) is oo + assert limit(((x + sin(x))**2).expand(), x, oo) is oo + assert limit(((x + cos(x))**2).expand(), x, -oo) is oo + assert limit(((x + sin(x))**2).expand(), x, -oo) is oo + + +def test_issue_14456(): + raises(NotImplementedError, lambda: Limit(exp(x), x, zoo).doit()) + raises(NotImplementedError, lambda: Limit(x**2/(x+1), x, zoo).doit()) + + +def test_issue_14411(): + assert limit(3*sec(4*pi*x - x/3), x, 3*pi/(24*pi - 2)) is -oo + + +def test_issue_13382(): + assert limit(x*(((x + 1)**2 + 1)/(x**2 + 1) - 1), x, oo) == 2 + + +def test_issue_13403(): + assert limit(x*(-1 + (x + log(x + 1) + 1)/(x + log(x))), x, oo) == 1 + + +def test_issue_13416(): + assert limit((-x**3*log(x)**3 + (x - 1)*(x + 1)**2*log(x + 1)**3)/(x**2*log(x)**3), x, oo) == 1 + + +def test_issue_13462(): + assert limit(n**2*(2*n*(-(1 - 1/(2*n))**x + 1) - x - (-x**2/4 + x/4)/n), n, oo) == x**3/24 - x**2/8 + x/12 + + +def test_issue_13750(): + a = Symbol('a') + assert limit(erf(a - x), x, oo) == -1 + assert limit(erf(sqrt(x) - x), x, oo) == -1 + + +def test_issue_14276(): + assert isinstance(limit(sin(x)**log(x), x, oo), Limit) + assert isinstance(limit(sin(x)**cos(x), x, oo), Limit) + assert isinstance(limit(sin(log(cos(x))), x, oo), Limit) + assert limit((1 + 1/(x**2 + cos(x)))**(x**2 + x), x, oo) == E + + +def test_issue_14514(): + assert limit((1/(log(x)**log(x)))**(1/x), x, oo) == 1 + + +def test_issues_14525(): + assert limit(sin(x)**2 - cos(x) + tan(x)*csc(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(sin(x)**2 - cos(x) + sin(x)*cot(x), x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cot(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.Infinity) + assert limit(cos(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(sin(x) - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(cos(x)**2 - tan(x)**2, x, oo) == AccumBounds(S.NegativeInfinity, S.One) + assert limit(tan(x)**2 + sin(x)**2 - cos(x), x, oo) == AccumBounds(-S.One, S.Infinity) + + +def test_issue_14574(): + assert limit(sqrt(x)*cos(x - x**2) / (x + 1), x, oo) == 0 + + +def test_issue_10102(): + assert limit(fresnels(x), x, oo) == S.Half + assert limit(3 + fresnels(x), x, oo) == 3 + S.Half + assert limit(5*fresnels(x), x, oo) == Rational(5, 2) + assert limit(fresnelc(x), x, oo) == S.Half + assert limit(fresnels(x), x, -oo) == Rational(-1, 2) + assert limit(4*fresnelc(x), x, -oo) == -2 + + +def test_issue_14377(): + raises(NotImplementedError, lambda: limit(exp(I*x)*sin(pi*x), x, oo)) + + +def test_issue_15146(): + e = (x/2) * (-2*x**3 - 2*(x**3 - 1) * x**2 * digamma(x**3 + 1) + \ + 2*(x**3 - 1) * x**2 * digamma(x**3 + x + 1) + x + 3) + assert limit(e, x, oo) == S(1)/3 + + +def test_issue_15202(): + e = (2**x*(2 + 2**(-x)*(-2*2**x + x + 2))/(x + 1))**(x + 1) + assert limit(e, x, oo) == exp(1) + + e = (log(x, 2)**7 + 10*x*factorial(x) + 5**x) / (factorial(x + 1) + 3*factorial(x) + 10**x) + assert limit(e, x, oo) == 10 + + +def test_issue_15282(): + assert limit((x**2000 - (x + 1)**2000) / x**1999, x, oo) == -2000 + + +def test_issue_15984(): + assert limit((-x + log(exp(x) + 1))/x, x, oo, dir='-') == 0 + + +def test_issue_13571(): + assert limit(uppergamma(x, 1) / gamma(x), x, oo) == 1 + + +def test_issue_13575(): + assert limit(acos(erfi(x)), x, 1) == acos(erfi(S.One)) + + +def test_issue_17325(): + assert Limit(sin(x)/x, x, 0, dir="+-").doit() == 1 + assert Limit(x**2, x, 0, dir="+-").doit() == 0 + assert Limit(1/x**2, x, 0, dir="+-").doit() is oo + assert Limit(1/x, x, 0, dir="+-").doit() is zoo + + +def test_issue_10978(): + assert LambertW(x).limit(x, 0) == 0 + + +def test_issue_14313_comment(): + assert limit(floor(n/2), n, oo) is oo + + +@XFAIL +def test_issue_15323(): + d = ((1 - 1/x)**x).diff(x) + assert limit(d, x, 1, dir='+') == 1 + + +def test_issue_12571(): + assert limit(-LambertW(-log(x))/log(x), x, 1) == 1 + + +def test_issue_14590(): + assert limit((x**3*((x + 1)/x)**x)/((x + 1)*(x + 2)*(x + 3)), x, oo) == exp(1) + + +def test_issue_14393(): + a, b = symbols('a b') + assert limit((x**b - y**b)/(x**a - y**a), x, y) == b*y**(-a + b)/a + + +def test_issue_14556(): + assert limit(factorial(n + 1)**(1/(n + 1)) - factorial(n)**(1/n), n, oo) == exp(-1) + + +def test_issue_14811(): + assert limit(((1 + ((S(2)/3)**(x + 1)))**(2**x))/(2**((S(4)/3)**(x - 1))), x, oo) == oo + + +def test_issue_16222(): + assert limit(exp(x), x, 1000000000) == exp(1000000000) + + +def test_issue_16714(): + assert limit(((x**(x + 1) + (x + 1)**x) / x**(x + 1))**x, x, oo) == exp(exp(1)) + + +def test_issue_16722(): + z = symbols('z', positive=True) + assert limit(binomial(n + z, n)*n**-z, n, oo) == 1/gamma(z + 1) + z = symbols('z', positive=True, integer=True) + assert limit(binomial(n + z, n)*n**-z, n, oo) == 1/gamma(z + 1) + + +def test_issue_17431(): + assert limit(((n + 1) + 1) / (((n + 1) + 2) * factorial(n + 1)) * + (n + 2) * factorial(n) / (n + 1), n, oo) == 0 + assert limit((n + 2)**2*factorial(n)/((n + 1)*(n + 3)*factorial(n + 1)) + , n, oo) == 0 + assert limit((n + 1) * factorial(n) / (n * factorial(n + 1)), n, oo) == 0 + + +def test_issue_17671(): + assert limit(Ei(-log(x)) - log(log(x))/x, x, 1) == EulerGamma + + +def test_issue_17751(): + a, b, c, x = symbols('a b c x', positive=True) + assert limit((a + 1)*x - sqrt((a + 1)**2*x**2 + b*x + c), x, oo) == -b/(2*a + 2) + + +def test_issue_17792(): + assert limit(factorial(n)/sqrt(n)*(exp(1)/n)**n, n, oo) == sqrt(2)*sqrt(pi) + + +def test_issue_18118(): + assert limit(sign(sin(x)), x, 0, "-") == -1 + assert limit(sign(sin(x)), x, 0, "+") == 1 + + +def test_issue_18306(): + assert limit(sin(sqrt(x))/sqrt(sin(x)), x, 0, '+') == 1 + + +def test_issue_18378(): + assert limit(log(exp(3*x) + x)/log(exp(x) + x**100), x, oo) == 3 + + +def test_issue_18399(): + assert limit((1 - S(1)/2*x)**(3*x), x, oo) is zoo + assert limit((-x)**x, x, oo) is zoo + + +def test_issue_18442(): + assert limit(tan(x)**(2**(sqrt(pi))), x, oo, dir='-') == Limit(tan(x)**(2**(sqrt(pi))), x, oo, dir='-') + + +def test_issue_18452(): + assert limit(abs(log(x))**x, x, 0) == 1 + assert limit(abs(log(x))**x, x, 0, "-") == 1 + + +def test_issue_18473(): + assert limit(sin(x)**(1/x), x, oo) == Limit(sin(x)**(1/x), x, oo, dir='-') + assert limit(cos(x)**(1/x), x, oo) == Limit(cos(x)**(1/x), x, oo, dir='-') + assert limit(tan(x)**(1/x), x, oo) == Limit(tan(x)**(1/x), x, oo, dir='-') + assert limit((cos(x) + 2)**(1/x), x, oo) == 1 + assert limit((sin(x) + 10)**(1/x), x, oo) == 1 + assert limit((cos(x) - 2)**(1/x), x, oo) == Limit((cos(x) - 2)**(1/x), x, oo, dir='-') + assert limit((cos(x) + 1)**(1/x), x, oo) == AccumBounds(0, 1) + assert limit((tan(x)**2)**(2/x) , x, oo) == AccumBounds(0, oo) + assert limit((sin(x)**2)**(1/x), x, oo) == AccumBounds(0, 1) + # Tests for issue #23751 + assert limit((cos(x) + 1)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit((sin(x)**2)**(1/x), x, -oo) == AccumBounds(1, oo) + assert limit((tan(x)**2)**(2/x) , x, -oo) == AccumBounds(0, oo) + + +def test_issue_18482(): + assert limit((2*exp(3*x)/(exp(2*x) + 1))**(1/x), x, oo) == exp(1) + + +def test_issue_18508(): + assert limit(sin(x)/sqrt(1-cos(x)), x, 0) == sqrt(2) + assert limit(sin(x)/sqrt(1-cos(x)), x, 0, dir='+') == sqrt(2) + assert limit(sin(x)/sqrt(1-cos(x)), x, 0, dir='-') == -sqrt(2) + + +def test_issue_18521(): + raises(NotImplementedError, lambda: limit(exp((2 - n) * x), x, oo)) + + +def test_issue_18969(): + a, b = symbols('a b', positive=True) + assert limit(LambertW(a), a, b) == LambertW(b) + assert limit(exp(LambertW(a)), a, b) == exp(LambertW(b)) + + +def test_issue_18992(): + assert limit(n/(factorial(n)**(1/n)), n, oo) == exp(1) + + +def test_issue_19067(): + x = Symbol('x') + assert limit(gamma(x)/(gamma(x - 1)*gamma(x + 2)), x, 0) == -1 + + +def test_issue_19586(): + assert limit(x**(2**x*3**(-x)), x, oo) == 1 + + +def test_issue_13715(): + n = Symbol('n') + p = Symbol('p', zero=True) + assert limit(n + p, n, 0) == 0 + + +def test_issue_15055(): + assert limit(n**3*((-n - 1)*sin(1/n) + (n + 2)*sin(1/(n + 1)))/(-n + 1), n, oo) == 1 + + +def test_issue_16708(): + m, vi = symbols('m vi', positive=True) + B, ti, d = symbols('B ti d') + assert limit((B*ti*vi - sqrt(m)*sqrt(-2*B*d*vi + m*(vi)**2) + m*vi)/(B*vi), B, 0) == (d + ti*vi)/vi + + +def test_issue_19154(): + assert limit(besseli(1, 3 *x)/(x *besseli(1, x)**3), x , oo) == 2*sqrt(3)*pi/3 + assert limit(besseli(1, 3 *x)/(x *besseli(1, x)**3), x , -oo) == -2*sqrt(3)*pi/3 + + +def test_issue_19453(): + beta = Symbol("beta", positive=True) + h = Symbol("h", positive=True) + m = Symbol("m", positive=True) + w = Symbol("omega", positive=True) + g = Symbol("g", positive=True) + + e = exp(1) + q = 3*h**2*beta*g*e**(0.5*h*beta*w) + p = m**2*w**2 + s = e**(h*beta*w) - 1 + Z = -q/(4*p*s) - q/(2*p*s**2) - q*(e**(h*beta*w) + 1)/(2*p*s**3)\ + + e**(0.5*h*beta*w)/s + E = -diff(log(Z), beta) + + assert limit(E - 0.5*h*w, beta, oo) == 0 + assert limit(E.simplify() - 0.5*h*w, beta, oo) == 0 + + +def test_issue_19739(): + assert limit((-S(1)/4)**x, x, oo) == 0 + + +def test_issue_19766(): + assert limit(2**(-x)*sqrt(4**(x + 1) + 1), x, oo) == 2 + + +def test_issue_19770(): + m = Symbol('m') + # the result is not 0 for non-real m + assert limit(cos(m*x)/x, x, oo) == Limit(cos(m*x)/x, x, oo, dir='-') + m = Symbol('m', real=True) + # can be improved to give the correct result 0 + assert limit(cos(m*x)/x, x, oo) == Limit(cos(m*x)/x, x, oo, dir='-') + m = Symbol('m', nonzero=True) + assert limit(cos(m*x), x, oo) == AccumBounds(-1, 1) + assert limit(cos(m*x)/x, x, oo) == 0 + + +def test_issue_7535(): + assert limit(tan(x)/sin(tan(x)), x, pi/2) == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='+') + assert limit(tan(x)/sin(tan(x)), x, pi/2, dir='-') == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='-') + assert limit(tan(x)/sin(tan(x)), x, pi/2, dir='+-') == Limit(tan(x)/sin(tan(x)), x, pi/2, dir='+-') + assert limit(sin(tan(x)),x,pi/2) == AccumBounds(-1, 1) + assert -oo*(1/sin(-oo)) == AccumBounds(-oo, oo) + assert oo*(1/sin(oo)) == AccumBounds(-oo, oo) + assert oo*(1/sin(-oo)) == AccumBounds(-oo, oo) + assert -oo*(1/sin(oo)) == AccumBounds(-oo, oo) + + +def test_issue_20365(): + assert limit(((x + 1)**(1/x) - E)/x, x, 0) == -E/2 + + +def test_issue_21031(): + assert limit(((1 + x)**(1/x) - (1 + 2*x)**(1/(2*x)))/asin(x), x, 0) == E/2 + + +def test_issue_21038(): + assert limit(sin(pi*x)/(3*x - 12), x, 4) == pi/3 + + +def test_issue_20578(): + expr = abs(x) * sin(1/x) + assert limit(expr,x,0,'+') == 0 + assert limit(expr,x,0,'-') == 0 + assert limit(expr,x,0,'+-') == 0 + + +def test_issue_21227(): + f = log(x) + + assert f.nseries(x, logx=y) == y + assert f.nseries(x, logx=-x) == -x + + f = log(-log(x)) + + assert f.nseries(x, logx=y) == log(-y) + assert f.nseries(x, logx=-x) == log(x) + + f = log(log(x)) + + assert f.nseries(x, logx=y) == log(y) + assert f.nseries(x, logx=-x) == log(-x) + assert f.nseries(x, logx=x) == log(x) + + f = log(log(log(1/x))) + + assert f.nseries(x, logx=y) == log(log(-y)) + assert f.nseries(x, logx=-y) == log(log(y)) + assert f.nseries(x, logx=x) == log(log(-x)) + assert f.nseries(x, logx=-x) == log(log(x)) + + +def test_issue_21415(): + exp = (x-1)*cos(1/(x-1)) + assert exp.limit(x,1) == 0 + assert exp.expand().limit(x,1) == 0 + + +def test_issue_21530(): + assert limit(sinh(n + 1)/sinh(n), n, oo) == E + + +def test_issue_21550(): + r = (sqrt(5) - 1)/2 + assert limit((x - r)/(x**2 + x - 1), x, r) == sqrt(5)/5 + + +def test_issue_21661(): + out = limit((x**(x + 1) * (log(x) + 1) + 1) / x, x, 11) + assert out == S(3138428376722)/11 + 285311670611*log(11) + + +def test_issue_21701(): + assert limit((besselj(z, x)/x**z).subs(z, 7), x, 0) == S(1)/645120 + + +def test_issue_21721(): + a = Symbol('a', real=True) + I = integrate(1/(pi*(1 + (x - a)**2)), x) + assert I.limit(x, oo) == S.Half + + +def test_issue_21756(): + term = (1 - exp(-2*I*pi*z))/(1 - exp(-2*I*pi*z/5)) + assert term.limit(z, 0) == 5 + assert re(term).limit(z, 0) == 5 + + +def test_issue_21785(): + a = Symbol('a') + assert sqrt((-a**2 + x**2)/(1 - x**2)).limit(a, 1, '-') == I + + +def test_issue_22181(): + assert limit((-1)**x * 2**(-x), x, oo) == 0 + + +def test_issue_22220(): + e1 = sqrt(30)*atan(sqrt(30)*tan(x/2)/6)/30 + e2 = sqrt(30)*I*(-log(sqrt(2)*tan(x/2) - 2*sqrt(15)*I/5) + + +log(sqrt(2)*tan(x/2) + 2*sqrt(15)*I/5))/60 + + assert limit(e1, x, -pi) == -sqrt(30)*pi/60 + assert limit(e2, x, -pi) == -sqrt(30)*pi/30 + + assert limit(e1, x, -pi, '-') == sqrt(30)*pi/60 + assert limit(e2, x, -pi, '-') == 0 + + # test https://github.com/sympy/sympy/issues/22220#issuecomment-972727694 + expr = log(x - I) - log(-x - I) + expr2 = logcombine(expr, force=True) + assert limit(expr, x, oo) == limit(expr2, x, oo) == I*pi + + # test https://github.com/sympy/sympy/issues/22220#issuecomment-1077618340 + expr = expr = (-log(tan(x/2) - I) +log(tan(x/2) + I)) + assert limit(expr, x, pi, '+') == 2*I*pi + assert limit(expr, x, pi, '-') == 0 + + +def test_issue_22334(): + k, n = symbols('k, n', positive=True) + assert limit((n+1)**k/((n+1)**(k+1) - (n)**(k+1)), n, oo) == 1/(k + 1) + assert limit((n+1)**k/((n+1)**(k+1) - (n)**(k+1)).expand(), n, oo) == 1/(k + 1) + assert limit((n+1)**k/(n*(-n**k + (n + 1)**k) + (n + 1)**k), n, oo) == 1/(k + 1) + + +def test_sympyissue_22986(): + assert limit(acosh(1 + 1/x)*sqrt(x), x, oo) == sqrt(2) + + +def test_issue_23231(): + f = (2**x - 2**(-x))/(2**x + 2**(-x)) + assert limit(f, x, -oo) == -1 + + +def test_issue_23596(): + assert integrate(((1 + x)/x**2)*exp(-1/x), (x, 0, oo)) == oo + + +def test_issue_23752(): + expr1 = sqrt(-I*x**2 + x - 3) + expr2 = sqrt(-I*x**2 + I*x - 3) + assert limit(expr1, x, 0, '+') == -sqrt(3)*I + assert limit(expr1, x, 0, '-') == -sqrt(3)*I + assert limit(expr2, x, 0, '+') == sqrt(3)*I + assert limit(expr2, x, 0, '-') == -sqrt(3)*I + + +def test_issue_24276(): + fx = log(tan(pi/2*tanh(x))).diff(x) + assert fx.limit(x, oo) == 2 + assert fx.simplify().limit(x, oo) == 2 + assert fx.rewrite(sin).limit(x, oo) == 2 + assert fx.rewrite(sin).simplify().limit(x, oo) == 2 + +def test_issue_25230(): + a = Symbol('a', real = True) + b = Symbol('b', positive = True) + c = Symbol('c', negative = True) + n = Symbol('n', integer = True) + raises(NotImplementedError, lambda: limit(Mod(x, a), x, a)) + assert limit(Mod(x, b), x, n*b, '+') == 0 + assert limit(Mod(x, b), x, n*b, '-') == b + assert limit(Mod(x, c), x, n*c, '+') == c + assert limit(Mod(x, c), x, n*c, '-') == 0 + + +def test_issue_25582(): + + assert limit(asin(exp(x)), x, oo, '-') == -oo*I + assert limit(acos(exp(x)), x, oo, '-') == oo*I + assert limit(atan(exp(x)), x, oo, '-') == pi/2 + assert limit(acot(exp(x)), x, oo, '-') == 0 + assert limit(asec(exp(x)), x, oo, '-') == pi/2 + assert limit(acsc(exp(x)), x, oo, '-') == 0 + + +def test_issue_25847(): + #atan + assert limit(atan(sin(x)/x), x, 0, '+-') == pi/4 + assert limit(atan(exp(1/x)), x, 0, '+') == pi/2 + assert limit(atan(exp(1/x)), x, 0, '-') == 0 + + #asin + assert limit(asin(sin(x)/x), x, 0, '+-') == pi/2 + assert limit(asin(exp(1/x)), x, 0, '+') == -oo*I + assert limit(asin(exp(1/x)), x, 0, '-') == 0 + + #acos + assert limit(acos(sin(x)/x), x, 0, '+-') == 0 + assert limit(acos(exp(1/x)), x, 0, '+') == oo*I + assert limit(acos(exp(1/x)), x, 0, '-') == pi/2 + + #acot + assert limit(acot(sin(x)/x), x, 0, '+-') == pi/4 + assert limit(acot(exp(1/x)), x, 0, '+') == 0 + assert limit(acot(exp(1/x)), x, 0, '-') == pi/2 + + #asec + assert limit(asec(sin(x)/x), x, 0, '+-') == 0 + assert limit(asec(exp(1/x)), x, 0, '+') == pi/2 + assert limit(asec(exp(1/x)), x, 0, '-') == oo*I + + #acsc + assert limit(acsc(sin(x)/x), x, 0, '+-') == pi/2 + assert limit(acsc(exp(1/x)), x, 0, '+') == 0 + assert limit(acsc(exp(1/x)), x, 0, '-') == -oo*I + + #atanh + assert limit(atanh(sin(x)/x), x, 0, '+-') == oo + assert limit(atanh(exp(1/x)), x, 0, '+') == -I*pi/2 + assert limit(atanh(exp(1/x)), x, 0, '-') == 0 + + #asinh + assert limit(asinh(sin(x)/x), x, 0, '+-') == log(1 + sqrt(2)) + assert limit(asinh(exp(1/x)), x, 0, '+') == oo + assert limit(asinh(exp(1/x)), x, 0, '-') == 0 + + #acosh + assert limit(acosh(sin(x)/x), x, 0, '+-') == 0 + assert limit(acosh(exp(1/x)), x, 0, '+') == oo + assert limit(acosh(exp(1/x)), x, 0, '-') == I*pi/2 + + #acoth + assert limit(acoth(sin(x)/x), x, 0, '+-') == oo + assert limit(acoth(exp(1/x)), x, 0, '+') == 0 + assert limit(acoth(exp(1/x)), x, 0, '-') == -I*pi/2 + + #asech + assert limit(asech(sin(x)/x), x, 0, '+-') == 0 + assert limit(asech(exp(1/x)), x, 0, '+') == I*pi/2 + assert limit(asech(exp(1/x)), x, 0, '-') == oo + + #acsch + assert limit(acsch(sin(x)/x), x, 0, '+-') == log(1 + sqrt(2)) + assert limit(acsch(exp(1/x)), x, 0, '+') == 0 + assert limit(acsch(exp(1/x)), x, 0, '-') == oo + + +def test_issue_26040(): + assert limit(besseli(0, x + 1)/besseli(0, x), x, oo) == S.Exp1 + + +def test_issue_26250(): + e = elliptic_e(4*x/(x**2 + 2*x + 1)) + k = elliptic_k(4*x/(x**2 + 2*x + 1)) + e1 = ((1-3*x**2)*e**2/2 - (x**2-2*x+1)*e*k/2) + e2 = pi**2*(x**8 - 2*x**7 - x**6 + 4*x**5 - x**4 - 2*x**3 + x**2) + assert limit(e1/e2, x, 0) == -S(1)/8 diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_limitseq.py b/lib/python3.10/site-packages/sympy/series/tests/test_limitseq.py new file mode 100644 index 0000000000000000000000000000000000000000..362bb0397feb0ec63929920855c81279eca0bd6a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_limitseq.py @@ -0,0 +1,177 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (binomial, factorial, subfactorial) +from sympy.functions.combinatorial.numbers import (fibonacci, harmonic) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.series.limitseq import limit_seq +from sympy.series.limitseq import difference_delta as dd +from sympy.testing.pytest import raises, XFAIL +from sympy.calculus.accumulationbounds import AccumulationBounds + +n, m, k = symbols('n m k', integer=True) + + +def test_difference_delta(): + e = n*(n + 1) + e2 = e * k + + assert dd(e) == 2*n + 2 + assert dd(e2, n, 2) == k*(4*n + 6) + + raises(ValueError, lambda: dd(e2)) + raises(ValueError, lambda: dd(e2, n, oo)) + + +def test_difference_delta__Sum(): + e = Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1/(n + 1) + assert dd(e, n, 5) == Add(*[1/(i + n + 1) for i in range(5)]) + + e = Sum(1/k, (k, 1, 3*n)) + assert dd(e, n) == Add(*[1/(i + 3*n + 1) for i in range(3)]) + + e = n * Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1 + Sum(1/k, (k, 1, n)) + + e = Sum(1/k, (k, 1, n), (m, 1, n)) + assert dd(e, n) == harmonic(n) + + +def test_difference_delta__Add(): + e = n + n*(n + 1) + assert dd(e, n) == 2*n + 3 + assert dd(e, n, 2) == 4*n + 8 + + e = n + Sum(1/k, (k, 1, n)) + assert dd(e, n) == 1 + 1/(n + 1) + assert dd(e, n, 5) == 5 + Add(*[1/(i + n + 1) for i in range(5)]) + + +def test_difference_delta__Pow(): + e = 4**n + assert dd(e, n) == 3*4**n + assert dd(e, n, 2) == 15*4**n + + e = 4**(2*n) + assert dd(e, n) == 15*4**(2*n) + assert dd(e, n, 2) == 255*4**(2*n) + + e = n**4 + assert dd(e, n) == (n + 1)**4 - n**4 + + e = n**n + assert dd(e, n) == (n + 1)**(n + 1) - n**n + + +def test_limit_seq(): + e = binomial(2*n, n) / Sum(binomial(2*k, k), (k, 1, n)) + assert limit_seq(e) == S(3) / 4 + assert limit_seq(e, m) == e + + e = (5*n**3 + 3*n**2 + 4) / (3*n**3 + 4*n - 5) + assert limit_seq(e, n) == S(5) / 3 + + e = (harmonic(n) * Sum(harmonic(k), (k, 1, n))) / (n * harmonic(2*n)**2) + assert limit_seq(e, n) == 1 + + e = Sum(k**2 * Sum(2**m/m, (m, 1, k)), (k, 1, n)) / (2**n*n) + assert limit_seq(e, n) == 4 + + e = (Sum(binomial(3*k, k) * binomial(5*k, k), (k, 1, n)) / + (binomial(3*n, n) * binomial(5*n, n))) + assert limit_seq(e, n) == S(84375) / 83351 + + e = Sum(harmonic(k)**2/k, (k, 1, 2*n)) / harmonic(n)**3 + assert limit_seq(e, n) == S.One / 3 + + raises(ValueError, lambda: limit_seq(e * m)) + + +def test_alternating_sign(): + assert limit_seq((-1)**n/n**2, n) == 0 + assert limit_seq((-2)**(n+1)/(n + 3**n), n) == 0 + assert limit_seq((2*n + (-1)**n)/(n + 1), n) == 2 + assert limit_seq(sin(pi*n), n) == 0 + assert limit_seq(cos(2*pi*n), n) == 1 + assert limit_seq((S.NegativeOne/5)**n, n) == 0 + assert limit_seq((Rational(-1, 5))**n, n) == 0 + assert limit_seq((I/3)**n, n) == 0 + assert limit_seq(sqrt(n)*(I/2)**n, n) == 0 + assert limit_seq(n**7*(I/3)**n, n) == 0 + assert limit_seq(n/(n + 1) + (I/2)**n, n) == 1 + + +def test_accum_bounds(): + assert limit_seq((-1)**n, n) == AccumulationBounds(-1, 1) + assert limit_seq(cos(pi*n), n) == AccumulationBounds(-1, 1) + assert limit_seq(sin(pi*n/2)**2, n) == AccumulationBounds(0, 1) + assert limit_seq(2*(-3)**n/(n + 3**n), n) == AccumulationBounds(-2, 2) + assert limit_seq(3*n/(n + 1) + 2*(-1)**n, n) == AccumulationBounds(1, 5) + + +def test_limitseq_sum(): + from sympy.abc import x, y, z + assert limit_seq(Sum(1/x, (x, 1, y)) - log(y), y) == S.EulerGamma + assert limit_seq(Sum(1/x, (x, 1, y)) - 1/y, y) is S.Infinity + assert (limit_seq(binomial(2*x, x) / Sum(binomial(2*y, y), (y, 1, x)), x) == + S(3) / 4) + assert (limit_seq(Sum(y**2 * Sum(2**z/z, (z, 1, y)), (y, 1, x)) / + (2**x*x), x) == 4) + + +def test_issue_9308(): + assert limit_seq(subfactorial(n)/factorial(n), n) == exp(-1) + + +def test_issue_10382(): + n = Symbol('n', integer=True) + assert limit_seq(fibonacci(n+1)/fibonacci(n), n).together() == S.GoldenRatio + + +def test_issue_11672(): + assert limit_seq(Rational(-1, 2)**n, n) == 0 + + +def test_issue_14196(): + k, n = symbols('k, n', positive=True) + m = Symbol('m') + assert limit_seq(Sum(m**k, (m, 1, n)).doit()/(n**(k + 1)), n) == 1/(k + 1) + + +def test_issue_16735(): + assert limit_seq(5**n/factorial(n), n) == 0 + + +def test_issue_19868(): + assert limit_seq(1/gamma(n + S.One/2), n) == 0 + + +@XFAIL +def test_limit_seq_fail(): + # improve Summation algorithm or add ad-hoc criteria + e = (harmonic(n)**3 * Sum(1/harmonic(k), (k, 1, n)) / + (n * Sum(harmonic(k)/k, (k, 1, n)))) + assert limit_seq(e, n) == 2 + + # No unique dominant term + e = (Sum(2**k * binomial(2*k, k) / k**2, (k, 1, n)) / + (Sum(2**k/k*2, (k, 1, n)) * Sum(binomial(2*k, k), (k, 1, n)))) + assert limit_seq(e, n) == S(3) / 7 + + # Simplifications of summations needs to be improved. + e = n**3*Sum(2**k/k**2, (k, 1, n))**2 / (2**n * Sum(2**k/k, (k, 1, n))) + assert limit_seq(e, n) == 2 + + e = (harmonic(n) * Sum(2**k/k, (k, 1, n)) / + (n * Sum(2**k*harmonic(k)/k**2, (k, 1, n)))) + assert limit_seq(e, n) == 1 + + e = (Sum(2**k*factorial(k) / k**2, (k, 1, 2*n)) / + (Sum(4**k/k**2, (k, 1, n)) * Sum(factorial(k), (k, 1, 2*n)))) + assert limit_seq(e, n) == S(3) / 16 diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_lseries.py b/lib/python3.10/site-packages/sympy/series/tests/test_lseries.py new file mode 100644 index 0000000000000000000000000000000000000000..42d327bf60c76eebdc4570d631efef4bc84b58e3 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_lseries.py @@ -0,0 +1,65 @@ +from sympy.core.numbers import E +from sympy.core.singleton import S +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.series.order import Order +from sympy.abc import x, y + + +def test_sin(): + e = sin(x).lseries(x) + assert next(e) == x + assert next(e) == -x**3/6 + assert next(e) == x**5/120 + + +def test_cos(): + e = cos(x).lseries(x) + assert next(e) == 1 + assert next(e) == -x**2/2 + assert next(e) == x**4/24 + + +def test_exp(): + e = exp(x).lseries(x) + assert next(e) == 1 + assert next(e) == x + assert next(e) == x**2/2 + assert next(e) == x**3/6 + + +def test_exp2(): + e = exp(cos(x)).lseries(x) + assert next(e) == E + assert next(e) == -E*x**2/2 + assert next(e) == E*x**4/6 + assert next(e) == -31*E*x**6/720 + + +def test_simple(): + assert list(x.lseries()) == [x] + assert list(S.One.lseries(x)) == [1] + assert not next((x/(x + y)).lseries(y)).has(Order) + + +def test_issue_5183(): + s = (x + 1/x).lseries() + assert list(s) == [1/x, x] + assert next((x + x**2).lseries()) == x + assert next(((1 + x)**7).lseries(x)) == 1 + assert next((sin(x + y)).series(x, n=3).lseries(y)) == x + # it would be nice if all terms were grouped, but in the + # following case that would mean that all the terms would have + # to be known since, for example, every term has a constant in it. + s = ((1 + x)**7).series(x, 1, n=None) + assert [next(s) for i in range(2)] == [128, -448 + 448*x] + + +def test_issue_6999(): + s = tanh(x).lseries(x, 1) + assert next(s) == tanh(1) + assert next(s) == x - (x - 1)*tanh(1)**2 - 1 + assert next(s) == -(x - 1)**2*tanh(1) + (x - 1)**2*tanh(1)**3 + assert next(s) == -(x - 1)**3*tanh(1)**4 - (x - 1)**3/3 + \ + 4*(x - 1)**3*tanh(1)**2/3 diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_nseries.py b/lib/python3.10/site-packages/sympy/series/tests/test_nseries.py new file mode 100644 index 0000000000000000000000000000000000000000..a2f20add82d3e858e2ce145fc9fcd4a6548a48cc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_nseries.py @@ -0,0 +1,557 @@ +from sympy.calculus.util import AccumBounds +from sympy.core.function import (Derivative, PoleError) +from sympy.core.numbers import (E, I, Integer, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (acosh, acoth, asinh, atanh, cosh, coth, sinh, tanh) +from sympy.functions.elementary.integers import (ceiling, floor, frac) +from sympy.functions.elementary.miscellaneous import (cbrt, sqrt) +from sympy.functions.elementary.trigonometric import (asin, cos, cot, sin, tan) +from sympy.series.limits import limit +from sympy.series.order import O +from sympy.abc import x, y, z + +from sympy.testing.pytest import raises, XFAIL + + +def test_simple_1(): + assert x.nseries(x, n=5) == x + assert y.nseries(x, n=5) == y + assert (1/(x*y)).nseries(y, n=5) == 1/(x*y) + assert Rational(3, 4).nseries(x, n=5) == Rational(3, 4) + assert x.nseries() == x + + +def test_mul_0(): + assert (x*log(x)).nseries(x, n=5) == x*log(x) + + +def test_mul_1(): + assert (x*log(2 + x)).nseries(x, n=5) == x*log(2) + x**2/2 - x**3/8 + \ + x**4/24 + O(x**5) + assert (x*log(1 + x)).nseries( + x, n=5) == x**2 - x**3/2 + x**4/3 + O(x**5) + + +def test_pow_0(): + assert (x**2).nseries(x, n=5) == x**2 + assert (1/x).nseries(x, n=5) == 1/x + assert (1/x**2).nseries(x, n=5) == 1/x**2 + assert (x**Rational(2, 3)).nseries(x, n=5) == (x**Rational(2, 3)) + assert (sqrt(x)**3).nseries(x, n=5) == (sqrt(x)**3) + + +def test_pow_1(): + assert ((1 + x)**2).nseries(x, n=5) == x**2 + 2*x + 1 + + # https://github.com/sympy/sympy/issues/21075 + assert ((sqrt(x) + 1)**2).nseries(x) == 2*sqrt(x) + x + 1 + assert ((sqrt(x) + cbrt(x))**2).nseries(x) == 2*x**Rational(5, 6)\ + + x**Rational(2, 3) + x + + +def test_geometric_1(): + assert (1/(1 - x)).nseries(x, n=5) == 1 + x + x**2 + x**3 + x**4 + O(x**5) + assert (x/(1 - x)).nseries(x, n=6) == x + x**2 + x**3 + x**4 + x**5 + O(x**6) + assert (x**3/(1 - x)).nseries(x, n=8) == x**3 + x**4 + x**5 + x**6 + \ + x**7 + O(x**8) + + +def test_sqrt_1(): + assert sqrt(1 + x).nseries(x, n=5) == 1 + x/2 - x**2/8 + x**3/16 - 5*x**4/128 + O(x**5) + + +def test_exp_1(): + assert exp(x).nseries(x, n=5) == 1 + x + x**2/2 + x**3/6 + x**4/24 + O(x**5) + assert exp(x).nseries(x, n=12) == 1 + x + x**2/2 + x**3/6 + x**4/24 + x**5/120 + \ + x**6/720 + x**7/5040 + x**8/40320 + x**9/362880 + x**10/3628800 + \ + x**11/39916800 + O(x**12) + assert exp(1/x).nseries(x, n=5) == exp(1/x) + assert exp(1/(1 + x)).nseries(x, n=4) == \ + (E*(1 - x - 13*x**3/6 + 3*x**2/2)).expand() + O(x**4) + assert exp(2 + x).nseries(x, n=5) == \ + (exp(2)*(1 + x + x**2/2 + x**3/6 + x**4/24)).expand() + O(x**5) + + +def test_exp_sqrt_1(): + assert exp(1 + sqrt(x)).nseries(x, n=3) == \ + (exp(1)*(1 + sqrt(x) + x/2 + sqrt(x)*x/6)).expand() + O(sqrt(x)**3) + + +def test_power_x_x1(): + assert (exp(x*log(x))).nseries(x, n=4) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + x**3*log(x)**3/6 + O(x**4*log(x)**4) + + +def test_power_x_x2(): + assert (x**x).nseries(x, n=4) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + x**3*log(x)**3/6 + O(x**4*log(x)**4) + + +def test_log_singular1(): + assert log(1 + 1/x).nseries(x, n=5) == x - log(x) - x**2/2 + x**3/3 - \ + x**4/4 + O(x**5) + + +def test_log_power1(): + e = 1 / (1/x + x ** (log(3)/log(2))) + assert e.nseries(x, n=5) == -x**(log(3)/log(2) + 2) + x + O(x**5) + + +def test_log_series(): + l = Symbol('l') + e = 1/(1 - log(x)) + assert e.nseries(x, n=5, logx=l) == 1/(1 - l) + + +def test_log2(): + e = log(-1/x) + assert e.nseries(x, n=5) == -log(x) + log(-1) + + +def test_log3(): + l = Symbol('l') + e = 1/log(-1/x) + assert e.nseries(x, n=4, logx=l) == 1/(-l + log(-1)) + + +def test_series1(): + e = sin(x) + assert e.nseries(x, 0, 0) != 0 + assert e.nseries(x, 0, 0) == O(1, x) + assert e.nseries(x, 0, 1) == O(x, x) + assert e.nseries(x, 0, 2) == x + O(x**2, x) + assert e.nseries(x, 0, 3) == x + O(x**3, x) + assert e.nseries(x, 0, 4) == x - x**3/6 + O(x**4, x) + + e = (exp(x) - 1)/x + assert e.nseries(x, 0, 3) == 1 + x/2 + x**2/6 + O(x**3) + + assert x.nseries(x, 0, 2) == x + + +@XFAIL +def test_series1_failing(): + assert x.nseries(x, 0, 0) == O(1, x) + assert x.nseries(x, 0, 1) == O(x, x) + + +def test_seriesbug1(): + assert (1/x).nseries(x, 0, 3) == 1/x + assert (x + 1/x).nseries(x, 0, 3) == x + 1/x + + +def test_series2x(): + assert ((x + 1)**(-2)).nseries(x, 0, 4) == 1 - 2*x + 3*x**2 - 4*x**3 + O(x**4, x) + assert ((x + 1)**(-1)).nseries(x, 0, 4) == 1 - x + x**2 - x**3 + O(x**4, x) + assert ((x + 1)**0).nseries(x, 0, 3) == 1 + assert ((x + 1)**1).nseries(x, 0, 3) == 1 + x + assert ((x + 1)**2).nseries(x, 0, 3) == x**2 + 2*x + 1 + assert ((x + 1)**3).nseries(x, 0, 3) == 1 + 3*x + 3*x**2 + O(x**3) + + assert (1/(1 + x)).nseries(x, 0, 4) == 1 - x + x**2 - x**3 + O(x**4, x) + assert (x + 3/(1 + 2*x)).nseries(x, 0, 4) == 3 - 5*x + 12*x**2 - 24*x**3 + O(x**4, x) + + assert ((1/x + 1)**3).nseries(x, 0, 3) == 1 + 3/x + 3/x**2 + x**(-3) + assert (1/(1 + 1/x)).nseries(x, 0, 4) == x - x**2 + x**3 - O(x**4, x) + assert (1/(1 + 1/x**2)).nseries(x, 0, 6) == x**2 - x**4 + O(x**6, x) + + +def test_bug2(): # 1/log(0)*log(0) problem + w = Symbol("w") + e = (w**(-1) + w**( + -log(3)*log(2)**(-1)))**(-1)*(3*w**(-log(3)*log(2)**(-1)) + 2*w**(-1)) + e = e.expand() + assert e.nseries(w, 0, 4).subs(w, 0) == 3 + + +def test_exp(): + e = (1 + x)**(1/x) + assert e.nseries(x, n=3) == exp(1) - x*exp(1)/2 + 11*exp(1)*x**2/24 + O(x**3) + + +def test_exp2(): + w = Symbol("w") + e = w**(1 - log(x)/(log(2) + log(x))) + logw = Symbol("logw") + assert e.nseries( + w, 0, 1, logx=logw) == exp(logw*log(2)/(log(x) + log(2))) + + +def test_bug3(): + e = (2/x + 3/x**2)/(1/x + 1/x**2) + assert e.nseries(x, n=3) == 3 - x + x**2 + O(x**3) + + +def test_generalexponent(): + p = 2 + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 3) == 3 - x + x**2 + O(x**3) + p = S.Half + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 2) == 2 - x + sqrt(x) + x**(S(3)/2) + O(x**2) + + e = 1 + sqrt(x) + assert e.nseries(x, 0, 4) == 1 + sqrt(x) + +# more complicated example + + +def test_genexp_x(): + e = 1/(1 + sqrt(x)) + assert e.nseries(x, 0, 2) == \ + 1 + x - sqrt(x) - sqrt(x)**3 + O(x**2, x) + +# more complicated example + + +def test_genexp_x2(): + p = Rational(3, 2) + e = (2/x + 3/x**p)/(1/x + 1/x**p) + assert e.nseries(x, 0, 3) == 3 + x + x**2 - sqrt(x) - x**(S(3)/2) - x**(S(5)/2) + O(x**3) + + +def test_seriesbug2(): + w = Symbol("w") + #simple case (1): + e = ((2*w)/w)**(1 + w) + assert e.nseries(w, 0, 1) == 2 + O(w, w) + assert e.nseries(w, 0, 1).subs(w, 0) == 2 + + +def test_seriesbug2b(): + w = Symbol("w") + #test sin + e = sin(2*w)/w + assert e.nseries(w, 0, 3) == 2 - 4*w**2/3 + O(w**3) + + +def test_seriesbug2d(): + w = Symbol("w", real=True) + e = log(sin(2*w)/w) + assert e.series(w, n=5) == log(2) - 2*w**2/3 - 4*w**4/45 + O(w**5) + + +def test_seriesbug2c(): + w = Symbol("w", real=True) + #more complicated case, but sin(x)~x, so the result is the same as in (1) + e = (sin(2*w)/w)**(1 + w) + assert e.series(w, 0, 1) == 2 + O(w) + assert e.series(w, 0, 3) == 2 + 2*w*log(2) + \ + w**2*(Rational(-4, 3) + log(2)**2) + O(w**3) + assert e.series(w, 0, 2).subs(w, 0) == 2 + + +def test_expbug4(): + x = Symbol("x", real=True) + assert (log( + sin(2*x)/x)*(1 + x)).series(x, 0, 2) == log(2) + x*log(2) + O(x**2, x) + assert exp( + log(sin(2*x)/x)*(1 + x)).series(x, 0, 2) == 2 + 2*x*log(2) + O(x**2) + + assert exp(log(2) + O(x)).nseries(x, 0, 2) == 2 + O(x) + assert ((2 + O(x))**(1 + x)).nseries(x, 0, 2) == 2 + O(x) + + +def test_logbug4(): + assert log(2 + O(x)).nseries(x, 0, 2) == log(2) + O(x, x) + + +def test_expbug5(): + assert exp(log(1 + x)/x).nseries(x, n=3) == exp(1) + -exp(1)*x/2 + 11*exp(1)*x**2/24 + O(x**3) + + assert exp(O(x)).nseries(x, 0, 2) == 1 + O(x) + + +def test_sinsinbug(): + assert sin(sin(x)).nseries(x, 0, 8) == x - x**3/3 + x**5/10 - 8*x**7/315 + O(x**8) + + +def test_issue_3258(): + a = x/(exp(x) - 1) + assert a.nseries(x, 0, 5) == 1 - x/2 - x**4/720 + x**2/12 + O(x**5) + + +def test_issue_3204(): + x = Symbol("x", nonnegative=True) + f = sin(x**3)**Rational(1, 3) + assert f.nseries(x, 0, 17) == x - x**7/18 - x**13/3240 + O(x**17) + + +def test_issue_3224(): + f = sqrt(1 - sqrt(y)) + assert f.nseries(y, 0, 2) == 1 - sqrt(y)/2 - y/8 - sqrt(y)**3/16 + O(y**2) + + +def test_issue_3463(): + w, i = symbols('w,i') + r = log(5)/log(3) + p = w**(-1 + r) + e = 1/x*(-log(w**(1 + r)) + log(w + w**r)) + e_ser = -r*log(w)/x + p/x - p**2/(2*x) + O(w) + assert e.nseries(w, n=1) == e_ser + + +def test_sin(): + assert sin(8*x).nseries(x, n=4) == 8*x - 256*x**3/3 + O(x**4) + assert sin(x + y).nseries(x, n=1) == sin(y) + O(x) + assert sin(x + y).nseries(x, n=2) == sin(y) + cos(y)*x + O(x**2) + assert sin(x + y).nseries(x, n=5) == sin(y) + cos(y)*x - sin(y)*x**2/2 - \ + cos(y)*x**3/6 + sin(y)*x**4/24 + O(x**5) + + +def test_issue_3515(): + e = sin(8*x)/x + assert e.nseries(x, n=6) == 8 - 256*x**2/3 + 4096*x**4/15 + O(x**6) + + +def test_issue_3505(): + e = sin(x)**(-4)*(sqrt(cos(x))*sin(x)**2 - + cos(x)**Rational(1, 3)*sin(x)**2) + assert e.nseries(x, n=9) == Rational(-1, 12) - 7*x**2/288 - \ + 43*x**4/10368 - 1123*x**6/2488320 + 377*x**8/29859840 + O(x**9) + + +def test_issue_3501(): + a = Symbol("a") + e = x**(-2)*(x*sin(a + x) - x*sin(a)) + assert e.nseries(x, n=6) == cos(a) - sin(a)*x/2 - cos(a)*x**2/6 + \ + x**3*sin(a)/24 + x**4*cos(a)/120 - x**5*sin(a)/720 + O(x**6) + e = x**(-2)*(x*cos(a + x) - x*cos(a)) + assert e.nseries(x, n=6) == -sin(a) - cos(a)*x/2 + sin(a)*x**2/6 + \ + cos(a)*x**3/24 - x**4*sin(a)/120 - x**5*cos(a)/720 + O(x**6) + + +def test_issue_3502(): + e = sin(5*x)/sin(2*x) + assert e.nseries(x, n=2) == Rational(5, 2) + O(x**2) + assert e.nseries(x, n=6) == \ + Rational(5, 2) - 35*x**2/4 + 329*x**4/48 + O(x**6) + + +def test_issue_3503(): + e = sin(2 + x)/(2 + x) + assert e.nseries(x, n=2) == sin(2)/2 + x*cos(2)/2 - x*sin(2)/4 + O(x**2) + + +def test_issue_3506(): + e = (x + sin(3*x))**(-2)*(x*(x + sin(3*x)) - (x + sin(3*x))*sin(2*x)) + assert e.nseries(x, n=7) == \ + Rational(-1, 4) + 5*x**2/96 + 91*x**4/768 + 11117*x**6/129024 + O(x**7) + + +def test_issue_3508(): + x = Symbol("x", real=True) + assert log(sin(x)).series(x, n=5) == log(x) - x**2/6 - x**4/180 + O(x**5) + e = -log(x) + x*(-log(x) + log(sin(2*x))) + log(sin(2*x)) + assert e.series(x, n=5) == \ + log(2) + log(2)*x - 2*x**2/3 - 2*x**3/3 - 4*x**4/45 + O(x**5) + + +def test_issue_3507(): + e = x**(-4)*(x**2 - x**2*sqrt(cos(x))) + assert e.nseries(x, n=9) == \ + Rational(1, 4) + x**2/96 + 19*x**4/5760 + 559*x**6/645120 + 29161*x**8/116121600 + O(x**9) + + +def test_issue_3639(): + assert sin(cos(x)).nseries(x, n=5) == \ + sin(1) - x**2*cos(1)/2 - x**4*sin(1)/8 + x**4*cos(1)/24 + O(x**5) + + +def test_hyperbolic(): + assert sinh(x).nseries(x, n=6) == x + x**3/6 + x**5/120 + O(x**6) + assert cosh(x).nseries(x, n=5) == 1 + x**2/2 + x**4/24 + O(x**5) + assert tanh(x).nseries(x, n=6) == x - x**3/3 + 2*x**5/15 + O(x**6) + assert coth(x).nseries(x, n=6) == \ + 1/x - x**3/45 + x/3 + 2*x**5/945 + O(x**6) + assert asinh(x).nseries(x, n=6) == x - x**3/6 + 3*x**5/40 + O(x**6) + assert acosh(x).nseries(x, n=6) == \ + pi*I/2 - I*x - 3*I*x**5/40 - I*x**3/6 + O(x**6) + assert atanh(x).nseries(x, n=6) == x + x**3/3 + x**5/5 + O(x**6) + assert acoth(x).nseries(x, n=6) == -I*pi/2 + x + x**3/3 + x**5/5 + O(x**6) + + +def test_series2(): + w = Symbol("w", real=True) + x = Symbol("x", real=True) + e = w**(-2)*(w*exp(1/x - w) - w*exp(1/x)) + assert e.nseries(w, n=4) == -exp(1/x) + w*exp(1/x)/2 - w**2*exp(1/x)/6 + w**3*exp(1/x)/24 + O(w**4) + + +def test_series3(): + w = Symbol("w", real=True) + e = w**(-6)*(w**3*tan(w) - w**3*sin(w)) + assert e.nseries(w, n=8) == Integer(1)/2 + w**2/8 + 13*w**4/240 + 529*w**6/24192 + O(w**8) + + +def test_bug4(): + w = Symbol("w") + e = x/(w**4 + x**2*w**4 + 2*x*w**4)*w**4 + assert e.nseries(w, n=2).removeO().expand() in [x/(1 + 2*x + x**2), + 1/(1 + x/2 + 1/x/2)/2, 1/x/(1 + 2/x + x**(-2))] + + +def test_bug5(): + w = Symbol("w") + l = Symbol('l') + e = (-log(w) + log(1 + w*log(x)))**(-2)*w**(-2)*((-log(w) + + log(1 + x*w))*(-log(w) + log(1 + w*log(x)))*w - x*(-log(w) + + log(1 + w*log(x)))*w) + assert e.nseries(w, n=0, logx=l) == x/w/l + 1/w + O(1, w) + assert e.nseries(w, n=1, logx=l) == x/w/l + 1/w - x/l + 1/l*log(x) \ + + x*log(x)/l**2 + O(w) + + +def test_issue_4115(): + assert (sin(x)/(1 - cos(x))).nseries(x, n=1) == 2/x + O(x) + assert (sin(x)**2/(1 - cos(x))).nseries(x, n=1) == 2 + O(x) + + +def test_pole(): + raises(PoleError, lambda: sin(1/x).series(x, 0, 5)) + raises(PoleError, lambda: sin(1 + 1/x).series(x, 0, 5)) + raises(PoleError, lambda: (x*sin(1/x)).series(x, 0, 5)) + + +def test_expsinbug(): + assert exp(sin(x)).series(x, 0, 0) == O(1, x) + assert exp(sin(x)).series(x, 0, 1) == 1 + O(x) + assert exp(sin(x)).series(x, 0, 2) == 1 + x + O(x**2) + assert exp(sin(x)).series(x, 0, 3) == 1 + x + x**2/2 + O(x**3) + assert exp(sin(x)).series(x, 0, 4) == 1 + x + x**2/2 + O(x**4) + assert exp(sin(x)).series(x, 0, 5) == 1 + x + x**2/2 - x**4/8 + O(x**5) + + +def test_floor(): + x = Symbol('x') + assert floor(x).series(x) == 0 + assert floor(-x).series(x) == -1 + assert floor(sin(x)).series(x) == 0 + assert floor(sin(-x)).series(x) == -1 + assert floor(x**3).series(x) == 0 + assert floor(-x**3).series(x) == -1 + assert floor(cos(x)).series(x) == 0 + assert floor(cos(-x)).series(x) == 0 + assert floor(5 + sin(x)).series(x) == 5 + assert floor(5 + sin(-x)).series(x) == 4 + + assert floor(x).series(x, 2) == 2 + assert floor(-x).series(x, 2) == -3 + + x = Symbol('x', negative=True) + assert floor(x + 1.5).series(x) == 1 + + +def test_frac(): + assert frac(x).series(x, cdir=1) == x + assert frac(x).series(x, cdir=-1) == 1 + x + assert frac(2*x + 1).series(x, cdir=1) == 2*x + assert frac(2*x + 1).series(x, cdir=-1) == 1 + 2*x + assert frac(x**2).series(x, cdir=1) == x**2 + assert frac(x**2).series(x, cdir=-1) == x**2 + assert frac(sin(x) + 5).series(x, cdir=1) == x - x**3/6 + x**5/120 + O(x**6) + assert frac(sin(x) + 5).series(x, cdir=-1) == 1 + x - x**3/6 + x**5/120 + O(x**6) + assert frac(sin(x) + S.Half).series(x) == S.Half + x - x**3/6 + x**5/120 + O(x**6) + assert frac(x**8).series(x, cdir=1) == O(x**6) + assert frac(1/x).series(x) == AccumBounds(0, 1) + O(x**6) + + +def test_ceiling(): + assert ceiling(x).series(x) == 1 + assert ceiling(-x).series(x) == 0 + assert ceiling(sin(x)).series(x) == 1 + assert ceiling(sin(-x)).series(x) == 0 + assert ceiling(1 - cos(x)).series(x) == 1 + assert ceiling(1 - cos(-x)).series(x) == 1 + assert ceiling(x).series(x, 2) == 3 + assert ceiling(-x).series(x, 2) == -2 + + +def test_abs(): + a = Symbol('a') + assert abs(x).nseries(x, n=4) == x + assert abs(-x).nseries(x, n=4) == x + assert abs(x + 1).nseries(x, n=4) == x + 1 + assert abs(sin(x)).nseries(x, n=4) == x - Rational(1, 6)*x**3 + O(x**4) + assert abs(sin(-x)).nseries(x, n=4) == x - Rational(1, 6)*x**3 + O(x**4) + assert abs(x - a).nseries(x, 1) == -a*sign(1 - a) + (x - 1)*sign(1 - a) + sign(1 - a) + + +def test_dir(): + assert abs(x).series(x, 0, dir="+") == x + assert abs(x).series(x, 0, dir="-") == -x + assert floor(x + 2).series(x, 0, dir='+') == 2 + assert floor(x + 2).series(x, 0, dir='-') == 1 + assert floor(x + 2.2).series(x, 0, dir='-') == 2 + assert ceiling(x + 2.2).series(x, 0, dir='-') == 3 + assert sin(x + y).series(x, 0, dir='-') == sin(x + y).series(x, 0, dir='+') + + +def test_cdir(): + assert abs(x).series(x, 0, cdir=1) == x + assert abs(x).series(x, 0, cdir=-1) == -x + assert floor(x + 2).series(x, 0, cdir=1) == 2 + assert floor(x + 2).series(x, 0, cdir=-1) == 1 + assert floor(x + 2.2).series(x, 0, cdir=1) == 2 + assert ceiling(x + 2.2).series(x, 0, cdir=-1) == 3 + assert sin(x + y).series(x, 0, cdir=-1) == sin(x + y).series(x, 0, cdir=1) + + +def test_issue_3504(): + a = Symbol("a") + e = asin(a*x)/x + assert e.series(x, 4, n=2).removeO() == \ + (x - 4)*(a/(4*sqrt(-16*a**2 + 1)) - asin(4*a)/16) + asin(4*a)/4 + + +def test_issue_4441(): + a, b = symbols('a,b') + f = 1/(1 + a*x) + assert f.series(x, 0, 5) == 1 - a*x + a**2*x**2 - a**3*x**3 + \ + a**4*x**4 + O(x**5) + f = 1/(1 + (a + b)*x) + assert f.series(x, 0, 3) == 1 + x*(-a - b)\ + + x**2*(a + b)**2 + O(x**3) + + +def test_issue_4329(): + assert tan(x).series(x, pi/2, n=3).removeO() == \ + -pi/6 + x/3 - 1/(x - pi/2) + assert cot(x).series(x, pi, n=3).removeO() == \ + -x/3 + pi/3 + 1/(x - pi) + assert limit(tan(x)**tan(2*x), x, pi/4) == exp(-1) + + +def test_issue_5183(): + assert abs(x + x**2).series(n=1) == O(x) + assert abs(x + x**2).series(n=2) == x + O(x**2) + assert ((1 + x)**2).series(x, n=6) == x**2 + 2*x + 1 + assert (1 + 1/x).series() == 1 + 1/x + assert Derivative(exp(x).series(), x).doit() == \ + 1 + x + x**2/2 + x**3/6 + x**4/24 + O(x**5) + + +def test_issue_5654(): + a = Symbol('a') + assert (1/(x**2+a**2)**2).nseries(x, x0=I*a, n=0) == \ + -I/(4*a**3*(-I*a + x)) - 1/(4*a**2*(-I*a + x)**2) + O(1, (x, I*a)) + assert (1/(x**2+a**2)**2).nseries(x, x0=I*a, n=1) == 3/(16*a**4) \ + -I/(4*a**3*(-I*a + x)) - 1/(4*a**2*(-I*a + x)**2) + O(-I*a + x, (x, I*a)) + + +def test_issue_5925(): + sx = sqrt(x + z).series(z, 0, 1) + sxy = sqrt(x + y + z).series(z, 0, 1) + s1, s2 = sx.subs(x, x + y), sxy + assert (s1 - s2).expand().removeO().simplify() == 0 + + sx = sqrt(x + z).series(z, 0, 1) + sxy = sqrt(x + y + z).series(z, 0, 1) + assert sxy.subs({x:1, y:2}) == sx.subs(x, 3) + + +def test_exp_2(): + assert exp(x**3).nseries(x, 0, 14) == 1 + x**3 + x**6/2 + x**9/6 + x**12/24 + O(x**14) diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_order.py b/lib/python3.10/site-packages/sympy/series/tests/test_order.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4cd9938d6ebbc4d8d915e09ec6e9c02c6fe599 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_order.py @@ -0,0 +1,477 @@ +from sympy.core.add import Add +from sympy.core.function import (Function, expand) +from sympy.core.numbers import (I, Rational, nan, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (conjugate, transpose) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.series.order import O, Order +from sympy.core.expr import unchanged +from sympy.testing.pytest import raises +from sympy.abc import w, x, y, z + + +def test_caching_bug(): + #needs to be a first test, so that all caches are clean + #cache it + O(w) + #and test that this won't raise an exception + O(w**(-1/x/log(3)*log(5)), w) + + +def test_free_symbols(): + assert Order(1).free_symbols == set() + assert Order(x).free_symbols == {x} + assert Order(1, x).free_symbols == {x} + assert Order(x*y).free_symbols == {x, y} + assert Order(x, x, y).free_symbols == {x, y} + + +def test_simple_1(): + o = Rational(0) + assert Order(2*x) == Order(x) + assert Order(x)*3 == Order(x) + assert -28*Order(x) == Order(x) + assert Order(Order(x)) == Order(x) + assert Order(Order(x), y) == Order(Order(x), x, y) + assert Order(-23) == Order(1) + assert Order(exp(x)) == Order(1, x) + assert Order(exp(1/x)).expr == exp(1/x) + assert Order(x*exp(1/x)).expr == x*exp(1/x) + assert Order(x**(o/3)).expr == x**(o/3) + assert Order(x**(o*Rational(5, 3))).expr == x**(o*Rational(5, 3)) + assert Order(x**2 + x + y, x) == O(1, x) + assert Order(x**2 + x + y, y) == O(1, y) + raises(ValueError, lambda: Order(exp(x), x, x)) + raises(TypeError, lambda: Order(x, 2 - x)) + + +def test_simple_2(): + assert Order(2*x)*x == Order(x**2) + assert Order(2*x)/x == Order(1, x) + assert Order(2*x)*x*exp(1/x) == Order(x**2*exp(1/x)) + assert (Order(2*x)*x*exp(1/x)/log(x)**3).expr == x**2*exp(1/x)*log(x)**-3 + + +def test_simple_3(): + assert Order(x) + x == Order(x) + assert Order(x) + 2 == 2 + Order(x) + assert Order(x) + x**2 == Order(x) + assert Order(x) + 1/x == 1/x + Order(x) + assert Order(1/x) + 1/x**2 == 1/x**2 + Order(1/x) + assert Order(x) + exp(1/x) == Order(x) + exp(1/x) + + +def test_simple_4(): + assert Order(x)**2 == Order(x**2) + + +def test_simple_5(): + assert Order(x) + Order(x**2) == Order(x) + assert Order(x) + Order(x**-2) == Order(x**-2) + assert Order(x) + Order(1/x) == Order(1/x) + + +def test_simple_6(): + assert Order(x) - Order(x) == Order(x) + assert Order(x) + Order(1) == Order(1) + assert Order(x) + Order(x**2) == Order(x) + assert Order(1/x) + Order(1) == Order(1/x) + assert Order(x) + Order(exp(1/x)) == Order(exp(1/x)) + assert Order(x**3) + Order(exp(2/x)) == Order(exp(2/x)) + assert Order(x**-3) + Order(exp(2/x)) == Order(exp(2/x)) + + +def test_simple_7(): + assert 1 + O(1) == O(1) + assert 2 + O(1) == O(1) + assert x + O(1) == O(1) + assert 1/x + O(1) == 1/x + O(1) + + +def test_simple_8(): + assert O(sqrt(-x)) == O(sqrt(x)) + assert O(x**2*sqrt(x)) == O(x**Rational(5, 2)) + assert O(x**3*sqrt(-(-x)**3)) == O(x**Rational(9, 2)) + assert O(x**Rational(3, 2)*sqrt((-x)**3)) == O(x**3) + assert O(x*(-2*x)**(I/2)) == O(x*(-x)**(I/2)) + + +def test_as_expr_variables(): + assert Order(x).as_expr_variables(None) == (x, ((x, 0),)) + assert Order(x).as_expr_variables(((x, 0),)) == (x, ((x, 0),)) + assert Order(y).as_expr_variables(((x, 0),)) == (y, ((x, 0), (y, 0))) + assert Order(y).as_expr_variables(((x, 0), (y, 0))) == (y, ((x, 0), (y, 0))) + + +def test_contains_0(): + assert Order(1, x).contains(Order(1, x)) + assert Order(1, x).contains(Order(1)) + assert Order(1).contains(Order(1, x)) is False + + +def test_contains_1(): + assert Order(x).contains(Order(x)) + assert Order(x).contains(Order(x**2)) + assert not Order(x**2).contains(Order(x)) + assert not Order(x).contains(Order(1/x)) + assert not Order(1/x).contains(Order(exp(1/x))) + assert not Order(x).contains(Order(exp(1/x))) + assert Order(1/x).contains(Order(x)) + assert Order(exp(1/x)).contains(Order(x)) + assert Order(exp(1/x)).contains(Order(1/x)) + assert Order(exp(1/x)).contains(Order(exp(1/x))) + assert Order(exp(2/x)).contains(Order(exp(1/x))) + assert not Order(exp(1/x)).contains(Order(exp(2/x))) + + +def test_contains_2(): + assert Order(x).contains(Order(y)) is None + assert Order(x).contains(Order(y*x)) + assert Order(y*x).contains(Order(x)) + assert Order(y).contains(Order(x*y)) + assert Order(x).contains(Order(y**2*x)) + + +def test_contains_3(): + assert Order(x*y**2).contains(Order(x**2*y)) is None + assert Order(x**2*y).contains(Order(x*y**2)) is None + + +def test_contains_4(): + assert Order(sin(1/x**2)).contains(Order(cos(1/x**2))) is True + assert Order(cos(1/x**2)).contains(Order(sin(1/x**2))) is True + + +def test_contains(): + assert Order(1, x) not in Order(1) + assert Order(1) in Order(1, x) + raises(TypeError, lambda: Order(x*y**2) in Order(x**2*y)) + + +def test_add_1(): + assert Order(x + x) == Order(x) + assert Order(3*x - 2*x**2) == Order(x) + assert Order(1 + x) == Order(1, x) + assert Order(1 + 1/x) == Order(1/x) + # TODO : A better output for Order(log(x) + 1/log(x)) + # could be Order(log(x)). Currently Order for expressions + # where all arguments would involve a log term would fall + # in this category and outputs for these should be improved. + assert Order(log(x) + 1/log(x)) == Order((log(x)**2 + 1)/log(x)) + assert Order(exp(1/x) + x) == Order(exp(1/x)) + assert Order(exp(1/x) + 1/x**20) == Order(exp(1/x)) + + +def test_ln_args(): + assert O(log(x)) + O(log(2*x)) == O(log(x)) + assert O(log(x)) + O(log(x**3)) == O(log(x)) + assert O(log(x*y)) + O(log(x) + log(y)) == O(log(x) + log(y), x, y) + + +def test_multivar_0(): + assert Order(x*y).expr == x*y + assert Order(x*y**2).expr == x*y**2 + assert Order(x*y, x).expr == x + assert Order(x*y**2, y).expr == y**2 + assert Order(x*y*z).expr == x*y*z + assert Order(x/y).expr == x/y + assert Order(x*exp(1/y)).expr == x*exp(1/y) + assert Order(exp(x)*exp(1/y)).expr == exp(x)*exp(1/y) + + +def test_multivar_0a(): + assert Order(exp(1/x)*exp(1/y)).expr == exp(1/x)*exp(1/y) + + +def test_multivar_1(): + assert Order(x + y).expr == x + y + assert Order(x + 2*y).expr == x + y + assert (Order(x + y) + x).expr == (x + y) + assert (Order(x + y) + x**2) == Order(x + y) + assert (Order(x + y) + 1/x) == 1/x + Order(x + y) + assert Order(x**2 + y*x).expr == x**2 + y*x + + +def test_multivar_2(): + assert Order(x**2*y + y**2*x, x, y).expr == x**2*y + y**2*x + + +def test_multivar_mul_1(): + assert Order(x + y)*x == Order(x**2 + y*x, x, y) + + +def test_multivar_3(): + assert (Order(x) + Order(y)).args in [ + (Order(x), Order(y)), + (Order(y), Order(x))] + assert Order(x) + Order(y) + Order(x + y) == Order(x + y) + assert (Order(x**2*y) + Order(y**2*x)).args in [ + (Order(x*y**2), Order(y*x**2)), + (Order(y*x**2), Order(x*y**2))] + assert (Order(x**2*y) + Order(y*x)) == Order(x*y) + + +def test_issue_3468(): + y = Symbol('y', negative=True) + z = Symbol('z', complex=True) + + # check that Order does not modify assumptions about symbols + Order(x) + Order(y) + Order(z) + + assert x.is_positive is None + assert y.is_positive is False + assert z.is_positive is None + + +def test_leading_order(): + assert (x + 1 + 1/x**5).extract_leading_order(x) == ((1/x**5, O(1/x**5)),) + assert (1 + 1/x).extract_leading_order(x) == ((1/x, O(1/x)),) + assert (1 + x).extract_leading_order(x) == ((1, O(1, x)),) + assert (1 + x**2).extract_leading_order(x) == ((1, O(1, x)),) + assert (2 + x**2).extract_leading_order(x) == ((2, O(1, x)),) + assert (x + x**2).extract_leading_order(x) == ((x, O(x)),) + + +def test_leading_order2(): + assert set((2 + pi + x**2).extract_leading_order(x)) == {(pi, O(1, x)), + (S(2), O(1, x))} + assert set((2*x + pi*x + x**2).extract_leading_order(x)) == {(2*x, O(x)), + (x*pi, O(x))} + + +def test_order_leadterm(): + assert O(x**2)._eval_as_leading_term(x) == O(x**2) + + +def test_order_symbols(): + e = x*y*sin(x)*Integral(x, (x, 1, 2)) + assert O(e) == O(x**2*y, x, y) + assert O(e, x) == O(x**2) + + +def test_nan(): + assert O(nan) is nan + assert not O(x).contains(nan) + + +def test_O1(): + assert O(1, x) * x == O(x) + assert O(1, y) * x == O(1, y) + + +def test_getn(): + # other lines are tested incidentally by the suite + assert O(x).getn() == 1 + assert O(x/log(x)).getn() == 1 + assert O(x**2/log(x)**2).getn() == 2 + assert O(x*log(x)).getn() == 1 + raises(NotImplementedError, lambda: (O(x) + O(y)).getn()) + + +def test_diff(): + assert O(x**2).diff(x) == O(x) + + +def test_getO(): + assert (x).getO() is None + assert (x).removeO() == x + assert (O(x)).getO() == O(x) + assert (O(x)).removeO() == 0 + assert (z + O(x) + O(y)).getO() == O(x) + O(y) + assert (z + O(x) + O(y)).removeO() == z + raises(NotImplementedError, lambda: (O(x) + O(y)).getn()) + + +def test_leading_term(): + from sympy.functions.special.gamma_functions import digamma + assert O(1/digamma(1/x)) == O(1/log(x)) + + +def test_eval(): + assert Order(x).subs(Order(x), 1) == 1 + assert Order(x).subs(x, y) == Order(y) + assert Order(x).subs(y, x) == Order(x) + assert Order(x).subs(x, x + y) == Order(x + y, (x, -y)) + assert (O(1)**x).is_Pow + + +def test_issue_4279(): + a, b = symbols('a b') + assert O(a, a, b) + O(1, a, b) == O(1, a, b) + assert O(b, a, b) + O(1, a, b) == O(1, a, b) + assert O(a + b, a, b) + O(1, a, b) == O(1, a, b) + assert O(1, a, b) + O(a, a, b) == O(1, a, b) + assert O(1, a, b) + O(b, a, b) == O(1, a, b) + assert O(1, a, b) + O(a + b, a, b) == O(1, a, b) + + +def test_issue_4855(): + assert 1/O(1) != O(1) + assert 1/O(x) != O(1/x) + assert 1/O(x, (x, oo)) != O(1/x, (x, oo)) + + f = Function('f') + assert 1/O(f(x)) != O(1/x) + + +def test_order_conjugate_transpose(): + x = Symbol('x', real=True) + y = Symbol('y', imaginary=True) + assert conjugate(Order(x)) == Order(conjugate(x)) + assert conjugate(Order(y)) == Order(conjugate(y)) + assert conjugate(Order(x**2)) == Order(conjugate(x)**2) + assert conjugate(Order(y**2)) == Order(conjugate(y)**2) + assert transpose(Order(x)) == Order(transpose(x)) + assert transpose(Order(y)) == Order(transpose(y)) + assert transpose(Order(x**2)) == Order(transpose(x)**2) + assert transpose(Order(y**2)) == Order(transpose(y)**2) + + +def test_order_noncommutative(): + A = Symbol('A', commutative=False) + assert Order(A + A*x, x) == Order(1, x) + assert (A + A*x)*Order(x) == Order(x) + assert (A*x)*Order(x) == Order(x**2, x) + assert expand((1 + Order(x))*A*A*x) == A*A*x + Order(x**2, x) + assert expand((A*A + Order(x))*x) == A*A*x + Order(x**2, x) + assert expand((A + Order(x))*A*x) == A*A*x + Order(x**2, x) + + +def test_issue_6753(): + assert (1 + x**2)**10000*O(x) == O(x) + + +def test_order_at_infinity(): + assert Order(1 + x, (x, oo)) == Order(x, (x, oo)) + assert Order(3*x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo))*3 == Order(x, (x, oo)) + assert -28*Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(Order(x, (x, oo)), (x, oo)) == Order(x, (x, oo)) + assert Order(Order(x, (x, oo)), (y, oo)) == Order(x, (x, oo), (y, oo)) + assert Order(3, (x, oo)) == Order(1, (x, oo)) + assert Order(x**2 + x + y, (x, oo)) == O(x**2, (x, oo)) + assert Order(x**2 + x + y, (y, oo)) == O(y, (y, oo)) + + assert Order(2*x, (x, oo))*x == Order(x**2, (x, oo)) + assert Order(2*x, (x, oo))/x == Order(1, (x, oo)) + assert Order(2*x, (x, oo))*x*exp(1/x) == Order(x**2*exp(1/x), (x, oo)) + assert Order(2*x, (x, oo))*x*exp(1/x)/log(x)**3 == Order(x**2*exp(1/x)*log(x)**-3, (x, oo)) + + assert Order(x, (x, oo)) + 1/x == 1/x + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + 1 == 1 + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + x == x + Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + x**2 == x**2 + Order(x, (x, oo)) + assert Order(1/x, (x, oo)) + 1/x**2 == 1/x**2 + Order(1/x, (x, oo)) == Order(1/x, (x, oo)) + assert Order(x, (x, oo)) + exp(1/x) == exp(1/x) + Order(x, (x, oo)) + + assert Order(x, (x, oo))**2 == Order(x**2, (x, oo)) + + assert Order(x, (x, oo)) + Order(x**2, (x, oo)) == Order(x**2, (x, oo)) + assert Order(x, (x, oo)) + Order(x**-2, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(1/x, (x, oo)) == Order(x, (x, oo)) + + assert Order(x, (x, oo)) - Order(x, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(1, (x, oo)) == Order(x, (x, oo)) + assert Order(x, (x, oo)) + Order(x**2, (x, oo)) == Order(x**2, (x, oo)) + assert Order(1/x, (x, oo)) + Order(1, (x, oo)) == Order(1, (x, oo)) + assert Order(x, (x, oo)) + Order(exp(1/x), (x, oo)) == Order(x, (x, oo)) + assert Order(x**3, (x, oo)) + Order(exp(2/x), (x, oo)) == Order(x**3, (x, oo)) + assert Order(x**-3, (x, oo)) + Order(exp(2/x), (x, oo)) == Order(exp(2/x), (x, oo)) + + # issue 7207 + assert Order(exp(x), (x, oo)).expr == Order(2*exp(x), (x, oo)).expr == exp(x) + assert Order(y**x, (x, oo)).expr == Order(2*y**x, (x, oo)).expr == exp(x*log(y)) + + # issue 19545 + assert Order(1/x - 3/(3*x + 2), (x, oo)).expr == x**(-2) + +def test_mixing_order_at_zero_and_infinity(): + assert (Order(x, (x, 0)) + Order(x, (x, oo))).is_Add + assert Order(x, (x, 0)) + Order(x, (x, oo)) == Order(x, (x, oo)) + Order(x, (x, 0)) + assert Order(Order(x, (x, oo))) == Order(x, (x, oo)) + + # not supported (yet) + raises(NotImplementedError, lambda: Order(x, (x, 0))*Order(x, (x, oo))) + raises(NotImplementedError, lambda: Order(x, (x, oo))*Order(x, (x, 0))) + raises(NotImplementedError, lambda: Order(Order(x, (x, oo)), y)) + raises(NotImplementedError, lambda: Order(Order(x), (x, oo))) + + +def test_order_at_some_point(): + assert Order(x, (x, 1)) == Order(1, (x, 1)) + assert Order(2*x - 2, (x, 1)) == Order(x - 1, (x, 1)) + assert Order(-x + 1, (x, 1)) == Order(x - 1, (x, 1)) + assert Order(x - 1, (x, 1))**2 == Order((x - 1)**2, (x, 1)) + assert Order(x - 2, (x, 2)) - O(x - 2, (x, 2)) == Order(x - 2, (x, 2)) + + +def test_order_subs_limits(): + # issue 3333 + assert (1 + Order(x)).subs(x, 1/x) == 1 + Order(1/x, (x, oo)) + assert (1 + Order(x)).limit(x, 0) == 1 + # issue 5769 + assert ((x + Order(x**2))/x).limit(x, 0) == 1 + + assert Order(x**2).subs(x, y - 1) == Order((y - 1)**2, (y, 1)) + assert Order(10*x**2, (x, 2)).subs(x, y - 1) == Order(1, (y, 3)) + + +def test_issue_9351(): + assert exp(x).series(x, 10, 1) == exp(10) + Order(x - 10, (x, 10)) + + +def test_issue_9192(): + assert O(1)*O(1) == O(1) + assert O(1)**O(1) == O(1) + + +def test_issue_9910(): + assert O(x*log(x) + sin(x), (x, oo)) == O(x*log(x), (x, oo)) + + +def test_performance_of_adding_order(): + l = [x**i for i in range(1000)] + l.append(O(x**1001)) + assert Add(*l).subs(x,1) == O(1) + +def test_issue_14622(): + assert (x**(-4) + x**(-3) + x**(-1) + O(x**(-6), (x, oo))).as_numer_denom() == ( + x**4 + x**5 + x**7 + O(x**2, (x, oo)), x**8) + assert (x**3 + O(x**2, (x, oo))).is_Add + assert O(x**2, (x, oo)).contains(x**3) is False + assert O(x, (x, oo)).contains(O(x, (x, 0))) is None + assert O(x, (x, 0)).contains(O(x, (x, oo))) is None + raises(NotImplementedError, lambda: O(x**3).contains(x**w)) + + +def test_issue_15539(): + assert O(1/x**2 + 1/x**4, (x, -oo)) == O(1/x**2, (x, -oo)) + assert O(1/x**4 + exp(x), (x, -oo)) == O(1/x**4, (x, -oo)) + assert O(1/x**4 + exp(-x), (x, -oo)) == O(exp(-x), (x, -oo)) + assert O(1/x, (x, oo)).subs(x, -x) == O(-1/x, (x, -oo)) + +def test_issue_18606(): + assert unchanged(Order, 0) + + +def test_issue_22165(): + assert O(log(x)).contains(2) + + +def test_issue_23231(): + # This test checks Order for expressions having + # arguments containing variables in exponents/powers. + assert O(x**x + 2**x, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x**x + x**2, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(x**x + 1/x**2, (x, oo)) == O(exp(x*log(x)), (x, oo)) + assert O(2**x + 3**x , (x, oo)) == O(exp(x*log(3)), (x, oo)) + + +def test_issue_9917(): + assert O(x*sin(x) + 1, (x, oo)) == O(x, (x, oo)) diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_residues.py b/lib/python3.10/site-packages/sympy/series/tests/test_residues.py new file mode 100644 index 0000000000000000000000000000000000000000..9f7d075a56500d008e3c8b46c1fda5db890fd76a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_residues.py @@ -0,0 +1,101 @@ +from sympy.core.function import Function +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.combinatorial.factorials import factorial +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cot, sin, tan) +from sympy.series.residues import residue +from sympy.testing.pytest import XFAIL, raises +from sympy.abc import x, z, a, s, k + + +def test_basic1(): + assert residue(1/x, x, 0) == 1 + assert residue(-2/x, x, 0) == -2 + assert residue(81/x, x, 0) == 81 + assert residue(1/x**2, x, 0) == 0 + assert residue(0, x, 0) == 0 + assert residue(5, x, 0) == 0 + assert residue(x, x, 0) == 0 + assert residue(x**2, x, 0) == 0 + + +def test_basic2(): + assert residue(1/x, x, 1) == 0 + assert residue(-2/x, x, 1) == 0 + assert residue(81/x, x, -1) == 0 + assert residue(1/x**2, x, 1) == 0 + assert residue(0, x, 1) == 0 + assert residue(5, x, 1) == 0 + assert residue(x, x, 1) == 0 + assert residue(x**2, x, 5) == 0 + + +def test_f(): + f = Function("f") + assert residue(f(x)/x**5, x, 0) == f(x).diff(x, 4).subs(x, 0)/24 + + +def test_functions(): + assert residue(1/sin(x), x, 0) == 1 + assert residue(2/sin(x), x, 0) == 2 + assert residue(1/sin(x)**2, x, 0) == 0 + assert residue(1/sin(x)**5, x, 0) == Rational(3, 8) + + +def test_expressions(): + assert residue(1/(x + 1), x, 0) == 0 + assert residue(1/(x + 1), x, -1) == 1 + assert residue(1/(x**2 + 1), x, -1) == 0 + assert residue(1/(x**2 + 1), x, I) == -I/2 + assert residue(1/(x**2 + 1), x, -I) == I/2 + assert residue(1/(x**4 + 1), x, 0) == 0 + assert residue(1/(x**4 + 1), x, exp(I*pi/4)).equals(-(Rational(1, 4) + I/4)/sqrt(2)) + assert residue(1/(x**2 + a**2)**2, x, a*I) == -I/4/a**3 + + +@XFAIL +def test_expressions_failing(): + n = Symbol('n', integer=True, positive=True) + assert residue(exp(z)/(z - pi*I/4*a)**n, z, I*pi*a) == \ + exp(I*pi*a/4)/factorial(n - 1) + + +def test_NotImplemented(): + raises(NotImplementedError, lambda: residue(exp(1/z), z, 0)) + + +def test_bug(): + assert residue(2**(z)*(s + z)*(1 - s - z)/z**2, z, 0) == \ + 1 + s*log(2) - s**2*log(2) - 2*s + + +def test_issue_5654(): + assert residue(1/(x**2 + a**2)**2, x, a*I) == -I/(4*a**3) + assert residue(1/s*1/(z - exp(s)), s, 0) == 1/(z - 1) + assert residue((1 + k)/s*1/(z - exp(s)), s, 0) == k/(z - 1) + 1/(z - 1) + + +def test_issue_6499(): + assert residue(1/(exp(z) - 1), z, 0) == 1 + + +def test_issue_14037(): + assert residue(sin(x**50)/x**51, x, 0) == 1 + + +def test_issue_21176(): + f = x**2*cot(pi*x)/(x**4 + 1) + assert residue(f, x, -sqrt(2)/2 - sqrt(2)*I/2).cancel().together(deep=True)\ + == sqrt(2)*(1 - I)/(8*tan(sqrt(2)*pi*(1 + I)/2)) + + +def test_issue_21177(): + r = -sqrt(3)*tanh(sqrt(3)*pi/2)/3 + a = residue(cot(pi*x)/((x - 1)*(x - 2) + 1), x, S(3)/2 - sqrt(3)*I/2) + b = residue(cot(pi*x)/(x**2 - 3*x + 3), x, S(3)/2 - sqrt(3)*I/2) + assert a == r + assert (b - a).cancel() == 0 diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_sequences.py b/lib/python3.10/site-packages/sympy/series/tests/test_sequences.py new file mode 100644 index 0000000000000000000000000000000000000000..61e276ad67982f0a9877de3548d70238976d28a5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_sequences.py @@ -0,0 +1,312 @@ +from sympy.core.containers import Tuple +from sympy.core.function import Function +from sympy.core.numbers import oo, Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols, Symbol +from sympy.functions.combinatorial.numbers import tribonacci, fibonacci +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import cos, sin +from sympy.series import EmptySequence +from sympy.series.sequences import (SeqMul, SeqAdd, SeqPer, SeqFormula, + sequence) +from sympy.sets.sets import Interval +from sympy.tensor.indexed import Indexed, Idx +from sympy.series.sequences import SeqExpr, SeqExprOp, RecursiveSeq +from sympy.testing.pytest import raises, slow + +x, y, z = symbols('x y z') +n, m = symbols('n m') + + +def test_EmptySequence(): + assert S.EmptySequence is EmptySequence + + assert S.EmptySequence.interval is S.EmptySet + assert S.EmptySequence.length is S.Zero + + assert list(S.EmptySequence) == [] + + +def test_SeqExpr(): + #SeqExpr is a baseclass and does not take care of + #ensuring all arguments are Basics hence the use of + #Tuple(...) here. + s = SeqExpr(Tuple(1, n, y), Tuple(x, 0, 10)) + + assert isinstance(s, SeqExpr) + assert s.gen == (1, n, y) + assert s.interval == Interval(0, 10) + assert s.start == 0 + assert s.stop == 10 + assert s.length == 11 + assert s.variables == (x,) + + assert SeqExpr(Tuple(1, 2, 3), Tuple(x, 0, oo)).length is oo + + +def test_SeqPer(): + s = SeqPer((1, n, 3), (x, 0, 5)) + + assert isinstance(s, SeqPer) + assert s.periodical == Tuple(1, n, 3) + assert s.period == 3 + assert s.coeff(3) == 1 + assert s.free_symbols == {n} + + assert list(s) == [1, n, 3, 1, n, 3] + assert s[:] == [1, n, 3, 1, n, 3] + assert SeqPer((1, n, 3), (x, -oo, 0))[0:6] == [1, n, 3, 1, n, 3] + + raises(ValueError, lambda: SeqPer((1, 2, 3), (0, 1, 2))) + raises(ValueError, lambda: SeqPer((1, 2, 3), (x, -oo, oo))) + raises(ValueError, lambda: SeqPer(n**2, (0, oo))) + + assert SeqPer((n, n**2, n**3), (m, 0, oo))[:6] == \ + [n, n**2, n**3, n, n**2, n**3] + assert SeqPer((n, n**2, n**3), (n, 0, oo))[:6] == [0, 1, 8, 3, 16, 125] + assert SeqPer((n, m), (n, 0, oo))[:6] == [0, m, 2, m, 4, m] + + +def test_SeqFormula(): + s = SeqFormula(n**2, (n, 0, 5)) + + assert isinstance(s, SeqFormula) + assert s.formula == n**2 + assert s.coeff(3) == 9 + + assert list(s) == [i**2 for i in range(6)] + assert s[:] == [i**2 for i in range(6)] + assert SeqFormula(n**2, (n, -oo, 0))[0:6] == [i**2 for i in range(6)] + + assert SeqFormula(n**2, (0, oo)) == SeqFormula(n**2, (n, 0, oo)) + + assert SeqFormula(n**2, (0, m)).subs(m, x) == SeqFormula(n**2, (0, x)) + assert SeqFormula(m*n**2, (n, 0, oo)).subs(m, x) == \ + SeqFormula(x*n**2, (n, 0, oo)) + + raises(ValueError, lambda: SeqFormula(n**2, (0, 1, 2))) + raises(ValueError, lambda: SeqFormula(n**2, (n, -oo, oo))) + raises(ValueError, lambda: SeqFormula(m*n**2, (0, oo))) + + seq = SeqFormula(x*(y**2 + z), (z, 1, 100)) + assert seq.expand() == SeqFormula(x*y**2 + x*z, (z, 1, 100)) + seq = SeqFormula(sin(x*(y**2 + z)),(z, 1, 100)) + assert seq.expand(trig=True) == SeqFormula(sin(x*y**2)*cos(x*z) + sin(x*z)*cos(x*y**2), (z, 1, 100)) + assert seq.expand() == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100)) + assert seq.expand(trig=False) == SeqFormula(sin(x*y**2 + x*z), (z, 1, 100)) + seq = SeqFormula(exp(x*(y**2 + z)), (z, 1, 100)) + assert seq.expand() == SeqFormula(exp(x*y**2)*exp(x*z), (z, 1, 100)) + assert seq.expand(power_exp=False) == SeqFormula(exp(x*y**2 + x*z), (z, 1, 100)) + assert seq.expand(mul=False, power_exp=False) == SeqFormula(exp(x*(y**2 + z)), (z, 1, 100)) + +def test_sequence(): + form = SeqFormula(n**2, (n, 0, 5)) + per = SeqPer((1, 2, 3), (n, 0, 5)) + inter = SeqFormula(n**2) + + assert sequence(n**2, (n, 0, 5)) == form + assert sequence((1, 2, 3), (n, 0, 5)) == per + assert sequence(n**2) == inter + + +def test_SeqExprOp(): + form = SeqFormula(n**2, (n, 0, 10)) + per = SeqPer((1, 2, 3), (m, 5, 10)) + + s = SeqExprOp(form, per) + assert s.gen == (n**2, (1, 2, 3)) + assert s.interval == Interval(5, 10) + assert s.start == 5 + assert s.stop == 10 + assert s.length == 6 + assert s.variables == (n, m) + + +def test_SeqAdd(): + per = SeqPer((1, 2, 3), (n, 0, oo)) + form = SeqFormula(n**2) + + per_bou = SeqPer((1, 2), (n, 1, 5)) + form_bou = SeqFormula(n**2, (6, 10)) + form_bou2 = SeqFormula(n**2, (1, 5)) + + assert SeqAdd() == S.EmptySequence + assert SeqAdd(S.EmptySequence) == S.EmptySequence + assert SeqAdd(per) == per + assert SeqAdd(per, S.EmptySequence) == per + assert SeqAdd(per_bou, form_bou) == S.EmptySequence + + s = SeqAdd(per_bou, form_bou2, evaluate=False) + assert s.args == (form_bou2, per_bou) + assert s[:] == [2, 6, 10, 18, 26] + assert list(s) == [2, 6, 10, 18, 26] + + assert isinstance(SeqAdd(per, per_bou, evaluate=False), SeqAdd) + + s1 = SeqAdd(per, per_bou) + assert isinstance(s1, SeqPer) + assert s1 == SeqPer((2, 4, 4, 3, 3, 5), (n, 1, 5)) + s2 = SeqAdd(form, form_bou) + assert isinstance(s2, SeqFormula) + assert s2 == SeqFormula(2*n**2, (6, 10)) + + assert SeqAdd(form, form_bou, per) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + assert SeqAdd(form, SeqAdd(form_bou, per)) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + assert SeqAdd(per, SeqAdd(form, form_bou), evaluate=False) == \ + SeqAdd(per, SeqFormula(2*n**2, (6, 10))) + + assert SeqAdd(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (m, 0, oo))) == \ + SeqPer((2, 4), (n, 0, oo)) + + +def test_SeqMul(): + per = SeqPer((1, 2, 3), (n, 0, oo)) + form = SeqFormula(n**2) + + per_bou = SeqPer((1, 2), (n, 1, 5)) + form_bou = SeqFormula(n**2, (n, 6, 10)) + form_bou2 = SeqFormula(n**2, (1, 5)) + + assert SeqMul() == S.EmptySequence + assert SeqMul(S.EmptySequence) == S.EmptySequence + assert SeqMul(per) == per + assert SeqMul(per, S.EmptySequence) == S.EmptySequence + assert SeqMul(per_bou, form_bou) == S.EmptySequence + + s = SeqMul(per_bou, form_bou2, evaluate=False) + assert s.args == (form_bou2, per_bou) + assert s[:] == [1, 8, 9, 32, 25] + assert list(s) == [1, 8, 9, 32, 25] + + assert isinstance(SeqMul(per, per_bou, evaluate=False), SeqMul) + + s1 = SeqMul(per, per_bou) + assert isinstance(s1, SeqPer) + assert s1 == SeqPer((1, 4, 3, 2, 2, 6), (n, 1, 5)) + s2 = SeqMul(form, form_bou) + assert isinstance(s2, SeqFormula) + assert s2 == SeqFormula(n**4, (6, 10)) + + assert SeqMul(form, form_bou, per) == \ + SeqMul(per, SeqFormula(n**4, (6, 10))) + assert SeqMul(form, SeqMul(form_bou, per)) == \ + SeqMul(per, SeqFormula(n**4, (6, 10))) + assert SeqMul(per, SeqMul(form, form_bou2, + evaluate=False), evaluate=False) == \ + SeqMul(form, per, form_bou2, evaluate=False) + + assert SeqMul(SeqPer((1, 2), (n, 0, oo)), SeqPer((1, 2), (n, 0, oo))) == \ + SeqPer((1, 4), (n, 0, oo)) + + +def test_add(): + per = SeqPer((1, 2), (n, 0, oo)) + form = SeqFormula(n**2) + + assert per + (SeqPer((2, 3))) == SeqPer((3, 5), (n, 0, oo)) + assert form + SeqFormula(n**3) == SeqFormula(n**2 + n**3) + + assert per + form == SeqAdd(per, form) + + raises(TypeError, lambda: per + n) + raises(TypeError, lambda: n + per) + + +def test_sub(): + per = SeqPer((1, 2), (n, 0, oo)) + form = SeqFormula(n**2) + + assert per - (SeqPer((2, 3))) == SeqPer((-1, -1), (n, 0, oo)) + assert form - (SeqFormula(n**3)) == SeqFormula(n**2 - n**3) + + assert per - form == SeqAdd(per, -form) + + raises(TypeError, lambda: per - n) + raises(TypeError, lambda: n - per) + + +def test_mul__coeff_mul(): + assert SeqPer((1, 2), (n, 0, oo)).coeff_mul(2) == SeqPer((2, 4), (n, 0, oo)) + assert SeqFormula(n**2).coeff_mul(2) == SeqFormula(2*n**2) + assert S.EmptySequence.coeff_mul(100) == S.EmptySequence + + assert SeqPer((1, 2), (n, 0, oo)) * (SeqPer((2, 3))) == \ + SeqPer((2, 6), (n, 0, oo)) + assert SeqFormula(n**2) * SeqFormula(n**3) == SeqFormula(n**5) + + assert S.EmptySequence * SeqFormula(n**2) == S.EmptySequence + assert SeqFormula(n**2) * S.EmptySequence == S.EmptySequence + + raises(TypeError, lambda: sequence(n**2) * n) + raises(TypeError, lambda: n * sequence(n**2)) + + +def test_neg(): + assert -SeqPer((1, -2), (n, 0, oo)) == SeqPer((-1, 2), (n, 0, oo)) + assert -SeqFormula(n**2) == SeqFormula(-n**2) + + +def test_operations(): + per = SeqPer((1, 2), (n, 0, oo)) + per2 = SeqPer((2, 4), (n, 0, oo)) + form = SeqFormula(n**2) + form2 = SeqFormula(n**3) + + assert per + form + form2 == SeqAdd(per, form, form2) + assert per + form - form2 == SeqAdd(per, form, -form2) + assert per + form - S.EmptySequence == SeqAdd(per, form) + assert per + per2 + form == SeqAdd(SeqPer((3, 6), (n, 0, oo)), form) + assert S.EmptySequence - per == -per + assert form + form == SeqFormula(2*n**2) + + assert per * form * form2 == SeqMul(per, form, form2) + assert form * form == SeqFormula(n**4) + assert form * -form == SeqFormula(-n**4) + + assert form * (per + form2) == SeqMul(form, SeqAdd(per, form2)) + assert form * (per + per) == SeqMul(form, per2) + + assert form.coeff_mul(m) == SeqFormula(m*n**2, (n, 0, oo)) + assert per.coeff_mul(m) == SeqPer((m, 2*m), (n, 0, oo)) + + +def test_Idx_limits(): + i = symbols('i', cls=Idx) + r = Indexed('r', i) + + assert SeqFormula(r, (i, 0, 5))[:] == [r.subs(i, j) for j in range(6)] + assert SeqPer((1, 2), (i, 0, 5))[:] == [1, 2, 1, 2, 1, 2] + + +@slow +def test_find_linear_recurrence(): + assert sequence((0, 1, 1, 2, 3, 5, 8, 13, 21, 34, 55), \ + (n, 0, 10)).find_linear_recurrence(11) == [1, 1] + assert sequence((1, 2, 4, 7, 28, 128, 582, 2745, 13021, 61699, 292521, \ + 1387138), (n, 0, 11)).find_linear_recurrence(12) == [5, -2, 6, -11] + assert sequence(x*n**3+y*n, (n, 0, oo)).find_linear_recurrence(10) \ + == [4, -6, 4, -1] + assert sequence(x**n, (n,0,20)).find_linear_recurrence(21) == [x] + assert sequence((1,2,3)).find_linear_recurrence(10, 5) == [0, 0, 1] + assert sequence(((1 + sqrt(5))/2)**n + \ + (-(1 + sqrt(5))/2)**(-n)).find_linear_recurrence(10) == [1, 1] + assert sequence(x*((1 + sqrt(5))/2)**n + y*(-(1 + sqrt(5))/2)**(-n), \ + (n,0,oo)).find_linear_recurrence(10) == [1, 1] + assert sequence((1,2,3,4,6),(n, 0, 4)).find_linear_recurrence(5) == [] + assert sequence((2,3,4,5,6,79),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \ + == ([], None) + assert sequence((2,3,4,5,8,30),(n, 0, 5)).find_linear_recurrence(6,gfvar=x) \ + == ([Rational(19, 2), -20, Rational(27, 2)], (-31*x**2 + 32*x - 4)/(27*x**3 - 40*x**2 + 19*x -2)) + assert sequence(fibonacci(n)).find_linear_recurrence(30,gfvar=x) \ + == ([1, 1], -x/(x**2 + x - 1)) + assert sequence(tribonacci(n)).find_linear_recurrence(30,gfvar=x) \ + == ([1, 1, 1], -x/(x**3 + x**2 + x - 1)) + +def test_RecursiveSeq(): + y = Function('y') + n = Symbol('n') + fib = RecursiveSeq(y(n - 1) + y(n - 2), y(n), n, [0, 1]) + assert fib.coeff(3) == 2 diff --git a/lib/python3.10/site-packages/sympy/series/tests/test_series.py b/lib/python3.10/site-packages/sympy/series/tests/test_series.py new file mode 100644 index 0000000000000000000000000000000000000000..2adeef40f8a2864862590f4bc172bc76f3b83e65 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/series/tests/test_series.py @@ -0,0 +1,404 @@ +from sympy.core.evalf import N +from sympy.core.function import (Derivative, Function, PoleError, Subs) +from sympy.core.numbers import (E, Float, Rational, oo, pi, I) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan, cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.integrals.integrals import Integral, integrate +from sympy.series.order import O +from sympy.series.series import series +from sympy.abc import x, y, n, k +from sympy.testing.pytest import raises +from sympy.core import EulerGamma + + +def test_sin(): + e1 = sin(x).series(x, 0) + e2 = series(sin(x), x, 0) + assert e1 == e2 + + +def test_cos(): + e1 = cos(x).series(x, 0) + e2 = series(cos(x), x, 0) + assert e1 == e2 + + +def test_exp(): + e1 = exp(x).series(x, 0) + e2 = series(exp(x), x, 0) + assert e1 == e2 + + +def test_exp2(): + e1 = exp(cos(x)).series(x, 0) + e2 = series(exp(cos(x)), x, 0) + assert e1 == e2 + + +def test_issue_5223(): + assert series(1, x) == 1 + assert next(S.Zero.lseries(x)) == 0 + assert cos(x).series() == cos(x).series(x) + raises(ValueError, lambda: cos(x + y).series()) + raises(ValueError, lambda: x.series(dir="")) + + assert (cos(x).series(x, 1) - + cos(x + 1).series(x).subs(x, x - 1)).removeO() == 0 + e = cos(x).series(x, 1, n=None) + assert [next(e) for i in range(2)] == [cos(1), -((x - 1)*sin(1))] + e = cos(x).series(x, 1, n=None, dir='-') + assert [next(e) for i in range(2)] == [cos(1), (1 - x)*sin(1)] + # the following test is exact so no need for x -> x - 1 replacement + assert abs(x).series(x, 1, dir='-') == x + assert exp(x).series(x, 1, dir='-', n=3).removeO() == \ + E - E*(-x + 1) + E*(-x + 1)**2/2 + + D = Derivative + assert D(x**2 + x**3*y**2, x, 2, y, 1).series(x).doit() == 12*x*y + assert next(D(cos(x), x).lseries()) == D(1, x) + assert D( + exp(x), x).series(n=3) == D(1, x) + D(x, x) + D(x**2/2, x) + D(x**3/6, x) + O(x**3) + + assert Integral(x, (x, 1, 3), (y, 1, x)).series(x) == -4 + 4*x + + assert (1 + x + O(x**2)).getn() == 2 + assert (1 + x).getn() is None + + raises(PoleError, lambda: ((1/sin(x))**oo).series()) + logx = Symbol('logx') + assert ((sin(x))**y).nseries(x, n=1, logx=logx) == \ + exp(y*logx) + O(x*exp(y*logx), x) + + assert sin(1/x).series(x, oo, n=5) == 1/x - 1/(6*x**3) + O(x**(-5), (x, oo)) + assert abs(x).series(x, oo, n=5, dir='+') == x + assert abs(x).series(x, -oo, n=5, dir='-') == -x + assert abs(-x).series(x, oo, n=5, dir='+') == x + assert abs(-x).series(x, -oo, n=5, dir='-') == -x + + assert exp(x*log(x)).series(n=3) == \ + 1 + x*log(x) + x**2*log(x)**2/2 + O(x**3*log(x)**3) + # XXX is this right? If not, fix "ngot > n" handling in expr. + p = Symbol('p', positive=True) + assert exp(sqrt(p)**3*log(p)).series(n=3) == \ + 1 + p**S('3/2')*log(p) + O(p**3*log(p)**3) + + assert exp(sin(x)*log(x)).series(n=2) == 1 + x*log(x) + O(x**2*log(x)**2) + + +def test_issue_6350(): + expr = integrate(exp(k*(y**3 - 3*y)), (y, 0, oo), conds='none') + assert expr.series(k, 0, 3) == -(-1)**(S(2)/3)*sqrt(3)*gamma(S(1)/3)**2*gamma(S(2)/3)/(6*pi*k**(S(1)/3)) - \ + sqrt(3)*k*gamma(-S(2)/3)*gamma(-S(1)/3)/(6*pi) - \ + (-1)**(S(1)/3)*sqrt(3)*k**(S(1)/3)*gamma(-S(1)/3)*gamma(S(1)/3)*gamma(S(2)/3)/(6*pi) - \ + (-1)**(S(2)/3)*sqrt(3)*k**(S(5)/3)*gamma(S(1)/3)**2*gamma(S(2)/3)/(4*pi) - \ + (-1)**(S(1)/3)*sqrt(3)*k**(S(7)/3)*gamma(-S(1)/3)*gamma(S(1)/3)*gamma(S(2)/3)/(8*pi) + O(k**3) + + +def test_issue_11313(): + assert Integral(cos(x), x).series(x) == sin(x).series(x) + assert Derivative(sin(x), x).series(x, n=3).doit() == cos(x).series(x, n=3) + + assert Derivative(x**3, x).as_leading_term(x) == 3*x**2 + assert Derivative(x**3, y).as_leading_term(x) == 0 + assert Derivative(sin(x), x).as_leading_term(x) == 1 + assert Derivative(cos(x), x).as_leading_term(x) == -x + + # This result is equivalent to zero, zero is not return because + # `Expr.series` doesn't currently detect an `x` in its `free_symbol`s. + assert Derivative(1, x).as_leading_term(x) == Derivative(1, x) + + assert Derivative(exp(x), x).series(x).doit() == exp(x).series(x) + assert 1 + Integral(exp(x), x).series(x) == exp(x).series(x) + + assert Derivative(log(x), x).series(x).doit() == (1/x).series(x) + assert Integral(log(x), x).series(x) == Integral(log(x), x).doit().series(x).removeO() + + +def test_series_of_Subs(): + from sympy.abc import z + + subs1 = Subs(sin(x), x, y) + subs2 = Subs(sin(x) * cos(z), x, y) + subs3 = Subs(sin(x * z), (x, z), (y, x)) + + assert subs1.series(x) == subs1 + subs1_series = (Subs(x, x, y) + Subs(-x**3/6, x, y) + + Subs(x**5/120, x, y) + O(y**6)) + assert subs1.series() == subs1_series + assert subs1.series(y) == subs1_series + assert subs1.series(z) == subs1 + assert subs2.series(z) == (Subs(z**4*sin(x)/24, x, y) + + Subs(-z**2*sin(x)/2, x, y) + Subs(sin(x), x, y) + O(z**6)) + assert subs3.series(x).doit() == subs3.doit().series(x) + assert subs3.series(z).doit() == sin(x*y) + + raises(ValueError, lambda: Subs(x + 2*y, y, z).series()) + assert Subs(x + y, y, z).series(x).doit() == x + z + + +def test_issue_3978(): + f = Function('f') + assert f(x).series(x, 0, 3, dir='-') == \ + f(0) + x*Subs(Derivative(f(x), x), x, 0) + \ + x**2*Subs(Derivative(f(x), x, x), x, 0)/2 + O(x**3) + assert f(x).series(x, 0, 3) == \ + f(0) + x*Subs(Derivative(f(x), x), x, 0) + \ + x**2*Subs(Derivative(f(x), x, x), x, 0)/2 + O(x**3) + assert f(x**2).series(x, 0, 3) == \ + f(0) + x**2*Subs(Derivative(f(x), x), x, 0) + O(x**3) + assert f(x**2+1).series(x, 0, 3) == \ + f(1) + x**2*Subs(Derivative(f(x), x), x, 1) + O(x**3) + + class TestF(Function): + pass + + assert TestF(x).series(x, 0, 3) == TestF(0) + \ + x*Subs(Derivative(TestF(x), x), x, 0) + \ + x**2*Subs(Derivative(TestF(x), x, x), x, 0)/2 + O(x**3) + +from sympy.series.acceleration import richardson, shanks +from sympy.concrete.summations import Sum +from sympy.core.numbers import Integer + + +def test_acceleration(): + e = (1 + 1/n)**n + assert round(richardson(e, n, 10, 20).evalf(), 10) == round(E.evalf(), 10) + + A = Sum(Integer(-1)**(k + 1) / k, (k, 1, n)) + assert round(shanks(A, n, 25).evalf(), 4) == round(log(2).evalf(), 4) + assert round(shanks(A, n, 25, 5).evalf(), 10) == round(log(2).evalf(), 10) + + +def test_issue_5852(): + assert series(1/cos(x/log(x)), x, 0) == 1 + x**2/(2*log(x)**2) + \ + 5*x**4/(24*log(x)**4) + O(x**6) + + +def test_issue_4583(): + assert cos(1 + x + x**2).series(x, 0, 5) == cos(1) - x*sin(1) + \ + x**2*(-sin(1) - cos(1)/2) + x**3*(-cos(1) + sin(1)/6) + \ + x**4*(-11*cos(1)/24 + sin(1)/2) + O(x**5) + + +def test_issue_6318(): + eq = (1/x)**Rational(2, 3) + assert (eq + 1).as_leading_term(x) == eq + + +def test_x_is_base_detection(): + eq = (x**2)**Rational(2, 3) + assert eq.series() == x**Rational(4, 3) + + +def test_issue_7203(): + assert series(cos(x), x, pi, 3) == \ + -1 + (x - pi)**2/2 + O((x - pi)**3, (x, pi)) + + +def test_exp_product_positive_factors(): + a, b = symbols('a, b', positive=True) + x = a * b + assert series(exp(x), x, n=8) == 1 + a*b + a**2*b**2/2 + \ + a**3*b**3/6 + a**4*b**4/24 + a**5*b**5/120 + a**6*b**6/720 + \ + a**7*b**7/5040 + O(a**8*b**8, a, b) + + +def test_issue_8805(): + assert series(1, n=8) == 1 + + +def test_issue_9549(): + y = (x**2 + x + 1) / (x**3 + x**2) + assert series(y, x, oo) == x**(-5) - 1/x**4 + x**(-3) + 1/x + O(x**(-6), (x, oo)) + + +def test_issue_10761(): + assert series(1/(x**-2 + x**-3), x, 0) == x**3 - x**4 + x**5 + O(x**6) + + +def test_issue_12578(): + y = (1 - 1/(x/2 - 1/(2*x))**4)**(S(1)/8) + assert y.series(x, 0, n=17) == 1 - 2*x**4 - 8*x**6 - 34*x**8 - 152*x**10 - 714*x**12 - \ + 3472*x**14 - 17318*x**16 + O(x**17) + + +def test_issue_12791(): + beta = symbols('beta', positive=True) + theta, varphi = symbols('theta varphi', real=True) + + expr = (-beta**2*varphi*sin(theta) + beta**2*cos(theta) + \ + beta*varphi*sin(theta) - beta*cos(theta) - beta + 1)/(beta*cos(theta) - 1)**2 + + sol = (0.5/(0.5*cos(theta) - 1.0)**2 - 0.25*cos(theta)/(0.5*cos(theta) - 1.0)**2 + + (beta - 0.5)*(-0.25*varphi*sin(2*theta) - 1.5*cos(theta) + + 0.25*cos(2*theta) + 1.25)/((0.5*cos(theta) - 1.0)**2*(0.5*cos(theta) - 1.0)) + + 0.25*varphi*sin(theta)/(0.5*cos(theta) - 1.0)**2 + + O((beta - S.Half)**2, (beta, S.Half))) + + assert expr.series(beta, 0.5, 2).trigsimp() == sol + + +def test_issue_14384(): + x, a = symbols('x a') + assert series(x**a, x) == x**a + assert series(x**(-2*a), x) == x**(-2*a) + assert series(exp(a*log(x)), x) == exp(a*log(x)) + raises(PoleError, lambda: series(x**I, x)) + raises(PoleError, lambda: series(x**(I + 1), x)) + raises(PoleError, lambda: series(exp(I*log(x)), x)) + + +def test_issue_14885(): + assert series(x**Rational(-3, 2)*exp(x), x, 0) == (x**Rational(-3, 2) + 1/sqrt(x) + + sqrt(x)/2 + x**Rational(3, 2)/6 + x**Rational(5, 2)/24 + x**Rational(7, 2)/120 + + x**Rational(9, 2)/720 + x**Rational(11, 2)/5040 + O(x**6)) + + +def test_issue_15539(): + assert series(atan(x), x, -oo) == (-1/(5*x**5) + 1/(3*x**3) - 1/x - pi/2 + + O(x**(-6), (x, -oo))) + assert series(atan(x), x, oo) == (-1/(5*x**5) + 1/(3*x**3) - 1/x + pi/2 + + O(x**(-6), (x, oo))) + + +def test_issue_7259(): + assert series(LambertW(x), x) == x - x**2 + 3*x**3/2 - 8*x**4/3 + 125*x**5/24 + O(x**6) + assert series(LambertW(x**2), x, n=8) == x**2 - x**4 + 3*x**6/2 + O(x**8) + assert series(LambertW(sin(x)), x, n=4) == x - x**2 + 4*x**3/3 + O(x**4) + +def test_issue_11884(): + assert cos(x).series(x, 1, n=1) == cos(1) + O(x - 1, (x, 1)) + + +def test_issue_18008(): + y = x*(1 + x*(1 - x))/((1 + x*(1 - x)) - (1 - x)*(1 - x)) + assert y.series(x, oo, n=4) == -9/(32*x**3) - 3/(16*x**2) - 1/(8*x) + S(1)/4 + x/2 + \ + O(x**(-4), (x, oo)) + + +def test_issue_18842(): + f = log(x/(1 - x)) + assert f.series(x, 0.491, n=1).removeO().nsimplify() == \ + -S(180019443780011)/5000000000000000 + + +def test_issue_19534(): + dt = symbols('dt', real=True) + expr = 16*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0)/45 + \ + 49*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.051640768506639183825*dt + \ + dt*(1/2 - sqrt(21)/14) + 1.0)/180 + 49*dt*(-0.23637909581542530626*dt*(2.0*dt + 1.0) - \ + 0.74817562366625959291*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.88085458023927036857*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + \ + 2.1165151389911680013*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) - \ + 1.1854881643947648988*dt + dt*(sqrt(21)/14 + 1/2) + 1.0)/180 + \ + dt*(0.66666666666666666667*dt*(2.0*dt + 1.0) + \ + 6.0173399699313066769*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 4.1117044797036320069*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) - \ + 7.0189140975801991157*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) + \ + 0.94010945196161777522*dt*(-0.23637909581542530626*dt*(2.0*dt + 1.0) - \ + 0.74817562366625959291*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.88085458023927036857*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + \ + 2.1165151389911680013*dt*(-0.049335189898860408029*dt*(2.0*dt + 1.0) + \ + 0.29601113939316244817*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) - \ + 0.12564355335492979587*dt*(0.074074074074074074074*dt*(2.0*dt + 1.0) + \ + 0.2962962962962962963*dt*(0.125*dt*(2.0*dt + 1.0) + 0.875*dt + 1.0) + \ + 0.96296296296296296296*dt + 1.0) + 0.22431393315265061193*dt + 1.0) - \ + 0.35816132904077632692*dt + 1.0) + 5.5065024887242400038*dt + 1.0)/20 + dt/20 + 1 + + assert N(expr.series(dt, 0, 8), 20) == ( + - Float('0.00092592592592592596126289', precision=70) * dt**7 + + Float('0.0027777777777777783174695', precision=70) * dt**6 + + Float('0.016666666666666656027029', precision=70) * dt**5 + + Float('0.083333333333333300951828', precision=70) * dt**4 + + Float('0.33333333333333337034077', precision=70) * dt**3 + + Float('1.0', precision=70) * dt**2 + + Float('1.0', precision=70) * dt + + Float('1.0', precision=70) + ) + + +def test_issue_11407(): + a, b, c, x = symbols('a b c x') + assert series(sqrt(a + b + c*x), x, 0, 1) == sqrt(a + b) + O(x) + assert series(sqrt(a + b + c + c*x), x, 0, 1) == sqrt(a + b + c) + O(x) + + +def test_issue_14037(): + assert (sin(x**50)/x**51).series(x, n=0) == 1/x + O(1, x) + + +def test_issue_20551(): + expr = (exp(x)/x).series(x, n=None) + terms = [ next(expr) for i in range(3) ] + assert terms == [1/x, 1, x/2] + + +def test_issue_20697(): + p_0, p_1, p_2, p_3, b_0, b_1, b_2 = symbols('p_0 p_1 p_2 p_3 b_0 b_1 b_2') + Q = (p_0 + (p_1 + (p_2 + p_3/y)/y)/y)/(1 + ((p_3/(b_0*y) + (b_0*p_2\ + - b_1*p_3)/b_0**2)/y + (b_0**2*p_1 - b_0*b_1*p_2 - p_3*(b_0*b_2\ + - b_1**2))/b_0**3)/y) + assert Q.series(y, n=3).ratsimp() == b_2*y**2 + b_1*y + b_0 + O(y**3) + + +def test_issue_21245(): + fi = (1 + sqrt(5))/2 + assert (1/(1 - x - x**2)).series(x, 1/fi, 1).factor() == \ + (-4812 - 2152*sqrt(5) + 1686*x + 754*sqrt(5)*x\ + + O((x - 2/(1 + sqrt(5)))**2, (x, 2/(1 + sqrt(5)))))/((1 + sqrt(5))\ + *(20 + 9*sqrt(5))**2*(x + sqrt(5)*x - 2)) + + +def test_issue_21938(): + expr = sin(1/x + exp(-x)) - sin(1/x) + assert expr.series(x, oo) == (1/(24*x**4) - 1/(2*x**2) + 1 + O(x**(-6), (x, oo)))*exp(-x) + + +def test_issue_23432(): + expr = 1/sqrt(1 - x**2) + result = expr.series(x, 0.5) + assert result.is_Add and len(result.args) == 7 + + +def test_issue_23727(): + res = series(sqrt(1 - x**2), x, 0.1) + assert res.is_Add == True + + +def test_issue_24266(): + #type1: exp(f(x)) + assert (exp(-I*pi*(2*x+1))).series(x, 0, 3) == -1 + 2*I*pi*x + 2*pi**2*x**2 + O(x**3) + assert (exp(-I*pi*(2*x+1))*gamma(1+x)).series(x, 0, 3) == -1 + x*(EulerGamma + 2*I*pi) + \ + x**2*(-EulerGamma**2/2 + 23*pi**2/12 - 2*EulerGamma*I*pi) + O(x**3) + + #type2: c**f(x) + assert ((2*I)**(-I*pi*(2*x+1))).series(x, 0, 2) == exp(pi**2/2 - I*pi*log(2)) + \ + x*(pi**2*exp(pi**2/2 - I*pi*log(2)) - 2*I*pi*exp(pi**2/2 - I*pi*log(2))*log(2)) + O(x**2) + assert ((2)**(-I*pi*(2*x+1))).series(x, 0, 2) == exp(-I*pi*log(2)) - 2*I*pi*x*exp(-I*pi*log(2))*log(2) + O(x**2) + + #type3: f(y)**g(x) + assert ((y)**(I*pi*(2*x+1))).series(x, 0, 2) == exp(I*pi*log(y)) + 2*I*pi*x*exp(I*pi*log(y))*log(y) + O(x**2) + assert ((I*y)**(I*pi*(2*x+1))).series(x, 0, 2) == exp(I*pi*log(I*y)) + 2*I*pi*x*exp(I*pi*log(I*y))*log(I*y) + O(x**2) diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97fd6d8de983ea999d6048c35865422de5a25e4e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/conditionset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/conditionset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8182977c85826cadc4390efe70167f598a87a758 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/conditionset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/contains.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/contains.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abbdb73f3fa2325319a9ddc1ad6495dd009ba9c9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/contains.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/fancysets.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/fancysets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0a5d990c961195b7cfb1b15c93807c3f92324dc0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/fancysets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/ordinals.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/ordinals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7b5a718461403c16cf9b32bd4fe81a4f9d5386c2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/ordinals.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/powerset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/powerset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b9d0add281313af17f7fdccbefb6af721a7c64e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/powerset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/setexpr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/setexpr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..72335584e2cee9f1f0751e6c702231427077011a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/setexpr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/__pycache__/sets.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/__pycache__/sets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4070bcb96ae89e22ffbd9d82195a9c7072eb4c46 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/__pycache__/sets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__init__.py b/lib/python3.10/site-packages/sympy/sets/handlers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..671c4d2f75a0d7960f9e9126b12c9279932d4a29 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/add.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/add.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d868fa793ac62240e98825ade3f7dc9b8205607c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/add.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ff1bba95cefa1c994767b90dacde42d2b0d532ac Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/comparison.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd8b57b2e032c35d5ae473004ea0e81792eb63f5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b33acc37b67e0542e4802e3eb5dc0b0f91eb9ae Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/intersection.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b92bf38a44e6783e6bf466bde8af959fb9c20ba2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/issubset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0d0e72d6e90e96f91a19a29980e3eca623b985d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/mul.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/power.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/power.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..87aff5016f5dcf8e09c847c18a112c0a6463a9e5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/power.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/union.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/union.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6b7b9553d8694fc6e9446801963256ff7aec2fb7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/handlers/__pycache__/union.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/add.py b/lib/python3.10/site-packages/sympy/sets/handlers/add.py new file mode 100644 index 0000000000000000000000000000000000000000..8c07b25ed19d21febffd6b23a92b34b787179f44 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/add.py @@ -0,0 +1,79 @@ +from sympy.core.numbers import oo, Infinity, NegativeInfinity +from sympy.core.singleton import S +from sympy.core import Basic, Expr +from sympy.multipledispatch import Dispatcher +from sympy.sets import Interval, FiniteSet + + + +# XXX: The functions in this module are clearly not tested and are broken in a +# number of ways. + +_set_add = Dispatcher('_set_add') +_set_sub = Dispatcher('_set_sub') + + +@_set_add.register(Basic, Basic) +def _(x, y): + return None + + +@_set_add.register(Expr, Expr) +def _(x, y): + return x+y + + +@_set_add.register(Interval, Interval) +def _(x, y): + """ + Additions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start + y.start, x.end + y.end, + x.left_open or y.left_open, x.right_open or y.right_open) + + +@_set_add.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet({S.Infinity}) + +@_set_add.register(Interval, NegativeInfinity) +def _(x, y): + if x.end is S.Infinity: + return Interval(-oo, oo) + return FiniteSet({S.NegativeInfinity}) + + +@_set_sub.register(Basic, Basic) +def _(x, y): + return None + + +@_set_sub.register(Expr, Expr) +def _(x, y): + return x-y + + +@_set_sub.register(Interval, Interval) +def _(x, y): + """ + Subtractions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + return Interval(x.start - y.end, x.end - y.start, + x.left_open or y.right_open, x.right_open or y.left_open) + + +@_set_sub.register(Interval, Infinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) + +@_set_sub.register(Interval, NegativeInfinity) +def _(x, y): + if x.start is S.NegativeInfinity: + return Interval(-oo, oo) + return FiniteSet(-oo) diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/comparison.py b/lib/python3.10/site-packages/sympy/sets/handlers/comparison.py new file mode 100644 index 0000000000000000000000000000000000000000..b64d1a2a22e15d09f6f10fb4fef730163d468d45 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/comparison.py @@ -0,0 +1,53 @@ +from sympy.core.relational import Eq, is_eq +from sympy.core.basic import Basic +from sympy.core.logic import fuzzy_and, fuzzy_bool +from sympy.logic.boolalg import And +from sympy.multipledispatch import dispatch +from sympy.sets.sets import tfn, ProductSet, Interval, FiniteSet, Set + + +@dispatch(Interval, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(FiniteSet, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Interval, Interval) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return And(Eq(lhs.left, rhs.left), + Eq(lhs.right, rhs.right), + lhs.left_open == rhs.left_open, + lhs.right_open == rhs.right_open) + +@dispatch(FiniteSet, FiniteSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + def all_in_both(): + s_set = set(lhs.args) + o_set = set(rhs.args) + yield fuzzy_and(lhs._contains(e) for e in o_set - s_set) + yield fuzzy_and(rhs._contains(e) for e in s_set - o_set) + + return tfn[fuzzy_and(all_in_both())] + + +@dispatch(ProductSet, ProductSet) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + if len(lhs.sets) != len(rhs.sets): + return False + + eqs = (is_eq(x, y) for x, y in zip(lhs.sets, rhs.sets)) + return tfn[fuzzy_and(map(fuzzy_bool, eqs))] + + +@dispatch(Set, Basic) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return False + + +@dispatch(Set, Set) # type:ignore +def _eval_is_eq(lhs, rhs): # noqa: F811 + return tfn[fuzzy_and(a.is_subset(b) for a, b in [(lhs, rhs), (rhs, lhs)])] diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/functions.py b/lib/python3.10/site-packages/sympy/sets/handlers/functions.py new file mode 100644 index 0000000000000000000000000000000000000000..2529dbfd458451d7d09e91c717b170df77b1d9fe --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/functions.py @@ -0,0 +1,262 @@ +from sympy.core.singleton import S +from sympy.sets.sets import Set +from sympy.calculus.singularities import singularities +from sympy.core import Expr, Add +from sympy.core.function import Lambda, FunctionClass, diff, expand_mul +from sympy.core.numbers import Float, oo +from sympy.core.symbol import Dummy, symbols, Wild +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.logic.boolalg import true +from sympy.multipledispatch import Dispatcher +from sympy.sets import (imageset, Interval, FiniteSet, Union, ImageSet, + Intersection, Range, Complement) +from sympy.sets.sets import EmptySet, is_function_invertible_in_set +from sympy.sets.fancysets import Integers, Naturals, Reals +from sympy.functions.elementary.exponential import match_real_imag + + +_x, _y = symbols("x y") + +FunctionUnion = (FunctionClass, Lambda) + +_set_function = Dispatcher('_set_function') + + +@_set_function.register(FunctionClass, Set) +def _(f, x): + return None + +@_set_function.register(FunctionUnion, FiniteSet) +def _(f, x): + return FiniteSet(*map(f, x)) + +@_set_function.register(Lambda, Interval) +def _(f, x): + from sympy.solvers.solveset import solveset + from sympy.series import limit + # TODO: handle functions with infinitely many solutions (eg, sin, tan) + # TODO: handle multivariate functions + + expr = f.expr + if len(expr.free_symbols) > 1 or len(f.variables) != 1: + return + var = f.variables[0] + if not var.is_real: + if expr.subs(var, Dummy(real=True)).is_real is False: + return + + if expr.is_Piecewise: + result = S.EmptySet + domain_set = x + for (p_expr, p_cond) in expr.args: + if p_cond is true: + intrvl = domain_set + else: + intrvl = p_cond.as_set() + intrvl = Intersection(domain_set, intrvl) + + if p_expr.is_Number: + image = FiniteSet(p_expr) + else: + image = imageset(Lambda(var, p_expr), intrvl) + result = Union(result, image) + + # remove the part which has been `imaged` + domain_set = Complement(domain_set, intrvl) + if domain_set is S.EmptySet: + break + return result + + if not x.start.is_comparable or not x.end.is_comparable: + return + + try: + from sympy.polys.polyutils import _nsort + sing = list(singularities(expr, var, x)) + if len(sing) > 1: + sing = _nsort(sing) + except NotImplementedError: + return + + if x.left_open: + _start = limit(expr, var, x.start, dir="+") + elif x.start not in sing: + _start = f(x.start) + if x.right_open: + _end = limit(expr, var, x.end, dir="-") + elif x.end not in sing: + _end = f(x.end) + + if len(sing) == 0: + soln_expr = solveset(diff(expr, var), var) + if not (isinstance(soln_expr, FiniteSet) + or soln_expr is S.EmptySet): + return + solns = list(soln_expr) + + extr = [_start, _end] + [f(i) for i in solns + if i.is_real and i in x] + start, end = Min(*extr), Max(*extr) + + left_open, right_open = False, False + if _start <= _end: + # the minimum or maximum value can occur simultaneously + # on both the edge of the interval and in some interior + # point + if start == _start and start not in solns: + left_open = x.left_open + if end == _end and end not in solns: + right_open = x.right_open + else: + if start == _end and start not in solns: + left_open = x.right_open + if end == _start and end not in solns: + right_open = x.left_open + + return Interval(start, end, left_open, right_open) + else: + return imageset(f, Interval(x.start, sing[0], + x.left_open, True)) + \ + Union(*[imageset(f, Interval(sing[i], sing[i + 1], True, True)) + for i in range(0, len(sing) - 1)]) + \ + imageset(f, Interval(sing[-1], x.end, True, x.right_open)) + +@_set_function.register(FunctionClass, Interval) +def _(f, x): + if f == exp: + return Interval(exp(x.start), exp(x.end), x.left_open, x.right_open) + elif f == log: + return Interval(log(x.start), log(x.end), x.left_open, x.right_open) + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Union) +def _(f, x): + return Union(*(imageset(f, arg) for arg in x.args)) + +@_set_function.register(FunctionUnion, Intersection) +def _(f, x): + # If the function is invertible, intersect the maps of the sets. + if is_function_invertible_in_set(f, x): + return Intersection(*(imageset(f, arg) for arg in x.args)) + else: + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, EmptySet) +def _(f, x): + return x + +@_set_function.register(FunctionUnion, Set) +def _(f, x): + return ImageSet(Lambda(_x, f(_x)), x) + +@_set_function.register(FunctionUnion, Range) +def _(f, self): + if not self: + return S.EmptySet + if not isinstance(f.expr, Expr): + return + if self.size == 1: + return FiniteSet(f(self[0])) + if f is S.IdentityFunction: + return self + + x = f.variables[0] + expr = f.expr + # handle f that is linear in f's variable + if x not in expr.free_symbols or x in expr.diff(x).free_symbols: + return + if self.start.is_finite: + F = f(self.step*x + self.start) # for i in range(len(self)) + else: + F = f(-self.step*x + self[-1]) + F = expand_mul(F) + if F != expr: + return imageset(x, F, Range(self.size)) + +@_set_function.register(FunctionUnion, Integers) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + n = f.variables[0] + if expr == abs(n): + return S.Naturals0 + + # f(x) + c and f(-x) + c cover the same integers + # so choose the form that has the fewest negatives + c = f(0) + fx = f(n) - c + f_x = f(-n) - c + neg_count = lambda e: sum(_.could_extract_minus_sign() + for _ in Add.make_args(e)) + if neg_count(f_x) < neg_count(fx): + expr = f_x + c + + a = Wild('a', exclude=[n]) + b = Wild('b', exclude=[n]) + match = expr.match(a*n + b) + if match and match[a] and ( + not match[a].atoms(Float) and + not match[b].atoms(Float)): + # canonical shift + a, b = match[a], match[b] + if a in [1, -1]: + # drop integer addends in b + nonint = [] + for bi in Add.make_args(b): + if not bi.is_integer: + nonint.append(bi) + b = Add(*nonint) + if b.is_number and a.is_real: + # avoid Mod for complex numbers, #11391 + br, bi = match_real_imag(b) + if br and br.is_comparable and a.is_comparable: + br %= a + b = br + S.ImaginaryUnit*bi + elif b.is_number and a.is_imaginary: + br, bi = match_real_imag(b) + ai = a/S.ImaginaryUnit + if bi and bi.is_comparable and ai.is_comparable: + bi %= ai + b = br + S.ImaginaryUnit*bi + expr = a*n + b + + if expr != f.expr: + return ImageSet(Lambda(n, expr), S.Integers) + + +@_set_function.register(FunctionUnion, Naturals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + + x = f.variables[0] + if not expr.free_symbols - {x}: + if expr == abs(x): + if self is S.Naturals: + return self + return S.Naturals0 + step = expr.coeff(x) + c = expr.subs(x, 0) + if c.is_Integer and step.is_Integer and expr == step*x + c: + if self is S.Naturals: + c += step + if step > 0: + if step == 1: + if c == 0: + return S.Naturals0 + elif c == 1: + return S.Naturals + return Range(c, oo, step) + return Range(c, -oo, step) + + +@_set_function.register(FunctionUnion, Reals) +def _(f, self): + expr = f.expr + if not isinstance(expr, Expr): + return + return _set_function(f, Interval(-oo, oo)) diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/intersection.py b/lib/python3.10/site-packages/sympy/sets/handlers/intersection.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb9309ef3e9d2722ab1bfe664f1d1644f17da5d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/intersection.py @@ -0,0 +1,533 @@ +from sympy.core.basic import _aresame +from sympy.core.function import Lambda, expand_complex +from sympy.core.mul import Mul +from sympy.core.numbers import ilcm, Float +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.core.sorting import ordered +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.integers import floor, ceiling +from sympy.sets.fancysets import ComplexRegion +from sympy.sets.sets import (FiniteSet, Intersection, Interval, Set, Union) +from sympy.multipledispatch import Dispatcher +from sympy.sets.conditionset import ConditionSet +from sympy.sets.fancysets import (Integers, Naturals, Reals, Range, + ImageSet, Rationals) +from sympy.sets.sets import EmptySet, UniversalSet, imageset, ProductSet +from sympy.simplify.radsimp import numer + + +intersection_sets = Dispatcher('intersection_sets') + + +@intersection_sets.register(ConditionSet, ConditionSet) +def _(a, b): + return None + +@intersection_sets.register(ConditionSet, Set) +def _(a, b): + return ConditionSet(a.sym, a.condition, Intersection(a.base_set, b)) + +@intersection_sets.register(Naturals, Integers) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Naturals) +def _(a, b): + return a if a is S.Naturals else b + +@intersection_sets.register(Interval, Naturals) +def _(a, b): + return intersection_sets(b, a) + +@intersection_sets.register(ComplexRegion, Set) +def _(self, other): + if other.is_ComplexRegion: + # self in rectangular form + if (not self.polar) and (not other.polar): + return ComplexRegion(Intersection(self.sets, other.sets)) + + # self in polar form + elif self.polar and other.polar: + r1, theta1 = self.a_interval, self.b_interval + r2, theta2 = other.a_interval, other.b_interval + new_r_interval = Intersection(r1, r2) + new_theta_interval = Intersection(theta1, theta2) + + # 0 and 2*Pi means the same + if ((2*S.Pi in theta1 and S.Zero in theta2) or + (2*S.Pi in theta2 and S.Zero in theta1)): + new_theta_interval = Union(new_theta_interval, + FiniteSet(0)) + return ComplexRegion(new_r_interval*new_theta_interval, + polar=True) + + + if other.is_subset(S.Reals): + new_interval = [] + x = symbols("x", cls=Dummy, real=True) + + # self in rectangular form + if not self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + + # self in polar form + elif self.polar: + for element in self.psets: + if S.Zero in element.args[1]: + new_interval.append(element.args[0]) + if S.Pi in element.args[1]: + new_interval.append(ImageSet(Lambda(x, -x), element.args[0])) + if S.Zero in element.args[0]: + new_interval.append(FiniteSet(0)) + new_interval = Union(*new_interval) + return Intersection(new_interval, other) + +@intersection_sets.register(Integers, Reals) +def _(a, b): + return a + +@intersection_sets.register(Range, Interval) +def _(a, b): + # Check that there are no symbolic arguments + if not all(i.is_number for i in a.args + b.args[:2]): + return + + # In case of null Range, return an EmptySet. + if a.size == 0: + return S.EmptySet + + # trim down to self's size, and represent + # as a Range with step 1. + start = ceiling(max(b.inf, a.inf)) + if start not in b: + start += 1 + end = floor(min(b.sup, a.sup)) + if end not in b: + end -= 1 + return intersection_sets(a, Range(start, end + 1)) + +@intersection_sets.register(Range, Naturals) +def _(a, b): + return intersection_sets(a, Interval(b.inf, S.Infinity)) + +@intersection_sets.register(Range, Range) +def _(a, b): + # Check that there are no symbolic range arguments + if not all(all(v.is_number for v in r.args) for r in [a, b]): + return None + + # non-overlap quick exits + if not b: + return S.EmptySet + if not a: + return S.EmptySet + if b.sup < a.inf: + return S.EmptySet + if b.inf > a.sup: + return S.EmptySet + + # work with finite end at the start + r1 = a + if r1.start.is_infinite: + r1 = r1.reversed + r2 = b + if r2.start.is_infinite: + r2 = r2.reversed + + # If both ends are infinite then it means that one Range is just the set + # of all integers (the step must be 1). + if r1.start.is_infinite: + return b + if r2.start.is_infinite: + return a + + from sympy.solvers.diophantine.diophantine import diop_linear + + # this equation represents the values of the Range; + # it's a linear equation + eq = lambda r, i: r.start + i*r.step + + # we want to know when the two equations might + # have integer solutions so we use the diophantine + # solver + va, vb = diop_linear(eq(r1, Dummy('a')) - eq(r2, Dummy('b'))) + + # check for no solution + no_solution = va is None and vb is None + if no_solution: + return S.EmptySet + + # there is a solution + # ------------------- + + # find the coincident point, c + a0 = va.as_coeff_Add()[0] + c = eq(r1, a0) + + # find the first point, if possible, in each range + # since c may not be that point + def _first_finite_point(r1, c): + if c == r1.start: + return c + # st is the signed step we need to take to + # get from c to r1.start + st = sign(r1.start - c)*step + # use Range to calculate the first point: + # we want to get as close as possible to + # r1.start; the Range will not be null since + # it will at least contain c + s1 = Range(c, r1.start + st, st)[-1] + if s1 == r1.start: + pass + else: + # if we didn't hit r1.start then, if the + # sign of st didn't match the sign of r1.step + # we are off by one and s1 is not in r1 + if sign(r1.step) != sign(st): + s1 -= st + if s1 not in r1: + return + return s1 + + # calculate the step size of the new Range + step = abs(ilcm(r1.step, r2.step)) + s1 = _first_finite_point(r1, c) + if s1 is None: + return S.EmptySet + s2 = _first_finite_point(r2, c) + if s2 is None: + return S.EmptySet + + # replace the corresponding start or stop in + # the original Ranges with these points; the + # result must have at least one point since + # we know that s1 and s2 are in the Ranges + def _updated_range(r, first): + st = sign(r.step)*step + if r.start.is_finite: + rv = Range(first, r.stop, st) + else: + rv = Range(r.start, first + st, st) + return rv + r1 = _updated_range(a, s1) + r2 = _updated_range(b, s2) + + # work with them both in the increasing direction + if sign(r1.step) < 0: + r1 = r1.reversed + if sign(r2.step) < 0: + r2 = r2.reversed + + # return clipped Range with positive step; it + # can't be empty at this point + start = max(r1.start, r2.start) + stop = min(r1.stop, r2.stop) + return Range(start, stop, step) + + +@intersection_sets.register(Range, Integers) +def _(a, b): + return a + + +@intersection_sets.register(Range, Rationals) +def _(a, b): + return a + + +@intersection_sets.register(ImageSet, Set) +def _(self, other): + from sympy.solvers.diophantine import diophantine + + # Only handle the straight-forward univariate case + if (len(self.lamda.variables) > 1 + or self.lamda.signature != self.lamda.variables): + return None + base_set = self.base_sets[0] + + # Intersection between ImageSets with Integers as base set + # For {f(n) : n in Integers} & {g(m) : m in Integers} we solve the + # diophantine equations f(n)=g(m). + # If the solutions for n are {h(t) : t in Integers} then we return + # {f(h(t)) : t in integers}. + # If the solutions for n are {n_1, n_2, ..., n_k} then we return + # {f(n_i) : 1 <= i <= k}. + if base_set is S.Integers: + gm = None + if isinstance(other, ImageSet) and other.base_sets == (S.Integers,): + gm = other.lamda.expr + var = other.lamda.variables[0] + # Symbol of second ImageSet lambda must be distinct from first + m = Dummy('m') + gm = gm.subs(var, m) + elif other is S.Integers: + m = gm = Dummy('m') + if gm is not None: + fn = self.lamda.expr + n = self.lamda.variables[0] + try: + solns = list(diophantine(fn - gm, syms=(n, m), permute=True)) + except (TypeError, NotImplementedError): + # TypeError if equation not polynomial with rational coeff. + # NotImplementedError if correct format but no solver. + return + # 3 cases are possible for solns: + # - empty set, + # - one or more parametric (infinite) solutions, + # - a finite number of (non-parametric) solution couples. + # Among those, there is one type of solution set that is + # not helpful here: multiple parametric solutions. + if len(solns) == 0: + return S.EmptySet + elif any(s.free_symbols for tupl in solns for s in tupl): + if len(solns) == 1: + soln, solm = solns[0] + (t,) = soln.free_symbols + expr = fn.subs(n, soln.subs(t, n)).expand() + return imageset(Lambda(n, expr), S.Integers) + else: + return + else: + return FiniteSet(*(fn.subs(n, s[0]) for s in solns)) + + if other == S.Reals: + from sympy.solvers.solvers import denoms, solve_linear + + def _solution_union(exprs, sym): + # return a union of linear solutions to i in expr; + # if i cannot be solved, use a ConditionSet for solution + sols = [] + for i in exprs: + x, xis = solve_linear(i, 0, [sym]) + if x == sym: + sols.append(FiniteSet(xis)) + else: + sols.append(ConditionSet(sym, Eq(i, 0))) + return Union(*sols) + + f = self.lamda.expr + n = self.lamda.variables[0] + + n_ = Dummy(n.name, real=True) + f_ = f.subs(n, n_) + + re, im = f_.as_real_imag() + im = expand_complex(im) + + re = re.subs(n_, n) + im = im.subs(n_, n) + ifree = im.free_symbols + lam = Lambda(n, re) + if im.is_zero: + # allow re-evaluation + # of self in this case to make + # the result canonical + pass + elif im.is_zero is False: + return S.EmptySet + elif ifree != {n}: + return None + else: + # univarite imaginary part in same variable; + # use numer instead of as_numer_denom to keep + # this as fast as possible while still handling + # simple cases + base_set &= _solution_union( + Mul.make_args(numer(im)), n) + # exclude values that make denominators 0 + base_set -= _solution_union(denoms(f), n) + return imageset(lam, base_set) + + elif isinstance(other, Interval): + from sympy.solvers.solveset import (invert_real, invert_complex, + solveset) + + f = self.lamda.expr + n = self.lamda.variables[0] + new_inf, new_sup = None, None + new_lopen, new_ropen = other.left_open, other.right_open + + if f.is_real: + inverter = invert_real + else: + inverter = invert_complex + + g1, h1 = inverter(f, other.inf, n) + g2, h2 = inverter(f, other.sup, n) + + if all(isinstance(i, FiniteSet) for i in (h1, h2)): + if g1 == n: + if len(h1) == 1: + new_inf = h1.args[0] + if g2 == n: + if len(h2) == 1: + new_sup = h2.args[0] + # TODO: Design a technique to handle multiple-inverse + # functions + + # Any of the new boundary values cannot be determined + if any(i is None for i in (new_sup, new_inf)): + return + + + range_set = S.EmptySet + + if all(i.is_real for i in (new_sup, new_inf)): + # this assumes continuity of underlying function + # however fixes the case when it is decreasing + if new_inf > new_sup: + new_inf, new_sup = new_sup, new_inf + new_interval = Interval(new_inf, new_sup, new_lopen, new_ropen) + range_set = base_set.intersect(new_interval) + else: + if other.is_subset(S.Reals): + solutions = solveset(f, n, S.Reals) + if not isinstance(range_set, (ImageSet, ConditionSet)): + range_set = solutions.intersect(other) + else: + return + + if range_set is S.EmptySet: + return S.EmptySet + elif isinstance(range_set, Range) and range_set.size is not S.Infinity: + range_set = FiniteSet(*list(range_set)) + + if range_set is not None: + return imageset(Lambda(n, f), range_set) + return + else: + return + + +@intersection_sets.register(ProductSet, ProductSet) +def _(a, b): + if len(b.args) != len(a.args): + return S.EmptySet + return ProductSet(*(i.intersect(j) for i, j in zip(a.sets, b.sets))) + + +@intersection_sets.register(Interval, Interval) +def _(a, b): + # handle (-oo, oo) + infty = S.NegativeInfinity, S.Infinity + if a == Interval(*infty): + l, r = a.left, a.right + if l.is_real or l in infty or r.is_real or r in infty: + return b + + # We can't intersect [0,3] with [x,6] -- we don't know if x>0 or x<0 + if not a._is_comparable(b): + return None + + empty = False + + if a.start <= b.end and b.start <= a.end: + # Get topology right. + if a.start < b.start: + start = b.start + left_open = b.left_open + elif a.start > b.start: + start = a.start + left_open = a.left_open + else: + start = a.start + if not _aresame(a.start, b.start): + # For example Integer(2) != Float(2) + # Prefer the Float boundary because Floats should be + # contagious in calculations. + if b.start.has(Float) and not a.start.has(Float): + start = b.start + elif a.start.has(Float) and not b.start.has(Float): + start = a.start + else: + #this is to ensure that if Eq(a.start, b.start) but + #type(a.start) != type(b.start) the order of a and b + #does not matter for the result + start = list(ordered([a,b]))[0].start + left_open = a.left_open or b.left_open + + if a.end < b.end: + end = a.end + right_open = a.right_open + elif a.end > b.end: + end = b.end + right_open = b.right_open + else: + # see above for logic with start + end = a.end + if not _aresame(a.end, b.end): + if b.end.has(Float) and not a.end.has(Float): + end = b.end + elif a.end.has(Float) and not b.end.has(Float): + end = a.end + else: + end = list(ordered([a,b]))[0].end + right_open = a.right_open or b.right_open + + if end - start == 0 and (left_open or right_open): + empty = True + else: + empty = True + + if empty: + return S.EmptySet + + return Interval(start, end, left_open, right_open) + +@intersection_sets.register(EmptySet, Set) +def _(a, b): + return S.EmptySet + +@intersection_sets.register(UniversalSet, Set) +def _(a, b): + return b + +@intersection_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements & b._elements)) + +@intersection_sets.register(FiniteSet, Set) +def _(a, b): + try: + return FiniteSet(*[el for el in a if el in b]) + except TypeError: + return None # could not evaluate `el in b` due to symbolic ranges. + +@intersection_sets.register(Set, Set) +def _(a, b): + return None + +@intersection_sets.register(Integers, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Naturals, Rationals) +def _(a, b): + return a + +@intersection_sets.register(Rationals, Reals) +def _(a, b): + return a + +def _intlike_interval(a, b): + try: + if b._inf is S.NegativeInfinity and b._sup is S.Infinity: + return a + s = Range(max(a.inf, ceiling(b.left)), floor(b.right) + 1) + return intersection_sets(s, b) # take out endpoints if open interval + except ValueError: + return None + +@intersection_sets.register(Integers, Interval) +def _(a, b): + return _intlike_interval(a, b) + +@intersection_sets.register(Naturals, Interval) +def _(a, b): + return _intlike_interval(a, b) diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/issubset.py b/lib/python3.10/site-packages/sympy/sets/handlers/issubset.py new file mode 100644 index 0000000000000000000000000000000000000000..cc23e8bf56f1743cd7f08452dd09a0acf981f5da --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/issubset.py @@ -0,0 +1,144 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.core.logic import fuzzy_and, fuzzy_bool, fuzzy_not, fuzzy_or +from sympy.core.relational import Eq +from sympy.sets.sets import FiniteSet, Interval, Set, Union, ProductSet +from sympy.sets.fancysets import Complexes, Reals, Range, Rationals +from sympy.multipledispatch import Dispatcher + + +_inf_sets = [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, S.Complexes] + + +is_subset_sets = Dispatcher('is_subset_sets') + + +@is_subset_sets.register(Set, Set) +def _(a, b): + return None + +@is_subset_sets.register(Interval, Interval) +def _(a, b): + # This is correct but can be made more comprehensive... + if fuzzy_bool(a.start < b.start): + return False + if fuzzy_bool(a.end > b.end): + return False + if (b.left_open and not a.left_open and fuzzy_bool(Eq(a.start, b.start))): + return False + if (b.right_open and not a.right_open and fuzzy_bool(Eq(a.end, b.end))): + return False + +@is_subset_sets.register(Interval, FiniteSet) +def _(a_interval, b_fs): + # An Interval can only be a subset of a finite set if it is finite + # which can only happen if it has zero measure. + if fuzzy_not(a_interval.measure.is_zero): + return False + +@is_subset_sets.register(Interval, Union) +def _(a_interval, b_u): + if all(isinstance(s, (Interval, FiniteSet)) for s in b_u.args): + intervals = [s for s in b_u.args if isinstance(s, Interval)] + if all(fuzzy_bool(a_interval.start < s.start) for s in intervals): + return False + if all(fuzzy_bool(a_interval.end > s.end) for s in intervals): + return False + if a_interval.measure.is_nonzero: + no_overlap = lambda s1, s2: fuzzy_or([ + fuzzy_bool(s1.end <= s2.start), + fuzzy_bool(s1.start >= s2.end), + ]) + if all(no_overlap(s, a_interval) for s in intervals): + return False + +@is_subset_sets.register(Range, Range) +def _(a, b): + if a.step == b.step == 1: + return fuzzy_and([fuzzy_bool(a.start >= b.start), + fuzzy_bool(a.stop <= b.stop)]) + +@is_subset_sets.register(Range, Interval) +def _(a_range, b_interval): + if a_range.step.is_positive: + if b_interval.left_open and a_range.inf.is_finite: + cond_left = a_range.inf > b_interval.left + else: + cond_left = a_range.inf >= b_interval.left + if b_interval.right_open and a_range.sup.is_finite: + cond_right = a_range.sup < b_interval.right + else: + cond_right = a_range.sup <= b_interval.right + return fuzzy_and([cond_left, cond_right]) + +@is_subset_sets.register(Range, FiniteSet) +def _(a_range, b_finiteset): + try: + a_size = a_range.size + except ValueError: + # symbolic Range of unknown size + return None + if a_size > len(b_finiteset): + return False + elif any(arg.has(Symbol) for arg in a_range.args): + return fuzzy_and(b_finiteset.contains(x) for x in a_range) + else: + # Checking A \ B == EmptySet is more efficient than repeated naive + # membership checks on an arbitrary FiniteSet. + a_set = set(a_range) + b_remaining = len(b_finiteset) + # Symbolic expressions and numbers of unknown type (integer or not) are + # all counted as "candidates", i.e. *potentially* matching some a in + # a_range. + cnt_candidate = 0 + for b in b_finiteset: + if b.is_Integer: + a_set.discard(b) + elif fuzzy_not(b.is_integer): + pass + else: + cnt_candidate += 1 + b_remaining -= 1 + if len(a_set) > b_remaining + cnt_candidate: + return False + if len(a_set) == 0: + return True + return None + +@is_subset_sets.register(Interval, Range) +def _(a_interval, b_range): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Interval, Rationals) +def _(a_interval, b_rationals): + if a_interval.measure.is_extended_nonzero: + return False + +@is_subset_sets.register(Range, Complexes) +def _(a, b): + return True + +@is_subset_sets.register(Complexes, Interval) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Range) +def _(a, b): + return False + +@is_subset_sets.register(Complexes, Rationals) +def _(a, b): + return False + +@is_subset_sets.register(Rationals, Reals) +def _(a, b): + return True + +@is_subset_sets.register(Rationals, Range) +def _(a, b): + return False + +@is_subset_sets.register(ProductSet, FiniteSet) +def _(a_ps, b_fs): + return fuzzy_and(b_fs.contains(x) for x in a_ps) diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/mul.py b/lib/python3.10/site-packages/sympy/sets/handlers/mul.py new file mode 100644 index 0000000000000000000000000000000000000000..0dedc8068b7973fd4cb6fbf2854e5fa671d188de --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/mul.py @@ -0,0 +1,79 @@ +from sympy.core import Basic, Expr +from sympy.core.numbers import oo +from sympy.core.symbol import symbols +from sympy.multipledispatch import Dispatcher +from sympy.sets.setexpr import set_mul +from sympy.sets.sets import Interval, Set + + +_x, _y = symbols("x y") + + +_set_mul = Dispatcher('_set_mul') +_set_div = Dispatcher('_set_div') + + +@_set_mul.register(Basic, Basic) +def _(x, y): + return None + +@_set_mul.register(Set, Set) +def _(x, y): + return None + +@_set_mul.register(Expr, Expr) +def _(x, y): + return x*y + +@_set_mul.register(Interval, Interval) +def _(x, y): + """ + Multiplications in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + # TODO: some intervals containing 0 and oo will fail as 0*oo returns nan. + comvals = ( + (x.start * y.start, bool(x.left_open or y.left_open)), + (x.start * y.end, bool(x.left_open or y.right_open)), + (x.end * y.start, bool(x.right_open or y.left_open)), + (x.end * y.end, bool(x.right_open or y.right_open)), + ) + # TODO: handle symbolic intervals + minval, minopen = min(comvals) + maxval, maxopen = max(comvals) + return Interval( + minval, + maxval, + minopen, + maxopen + ) + +@_set_div.register(Basic, Basic) +def _(x, y): + return None + +@_set_div.register(Expr, Expr) +def _(x, y): + return x/y + +@_set_div.register(Set, Set) +def _(x, y): + return None + +@_set_div.register(Interval, Interval) +def _(x, y): + """ + Divisions in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + if (y.start*y.end).is_negative: + return Interval(-oo, oo) + if y.start == 0: + s2 = oo + else: + s2 = 1/y.start + if y.end == 0: + s1 = -oo + else: + s1 = 1/y.end + return set_mul(x, Interval(s1, s2, y.right_open, y.left_open)) diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/power.py b/lib/python3.10/site-packages/sympy/sets/handlers/power.py new file mode 100644 index 0000000000000000000000000000000000000000..3cad4ee49ab27770143bc121d1fbcd024bf01548 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/power.py @@ -0,0 +1,107 @@ +from sympy.core import Basic, Expr +from sympy.core.function import Lambda +from sympy.core.numbers import oo, Infinity, NegativeInfinity, Zero, Integer +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import (Max, Min) +from sympy.sets.fancysets import ImageSet +from sympy.sets.setexpr import set_div +from sympy.sets.sets import Set, Interval, FiniteSet, Union +from sympy.multipledispatch import Dispatcher + + +_x, _y = symbols("x y") + + +_set_pow = Dispatcher('_set_pow') + + +@_set_pow.register(Basic, Basic) +def _(x, y): + return None + +@_set_pow.register(Set, Set) +def _(x, y): + return ImageSet(Lambda((_x, _y), (_x ** _y)), x, y) + +@_set_pow.register(Expr, Expr) +def _(x, y): + return x**y + +@_set_pow.register(Interval, Zero) +def _(x, z): + return FiniteSet(S.One) + +@_set_pow.register(Interval, Integer) +def _(x, exponent): + """ + Powers in interval arithmetic + https://en.wikipedia.org/wiki/Interval_arithmetic + """ + s1 = x.start**exponent + s2 = x.end**exponent + if ((s2 > s1) if exponent > 0 else (x.end > -x.start)) == True: + left_open = x.left_open + right_open = x.right_open + # TODO: handle unevaluated condition. + sleft = s2 + else: + # TODO: `s2 > s1` could be unevaluated. + left_open = x.right_open + right_open = x.left_open + sleft = s1 + + if x.start.is_positive: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + elif x.end.is_negative: + return Interval( + Min(s1, s2), + Max(s1, s2), left_open, right_open) + + # Case where x.start < 0 and x.end > 0: + if exponent.is_odd: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(-oo, s1, True, x.left_open) + return Union(Interval(-oo, s1, True, x.left_open), Interval(s2, oo, x.right_open)) + else: + return Interval(s1, s2, x.left_open, x.right_open) + elif exponent.is_even: + if exponent.is_negative: + if x.start.is_zero: + return Interval(s2, oo, x.right_open) + if x.end.is_zero: + return Interval(s1, oo, x.left_open) + return Interval(0, oo) + else: + return Interval(S.Zero, sleft, S.Zero not in x, left_open) + +@_set_pow.register(Interval, Infinity) +def _(b, e): + # TODO: add logic for open intervals? + if b.start.is_nonnegative: + if b.end < 1: + return FiniteSet(S.Zero) + if b.start > 1: + return FiniteSet(S.Infinity) + return Interval(0, oo) + elif b.end.is_negative: + if b.start > -1: + return FiniteSet(S.Zero) + if b.end < -1: + return FiniteSet(-oo, oo) + return Interval(-oo, oo) + else: + if b.start > -1: + if b.end < 1: + return FiniteSet(S.Zero) + return Interval(0, oo) + return Interval(-oo, oo) + +@_set_pow.register(Interval, NegativeInfinity) +def _(b, e): + return _set_pow(set_div(S.One, b), oo) diff --git a/lib/python3.10/site-packages/sympy/sets/handlers/union.py b/lib/python3.10/site-packages/sympy/sets/handlers/union.py new file mode 100644 index 0000000000000000000000000000000000000000..75d867b49969ae2aeea76155dbaae7e05c1a6847 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/handlers/union.py @@ -0,0 +1,147 @@ +from sympy.core.singleton import S +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import Min, Max +from sympy.sets.sets import (EmptySet, FiniteSet, Intersection, + Interval, ProductSet, Set, Union, UniversalSet) +from sympy.sets.fancysets import (ComplexRegion, Naturals, Naturals0, + Integers, Rationals, Reals) +from sympy.multipledispatch import Dispatcher + + +union_sets = Dispatcher('union_sets') + + +@union_sets.register(Naturals0, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals) +def _(a, b): + return a + +@union_sets.register(Rationals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals) +def _(a, b): + return a + +@union_sets.register(Reals, Naturals0) +def _(a, b): + return a + +@union_sets.register(Reals, Rationals) +def _(a, b): + return a + +@union_sets.register(Integers, Set) +def _(a, b): + intersect = Intersection(a, b) + if intersect == a: + return b + elif intersect == b: + return a + +@union_sets.register(ComplexRegion, Set) +def _(a, b): + if b.is_subset(S.Reals): + # treat a subset of reals as a complex region + b = ComplexRegion.from_real(b) + + if b.is_ComplexRegion: + # a in rectangular form + if (not a.polar) and (not b.polar): + return ComplexRegion(Union(a.sets, b.sets)) + # a in polar form + elif a.polar and b.polar: + return ComplexRegion(Union(a.sets, b.sets), polar=True) + return None + +@union_sets.register(EmptySet, Set) +def _(a, b): + return b + + +@union_sets.register(UniversalSet, Set) +def _(a, b): + return a + +@union_sets.register(ProductSet, ProductSet) +def _(a, b): + if b.is_subset(a): + return a + if len(b.sets) != len(a.sets): + return None + if len(a.sets) == 2: + a1, a2 = a.sets + b1, b2 = b.sets + if a1 == b1: + return a1 * Union(a2, b2) + if a2 == b2: + return Union(a1, b1) * a2 + return None + +@union_sets.register(ProductSet, Set) +def _(a, b): + if b.is_subset(a): + return a + return None + +@union_sets.register(Interval, Interval) +def _(a, b): + if a._is_comparable(b): + # Non-overlapping intervals + end = Min(a.end, b.end) + start = Max(a.start, b.start) + if (end < start or + (end == start and (end not in a and end not in b))): + return None + else: + start = Min(a.start, b.start) + end = Max(a.end, b.end) + + left_open = ((a.start != start or a.left_open) and + (b.start != start or b.left_open)) + right_open = ((a.end != end or a.right_open) and + (b.end != end or b.right_open)) + return Interval(start, end, left_open, right_open) + +@union_sets.register(Interval, UniversalSet) +def _(a, b): + return S.UniversalSet + +@union_sets.register(Interval, Set) +def _(a, b): + # If I have open end points and these endpoints are contained in b + # But only in case, when endpoints are finite. Because + # interval does not contain oo or -oo. + open_left_in_b_and_finite = (a.left_open and + sympify(b.contains(a.start)) is S.true and + a.start.is_finite) + open_right_in_b_and_finite = (a.right_open and + sympify(b.contains(a.end)) is S.true and + a.end.is_finite) + if open_left_in_b_and_finite or open_right_in_b_and_finite: + # Fill in my end points and return + open_left = a.left_open and a.start not in b + open_right = a.right_open and a.end not in b + new_a = Interval(a.start, a.end, open_left, open_right) + return {new_a, b} + return None + +@union_sets.register(FiniteSet, FiniteSet) +def _(a, b): + return FiniteSet(*(a._elements | b._elements)) + +@union_sets.register(FiniteSet, Set) +def _(a, b): + # If `b` set contains one of my elements, remove it from `a` + if any(b.contains(x) == True for x in a): + return { + FiniteSet(*[x for x in a if b.contains(x) != True]), b} + return None + +@union_sets.register(Set, Set) +def _(a, b): + return None diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__init__.py b/lib/python3.10/site-packages/sympy/sets/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cf4b7045f24e12a71cacefc36a967b0c989a5d8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c616a12cf36c10d08c4f3d73279ee3c5cc2fbf04 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_conditionset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0353b518cf0850709da7957d26a0c022ebc89ee7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_contains.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_fancysets.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_fancysets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..db3dcb302f9d151f29b2c863a9287e382063d4ae Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_fancysets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..10877603b182456aacc0eb73eac7e6d5db27e98e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_ordinals.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cd2afd76865cf123fc0d253a705908175f804200 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_powerset.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..33d953a5e1d56848ae82a3c17a456eecf782bd5a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_setexpr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_sets.cpython-310.pyc b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_sets.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96398a1392a084948792bcd87ec74a9024623878 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/sets/tests/__pycache__/test_sets.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/sets/tests/test_conditionset.py b/lib/python3.10/site-packages/sympy/sets/tests/test_conditionset.py new file mode 100644 index 0000000000000000000000000000000000000000..4818246f306afd46a09a2cbea1faab858a9e7806 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/tests/test_conditionset.py @@ -0,0 +1,294 @@ +from sympy.core.expr import unchanged +from sympy.sets import (ConditionSet, Intersection, FiniteSet, + EmptySet, Union, Contains, ImageSet) +from sympy.sets.sets import SetKind +from sympy.core.function import (Function, Lambda) +from sympy.core.mod import Mod +from sympy.core.kind import NumberKind +from sympy.core.numbers import (oo, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.trigonometric import (asin, sin) +from sympy.logic.boolalg import And +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.sets.sets import Interval +from sympy.testing.pytest import raises, warns_deprecated_sympy + + +w = Symbol('w') +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +f = Function('f') + + +def test_CondSet(): + sin_sols_principal = ConditionSet(x, Eq(sin(x), 0), + Interval(0, 2*pi, False, True)) + assert pi in sin_sols_principal + assert pi/2 not in sin_sols_principal + assert 3*pi not in sin_sols_principal + assert oo not in sin_sols_principal + assert 5 in ConditionSet(x, x**2 > 4, S.Reals) + assert 1 not in ConditionSet(x, x**2 > 4, S.Reals) + # in this case, 0 is not part of the base set so + # it can't be in any subset selected by the condition + assert 0 not in ConditionSet(x, y > 5, Interval(1, 7)) + # since 'in' requires a true/false, the following raises + # an error because the given value provides no information + # for the condition to evaluate (since the condition does + # not depend on the dummy symbol): the result is `y > 5`. + # In this case, ConditionSet is just acting like + # Piecewise((Interval(1, 7), y > 5), (S.EmptySet, True)). + raises(TypeError, lambda: 6 in ConditionSet(x, y > 5, + Interval(1, 7))) + + X = MatrixSymbol('X', 2, 2) + matrix_set = ConditionSet(X, Eq(X*Matrix([[1, 1], [1, 1]]), X)) + Y = Matrix([[0, 0], [0, 0]]) + assert matrix_set.contains(Y).doit() is S.true + Z = Matrix([[1, 2], [3, 4]]) + assert matrix_set.contains(Z).doit() is S.false + + assert isinstance(ConditionSet(x, x < 1, {x, y}).base_set, + FiniteSet) + raises(TypeError, lambda: ConditionSet(x, x + 1, {x, y})) + raises(TypeError, lambda: ConditionSet(x, x, 1)) + + I = S.Integers + U = S.UniversalSet + C = ConditionSet + assert C(x, False, I) is S.EmptySet + assert C(x, True, I) is I + assert C(x, x < 1, C(x, x < 2, I) + ) == C(x, (x < 1) & (x < 2), I) + assert C(y, y < 1, C(x, y < 2, I) + ) == C(x, (x < 1) & (y < 2), I), C(y, y < 1, C(x, y < 2, I)) + assert C(y, y < 1, C(x, x < 2, I) + ) == C(y, (y < 1) & (y < 2), I) + assert C(y, y < 1, C(x, y < x, I) + ) == C(x, (x < 1) & (y < x), I) + assert unchanged(C, y, x < 1, C(x, y < x, I)) + assert ConditionSet(x, x < 1).base_set is U + # arg checking is not done at instantiation but this + # will raise an error when containment is tested + assert ConditionSet((x,), x < 1).base_set is U + + c = ConditionSet((x, y), x < y, I**2) + assert (1, 2) in c + assert (1, pi) not in c + + raises(TypeError, lambda: C(x, x > 1, C((x, y), x > 1, I**2))) + # signature mismatch since only 3 args are accepted + raises(TypeError, lambda: C((x, y), x + y < 2, U, U)) + + +def test_CondSet_intersect(): + input_conditionset = ConditionSet(x, x**2 > 4, Interval(1, 4, False, + False)) + other_domain = Interval(0, 3, False, False) + output_conditionset = ConditionSet(x, x**2 > 4, Interval( + 1, 3, False, False)) + assert Intersection(input_conditionset, other_domain + ) == output_conditionset + + +def test_issue_9849(): + assert ConditionSet(x, Eq(x, x), S.Naturals + ) is S.Naturals + assert ConditionSet(x, Eq(Abs(sin(x)), -1), S.Naturals + ) == S.EmptySet + + +def test_simplified_FiniteSet_in_CondSet(): + assert ConditionSet(x, And(x < 1, x > -3), FiniteSet(0, 1, 2) + ) == FiniteSet(0) + assert ConditionSet(x, x < 0, FiniteSet(0, 1, 2)) == EmptySet + assert ConditionSet(x, And(x < -3), EmptySet) == EmptySet + y = Symbol('y') + assert (ConditionSet(x, And(x > 0), FiniteSet(-1, 0, 1, y)) == + Union(FiniteSet(1), ConditionSet(x, And(x > 0), FiniteSet(y)))) + assert (ConditionSet(x, Eq(Mod(x, 3), 1), FiniteSet(1, 4, 2, y)) == + Union(FiniteSet(1, 4), ConditionSet(x, Eq(Mod(x, 3), 1), + FiniteSet(y)))) + + +def test_free_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).free_symbols == {y, z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(z) + ).free_symbols == {z} + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, z) + ).free_symbols == {x, z} + assert ConditionSet(x, Eq(x, 0), ImageSet(Lambda(y, y**2), + S.Integers)).free_symbols == set() + + +def test_bound_symbols(): + assert ConditionSet(x, Eq(y, 0), FiniteSet(z) + ).bound_symbols == [x] + assert ConditionSet(x, Eq(x, 0), FiniteSet(x, y) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ImageSet(Lambda(y, y**2), S.Integers) + ).bound_symbols == [x] + assert ConditionSet(x, x < 10, ConditionSet(y, y > 1, S.Integers) + ).bound_symbols == [x] + + +def test_as_dummy(): + _0, _1 = symbols('_0 _1') + assert ConditionSet(x, x < 1, Interval(y, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(y, oo)) + assert ConditionSet(x, x < 1, Interval(x, oo) + ).as_dummy() == ConditionSet(_0, _0 < 1, Interval(x, oo)) + assert ConditionSet(x, x < 1, ImageSet(Lambda(y, y**2), S.Integers) + ).as_dummy() == ConditionSet( + _0, _0 < 1, ImageSet(Lambda(_0, _0**2), S.Integers)) + e = ConditionSet((x, y), x <= y, S.Reals**2) + assert e.bound_symbols == [x, y] + assert e.as_dummy() == ConditionSet((_0, _1), _0 <= _1, S.Reals**2) + assert e.as_dummy() == ConditionSet((y, x), y <= x, S.Reals**2 + ).as_dummy() + + +def test_subs_CondSet(): + s = FiniteSet(z, y) + c = ConditionSet(x, x < 2, s) + assert c.subs(x, y) == c + assert c.subs(z, y) == ConditionSet(x, x < 2, FiniteSet(y)) + assert c.xreplace({x: y}) == ConditionSet(y, y < 2, s) + + assert ConditionSet(x, x < y, s + ).subs(y, w) == ConditionSet(x, x < w, s.subs(y, w)) + # if the user uses assumptions that cause the condition + # to evaluate, that can't be helped from SymPy's end + n = Symbol('n', negative=True) + assert ConditionSet(n, 0 < n, S.Integers) is S.EmptySet + p = Symbol('p', positive=True) + assert ConditionSet(n, n < y, S.Integers + ).subs(n, x) == ConditionSet(n, n < y, S.Integers) + raises(ValueError, lambda: ConditionSet( + x + 1, x < 1, S.Integers)) + assert ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) == Interval(-5, 5), ConditionSet( + p, n < x, Interval(-5, 5)).subs(x, p) + assert ConditionSet( + n, n < x, Interval(-oo, 0)).subs(x, p + ) == Interval(-oo, 0) + + assert ConditionSet(f(x), f(x) < 1, {w, z} + ).subs(f(x), y) == ConditionSet(f(x), f(x) < 1, {w, z}) + + # issue 17341 + k = Symbol('k') + img1 = ImageSet(Lambda(k, 2*k*pi + asin(y)), S.Integers) + img2 = ImageSet(Lambda(k, 2*k*pi + asin(S.One/3)), S.Integers) + assert ConditionSet(x, Contains( + y, Interval(-1,1)), img1).subs(y, S.One/3).dummy_eq(img2) + + assert (0, 1) in ConditionSet((x, y), x + y < 3, S.Integers**2) + + raises(TypeError, lambda: ConditionSet(n, n < -10, Interval(0, 10))) + + +def test_subs_CondSet_tebr(): + with warns_deprecated_sympy(): + assert ConditionSet((x, y), {x + 1, x + y}, S.Reals**2) == \ + ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + + +def test_dummy_eq(): + C = ConditionSet + I = S.Integers + c = C(x, x < 1, I) + assert c.dummy_eq(C(y, y < 1, I)) + assert c.dummy_eq(1) == False + assert c.dummy_eq(C(x, x < 1, S.Reals)) == False + + c1 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c2 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Reals**2) + c3 = ConditionSet((x, y), Eq(x + 1, 0) & Eq(x + y, 0), S.Complexes**2) + assert c1.dummy_eq(c2) + assert c1.dummy_eq(c3) is False + assert c.dummy_eq(c1) is False + assert c1.dummy_eq(c) is False + + # issue 19496 + m = Symbol('m') + n = Symbol('n') + a = Symbol('a') + d1 = ImageSet(Lambda(m, m*pi), S.Integers) + d2 = ImageSet(Lambda(n, n*pi), S.Integers) + c1 = ConditionSet(x, Ne(a, 0), d1) + c2 = ConditionSet(x, Ne(a, 0), d2) + assert c1.dummy_eq(c2) + + +def test_contains(): + assert 6 in ConditionSet(x, x > 5, Interval(1, 7)) + assert (8 in ConditionSet(x, y > 5, Interval(1, 7))) is False + # `in` should give True or False; in this case there is not + # enough information for that result + raises(TypeError, + lambda: 6 in ConditionSet(x, y > 5, Interval(1, 7))) + # here, there is enough information but the comparison is + # not defined + raises(TypeError, lambda: 0 in ConditionSet(x, 1/x >= 0, S.Reals)) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(6) == (y > 5) + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(8) is S.false + assert ConditionSet(x, y > 5, Interval(1, 7) + ).contains(w) == And(Contains(w, Interval(1, 7)), y > 5) + # This returns an unevaluated Contains object + # because 1/0 should not be defined for 1 and 0 in the context of + # reals. + assert ConditionSet(x, 1/x >= 0, S.Reals).contains(0) == \ + Contains(0, ConditionSet(x, 1/x >= 0, S.Reals), evaluate=False) + c = ConditionSet((x, y), x + y > 1, S.Integers**2) + assert not c.contains(1) + assert c.contains((2, 1)) + assert not c.contains((0, 1)) + c = ConditionSet((w, (x, y)), w + x + y > 1, S.Integers*S.Integers**2) + assert not c.contains(1) + assert not c.contains((1, 2)) + assert not c.contains(((1, 2), 3)) + assert not c.contains(((1, 2), (3, 4))) + assert c.contains((1, (3, 4))) + + +def test_as_relational(): + assert ConditionSet((x, y), x > 1, S.Integers**2).as_relational((x, y) + ) == (x > 1) & Contains(x, S.Integers) & Contains(y, S.Integers) + assert ConditionSet(x, x > 1, S.Integers).as_relational(x + ) == Contains(x, S.Integers) & (x > 1) + + +def test_flatten(): + """Tests whether there is basic denesting functionality""" + inner = ConditionSet(x, sin(x) + x > 0) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(y, sin(y) + y > 0) + outer = ConditionSet(x, Contains(y, inner), S.Reals) + assert outer != ConditionSet(x, sin(x) + x > 0, S.Reals) + + inner = ConditionSet(x, sin(x) + x > 0).intersect(Interval(-1, 1)) + outer = ConditionSet(x, Contains(x, inner), S.Reals) + assert outer == ConditionSet(x, sin(x) + x > 0, Interval(-1, 1)) + + +def test_duplicate(): + from sympy.core.function import BadSignatureError + # test coverage for line 95 in conditionset.py, check for duplicates in symbols + dup = symbols('a,a') + raises(BadSignatureError, lambda: ConditionSet(dup, x < 0)) + + +def test_SetKind_ConditionSet(): + assert ConditionSet(x, Eq(sin(x), 0), Interval(0, 2*pi)).kind is SetKind(NumberKind) + assert ConditionSet(x, x < 0).kind is SetKind(NumberKind) diff --git a/lib/python3.10/site-packages/sympy/sets/tests/test_contains.py b/lib/python3.10/site-packages/sympy/sets/tests/test_contains.py new file mode 100644 index 0000000000000000000000000000000000000000..bb6b98940946f98bf377aad6810f5b32eb6dd069 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/tests/test_contains.py @@ -0,0 +1,52 @@ +from sympy.core.expr import unchanged +from sympy.core.numbers import oo +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.sets import (FiniteSet, Interval) +from sympy.testing.pytest import raises + + +def test_contains_basic(): + raises(TypeError, lambda: Contains(S.Integers, 1)) + assert Contains(2, S.Integers) is S.true + assert Contains(-2, S.Naturals) is S.false + + i = Symbol('i', integer=True) + assert Contains(i, S.Naturals) == Contains(i, S.Naturals, evaluate=False) + + +def test_issue_6194(): + x = Symbol('x') + assert unchanged(Contains, x, Interval(0, 1)) + assert Interval(0, 1).contains(x) == (S.Zero <= x) & (x <= 1) + assert Contains(x, FiniteSet(0)) != S.false + assert Contains(x, Interval(1, 1)) != S.false + assert Contains(x, S.Integers) != S.false + + +def test_issue_10326(): + assert Contains(oo, Interval(-oo, oo)) == False + assert Contains(-oo, Interval(-oo, oo)) == False + + +def test_binary_symbols(): + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + assert Contains(x, FiniteSet(y, Eq(z, True)) + ).binary_symbols == {y, z} + + +def test_as_set(): + x = Symbol('x') + y = Symbol('y') + assert Contains(x, FiniteSet(y)).as_set() == FiniteSet(y) + assert Contains(x, S.Integers).as_set() == S.Integers + assert Contains(x, S.Reals).as_set() == S.Reals + + +def test_type_error(): + # Pass in a parameter not of type "set" + raises(TypeError, lambda: Contains(2, None)) diff --git a/lib/python3.10/site-packages/sympy/sets/tests/test_fancysets.py b/lib/python3.10/site-packages/sympy/sets/tests/test_fancysets.py new file mode 100644 index 0000000000000000000000000000000000000000..b23c2a99fce0af5bfe7c667185465ee417de19ce --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/tests/test_fancysets.py @@ -0,0 +1,1313 @@ + +from sympy.core.expr import unchanged +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range, normalize_theta_set, + ComplexRegion) +from sympy.sets.sets import (FiniteSet, Interval, Union, imageset, + Intersection, ProductSet, SetKind) +from sympy.sets.conditionset import ConditionSet +from sympy.simplify.simplify import simplify +from sympy.core.basic import Basic +from sympy.core.containers import Tuple, TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind +from sympy.core.numbers import (I, Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin, tan) +from sympy.logic.boolalg import And +from sympy.matrices.dense import eye +from sympy.testing.pytest import XFAIL, raises +from sympy.abc import x, y, t, z +from sympy.core.mod import Mod + +import itertools + + +def test_naturals(): + N = S.Naturals + assert 5 in N + assert -5 not in N + assert 5.5 not in N + ni = iter(N) + a, b, c, d = next(ni), next(ni), next(ni), next(ni) + assert (a, b, c, d) == (1, 2, 3, 4) + assert isinstance(a, Basic) + + assert N.intersect(Interval(-5, 5)) == Range(1, 6) + assert N.intersect(Interval(-5, 5, True, True)) == Range(1, 5) + + assert N.boundary == N + assert N.is_open == False + assert N.is_closed == True + + assert N.inf == 1 + assert N.sup is oo + assert not N.contains(oo) + for s in (S.Naturals0, S.Naturals): + assert s.intersection(S.Reals) is s + assert s.is_subset(S.Reals) + + assert N.as_relational(x) == And(Eq(floor(x), x), x >= 1, x < oo) + + +def test_naturals0(): + N = S.Naturals0 + assert 0 in N + assert -1 not in N + assert next(iter(N)) == 0 + assert not N.contains(oo) + assert N.contains(sin(x)) == Contains(sin(x), N) + + +def test_integers(): + Z = S.Integers + assert 5 in Z + assert -5 in Z + assert 5.5 not in Z + assert not Z.contains(oo) + assert not Z.contains(-oo) + + zi = iter(Z) + a, b, c, d = next(zi), next(zi), next(zi), next(zi) + assert (a, b, c, d) == (0, 1, -1, 2) + assert isinstance(a, Basic) + + assert Z.intersect(Interval(-5, 5)) == Range(-5, 6) + assert Z.intersect(Interval(-5, 5, True, True)) == Range(-4, 5) + assert Z.intersect(Interval(5, S.Infinity)) == Range(5, S.Infinity) + assert Z.intersect(Interval.Lopen(5, S.Infinity)) == Range(6, S.Infinity) + + assert Z.inf is -oo + assert Z.sup is oo + + assert Z.boundary == Z + assert Z.is_open == False + assert Z.is_closed == True + + assert Z.as_relational(x) == And(Eq(floor(x), x), -oo < x, x < oo) + + +def test_ImageSet(): + raises(ValueError, lambda: ImageSet(x, S.Integers)) + assert ImageSet(Lambda(x, 1), S.Integers) == FiniteSet(1) + assert ImageSet(Lambda(x, y), S.Integers) == {y} + assert ImageSet(Lambda(x, 1), S.EmptySet) == S.EmptySet + empty = Intersection(FiniteSet(log(2)/pi), S.Integers) + assert unchanged(ImageSet, Lambda(x, 1), empty) # issue #17471 + squares = ImageSet(Lambda(x, x**2), S.Naturals) + assert 4 in squares + assert 5 not in squares + assert FiniteSet(*range(10)).intersect(squares) == FiniteSet(1, 4, 9) + + assert 16 not in squares.intersect(Interval(0, 10)) + + si = iter(squares) + a, b, c, d = next(si), next(si), next(si), next(si) + assert (a, b, c, d) == (1, 4, 9, 16) + + harmonics = ImageSet(Lambda(x, 1/x), S.Naturals) + assert Rational(1, 5) in harmonics + assert Rational(.25) in harmonics + assert harmonics.contains(.25) == Contains( + 0.25, ImageSet(Lambda(x, 1/x), S.Naturals), evaluate=False) + assert Rational(.3) not in harmonics + assert (1, 2) not in harmonics + + assert harmonics.is_iterable + + assert imageset(x, -x, Interval(0, 1)) == Interval(-1, 0) + + assert ImageSet(Lambda(x, x**2), Interval(0, 2)).doit() == Interval(0, 4) + assert ImageSet(Lambda((x, y), 2*x), {4}, {3}).doit() == FiniteSet(8) + assert (ImageSet(Lambda((x, y), x+y), {1, 2, 3}, {10, 20, 30}).doit() == + FiniteSet(11, 12, 13, 21, 22, 23, 31, 32, 33)) + + c = Interval(1, 3) * Interval(1, 3) + assert Tuple(2, 6) in ImageSet(Lambda(((x, y),), (x, 2*y)), c) + assert Tuple(2, S.Half) in ImageSet(Lambda(((x, y),), (x, 1/y)), c) + assert Tuple(2, -2) not in ImageSet(Lambda(((x, y),), (x, y**2)), c) + assert Tuple(2, -2) in ImageSet(Lambda(((x, y),), (x, -2)), c) + c3 = ProductSet(Interval(3, 7), Interval(8, 11), Interval(5, 9)) + assert Tuple(8, 3, 9) in ImageSet(Lambda(((t, y, x),), (y, t, x)), c3) + assert Tuple(Rational(1, 8), 3, 9) in ImageSet(Lambda(((t, y, x),), (1/y, t, x)), c3) + assert 2/pi not in ImageSet(Lambda(((x, y),), 2/x), c) + assert 2/S(100) not in ImageSet(Lambda(((x, y),), 2/x), c) + assert Rational(2, 3) in ImageSet(Lambda(((x, y),), 2/x), c) + + S1 = imageset(lambda x, y: x + y, S.Integers, S.Naturals) + assert S1.base_pset == ProductSet(S.Integers, S.Naturals) + assert S1.base_sets == (S.Integers, S.Naturals) + + # Passing a set instead of a FiniteSet shouldn't raise + assert unchanged(ImageSet, Lambda(x, x**2), {1, 2, 3}) + + S2 = ImageSet(Lambda(((x, y),), x+y), {(1, 2), (3, 4)}) + assert 3 in S2.doit() + # FIXME: This doesn't yet work: + #assert 3 in S2 + assert S2._contains(3) is None + + raises(TypeError, lambda: ImageSet(Lambda(x, x**2), 1)) + + +def test_image_is_ImageSet(): + assert isinstance(imageset(x, sqrt(sin(x)), Range(5)), ImageSet) + + +def test_halfcircle(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + + assert (1, 0) in halfcircle + assert (0, -1) not in halfcircle + assert (0, 0) in halfcircle + assert halfcircle._contains((r, 0)) is None + assert not halfcircle.is_iterable + + +@XFAIL +def test_halfcircle_fail(): + r, th = symbols('r, theta', real=True) + L = Lambda(((r, th),), (r*cos(th), r*sin(th))) + halfcircle = ImageSet(L, Interval(0, 1)*Interval(0, pi)) + assert (r, 2*pi) not in halfcircle + + +def test_ImageSet_iterator_not_injective(): + L = Lambda(x, x - x % 2) # produces 0, 2, 2, 4, 4, 6, 6, ... + evens = ImageSet(L, S.Naturals) + i = iter(evens) + # No repeats here + assert (next(i), next(i), next(i), next(i)) == (0, 2, 4, 6) + + +def test_inf_Range_len(): + raises(ValueError, lambda: len(Range(0, oo, 2))) + assert Range(0, oo, 2).size is S.Infinity + assert Range(0, -oo, -2).size is S.Infinity + assert Range(oo, 0, -2).size is S.Infinity + assert Range(-oo, 0, 2).size is S.Infinity + + +def test_Range_set(): + empty = Range(0) + + assert Range(5) == Range(0, 5) == Range(0, 5, 1) + + r = Range(10, 20, 2) + assert 12 in r + assert 8 not in r + assert 11 not in r + assert 30 not in r + + assert list(Range(0, 5)) == list(range(5)) + assert list(Range(5, 0, -1)) == list(range(5, 0, -1)) + + + assert Range(5, 15).sup == 14 + assert Range(5, 15).inf == 5 + assert Range(15, 5, -1).sup == 15 + assert Range(15, 5, -1).inf == 6 + assert Range(10, 67, 10).sup == 60 + assert Range(60, 7, -10).inf == 10 + + assert len(Range(10, 38, 10)) == 3 + + assert Range(0, 0, 5) == empty + assert Range(oo, oo, 1) == empty + assert Range(oo, 1, 1) == empty + assert Range(-oo, 1, -1) == empty + assert Range(1, oo, -1) == empty + assert Range(1, -oo, 1) == empty + assert Range(1, -4, oo) == empty + ip = symbols('ip', positive=True) + assert Range(0, ip, -1) == empty + assert Range(0, -ip, 1) == empty + assert Range(1, -4, -oo) == Range(1, 2) + assert Range(1, 4, oo) == Range(1, 2) + assert Range(-oo, oo).size == oo + assert Range(oo, -oo, -1).size == oo + raises(ValueError, lambda: Range(-oo, oo, 2)) + raises(ValueError, lambda: Range(x, pi, y)) + raises(ValueError, lambda: Range(x, y, 0)) + + assert 5 in Range(0, oo, 5) + assert -5 in Range(-oo, 0, 5) + assert oo not in Range(0, oo) + ni = symbols('ni', integer=False) + assert ni not in Range(oo) + u = symbols('u', integer=None) + assert Range(oo).contains(u) is not False + inf = symbols('inf', infinite=True) + assert inf not in Range(-oo, oo) + raises(ValueError, lambda: Range(0, oo, 2)[-1]) + raises(ValueError, lambda: Range(0, -oo, -2)[-1]) + assert Range(-oo, 1, 1)[-1] is S.Zero + assert Range(oo, 1, -1)[-1] == 2 + assert inf not in Range(oo) + assert Range(1, 10, 1)[-1] == 9 + assert all(i.is_Integer for i in Range(0, -1, 1)) + it = iter(Range(-oo, 0, 2)) + raises(TypeError, lambda: next(it)) + + assert empty.intersect(S.Integers) == empty + assert Range(-1, 10, 1).intersect(S.Complexes) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Reals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Rationals) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Integers) == Range(-1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals) == Range(1, 10, 1) + assert Range(-1, 10, 1).intersect(S.Naturals0) == Range(0, 10, 1) + + # test slicing + assert Range(1, 10, 1)[5] == 6 + assert Range(1, 12, 2)[5] == 11 + assert Range(1, 10, 1)[-1] == 9 + assert Range(1, 10, 3)[-1] == 7 + raises(ValueError, lambda: Range(oo,0,-1)[1:3:0]) + raises(ValueError, lambda: Range(oo,0,-1)[:1]) + raises(ValueError, lambda: Range(1, oo)[-2]) + raises(ValueError, lambda: Range(-oo, 1)[2]) + raises(IndexError, lambda: Range(10)[-20]) + raises(IndexError, lambda: Range(10)[20]) + raises(ValueError, lambda: Range(2, -oo, -2)[2:2:0]) + assert Range(2, -oo, -2)[2:2:2] == empty + assert Range(2, -oo, -2)[:2:2] == Range(2, -2, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:2]) + assert Range(-oo, 4, 2)[::-2] == Range(2, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[::2]) + assert Range(oo, 2, -2)[::] == Range(oo, 2, -2) + assert Range(-oo, 4, 2)[:-2:-2] == Range(2, 0, -4) + assert Range(-oo, 4, 2)[:-2:2] == Range(-oo, 0, 4) + raises(ValueError, lambda: Range(-oo, 4, 2)[:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[:2:-2]) + assert Range(-oo, 4, 2)[-2::-2] == Range(0, -oo, -4) + raises(ValueError, lambda: Range(-oo, 4, 2)[-2:0:-2]) + raises(ValueError, lambda: Range(-oo, 4, 2)[0::2]) + assert Range(oo, 2, -2)[0::] == Range(oo, 2, -2) + raises(ValueError, lambda: Range(-oo, 4, 2)[0:-2:2]) + assert Range(oo, 2, -2)[0:-2:] == Range(oo, 6, -2) + raises(ValueError, lambda: Range(oo, 2, -2)[0:2:]) + raises(ValueError, lambda: Range(-oo, 4, 2)[2::-1]) + assert Range(-oo, 4, 2)[-2::2] == Range(0, 4, 4) + assert Range(oo, 0, -2)[-10:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[0]) + raises(ValueError, lambda: Range(oo, 0, -2)[-10:10:2]) + raises(ValueError, lambda: Range(oo, 0, -2)[0::-2]) + assert Range(oo, 0, -2)[0:-4:-2] == empty + assert Range(oo, 0, -2)[:0:2] == empty + raises(ValueError, lambda: Range(oo, 0, -2)[:1:-1]) + + # test empty Range + assert Range(x, x, y) == empty + assert empty.reversed == empty + assert 0 not in empty + assert list(empty) == [] + assert len(empty) == 0 + assert empty.size is S.Zero + assert empty.intersect(FiniteSet(0)) is S.EmptySet + assert bool(empty) is False + raises(IndexError, lambda: empty[0]) + assert empty[:0] == empty + raises(NotImplementedError, lambda: empty.inf) + raises(NotImplementedError, lambda: empty.sup) + assert empty.as_relational(x) is S.false + + AB = [None] + list(range(12)) + for R in [ + Range(1, 10), + Range(1, 10, 2), + ]: + r = list(R) + for a, b, c in itertools.product(AB, AB, [-3, -1, None, 1, 3]): + for reverse in range(2): + r = list(reversed(r)) + R = R.reversed + result = list(R[a:b:c]) + ans = r[a:b:c] + txt = ('\n%s[%s:%s:%s] = %s -> %s' % ( + R, a, b, c, result, ans)) + check = ans == result + assert check, txt + + assert Range(1, 10, 1).boundary == Range(1, 10, 1) + + for r in (Range(1, 10, 2), Range(1, oo, 2)): + rev = r.reversed + assert r.inf == rev.inf and r.sup == rev.sup + assert r.step == -rev.step + + builtin_range = range + + raises(TypeError, lambda: Range(builtin_range(1))) + assert S(builtin_range(10)) == Range(10) + assert S(builtin_range(1000000000000)) == Range(1000000000000) + + # test Range.as_relational + assert Range(1, 4).as_relational(x) == (x >= 1) & (x <= 3) & Eq(Mod(x, 1), 0) + assert Range(oo, 1, -2).as_relational(x) == (x >= 3) & (x < oo) & Eq(Mod(x + 1, -2), 0) + + +def test_Range_symbolic(): + # symbolic Range + xr = Range(x, x + 4, 5) + sr = Range(x, y, t) + i = Symbol('i', integer=True) + ip = Symbol('i', integer=True, positive=True) + ipr = Range(ip) + inr = Range(0, -ip, -1) + ir = Range(i, i + 19, 2) + ir2 = Range(i, i*8, 3*i) + i = Symbol('i', integer=True) + inf = symbols('inf', infinite=True) + raises(ValueError, lambda: Range(inf)) + raises(ValueError, lambda: Range(inf, 0, -1)) + raises(ValueError, lambda: Range(inf, inf, 1)) + raises(ValueError, lambda: Range(1, 1, inf)) + # args + assert xr.args == (x, x + 5, 5) + assert sr.args == (x, y, t) + assert ir.args == (i, i + 20, 2) + assert ir2.args == (i, 10*i, 3*i) + # reversed + raises(ValueError, lambda: xr.reversed) + raises(ValueError, lambda: sr.reversed) + assert ipr.reversed.args == (ip - 1, -1, -1) + assert inr.reversed.args == (-ip + 1, 1, 1) + assert ir.reversed.args == (i + 18, i - 2, -2) + assert ir2.reversed.args == (7*i, -2*i, -3*i) + # contains + assert inf not in sr + assert inf not in ir + assert 0 in ipr + assert 0 in inr + raises(TypeError, lambda: 1 in ipr) + raises(TypeError, lambda: -1 in inr) + assert .1 not in sr + assert .1 not in ir + assert i + 1 not in ir + assert i + 2 in ir + raises(TypeError, lambda: x in xr) # XXX is this what contains is supposed to do? + raises(TypeError, lambda: 1 in sr) # XXX is this what contains is supposed to do? + # iter + raises(ValueError, lambda: next(iter(xr))) + raises(ValueError, lambda: next(iter(sr))) + assert next(iter(ir)) == i + assert next(iter(ir2)) == i + assert sr.intersect(S.Integers) == sr + assert sr.intersect(FiniteSet(x)) == Intersection({x}, sr) + raises(ValueError, lambda: sr[:2]) + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + # len + assert len(ir) == ir.size == 10 + assert len(ir2) == ir2.size == 3 + raises(ValueError, lambda: len(xr)) + raises(ValueError, lambda: xr.size) + raises(ValueError, lambda: len(sr)) + raises(ValueError, lambda: sr.size) + # bool + assert bool(Range(0)) == False + assert bool(xr) + assert bool(ir) + assert bool(ipr) + assert bool(inr) + raises(ValueError, lambda: bool(sr)) + raises(ValueError, lambda: bool(ir2)) + # inf + raises(ValueError, lambda: xr.inf) + raises(ValueError, lambda: sr.inf) + assert ipr.inf == 0 + assert inr.inf == -ip + 1 + assert ir.inf == i + raises(ValueError, lambda: ir2.inf) + # sup + raises(ValueError, lambda: xr.sup) + raises(ValueError, lambda: sr.sup) + assert ipr.sup == ip - 1 + assert inr.sup == 0 + assert ir.inf == i + raises(ValueError, lambda: ir2.sup) + # getitem + raises(ValueError, lambda: xr[0]) + raises(ValueError, lambda: sr[0]) + raises(ValueError, lambda: sr[-1]) + raises(ValueError, lambda: sr[:2]) + assert ir[:2] == Range(i, i + 4, 2) + assert ir[0] == i + assert ir[-2] == i + 16 + assert ir[-1] == i + 18 + assert ir2[:2] == Range(i, 7*i, 3*i) + assert ir2[0] == i + assert ir2[-2] == 4*i + assert ir2[-1] == 7*i + raises(ValueError, lambda: Range(i)[-1]) + assert ipr[0] == ipr.inf == 0 + assert ipr[-1] == ipr.sup == ip - 1 + assert inr[0] == inr.sup == 0 + assert inr[-1] == inr.inf == -ip + 1 + raises(ValueError, lambda: ipr[-2]) + assert ir.inf == i + assert ir.sup == i + 18 + raises(ValueError, lambda: Range(i).inf) + # as_relational + assert ir.as_relational(x) == ((x >= i) & (x <= i + 18) & + Eq(Mod(-i + x, 2), 0)) + assert ir2.as_relational(x) == Eq( + Mod(-i + x, 3*i), 0) & (((x >= i) & (x <= 7*i) & (3*i >= 1)) | + ((x <= i) & (x >= 7*i) & (3*i <= -1))) + assert Range(i, i + 1).as_relational(x) == Eq(x, i) + assert sr.as_relational(z) == Eq( + Mod(t, 1), 0) & Eq(Mod(x, 1), 0) & Eq(Mod(-x + z, t), 0 + ) & (((z >= x) & (z <= -t + y) & (t >= 1)) | + ((z <= x) & (z >= -t + y) & (t <= -1))) + assert xr.as_relational(z) == Eq(z, x) & Eq(Mod(x, 1), 0) + # symbols can clash if user wants (but it must be integer) + assert xr.as_relational(x) == Eq(Mod(x, 1), 0) + # contains() for symbolic values (issue #18146) + e = Symbol('e', integer=True, even=True) + o = Symbol('o', integer=True, odd=True) + assert Range(5).contains(i) == And(i >= 0, i <= 4) + assert Range(1).contains(i) == Eq(i, 0) + assert Range(-oo, 5, 1).contains(i) == (i <= 4) + assert Range(-oo, oo).contains(i) == True + assert Range(0, 8, 2).contains(i) == Contains(i, Range(0, 8, 2)) + assert Range(0, 8, 2).contains(e) == And(e >= 0, e <= 6) + assert Range(0, 8, 2).contains(2*i) == And(2*i >= 0, 2*i <= 6) + assert Range(0, 8, 2).contains(o) == False + assert Range(1, 9, 2).contains(e) == False + assert Range(1, 9, 2).contains(o) == And(o >= 1, o <= 7) + assert Range(8, 0, -2).contains(o) == False + assert Range(9, 1, -2).contains(o) == And(o >= 3, o <= 9) + assert Range(-oo, 8, 2).contains(i) == Contains(i, Range(-oo, 8, 2)) + + +def test_range_range_intersection(): + for a, b, r in [ + (Range(0), Range(1), S.EmptySet), + (Range(3), Range(4, oo), S.EmptySet), + (Range(3), Range(-3, -1), S.EmptySet), + (Range(1, 3), Range(0, 3), Range(1, 3)), + (Range(1, 3), Range(1, 4), Range(1, 3)), + (Range(1, oo, 2), Range(2, oo, 2), S.EmptySet), + (Range(0, oo, 2), Range(oo), Range(0, oo, 2)), + (Range(0, oo, 2), Range(100), Range(0, 100, 2)), + (Range(2, oo, 2), Range(oo), Range(2, oo, 2)), + (Range(0, oo, 2), Range(5, 6), S.EmptySet), + (Range(2, 80, 1), Range(55, 71, 4), Range(55, 71, 4)), + (Range(0, 6, 3), Range(-oo, 5, 3), S.EmptySet), + (Range(0, oo, 2), Range(5, oo, 3), Range(8, oo, 6)), + (Range(4, 6, 2), Range(2, 16, 7), S.EmptySet),]: + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + a, b = b, a + assert a.intersect(b) == r + assert a.intersect(b.reversed) == r + assert a.reversed.intersect(b) == r + assert a.reversed.intersect(b.reversed) == r + + +def test_range_interval_intersection(): + p = symbols('p', positive=True) + assert isinstance(Range(3).intersect(Interval(p, p + 2)), Intersection) + assert Range(4).intersect(Interval(0, 3)) == Range(4) + assert Range(4).intersect(Interval(-oo, oo)) == Range(4) + assert Range(4).intersect(Interval(1, oo)) == Range(1, 4) + assert Range(4).intersect(Interval(1.1, oo)) == Range(2, 4) + assert Range(4).intersect(Interval(0.1, 3)) == Range(1, 4) + assert Range(4).intersect(Interval(0.1, 3.1)) == Range(1, 4) + assert Range(4).intersect(Interval.open(0, 3)) == Range(1, 3) + assert Range(4).intersect(Interval.open(0.1, 0.5)) is S.EmptySet + assert Interval(-1, 5).intersect(S.Complexes) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Reals) == Interval(-1, 5) + assert Interval(-1, 5).intersect(S.Integers) == Range(-1, 6) + assert Interval(-1, 5).intersect(S.Naturals) == Range(1, 6) + assert Interval(-1, 5).intersect(S.Naturals0) == Range(0, 6) + + # Null Range intersections + assert Range(0).intersect(Interval(0.2, 0.8)) is S.EmptySet + assert Range(0).intersect(Interval(-oo, oo)) is S.EmptySet + + +def test_range_is_finite_set(): + assert Range(-100, 100).is_finite_set is True + assert Range(2, oo).is_finite_set is False + assert Range(-oo, 50).is_finite_set is False + assert Range(-oo, oo).is_finite_set is False + assert Range(oo, -oo).is_finite_set is True + assert Range(0, 0).is_finite_set is True + assert Range(oo, oo).is_finite_set is True + assert Range(-oo, -oo).is_finite_set is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + assert Range(n, n + 49).is_finite_set is True + assert Range(n, 0).is_finite_set is True + assert Range(-3, n + 7).is_finite_set is True + assert Range(n, m).is_finite_set is True + assert Range(n + m, m - n).is_finite_set is True + assert Range(n, n + m + n).is_finite_set is True + assert Range(n, oo).is_finite_set is False + assert Range(-oo, n).is_finite_set is False + assert Range(n, -oo).is_finite_set is True + assert Range(oo, n).is_finite_set is True + + +def test_Range_is_iterable(): + assert Range(-100, 100).is_iterable is True + assert Range(2, oo).is_iterable is False + assert Range(-oo, 50).is_iterable is False + assert Range(-oo, oo).is_iterable is False + assert Range(oo, -oo).is_iterable is True + assert Range(0, 0).is_iterable is True + assert Range(oo, oo).is_iterable is True + assert Range(-oo, -oo).is_iterable is True + n = Symbol('n', integer=True) + m = Symbol('m', integer=True) + p = Symbol('p', integer=True, positive=True) + assert Range(n, n + 49).is_iterable is True + assert Range(n, 0).is_iterable is False + assert Range(-3, n + 7).is_iterable is False + assert Range(-3, p + 7).is_iterable is False # Should work with better __iter__ + assert Range(n, m).is_iterable is False + assert Range(n + m, m - n).is_iterable is False + assert Range(n, n + m + n).is_iterable is False + assert Range(n, oo).is_iterable is False + assert Range(-oo, n).is_iterable is False + x = Symbol('x') + assert Range(x, x + 49).is_iterable is False + assert Range(x, 0).is_iterable is False + assert Range(-3, x + 7).is_iterable is False + assert Range(x, m).is_iterable is False + assert Range(x + m, m - x).is_iterable is False + assert Range(x, x + m + x).is_iterable is False + assert Range(x, oo).is_iterable is False + assert Range(-oo, x).is_iterable is False + + +def test_Integers_eval_imageset(): + ans = ImageSet(Lambda(x, 2*x + Rational(3, 7)), S.Integers) + im = imageset(Lambda(x, -2*x + Rational(3, 7)), S.Integers) + assert im == ans + im = imageset(Lambda(x, -2*x - Rational(11, 7)), S.Integers) + assert im == ans + y = Symbol('y') + L = imageset(x, 2*x + y, S.Integers) + assert y + 4 in L + a, b, c = 0.092, 0.433, 0.341 + assert a in imageset(x, a + c*x, S.Integers) + assert b in imageset(x, b + c*x, S.Integers) + + _x = symbols('x', negative=True) + eq = _x**2 - _x + 1 + assert imageset(_x, eq, S.Integers).lamda.expr == _x**2 + _x + 1 + eq = 3*_x - 1 + assert imageset(_x, eq, S.Integers).lamda.expr == 3*_x + 2 + + assert imageset(x, (x, 1/x), S.Integers) == \ + ImageSet(Lambda(x, (x, 1/x)), S.Integers) + + +def test_Range_eval_imageset(): + a, b, c = symbols('a b c') + assert imageset(x, a*(x + b) + c, Range(3)) == \ + imageset(x, a*x + a*b + c, Range(3)) + eq = (x + 1)**2 + assert imageset(x, eq, Range(3)).lamda.expr == eq + eq = a*(x + b) + c + r = Range(3, -3, -2) + imset = imageset(x, eq, r) + assert imset.lamda.expr != eq + assert list(imset) == [eq.subs(x, i).expand() for i in list(r)] + + +def test_fun(): + assert (FiniteSet(*ImageSet(Lambda(x, sin(pi*x/4)), + Range(-10, 11))) == FiniteSet(-1, -sqrt(2)/2, 0, sqrt(2)/2, 1)) + + +def test_Range_is_empty(): + i = Symbol('i', integer=True) + n = Symbol('n', negative=True, integer=True) + p = Symbol('p', positive=True, integer=True) + + assert Range(0).is_empty + assert not Range(1).is_empty + assert Range(1, 0).is_empty + assert not Range(-1, 0).is_empty + assert Range(i).is_empty is None + assert Range(n).is_empty + assert Range(p).is_empty is False + assert Range(n, 0).is_empty is False + assert Range(n, p).is_empty is False + assert Range(p, n).is_empty + assert Range(n, -1).is_empty is None + assert Range(p, n, -1).is_empty is False + + +def test_Reals(): + assert 5 in S.Reals + assert S.Pi in S.Reals + assert -sqrt(2) in S.Reals + assert (2, 5) not in S.Reals + assert sqrt(-1) not in S.Reals + assert S.Reals == Interval(-oo, oo) + assert S.Reals != Interval(0, oo) + assert S.Reals.is_subset(Interval(-oo, oo)) + assert S.Reals.intersect(Range(-oo, oo)) == Range(-oo, oo) + assert S.ComplexInfinity not in S.Reals + assert S.NaN not in S.Reals + assert x + S.ComplexInfinity not in S.Reals + + +def test_Complex(): + assert 5 in S.Complexes + assert 5 + 4*I in S.Complexes + assert S.Pi in S.Complexes + assert -sqrt(2) in S.Complexes + assert -I in S.Complexes + assert sqrt(-1) in S.Complexes + assert S.Complexes.intersect(S.Reals) == S.Reals + assert S.Complexes.union(S.Reals) == S.Complexes + assert S.Complexes == ComplexRegion(S.Reals*S.Reals) + assert (S.Complexes == ComplexRegion(Interval(1, 2)*Interval(3, 4))) == False + assert str(S.Complexes) == "Complexes" + assert repr(S.Complexes) == "Complexes" + + +def take(n, iterable): + "Return first n items of the iterable as a list" + return list(itertools.islice(iterable, n)) + + +def test_intersections(): + assert S.Integers.intersect(S.Reals) == S.Integers + assert 5 in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(S.Reals) + assert -5 not in S.Naturals.intersect(S.Reals) + assert 5.5 not in S.Integers.intersect(S.Reals) + assert 5 in S.Integers.intersect(Interval(3, oo)) + assert -5 in S.Integers.intersect(Interval(-oo, 3)) + assert all(x.is_Integer + for x in take(10, S.Integers.intersect(Interval(3, oo)) )) + + +def test_infinitely_indexed_set_1(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == imageset(Lambda(m, m), S.Integers) + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(m, 2*m + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(n, 2*n), S.Integers).intersect( + imageset(Lambda(n, 2*n + 1), S.Integers)) is S.EmptySet + + assert imageset(Lambda(m, 2*m), S.Integers).intersect( + imageset(Lambda(n, 3*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*t), S.Integers)) + + assert imageset(x, x/2 + Rational(1, 3), S.Integers).intersect(S.Integers) is S.EmptySet + assert imageset(x, x/2 + S.Half, S.Integers).intersect(S.Integers) is S.Integers + + # https://github.com/sympy/sympy/issues/17355 + S53 = ImageSet(Lambda(n, 5*n + 3), S.Integers) + assert S53.intersect(S.Integers) == S53 + + +def test_infinitely_indexed_set_2(): + from sympy.abc import n + a = Symbol('a', integer=True) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, n + a), S.Integers) + assert imageset(Lambda(n, n + pi), S.Integers) == \ + imageset(Lambda(n, n + a + pi), S.Integers) + assert imageset(Lambda(n, n), S.Integers) == \ + imageset(Lambda(n, -n + a), S.Integers) + assert imageset(Lambda(n, -6*n), S.Integers) == \ + ImageSet(Lambda(n, 6*n), S.Integers) + assert imageset(Lambda(n, 2*n + pi), S.Integers) == \ + ImageSet(Lambda(n, 2*n + pi - 2), S.Integers) + + +def test_imageset_intersect_real(): + from sympy.abc import n + assert imageset(Lambda(n, n + (n - 1)*(n + 1)*I), S.Integers).intersect(S.Reals) == FiniteSet(-1, 1) + im = (n - 1)*(n + S.Half) + assert imageset(Lambda(n, n + im*I), S.Integers + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n + im*(n + 1)*I), S.Naturals0 + ).intersect(S.Reals) == FiniteSet(1) + assert imageset(Lambda(n, n/2 + im.expand()*I), S.Integers + ).intersect(S.Reals) == ImageSet(Lambda(x, x/2), ConditionSet( + n, Eq(n**2 - n/2 - S(1)/2, 0), S.Integers)) + assert imageset(Lambda(n, n/(1/n - 1) + im*(n + 1)*I), S.Integers + ).intersect(S.Reals) == FiniteSet(S.Half) + assert imageset(Lambda(n, n/(n - 6) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) == FiniteSet(-1) + assert imageset(Lambda(n, n/(n**2 - 9) + + (n - 3)*(n + 1)*I/(2*n + 2)), S.Integers).intersect( + S.Reals) is S.EmptySet + s = ImageSet( + Lambda(n, -I*(I*(2*pi*n - pi/4) + log(Abs(sqrt(-I))))), + S.Integers) + # s is unevaluated, but after intersection the result + # should be canonical + assert s.intersect(S.Reals) == imageset( + Lambda(n, 2*n*pi - pi/4), S.Integers) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_imageset_intersect_interval(): + from sympy.abc import n + f1 = ImageSet(Lambda(n, n*pi), S.Integers) + f2 = ImageSet(Lambda(n, 2*n), Interval(0, pi)) + f3 = ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers) + # complex expressions + f4 = ImageSet(Lambda(n, n*I*pi), S.Integers) + f5 = ImageSet(Lambda(n, 2*I*n*pi + pi/2), S.Integers) + # non-linear expressions + f6 = ImageSet(Lambda(n, log(n)), S.Integers) + f7 = ImageSet(Lambda(n, n**2), S.Integers) + f8 = ImageSet(Lambda(n, Abs(n)), S.Integers) + f9 = ImageSet(Lambda(n, exp(n)), S.Naturals0) + + assert f1.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f1.intersect(Interval(0, 2*pi, False, True)) == FiniteSet(0, pi) + assert f2.intersect(Interval(1, 2)) == Interval(1, 2) + assert f3.intersect(Interval(-1, 1)) == S.EmptySet + assert f3.intersect(Interval(-5, 5)) == FiniteSet(pi*Rational(-3, 2), pi/2) + assert f4.intersect(Interval(-1, 1)) == FiniteSet(0) + assert f4.intersect(Interval(1, 2)) == S.EmptySet + assert f5.intersect(Interval(0, 1)) == S.EmptySet + assert f6.intersect(Interval(0, 1)) == FiniteSet(S.Zero, log(2)) + assert f7.intersect(Interval(0, 10)) == Intersection(f7, Interval(0, 10)) + assert f8.intersect(Interval(0, 2)) == Intersection(f8, Interval(0, 2)) + assert f9.intersect(Interval(1, 2)) == Intersection(f9, Interval(1, 2)) + + +def test_imageset_intersect_diophantine(): + from sympy.abc import m, n + # Check that same lambda variable for both ImageSets is handled correctly + img1 = ImageSet(Lambda(n, 2*n + 1), S.Integers) + img2 = ImageSet(Lambda(n, 4*n + 1), S.Integers) + assert img1.intersect(img2) == img2 + # Empty solution set returned by diophantine: + assert ImageSet(Lambda(n, 2*n), S.Integers).intersect( + ImageSet(Lambda(n, 2*n + 1), S.Integers)) == S.EmptySet + # Check intersection with S.Integers: + assert ImageSet(Lambda(n, 9/n + 20*n/3), S.Integers).intersect( + S.Integers) == FiniteSet(-61, -23, 23, 61) + # Single solution (2, 3) for diophantine solution: + assert ImageSet(Lambda(n, (n - 2)**2), S.Integers).intersect( + ImageSet(Lambda(n, -(n - 3)**2), S.Integers)) == FiniteSet(0) + # Single parametric solution for diophantine solution: + assert ImageSet(Lambda(n, n**2 + 5), S.Integers).intersect( + ImageSet(Lambda(m, 2*m), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 4*n**2 + 4*n + 6), S.Integers)) + # 4 non-parametric solution couples for dioph. equation: + assert ImageSet(Lambda(n, n**2 - 9), S.Integers).intersect( + ImageSet(Lambda(m, -m**2), S.Integers)) == FiniteSet(-9, 0) + # Double parametric solution for diophantine solution: + assert ImageSet(Lambda(m, m**2 + 40), S.Integers).intersect( + ImageSet(Lambda(n, 41*n), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(m, m**2 + 40), S.Integers), + ImageSet(Lambda(n, 41*n), S.Integers))) + # Check that diophantine returns *all* (8) solutions (permute=True) + assert ImageSet(Lambda(n, n**4 - 2**4), S.Integers).intersect( + ImageSet(Lambda(m, -m**4 + 3**4), S.Integers)) == FiniteSet(0, 65) + assert ImageSet(Lambda(n, pi/12 + n*5*pi/12), S.Integers).intersect( + ImageSet(Lambda(n, 7*pi/12 + n*11*pi/12), S.Integers)).dummy_eq(ImageSet( + Lambda(n, 55*pi*n/12 + 17*pi/4), S.Integers)) + # TypeError raised by diophantine (#18081) + assert ImageSet(Lambda(n, n*log(2)), S.Integers).intersection( + S.Integers).dummy_eq(Intersection(ImageSet( + Lambda(n, n*log(2)), S.Integers), S.Integers)) + # NotImplementedError raised by diophantine (no solver for cubic_thue) + assert ImageSet(Lambda(n, n**3 + 1), S.Integers).intersect( + ImageSet(Lambda(n, n**3), S.Integers)).dummy_eq(Intersection( + ImageSet(Lambda(n, n**3 + 1), S.Integers), + ImageSet(Lambda(n, n**3), S.Integers))) + + +def test_infinitely_indexed_set_3(): + from sympy.abc import n, m + assert imageset(Lambda(m, 2*pi*m), S.Integers).intersect( + imageset(Lambda(n, 3*pi*n), S.Integers)).dummy_eq( + ImageSet(Lambda(t, 6*pi*t), S.Integers)) + assert imageset(Lambda(n, 2*n + 1), S.Integers) == \ + imageset(Lambda(n, 2*n - 1), S.Integers) + assert imageset(Lambda(n, 3*n + 2), S.Integers) == \ + imageset(Lambda(n, 3*n - 1), S.Integers) + + +def test_ImageSet_simplification(): + from sympy.abc import n, m + assert imageset(Lambda(n, n), S.Integers) == S.Integers + assert imageset(Lambda(n, sin(n)), + imageset(Lambda(m, tan(m)), S.Integers)) == \ + imageset(Lambda(m, sin(tan(m))), S.Integers) + assert imageset(n, 1 + 2*n, S.Naturals) == Range(3, oo, 2) + assert imageset(n, 1 + 2*n, S.Naturals0) == Range(1, oo, 2) + assert imageset(n, 1 - 2*n, S.Naturals) == Range(-1, -oo, -2) + + +def test_ImageSet_contains(): + assert (2, S.Half) in imageset(x, (x, 1/x), S.Integers) + assert imageset(x, x + I*3, S.Integers).intersection(S.Reals) is S.EmptySet + i = Dummy(integer=True) + q = imageset(x, x + I*y, S.Integers).intersection(S.Reals) + assert q.subs(y, I*i).intersection(S.Integers) is S.Integers + q = imageset(x, x + I*y/x, S.Integers).intersection(S.Reals) + assert q.subs(y, 0) is S.Integers + assert q.subs(y, I*i*x).intersection(S.Integers) is S.Integers + z = cos(1)**2 + sin(1)**2 - 1 + q = imageset(x, x + I*z, S.Integers).intersection(S.Reals) + assert q is not S.EmptySet + + +def test_ComplexRegion_contains(): + r = Symbol('r', real=True) + # contains in ComplexRegion + a = Interval(2, 3) + b = Interval(4, 6) + c = Interval(7, 9) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*b, c*a)) + assert 2.5 + 4.5*I in c1 + assert 2 + 4*I in c1 + assert 3 + 4*I in c1 + assert 8 + 2.5*I in c2 + assert 2.5 + 6.1*I not in c1 + assert 4.5 + 3.2*I not in c1 + assert c1.contains(x) == Contains(x, c1, evaluate=False) + assert c1.contains(r) == False + assert c2.contains(x) == Contains(x, c2, evaluate=False) + assert c2.contains(r) == False + + r1 = Interval(0, 1) + theta1 = Interval(0, 2*S.Pi) + c3 = ComplexRegion(r1*theta1, polar=True) + assert (0.5 + I*6/10) in c3 + assert (S.Half + I*6/10) in c3 + assert (S.Half + .6*I) in c3 + assert (0.5 + .6*I) in c3 + assert I in c3 + assert 1 in c3 + assert 0 in c3 + assert 1 + I not in c3 + assert 1 - I not in c3 + assert c3.contains(x) == Contains(x, c3, evaluate=False) + assert c3.contains(r + 2*I) == Contains( + r + 2*I, c3, evaluate=False) # is in fact False + assert c3.contains(1/(1 + r**2)) == Contains( + 1/(1 + r**2), c3, evaluate=False) # is in fact True + + r2 = Interval(0, 3) + theta2 = Interval(pi, 2*pi, left_open=True) + c4 = ComplexRegion(r2*theta2, polar=True) + assert c4.contains(0) == True + assert c4.contains(2 + I) == False + assert c4.contains(-2 + I) == False + assert c4.contains(-2 - I) == True + assert c4.contains(2 - I) == True + assert c4.contains(-2) == False + assert c4.contains(2) == True + assert c4.contains(x) == Contains(x, c4, evaluate=False) + assert c4.contains(3/(1 + r**2)) == Contains( + 3/(1 + r**2), c4, evaluate=False) # is in fact True + + raises(ValueError, lambda: ComplexRegion(r1*theta1, polar=2)) + + +def test_symbolic_Range(): + n = Symbol('n') + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + raises(ValueError, lambda: Range(n, n+1)[0]) + raises(ValueError, lambda: Range(n).size) + + n = Symbol('n', integer=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n, n+1)[0] == n + raises(ValueError, lambda: Range(n).size) + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, nonnegative=True) + raises(ValueError, lambda: Range(n)[0]) + raises(IndexError, lambda: Range(n, n)[0]) + assert Range(n+1)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n+1).size == n+1 + assert Range(n, n+1).size == 1 + + n = Symbol('n', integer=True, positive=True) + assert Range(n)[0] == 0 + assert Range(n, n+1)[0] == n + assert Range(n).size == n + assert Range(n, n+1).size == 1 + + m = Symbol('m', integer=True, positive=True) + + assert Range(n, n+m)[0] == n + assert Range(n, n+m).size == m + assert Range(n, n+1).size == 1 + assert Range(n, n+m, 2).size == floor(m/2) + + m = Symbol('m', integer=True, positive=True, even=True) + assert Range(n, n+m, 2).size == m/2 + + +def test_issue_18400(): + n = Symbol('n', integer=True) + raises(ValueError, lambda: imageset(lambda x: x*2, Range(n))) + + n = Symbol('n', integer=True, positive=True) + # No exception + assert imageset(lambda x: x*2, Range(n)) == imageset(lambda x: x*2, Range(n)) + + +def test_ComplexRegion_intersect(): + # Polar form + X_axis = ComplexRegion(Interval(0, oo)*FiniteSet(0, S.Pi), polar=True) + + unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + upper_half_unit_disk = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + upper_half_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + lower_half_disk = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + right_half_disk = ComplexRegion(Interval(0, oo)*Interval(-S.Pi/2, S.Pi/2), polar=True) + first_quad_disk = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi/2), polar=True) + + assert upper_half_disk.intersect(unit_disk) == upper_half_unit_disk + assert right_half_disk.intersect(first_quad_disk) == first_quad_disk + assert upper_half_disk.intersect(right_half_disk) == first_quad_disk + assert upper_half_disk.intersect(lower_half_disk) == X_axis + + c1 = ComplexRegion(Interval(0, 4)*Interval(0, 2*S.Pi), polar=True) + assert c1.intersect(Interval(1, 5)) == Interval(1, 4) + assert c1.intersect(Interval(4, 9)) == FiniteSet(4) + assert c1.intersect(Interval(5, 12)) is S.EmptySet + + # Rectangular form + X_axis = ComplexRegion(Interval(-oo, oo)*FiniteSet(0)) + + unit_square = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + upper_half_unit_square = ComplexRegion(Interval(-1, 1)*Interval(0, 1)) + upper_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(0, oo)) + lower_half_plane = ComplexRegion(Interval(-oo, oo)*Interval(-oo, 0)) + right_half_plane = ComplexRegion(Interval(0, oo)*Interval(-oo, oo)) + first_quad_plane = ComplexRegion(Interval(0, oo)*Interval(0, oo)) + + assert upper_half_plane.intersect(unit_square) == upper_half_unit_square + assert right_half_plane.intersect(first_quad_plane) == first_quad_plane + assert upper_half_plane.intersect(right_half_plane) == first_quad_plane + assert upper_half_plane.intersect(lower_half_plane) == X_axis + + c1 = ComplexRegion(Interval(-5, 5)*Interval(-10, 10)) + assert c1.intersect(Interval(2, 7)) == Interval(2, 5) + assert c1.intersect(Interval(5, 7)) == FiniteSet(5) + assert c1.intersect(Interval(6, 9)) is S.EmptySet + + # unevaluated object + C1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + C2 = ComplexRegion(Interval(-1, 1)*Interval(-1, 1)) + assert C1.intersect(C2) == Intersection(C1, C2, evaluate=False) + + +def test_ComplexRegion_union(): + # Polar form + c1 = ComplexRegion(Interval(0, 1)*Interval(0, 2*S.Pi), polar=True) + c2 = ComplexRegion(Interval(0, 1)*Interval(0, S.Pi), polar=True) + c3 = ComplexRegion(Interval(0, oo)*Interval(0, S.Pi), polar=True) + c4 = ComplexRegion(Interval(0, oo)*Interval(S.Pi, 2*S.Pi), polar=True) + + p1 = Union(Interval(0, 1)*Interval(0, 2*S.Pi), Interval(0, 1)*Interval(0, S.Pi)) + p2 = Union(Interval(0, oo)*Interval(0, S.Pi), Interval(0, oo)*Interval(S.Pi, 2*S.Pi)) + + assert c1.union(c2) == ComplexRegion(p1, polar=True) + assert c3.union(c4) == ComplexRegion(p2, polar=True) + + # Rectangular form + c5 = ComplexRegion(Interval(2, 5)*Interval(6, 9)) + c6 = ComplexRegion(Interval(4, 6)*Interval(10, 12)) + c7 = ComplexRegion(Interval(0, 10)*Interval(-10, 0)) + c8 = ComplexRegion(Interval(12, 16)*Interval(14, 20)) + + p3 = Union(Interval(2, 5)*Interval(6, 9), Interval(4, 6)*Interval(10, 12)) + p4 = Union(Interval(0, 10)*Interval(-10, 0), Interval(12, 16)*Interval(14, 20)) + + assert c5.union(c6) == ComplexRegion(p3) + assert c7.union(c8) == ComplexRegion(p4) + + assert c1.union(Interval(2, 4)) == Union(c1, Interval(2, 4), evaluate=False) + assert c5.union(Interval(2, 4)) == Union(c5, ComplexRegion.from_real(Interval(2, 4))) + + +def test_ComplexRegion_from_real(): + c1 = ComplexRegion(Interval(0, 1) * Interval(0, 2 * S.Pi), polar=True) + + raises(ValueError, lambda: c1.from_real(c1)) + assert c1.from_real(Interval(-1, 1)) == ComplexRegion(Interval(-1, 1) * FiniteSet(0), False) + + +def test_ComplexRegion_measure(): + a, b = Interval(2, 5), Interval(4, 8) + theta1, theta2 = Interval(0, 2*S.Pi), Interval(0, S.Pi) + c1 = ComplexRegion(a*b) + c2 = ComplexRegion(Union(a*theta1, b*theta2), polar=True) + + assert c1.measure == 12 + assert c2.measure == 9*pi + + +def test_normalize_theta_set(): + # Interval + assert normalize_theta_set(Interval(pi, 2*pi)) == \ + Union(FiniteSet(0), Interval.Ropen(pi, 2*pi)) + assert normalize_theta_set(Interval(pi*Rational(9, 2), 5*pi)) == Interval(pi/2, pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), pi/2)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval.open(pi*Rational(-3, 2), pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval.open(pi*Rational(-7, 2), pi*Rational(-3, 2))) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi/2, 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-4*pi, 3*pi)) == Interval.Ropen(0, 2*pi) + assert normalize_theta_set(Interval(pi*Rational(-3, 2), -pi/2)) == Interval(pi/2, pi*Rational(3, 2)) + assert normalize_theta_set(Interval.open(0, 2*pi)) == Interval.open(0, 2*pi) + assert normalize_theta_set(Interval.Ropen(-pi/2, pi/2)) == \ + Union(Interval.Ropen(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.Lopen(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.open(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval(-pi/2, pi/2)) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + assert normalize_theta_set(Interval.open(4*pi, pi*Rational(9, 2))) == Interval.open(0, pi/2) + assert normalize_theta_set(Interval.Lopen(4*pi, pi*Rational(9, 2))) == Interval.Lopen(0, pi/2) + assert normalize_theta_set(Interval.Ropen(4*pi, pi*Rational(9, 2))) == Interval.Ropen(0, pi/2) + assert normalize_theta_set(Interval.open(3*pi, 5*pi)) == \ + Union(Interval.Ropen(0, pi), Interval.open(pi, 2*pi)) + + # FiniteSet + assert normalize_theta_set(FiniteSet(0, pi, 3*pi)) == FiniteSet(0, pi) + assert normalize_theta_set(FiniteSet(0, pi/2, pi, 2*pi)) == FiniteSet(0, pi/2, pi) + assert normalize_theta_set(FiniteSet(0, -pi/2, -pi, -2*pi)) == FiniteSet(0, pi, pi*Rational(3, 2)) + assert normalize_theta_set(FiniteSet(pi*Rational(-3, 2), pi/2)) == \ + FiniteSet(pi/2) + assert normalize_theta_set(FiniteSet(2*pi)) == FiniteSet(0) + + # Unions + assert normalize_theta_set(Union(Interval(0, pi/3), Interval(pi/2, pi))) == \ + Union(Interval(0, pi/3), Interval(pi/2, pi)) + assert normalize_theta_set(Union(Interval(0, pi), Interval(2*pi, pi*Rational(7, 3)))) == \ + Interval(0, pi) + + # ValueError for non-real sets + raises(ValueError, lambda: normalize_theta_set(S.Complexes)) + + # NotImplementedError for subset of reals + raises(NotImplementedError, lambda: normalize_theta_set(Interval(0, 1))) + + # NotImplementedError without pi as coefficient + raises(NotImplementedError, lambda: normalize_theta_set(Interval(1, 2*pi))) + raises(NotImplementedError, lambda: normalize_theta_set(Interval(2*pi, 10))) + raises(NotImplementedError, lambda: normalize_theta_set(FiniteSet(0, 3, 3*pi))) + + +def test_ComplexRegion_FiniteSet(): + x, y, z, a, b, c = symbols('x y z a b c') + + # Issue #9669 + assert ComplexRegion(FiniteSet(a, b, c)*FiniteSet(x, y, z)) == \ + FiniteSet(a + I*x, a + I*y, a + I*z, b + I*x, b + I*y, + b + I*z, c + I*x, c + I*y, c + I*z) + assert ComplexRegion(FiniteSet(2)*FiniteSet(3)) == FiniteSet(2 + 3*I) + + +def test_union_RealSubSet(): + assert (S.Complexes).union(Interval(1, 2)) == S.Complexes + assert (S.Complexes).union(S.Integers) == S.Complexes + + +def test_SetKind_fancySet(): + G = lambda *args: ImageSet(Lambda(x, x ** 2), *args) + assert G(Interval(1, 4)).kind is SetKind(NumberKind) + assert G(FiniteSet(1, 4)).kind is SetKind(NumberKind) + assert S.Rationals.kind is SetKind(NumberKind) + assert S.Naturals.kind is SetKind(NumberKind) + assert S.Integers.kind is SetKind(NumberKind) + assert Range(3).kind is SetKind(NumberKind) + a = Interval(2, 3) + b = Interval(4, 6) + c1 = ComplexRegion(a*b) + assert c1.kind is SetKind(TupleKind(NumberKind, NumberKind)) + + +def test_issue_9980(): + c1 = ComplexRegion(Interval(1, 2)*Interval(2, 3)) + c2 = ComplexRegion(Interval(1, 5)*Interval(1, 3)) + R = Union(c1, c2) + assert simplify(R) == ComplexRegion(Union(Interval(1, 2)*Interval(2, 3), \ + Interval(1, 5)*Interval(1, 3)), False) + assert c1.func(*c1.args) == c1 + assert R.func(*R.args) == R + + +def test_issue_11732(): + interval12 = Interval(1, 2) + finiteset1234 = FiniteSet(1, 2, 3, 4) + pointComplex = Tuple(1, 5) + + assert (interval12 in S.Naturals) == False + assert (interval12 in S.Naturals0) == False + assert (interval12 in S.Integers) == False + assert (interval12 in S.Complexes) == False + + assert (finiteset1234 in S.Naturals) == False + assert (finiteset1234 in S.Naturals0) == False + assert (finiteset1234 in S.Integers) == False + assert (finiteset1234 in S.Complexes) == False + + assert (pointComplex in S.Naturals) == False + assert (pointComplex in S.Naturals0) == False + assert (pointComplex in S.Integers) == False + assert (pointComplex in S.Complexes) == True + + +def test_issue_11730(): + unit = Interval(0, 1) + square = ComplexRegion(unit ** 2) + + assert Union(S.Complexes, FiniteSet(oo)) != S.Complexes + assert Union(S.Complexes, FiniteSet(eye(4))) != S.Complexes + assert Union(unit, square) == square + assert Intersection(S.Reals, square) == unit + + +def test_issue_11938(): + unit = Interval(0, 1) + ival = Interval(1, 2) + cr1 = ComplexRegion(ival * unit) + + assert Intersection(cr1, S.Reals) == ival + assert Intersection(cr1, unit) == FiniteSet(1) + + arg1 = Interval(0, S.Pi) + arg2 = FiniteSet(S.Pi) + arg3 = Interval(S.Pi / 4, 3 * S.Pi / 4) + cp1 = ComplexRegion(unit * arg1, polar=True) + cp2 = ComplexRegion(unit * arg2, polar=True) + cp3 = ComplexRegion(unit * arg3, polar=True) + + assert Intersection(cp1, S.Reals) == Interval(-1, 1) + assert Intersection(cp2, S.Reals) == Interval(-1, 0) + assert Intersection(cp3, S.Reals) == FiniteSet(0) + + +def test_issue_11914(): + a, b = Interval(0, 1), Interval(0, pi) + c, d = Interval(2, 3), Interval(pi, 3 * pi / 2) + cp1 = ComplexRegion(a * b, polar=True) + cp2 = ComplexRegion(c * d, polar=True) + + assert -3 in cp1.union(cp2) + assert -3 in cp2.union(cp1) + assert -5 not in cp1.union(cp2) + + +def test_issue_9543(): + assert ImageSet(Lambda(x, x**2), S.Naturals).is_subset(S.Reals) + + +def test_issue_16871(): + assert ImageSet(Lambda(x, x), FiniteSet(1)) == {1} + assert ImageSet(Lambda(x, x - 3), S.Integers + ).intersection(S.Integers) is S.Integers + + +@XFAIL +def test_issue_16871b(): + assert ImageSet(Lambda(x, x - 3), S.Integers).is_subset(S.Integers) + + +def test_issue_18050(): + assert imageset(Lambda(x, I*x + 1), S.Integers + ) == ImageSet(Lambda(x, I*x + 1), S.Integers) + assert imageset(Lambda(x, 3*I*x + 4 + 8*I), S.Integers + ) == ImageSet(Lambda(x, 3*I*x + 4 + 2*I), S.Integers) + # no 'Mod' for next 2 tests: + assert imageset(Lambda(x, 2*x + 3*I), S.Integers + ) == ImageSet(Lambda(x, 2*x + 3*I), S.Integers) + r = Symbol('r', positive=True) + assert imageset(Lambda(x, r*x + 10), S.Integers + ) == ImageSet(Lambda(x, r*x + 10), S.Integers) + # reduce real part: + assert imageset(Lambda(x, 3*x + 8 + 5*I), S.Integers + ) == ImageSet(Lambda(x, 3*x + 2 + 5*I), S.Integers) + + +def test_Rationals(): + assert S.Integers.is_subset(S.Rationals) + assert S.Naturals.is_subset(S.Rationals) + assert S.Naturals0.is_subset(S.Rationals) + assert S.Rationals.is_subset(S.Reals) + assert S.Rationals.inf is -oo + assert S.Rationals.sup is oo + it = iter(S.Rationals) + assert [next(it) for i in range(12)] == [ + 0, 1, -1, S.Half, 2, Rational(-1, 2), -2, + Rational(1, 3), 3, Rational(-1, 3), -3, Rational(2, 3)] + assert Basic() not in S.Rationals + assert S.Half in S.Rationals + assert S.Rationals.contains(0.5) == Contains( + 0.5, S.Rationals, evaluate=False) + assert 2 in S.Rationals + r = symbols('r', rational=True) + assert r in S.Rationals + raises(TypeError, lambda: x in S.Rationals) + # issue #18134: + assert S.Rationals.boundary == S.Reals + assert S.Rationals.closure == S.Reals + assert S.Rationals.is_open == False + assert S.Rationals.is_closed == False + + +def test_NZQRC_unions(): + # check that all trivial number set unions are simplified: + nbrsets = (S.Naturals, S.Naturals0, S.Integers, S.Rationals, + S.Reals, S.Complexes) + unions = (Union(a, b) for a in nbrsets for b in nbrsets) + assert all(u.is_Union is False for u in unions) + + +def test_imageset_intersection(): + n = Dummy() + s = ImageSet(Lambda(n, -I*(I*(2*pi*n - pi/4) + + log(Abs(sqrt(-I))))), S.Integers) + assert s.intersect(S.Reals) == ImageSet( + Lambda(n, 2*pi*n + pi*Rational(7, 4)), S.Integers) + + +def test_issue_17858(): + assert 1 in Range(-oo, oo) + assert 0 in Range(oo, -oo, -1) + assert oo not in Range(-oo, oo) + assert -oo not in Range(-oo, oo) + +def test_issue_17859(): + r = Range(-oo,oo) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) + r = Range(oo,-oo,-1) + raises(ValueError,lambda: r[::2]) + raises(ValueError, lambda: r[::-2]) diff --git a/lib/python3.10/site-packages/sympy/sets/tests/test_ordinals.py b/lib/python3.10/site-packages/sympy/sets/tests/test_ordinals.py new file mode 100644 index 0000000000000000000000000000000000000000..973ca329586f3e904f9377c44022c266f81c805c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/tests/test_ordinals.py @@ -0,0 +1,67 @@ +from sympy.sets.ordinals import Ordinal, OmegaPower, ord0, omega +from sympy.testing.pytest import raises + +def test_string_ordinals(): + assert str(omega) == 'w' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(3, 2))) == 'w**5*3 + w**3*2' + assert str(Ordinal(OmegaPower(5, 3), OmegaPower(0, 5))) == 'w**5*3 + 5' + assert str(Ordinal(OmegaPower(1, 3), OmegaPower(0, 5))) == 'w*3 + 5' + assert str(Ordinal(OmegaPower(omega + 1, 1), OmegaPower(3, 2))) == 'w**(w + 1) + w**3*2' + +def test_addition_with_integers(): + assert 3 + Ordinal(OmegaPower(5, 3)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(5, 3))+3 == Ordinal(OmegaPower(5, 3), OmegaPower(0, 3)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(0, 2))+3 == \ + Ordinal(OmegaPower(5, 3), OmegaPower(0, 5)) + + +def test_addition_with_ordinals(): + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(3, 3)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 5)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(4, 2)) + assert Ordinal(OmegaPower(omega, 2), OmegaPower(3, 2)) + Ordinal(OmegaPower(4, 2)) == \ + Ordinal(OmegaPower(omega, 2), OmegaPower(4, 2)) + +def test_comparison(): + assert Ordinal(OmegaPower(5, 3)) > Ordinal(OmegaPower(4, 3), OmegaPower(2, 1)) + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) < Ordinal(OmegaPower(5, 4)) + assert Ordinal(OmegaPower(5, 4)) < Ordinal(OmegaPower(5, 5), OmegaPower(4, 1)) + + assert Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == \ + Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) + assert not Ordinal(OmegaPower(5, 3), OmegaPower(3, 2)) == Ordinal(OmegaPower(5, 3)) + assert Ordinal(OmegaPower(omega, 3)) > Ordinal(OmegaPower(5, 3)) + +def test_multiplication_with_integers(): + w = omega + assert 3*w == w + assert w*9 == Ordinal(OmegaPower(1, 9)) + +def test_multiplication(): + w = omega + assert w*(w + 1) == w*w + w + assert (w + 1)*(w + 1) == w*w + w + 1 + assert w*1 == w + assert 1*w == w + assert w*ord0 == ord0 + assert ord0*w == ord0 + assert w**w == w * w**w + assert (w**w)*w*w == w**(w + 2) + +def test_exponentiation(): + w = omega + assert w**2 == w*w + assert w**3 == w*w*w + assert w**(w + 1) == Ordinal(OmegaPower(omega + 1, 1)) + assert (w**w)*(w**w) == w**(w*2) + +def test_comapre_not_instance(): + w = OmegaPower(omega + 1, 1) + assert(not (w == None)) + assert(not (w < 5)) + raises(TypeError, lambda: w < 6.66) + +def test_is_successort(): + w = Ordinal(OmegaPower(5, 1)) + assert not w.is_successor_ordinal diff --git a/lib/python3.10/site-packages/sympy/sets/tests/test_powerset.py b/lib/python3.10/site-packages/sympy/sets/tests/test_powerset.py new file mode 100644 index 0000000000000000000000000000000000000000..2e3a407d565f6b9537a296af103ec0a4e137cff9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/tests/test_powerset.py @@ -0,0 +1,141 @@ +from sympy.core.expr import unchanged +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.sets.contains import Contains +from sympy.sets.fancysets import Interval +from sympy.sets.powerset import PowerSet +from sympy.sets.sets import FiniteSet +from sympy.testing.pytest import raises, XFAIL + + +def test_powerset_creation(): + assert unchanged(PowerSet, FiniteSet(1, 2)) + assert unchanged(PowerSet, S.EmptySet) + raises(ValueError, lambda: PowerSet(123)) + assert unchanged(PowerSet, S.Reals) + assert unchanged(PowerSet, S.Integers) + + +def test_powerset_rewrite_FiniteSet(): + assert PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) == \ + FiniteSet(S.EmptySet, FiniteSet(1), FiniteSet(2), FiniteSet(1, 2)) + assert PowerSet(S.EmptySet).rewrite(FiniteSet) == FiniteSet(S.EmptySet) + assert PowerSet(S.Naturals).rewrite(FiniteSet) == PowerSet(S.Naturals) + + +def test_finiteset_rewrite_powerset(): + assert FiniteSet(S.EmptySet).rewrite(PowerSet) == PowerSet(S.EmptySet) + assert FiniteSet( + S.EmptySet, FiniteSet(1), + FiniteSet(2), FiniteSet(1, 2)).rewrite(PowerSet) == \ + PowerSet(FiniteSet(1, 2)) + assert FiniteSet(1, 2, 3).rewrite(PowerSet) == FiniteSet(1, 2, 3) + + +def test_powerset__contains__(): + subset_series = [ + S.EmptySet, + FiniteSet(1, 2), + S.Naturals, + S.Naturals0, + S.Integers, + S.Rationals, + S.Reals, + S.Complexes] + + l = len(subset_series) + for i in range(l): + for j in range(l): + if i <= j: + assert subset_series[i] in \ + PowerSet(subset_series[j], evaluate=False) + else: + assert subset_series[i] not in \ + PowerSet(subset_series[j], evaluate=False) + + +@XFAIL +def test_failing_powerset__contains__(): + # XXX These are failing when evaluate=True, + # but using unevaluated PowerSet works fine. + assert FiniteSet(1, 2) not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Naturals0 not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Integers not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Integers not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Rationals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Reals not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Reals not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(S.EmptySet).rewrite(FiniteSet) + assert S.Complexes not in PowerSet(FiniteSet(1, 2)).rewrite(FiniteSet) + + +def test_powerset__len__(): + A = PowerSet(S.EmptySet, evaluate=False) + assert len(A) == 1 + A = PowerSet(A, evaluate=False) + assert len(A) == 2 + A = PowerSet(A, evaluate=False) + assert len(A) == 4 + A = PowerSet(A, evaluate=False) + assert len(A) == 16 + + +def test_powerset__iter__(): + a = PowerSet(FiniteSet(1, 2)).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + + a = PowerSet(S.Naturals).__iter__() + assert next(a) == S.EmptySet + assert next(a) == FiniteSet(1) + assert next(a) == FiniteSet(2) + assert next(a) == FiniteSet(1, 2) + assert next(a) == FiniteSet(3) + assert next(a) == FiniteSet(1, 3) + assert next(a) == FiniteSet(2, 3) + assert next(a) == FiniteSet(1, 2, 3) + + +def test_powerset_contains(): + A = PowerSet(FiniteSet(1), evaluate=False) + assert A.contains(2) == Contains(2, A) + + x = Symbol('x') + + A = PowerSet(FiniteSet(x), evaluate=False) + assert A.contains(FiniteSet(1)) == Contains(FiniteSet(1), A) + + +def test_powerset_method(): + # EmptySet + A = FiniteSet() + pset = A.powerset() + assert len(pset) == 1 + assert pset == FiniteSet(S.EmptySet) + + # FiniteSets + A = FiniteSet(1, 2) + pset = A.powerset() + assert len(pset) == 2**len(A) + assert pset == FiniteSet(FiniteSet(), FiniteSet(1), + FiniteSet(2), A) + # Not finite sets + A = Interval(0, 1) + assert A.powerset() == PowerSet(A) + +def test_is_subset(): + # covers line 101-102 + # initialize powerset(1), which is a subset of powerset(1,2) + subset = PowerSet(FiniteSet(1)) + pset = PowerSet(FiniteSet(1, 2)) + bad_set = PowerSet(FiniteSet(2, 3)) + # assert "subset" is subset of pset == True + assert subset.is_subset(pset) + # assert "bad_set" is subset of pset == False + assert not pset.is_subset(bad_set) diff --git a/lib/python3.10/site-packages/sympy/sets/tests/test_setexpr.py b/lib/python3.10/site-packages/sympy/sets/tests/test_setexpr.py new file mode 100644 index 0000000000000000000000000000000000000000..faab1261c8d3e86901b04d30e8bc94de31642b93 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/tests/test_setexpr.py @@ -0,0 +1,317 @@ +from sympy.sets.setexpr import SetExpr +from sympy.sets import Interval, FiniteSet, Intersection, ImageSet, Union + +from sympy.core.expr import Expr +from sympy.core.function import Lambda +from sympy.core.numbers import (I, Rational, oo) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.sets.sets import Set + + +a, x = symbols("a, x") +_d = Dummy("d") + + +def test_setexpr(): + se = SetExpr(Interval(0, 1)) + assert isinstance(se.set, Set) + assert isinstance(se, Expr) + + +def test_scalar_funcs(): + assert SetExpr(Interval(0, 1)).set == Interval(0, 1) + a, b = Symbol('a', real=True), Symbol('b', real=True) + a, b = 1, 2 + # TODO: add support for more functions in the future: + for f in [exp, log]: + input_se = f(SetExpr(Interval(a, b))) + output = input_se.set + expected = Interval(Min(f(a), f(b)), Max(f(a), f(b))) + assert output == expected + + +def test_Add_Mul(): + assert (SetExpr(Interval(0, 1)) + 1).set == Interval(1, 2) + assert (SetExpr(Interval(0, 1))*2).set == Interval(0, 2) + + +def test_Pow(): + assert (SetExpr(Interval(0, 2))**2).set == Interval(0, 4) + + +def test_compound(): + assert (exp(SetExpr(Interval(0, 1))*2 + 1)).set == \ + Interval(exp(1), exp(3)) + + +def test_Interval_Interval(): + assert (SetExpr(Interval(1, 2)) + SetExpr(Interval(10, 20))).set == \ + Interval(11, 22) + assert (SetExpr(Interval(1, 2))*SetExpr(Interval(10, 20))).set == \ + Interval(10, 40) + + +def test_FiniteSet_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2, 3)) + SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(2, 3, 4, 5) + assert (SetExpr(FiniteSet(1, 2, 3))*SetExpr(FiniteSet(1, 2))).set == \ + FiniteSet(1, 2, 3, 4, 6) + + +def test_Interval_FiniteSet(): + assert (SetExpr(FiniteSet(1, 2)) + SetExpr(Interval(0, 10))).set == \ + Interval(1, 12) + + +def test_Many_Sets(): + assert (SetExpr(Interval(0, 1)) + + SetExpr(Interval(2, 3)) + + SetExpr(FiniteSet(10, 11, 12))).set == Interval(12, 16) + + +def test_same_setexprs_are_not_identical(): + a = SetExpr(FiniteSet(0, 1)) + b = SetExpr(FiniteSet(0, 1)) + assert (a + b).set == FiniteSet(0, 1, 2) + + # Cannot detect the set being the same: + # assert (a + a).set == FiniteSet(0, 2) + + +def test_Interval_arithmetic(): + i12cc = SetExpr(Interval(1, 2)) + i12lo = SetExpr(Interval.Lopen(1, 2)) + i12ro = SetExpr(Interval.Ropen(1, 2)) + i12o = SetExpr(Interval.open(1, 2)) + + n23cc = SetExpr(Interval(-2, 3)) + n23lo = SetExpr(Interval.Lopen(-2, 3)) + n23ro = SetExpr(Interval.Ropen(-2, 3)) + n23o = SetExpr(Interval.open(-2, 3)) + + n3n2cc = SetExpr(Interval(-3, -2)) + + assert i12cc + i12cc == SetExpr(Interval(2, 4)) + assert i12cc - i12cc == SetExpr(Interval(-1, 1)) + assert i12cc*i12cc == SetExpr(Interval(1, 4)) + assert i12cc/i12cc == SetExpr(Interval(S.Half, 2)) + assert i12cc**2 == SetExpr(Interval(1, 4)) + assert i12cc**3 == SetExpr(Interval(1, 8)) + + assert i12lo + i12ro == SetExpr(Interval.open(2, 4)) + assert i12lo - i12ro == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12ro == SetExpr(Interval.open(1, 4)) + assert i12lo/i12ro == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12lo == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12lo == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12lo + i12cc == SetExpr(Interval.Lopen(2, 4)) + assert i12lo - i12cc == SetExpr(Interval.Lopen(-1, 1)) + assert i12lo*i12cc == SetExpr(Interval.Lopen(1, 4)) + assert i12lo/i12cc == SetExpr(Interval.Lopen(S.Half, 2)) + assert i12lo + i12o == SetExpr(Interval.open(2, 4)) + assert i12lo - i12o == SetExpr(Interval.open(-1, 1)) + assert i12lo*i12o == SetExpr(Interval.open(1, 4)) + assert i12lo/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12lo**2 == SetExpr(Interval.Lopen(1, 4)) + assert i12lo**3 == SetExpr(Interval.Lopen(1, 8)) + + assert i12ro + i12ro == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12ro == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12ro + i12cc == SetExpr(Interval.Ropen(2, 4)) + assert i12ro - i12cc == SetExpr(Interval.Ropen(-1, 1)) + assert i12ro*i12cc == SetExpr(Interval.Ropen(1, 4)) + assert i12ro/i12cc == SetExpr(Interval.Ropen(S.Half, 2)) + assert i12ro + i12o == SetExpr(Interval.open(2, 4)) + assert i12ro - i12o == SetExpr(Interval.open(-1, 1)) + assert i12ro*i12o == SetExpr(Interval.open(1, 4)) + assert i12ro/i12o == SetExpr(Interval.open(S.Half, 2)) + assert i12ro**2 == SetExpr(Interval.Ropen(1, 4)) + assert i12ro**3 == SetExpr(Interval.Ropen(1, 8)) + + assert i12o + i12lo == SetExpr(Interval.open(2, 4)) + assert i12o - i12lo == SetExpr(Interval.open(-1, 1)) + assert i12o*i12lo == SetExpr(Interval.open(1, 4)) + assert i12o/i12lo == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12ro == SetExpr(Interval.open(2, 4)) + assert i12o - i12ro == SetExpr(Interval.open(-1, 1)) + assert i12o*i12ro == SetExpr(Interval.open(1, 4)) + assert i12o/i12ro == SetExpr(Interval.open(S.Half, 2)) + assert i12o + i12cc == SetExpr(Interval.open(2, 4)) + assert i12o - i12cc == SetExpr(Interval.open(-1, 1)) + assert i12o*i12cc == SetExpr(Interval.open(1, 4)) + assert i12o/i12cc == SetExpr(Interval.open(S.Half, 2)) + assert i12o**2 == SetExpr(Interval.open(1, 4)) + assert i12o**3 == SetExpr(Interval.open(1, 8)) + + assert n23cc + n23cc == SetExpr(Interval(-4, 6)) + assert n23cc - n23cc == SetExpr(Interval(-5, 5)) + assert n23cc*n23cc == SetExpr(Interval(-6, 9)) + assert n23cc/n23cc == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23ro == SetExpr(Interval.Ropen(-4, 6)) + assert n23cc - n23ro == SetExpr(Interval.Lopen(-5, 5)) + assert n23cc*n23ro == SetExpr(Interval.Ropen(-6, 9)) + assert n23cc/n23ro == SetExpr(Interval.Lopen(-oo, oo)) + assert n23cc + n23lo == SetExpr(Interval.Lopen(-4, 6)) + assert n23cc - n23lo == SetExpr(Interval.Ropen(-5, 5)) + assert n23cc*n23lo == SetExpr(Interval(-6, 9)) + assert n23cc/n23lo == SetExpr(Interval.open(-oo, oo)) + assert n23cc + n23o == SetExpr(Interval.open(-4, 6)) + assert n23cc - n23o == SetExpr(Interval.open(-5, 5)) + assert n23cc*n23o == SetExpr(Interval.open(-6, 9)) + assert n23cc/n23o == SetExpr(Interval.open(-oo, oo)) + assert n23cc**2 == SetExpr(Interval(0, 9)) + assert n23cc**3 == SetExpr(Interval(-8, 27)) + + n32cc = SetExpr(Interval(-3, 2)) + n32lo = SetExpr(Interval.Lopen(-3, 2)) + n32ro = SetExpr(Interval.Ropen(-3, 2)) + assert n32cc*n32lo == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32cc == SetExpr(Interval(-6, 9)) + assert n32lo*n32cc == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc*n32ro == SetExpr(Interval(-6, 9)) + assert n32lo*n32ro == SetExpr(Interval.Ropen(-6, 9)) + assert n32cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + assert i12cc/n32lo == SetExpr(Interval.Ropen(-oo, oo)) + + assert n3n2cc**2 == SetExpr(Interval(4, 9)) + assert n3n2cc**3 == SetExpr(Interval(-27, -8)) + + assert n23cc + i12cc == SetExpr(Interval(-1, 5)) + assert n23cc - i12cc == SetExpr(Interval(-4, 2)) + assert n23cc*i12cc == SetExpr(Interval(-4, 6)) + assert n23cc/i12cc == SetExpr(Interval(-2, 3)) + + +def test_SetExpr_Intersection(): + x, y, z, w = symbols("x y z w") + set1 = Interval(x, y) + set2 = Interval(w, z) + inter = Intersection(set1, set2) + se = SetExpr(inter) + assert exp(se).set == Intersection( + ImageSet(Lambda(x, exp(x)), set1), + ImageSet(Lambda(x, exp(x)), set2)) + assert cos(se).set == ImageSet(Lambda(x, cos(x)), inter) + + +def test_SetExpr_Interval_div(): + # TODO: some expressions cannot be calculated due to bugs (currently + # commented): + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(-2, 1)) == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(2, 3))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-3, -2))/SetExpr(Interval(0, 4)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(-3, 0)) == SetExpr(Interval(-oo, Rational(-2, 3))) + assert SetExpr(Interval(2, 4))/SetExpr(Interval(0, 3)) == SetExpr(Interval(Rational(2, 3), oo)) + + # assert SetExpr(Interval(0, 1))/SetExpr(Interval(0, 1)) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, 0))/SetExpr(Interval(0, 1)) == SetExpr(Interval(-oo, 0)) + assert SetExpr(Interval(-1, 2))/SetExpr(Interval(-2, 2)) == SetExpr(Interval(-oo, oo)) + + assert 1/SetExpr(Interval(-1, 2)) == SetExpr(Union(Interval(-oo, -1), Interval(S.Half, oo))) + + assert 1/SetExpr(Interval(0, 2)) == SetExpr(Interval(S.Half, oo)) + assert (-1)/SetExpr(Interval(0, 2)) == SetExpr(Interval(-oo, Rational(-1, 2))) + assert 1/SetExpr(Interval(-oo, 0)) == SetExpr(Interval.open(-oo, 0)) + assert 1/SetExpr(Interval(-1, 0)) == SetExpr(Interval(-oo, -1)) + # assert (-2)/SetExpr(Interval(-oo, 0)) == SetExpr(Interval(0, oo)) + # assert 1/SetExpr(Interval(-oo, -1)) == SetExpr(Interval(-1, 0)) + + # assert SetExpr(Interval(1, 2))/a == Mul(SetExpr(Interval(1, 2)), 1/a, evaluate=False) + + # assert SetExpr(Interval(1, 2))/0 == SetExpr(Interval(1, 2))*zoo + # assert SetExpr(Interval(1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, -1))/(-oo) == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-oo, oo))/oo == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-oo, oo))/(-oo) == SetExpr(Interval(-oo, oo)) + # assert SetExpr(Interval(-1, oo))/oo == SetExpr(Interval(0, oo)) + # assert SetExpr(Interval(-1, oo))/(-oo) == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/oo == SetExpr(Interval(-oo, 0)) + # assert SetExpr(Interval(-oo, 1))/(-oo) == SetExpr(Interval(0, oo)) + + +def test_SetExpr_Interval_pow(): + assert SetExpr(Interval(0, 2))**2 == SetExpr(Interval(0, 4)) + assert SetExpr(Interval(-1, 1))**2 == SetExpr(Interval(0, 1)) + assert SetExpr(Interval(1, 2))**2 == SetExpr(Interval(1, 4)) + assert SetExpr(Interval(-1, 2))**3 == SetExpr(Interval(-1, 8)) + assert SetExpr(Interval(-1, 1))**0 == SetExpr(FiniteSet(1)) + + + assert SetExpr(Interval(1, 2))**Rational(5, 2) == SetExpr(Interval(1, 4*sqrt(2))) + #assert SetExpr(Interval(-1, 2))**Rational(1, 3) == SetExpr(Interval(-1, 2**Rational(1, 3))) + #assert SetExpr(Interval(0, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + + #assert SetExpr(Interval(-4, 2))**Rational(2, 3) == SetExpr(Interval(0, 2*2**Rational(1, 3))) + + #assert SetExpr(Interval(-1, 5))**S.Half == SetExpr(Interval(0, sqrt(5))) + #assert SetExpr(Interval(-oo, 2))**S.Half == SetExpr(Interval(0, sqrt(2))) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 4)) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(1, 5))**(-2) == SetExpr(Interval(Rational(1, 25), 1)) + assert SetExpr(Interval(-1, 3))**(-2) == SetExpr(Interval(0, oo)) + + assert SetExpr(Interval(0, 2))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + assert SetExpr(Interval(-1, 2))**(-3) == SetExpr(Union(Interval(-oo, -1), Interval(Rational(1, 8), oo))) + assert SetExpr(Interval(-3, -2))**(-3) == SetExpr(Interval(Rational(-1, 8), Rational(-1, 27))) + assert SetExpr(Interval(-3, -2))**(-2) == SetExpr(Interval(Rational(1, 9), Rational(1, 4))) + #assert SetExpr(Interval(0, oo))**S.Half == SetExpr(Interval(0, oo)) + #assert SetExpr(Interval(-oo, -1))**Rational(1, 3) == SetExpr(Interval(-oo, -1)) + #assert SetExpr(Interval(-2, 3))**(Rational(-1, 3)) == SetExpr(Interval(-oo, oo)) + + assert SetExpr(Interval(-oo, 0))**(-2) == SetExpr(Interval.open(0, oo)) + assert SetExpr(Interval(-2, 0))**(-2) == SetExpr(Interval(Rational(1, 4), oo)) + + assert SetExpr(Interval(Rational(1, 3), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(S.Half, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(0, 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(2, 3))**oo == SetExpr(FiniteSet(oo)) + assert SetExpr(Interval(1, 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(S.Half, 3))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-1, 3), Rational(-1, 4)))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(-1, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-3, -2))**oo == SetExpr(FiniteSet(-oo, oo)) + assert SetExpr(Interval(-2, -1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, Rational(-1, 2)))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(Rational(-1, 2), S.Half))**oo == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(Rational(-1, 2), 1))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(Rational(-2, 3), 2))**oo == SetExpr(Interval(0, oo)) + assert SetExpr(Interval(-1, 1))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, S.Half))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-1, 2))**oo == SetExpr(Interval(-oo, oo)) + assert SetExpr(Interval(-2, S.Half))**oo == SetExpr(Interval(-oo, oo)) + + assert (SetExpr(Interval(1, 2))**x).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**x), Interval(1, 2)))) + + assert SetExpr(Interval(2, 3))**(-oo) == SetExpr(FiniteSet(0)) + assert SetExpr(Interval(0, 2))**(-oo) == SetExpr(Interval(0, oo)) + assert (SetExpr(Interval(-1, 2))**(-oo)).dummy_eq(SetExpr(ImageSet(Lambda(_d, _d**(-oo)), Interval(-1, 2)))) + + +def test_SetExpr_Integers(): + assert SetExpr(S.Integers) + 1 == SetExpr(S.Integers) + assert (SetExpr(S.Integers) + I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, _d + I), S.Integers))) + assert SetExpr(S.Integers)*(-1) == SetExpr(S.Integers) + assert (SetExpr(S.Integers)*2).dummy_eq( + SetExpr(ImageSet(Lambda(_d, 2*_d), S.Integers))) + assert (SetExpr(S.Integers)*I).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d), S.Integers))) + # issue #18050: + assert SetExpr(S.Integers)._eval_func(Lambda(x, I*x + 1)).dummy_eq( + SetExpr(ImageSet(Lambda(_d, I*_d + 1), S.Integers))) + # needs improvement: + assert (SetExpr(S.Integers)*I + 1).dummy_eq( + SetExpr(ImageSet(Lambda(x, x + 1), + ImageSet(Lambda(_d, _d*I), S.Integers)))) diff --git a/lib/python3.10/site-packages/sympy/sets/tests/test_sets.py b/lib/python3.10/site-packages/sympy/sets/tests/test_sets.py new file mode 100644 index 0000000000000000000000000000000000000000..657ab19a90eb88ca48f266f7a5cf050504caed43 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/sets/tests/test_sets.py @@ -0,0 +1,1753 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.containers import TupleKind +from sympy.core.function import Lambda +from sympy.core.kind import NumberKind, UndefinedKind +from sympy.core.numbers import (Float, I, Rational, nan, oo, pi, zoo) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.logic.boolalg import (false, true) +from sympy.matrices.kind import MatrixKind +from sympy.matrices.dense import Matrix +from sympy.polys.rootoftools import rootof +from sympy.sets.contains import Contains +from sympy.sets.fancysets import (ImageSet, Range) +from sympy.sets.sets import (Complement, DisjointUnion, FiniteSet, Intersection, Interval, ProductSet, Set, SymmetricDifference, Union, imageset, SetKind) +from mpmath import mpi + +from sympy.core.expr import unchanged +from sympy.core.relational import Eq, Ne, Le, Lt, LessThan +from sympy.logic import And, Or, Xor +from sympy.testing.pytest import raises, XFAIL, warns_deprecated_sympy +from sympy.utilities.iterables import cartes + +from sympy.abc import x, y, z, m, n + +EmptySet = S.EmptySet + +def test_imageset(): + ints = S.Integers + assert imageset(x, x - 1, S.Naturals) is S.Naturals0 + assert imageset(x, x + 1, S.Naturals0) is S.Naturals + assert imageset(x, abs(x), S.Naturals0) is S.Naturals0 + assert imageset(x, abs(x), S.Naturals) is S.Naturals + assert imageset(x, abs(x), S.Integers) is S.Naturals0 + # issue 16878a + r = symbols('r', real=True) + assert imageset(x, (x, x), S.Reals)._contains((1, r)) == None + assert imageset(x, (x, x), S.Reals)._contains((1, 2)) == False + assert (r, r) in imageset(x, (x, x), S.Reals) + assert 1 + I in imageset(x, x + I, S.Reals) + assert {1} not in imageset(x, (x,), S.Reals) + assert (1, 1) not in imageset(x, (x,), S.Reals) + raises(TypeError, lambda: imageset(x, ints)) + raises(ValueError, lambda: imageset(x, y, z, ints)) + raises(ValueError, lambda: imageset(Lambda(x, cos(x)), y)) + assert (1, 2) in imageset(Lambda((x, y), (x, y)), ints, ints) + raises(ValueError, lambda: imageset(Lambda(x, x), ints, ints)) + assert imageset(cos, ints) == ImageSet(Lambda(x, cos(x)), ints) + def f(x): + return cos(x) + assert imageset(f, ints) == imageset(x, cos(x), ints) + f = lambda x: cos(x) + assert imageset(f, ints) == ImageSet(Lambda(x, cos(x)), ints) + assert imageset(x, 1, ints) == FiniteSet(1) + assert imageset(x, y, ints) == {y} + assert imageset((x, y), (1, z), ints, S.Reals) == {(1, z)} + clash = Symbol('x', integer=true) + assert (str(imageset(lambda x: x + clash, Interval(-2, 1)).lamda.expr) + in ('x0 + x', 'x + x0')) + x1, x2 = symbols("x1, x2") + assert imageset(lambda x, y: + Add(x, y), Interval(1, 2), Interval(2, 3)).dummy_eq( + ImageSet(Lambda((x1, x2), x1 + x2), + Interval(1, 2), Interval(2, 3))) + + +def test_is_empty(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_empty is False + + assert S.EmptySet.is_empty is True + + +def test_is_finiteset(): + for s in [S.Naturals, S.Naturals0, S.Integers, S.Rationals, S.Reals, + S.UniversalSet]: + assert s.is_finite_set is False + + assert S.EmptySet.is_finite_set is True + + assert FiniteSet(1, 2).is_finite_set is True + assert Interval(1, 2).is_finite_set is False + assert Interval(x, y).is_finite_set is None + assert ProductSet(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert ProductSet(FiniteSet(1), Interval(1, 2)).is_finite_set is False + assert ProductSet(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Union(Interval(0, 1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), Interval(2, 3)).is_finite_set is False + assert Union(FiniteSet(1), FiniteSet(2)).is_finite_set is True + assert Union(FiniteSet(1), Interval(x, y)).is_finite_set is None + assert Intersection(Interval(x, y), FiniteSet(1)).is_finite_set is True + assert Intersection(Interval(x, y), Interval(1, 2)).is_finite_set is None + assert Intersection(FiniteSet(x), FiniteSet(y)).is_finite_set is True + assert Complement(FiniteSet(1), Interval(x, y)).is_finite_set is True + assert Complement(Interval(x, y), FiniteSet(1)).is_finite_set is None + assert Complement(Interval(1, 2), FiniteSet(x)).is_finite_set is False + assert DisjointUnion(Interval(-5, 3), FiniteSet(x, y)).is_finite_set is False + assert DisjointUnion(S.EmptySet, FiniteSet(x, y), S.EmptySet).is_finite_set is True + + +def test_deprecated_is_EmptySet(): + with warns_deprecated_sympy(): + S.EmptySet.is_EmptySet + + with warns_deprecated_sympy(): + FiniteSet(1).is_EmptySet + + +def test_interval_arguments(): + assert Interval(0, oo) == Interval(0, oo, False, True) + assert Interval(0, oo).right_open is true + assert Interval(-oo, 0) == Interval(-oo, 0, True, False) + assert Interval(-oo, 0).left_open is true + assert Interval(oo, -oo) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(-oo, -oo) == S.EmptySet + assert Interval(oo, x) == S.EmptySet + assert Interval(oo, oo) == S.EmptySet + assert Interval(x, -oo) == S.EmptySet + assert Interval(x, x) == {x} + + assert isinstance(Interval(1, 1), FiniteSet) + e = Sum(x, (x, 1, 3)) + assert isinstance(Interval(e, e), FiniteSet) + + assert Interval(1, 0) == S.EmptySet + assert Interval(1, 1).measure == 0 + + assert Interval(1, 1, False, True) == S.EmptySet + assert Interval(1, 1, True, False) == S.EmptySet + assert Interval(1, 1, True, True) == S.EmptySet + + + assert isinstance(Interval(0, Symbol('a')), Interval) + assert Interval(Symbol('a', positive=True), 0) == S.EmptySet + raises(ValueError, lambda: Interval(0, S.ImaginaryUnit)) + raises(ValueError, lambda: Interval(0, Symbol('z', extended_real=False))) + raises(ValueError, lambda: Interval(x, x + S.ImaginaryUnit)) + + raises(NotImplementedError, lambda: Interval(0, 1, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, False, And(x, y))) + raises(NotImplementedError, lambda: Interval(0, 1, z, And(x, y))) + + +def test_interval_symbolic_end_points(): + a = Symbol('a', real=True) + + assert Union(Interval(0, a), Interval(0, 3)).sup == Max(a, 3) + assert Union(Interval(a, 0), Interval(-3, 0)).inf == Min(-3, a) + + assert Interval(0, a).contains(1) == LessThan(1, a) + + +def test_interval_is_empty(): + x, y = symbols('x, y') + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + n = Symbol('n', negative=True) + nn = Symbol('nn', nonnegative=True) + assert Interval(1, 2).is_empty == False + assert Interval(3, 3).is_empty == False # FiniteSet + assert Interval(r, r).is_empty == False # FiniteSet + assert Interval(r, r + nn).is_empty == False + assert Interval(x, x).is_empty == False + assert Interval(1, oo).is_empty == False + assert Interval(-oo, oo).is_empty == False + assert Interval(-oo, 1).is_empty == False + assert Interval(x, y).is_empty == None + assert Interval(r, oo).is_empty == False # real implies finite + assert Interval(n, 0).is_empty == False + assert Interval(n, 0, left_open=True).is_empty == False + assert Interval(p, 0).is_empty == True # EmptySet + assert Interval(nn, 0).is_empty == None + assert Interval(n, p).is_empty == False + assert Interval(0, p, left_open=True).is_empty == False + assert Interval(0, p, right_open=True).is_empty == False + assert Interval(0, nn, left_open=True).is_empty == None + assert Interval(0, nn, right_open=True).is_empty == None + + +def test_union(): + assert Union(Interval(1, 2), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 2), Interval(2, 3, True)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(2, 4)) == Interval(1, 4) + assert Union(Interval(1, 2), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 3), Interval(1, 2)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(1, 2)) == \ + Interval(1, 3, False, True) + assert Union(Interval(1, 3), Interval(1, 2, False, True)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3)) == Interval(1, 3) + assert Union(Interval(1, 2, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 2, True), Interval(1, 3, True, True)) == \ + Interval(1, 3, True, True) + assert Union(Interval(1, 2, True, True), Interval(1, 3, True)) == \ + Interval(1, 3, True) + assert Union(Interval(1, 3), Interval(2, 3)) == Interval(1, 3) + assert Union(Interval(1, 3, False, True), Interval(2, 3)) == \ + Interval(1, 3) + assert Union(Interval(1, 2, False, True), Interval(2, 3, True)) != \ + Interval(1, 3) + assert Union(Interval(1, 2), S.EmptySet) == Interval(1, 2) + assert Union(S.EmptySet) == S.EmptySet + + assert Union(Interval(0, 1), *[FiniteSet(1.0/n) for n in range(1, 10)]) == \ + Interval(0, 1) + # issue #18241: + x = Symbol('x') + assert Union(Interval(0, 1), FiniteSet(1, x)) == Union( + Interval(0, 1), FiniteSet(x)) + assert unchanged(Union, Interval(0, 1), FiniteSet(2, x)) + + assert Interval(1, 2).union(Interval(2, 3)) == \ + Interval(1, 2) + Interval(2, 3) + + assert Interval(1, 2).union(Interval(2, 3)) == Interval(1, 3) + + assert Union(Set()) == Set() + + assert FiniteSet(1) + FiniteSet(2) + FiniteSet(3) == FiniteSet(1, 2, 3) + assert FiniteSet('ham') + FiniteSet('eggs') == FiniteSet('ham', 'eggs') + assert FiniteSet(1, 2, 3) + S.EmptySet == FiniteSet(1, 2, 3) + + assert FiniteSet(1, 2, 3) & FiniteSet(2, 3, 4) == FiniteSet(2, 3) + assert FiniteSet(1, 2, 3) | FiniteSet(2, 3, 4) == FiniteSet(1, 2, 3, 4) + + assert FiniteSet(1, 2, 3) & S.EmptySet == S.EmptySet + assert FiniteSet(1, 2, 3) | S.EmptySet == FiniteSet(1, 2, 3) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert S.EmptySet | FiniteSet(x, FiniteSet(y, z)) == \ + FiniteSet(x, FiniteSet(y, z)) + + # Test that Intervals and FiniteSets play nicely + assert Interval(1, 3) + FiniteSet(2) == Interval(1, 3) + assert Interval(1, 3, True, True) + FiniteSet(3) == \ + Interval(1, 3, True, False) + X = Interval(1, 3) + FiniteSet(5) + Y = Interval(1, 2) + FiniteSet(3) + XandY = X.intersect(Y) + assert 2 in X and 3 in X and 3 in XandY + assert XandY.is_subset(X) and XandY.is_subset(Y) + + raises(TypeError, lambda: Union(1, 2, 3)) + + assert X.is_iterable is False + + # issue 7843 + assert Union(S.EmptySet, FiniteSet(-sqrt(-I), sqrt(-I))) == \ + FiniteSet(-sqrt(-I), sqrt(-I)) + + assert Union(S.Reals, S.Integers) == S.Reals + + +def test_union_iter(): + # Use Range because it is ordered + u = Union(Range(3), Range(5), Range(4), evaluate=False) + + # Round robin + assert list(u) == [0, 0, 0, 1, 1, 1, 2, 2, 2, 3, 3, 4] + + +def test_union_is_empty(): + assert (Interval(x, y) + FiniteSet(1)).is_empty == False + assert (Interval(x, y) + Interval(-x, y)).is_empty == None + + +def test_difference(): + assert Interval(1, 3) - Interval(1, 2) == Interval(2, 3, True) + assert Interval(1, 3) - Interval(2, 3) == Interval(1, 2, False, True) + assert Interval(1, 3, True) - Interval(2, 3) == Interval(1, 2, True, True) + assert Interval(1, 3, True) - Interval(2, 3, True) == \ + Interval(1, 2, True, False) + assert Interval(0, 2) - FiniteSet(1) == \ + Union(Interval(0, 1, False, True), Interval(1, 2, True, False)) + + # issue #18119 + assert S.Reals - FiniteSet(I) == S.Reals + assert S.Reals - FiniteSet(-I, I) == S.Reals + assert Interval(0, 10) - FiniteSet(-I, I) == Interval(0, 10) + assert Interval(0, 10) - FiniteSet(1, I) == Union( + Interval.Ropen(0, 1), Interval.Lopen(1, 10)) + assert S.Reals - FiniteSet(1, 2 + I, x, y**2) == Complement( + Union(Interval.open(-oo, 1), Interval.open(1, oo)), FiniteSet(x, y**2), + evaluate=False) + + assert FiniteSet(1, 2, 3) - FiniteSet(2) == FiniteSet(1, 3) + assert FiniteSet('ham', 'eggs') - FiniteSet('eggs') == FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4) - Interval(2, 10, True, False) == \ + FiniteSet(1, 2) + assert FiniteSet(1, 2, 3, 4) - S.EmptySet == FiniteSet(1, 2, 3, 4) + assert Union(Interval(0, 2), FiniteSet(2, 3, 4)) - Interval(1, 3) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert -1 in S.Reals - S.Naturals + + +def test_Complement(): + A = FiniteSet(1, 3, 4) + B = FiniteSet(3, 4) + C = Interval(1, 3) + D = Interval(1, 2) + + assert Complement(A, B, evaluate=False).is_iterable is True + assert Complement(A, C, evaluate=False).is_iterable is True + assert Complement(C, D, evaluate=False).is_iterable is None + + assert FiniteSet(*Complement(A, B, evaluate=False)) == FiniteSet(1) + assert FiniteSet(*Complement(A, C, evaluate=False)) == FiniteSet(4) + raises(TypeError, lambda: FiniteSet(*Complement(C, A, evaluate=False))) + + assert Complement(Interval(1, 3), Interval(1, 2)) == Interval(2, 3, True) + assert Complement(FiniteSet(1, 3, 4), FiniteSet(3, 4)) == FiniteSet(1) + assert Complement(Union(Interval(0, 2), FiniteSet(2, 3, 4)), + Interval(1, 3)) == \ + Union(Interval(0, 1, False, True), FiniteSet(4)) + + assert 3 not in Complement(Interval(0, 5), Interval(1, 4), evaluate=False) + assert -1 in Complement(S.Reals, S.Naturals, evaluate=False) + assert 1 not in Complement(S.Reals, S.Naturals, evaluate=False) + + assert Complement(S.Integers, S.UniversalSet) == EmptySet + assert S.UniversalSet.complement(S.Integers) == EmptySet + + assert (0 not in S.Reals.intersect(S.Integers - FiniteSet(0))) + + assert S.EmptySet - S.Integers == S.EmptySet + + assert (S.Integers - FiniteSet(0)) - FiniteSet(1) == S.Integers - FiniteSet(0, 1) + + assert S.Reals - Union(S.Naturals, FiniteSet(pi)) == \ + Intersection(S.Reals - S.Naturals, S.Reals - FiniteSet(pi)) + # issue 12712 + assert Complement(FiniteSet(x, y, 2), Interval(-10, 10)) == \ + Complement(FiniteSet(x, y), Interval(-10, 10)) + + A = FiniteSet(*symbols('a:c')) + B = FiniteSet(*symbols('d:f')) + assert unchanged(Complement, ProductSet(A, A), B) + + A2 = ProductSet(A, A) + B3 = ProductSet(B, B, B) + assert A2 - B3 == A2 + assert B3 - A2 == B3 + + +def test_set_operations_nonsets(): + '''Tests that e.g. FiniteSet(1) * 2 raises TypeError''' + ops = [ + lambda a, b: a + b, + lambda a, b: a - b, + lambda a, b: a * b, + lambda a, b: a / b, + lambda a, b: a // b, + lambda a, b: a | b, + lambda a, b: a & b, + lambda a, b: a ^ b, + # FiniteSet(1) ** 2 gives a ProductSet + #lambda a, b: a ** b, + ] + Sx = FiniteSet(x) + Sy = FiniteSet(y) + sets = [ + {1}, + FiniteSet(1), + Interval(1, 2), + Union(Sx, Interval(1, 2)), + Intersection(Sx, Sy), + Complement(Sx, Sy), + ProductSet(Sx, Sy), + S.EmptySet, + ] + nums = [0, 1, 2, S(0), S(1), S(2)] + + for si in sets: + for ni in nums: + for op in ops: + raises(TypeError, lambda : op(si, ni)) + raises(TypeError, lambda : op(ni, si)) + raises(TypeError, lambda: si ** object()) + raises(TypeError, lambda: si ** {1}) + + +def test_complement(): + assert Complement({1, 2}, {1}) == {2} + assert Interval(0, 1).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, True, True)) + assert Interval(0, 1, True, False).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, True, True)) + assert Interval(0, 1, False, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, oo, False, True)) + assert Interval(0, 1, True, True).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, False), Interval(1, oo, False, True)) + + assert S.UniversalSet.complement(S.EmptySet) == S.EmptySet + assert S.UniversalSet.complement(S.Reals) == S.EmptySet + assert S.UniversalSet.complement(S.UniversalSet) == S.EmptySet + + assert S.EmptySet.complement(S.Reals) == S.Reals + + assert Union(Interval(0, 1), Interval(2, 3)).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(1, 2, True, True), + Interval(3, oo, True, True)) + + assert FiniteSet(0).complement(S.Reals) == \ + Union(Interval(-oo, 0, True, True), Interval(0, oo, True, True)) + + assert (FiniteSet(5) + Interval(S.NegativeInfinity, + 0)).complement(S.Reals) == \ + Interval(0, 5, True, True) + Interval(5, S.Infinity, True, True) + + assert FiniteSet(1, 2, 3).complement(S.Reals) == \ + Interval(S.NegativeInfinity, 1, True, True) + \ + Interval(1, 2, True, True) + Interval(2, 3, True, True) +\ + Interval(3, S.Infinity, True, True) + + assert FiniteSet(x).complement(S.Reals) == Complement(S.Reals, FiniteSet(x)) + + assert FiniteSet(0, x).complement(S.Reals) == Complement(Interval(-oo, 0, True, True) + + Interval(0, oo, True, True) + , FiniteSet(x), evaluate=False) + + square = Interval(0, 1) * Interval(0, 1) + notsquare = square.complement(S.Reals*S.Reals) + + assert all(pt in square for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any( + pt in notsquare for pt in [(0, 0), (.5, .5), (1, 0), (1, 1)]) + assert not any(pt in square for pt in [(-1, 0), (1.5, .5), (10, 10)]) + assert all(pt in notsquare for pt in [(-1, 0), (1.5, .5), (10, 10)]) + + +def test_intersect1(): + assert all(S.Integers.intersection(i) is i for i in + (S.Naturals, S.Naturals0)) + assert all(i.intersection(S.Integers) is i for i in + (S.Naturals, S.Naturals0)) + s = S.Naturals0 + assert S.Naturals.intersection(s) is S.Naturals + assert s.intersection(S.Naturals) is S.Naturals + x = Symbol('x') + assert Interval(0, 2).intersect(Interval(1, 2)) == Interval(1, 2) + assert Interval(0, 2).intersect(Interval(1, 2, True)) == \ + Interval(1, 2, True) + assert Interval(0, 2, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, False) + assert Interval(0, 2, True, True).intersect(Interval(1, 2)) == \ + Interval(1, 2, False, True) + assert Interval(0, 2).intersect(Union(Interval(0, 1), Interval(2, 3))) == \ + Union(Interval(0, 1), Interval(2, 2)) + + assert FiniteSet(1, 2).intersect(FiniteSet(1, 2, 3)) == FiniteSet(1, 2) + assert FiniteSet(1, 2, x).intersect(FiniteSet(x)) == FiniteSet(x) + assert FiniteSet('ham', 'eggs').intersect(FiniteSet('ham')) == \ + FiniteSet('ham') + assert FiniteSet(1, 2, 3, 4, 5).intersect(S.EmptySet) == S.EmptySet + + assert Interval(0, 5).intersect(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersect(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(0, 2)) == \ + Union(Interval(0, 1), Interval(2, 2)) + assert Union(Interval(0, 1), Interval(2, 3)).intersect(Interval(1, 2, True, True)) == \ + S.EmptySet + assert Union(Interval(0, 1), Interval(2, 3)).intersect(S.EmptySet) == \ + S.EmptySet + assert Union(Interval(0, 5), FiniteSet('ham')).intersect(FiniteSet(2, 3, 4, 5, 6)) == \ + Intersection(FiniteSet(2, 3, 4, 5, 6), Union(FiniteSet('ham'), Interval(0, 5))) + assert Intersection(FiniteSet(1, 2, 3), Interval(2, x), Interval(3, y)) == \ + Intersection(FiniteSet(3), Interval(2, x), Interval(3, y), evaluate=False) + assert Intersection(FiniteSet(1, 2), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + assert Intersection(FiniteSet(1, 2, 4), Interval(0, 3), Interval(x, y)) == \ + Intersection({1, 2}, Interval(x, y), evaluate=False) + # XXX: Is the real=True necessary here? + # https://github.com/sympy/sympy/issues/17532 + m, n = symbols('m, n', real=True) + assert Intersection(FiniteSet(m), FiniteSet(m, n), Interval(m, m+1)) == \ + FiniteSet(m) + + # issue 8217 + assert Intersection(FiniteSet(x), FiniteSet(y)) == \ + Intersection(FiniteSet(x), FiniteSet(y), evaluate=False) + assert FiniteSet(x).intersect(S.Reals) == \ + Intersection(S.Reals, FiniteSet(x), evaluate=False) + + # tests for the intersection alias + assert Interval(0, 5).intersection(FiniteSet(1, 3)) == FiniteSet(1, 3) + assert Interval(0, 1, True, True).intersection(FiniteSet(1)) == S.EmptySet + + assert Union(Interval(0, 1), Interval(2, 3)).intersection(Interval(1, 2)) == \ + Union(Interval(1, 1), Interval(2, 2)) + + # canonical boundary selected + a = sqrt(2*sqrt(6) + 5) + b = sqrt(2) + sqrt(3) + assert Interval(a, 4).intersection(Interval(b, 5)) == Interval(b, 4) + assert Interval(1, a).intersection(Interval(0, b)) == Interval(1, b) + + +def test_intersection_interval_float(): + # intersection of Intervals with mixed Rational/Float boundaries should + # lead to Float boundaries in all cases regardless of which Interval is + # open or closed. + typs = [ + (Interval, Interval, Interval), + (Interval, Interval.open, Interval.open), + (Interval, Interval.Lopen, Interval.Lopen), + (Interval, Interval.Ropen, Interval.Ropen), + (Interval.open, Interval.open, Interval.open), + (Interval.open, Interval.Lopen, Interval.open), + (Interval.open, Interval.Ropen, Interval.open), + (Interval.Lopen, Interval.Lopen, Interval.Lopen), + (Interval.Lopen, Interval.Ropen, Interval.open), + (Interval.Ropen, Interval.Ropen, Interval.Ropen), + ] + + as_float = lambda a1, a2: a2 if isinstance(a2, float) else a1 + + for t1, t2, t3 in typs: + for t1i, t2i in [(t1, t2), (t2, t1)]: + for a1, a2, b1, b2 in cartes([2, 2.0], [2, 2.0], [3, 3.0], [3, 3.0]): + I1 = t1(a1, b1) + I2 = t2(a2, b2) + I3 = t3(as_float(a1, a2), as_float(b1, b2)) + assert I1.intersect(I2) == I3 + + +def test_intersection(): + # iterable + i = Intersection(FiniteSet(1, 2, 3), Interval(2, 5), evaluate=False) + assert i.is_iterable + assert set(i) == {S(2), S(3)} + + # challenging intervals + x = Symbol('x', real=True) + i = Intersection(Interval(0, 3), Interval(x, 6)) + assert (5 in i) is False + raises(TypeError, lambda: 2 in i) + + # Singleton special cases + assert Intersection(Interval(0, 1), S.EmptySet) == S.EmptySet + assert Intersection(Interval(-oo, oo), Interval(-oo, x)) == Interval(-oo, x) + + # Products + line = Interval(0, 5) + i = Intersection(line**2, line**3, evaluate=False) + assert (2, 2) not in i + assert (2, 2, 2) not in i + raises(TypeError, lambda: list(i)) + + a = Intersection(Intersection(S.Integers, S.Naturals, evaluate=False), S.Reals, evaluate=False) + assert a._argset == frozenset([Intersection(S.Naturals, S.Integers, evaluate=False), S.Reals]) + + assert Intersection(S.Complexes, FiniteSet(S.ComplexInfinity)) == S.EmptySet + + # issue 12178 + assert Intersection() == S.UniversalSet + + # issue 16987 + assert Intersection({1}, {1}, {x}) == Intersection({1}, {x}) + + +def test_issue_9623(): + n = Symbol('n') + + a = S.Reals + b = Interval(0, oo) + c = FiniteSet(n) + + assert Intersection(a, b, c) == Intersection(b, c) + assert Intersection(Interval(1, 2), Interval(3, 4), FiniteSet(n)) == EmptySet + + +def test_is_disjoint(): + assert Interval(0, 2).is_disjoint(Interval(1, 2)) == False + assert Interval(0, 2).is_disjoint(Interval(3, 4)) == True + + +def test_ProductSet__len__(): + A = FiniteSet(1, 2) + B = FiniteSet(1, 2, 3) + assert ProductSet(A).__len__() == 2 + assert ProductSet(A).__len__() is not S(2) + assert ProductSet(A, B).__len__() == 6 + assert ProductSet(A, B).__len__() is not S(6) + + +def test_ProductSet(): + # ProductSet is always a set of Tuples + assert ProductSet(S.Reals) == S.Reals ** 1 + assert ProductSet(S.Reals, S.Reals) == S.Reals ** 2 + assert ProductSet(S.Reals, S.Reals, S.Reals) == S.Reals ** 3 + + assert ProductSet(S.Reals) != S.Reals + assert ProductSet(S.Reals, S.Reals) == S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) != S.Reals * S.Reals * S.Reals + assert ProductSet(S.Reals, S.Reals, S.Reals) == (S.Reals * S.Reals * S.Reals).flatten() + + assert 1 not in ProductSet(S.Reals) + assert (1,) in ProductSet(S.Reals) + + assert 1 not in ProductSet(S.Reals, S.Reals) + assert (1, 2) in ProductSet(S.Reals, S.Reals) + assert (1, I) not in ProductSet(S.Reals, S.Reals) + + assert (1, 2, 3) in ProductSet(S.Reals, S.Reals, S.Reals) + assert (1, 2, 3) in S.Reals ** 3 + assert (1, 2, 3) not in S.Reals * S.Reals * S.Reals + assert ((1, 2), 3) in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) not in S.Reals * S.Reals * S.Reals + assert (1, (2, 3)) in S.Reals * (S.Reals * S.Reals) + + assert ProductSet() == FiniteSet(()) + assert ProductSet(S.Reals, S.EmptySet) == S.EmptySet + + # See GH-17458 + + for ni in range(5): + Rn = ProductSet(*(S.Reals,) * ni) + assert (1,) * ni in Rn + assert 1 not in Rn + + assert (S.Reals * S.Reals) * S.Reals != S.Reals * (S.Reals * S.Reals) + + S1 = S.Reals + S2 = S.Integers + x1 = pi + x2 = 3 + assert x1 in S1 + assert x2 in S2 + assert (x1, x2) in S1 * S2 + S3 = S1 * S2 + x3 = (x1, x2) + assert x3 in S3 + assert (x3, x3) in S3 * S3 + assert x3 + x3 not in S3 * S3 + + raises(ValueError, lambda: S.Reals**-1) + with warns_deprecated_sympy(): + ProductSet(FiniteSet(s) for s in range(2)) + raises(TypeError, lambda: ProductSet(None)) + + S1 = FiniteSet(1, 2) + S2 = FiniteSet(3, 4) + S3 = ProductSet(S1, S2) + assert (S3.as_relational(x, y) + == And(S1.as_relational(x), S2.as_relational(y)) + == And(Or(Eq(x, 1), Eq(x, 2)), Or(Eq(y, 3), Eq(y, 4)))) + raises(ValueError, lambda: S3.as_relational(x)) + raises(ValueError, lambda: S3.as_relational(x, 1)) + raises(ValueError, lambda: ProductSet(Interval(0, 1)).as_relational(x, y)) + + Z2 = ProductSet(S.Integers, S.Integers) + assert Z2.contains((1, 2)) is S.true + assert Z2.contains((1,)) is S.false + assert Z2.contains(x) == Contains(x, Z2, evaluate=False) + assert Z2.contains(x).subs(x, 1) is S.false + assert Z2.contains((x, 1)).subs(x, 2) is S.true + assert Z2.contains((x, y)) == Contains(x, S.Integers) & Contains(y, S.Integers) + assert unchanged(Contains, (x, y), Z2) + assert Contains((1, 2), Z2) is S.true + + +def test_ProductSet_of_single_arg_is_not_arg(): + assert unchanged(ProductSet, Interval(0, 1)) + assert unchanged(ProductSet, ProductSet(Interval(0, 1))) + + +def test_ProductSet_is_empty(): + assert ProductSet(S.Integers, S.Reals).is_empty == False + assert ProductSet(Interval(x, 1), S.Reals).is_empty == None + + +def test_interval_subs(): + a = Symbol('a', real=True) + + assert Interval(0, a).subs(a, 2) == Interval(0, 2) + assert Interval(a, 0).subs(a, 2) == S.EmptySet + + +def test_interval_to_mpi(): + assert Interval(0, 1).to_mpi() == mpi(0, 1) + assert Interval(0, 1, True, False).to_mpi() == mpi(0, 1) + assert type(Interval(0, 1).to_mpi()) == type(mpi(0, 1)) + + +def test_set_evalf(): + assert Interval(S(11)/64, S.Half).evalf() == Interval( + Float('0.171875'), Float('0.5')) + assert Interval(x, S.Half, right_open=True).evalf() == Interval( + x, Float('0.5'), right_open=True) + assert Interval(-oo, S.Half).evalf() == Interval(-oo, Float('0.5')) + assert FiniteSet(2, x).evalf() == FiniteSet(Float('2.0'), x) + + +def test_measure(): + a = Symbol('a', real=True) + + assert Interval(1, 3).measure == 2 + assert Interval(0, a).measure == a + assert Interval(1, a).measure == a - 1 + + assert Union(Interval(1, 2), Interval(3, 4)).measure == 2 + assert Union(Interval(1, 2), Interval(3, 4), FiniteSet(5, 6, 7)).measure \ + == 2 + + assert FiniteSet(1, 2, oo, a, -oo, -5).measure == 0 + + assert S.EmptySet.measure == 0 + + square = Interval(0, 10) * Interval(0, 10) + offsetsquare = Interval(5, 15) * Interval(5, 15) + band = Interval(-oo, oo) * Interval(2, 4) + + assert square.measure == offsetsquare.measure == 100 + assert (square + offsetsquare).measure == 175 # there is some overlap + assert (square - offsetsquare).measure == 75 + assert (square * FiniteSet(1, 2, 3)).measure == 0 + assert (square.intersect(band)).measure == 20 + assert (square + band).measure is oo + assert (band * FiniteSet(1, 2, 3)).measure is nan + + +def test_is_subset(): + assert Interval(0, 1).is_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_subset(Interval(0, 2)) is False + assert Interval(0, 1).is_subset(FiniteSet(0, 1)) is False + + assert FiniteSet(1, 2).is_subset(FiniteSet(1, 2, 3, 4)) + assert FiniteSet(4, 5).is_subset(FiniteSet(1, 2, 3, 4)) is False + assert FiniteSet(1).is_subset(Interval(0, 2)) + assert FiniteSet(1, 2).is_subset(Interval(0, 2, True, True)) is False + assert (Interval(1, 2) + FiniteSet(3)).is_subset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) + + assert Interval(3, 4).is_subset(Union(Interval(0, 1), Interval(2, 5))) is True + assert Interval(3, 6).is_subset(Union(Interval(0, 1), Interval(2, 5))) is False + + assert FiniteSet(1, 2, 3, 4).is_subset(Interval(0, 5)) is True + assert S.EmptySet.is_subset(FiniteSet(1, 2, 3)) is True + + assert Interval(0, 1).is_subset(S.EmptySet) is False + assert S.EmptySet.is_subset(S.EmptySet) is True + + raises(ValueError, lambda: S.EmptySet.is_subset(1)) + + # tests for the issubset alias + assert FiniteSet(1, 2, 3, 4).issubset(Interval(0, 5)) is True + assert S.EmptySet.issubset(FiniteSet(1, 2, 3)) is True + + assert S.Naturals.is_subset(S.Integers) + assert S.Naturals0.is_subset(S.Integers) + + assert FiniteSet(x).is_subset(FiniteSet(y)) is None + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x)) is True + assert FiniteSet(x).is_subset(FiniteSet(y).subs(y, x+1)) is False + + assert Interval(0, 1).is_subset(Interval(0, 1, left_open=True)) is False + assert Interval(-2, 3).is_subset(Union(Interval(-oo, -2), Interval(3, oo))) is False + + n = Symbol('n', integer=True) + assert Range(-3, 4, 1).is_subset(FiniteSet(-10, 10)) is False + assert Range(S(10)**100).is_subset(FiniteSet(0, 1, 2)) is False + assert Range(6, 0, -2).is_subset(FiniteSet(2, 4, 6)) is True + assert Range(1, oo).is_subset(FiniteSet(1, 2)) is False + assert Range(-oo, 1).is_subset(FiniteSet(1)) is False + assert Range(3).is_subset(FiniteSet(0, 1, n)) is None + assert Range(n, n + 2).is_subset(FiniteSet(n, n + 1)) is True + assert Range(5).is_subset(Interval(0, 4, right_open=True)) is False + #issue 19513 + assert imageset(Lambda(n, 1/n), S.Integers).is_subset(S.Reals) is None + +def test_is_proper_subset(): + assert Interval(0, 1).is_proper_subset(Interval(0, 2)) is True + assert Interval(0, 3).is_proper_subset(Interval(0, 2)) is False + assert S.EmptySet.is_proper_subset(FiniteSet(1, 2, 3)) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_subset(0)) + + +def test_is_superset(): + assert Interval(0, 1).is_superset(Interval(0, 2)) == False + assert Interval(0, 3).is_superset(Interval(0, 2)) + + assert FiniteSet(1, 2).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(4, 5).is_superset(FiniteSet(1, 2, 3, 4)) == False + assert FiniteSet(1).is_superset(Interval(0, 2)) == False + assert FiniteSet(1, 2).is_superset(Interval(0, 2, True, True)) == False + assert (Interval(1, 2) + FiniteSet(3)).is_superset( + Interval(0, 2, False, True) + FiniteSet(2, 3)) == False + + assert Interval(3, 4).is_superset(Union(Interval(0, 1), Interval(2, 5))) == False + + assert FiniteSet(1, 2, 3, 4).is_superset(Interval(0, 5)) == False + assert S.EmptySet.is_superset(FiniteSet(1, 2, 3)) == False + + assert Interval(0, 1).is_superset(S.EmptySet) == True + assert S.EmptySet.is_superset(S.EmptySet) == True + + raises(ValueError, lambda: S.EmptySet.is_superset(1)) + + # tests for the issuperset alias + assert Interval(0, 1).issuperset(S.EmptySet) == True + assert S.EmptySet.issuperset(S.EmptySet) == True + + +def test_is_proper_superset(): + assert Interval(0, 1).is_proper_superset(Interval(0, 2)) is False + assert Interval(0, 3).is_proper_superset(Interval(0, 2)) is True + assert FiniteSet(1, 2, 3).is_proper_superset(S.EmptySet) is True + + raises(ValueError, lambda: Interval(0, 1).is_proper_superset(0)) + + +def test_contains(): + assert Interval(0, 2).contains(1) is S.true + assert Interval(0, 2).contains(3) is S.false + assert Interval(0, 2, True, False).contains(0) is S.false + assert Interval(0, 2, True, False).contains(2) is S.true + assert Interval(0, 2, False, True).contains(0) is S.true + assert Interval(0, 2, False, True).contains(2) is S.false + assert Interval(0, 2, True, True).contains(0) is S.false + assert Interval(0, 2, True, True).contains(2) is S.false + + assert (Interval(0, 2) in Interval(0, 2)) is False + + assert FiniteSet(1, 2, 3).contains(2) is S.true + assert FiniteSet(1, 2, Symbol('x')).contains(Symbol('x')) is S.true + + assert FiniteSet(y)._contains(x) == Eq(y, x, evaluate=False) + raises(TypeError, lambda: x in FiniteSet(y)) + assert FiniteSet({x, y})._contains({x}) == Eq({x, y}, {x}, evaluate=False) + assert FiniteSet({x, y}).subs(y, x)._contains({x}) is S.true + assert FiniteSet({x, y}).subs(y, x+1)._contains({x}) is S.false + + # issue 8197 + from sympy.abc import a, b + assert FiniteSet(b).contains(-a) == Eq(b, -a) + assert FiniteSet(b).contains(a) == Eq(b, a) + assert FiniteSet(a).contains(1) == Eq(a, 1) + raises(TypeError, lambda: 1 in FiniteSet(a)) + + # issue 8209 + rad1 = Pow(Pow(2, Rational(1, 3)) - 1, Rational(1, 3)) + rad2 = Pow(Rational(1, 9), Rational(1, 3)) - Pow(Rational(2, 9), Rational(1, 3)) + Pow(Rational(4, 9), Rational(1, 3)) + s1 = FiniteSet(rad1) + s2 = FiniteSet(rad2) + assert s1 - s2 == S.EmptySet + + items = [1, 2, S.Infinity, S('ham'), -1.1] + fset = FiniteSet(*items) + assert all(item in fset for item in items) + assert all(fset.contains(item) is S.true for item in items) + + assert Union(Interval(0, 1), Interval(2, 5)).contains(3) is S.true + assert Union(Interval(0, 1), Interval(2, 5)).contains(6) is S.false + assert Union(Interval(0, 1), FiniteSet(2, 5)).contains(3) is S.false + + assert S.EmptySet.contains(1) is S.false + assert FiniteSet(rootof(x**3 + x - 1, 0)).contains(S.Infinity) is S.false + + assert rootof(x**5 + x**3 + 1, 0) in S.Reals + assert not rootof(x**5 + x**3 + 1, 1) in S.Reals + + # non-bool results + assert Union(Interval(1, 2), Interval(3, 4)).contains(x) == \ + Or(And(S.One <= x, x <= 2), And(S(3) <= x, x <= 4)) + assert Intersection(Interval(1, x), Interval(2, 3)).contains(y) == \ + And(y <= 3, y <= x, S.One <= y, S(2) <= y) + + assert (S.Complexes).contains(S.ComplexInfinity) == S.false + + +def test_interval_symbolic(): + x = Symbol('x') + e = Interval(0, 1) + assert e.contains(x) == And(S.Zero <= x, x <= 1) + raises(TypeError, lambda: x in e) + e = Interval(0, 1, True, True) + assert e.contains(x) == And(S.Zero < x, x < 1) + c = Symbol('c', real=False) + assert Interval(x, x + 1).contains(c) == False + e = Symbol('e', extended_real=True) + assert Interval(-oo, oo).contains(e) == And( + S.NegativeInfinity < e, e < S.Infinity) + + +def test_union_contains(): + x = Symbol('x') + i1 = Interval(0, 1) + i2 = Interval(2, 3) + i3 = Union(i1, i2) + assert i3.as_relational(x) == Or(And(S.Zero <= x, x <= 1), And(S(2) <= x, x <= 3)) + raises(TypeError, lambda: x in i3) + e = i3.contains(x) + assert e == i3.as_relational(x) + assert e.subs(x, -0.5) is false + assert e.subs(x, 0.5) is true + assert e.subs(x, 1.5) is false + assert e.subs(x, 2.5) is true + assert e.subs(x, 3.5) is false + + U = Interval(0, 2, True, True) + Interval(10, oo) + FiniteSet(-1, 2, 5, 6) + assert all(el not in U for el in [0, 4, -oo]) + assert all(el in U for el in [2, 5, 10]) + + +def test_is_number(): + assert Interval(0, 1).is_number is False + assert Set().is_number is False + + +def test_Interval_is_left_unbounded(): + assert Interval(3, 4).is_left_unbounded is False + assert Interval(-oo, 3).is_left_unbounded is True + assert Interval(Float("-inf"), 3).is_left_unbounded is True + + +def test_Interval_is_right_unbounded(): + assert Interval(3, 4).is_right_unbounded is False + assert Interval(3, oo).is_right_unbounded is True + assert Interval(3, Float("+inf")).is_right_unbounded is True + + +def test_Interval_as_relational(): + x = Symbol('x') + + assert Interval(-1, 2, False, False).as_relational(x) == \ + And(Le(-1, x), Le(x, 2)) + assert Interval(-1, 2, True, False).as_relational(x) == \ + And(Lt(-1, x), Le(x, 2)) + assert Interval(-1, 2, False, True).as_relational(x) == \ + And(Le(-1, x), Lt(x, 2)) + assert Interval(-1, 2, True, True).as_relational(x) == \ + And(Lt(-1, x), Lt(x, 2)) + + assert Interval(-oo, 2, right_open=False).as_relational(x) == And(Lt(-oo, x), Le(x, 2)) + assert Interval(-oo, 2, right_open=True).as_relational(x) == And(Lt(-oo, x), Lt(x, 2)) + + assert Interval(-2, oo, left_open=False).as_relational(x) == And(Le(-2, x), Lt(x, oo)) + assert Interval(-2, oo, left_open=True).as_relational(x) == And(Lt(-2, x), Lt(x, oo)) + + assert Interval(-oo, oo).as_relational(x) == And(Lt(-oo, x), Lt(x, oo)) + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert Interval(x, y).as_relational(x) == (x <= y) + assert Interval(y, x).as_relational(x) == (y <= x) + + +def test_Finite_as_relational(): + x = Symbol('x') + y = Symbol('y') + + assert FiniteSet(1, 2).as_relational(x) == Or(Eq(x, 1), Eq(x, 2)) + assert FiniteSet(y, -5).as_relational(x) == Or(Eq(x, y), Eq(x, -5)) + + +def test_Union_as_relational(): + x = Symbol('x') + assert (Interval(0, 1) + FiniteSet(2)).as_relational(x) == \ + Or(And(Le(0, x), Le(x, 1)), Eq(x, 2)) + assert (Interval(0, 1, True, True) + FiniteSet(1)).as_relational(x) == \ + And(Lt(0, x), Le(x, 1)) + assert Or(x < 0, x > 0).as_set().as_relational(x) == \ + And((x > -oo), (x < oo), Ne(x, 0)) + assert (Interval.Ropen(1, 3) + Interval.Lopen(3, 5) + ).as_relational(x) == And(Ne(x,3),(x>=1),(x<=5)) + + +def test_Intersection_as_relational(): + x = Symbol('x') + assert (Intersection(Interval(0, 1), FiniteSet(2), + evaluate=False).as_relational(x) + == And(And(Le(0, x), Le(x, 1)), Eq(x, 2))) + + +def test_Complement_as_relational(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == \ + And(Le(0, x), Le(x, 1), Ne(x, 2)) + + +@XFAIL +def test_Complement_as_relational_fail(): + x = Symbol('x') + expr = Complement(Interval(0, 1), FiniteSet(2), evaluate=False) + # XXX This example fails because 0 <= x changes to x >= 0 + # during the evaluation. + assert expr.as_relational(x) == \ + (0 <= x) & (x <= 1) & Ne(x, 2) + + +def test_SymmetricDifference_as_relational(): + x = Symbol('x') + expr = SymmetricDifference(Interval(0, 1), FiniteSet(2), evaluate=False) + assert expr.as_relational(x) == Xor(Eq(x, 2), Le(0, x) & Le(x, 1)) + + +def test_EmptySet(): + assert S.EmptySet.as_relational(Symbol('x')) is S.false + assert S.EmptySet.intersect(S.UniversalSet) == S.EmptySet + assert S.EmptySet.boundary == S.EmptySet + + +def test_finite_basic(): + x = Symbol('x') + A = FiniteSet(1, 2, 3) + B = FiniteSet(3, 4, 5) + AorB = Union(A, B) + AandB = A.intersect(B) + assert A.is_subset(AorB) and B.is_subset(AorB) + assert AandB.is_subset(A) + assert AandB == FiniteSet(3) + + assert A.inf == 1 and A.sup == 3 + assert AorB.inf == 1 and AorB.sup == 5 + assert FiniteSet(x, 1, 5).sup == Max(x, 5) + assert FiniteSet(x, 1, 5).inf == Min(x, 1) + + # issue 7335 + assert FiniteSet(S.EmptySet) != S.EmptySet + assert FiniteSet(FiniteSet(1, 2, 3)) != FiniteSet(1, 2, 3) + assert FiniteSet((1, 2, 3)) != FiniteSet(1, 2, 3) + + # Ensure a variety of types can exist in a FiniteSet + assert FiniteSet((1, 2), A, -5, x, 'eggs', x**2) + + assert (A > B) is False + assert (A >= B) is False + assert (A < B) is False + assert (A <= B) is False + assert AorB > A and AorB > B + assert AorB >= A and AorB >= B + assert A >= A and A <= A + assert A >= AandB and B >= AandB + assert A > AandB and B > AandB + + +def test_product_basic(): + H, T = 'H', 'T' + unit_line = Interval(0, 1) + d6 = FiniteSet(1, 2, 3, 4, 5, 6) + d4 = FiniteSet(1, 2, 3, 4) + coin = FiniteSet(H, T) + + square = unit_line * unit_line + + assert (0, 0) in square + assert 0 not in square + assert (H, T) in coin ** 2 + assert (.5, .5, .5) in (square * unit_line).flatten() + assert ((.5, .5), .5) in square * unit_line + assert (H, 3, 3) in (coin * d6 * d6).flatten() + assert ((H, 3), 3) in coin * d6 * d6 + HH, TT = sympify(H), sympify(T) + assert set(coin**2) == {(HH, HH), (HH, TT), (TT, HH), (TT, TT)} + + assert (d4*d4).is_subset(d6*d6) + + assert square.complement(Interval(-oo, oo)*Interval(-oo, oo)) == Union( + (Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))*Interval(-oo, oo), + Interval(-oo, oo)*(Interval(-oo, 0, True, True) + + Interval(1, oo, True, True))) + + assert (Interval(-5, 5)**3).is_subset(Interval(-10, 10)**3) + assert not (Interval(-10, 10)**3).is_subset(Interval(-5, 5)**3) + assert not (Interval(-5, 5)**2).is_subset(Interval(-10, 10)**3) + + assert (Interval(.2, .5)*FiniteSet(.5)).is_subset(square) # segment in square + + assert len(coin*coin*coin) == 8 + assert len(S.EmptySet*S.EmptySet) == 0 + assert len(S.EmptySet*coin) == 0 + raises(TypeError, lambda: len(coin*Interval(0, 2))) + + +def test_real(): + x = Symbol('x', real=True) + + I = Interval(0, 5) + J = Interval(10, 20) + A = FiniteSet(1, 2, 30, x, S.Pi) + B = FiniteSet(-4, 0) + C = FiniteSet(100) + D = FiniteSet('Ham', 'Eggs') + + assert all(s.is_subset(S.Reals) for s in [I, J, A, B, C]) + assert not D.is_subset(S.Reals) + assert all((a + b).is_subset(S.Reals) for a in [I, J, A, B, C] for b in [I, J, A, B, C]) + assert not any((a + D).is_subset(S.Reals) for a in [I, J, A, B, C, D]) + + assert not (I + A + D).is_subset(S.Reals) + + +def test_supinf(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + + assert (Interval(0, 1) + FiniteSet(2)).sup == 2 + assert (Interval(0, 1) + FiniteSet(2)).inf == 0 + assert (Interval(0, 1) + FiniteSet(x)).sup == Max(1, x) + assert (Interval(0, 1) + FiniteSet(x)).inf == Min(0, x) + assert FiniteSet(5, 1, x).sup == Max(5, x) + assert FiniteSet(5, 1, x).inf == Min(1, x) + assert FiniteSet(5, 1, x, y).sup == Max(5, x, y) + assert FiniteSet(5, 1, x, y).inf == Min(1, x, y) + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).sup == \ + S.Infinity + assert FiniteSet(5, 1, x, y, S.Infinity, S.NegativeInfinity).inf == \ + S.NegativeInfinity + assert FiniteSet('Ham', 'Eggs').sup == Max('Ham', 'Eggs') + + +def test_universalset(): + U = S.UniversalSet + x = Symbol('x') + assert U.as_relational(x) is S.true + assert U.union(Interval(2, 4)) == U + + assert U.intersect(Interval(2, 4)) == Interval(2, 4) + assert U.measure is S.Infinity + assert U.boundary == S.EmptySet + assert U.contains(0) is S.true + + +def test_Union_of_ProductSets_shares(): + line = Interval(0, 2) + points = FiniteSet(0, 1, 2) + assert Union(line * line, line * points) == line * line + + +def test_Interval_free_symbols(): + # issue 6211 + assert Interval(0, 1).free_symbols == set() + x = Symbol('x', real=True) + assert Interval(0, x).free_symbols == {x} + + +def test_image_interval(): + x = Symbol('x', real=True) + a = Symbol('a', real=True) + assert imageset(x, 2*x, Interval(-2, 1)) == Interval(-4, 2) + assert imageset(x, 2*x, Interval(-2, 1, True, False)) == \ + Interval(-4, 2, True, False) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1)) == Interval(0, 4) + assert imageset(x, x**2, Interval(-2, 1, True, False)) == \ + Interval(0, 4, False, True) + assert imageset(x, x**2, Interval(-2, 1, True, True)) == \ + Interval(0, 4, False, True) + assert imageset(x, (x - 2)**2, Interval(1, 3)) == Interval(0, 1) + assert imageset(x, 3*x**4 - 26*x**3 + 78*x**2 - 90*x, Interval(0, 4)) == \ + Interval(-35, 0) # Multiple Maxima + assert imageset(x, x + 1/x, Interval(-oo, oo)) == Interval(-oo, -2) \ + + Interval(2, oo) # Single Infinite discontinuity + assert imageset(x, 1/x + 1/(x-1)**2, Interval(0, 2, True, False)) == \ + Interval(Rational(3, 2), oo, False) # Multiple Infinite discontinuities + + # Test for Python lambda + assert imageset(lambda x: 2*x, Interval(-2, 1)) == Interval(-4, 2) + + assert imageset(Lambda(x, a*x), Interval(0, 1)) == \ + ImageSet(Lambda(x, a*x), Interval(0, 1)) + + assert imageset(Lambda(x, sin(cos(x))), Interval(0, 1)) == \ + ImageSet(Lambda(x, sin(cos(x))), Interval(0, 1)) + + +def test_image_piecewise(): + f = Piecewise((x, x <= -1), (1/x**2, x <= 5), (x**3, True)) + f1 = Piecewise((0, x <= 1), (1, x <= 2), (2, True)) + assert imageset(x, f, Interval(-5, 5)) == Union(Interval(-5, -1), Interval(Rational(1, 25), oo)) + assert imageset(x, f1, Interval(1, 2)) == FiniteSet(0, 1) + + +@XFAIL # See: https://github.com/sympy/sympy/pull/2723#discussion_r8659826 +def test_image_Intersection(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert imageset(x, x**2, Interval(-2, 0).intersect(Interval(x, y))) == \ + Interval(0, 4).intersect(Interval(Min(x**2, y**2), Max(x**2, y**2))) + + +def test_image_FiniteSet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, FiniteSet(1, 2, 3)) == FiniteSet(2, 4, 6) + + +def test_image_Union(): + x = Symbol('x', real=True) + assert imageset(x, x**2, Interval(-2, 0) + FiniteSet(1, 2, 3)) == \ + (Interval(0, 4) + FiniteSet(9)) + + +def test_image_EmptySet(): + x = Symbol('x', real=True) + assert imageset(x, 2*x, S.EmptySet) == S.EmptySet + + +def test_issue_5724_7680(): + assert I not in S.Reals # issue 7680 + assert Interval(-oo, oo).contains(I) is S.false + + +def test_boundary(): + assert FiniteSet(1).boundary == FiniteSet(1) + assert all(Interval(0, 1, left_open, right_open).boundary == FiniteSet(0, 1) + for left_open in (true, false) for right_open in (true, false)) + + +def test_boundary_Union(): + assert (Interval(0, 1) + Interval(2, 3)).boundary == FiniteSet(0, 1, 2, 3) + assert ((Interval(0, 1, False, True) + + Interval(1, 2, True, False)).boundary == FiniteSet(0, 1, 2)) + + assert (Interval(0, 1) + FiniteSet(2)).boundary == FiniteSet(0, 1, 2) + assert Union(Interval(0, 10), Interval(5, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + assert Union(Interval(0, 10), Interval(0, 1), evaluate=False).boundary \ + == FiniteSet(0, 10) + assert Union(Interval(0, 10, True, True), + Interval(10, 15, True, True), evaluate=False).boundary \ + == FiniteSet(0, 10, 15) + + +@XFAIL +def test_union_boundary_of_joining_sets(): + """ Testing the boundary of unions is a hard problem """ + assert Union(Interval(0, 10), Interval(10, 15), evaluate=False).boundary \ + == FiniteSet(0, 15) + + +def test_boundary_ProductSet(): + open_square = Interval(0, 1, True, True) ** 2 + assert open_square.boundary == (FiniteSet(0, 1) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1)) + + second_square = Interval(1, 2, True, True) * Interval(0, 1, True, True) + assert (open_square + second_square).boundary == ( + FiniteSet(0, 1) * Interval(0, 1) + + FiniteSet(1, 2) * Interval(0, 1) + + Interval(0, 1) * FiniteSet(0, 1) + + Interval(1, 2) * FiniteSet(0, 1)) + + +def test_boundary_ProductSet_line(): + line_in_r2 = Interval(0, 1) * FiniteSet(0) + assert line_in_r2.boundary == line_in_r2 + + +def test_is_open(): + assert Interval(0, 1, False, False).is_open is False + assert Interval(0, 1, True, False).is_open is False + assert Interval(0, 1, True, True).is_open is True + assert FiniteSet(1, 2, 3).is_open is False + + +def test_is_closed(): + assert Interval(0, 1, False, False).is_closed is True + assert Interval(0, 1, True, False).is_closed is False + assert FiniteSet(1, 2, 3).is_closed is True + + +def test_closure(): + assert Interval(0, 1, False, True).closure == Interval(0, 1, False, False) + + +def test_interior(): + assert Interval(0, 1, False, True).interior == Interval(0, 1, True, True) + + +def test_issue_7841(): + raises(TypeError, lambda: x in S.Reals) + + +def test_Eq(): + assert Eq(Interval(0, 1), Interval(0, 1)) + assert Eq(Interval(0, 1), Interval(0, 2)) == False + + s1 = FiniteSet(0, 1) + s2 = FiniteSet(1, 2) + + assert Eq(s1, s1) + assert Eq(s1, s2) == False + + assert Eq(s1*s2, s1*s2) + assert Eq(s1*s2, s2*s1) == False + + assert unchanged(Eq, FiniteSet({x, y}), FiniteSet({x})) + assert Eq(FiniteSet({x, y}).subs(y, x), FiniteSet({x})) is S.true + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x) is S.true + assert Eq(FiniteSet({x, y}).subs(y, x+1), FiniteSet({x})) is S.false + assert Eq(FiniteSet({x, y}), FiniteSet({x})).subs(y, x+1) is S.false + + assert Eq(ProductSet({1}, {2}), Interval(1, 2)) is S.false + assert Eq(ProductSet({1}), ProductSet({1}, {2})) is S.false + + assert Eq(FiniteSet(()), FiniteSet(1)) is S.false + assert Eq(ProductSet(), FiniteSet(1)) is S.false + + i1 = Interval(0, 1) + i2 = Interval(x, y) + assert unchanged(Eq, ProductSet(i1, i1), ProductSet(i2, i2)) + + +def test_SymmetricDifference(): + A = FiniteSet(0, 1, 2, 3, 4, 5) + B = FiniteSet(2, 4, 6, 8, 10) + C = Interval(8, 10) + + assert SymmetricDifference(A, B, evaluate=False).is_iterable is True + assert SymmetricDifference(A, C, evaluate=False).is_iterable is None + assert FiniteSet(*SymmetricDifference(A, B, evaluate=False)) == \ + FiniteSet(0, 1, 3, 5, 6, 8, 10) + raises(TypeError, + lambda: FiniteSet(*SymmetricDifference(A, C, evaluate=False))) + + assert SymmetricDifference(FiniteSet(0, 1, 2, 3, 4, 5), \ + FiniteSet(2, 4, 6, 8, 10)) == FiniteSet(0, 1, 3, 5, 6, 8, 10) + assert SymmetricDifference(FiniteSet(2, 3, 4), FiniteSet(2, 3, 4 ,5)) \ + == FiniteSet(5) + assert FiniteSet(1, 2, 3, 4, 5) ^ FiniteSet(1, 2, 5, 6) == \ + FiniteSet(3, 4, 6) + assert Set(S(1), S(2), S(3)) ^ Set(S(2), S(3), S(4)) == Union(Set(S(1), S(2), S(3)) - Set(S(2), S(3), S(4)), \ + Set(S(2), S(3), S(4)) - Set(S(1), S(2), S(3))) + assert Interval(0, 4) ^ Interval(2, 5) == Union(Interval(0, 4) - \ + Interval(2, 5), Interval(2, 5) - Interval(0, 4)) + + +def test_issue_9536(): + from sympy.functions.elementary.exponential import log + a = Symbol('a', real=True) + assert FiniteSet(log(a)).intersect(S.Reals) == Intersection(S.Reals, FiniteSet(log(a))) + + +def test_issue_9637(): + n = Symbol('n') + a = FiniteSet(n) + b = FiniteSet(2, n) + assert Complement(S.Reals, a) == Complement(S.Reals, a, evaluate=False) + assert Complement(Interval(1, 3), a) == Complement(Interval(1, 3), a, evaluate=False) + assert Complement(Interval(1, 3), b) == \ + Complement(Union(Interval(1, 2, False, True), Interval(2, 3, True, False)), a) + assert Complement(a, S.Reals) == Complement(a, S.Reals, evaluate=False) + assert Complement(a, Interval(1, 3)) == Complement(a, Interval(1, 3), evaluate=False) + + +def test_issue_9808(): + # See https://github.com/sympy/sympy/issues/16342 + assert Complement(FiniteSet(y), FiniteSet(1)) == Complement(FiniteSet(y), FiniteSet(1), evaluate=False) + assert Complement(FiniteSet(1, 2, x), FiniteSet(x, y, 2, 3)) == \ + Complement(FiniteSet(1), FiniteSet(y), evaluate=False) + + +def test_issue_9956(): + assert Union(Interval(-oo, oo), FiniteSet(1)) == Interval(-oo, oo) + assert Interval(-oo, oo).contains(1) is S.true + + +def test_issue_Symbol_inter(): + i = Interval(0, oo) + r = S.Reals + mat = Matrix([0, 0, 0]) + assert Intersection(r, i, FiniteSet(m), FiniteSet(m, n)) == \ + Intersection(i, FiniteSet(m)) + assert Intersection(FiniteSet(1, m, n), FiniteSet(m, n, 2), i) == \ + Intersection(i, FiniteSet(m, n)) + assert Intersection(FiniteSet(m, n, x), FiniteSet(m, z), r) == \ + Intersection(Intersection({m, z}, {m, n, x}), r) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, x), r) == \ + Intersection(FiniteSet(3, m, n), FiniteSet(m, n, x), r, evaluate=False) + assert Intersection(FiniteSet(m, n, 3), FiniteSet(m, n, 2, 3), r) == \ + Intersection(FiniteSet(3, m, n), r) + assert Intersection(r, FiniteSet(mat, 2, n), FiniteSet(0, mat, n)) == \ + Intersection(r, FiniteSet(n)) + assert Intersection(FiniteSet(sin(x), cos(x)), FiniteSet(sin(x), cos(x), 1), r) == \ + Intersection(r, FiniteSet(sin(x), cos(x))) + assert Intersection(FiniteSet(x**2, 1, sin(x)), FiniteSet(x**2, 2, sin(x)), r) == \ + Intersection(r, FiniteSet(x**2, sin(x))) + + +def test_issue_11827(): + assert S.Naturals0**4 + + +def test_issue_10113(): + f = x**2/(x**2 - 4) + assert imageset(x, f, S.Reals) == Union(Interval(-oo, 0), Interval(1, oo, True, True)) + assert imageset(x, f, Interval(-2, 2)) == Interval(-oo, 0) + assert imageset(x, f, Interval(-2, 3)) == Union(Interval(-oo, 0), Interval(Rational(9, 5), oo)) + + +def test_issue_10248(): + raises( + TypeError, lambda: list(Intersection(S.Reals, FiniteSet(x))) + ) + A = Symbol('A', real=True) + assert list(Intersection(S.Reals, FiniteSet(A))) == [A] + + +def test_issue_9447(): + a = Interval(0, 1) + Interval(2, 3) + assert Complement(S.UniversalSet, a) == Complement( + S.UniversalSet, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + assert Complement(S.Naturals, a) == Complement( + S.Naturals, Union(Interval(0, 1), Interval(2, 3)), evaluate=False) + + +def test_issue_10337(): + assert (FiniteSet(2) == 3) is False + assert (FiniteSet(2) != 3) is True + raises(TypeError, lambda: FiniteSet(2) < 3) + raises(TypeError, lambda: FiniteSet(2) <= 3) + raises(TypeError, lambda: FiniteSet(2) > 3) + raises(TypeError, lambda: FiniteSet(2) >= 3) + + +def test_issue_10326(): + bad = [ + EmptySet, + FiniteSet(1), + Interval(1, 2), + S.ComplexInfinity, + S.ImaginaryUnit, + S.Infinity, + S.NaN, + S.NegativeInfinity, + ] + interval = Interval(0, 5) + for i in bad: + assert i not in interval + + x = Symbol('x', real=True) + nr = Symbol('nr', extended_real=False) + assert x + 1 in Interval(x, x + 4) + assert nr not in Interval(x, x + 4) + assert Interval(1, 2) in FiniteSet(Interval(0, 5), Interval(1, 2)) + assert Interval(-oo, oo).contains(oo) is S.false + assert Interval(-oo, oo).contains(-oo) is S.false + + +def test_issue_2799(): + U = S.UniversalSet + a = Symbol('a', real=True) + inf_interval = Interval(a, oo) + R = S.Reals + + assert U + inf_interval == inf_interval + U + assert U + R == R + U + assert R + inf_interval == inf_interval + R + + +def test_issue_9706(): + assert Interval(-oo, 0).closure == Interval(-oo, 0, True, False) + assert Interval(0, oo).closure == Interval(0, oo, False, True) + assert Interval(-oo, oo).closure == Interval(-oo, oo) + + +def test_issue_8257(): + reals_plus_infinity = Union(Interval(-oo, oo), FiniteSet(oo)) + reals_plus_negativeinfinity = Union(Interval(-oo, oo), FiniteSet(-oo)) + assert Interval(-oo, oo) + FiniteSet(oo) == reals_plus_infinity + assert FiniteSet(oo) + Interval(-oo, oo) == reals_plus_infinity + assert Interval(-oo, oo) + FiniteSet(-oo) == reals_plus_negativeinfinity + assert FiniteSet(-oo) + Interval(-oo, oo) == reals_plus_negativeinfinity + + +def test_issue_10931(): + assert S.Integers - S.Integers == EmptySet + assert S.Integers - S.Reals == EmptySet + + +def test_issue_11174(): + soln = Intersection(Interval(-oo, oo), FiniteSet(-x), evaluate=False) + assert Intersection(FiniteSet(-x), S.Reals) == soln + + soln = Intersection(S.Reals, FiniteSet(x), evaluate=False) + assert Intersection(FiniteSet(x), S.Reals) == soln + + +def test_issue_18505(): + assert ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers).contains(0) == \ + Contains(0, ImageSet(Lambda(n, sqrt(pi*n/2 - 1 + pi/2)), S.Integers)) + + +def test_finite_set_intersection(): + # The following should not produce recursion errors + # Note: some of these are not completely correct. See + # https://github.com/sympy/sympy/issues/16342. + assert Intersection(FiniteSet(-oo, x), FiniteSet(x)) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(0, x)]) == FiniteSet(x) + + assert Intersection._handle_finite_sets([FiniteSet(-oo, x), FiniteSet(x)]) == FiniteSet(x) + assert Intersection._handle_finite_sets([FiniteSet(2, 3, x, y), FiniteSet(1, 2, x)]) == \ + Intersection._handle_finite_sets([FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)]) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, 3, x, y)) == \ + Intersection(FiniteSet(1, 2, x), FiniteSet(2, x, y)) + + assert FiniteSet(1+x-y) & FiniteSet(1) == \ + FiniteSet(1) & FiniteSet(1+x-y) == \ + Intersection(FiniteSet(1+x-y), FiniteSet(1), evaluate=False) + + assert FiniteSet(1) & FiniteSet(x) == FiniteSet(x) & FiniteSet(1) == \ + Intersection(FiniteSet(1), FiniteSet(x), evaluate=False) + + assert FiniteSet({x}) & FiniteSet({x, y}) == \ + Intersection(FiniteSet({x}), FiniteSet({x, y}), evaluate=False) + + +def test_union_intersection_constructor(): + # The actual exception does not matter here, so long as these fail + sets = [FiniteSet(1), FiniteSet(2)] + raises(Exception, lambda: Union(sets)) + raises(Exception, lambda: Intersection(sets)) + raises(Exception, lambda: Union(tuple(sets))) + raises(Exception, lambda: Intersection(tuple(sets))) + raises(Exception, lambda: Union(i for i in sets)) + raises(Exception, lambda: Intersection(i for i in sets)) + + # Python sets are treated the same as FiniteSet + # The union of a single set (of sets) is the set (of sets) itself + assert Union(set(sets)) == FiniteSet(*sets) + assert Intersection(set(sets)) == FiniteSet(*sets) + + assert Union({1}, {2}) == FiniteSet(1, 2) + assert Intersection({1, 2}, {2, 3}) == FiniteSet(2) + + +def test_Union_contains(): + assert zoo not in Union( + Interval.open(-oo, 0), Interval.open(0, oo)) + + +@XFAIL +def test_issue_16878b(): + # in intersection_sets for (ImageSet, Set) there is no code + # that handles the base_set of S.Reals like there is + # for Integers + assert imageset(x, (x, x), S.Reals).is_subset(S.Reals**2) is True + +def test_DisjointUnion(): + assert DisjointUnion(FiniteSet(1, 2, 3), FiniteSet(1, 2, 3), FiniteSet(1, 2, 3)).rewrite(Union) == (FiniteSet(1, 2, 3) * FiniteSet(0, 1, 2)) + assert DisjointUnion(Interval(1, 3), Interval(2, 4)).rewrite(Union) == Union(Interval(1, 3) * FiniteSet(0), Interval(2, 4) * FiniteSet(1)) + assert DisjointUnion(Interval(0, 5), Interval(0, 5)).rewrite(Union) == Union(Interval(0, 5) * FiniteSet(0), Interval(0, 5) * FiniteSet(1)) + assert DisjointUnion(Interval(-1, 2), S.EmptySet, S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(Interval(-1, 2)).rewrite(Union) == Interval(-1, 2) * FiniteSet(0) + assert DisjointUnion(S.EmptySet, Interval(-1, 2), S.EmptySet).rewrite(Union) == Interval(-1, 2) * FiniteSet(1) + assert DisjointUnion(Interval(-oo, oo)).rewrite(Union) == Interval(-oo, oo) * FiniteSet(0) + assert DisjointUnion(S.EmptySet).rewrite(Union) == S.EmptySet + assert DisjointUnion().rewrite(Union) == S.EmptySet + raises(TypeError, lambda: DisjointUnion(Symbol('n'))) + + x = Symbol("x") + y = Symbol("y") + z = Symbol("z") + assert DisjointUnion(FiniteSet(x), FiniteSet(y, z)).rewrite(Union) == (FiniteSet(x) * FiniteSet(0)) + (FiniteSet(y, z) * FiniteSet(1)) + +def test_DisjointUnion_is_empty(): + assert DisjointUnion(S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_empty is True + assert DisjointUnion(S.EmptySet, FiniteSet(1, 2, 3)).is_empty is False + +def test_DisjointUnion_is_iterable(): + assert DisjointUnion(S.Integers, S.Naturals, S.Rationals).is_iterable is True + assert DisjointUnion(S.EmptySet, S.Reals).is_iterable is False + assert DisjointUnion(FiniteSet(1, 2, 3), S.EmptySet, FiniteSet(x, y)).is_iterable is True + assert DisjointUnion(S.EmptySet, S.EmptySet).is_iterable is False + +def test_DisjointUnion_contains(): + assert (0, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (1, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 0) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 1) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (2, 2) in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 1, 2) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (0, 0.5) not in DisjointUnion(FiniteSet(0.5)) + assert (0, 5) not in DisjointUnion(FiniteSet(0, 1, 2), FiniteSet(0, 1, 2), FiniteSet(0, 1, 2)) + assert (x, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (z, 0) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (y, 2) in DisjointUnion(FiniteSet(x, y, z), S.EmptySet, FiniteSet(y)) + assert (0.5, 0) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (0.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 0) not in DisjointUnion(Interval(0, 1), Interval(0, 2)) + assert (1.5, 1) in DisjointUnion(Interval(0, 1), Interval(0, 2)) + +def test_DisjointUnion_iter(): + D = DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z)) + it = iter(D) + L1 = [(x, 1), (y, 1), (z, 1)] + L2 = [(3, 0), (5, 0), (7, 0), (9, 0)] + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + nxt = next(it) + assert nxt in L1 + L1.remove(nxt) + nxt = next(it) + assert nxt in L2 + L2.remove(nxt) + raises(StopIteration, lambda: next(it)) + + raises(ValueError, lambda: iter(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_DisjointUnion_len(): + assert len(DisjointUnion(FiniteSet(3, 5, 7, 9), FiniteSet(x, y, z))) == 7 + assert len(DisjointUnion(S.EmptySet, S.EmptySet, FiniteSet(x, y, z), S.EmptySet)) == 3 + raises(ValueError, lambda: len(DisjointUnion(Interval(0, 1), S.EmptySet))) + +def test_SetKind_ProductSet(): + p = ProductSet(FiniteSet(Matrix([1, 2])), FiniteSet(Matrix([1, 2]))) + mk = MatrixKind(NumberKind) + k = SetKind(TupleKind(mk, mk)) + assert p.kind is k + assert ProductSet(Interval(1, 2), FiniteSet(Matrix([1, 2]))).kind is SetKind(TupleKind(NumberKind, mk)) + +def test_SetKind_Interval(): + assert Interval(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_EmptySet_UniversalSet(): + assert S.UniversalSet.kind is SetKind(UndefinedKind) + assert EmptySet.kind is SetKind() + +def test_SetKind_FiniteSet(): + assert FiniteSet(1, Matrix([1, 2])).kind is SetKind(UndefinedKind) + assert FiniteSet(1, 2).kind is SetKind(NumberKind) + +def test_SetKind_Unions(): + assert Union(FiniteSet(Matrix([1, 2])), Interval(1, 2)).kind is SetKind(UndefinedKind) + assert Union(Interval(1, 2), Interval(1, 7)).kind is SetKind(NumberKind) + +def test_SetKind_DisjointUnion(): + A = FiniteSet(1, 2, 3) + B = Interval(0, 5) + assert DisjointUnion(A, B).kind is SetKind(NumberKind) + +def test_SetKind_evaluate_False(): + U = lambda *args: Union(*args, evaluate=False) + assert U({1}, EmptySet).kind is SetKind(NumberKind) + assert U(Interval(1, 2), EmptySet).kind is SetKind(NumberKind) + assert U({1}, S.UniversalSet).kind is SetKind(UndefinedKind) + assert U(Interval(1, 2), Interval(4, 5), + FiniteSet(1)).kind is SetKind(NumberKind) + I = lambda *args: Intersection(*args, evaluate=False) + assert I({1}, S.UniversalSet).kind is SetKind(NumberKind) + assert I({1}, EmptySet).kind is SetKind() + C = lambda *args: Complement(*args, evaluate=False) + assert C(S.UniversalSet, {1, 2, 4, 5}).kind is SetKind(UndefinedKind) + assert C({1, 2, 3, 4, 5}, EmptySet).kind is SetKind(NumberKind) + assert C(EmptySet, {1, 2, 3, 4, 5}).kind is SetKind() + +def test_SetKind_ImageSet_Special(): + f = ImageSet(Lambda(n, n ** 2), Interval(1, 4)) + assert (f - FiniteSet(3)).kind is SetKind(NumberKind) + assert (f + Interval(16, 17)).kind is SetKind(NumberKind) + assert (f + FiniteSet(17)).kind is SetKind(NumberKind) + +def test_issue_20089(): + B = FiniteSet(FiniteSet(1, 2), FiniteSet(1)) + assert 1 not in B + assert 1.0 not in B + assert not Eq(1, FiniteSet(1, 2)) + assert FiniteSet(1) in B + A = FiniteSet(1, 2) + assert A in B + assert B.issubset(B) + assert not A.issubset(B) + assert 1 in A + C = FiniteSet(FiniteSet(1, 2), FiniteSet(1), 1, 2) + assert A.issubset(C) + assert B.issubset(C) + +def test_issue_19378(): + a = FiniteSet(1, 2) + b = ProductSet(a, a) + c = FiniteSet((1, 1), (1, 2), (2, 1), (2, 2)) + assert b.is_subset(c) is True + d = FiniteSet(1) + assert b.is_subset(d) is False + assert Eq(c, b).simplify() is S.true + assert Eq(a, c).simplify() is S.false + assert Eq({1}, {x}).simplify() == Eq({1}, {x}) + +def test_intersection_symbolic(): + n = Symbol('n') + # These should not throw an error + assert isinstance(Intersection(Range(n), Range(100)), Intersection) + assert isinstance(Intersection(Range(n), Interval(1, 100)), Intersection) + assert isinstance(Intersection(Range(100), Interval(1, n)), Intersection) + + +@XFAIL +def test_intersection_symbolic_failing(): + n = Symbol('n', integer=True, positive=True) + assert Intersection(Range(10, n), Range(4, 500, 5)) == Intersection( + Range(14, n), Range(14, 500, 5)) + assert Intersection(Interval(10, n), Range(4, 500, 5)) == Intersection( + Interval(14, n), Range(14, 500, 5)) + + +def test_issue_20379(): + #https://github.com/sympy/sympy/issues/20379 + x = pi - 3.14159265358979 + assert FiniteSet(x).evalf(2) == FiniteSet(Float('3.23108914886517e-15', 2)) + +def test_finiteset_simplify(): + S = FiniteSet(1, cos(1)**2 + sin(1)**2) + assert S.simplify() == {1} + +def test_issue_14336(): + #https://github.com/sympy/sympy/issues/14336 + U = S.Complexes + x = Symbol("x") + U -= U.intersect(Ne(x, 1).as_set()) + U -= U.intersect(S.true.as_set()) + +def test_issue_9855(): + #https://github.com/sympy/sympy/issues/9855 + x, y, z = symbols('x, y, z', real=True) + s1 = Interval(1, x) & Interval(y, 2) + s2 = Interval(1, 2) + assert s1.is_subset(s2) == None diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b9f31b58d0034f032300ac599d0585de496613ce Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/combsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/combsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1800c8d5161f9c829050e06b81e485eddb6b41ba Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/combsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/cse_main.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/cse_main.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0c55d99889c83b507af278eba5c3bec42457cd21 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/cse_main.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a26623bf6f50f12cb1d69eec3d6bb2663701f5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/cse_opts.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/epathtools.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/epathtools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ecb6de7f8a26f8d9755f3c59a2e04e2a38f2725 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/epathtools.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/fu.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/fu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d7ea34d2a1b6139179f30ac70232f526d92518a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/fu.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b8d5f6ae38022c75509a036289062629f51f5784 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/gammasimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/hyperexpand.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/hyperexpand.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..257fd1fdc105625e14cd6f8cbd63a45b05194f43 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/hyperexpand.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..68b4b9a73ae97db0ef945cd060ae3344ae8ee0c2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/hyperexpand_doc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/powsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/powsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1b7a975064b78b267195c8a05907afa40227ddec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/powsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/radsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/radsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f9b0bae41f18c37657cf314a1fce757a257fbe4a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/radsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..65480dbee65981a26a405ab39f5d1a6793bf2c4b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/ratsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/simplify.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/simplify.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7d8e928dc9e8c8d2ec9121d43a50cab98f8ad3ec Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/simplify.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9289c0cd30e5e00bd09f9c5814047c6d0de55f28 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/sqrtdenest.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d9580b62f395898f668c9478d4234a03a8ebe809 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/traversaltools.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..824ec115fe203c39f36c529baaf9ab563b832146 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/__pycache__/trigsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__init__.py b/lib/python3.10/site-packages/sympy/simplify/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f925c7c4d91bdd9a96cd766813aff9752894b08d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..22f1345714d7b2ca059f38d932480493775c6e42 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_combsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5d1ff8bf95fac3a66ecb35c2c72437c5a6c2f739 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_cse.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5e7e9a707ca2a18b4d4ab38504835a8de1c3da02 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_epathtools.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e9d4c2766aa6cac14faf06ee2130223e74b771f2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_fu.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2daa50494eaf391172060f960feb3499023b0a37 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_function.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d388e5195801ada946904b485224baea787d840 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_gammasimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..684fd1e89bfd58f959cd6fb967865c486332ac5c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_hyperexpand.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe349d6d3378a3c2f91f4dd2fab25e69e7ac7c3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_powsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..41a535f527e94deea18ef38e2733641beadbf9e6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_radsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..089828f818baf054f88e00f278420e2f4cbed185 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_ratsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7cf940aa0a820427aded7bbc4672b1e9725cb45 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_rewrite.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0d2c5ff3c5ef995126eba3f80d64ffd7057a952c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_simplify.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c7fdd5b5acda2b2eb2a3775d8626873d1ceec16c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_sqrtdenest.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a8e6d407835408d28f32958b67f1d87bd03325f5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/simplify/tests/__pycache__/test_trigsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_combsimp.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_combsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e56758a005fbb013c2b6ea4121b16c3434a54b03 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_combsimp.py @@ -0,0 +1,75 @@ +from sympy.core.numbers import Rational +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (FallingFactorial, RisingFactorial, binomial, factorial) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.combsimp import combsimp +from sympy.abc import x + + +def test_combsimp(): + k, m, n = symbols('k m n', integer = True) + + assert combsimp(factorial(n)) == factorial(n) + assert combsimp(binomial(n, k)) == binomial(n, k) + + assert combsimp(factorial(n)/factorial(n - 3)) == n*(-1 + n)*(-2 + n) + assert combsimp(binomial(n + 1, k + 1)/binomial(n, k)) == (1 + n)/(1 + k) + + assert combsimp(binomial(3*n + 4, n + 1)/binomial(3*n + 1, n)) == \ + Rational(3, 2)*((3*n + 2)*(3*n + 4)/((n + 1)*(2*n + 3))) + + assert combsimp(factorial(n)**2/factorial(n - 3)) == \ + factorial(n)*n*(-1 + n)*(-2 + n) + assert combsimp(factorial(n)*binomial(n + 1, k + 1)/binomial(n, k)) == \ + factorial(n + 1)/(1 + k) + + assert combsimp(gamma(n + 3)) == factorial(n + 2) + + assert combsimp(factorial(x)) == gamma(x + 1) + + # issue 9699 + assert combsimp((n + 1)*factorial(n)) == factorial(n + 1) + assert combsimp(factorial(n)/n) == factorial(n-1) + + # issue 6658 + assert combsimp(binomial(n, n - k)) == binomial(n, k) + + # issue 6341, 7135 + assert combsimp(factorial(n)/(factorial(k)*factorial(n - k))) == \ + binomial(n, k) + assert combsimp(factorial(k)*factorial(n - k)/factorial(n)) == \ + 1/binomial(n, k) + assert combsimp(factorial(2*n)/factorial(n)**2) == binomial(2*n, n) + assert combsimp(factorial(2*n)*factorial(k)*factorial(n - k)/ + factorial(n)**3) == binomial(2*n, n)/binomial(n, k) + + assert combsimp(factorial(n*(1 + n) - n**2 - n)) == 1 + + assert combsimp(6*FallingFactorial(-4, n)/factorial(n)) == \ + (-1)**n*(n + 1)*(n + 2)*(n + 3) + assert combsimp(6*FallingFactorial(-4, n - 1)/factorial(n - 1)) == \ + (-1)**(n - 1)*n*(n + 1)*(n + 2) + assert combsimp(6*FallingFactorial(-4, n - 3)/factorial(n - 3)) == \ + (-1)**(n - 3)*n*(n - 1)*(n - 2) + assert combsimp(6*FallingFactorial(-4, -n - 1)/factorial(-n - 1)) == \ + -(-1)**(-n - 1)*n*(n - 1)*(n - 2) + + assert combsimp(6*RisingFactorial(4, n)/factorial(n)) == \ + (n + 1)*(n + 2)*(n + 3) + assert combsimp(6*RisingFactorial(4, n - 1)/factorial(n - 1)) == \ + n*(n + 1)*(n + 2) + assert combsimp(6*RisingFactorial(4, n - 3)/factorial(n - 3)) == \ + n*(n - 1)*(n - 2) + assert combsimp(6*RisingFactorial(4, -n - 1)/factorial(-n - 1)) == \ + -n*(n - 1)*(n - 2) + + +def test_issue_6878(): + n = symbols('n', integer=True) + assert combsimp(RisingFactorial(-10, n)) == 3628800*(-1)**n/factorial(10 - n) + + +def test_issue_14528(): + p = symbols("p", integer=True, positive=True) + assert combsimp(binomial(1,p)) == 1/(factorial(p)*factorial(1-p)) + assert combsimp(factorial(2-p)) == factorial(2-p) diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_cse.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_cse.py new file mode 100644 index 0000000000000000000000000000000000000000..4e3955252ee65b1f7d949c23822508b7ef0a0dd9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_cse.py @@ -0,0 +1,758 @@ +from functools import reduce +import itertools +from operator import add + +from sympy.codegen.matrix_nodes import MatrixSolve +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.expr import UnevaluatedExpr +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.dense import Matrix +from sympy.matrices.expressions import Inverse, MatAdd, MatMul, Transpose +from sympy.polys.rootoftools import CRootOf +from sympy.series.order import O +from sympy.simplify.cse_main import cse +from sympy.simplify.simplify import signsimp +from sympy.tensor.indexed import (Idx, IndexedBase) + +from sympy.core.function import count_ops +from sympy.simplify.cse_opts import sub_pre, sub_post +from sympy.functions.special.hyper import meijerg +from sympy.simplify import cse_main, cse_opts +from sympy.utilities.iterables import subsets +from sympy.testing.pytest import XFAIL, raises +from sympy.matrices import (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix) +from sympy.matrices.expressions import MatrixSymbol + + +w, x, y, z = symbols('w,x,y,z') +x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12 = symbols('x:13') + + +def test_numbered_symbols(): + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('y%s' % i) for i in range(0, 10)] + ns = cse_main.numbered_symbols(prefix='y') + assert list(itertools.islice( + ns, 10, 20)) == [Symbol('y%s' % i) for i in range(10, 20)] + ns = cse_main.numbered_symbols() + assert list(itertools.islice( + ns, 0, 10)) == [Symbol('x%s' % i) for i in range(0, 10)] + +# Dummy "optimization" functions for testing. + + +def opt1(expr): + return expr + y + + +def opt2(expr): + return expr*z + + +def test_preprocess_for_cse(): + assert cse_main.preprocess_for_cse(x, [(opt1, None)]) == x + y + assert cse_main.preprocess_for_cse(x, [(None, opt1)]) == x + assert cse_main.preprocess_for_cse(x, [(None, None)]) == x + assert cse_main.preprocess_for_cse(x, [(opt1, opt2)]) == x + y + assert cse_main.preprocess_for_cse( + x, [(opt1, None), (opt2, None)]) == (x + y)*z + + +def test_postprocess_for_cse(): + assert cse_main.postprocess_for_cse(x, [(opt1, None)]) == x + assert cse_main.postprocess_for_cse(x, [(None, opt1)]) == x + y + assert cse_main.postprocess_for_cse(x, [(None, None)]) == x + assert cse_main.postprocess_for_cse(x, [(opt1, opt2)]) == x*z + # Note the reverse order of application. + assert cse_main.postprocess_for_cse( + x, [(None, opt1), (None, opt2)]) == x*z + y + + +def test_cse_single(): + # Simple substitution. + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + + subst42, (red42,) = cse([42]) # issue_15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse([0.5]) + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_single2(): + # Simple substitution, test for being able to pass the expression directly + e = Add(Pow(x + y, 2), sqrt(x + y)) + substs, reduced = cse(e) + assert substs == [(x0, x + y)] + assert reduced == [sqrt(x0) + x0**2] + substs, reduced = cse(Matrix([[1]])) + assert isinstance(reduced[0], Matrix) + + subst42, (red42,) = cse(42) # issue 15082 + assert len(subst42) == 0 and red42 == 42 + subst_half, (red_half,) = cse(0.5) # issue 15082 + assert len(subst_half) == 0 and red_half == 0.5 + + +def test_cse_not_possible(): + # No substitution possible. + e = Add(x, y) + substs, reduced = cse([e]) + assert substs == [] + assert reduced == [x + y] + # issue 6329 + eq = (meijerg((1, 2), (y, 4), (5,), [], x) + + meijerg((1, 3), (y, 4), (5,), [], x)) + assert cse(eq) == ([], [eq]) + + +def test_nested_substitution(): + # Substitution within a substitution. + e = Add(Pow(w*x + y, 2), sqrt(w*x + y)) + substs, reduced = cse([e]) + assert substs == [(x0, w*x + y)] + assert reduced == [sqrt(x0) + x0**2] + + +def test_subtraction_opt(): + # Make sure subtraction is optimized. + e = (x - y)*(z - y) + exp((x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [-x0 + exp(-x0)] + e = -(x - y)*(z - y) + exp(-(x - y)*(z - y)) + substs, reduced = cse( + [e], optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) + assert substs == [(x0, (x - y)*(y - z))] + assert reduced == [x0 + exp(x0)] + # issue 4077 + n = -1 + 1/x + e = n/x/(-n)**2 - 1/n/x + assert cse(e, optimizations=[(cse_opts.sub_pre, cse_opts.sub_post)]) == \ + ([], [0]) + assert cse(((w + x + y + z)*(w - y - z))/(w + x)**3) == \ + ([(x0, w + x), (x1, y + z)], [(w - x1)*(x0 + x1)/x0**3]) + + +def test_multiple_expressions(): + e1 = (x + y)*z + e2 = (x + y)*w + substs, reduced = cse([e1, e2]) + assert substs == [(x0, x + y)] + assert reduced == [x0*z, x0*w] + l = [w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [z + x*x0, x0] + l = [w*x*y, w*x*y + z, w*y] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == rsubsts + assert reduced == [x1, x1 + z, x0] + l = [(x - z)*(y - z), x - z, y - z] + substs, reduced = cse(l) + rsubsts, _ = cse(reversed(l)) + assert substs == [(x0, -z), (x1, x + x0), (x2, x0 + y)] + assert rsubsts == [(x0, -z), (x1, x0 + y), (x2, x + x0)] + assert reduced == [x1*x2, x1, x2] + l = [w*y + w + x + y + z, w*x*y] + assert cse(l) == ([(x0, w*y)], [w + x + x0 + y + z, x*x0]) + assert cse([x + y, x + y + z]) == ([(x0, x + y)], [x0, z + x0]) + assert cse([x + y, x + z]) == ([], [x + y, x + z]) + assert cse([x*y, z + x*y, x*y*z + 3]) == \ + ([(x0, x*y)], [x0, z + x0, 3 + x0*z]) + + +@XFAIL # CSE of non-commutative Mul terms is disabled +def test_non_commutative_cse(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([(x0, A*B)], [x0*C, x0]) + + +# Test if CSE of non-commutative Mul terms is disabled +def test_bypass_non_commutatives(): + A, B, C = symbols('A B C', commutative=False) + l = [A*B*C, A*C] + assert cse(l) == ([], l) + l = [A*B*C, A*B] + assert cse(l) == ([], l) + l = [B*C, A*B*C] + assert cse(l) == ([], l) + + +@XFAIL # CSE fails when replacing non-commutative sub-expressions +def test_non_commutative_order(): + A, B, C = symbols('A B C', commutative=False) + x0 = symbols('x0', commutative=False) + l = [B+C, A*(B+C)] + assert cse(l) == ([(x0, B+C)], [x0, A*x0]) + + +@XFAIL # Worked in gh-11232, but was reverted due to performance considerations +def test_issue_10228(): + assert cse([x*y**2 + x*y]) == ([(x0, x*y)], [x0*y + x0]) + assert cse([x + y, 2*x + y]) == ([(x0, x + y)], [x0, x + x0]) + assert cse((w + 2*x + y + z, w + x + 1)) == ( + [(x0, w + x)], [x0 + x + y + z, x0 + 1]) + assert cse(((w + x + y + z)*(w - x))/(w + x)) == ( + [(x0, w + x)], [(x0 + y + z)*(w - x)/x0]) + a, b, c, d, f, g, j, m = symbols('a, b, c, d, f, g, j, m') + exprs = (d*g**2*j*m, 4*a*f*g*m, a*b*c*f**2) + assert cse(exprs) == ( + [(x0, g*m), (x1, a*f)], [d*g*j*x0, 4*x0*x1, b*c*f*x1] +) + +@XFAIL +def test_powers(): + assert cse(x*y**2 + x*y) == ([(x0, x*y)], [x0*y + x0]) + + +def test_issue_4498(): + assert cse(w/(x - y) + z/(y - x), optimizations='basic') == \ + ([], [(w - z)/(x - y)]) + + +def test_issue_4020(): + assert cse(x**5 + x**4 + x**3 + x**2, optimizations='basic') \ + == ([(x0, x**2)], [x0*(x**3 + x + x0 + 1)]) + + +def test_issue_4203(): + assert cse(sin(x**x)/x**x) == ([(x0, x**x)], [sin(x0)/x0]) + + +def test_issue_6263(): + e = Eq(x*(-x + 1) + x*(x - 1), 0) + assert cse(e, optimizations='basic') == ([], [True]) + + +def test_issue_25043(): + c = symbols("c") + x = symbols("x0", real=True) + cse_expr = cse(c*x**2 + c*(x**4 - x**2))[-1][-1] + free = cse_expr.free_symbols + assert len(free) == len({i.name for i in free}) + + +def test_dont_cse_tuples(): + from sympy.core.function import Subs + f = Function("f") + g = Function("g") + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + assert name_val == [] + assert expr == (Subs(f(x, y), (x, y), (0, 1)) + + Subs(g(x, y), (x, y), (0, 1))) + + name_val, (expr,) = cse( + Subs(f(x, y), (x, y), (0, x + y)) + + Subs(g(x, y), (x, y), (0, x + y))) + + assert name_val == [(x0, x + y)] + assert expr == Subs(f(x, y), (x, y), (0, x0)) + \ + Subs(g(x, y), (x, y), (0, x0)) + + +def test_pow_invpow(): + assert cse(1/x**2 + x**2) == \ + ([(x0, x**2)], [x0 + 1/x0]) + assert cse(x**2 + (1 + 1/x**2)/x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0 + x1*(x1 + 1)]) + assert cse(1/x**2 + (1 + 1/x**2)*x**2) == \ + ([(x0, x**2), (x1, 1/x0)], [x0*(x1 + 1) + x1]) + assert cse(cos(1/x**2) + sin(1/x**2)) == \ + ([(x0, x**(-2))], [sin(x0) + cos(x0)]) + assert cse(cos(x**2) + sin(x**2)) == \ + ([(x0, x**2)], [sin(x0) + cos(x0)]) + assert cse(y/(2 + x**2) + z/x**2/y) == \ + ([(x0, x**2)], [y/(x0 + 2) + z/(x0*y)]) + assert cse(exp(x**2) + x**2*cos(1/x**2)) == \ + ([(x0, x**2)], [x0*cos(1/x0) + exp(x0)]) + assert cse((1 + 1/x**2)/x**2) == \ + ([(x0, x**(-2))], [x0*(x0 + 1)]) + assert cse(x**(2*y) + x**(-2*y)) == \ + ([(x0, x**(2*y))], [x0 + 1/x0]) + + +def test_postprocess(): + eq = (x + 1 + exp((x + 1)/(y + 1)) + cos(y + 1)) + assert cse([eq, Eq(x, z + 1), z - 2, (z + 1)*(x + 1)], + postprocess=cse_main.cse_separate) == \ + [[(x0, y + 1), (x2, z + 1), (x, x2), (x1, x + 1)], + [x1 + exp(x1/x0) + cos(x0), z - 2, x1*x2]] + + +def test_issue_4499(): + # previously, this gave 16 constants + from sympy.abc import a, b + B = Function('B') + G = Function('G') + t = Tuple(* + (a, a + S.Half, 2*a, b, 2*a - b + 1, (sqrt(z)/2)**(-2*a + 1)*B(2*a - + b, sqrt(z))*B(b - 1, sqrt(z))*G(b)*G(2*a - b + 1), + sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b, + sqrt(z))*G(b)*G(2*a - b + 1), sqrt(z)*(sqrt(z)/2)**(-2*a + 1)*B(b - 1, + sqrt(z))*B(2*a - b + 1, sqrt(z))*G(b)*G(2*a - b + 1), + (sqrt(z)/2)**(-2*a + 1)*B(b, sqrt(z))*B(2*a - b + 1, + sqrt(z))*G(b)*G(2*a - b + 1), 1, 0, S.Half, z/2, -b + 1, -2*a + b, + -2*a)) + c = cse(t) + ans = ( + [(x0, 2*a), (x1, -b + x0), (x2, x1 + 1), (x3, b - 1), (x4, sqrt(z)), + (x5, B(x3, x4)), (x6, (x4/2)**(1 - x0)*G(b)*G(x2)), (x7, x6*B(x1, x4)), + (x8, B(b, x4)), (x9, x6*B(x2, x4))], + [(a, a + S.Half, x0, b, x2, x5*x7, x4*x7*x8, x4*x5*x9, x8*x9, + 1, 0, S.Half, z/2, -x3, -x1, -x0)]) + assert ans == c + + +def test_issue_6169(): + r = CRootOf(x**6 - 4*x**5 - 2, 1) + assert cse(r) == ([], [r]) + # and a check that the right thing is done with the new + # mechanism + assert sub_post(sub_pre((-x - y)*z - x - y)) == -z*(x + y) - x - y + + +def test_cse_Indexed(): + len_y = 5 + y = IndexedBase('y', shape=(len_y,)) + x = IndexedBase('x', shape=(len_y,)) + i = Idx('i', len_y-1) + + expr1 = (y[i+1]-y[i])/(x[i+1]-x[i]) + expr2 = 1/(x[i+1]-x[i]) + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + +def test_cse_MatrixSymbol(): + # MatrixSymbols have non-Basic args, so make sure that works + A = MatrixSymbol("A", 3, 3) + assert cse(A) == ([], [A]) + + n = symbols('n', integer=True) + B = MatrixSymbol("B", n, n) + assert cse(B) == ([], [B]) + + assert cse(A[0] * A[0]) == ([], [A[0]*A[0]]) + + assert cse(A[0,0]*A[0,1] + A[0,0]*A[0,1]*A[0,2]) == ([(x0, A[0, 0]*A[0, 1])], [x0*A[0, 2] + x0]) + +def test_cse_MatrixExpr(): + A = MatrixSymbol('A', 3, 3) + y = MatrixSymbol('y', 3, 1) + + expr1 = (A.T*A).I * A * y + expr2 = (A.T*A) * A * y + replacements, reduced_exprs = cse([expr1, expr2]) + assert len(replacements) > 0 + + replacements, reduced_exprs = cse([expr1 + expr2, expr1]) + assert replacements + + replacements, reduced_exprs = cse([A**2, A + A**2]) + assert replacements + + +def test_Piecewise(): + f = Piecewise((-z + x*y, Eq(y, 0)), (-z - x*y, True)) + ans = cse(f) + actual_ans = ([(x0, x*y)], + [Piecewise((x0 - z, Eq(y, 0)), (-z - x0, True))]) + assert ans == actual_ans + + +def test_ignore_order_terms(): + eq = exp(x).series(x,0,3) + sin(y+x**3) - 1 + assert cse(eq) == ([], [sin(x**3 + y) + x + x**2/2 + O(x**3)]) + + +def test_name_conflict(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_name_conflict_cust_symbols(): + z1 = x0 + y + z2 = x2 + x3 + l = [cos(z1) + z1, cos(z2) + z2, x0 + x2] + substs, reduced = cse(l, symbols("x:10")) + assert [e.subs(reversed(substs)) for e in reduced] == l + + +def test_symbols_exhausted_error(): + l = cos(x+y)+x+y+cos(w+y)+sin(w+y) + sym = [x, y, z] + with raises(ValueError): + cse(l, symbols=sym) + + +def test_issue_7840(): + # daveknippers' example + C393 = sympify( \ + 'Piecewise((C391 - 1.65, C390 < 0.5), (Piecewise((C391 - 1.65, \ + C391 > 2.35), (C392, True)), True))' + ) + C391 = sympify( \ + 'Piecewise((2.05*C390**(-1.03), C390 < 0.5), (2.5*C390**(-0.625), True))' + ) + C393 = C393.subs('C391',C391) + # simple substitution + sub = {} + sub['C390'] = 0.703451854 + sub['C392'] = 1.01417794 + ss_answer = C393.subs(sub) + # cse + substitutions,new_eqn = cse(C393) + for pair in substitutions: + sub[pair[0].name] = pair[1].subs(sub) + cse_answer = new_eqn[0].subs(sub) + # both methods should be the same + assert ss_answer == cse_answer + + # GitRay's example + expr = sympify( + "Piecewise((Symbol('ON'), Equality(Symbol('mode'), Symbol('ON'))), \ + (Piecewise((Piecewise((Symbol('OFF'), StrictLessThan(Symbol('x'), \ + Symbol('threshold'))), (Symbol('ON'), true)), Equality(Symbol('mode'), \ + Symbol('AUTO'))), (Symbol('OFF'), true)), true))" + ) + substitutions, new_eqn = cse(expr) + # this Piecewise should be exactly the same + assert new_eqn[0] == expr + # there should not be any replacements + assert len(substitutions) < 1 + + +def test_issue_8891(): + for cls in (MutableDenseMatrix, MutableSparseMatrix, + ImmutableDenseMatrix, ImmutableSparseMatrix): + m = cls(2, 2, [x + y, 0, 0, 0]) + res = cse([x + y, m]) + ans = ([(x0, x + y)], [x0, cls([[x0, 0], [0, 0]])]) + assert res == ans + assert isinstance(res[1][-1], cls) + + +def test_issue_11230(): + # a specific test that always failed + a, b, f, k, l, i = symbols('a b f k l i') + p = [a*b*f*k*l, a*i*k**2*l, f*i*k**2*l] + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + + # random tests for the issue + from sympy.core.random import choice + from sympy.core.function import expand_mul + s = symbols('a:m') + # 35 Mul tests, none of which should ever fail + ex = [Mul(*[choice(s) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Mul for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + assert p == C + # 35 Add tests, none of which should ever fail + ex = [Add(*[choice(s[:7]) for i in range(5)]) for i in range(7)] + for p in subsets(ex, 3): + p = list(p) + R, C = cse(p) + assert not any(i.is_Add for a in C for i in a.args) + for ri in reversed(R): + for i in range(len(C)): + C[i] = C[i].subs(*ri) + # use expand_mul to handle cases like this: + # p = [a + 2*b + 2*e, 2*b + c + 2*e, b + 2*c + 2*g] + # x0 = 2*(b + e) is identified giving a rebuilt p that + # is now `[a + 2*(b + e), c + 2*(b + e), b + 2*c + 2*g]` + assert p == [expand_mul(i) for i in C] + + +@XFAIL +def test_issue_11577(): + def check(eq): + r, c = cse(eq) + assert eq.count_ops() >= \ + len(r) + sum(i[1].count_ops() for i in r) + \ + count_ops(c) + + eq = x**5*y**2 + x**5*y + x**5 + assert cse(eq) == ( + [(x0, x**4), (x1, x*y)], [x**5 + x0*x1*y + x0*x1]) + # ([(x0, x**5*y)], [x0*y + x0 + x**5]) or + # ([(x0, x**5)], [x0*y**2 + x0*y + x0]) + check(eq) + + eq = x**2/(y + 1)**2 + x/(y + 1) + assert cse(eq) == ( + [(x0, y + 1)], [x**2/x0**2 + x/x0]) + # ([(x0, x/(y + 1))], [x0**2 + x0]) + check(eq) + + +def test_hollow_rejection(): + eq = [x + 3, x + 4] + assert cse(eq) == ([], eq) + + +def test_cse_ignore(): + exprs = [exp(y)*(3*y + 3*sqrt(x+1)), exp(y)*(5*y + 5*sqrt(x+1))] + subst1, red1 = cse(exprs) + assert any(y in sub.free_symbols for _, sub in subst1), "cse failed to identify any term with y" + + subst2, red2 = cse(exprs, ignore=(y,)) # y is not allowed in substitutions + assert not any(y in sub.free_symbols for _, sub in subst2), "Sub-expressions containing y must be ignored" + assert any(sub - sqrt(x + 1) == 0 for _, sub in subst2), "cse failed to identify sqrt(x + 1) as sub-expression" + +def test_cse_ignore_issue_15002(): + l = [ + w*exp(x)*exp(-z), + exp(y)*exp(x)*exp(-z) + ] + substs, reduced = cse(l, ignore=(x,)) + rl = [e.subs(reversed(substs)) for e in reduced] + assert rl == l + + +def test_cse_unevaluated(): + xp1 = UnevaluatedExpr(x + 1) + # This used to cause RecursionError + [(x0, ue)], [red] = cse([(-1 - xp1) / (1 - xp1)]) + if ue == xp1: + assert red == (-1 - x0) / (1 - x0) + elif ue == -xp1: + assert red == (-1 + x0) / (1 + x0) + else: + msg = f'Expected common subexpression {xp1} or {-xp1}, instead got {ue}' + assert False, msg + + +def test_cse__performance(): + nexprs, nterms = 3, 20 + x = symbols('x:%d' % nterms) + exprs = [ + reduce(add, [x[j]*(-1)**(i+j) for j in range(nterms)]) + for i in range(nexprs) + ] + assert (exprs[0] + exprs[1]).simplify() == 0 + subst, red = cse(exprs) + assert len(subst) > 0, "exprs[0] == -exprs[2], i.e. a CSE" + for i, e in enumerate(red): + assert (e.subs(reversed(subst)) - exprs[i]).simplify() == 0 + + +def test_issue_12070(): + exprs = [x + y, 2 + x + y, x + y + z, 3 + x + y + z] + subst, red = cse(exprs) + assert 6 >= (len(subst) + sum(v.count_ops() for k, v in subst) + + count_ops(red)) + + +def test_issue_13000(): + eq = x/(-4*x**2 + y**2) + cse_eq = cse(eq)[1][0] + assert cse_eq == eq + + +def test_issue_18203(): + eq = CRootOf(x**5 + 11*x - 2, 0) + CRootOf(x**5 + 11*x - 2, 1) + assert cse(eq) == ([], [eq]) + + +def test_unevaluated_mul(): + eq = Mul(x + y, x + y, evaluate=False) + assert cse(eq) == ([(x0, x + y)], [x0**2]) + +def test_cse_release_variables(): + from sympy.simplify.cse_main import cse_release_variables + _0, _1, _2, _3, _4 = symbols('_:5') + eqs = [(x + y - 1)**2, x, + x + y, (x + y)/(2*x + 1) + (x + y - 1)**2, + (2*x + 1)**(x + y)] + r, e = cse(eqs, postprocess=cse_release_variables) + # this can change in keeping with the intention of the function + assert r, e == ([ + (x0, x + y), (x1, (x0 - 1)**2), (x2, 2*x + 1), + (_3, x0/x2 + x1), (_4, x2**x0), (x2, None), (_0, x1), + (x1, None), (_2, x0), (x0, None), (_1, x)], (_0, _1, _2, _3, _4)) + r.reverse() + r = [(s, v) for s, v in r if v is not None] + assert eqs == [i.subs(r) for i in e] + +def test_cse_list(): + _cse = lambda x: cse(x, list=False) + assert _cse(x) == ([], x) + assert _cse('x') == ([], 'x') + it = [x] + for c in (list, tuple, set): + assert _cse(c(it)) == ([], c(it)) + #Tuple works different from tuple: + assert _cse(Tuple(*it)) == ([], Tuple(*it)) + d = {x: 1} + assert _cse(d) == ([], d) + +def test_issue_18991(): + A = MatrixSymbol('A', 2, 2) + assert signsimp(-A * A - A) == -A * A - A + + +def test_unevaluated_Mul(): + m = [Mul(1, 2, evaluate=False)] + assert cse(m) == ([], m) + + +def test_cse_matrix_expression_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = Inverse(A) + cse_expr = cse(x) + assert cse_expr == ([], [Inverse(A)]) + + +def test_cse_matrix_expression_matmul_inverse(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatMul(Inverse(A), b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matrix(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_negate_matmul_not_extracted(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, A, B) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +@XFAIL # No simplification rule for nested associative operations +def test_cse_matrix_nested_matmul_collapsed(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + B = ImmutableDenseMatrix(symbols('B:4')).reshape(2, 2) + x = MatMul(S.NegativeOne, MatMul(A, B)) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(S.NegativeOne, A, B)]) + + +def test_cse_matrix_optimize_out_single_argument_mul(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatMul(MatMul(A))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_mul_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatMul(MatMul(MatMul(A))), MatMul(MatMul(A)), MatMul(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_optimize_out_single_argument_add(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatAdd(MatAdd(MatAdd(MatAdd(A)))) + cse_expr = cse(x) + assert cse_expr == ([], [A]) + + +@XFAIL # Multiple simplification passed not supported in CSE +def test_cse_matrix_optimize_out_single_argument_add_combined(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + x = MatMul(MatAdd(MatAdd(MatAdd(A))), MatAdd(MatAdd(A)), MatAdd(A), A) + cse_expr = cse(x) + assert cse_expr == ([], [MatMul(4, A)]) + + +def test_cse_matrix_expression_matrix_solve(): + A = ImmutableDenseMatrix(symbols('A:4')).reshape(2, 2) + b = ImmutableDenseMatrix(symbols('b:2')) + x = MatrixSolve(A, b) + cse_expr = cse(x) + assert cse_expr == ([], [x]) + + +def test_cse_matrix_matrix_expression(): + X = ImmutableDenseMatrix(symbols('X:4')).reshape(2, 2) + y = ImmutableDenseMatrix(symbols('y:2')) + b = MatMul(Inverse(MatMul(Transpose(X), X)), Transpose(X), y) + cse_expr = cse(b) + x0 = MatrixSymbol('x0', 2, 2) + reduced_expr_expected = MatMul(Inverse(MatMul(x0, X)), x0, y) + assert cse_expr == ([(x0, Transpose(X))], [reduced_expr_expected]) + + +def test_cse_matrix_kalman_filter(): + """Kalman Filter example from Matthew Rocklin's SciPy 2013 talk. + + Talk titled: "Matrix Expressions and BLAS/LAPACK; SciPy 2013 Presentation" + + Video: https://pyvideo.org/scipy-2013/matrix-expressions-and-blaslapack-scipy-2013-pr.html + + Notes + ===== + + Equations are: + + new_mu = mu + Sigma*H.T * (R + H*Sigma*H.T).I * (H*mu - data) + = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = Sigma - Sigma*H.T * (R + H*Sigma*H.T).I * H * Sigma + = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H)), Inverse(MatAdd(R, MatMul(H*Sigma*Transpose(H)))), H, Sigma)) + + """ + N = 2 + mu = ImmutableDenseMatrix(symbols(f'mu:{N}')) + Sigma = ImmutableDenseMatrix(symbols(f'Sigma:{N * N}')).reshape(N, N) + H = ImmutableDenseMatrix(symbols(f'H:{N * N}')).reshape(N, N) + R = ImmutableDenseMatrix(symbols(f'R:{N * N}')).reshape(N, N) + data = ImmutableDenseMatrix(symbols(f'data:{N}')) + new_mu = MatAdd(mu, MatMul(Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))) + new_Sigma = MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, Transpose(H), Inverse(MatAdd(R, MatMul(H, Sigma, Transpose(H)))), H, Sigma)) + cse_expr = cse([new_mu, new_Sigma]) + x0 = MatrixSymbol('x0', N, N) + x1 = MatrixSymbol('x1', N, N) + replacements_expected = [ + (x0, Transpose(H)), + (x1, Inverse(MatAdd(R, MatMul(H, Sigma, x0)))), + ] + reduced_exprs_expected = [ + MatAdd(mu, MatMul(Sigma, x0, x1, MatAdd(MatMul(H, mu), MatMul(S.NegativeOne, data)))), + MatAdd(Sigma, MatMul(S.NegativeOne, Sigma, x0, x1, H, Sigma)), + ] + assert cse_expr == (replacements_expected, reduced_exprs_expected) diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_epathtools.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_epathtools.py new file mode 100644 index 0000000000000000000000000000000000000000..a8bb47b2f2ff624077ab9905677b181c587ab5a7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_epathtools.py @@ -0,0 +1,90 @@ +"""Tests for tools for manipulation of expressions using paths. """ + +from sympy.simplify.epathtools import epath, EPath +from sympy.testing.pytest import raises + +from sympy.core.numbers import E +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.abc import x, y, z, t + + +def test_epath_select(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + + assert epath("/*", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/*/*", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/*/*/*", expr) == [x, 1, t, 3, y, 4] + assert epath("/*/*/*/*", expr) == [] + + assert epath("/[:]", expr) == [((x, 1, t), 2), ((3, y, 4), z)] + assert epath("/[:]/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + assert epath("/[:]/[:]/[:]", expr) == [x, 1, t, 3, y, 4] + assert epath("/[:]/[:]/[:]/[:]", expr) == [] + + assert epath("/*/[:]", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/[1]", expr) == [2, z] + assert epath("/*/[2]", expr) == [] + + assert epath("/*/int", expr) == [2] + assert epath("/*/Symbol", expr) == [z] + assert epath("/*/tuple", expr) == [(x, 1, t), (3, y, 4)] + assert epath("/*/__iter__?", expr) == [(x, 1, t), (3, y, 4)] + + assert epath("/*/int|tuple", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|tuple", expr) == [(x, 1, t), (3, y, 4), z] + assert epath("/*/int|Symbol|tuple", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/int|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4)] + assert epath("/*/Symbol|__iter__?", expr) == [(x, 1, t), (3, y, 4), z] + assert epath( + "/*/int|Symbol|__iter__?", expr) == [(x, 1, t), 2, (3, y, 4), z] + + assert epath("/*/[0]/int", expr) == [1, 3, 4] + assert epath("/*/[0]/Symbol", expr) == [x, t, y] + + assert epath("/*/[0]/int[1:]", expr) == [1, 4] + assert epath("/*/[0]/Symbol[1:]", expr) == [t, y] + + assert epath("/Symbol", x + y + z + 1) == [x, y, z] + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E)) == [x, x, y] + + +def test_epath_apply(): + expr = [((x, 1, t), 2), ((3, y, 4), z)] + func = lambda expr: expr**2 + + assert epath("/*", expr, list) == [[(x, 1, t), 2], [(3, y, 4), z]] + + assert epath("/*/[0]", expr, list) == [([x, 1, t], 2), ([3, y, 4], z)] + assert epath("/*/[1]", expr, func) == [((x, 1, t), 4), ((3, y, 4), z**2)] + assert epath("/*/[2]", expr, list) == expr + + assert epath("/*/[0]/int", expr, func) == [((x, 1, t), 2), ((9, y, 16), z)] + assert epath("/*/[0]/Symbol", expr, func) == [((x**2, 1, t**2), 2), + ((3, y**2, 4), z)] + assert epath( + "/*/[0]/int[1:]", expr, func) == [((x, 1, t), 2), ((3, y, 16), z)] + assert epath("/*/[0]/Symbol[1:]", expr, func) == [((x, 1, t**2), + 2), ((3, y**2, 4), z)] + + assert epath("/Symbol", x + y + z + 1, func) == x**2 + y**2 + z**2 + 1 + assert epath("/*/*/Symbol", t + sin(x + 1) + cos(x + y + E), func) == \ + t + sin(x**2 + 1) + cos(x**2 + y**2 + E) + + +def test_EPath(): + assert EPath("/*/[0]")._path == "/*/[0]" + assert EPath(EPath("/*/[0]"))._path == "/*/[0]" + assert isinstance(epath("/*/[0]"), EPath) is True + + assert repr(EPath("/*/[0]")) == "EPath('/*/[0]')" + + raises(ValueError, lambda: EPath("")) + raises(ValueError, lambda: EPath("/")) + raises(ValueError, lambda: EPath("/|x")) + raises(ValueError, lambda: EPath("/[")) + raises(ValueError, lambda: EPath("/[0]%")) + + raises(NotImplementedError, lambda: EPath("Symbol")) diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_fu.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_fu.py new file mode 100644 index 0000000000000000000000000000000000000000..2de2126b7333195fceeffe72dc9cb642e7eba9a9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_fu.py @@ -0,0 +1,492 @@ +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.parameters import evaluate +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.hyperbolic import (cosh, coth, csch, sech, sinh, tanh) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, cot, csc, sec, sin, tan) +from sympy.simplify.powsimp import powsimp +from sympy.simplify.fu import ( + L, TR1, TR10, TR10i, TR11, _TR11, TR12, TR12i, TR13, TR14, TR15, TR16, + TR111, TR2, TR2i, TR3, TR4, TR5, TR6, TR7, TR8, TR9, TRmorrie, _TR56 as T, + TRpower, hyper_as_trig, fu, process_common_addends, trig_split, + as_f_sign_1) +from sympy.core.random import verify_numerically +from sympy.abc import a, b, c, x, y, z + + +def test_TR1(): + assert TR1(2*csc(x) + sec(x)) == 1/cos(x) + 2/sin(x) + + +def test_TR2(): + assert TR2(tan(x)) == sin(x)/cos(x) + assert TR2(cot(x)) == cos(x)/sin(x) + assert TR2(tan(tan(x) - sin(x)/cos(x))) == 0 + + +def test_TR2i(): + # just a reminder that ratios of powers only simplify if both + # numerator and denominator satisfy the condition that each + # has a positive base or an integer exponent; e.g. the following, + # at y=-1, x=1/2 gives sqrt(2)*I != -sqrt(2)*I + assert powsimp(2**x/y**x) != (2/y)**x + + assert TR2i(sin(x)/cos(x)) == tan(x) + assert TR2i(sin(x)*sin(y)/cos(x)) == tan(x)*sin(y) + assert TR2i(1/(sin(x)/cos(x))) == 1/tan(x) + assert TR2i(1/(sin(x)*sin(y)/cos(x))) == 1/tan(x)/sin(y) + assert TR2i(sin(x)/2/(cos(x) + 1)) == sin(x)/(cos(x) + 1)/2 + + assert TR2i(sin(x)/2/(cos(x) + 1), half=True) == tan(x/2)/2 + assert TR2i(sin(1)/(cos(1) + 1), half=True) == tan(S.Half) + assert TR2i(sin(2)/(cos(2) + 1), half=True) == tan(1) + assert TR2i(sin(4)/(cos(4) + 1), half=True) == tan(2) + assert TR2i(sin(5)/(cos(5) + 1), half=True) == tan(5*S.Half) + assert TR2i((cos(1) + 1)/sin(1), half=True) == 1/tan(S.Half) + assert TR2i((cos(2) + 1)/sin(2), half=True) == 1/tan(1) + assert TR2i((cos(4) + 1)/sin(4), half=True) == 1/tan(2) + assert TR2i((cos(5) + 1)/sin(5), half=True) == 1/tan(5*S.Half) + assert TR2i((cos(1) + 1)**(-a)*sin(1)**a, half=True) == tan(S.Half)**a + assert TR2i((cos(2) + 1)**(-a)*sin(2)**a, half=True) == tan(1)**a + assert TR2i((cos(4) + 1)**(-a)*sin(4)**a, half=True) == (cos(4) + 1)**(-a)*sin(4)**a + assert TR2i((cos(5) + 1)**(-a)*sin(5)**a, half=True) == (cos(5) + 1)**(-a)*sin(5)**a + assert TR2i((cos(1) + 1)**a*sin(1)**(-a), half=True) == tan(S.Half)**(-a) + assert TR2i((cos(2) + 1)**a*sin(2)**(-a), half=True) == tan(1)**(-a) + assert TR2i((cos(4) + 1)**a*sin(4)**(-a), half=True) == (cos(4) + 1)**a*sin(4)**(-a) + assert TR2i((cos(5) + 1)**a*sin(5)**(-a), half=True) == (cos(5) + 1)**a*sin(5)**(-a) + + i = symbols('i', integer=True) + assert TR2i(((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**(-i) + assert TR2i(1/((cos(5) + 1)**i*sin(5)**(-i)), half=True) == tan(5*S.Half)**i + + +def test_TR3(): + assert TR3(cos(y - x*(y - x))) == cos(x*(x - y) + y) + assert cos(pi/2 + x) == -sin(x) + assert cos(30*pi/2 + x) == -cos(x) + + for f in (cos, sin, tan, cot, csc, sec): + i = f(pi*Rational(3, 7)) + j = TR3(i) + assert verify_numerically(i, j) and i.func != j.func + + with evaluate(False): + eq = cos(9*pi/22) + assert eq.has(9*pi) and TR3(eq) == sin(pi/11) + + +def test_TR4(): + for i in [0, pi/6, pi/4, pi/3, pi/2]: + with evaluate(False): + eq = cos(i) + assert isinstance(eq, cos) and TR4(eq) == cos(i) + + +def test__TR56(): + h = lambda x: 1 - x + assert T(sin(x)**3, sin, cos, h, 4, False) == sin(x)*(-cos(x)**2 + 1) + assert T(sin(x)**10, sin, cos, h, 4, False) == sin(x)**10 + assert T(sin(x)**6, sin, cos, h, 6, False) == (-cos(x)**2 + 1)**3 + assert T(sin(x)**6, sin, cos, h, 6, True) == sin(x)**6 + assert T(sin(x)**8, sin, cos, h, 10, True) == (-cos(x)**2 + 1)**4 + + # issue 17137 + assert T(sin(x)**I, sin, cos, h, 4, True) == sin(x)**I + assert T(sin(x)**(2*I + 1), sin, cos, h, 4, True) == sin(x)**(2*I + 1) + + +def test_TR5(): + assert TR5(sin(x)**2) == -cos(x)**2 + 1 + assert TR5(sin(x)**-2) == sin(x)**(-2) + assert TR5(sin(x)**4) == (-cos(x)**2 + 1)**2 + + +def test_TR6(): + assert TR6(cos(x)**2) == -sin(x)**2 + 1 + assert TR6(cos(x)**-2) == cos(x)**(-2) + assert TR6(cos(x)**4) == (-sin(x)**2 + 1)**2 + + +def test_TR7(): + assert TR7(cos(x)**2) == cos(2*x)/2 + S.Half + assert TR7(cos(x)**2 + 1) == cos(2*x)/2 + Rational(3, 2) + + +def test_TR8(): + assert TR8(cos(2)*cos(3)) == cos(5)/2 + cos(1)/2 + assert TR8(cos(2)*sin(3)) == sin(5)/2 + sin(1)/2 + assert TR8(sin(2)*sin(3)) == -cos(5)/2 + cos(1)/2 + assert TR8(sin(1)*sin(2)*sin(3)) == sin(4)/4 - sin(6)/4 + sin(2)/4 + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)) == \ + cos(4)/4 + cos(10)/8 + cos(2)/8 + cos(8)/8 + cos(14)/8 + \ + cos(6)/8 + Rational(1, 8) + assert TR8(cos(2)*cos(3)*cos(4)*cos(5)*cos(6)) == \ + cos(10)/8 + cos(4)/8 + 3*cos(2)/16 + cos(16)/16 + cos(8)/8 + \ + cos(14)/16 + cos(20)/16 + cos(12)/16 + Rational(1, 16) + cos(6)/8 + assert TR8(sin(pi*Rational(3, 7))**2*cos(pi*Rational(3, 7))**2/(16*sin(pi/7)**2)) == Rational(1, 64) + +def test_TR9(): + a = S.Half + b = 3*a + assert TR9(a) == a + assert TR9(cos(1) + cos(2)) == 2*cos(a)*cos(b) + assert TR9(cos(1) - cos(2)) == 2*sin(a)*sin(b) + assert TR9(sin(1) - sin(2)) == -2*sin(a)*cos(b) + assert TR9(sin(1) + sin(2)) == 2*sin(b)*cos(a) + assert TR9(cos(1) + 2*sin(1) + 2*sin(2)) == cos(1) + 4*sin(b)*cos(a) + assert TR9(cos(4) + cos(2) + 2*cos(1)*cos(3)) == 4*cos(1)*cos(3) + assert TR9((cos(4) + cos(2))/cos(3)/2 + cos(3)) == 2*cos(1)*cos(2) + assert TR9(cos(3) + cos(4) + cos(5) + cos(6)) == \ + 4*cos(S.Half)*cos(1)*cos(Rational(9, 2)) + assert TR9(cos(3) + cos(3)*cos(2)) == cos(3) + cos(2)*cos(3) + assert TR9(-cos(y) + cos(x*y)) == -2*sin(x*y/2 - y/2)*sin(x*y/2 + y/2) + assert TR9(-sin(y) + sin(x*y)) == 2*sin(x*y/2 - y/2)*cos(x*y/2 + y/2) + c = cos(x) + s = sin(x) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for a in ((c, s), (s, c), (cos(x), cos(x*y)), (sin(x), sin(x*y))): + args = zip(si, a) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR9(ex) + assert not (a[0].func == a[1].func and ( + not verify_numerically(ex, t.expand(trig=True)) or t.is_Add) + or a[1].func != a[0].func and ex != t) + + +def test_TR10(): + assert TR10(cos(a + b)) == -sin(a)*sin(b) + cos(a)*cos(b) + assert TR10(sin(a + b)) == sin(a)*cos(b) + sin(b)*cos(a) + assert TR10(sin(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*sin(c) + \ + (sin(a)*cos(b) + sin(b)*cos(a))*cos(c) + assert TR10(cos(a + b + c)) == \ + (-sin(a)*sin(b) + cos(a)*cos(b))*cos(c) - \ + (sin(a)*cos(b) + sin(b)*cos(a))*sin(c) + + +def test_TR10i(): + assert TR10i(cos(1)*cos(3) + sin(1)*sin(3)) == cos(2) + assert TR10i(cos(1)*cos(3) - sin(1)*sin(3)) == cos(4) + assert TR10i(cos(1)*sin(3) - sin(1)*cos(3)) == sin(2) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3)) == sin(4) + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + 7) == sin(4) + 7 + assert TR10i(cos(1)*sin(3) + sin(1)*cos(3) + cos(3)) == cos(3) + sin(4) + assert TR10i(2*cos(1)*sin(3) + 2*sin(1)*cos(3) + cos(3)) == \ + 2*sin(4) + cos(3) + assert TR10i(cos(2)*cos(3) + sin(2)*(cos(1)*sin(2) + cos(2)*sin(1))) == \ + cos(1) + eq = (cos(2)*cos(3) + sin(2)*( + cos(1)*sin(2) + cos(2)*sin(1)))*cos(5) + sin(1)*sin(5) + assert TR10i(eq) == TR10i(eq.expand()) == cos(4) + assert TR10i(sqrt(2)*cos(x)*x + sqrt(6)*sin(x)*x) == \ + 2*sqrt(2)*x*sin(x + pi/6) + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(x)/sqrt(6)/3 + sin(x)/sqrt(2)/3) == 4*sqrt(6)*sin(x + pi/6)/9 + assert TR10i(cos(x)/sqrt(6) + sin(x)/sqrt(2) + + cos(y)/sqrt(6)/3 + sin(y)/sqrt(2)/3) == \ + sqrt(6)*sin(x + pi/6)/3 + sqrt(6)*sin(y + pi/6)/9 + assert TR10i(cos(x) + sqrt(3)*sin(x) + 2*sqrt(3)*cos(x + pi/6)) == 4*cos(x) + assert TR10i(cos(x) + sqrt(3)*sin(x) + + 2*sqrt(3)*cos(x + pi/6) + 4*sin(x)) == 4*sqrt(2)*sin(x + pi/4) + assert TR10i(cos(2)*sin(3) + sin(2)*cos(4)) == \ + sin(2)*cos(4) + sin(3)*cos(2) + + A = Symbol('A', commutative=False) + assert TR10i(sqrt(2)*cos(x)*A + sqrt(6)*sin(x)*A) == \ + 2*sqrt(2)*sin(x + pi/6)*A + + + c = cos(x) + s = sin(x) + h = sin(y) + r = cos(y) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # explicit 2-args + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + c = cos(x) + s = sin(x) + h = sin(pi/6) + r = cos(pi/6) + for si in ((1, 1), (1, -1), (-1, 1), (-1, -1)): + for argsi in ((c*r, s*h), (c*h, s*r)): # induced + args = zip(si, argsi) + ex = Add(*[Mul(*ai) for ai in args]) + t = TR10i(ex) + assert not (ex - t.expand(trig=True) or t.is_Add) + + +def test_TR11(): + + assert TR11(sin(2*x)) == 2*sin(x)*cos(x) + assert TR11(sin(4*x)) == 4*((-sin(x)**2 + cos(x)**2)*sin(x)*cos(x)) + assert TR11(sin(x*Rational(4, 3))) == \ + 4*((-sin(x/3)**2 + cos(x/3)**2)*sin(x/3)*cos(x/3)) + + assert TR11(cos(2*x)) == -sin(x)**2 + cos(x)**2 + assert TR11(cos(4*x)) == \ + (-sin(x)**2 + cos(x)**2)**2 - 4*sin(x)**2*cos(x)**2 + + assert TR11(cos(2)) == cos(2) + + assert TR11(cos(pi*Rational(3, 7)), pi*Rational(2, 7)) == -cos(pi*Rational(2, 7))**2 + sin(pi*Rational(2, 7))**2 + assert TR11(cos(4), 2) == -sin(2)**2 + cos(2)**2 + assert TR11(cos(6), 2) == cos(6) + assert TR11(sin(x)/cos(x/2), x/2) == 2*sin(x/2) + +def test__TR11(): + + assert _TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) == \ + 4*sin(x/8)*sin(x/6)*sin(2*x),_TR11(sin(x/3)*sin(2*x)*sin(x/4)/(cos(x/6)*cos(x/8))) + assert _TR11(sin(x/3)/cos(x/6)) == 2*sin(x/6) + + assert _TR11(cos(x/6)/sin(x/3)) == 1/(2*sin(x/6)) + assert _TR11(sin(2*x)*cos(x/8)/sin(x/4)) == sin(2*x)/(2*sin(x/8)), _TR11(sin(2*x)*cos(x/8)/sin(x/4)) + assert _TR11(sin(x)/sin(x/2)) == 2*cos(x/2) + + +def test_TR12(): + assert TR12(tan(x + y)) == (tan(x) + tan(y))/(-tan(x)*tan(y) + 1) + assert TR12(tan(x + y + z)) ==\ + (tan(z) + (tan(x) + tan(y))/(-tan(x)*tan(y) + 1))/( + 1 - (tan(x) + tan(y))*tan(z)/(-tan(x)*tan(y) + 1)) + assert TR12(tan(x*y)) == tan(x*y) + + +def test_TR13(): + assert TR13(tan(3)*tan(2)) == -tan(2)/tan(5) - tan(3)/tan(5) + 1 + assert TR13(cot(3)*cot(2)) == 1 + cot(3)*cot(5) + cot(2)*cot(5) + assert TR13(tan(1)*tan(2)*tan(3)) == \ + (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*tan(1) + assert TR13(tan(1)*tan(2)*cot(3)) == \ + (-tan(2)/tan(3) + 1 - tan(1)/tan(3))*cot(3) + + +def test_L(): + assert L(cos(x) + sin(x)) == 2 + + +def test_fu(): + + assert fu(sin(50)**2 + cos(50)**2 + sin(pi/6)) == Rational(3, 2) + assert fu(sqrt(6)*cos(x) + sqrt(2)*sin(x)) == 2*sqrt(2)*sin(x + pi/3) + + + eq = sin(x)**4 - cos(y)**2 + sin(y)**2 + 2*cos(x)**2 + assert fu(eq) == cos(x)**4 - 2*cos(y)**2 + 2 + + assert fu(S.Half - cos(2*x)/2) == sin(x)**2 + + assert fu(sin(a)*(cos(b) - sin(b)) + cos(a)*(sin(b) + cos(b))) == \ + sqrt(2)*sin(a + b + pi/4) + + assert fu(sqrt(3)*cos(x)/2 + sin(x)/2) == sin(x + pi/3) + + assert fu(1 - sin(2*x)**2/4 - sin(y)**2 - cos(x)**4) == \ + -cos(x)**2 + cos(y)**2 + + assert fu(cos(pi*Rational(4, 9))) == sin(pi/18) + assert fu(cos(pi/9)*cos(pi*Rational(2, 9))*cos(pi*Rational(3, 9))*cos(pi*Rational(4, 9))) == Rational(1, 16) + + assert fu( + tan(pi*Rational(7, 18)) + tan(pi*Rational(5, 18)) - sqrt(3)*tan(pi*Rational(5, 18))*tan(pi*Rational(7, 18))) == \ + -sqrt(3) + + assert fu(tan(1)*tan(2)) == tan(1)*tan(2) + + expr = Mul(*[cos(2**i) for i in range(10)]) + assert fu(expr) == sin(1024)/(1024*sin(1)) + + # issue #18059: + assert fu(cos(x) + sqrt(sin(x)**2)) == cos(x) + sqrt(sin(x)**2) + + assert fu((-14*sin(x)**3 + 35*sin(x) + 6*sqrt(3)*cos(x)**3 + 9*sqrt(3)*cos(x))/((cos(2*x) + 4))) == \ + 7*sin(x) + 3*sqrt(3)*cos(x) + + +def test_objective(): + assert fu(sin(x)/cos(x), measure=lambda x: x.count_ops()) == \ + tan(x) + assert fu(sin(x)/cos(x), measure=lambda x: -x.count_ops()) == \ + sin(x)/cos(x) + + +def test_process_common_addends(): + # this tests that the args are not evaluated as they are given to do + # and that key2 works when key1 is False + do = lambda x: Add(*[i**(i%2) for i in x.args]) + assert process_common_addends(Add(*[1, 2, 3, 4], evaluate=False), do, + key2=lambda x: x%2, key1=False) == 1**1 + 3**1 + 2**0 + 4**0 + + +def test_trig_split(): + assert trig_split(cos(x), cos(y)) == (1, 1, 1, x, y, True) + assert trig_split(2*cos(x), -2*cos(y)) == (2, 1, -1, x, y, True) + assert trig_split(cos(x)*sin(y), cos(y)*sin(y)) == \ + (sin(y), 1, 1, x, y, True) + + assert trig_split(cos(x), -sqrt(3)*sin(x), two=True) == \ + (2, 1, -1, x, pi/6, False) + assert trig_split(cos(x), sin(x), two=True) == \ + (sqrt(2), 1, 1, x, pi/4, False) + assert trig_split(cos(x), -sin(x), two=True) == \ + (sqrt(2), 1, -1, x, pi/4, False) + assert trig_split(sqrt(2)*cos(x), -sqrt(6)*sin(x), two=True) == \ + (2*sqrt(2), 1, -1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x), -sqrt(2)*sin(x), two=True) == \ + (-2*sqrt(2), 1, 1, x, pi/3, False) + assert trig_split(cos(x)/sqrt(6), sin(x)/sqrt(2), two=True) == \ + (sqrt(6)/3, 1, 1, x, pi/6, False) + assert trig_split(-sqrt(6)*cos(x)*sin(y), + -sqrt(2)*sin(x)*sin(y), two=True) == \ + (-2*sqrt(2)*sin(y), 1, 1, x, pi/3, False) + + assert trig_split(cos(x), sin(x)) is None + assert trig_split(cos(x), sin(z)) is None + assert trig_split(2*cos(x), -sin(x)) is None + assert trig_split(cos(x), -sqrt(3)*sin(x)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(z)) is None + assert trig_split(cos(x)*cos(y), sin(x)*sin(y)) is None + assert trig_split(-sqrt(6)*cos(x), sqrt(2)*sin(x)*sin(y), two=True) is \ + None + + assert trig_split(sqrt(3)*sqrt(x), cos(3), two=True) is None + assert trig_split(sqrt(3)*root(x, 3), sin(3)*cos(2), two=True) is None + assert trig_split(cos(5)*cos(6), cos(7)*sin(5), two=True) is None + + +def test_TRmorrie(): + assert TRmorrie(7*Mul(*[cos(i) for i in range(10)])) == \ + 7*sin(12)*sin(16)*cos(5)*cos(7)*cos(9)/(64*sin(1)*sin(3)) + assert TRmorrie(x) == x + assert TRmorrie(2*x) == 2*x + e = cos(pi/7)*cos(pi*Rational(2, 7))*cos(pi*Rational(4, 7)) + assert TR8(TRmorrie(e)) == Rational(-1, 8) + e = Mul(*[cos(2**i*pi/17) for i in range(1, 17)]) + assert TR8(TR3(TRmorrie(e))) == Rational(1, 65536) + # issue 17063 + eq = cos(x)/cos(x/2) + assert TRmorrie(eq) == eq + # issue #20430 + eq = cos(x/2)*sin(x/2)*cos(x)**3 + assert TRmorrie(eq) == sin(2*x)*cos(x)**2/4 + + +def test_TRpower(): + assert TRpower(1/sin(x)**2) == 1/sin(x)**2 + assert TRpower(cos(x)**3*sin(x/2)**4) == \ + (3*cos(x)/4 + cos(3*x)/4)*(-cos(x)/2 + cos(2*x)/8 + Rational(3, 8)) + for k in range(2, 8): + assert verify_numerically(sin(x)**k, TRpower(sin(x)**k)) + assert verify_numerically(cos(x)**k, TRpower(cos(x)**k)) + + +def test_hyper_as_trig(): + from sympy.simplify.fu import _osborne, _osbornei + + eq = sinh(x)**2 + cosh(x)**2 + t, f = hyper_as_trig(eq) + assert f(fu(t)) == cosh(2*x) + e, f = hyper_as_trig(tanh(x + y)) + assert f(TR12(e)) == (tanh(x) + tanh(y))/(tanh(x)*tanh(y) + 1) + + d = Dummy() + assert _osborne(sinh(x), d) == I*sin(x*d) + assert _osborne(tanh(x), d) == I*tan(x*d) + assert _osborne(coth(x), d) == cot(x*d)/I + assert _osborne(cosh(x), d) == cos(x*d) + assert _osborne(sech(x), d) == sec(x*d) + assert _osborne(csch(x), d) == csc(x*d)/I + for func in (sinh, cosh, tanh, coth, sech, csch): + h = func(pi) + assert _osbornei(_osborne(h, d), d) == h + # /!\ the _osborne functions are not meant to work + # in the o(i(trig, d), d) direction so we just check + # that they work as they are supposed to work + assert _osbornei(cos(x*y + z), y) == cosh(x + z*I) + assert _osbornei(sin(x*y + z), y) == sinh(x + z*I)/I + assert _osbornei(tan(x*y + z), y) == tanh(x + z*I)/I + assert _osbornei(cot(x*y + z), y) == coth(x + z*I)*I + assert _osbornei(sec(x*y + z), y) == sech(x + z*I) + assert _osbornei(csc(x*y + z), y) == csch(x + z*I)*I + + +def test_TR12i(): + ta, tb, tc = [tan(i) for i in (a, b, c)] + assert TR12i((ta + tb)/(-ta*tb + 1)) == tan(a + b) + assert TR12i((ta + tb)/(ta*tb - 1)) == -tan(a + b) + assert TR12i((-ta - tb)/(ta*tb - 1)) == tan(a + b) + eq = (ta + tb)/(-ta*tb + 1)**2*(-3*ta - 3*tc)/(2*(ta*tc - 1)) + assert TR12i(eq.expand()) == \ + -3*tan(a + b)*tan(a + c)/(tan(a) + tan(b) - 1)/2 + assert TR12i(tan(x)/sin(x)) == tan(x)/sin(x) + eq = (ta + cos(2))/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (ta + tb + 2)**2/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = ta/(-ta*tb + 1) + assert TR12i(eq) == eq + eq = (((ta + tb)*(a + 1)).expand())**2/(ta*tb - 1) + assert TR12i(eq) == -(a + 1)**2*tan(a + b) + + +def test_TR14(): + eq = (cos(x) - 1)*(cos(x) + 1) + ans = -sin(x)**2 + assert TR14(eq) == ans + assert TR14(1/eq) == 1/ans + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**2) == ans**2 + assert TR14((cos(x) - 1)**2*(cos(x) + 1)**3) == ans**2*(cos(x) + 1) + assert TR14((cos(x) - 1)**3*(cos(x) + 1)**2) == ans**2*(cos(x) - 1) + eq = (cos(x) - 1)**y*(cos(x) + 1)**y + assert TR14(eq) == eq + eq = (cos(x) - 2)**y*(cos(x) + 1) + assert TR14(eq) == eq + eq = (tan(x) - 2)**2*(cos(x) + 1) + assert TR14(eq) == eq + i = symbols('i', integer=True) + assert TR14((cos(x) - 1)**i*(cos(x) + 1)**i) == ans**i + assert TR14((sin(x) - 1)**i*(sin(x) + 1)**i) == (-cos(x)**2)**i + # could use extraction in this case + eq = (cos(x) - 1)**(i + 1)*(cos(x) + 1)**i + assert TR14(eq) in [(cos(x) - 1)*ans**i, eq] + + assert TR14((sin(x) - 1)*(sin(x) + 1)) == -cos(x)**2 + p1 = (cos(x) + 1)*(cos(x) - 1) + p2 = (cos(y) - 1)*2*(cos(y) + 1) + p3 = (3*(cos(y) - 1))*(3*(cos(y) + 1)) + assert TR14(p1*p2*p3*(x - 1)) == -18*((x - 1)*sin(x)**2*sin(y)**4) + + +def test_TR15_16_17(): + assert TR15(1 - 1/sin(x)**2) == -cot(x)**2 + assert TR16(1 - 1/cos(x)**2) == -tan(x)**2 + assert TR111(1 - 1/tan(x)**2) == 1 - cot(x)**2 + + +def test_as_f_sign_1(): + assert as_f_sign_1(x + 1) == (1, x, 1) + assert as_f_sign_1(x - 1) == (1, x, -1) + assert as_f_sign_1(-x + 1) == (-1, x, -1) + assert as_f_sign_1(-x - 1) == (-1, x, 1) + assert as_f_sign_1(2*x + 2) == (2, x, 1) + assert as_f_sign_1(x*y - y) == (y, x, -1) + assert as_f_sign_1(-x*y + y) == (-y, x, -1) + + +def test_issue_25590(): + A = Symbol('A', commutative=False) + B = Symbol('B', commutative=False) + + assert TR8(2*cos(x)*sin(x)*B*A) == sin(2*x)*B*A + assert TR13(tan(2)*tan(3)*B*A) == (-tan(2)/tan(5) - tan(3)/tan(5) + 1)*B*A + + # XXX The result may not be optimal than + # sin(2*x)*B*A + cos(x)**2 and may change in the future + assert (2*cos(x)*sin(x)*B*A + cos(x)**2).simplify() == sin(2*x)*B*A + cos(2*x)/2 + S.One/2 diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_function.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_function.py new file mode 100644 index 0000000000000000000000000000000000000000..441b9faf1bb3c5e7f2279b2a61066d050e45f773 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_function.py @@ -0,0 +1,54 @@ +""" Unit tests for Hyper_Function""" +from sympy.core import symbols, Dummy, Tuple, S, Rational +from sympy.functions import hyper + +from sympy.simplify.hyperexpand import Hyper_Function + +def test_attrs(): + a, b = symbols('a, b', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f.ap == Tuple(2, a) + assert f.bq == Tuple(b) + assert f.args == (Tuple(2, a), Tuple(b)) + assert f.sizes == (2, 1) + +def test_call(): + a, b, x = symbols('a, b, x', cls=Dummy) + f = Hyper_Function([2, a], [b]) + assert f(x) == hyper([2, a], [b], x) + +def test_has(): + a, b, c = symbols('a, b, c', cls=Dummy) + f = Hyper_Function([2, -a], [b]) + assert f.has(a) + assert f.has(Tuple(b)) + assert not f.has(c) + +def test_eq(): + assert Hyper_Function([1], []) == Hyper_Function([1], []) + assert (Hyper_Function([1], []) != Hyper_Function([1], [])) is False + assert Hyper_Function([1], []) != Hyper_Function([2], []) + assert Hyper_Function([1], []) != Hyper_Function([1, 2], []) + assert Hyper_Function([1], []) != Hyper_Function([1], [2]) + +def test_gamma(): + assert Hyper_Function([2, 3], [-1]).gamma == 0 + assert Hyper_Function([-2, -3], [-1]).gamma == 2 + n = Dummy(integer=True) + assert Hyper_Function([-1, n, 1], []).gamma == 1 + assert Hyper_Function([-1, -n, 1], []).gamma == 1 + p = Dummy(integer=True, positive=True) + assert Hyper_Function([-1, p, 1], []).gamma == 1 + assert Hyper_Function([-1, -p, 1], []).gamma == 2 + +def test_suitable_origin(): + assert Hyper_Function((S.Half,), (Rational(3, 2),))._is_suitable_origin() is True + assert Hyper_Function((S.Half,), (S.Half,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (Rational(-1, 2),))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (0,))._is_suitable_origin() is False + assert Hyper_Function((S.Half,), (-1, 1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 0), (1,))._is_suitable_origin() is False + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3)))._is_suitable_origin() is True + assert Hyper_Function((S.Half, 1), + (2, Rational(-2, 3), Rational(3, 2)))._is_suitable_origin() is True diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_gammasimp.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_gammasimp.py new file mode 100644 index 0000000000000000000000000000000000000000..e4c73093250b279510e3c2274db22818a9adffd8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_gammasimp.py @@ -0,0 +1,127 @@ +from sympy.core.function import Function +from sympy.core.numbers import (Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import (rf, binomial, factorial) +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.gamma_functions import gamma +from sympy.simplify.gammasimp import gammasimp +from sympy.simplify.powsimp import powsimp +from sympy.simplify.simplify import simplify + +from sympy.abc import x, y, n, k + + +def test_gammasimp(): + R = Rational + + # was part of test_combsimp_gamma() in test_combsimp.py + assert gammasimp(gamma(x)) == gamma(x) + assert gammasimp(gamma(x + 1)/x) == gamma(x) + assert gammasimp(gamma(x)/(x - 1)) == gamma(x - 1) + assert gammasimp(x*gamma(x)) == gamma(x + 1) + assert gammasimp((x + 1)*gamma(x + 1)) == gamma(x + 2) + assert gammasimp(gamma(x + y)*(x + y)) == gamma(x + y + 1) + assert gammasimp(x/gamma(x + 1)) == 1/gamma(x) + assert gammasimp((x + 1)**2/gamma(x + 2)) == (x + 1)/gamma(x + 1) + assert gammasimp(x*gamma(x) + gamma(x + 3)/(x + 2)) == \ + (x + 2)*gamma(x + 1) + + assert gammasimp(gamma(2*x)*x) == gamma(2*x + 1)/2 + assert gammasimp(gamma(2*x)/(x - S.Half)) == 2*gamma(2*x - 1) + + assert gammasimp(gamma(x)*gamma(1 - x)) == pi/sin(pi*x) + assert gammasimp(gamma(x)*gamma(-x)) == -pi/(x*sin(pi*x)) + assert gammasimp(1/gamma(x + 3)/gamma(1 - x)) == \ + sin(pi*x)/(pi*x*(x + 1)*(x + 2)) + + assert gammasimp(factorial(n + 2)) == gamma(n + 3) + assert gammasimp(binomial(n, k)) == \ + gamma(n + 1)/(gamma(k + 1)*gamma(-k + n + 1)) + + assert powsimp(gammasimp( + gamma(x)*gamma(x + S.Half)*gamma(y)/gamma(x + y))) == \ + 2**(-2*x + 1)*sqrt(pi)*gamma(2*x)*gamma(y)/gamma(x + y) + assert gammasimp(1/gamma(x)/gamma(x - Rational(1, 3))/gamma(x + Rational(1, 3))) == \ + 3**(3*x - Rational(3, 2))/(2*pi*gamma(3*x - 1)) + assert simplify( + gamma(S.Half + x/2)*gamma(1 + x/2)/gamma(1 + x)/sqrt(pi)*2**x) == 1 + assert gammasimp(gamma(Rational(-1, 4))*gamma(Rational(-3, 4))) == 16*sqrt(2)*pi/3 + + assert powsimp(gammasimp(gamma(2*x)/gamma(x))) == \ + 2**(2*x - 1)*gamma(x + S.Half)/sqrt(pi) + + # issue 6792 + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e) == -k + assert gammasimp(1/e) == -1/k + e = (gamma(x) + gamma(x + 1))/gamma(x) + assert gammasimp(e) == x + 1 + assert gammasimp(1/e) == 1/(x + 1) + e = (gamma(x) + gamma(x + 2))*(gamma(x - 1) + gamma(x))/gamma(x) + assert gammasimp(e) == (x**2 + x + 1)*gamma(x + 1)/(x - 1) + e = (-gamma(k)*gamma(k + 2) + gamma(k + 1)**2)/gamma(k)**2 + assert gammasimp(e**2) == k**2 + assert gammasimp(e**2/gamma(k + 1)) == k/gamma(k) + a = R(1, 2) + R(1, 3) + b = a + R(1, 3) + assert gammasimp(gamma(2*k)/gamma(k)*gamma(k + a)*gamma(k + b) + ) == 3*2**(2*k + 1)*3**(-3*k - 2)*sqrt(pi)*gamma(3*k + R(3, 2))/2 + + # issue 9699 + assert gammasimp((x + 1)*factorial(x)/gamma(y)) == gamma(x + 2)/gamma(y) + assert gammasimp(rf(x + n, k)*binomial(n, k)).simplify() == Piecewise( + (gamma(n + 1)*gamma(k + n + x)/(gamma(k + 1)*gamma(n + x)*gamma(-k + n + 1)), n > -x), + ((-1)**k*gamma(n + 1)*gamma(-n - x + 1)/(gamma(k + 1)*gamma(-k + n + 1)*gamma(-k - n - x + 1)), True)) + + A, B = symbols('A B', commutative=False) + assert gammasimp(e*B*A) == gammasimp(e)*B*A + + # check iteration + assert gammasimp(gamma(2*k)/gamma(k)*gamma(-k - R(1, 2))) == ( + -2**(2*k + 1)*sqrt(pi)/(2*((2*k + 1)*cos(pi*k)))) + assert gammasimp( + gamma(k)*gamma(k + R(1, 3))*gamma(k + R(2, 3))/gamma(k*R(3, 2))) == ( + 3*2**(3*k + 1)*3**(-3*k - S.Half)*sqrt(pi)*gamma(k*R(3, 2) + S.Half)/2) + + # issue 6153 + assert gammasimp(gamma(Rational(1, 4))/gamma(Rational(5, 4))) == 4 + + # was part of test_combsimp() in test_combsimp.py + assert gammasimp(binomial(n + 2, k + S.Half)) == gamma(n + 3)/ \ + (gamma(k + R(3, 2))*gamma(-k + n + R(5, 2))) + assert gammasimp(binomial(n + 2, k + 2.0)) == \ + gamma(n + 3)/(gamma(k + 3.0)*gamma(-k + n + 1)) + + # issue 11548 + assert gammasimp(binomial(0, x)) == sin(pi*x)/(pi*x) + + e = gamma(n + Rational(1, 3))*gamma(n + R(2, 3)) + assert gammasimp(e) == e + assert gammasimp(gamma(4*n + S.Half)/gamma(2*n - R(3, 4))) == \ + 2**(4*n - R(5, 2))*(8*n - 3)*gamma(2*n + R(3, 4))/sqrt(pi) + + i, m = symbols('i m', integer = True) + e = gamma(exp(i)) + assert gammasimp(e) == e + e = gamma(m + 3) + assert gammasimp(e) == e + e = gamma(m + 1)/(gamma(i + 1)*gamma(-i + m + 1)) + assert gammasimp(e) == e + + p = symbols("p", integer=True, positive=True) + assert gammasimp(gamma(-p + 4)) == gamma(-p + 4) + + +def test_issue_22606(): + fx = Function('f')(x) + eq = x + gamma(y) + # seems like ans should be `eq`, not `(x*y + gamma(y + 1))/y` + ans = gammasimp(eq) + assert gammasimp(eq.subs(x, fx)).subs(fx, x) == ans + assert gammasimp(eq.subs(x, cos(x))).subs(cos(x), x) == ans + assert 1/gammasimp(1/eq) == ans + assert gammasimp(fx.subs(x, eq)).args[0] == ans diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_hyperexpand.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_hyperexpand.py new file mode 100644 index 0000000000000000000000000000000000000000..c703c228a13201de13cfd4c3413fc75a2cf5bdb6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_hyperexpand.py @@ -0,0 +1,1063 @@ +from sympy.core.random import randrange + +from sympy.simplify.hyperexpand import (ShiftA, ShiftB, UnShiftA, UnShiftB, + MeijerShiftA, MeijerShiftB, MeijerShiftC, MeijerShiftD, + MeijerUnShiftA, MeijerUnShiftB, MeijerUnShiftC, + MeijerUnShiftD, + ReduceOrder, reduce_order, apply_operators, + devise_plan, make_derivative_operator, Formula, + hyperexpand, Hyper_Function, G_Function, + reduce_order_meijer, + build_hypergeometric_formula) +from sympy.concrete.summations import Sum +from sympy.core.containers import Tuple +from sympy.core.expr import Expr +from sympy.core.numbers import I +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.special.hyper import (hyper, meijerg) +from sympy.abc import z, a, b, c +from sympy.testing.pytest import XFAIL, raises, slow, tooslow +from sympy.core.random import verify_numerically as tn + +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import atanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (asin, cos, sin) +from sympy.functions.special.bessel import besseli +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import (gamma, lowergamma) + + +def test_branch_bug(): + assert hyperexpand(hyper((Rational(-1, 3), S.Half), (Rational(2, 3), Rational(3, 2)), -z)) == \ + -z**S('1/3')*lowergamma(exp_polar(I*pi)/3, z)/5 \ + + sqrt(pi)*erf(sqrt(z))/(5*sqrt(z)) + assert hyperexpand(meijerg([Rational(7, 6), 1], [], [Rational(2, 3)], [Rational(1, 6), 0], z)) == \ + 2*z**S('2/3')*(2*sqrt(pi)*erf(sqrt(z))/sqrt(z) - 2*lowergamma( + Rational(2, 3), z)/z**S('2/3'))*gamma(Rational(2, 3))/gamma(Rational(5, 3)) + + +def test_hyperexpand(): + # Luke, Y. L. (1969), The Special Functions and Their Approximations, + # Volume 1, section 6.2 + + assert hyperexpand(hyper([], [], z)) == exp(z) + assert hyperexpand(hyper([1, 1], [2], -z)*z) == log(1 + z) + assert hyperexpand(hyper([], [S.Half], -z**2/4)) == cos(z) + assert hyperexpand(z*hyper([], [S('3/2')], -z**2/4)) == sin(z) + assert hyperexpand(hyper([S('1/2'), S('1/2')], [S('3/2')], z**2)*z) \ + == asin(z) + assert isinstance(Sum(binomial(2, z)*z**2, (z, 0, a)).doit(), Expr) + + +def can_do(ap, bq, numerical=True, div=1, lowerplane=False): + r = hyperexpand(hyper(ap, bq, z)) + if r.has(hyper): + return False + if not numerical: + return True + repl = {} + randsyms = r.free_symbols - {z} + while randsyms: + # Only randomly generated parameters are checked. + for n, ai in enumerate(randsyms): + repl[ai] = randcplx(n)/div + if not any(b.is_Integer and b <= 0 for b in Tuple(*bq).subs(repl)): + break + [a, b, c, d] = [2, -1, 3, 1] + if lowerplane: + [a, b, c, d] = [2, -2, 3, -1] + return tn( + hyper(ap, bq, z).subs(repl), + r.replace(exp_polar, exp).subs(repl), + z, a=a, b=b, c=c, d=d) + + +def test_roach(): + # Kelly B. Roach. Meijer G Function Representations. + # Section "Gallery" + assert can_do([S.Half], [Rational(9, 2)]) + assert can_do([], [1, Rational(5, 2), 4]) + assert can_do([Rational(-1, 2), 1, 2], [3, 4]) + assert can_do([Rational(1, 3)], [Rational(-2, 3), Rational(-1, 2), S.Half, 1]) + assert can_do([Rational(-3, 2), Rational(-1, 2)], [Rational(-5, 2), 1]) + assert can_do([Rational(-3, 2), ], [Rational(-1, 2), S.Half]) # shine-integral + assert can_do([Rational(-3, 2), Rational(-1, 2)], [2]) # elliptic integrals + + +@XFAIL +def test_roach_fail(): + assert can_do([Rational(-1, 2), 1], [Rational(1, 4), S.Half, Rational(3, 4)]) # PFDD + assert can_do([Rational(3, 2)], [Rational(5, 2), 5]) # struve function + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(5, 2)]) # polylog, pfdd + assert can_do([1, 2, 3], [S.Half, 4]) # XXX ? + assert can_do([S.Half], [Rational(-1, 3), Rational(-1, 2), Rational(-2, 3)]) # PFDD ? + +# For the long table tests, see end of file + + +def test_polynomial(): + from sympy.core.numbers import oo + assert hyperexpand(hyper([], [-1], z)) is oo + assert hyperexpand(hyper([-2], [-1], z)) is oo + assert hyperexpand(hyper([0, 0], [-1], z)) == 1 + assert can_do([-5, -2, randcplx(), randcplx()], [-10, randcplx()]) + assert hyperexpand(hyper((-1, 1), (-2,), z)) == 1 + z/2 + + +def test_hyperexpand_bases(): + assert hyperexpand(hyper([2], [a], z)) == \ + a + z**(-a + 1)*(-a**2 + 3*a + z*(a - 1) - 2)*exp(z)* \ + lowergamma(a - 1, z) - 1 + # TODO [a+1, aRational(-1, 2)], [2*a] + assert hyperexpand(hyper([1, 2], [3], z)) == -2/z - 2*log(-z + 1)/z**2 + assert hyperexpand(hyper([S.Half, 2], [Rational(3, 2)], z)) == \ + -1/(2*z - 2) + atanh(sqrt(z))/sqrt(z)/2 + assert hyperexpand(hyper([S.Half, S.Half], [Rational(5, 2)], z)) == \ + (-3*z + 3)/4/(z*sqrt(-z + 1)) \ + + (6*z - 3)*asin(sqrt(z))/(4*z**Rational(3, 2)) + assert hyperexpand(hyper([1, 2], [Rational(3, 2)], z)) == -1/(2*z - 2) \ + - asin(sqrt(z))/(sqrt(z)*(2*z - 2)*sqrt(-z + 1)) + assert hyperexpand(hyper([Rational(-1, 2) - 1, 1, 2], [S.Half, 3], z)) == \ + sqrt(z)*(z*Rational(6, 7) - Rational(6, 5))*atanh(sqrt(z)) \ + + (-30*z**2 + 32*z - 6)/35/z - 6*log(-z + 1)/(35*z**2) + assert hyperexpand(hyper([1 + S.Half, 1, 1], [2, 2], z)) == \ + -4*log(sqrt(-z + 1)/2 + S.Half)/z + # TODO hyperexpand(hyper([a], [2*a + 1], z)) + # TODO [S.Half, a], [Rational(3, 2), a+1] + assert hyperexpand(hyper([2], [b, 1], z)) == \ + z**(-b/2 + S.Half)*besseli(b - 1, 2*sqrt(z))*gamma(b) \ + + z**(-b/2 + 1)*besseli(b, 2*sqrt(z))*gamma(b) + # TODO [a], [a - S.Half, 2*a] + + +def test_hyperexpand_parametric(): + assert hyperexpand(hyper([a, S.Half + a], [S.Half], z)) \ + == (1 + sqrt(z))**(-2*a)/2 + (1 - sqrt(z))**(-2*a)/2 + assert hyperexpand(hyper([a, Rational(-1, 2) + a], [2*a], z)) \ + == 2**(2*a - 1)*((-z + 1)**S.Half + 1)**(-2*a + 1) + + +def test_shifted_sum(): + from sympy.simplify.simplify import simplify + assert simplify(hyperexpand(z**4*hyper([2], [3, S('3/2')], -z**2))) \ + == z*sin(2*z) + (-z**2 + S.Half)*cos(2*z) - S.Half + + +def _randrat(): + """ Steer clear of integers. """ + return S(randrange(25) + 10)/50 + + +def randcplx(offset=-1): + """ Polys is not good with real coefficients. """ + return _randrat() + I*_randrat() + I*(1 + offset) + + +@slow +def test_formulae(): + from sympy.simplify.hyperexpand import FormulaCollection + formulae = FormulaCollection().formulae + for formula in formulae: + h = formula.func(formula.z) + rep = {} + for n, sym in enumerate(formula.symbols): + rep[sym] = randcplx(n) + + # NOTE hyperexpand returns truly branched functions. We know we are + # on the main sheet, but numerical evaluation can still go wrong + # (e.g. if exp_polar cannot be evalf'd). + # Just replace all exp_polar by exp, this usually works. + + # first test if the closed-form is actually correct + h = h.subs(rep) + closed_form = formula.closed_form.subs(rep).rewrite('nonrepsmall') + z = formula.z + assert tn(h, closed_form.replace(exp_polar, exp), z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep).rewrite('nonrepsmall') + assert tn(closed_form.replace( + exp_polar, exp), cl.replace(exp_polar, exp), z) + deriv1 = z*formula.B.applyfunc(lambda t: t.rewrite( + 'nonrepsmall')).diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep).replace(exp_polar, exp), + d2.subs(rep).rewrite('nonrepsmall').replace(exp_polar, exp), z) + + +def test_meijerg_formulae(): + from sympy.simplify.hyperexpand import MeijerFormulaCollection + formulae = MeijerFormulaCollection().formulae + for sig in formulae: + for formula in formulae[sig]: + g = meijerg(formula.func.an, formula.func.ap, + formula.func.bm, formula.func.bq, + formula.z) + rep = {} + for sym in formula.symbols: + rep[sym] = randcplx() + + # first test if the closed-form is actually correct + g = g.subs(rep) + closed_form = formula.closed_form.subs(rep) + z = formula.z + assert tn(g, closed_form, z) + + # now test the computed matrix + cl = (formula.C * formula.B)[0].subs(rep) + assert tn(closed_form, cl, z) + deriv1 = z*formula.B.diff(z) + deriv2 = formula.M * formula.B + for d1, d2 in zip(deriv1, deriv2): + assert tn(d1.subs(rep), d2.subs(rep), z) + + +def op(f): + return z*f.diff(z) + + +def test_plan(): + assert devise_plan(Hyper_Function([0], ()), + Hyper_Function([0], ()), z) == [] + with raises(ValueError): + devise_plan(Hyper_Function([1], ()), Hyper_Function((), ()), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], [1]), Hyper_Function([2], [2]), z) + with raises(ValueError): + devise_plan(Hyper_Function([2], []), Hyper_Function([S("1/2")], []), z) + + # We cannot use pi/(10000 + n) because polys is insanely slow. + a1, a2, b1 = (randcplx(n) for n in range(3)) + b1 += 2*I + h = hyper([a1, a2], [b1], z) + + h2 = hyper((a1 + 1, a2), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + h2 = hyper((a1 + 1, a2 - 1), [b1], z) + assert tn(apply_operators(h, + devise_plan(Hyper_Function((a1 + 1, a2 - 1), [b1]), + Hyper_Function((a1, a2), [b1]), z), op), + h2, z) + + +def test_plan_derivatives(): + a1, a2, a3 = 1, 2, S('1/2') + b1, b2 = 3, S('5/2') + h = Hyper_Function((a1, a2, a3), (b1, b2)) + h2 = Hyper_Function((a1 + 1, a2 + 1, a3 + 2), (b1 + 1, b2 + 1)) + ops = devise_plan(h2, h, z) + f = Formula(h, z, h(z), []) + deriv = make_derivative_operator(f.M, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + h2 = Hyper_Function((a1, a2 - 1, a3 - 2), (b1 - 1, b2 - 1)) + ops = devise_plan(h2, h, z) + assert tn((apply_operators(f.C, ops, deriv)*f.B)[0], h2(z), z) + + +def test_reduction_operators(): + a1, a2, b1 = (randcplx(n) for n in range(3)) + h = hyper([a1], [b1], z) + + assert ReduceOrder(2, 0) is None + assert ReduceOrder(2, -1) is None + assert ReduceOrder(1, S('1/2')) is None + + h2 = hyper((a1, a2), (b1, a2), z) + assert tn(ReduceOrder(a2, a2).apply(h, op), h2, z) + + h2 = hyper((a1, a2 + 1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 1, a2).apply(h, op), h2, z) + + h2 = hyper((a2 + 4, a1), (b1, a2), z) + assert tn(ReduceOrder(a2 + 4, a2).apply(h, op), h2, z) + + # test several step order reduction + ap = (a2 + 4, a1, b1 + 1) + bq = (a2, b1, b1) + func, ops = reduce_order(Hyper_Function(ap, bq)) + assert func.ap == (a1,) + assert func.bq == (b1,) + assert tn(apply_operators(h, ops, op), hyper(ap, bq, z), z) + + +def test_shift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: ShiftA(0)) + raises(ValueError, lambda: ShiftB(1)) + + assert tn(ShiftA(a1).apply(h, op), hyper((a1 + 1, a2), (b1, b2, b3), z), z) + assert tn(ShiftA(a2).apply(h, op), hyper((a1, a2 + 1), (b1, b2, b3), z), z) + assert tn(ShiftB(b1).apply(h, op), hyper((a1, a2), (b1 - 1, b2, b3), z), z) + assert tn(ShiftB(b2).apply(h, op), hyper((a1, a2), (b1, b2 - 1, b3), z), z) + assert tn(ShiftB(b3).apply(h, op), hyper((a1, a2), (b1, b2, b3 - 1), z), z) + + +def test_ushift_operators(): + a1, a2, b1, b2, b3 = (randcplx(n) for n in range(5)) + h = hyper((a1, a2), (b1, b2, b3), z) + + raises(ValueError, lambda: UnShiftA((1,), (), 0, z)) + raises(ValueError, lambda: UnShiftB((), (-1,), 0, z)) + raises(ValueError, lambda: UnShiftA((1,), (0, -1, 1), 0, z)) + raises(ValueError, lambda: UnShiftB((0, 1), (1,), 0, z)) + + s = UnShiftA((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1 - 1, a2), (b1, b2, b3), z), z) + s = UnShiftA((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2 - 1), (b1, b2, b3), z), z) + + s = UnShiftB((a1, a2), (b1, b2, b3), 0, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1 + 1, b2, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 1, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2 + 1, b3), z), z) + s = UnShiftB((a1, a2), (b1, b2, b3), 2, z) + assert tn(s.apply(h, op), hyper((a1, a2), (b1, b2, b3 + 1), z), z) + + +def can_do_meijer(a1, a2, b1, b2, numeric=True): + """ + This helper function tries to hyperexpand() the meijer g-function + corresponding to the parameters a1, a2, b1, b2. + It returns False if this expansion still contains g-functions. + If numeric is True, it also tests the so-obtained formula numerically + (at random values) and returns False if the test fails. + Else it returns True. + """ + from sympy.core.function import expand + from sympy.functions.elementary.complexes import unpolarify + r = hyperexpand(meijerg(a1, a2, b1, b2, z)) + if r.has(meijerg): + return False + # NOTE hyperexpand() returns a truly branched function, whereas numerical + # evaluation only works on the main branch. Since we are evaluating on + # the main branch, this should not be a problem, but expressions like + # exp_polar(I*pi/2*x)**a are evaluated incorrectly. We thus have to get + # rid of them. The expand heuristically does this... + r = unpolarify(expand(r, force=True, power_base=True, power_exp=False, + mul=False, log=False, multinomial=False, basic=False)) + + if not numeric: + return True + + repl = {} + for n, ai in enumerate(meijerg(a1, a2, b1, b2, z).free_symbols - {z}): + repl[ai] = randcplx(n) + return tn(meijerg(a1, a2, b1, b2, z).subs(repl), r.subs(repl), z) + + +@slow +def test_meijerg_expand(): + from sympy.simplify.gammasimp import gammasimp + from sympy.simplify.simplify import simplify + # from mpmath docs + assert hyperexpand(meijerg([[], []], [[0], []], -z)) == exp(z) + + assert hyperexpand(meijerg([[1, 1], []], [[1], [0]], z)) == \ + log(z + 1) + assert hyperexpand(meijerg([[1, 1], []], [[1], [1]], z)) == \ + z/(z + 1) + assert hyperexpand(meijerg([[], []], [[S.Half], [0]], (z/2)**2)) \ + == sin(z)/sqrt(pi) + assert hyperexpand(meijerg([[], []], [[0], [S.Half]], (z/2)**2)) \ + == cos(z)/sqrt(pi) + assert can_do_meijer([], [a], [a - 1, a - S.Half], []) + assert can_do_meijer([], [], [a/2], [-a/2], False) # branches... + assert can_do_meijer([a], [b], [a], [b, a - 1]) + + # wikipedia + assert hyperexpand(meijerg([1], [], [], [0], z)) == \ + Piecewise((0, abs(z) < 1), (1, abs(1/z) < 1), + (meijerg([1], [], [], [0], z), True)) + assert hyperexpand(meijerg([], [1], [0], [], z)) == \ + Piecewise((1, abs(z) < 1), (0, abs(1/z) < 1), + (meijerg([], [1], [0], [], z), True)) + + # The Special Functions and their Approximations + assert can_do_meijer([], [], [a + b/2], [a, a - b/2, a + S.Half]) + assert can_do_meijer( + [], [], [a], [b], False) # branches only agree for small z + assert can_do_meijer([], [S.Half], [a], [-a]) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, b], []) + assert can_do_meijer([], [], [a, a + S.Half], [b, b + S.Half]) + assert can_do_meijer([], [], [a, -a], [0, S.Half], False) # dito + assert can_do_meijer([], [], [a, a + S.Half, b, b + S.Half], []) + assert can_do_meijer([S.Half], [], [0], [a, -a]) + assert can_do_meijer([S.Half], [], [a], [0, -a], False) # dito + assert can_do_meijer([], [a - S.Half], [a, b], [a - S.Half], False) + assert can_do_meijer([], [a + S.Half], [a + b, a - b, a], [], False) + assert can_do_meijer([a + S.Half], [], [b, 2*a - b, a], [], False) + + # This for example is actually zero. + assert can_do_meijer([], [], [], [a, b]) + + # Testing a bug: + assert hyperexpand(meijerg([0, 2], [], [], [-1, 1], z)) == \ + Piecewise((0, abs(z) < 1), + (z*(1 - 1/z**2)/2, abs(1/z) < 1), + (meijerg([0, 2], [], [], [-1, 1], z), True)) + + # Test that the simplest possible answer is returned: + assert gammasimp(simplify(hyperexpand( + meijerg([1], [1 - a], [-a/2, -a/2 + S.Half], [], 1/z)))) == \ + -2*sqrt(pi)*(sqrt(z + 1) + 1)**a/a + + # Test that hyper is returned + assert hyperexpand(meijerg([1], [], [a], [0, 0], z)) == hyper( + (a,), (a + 1, a + 1), z*exp_polar(I*pi))*z**a*gamma(a)/gamma(a + 1)**2 + + # Test place option + f = meijerg(((0, 1), ()), ((S.Half,), (0,)), z**2) + assert hyperexpand(f) == sqrt(pi)/sqrt(1 + z**(-2)) + assert hyperexpand(f, place=0) == sqrt(pi)*z/sqrt(z**2 + 1) + + +def test_meijerg_lookup(): + from sympy.functions.special.error_functions import (Ci, Si) + from sympy.functions.special.gamma_functions import uppergamma + assert hyperexpand(meijerg([a], [], [b, a], [], z)) == \ + z**b*exp(z)*gamma(-a + b + 1)*uppergamma(a - b, z) + assert hyperexpand(meijerg([0], [], [0, 0], [], z)) == \ + exp(z)*uppergamma(0, z) + assert can_do_meijer([a], [], [b, a + 1], []) + assert can_do_meijer([a], [], [b + 2, a], []) + assert can_do_meijer([a], [], [b - 2, a], []) + + assert hyperexpand(meijerg([a], [], [a, a, a - S.Half], [], z)) == \ + -sqrt(pi)*z**(a - S.Half)*(2*cos(2*sqrt(z))*(Si(2*sqrt(z)) - pi/2) + - 2*sin(2*sqrt(z))*Ci(2*sqrt(z))) == \ + hyperexpand(meijerg([a], [], [a, a - S.Half, a], [], z)) == \ + hyperexpand(meijerg([a], [], [a - S.Half, a, a], [], z)) + assert can_do_meijer([a - 1], [], [a + 2, a - Rational(3, 2), a + 1], []) + + +@XFAIL +def test_meijerg_expand_fail(): + # These basically test hyper([], [1/2 - a, 1/2 + 1, 1/2], z), + # which is *very* messy. But since the meijer g actually yields a + # sum of bessel functions, things can sometimes be simplified a lot and + # are then put into tables... + assert can_do_meijer([], [], [a + S.Half], [a, a - b/2, a + b/2]) + assert can_do_meijer([], [], [0, S.Half], [a, -a]) + assert can_do_meijer([], [], [3*a - S.Half, a, -a - S.Half], [a - S.Half]) + assert can_do_meijer([], [], [0, a - S.Half, -a - S.Half], [S.Half]) + assert can_do_meijer([], [], [a, b + S.Half, b], [2*b - a]) + assert can_do_meijer([], [], [a, b + S.Half, b, 2*b - a]) + assert can_do_meijer([S.Half], [], [-a, a], [0]) + + +@slow +def test_meijerg(): + # carefully set up the parameters. + # NOTE: this used to fail sometimes. I believe it is fixed, but if you + # hit an inexplicable test failure here, please let me know the seed. + a1, a2 = (randcplx(n) - 5*I - n*I for n in range(2)) + b1, b2 = (randcplx(n) + 5*I + n*I for n in range(2)) + b3, b4, b5, a3, a4, a5 = (randcplx() for n in range(6)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert ReduceOrder.meijer_minus(3, 4) is None + assert ReduceOrder.meijer_plus(4, 3) is None + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2], z) + assert tn(ReduceOrder.meijer_plus(a2, a2).apply(g, op), g2, z) + + g2 = meijerg([a1, a2], [a3, a4], [b1], [b3, b4, a2 + 1], z) + assert tn(ReduceOrder.meijer_plus(a2, a2 + 1).apply(g, op), g2, z) + + g2 = meijerg([a1, a2 - 1], [a3, a4], [b1], [b3, b4, a2 + 2], z) + assert tn(ReduceOrder.meijer_plus(a2 - 1, a2 + 2).apply(g, op), g2, z) + + g2 = meijerg([a1], [a3, a4, b2 - 1], [b1, b2 + 2], [b3, b4], z) + assert tn(ReduceOrder.meijer_minus( + b2 + 2, b2 - 1).apply(g, op), g2, z, tol=1e-6) + + # test several-step reduction + an = [a1, a2] + bq = [b3, b4, a2 + 1] + ap = [a3, a4, b2 - 1] + bm = [b1, b2 + 1] + niq, ops = reduce_order_meijer(G_Function(an, ap, bm, bq)) + assert niq.an == (a1,) + assert set(niq.ap) == {a3, a4} + assert niq.bm == (b1,) + assert set(niq.bq) == {b3, b4} + assert tn(apply_operators(g, ops, op), meijerg(an, ap, bm, bq, z), z) + + +def test_meijerg_shift_operators(): + # carefully set up the parameters. XXX this still fails sometimes + a1, a2, a3, a4, a5, b1, b2, b3, b4, b5 = (randcplx(n) for n in range(10)) + g = meijerg([a1], [a3, a4], [b1], [b3, b4], z) + + assert tn(MeijerShiftA(b1).apply(g, op), + meijerg([a1], [a3, a4], [b1 + 1], [b3, b4], z), z) + assert tn(MeijerShiftB(a1).apply(g, op), + meijerg([a1 - 1], [a3, a4], [b1], [b3, b4], z), z) + assert tn(MeijerShiftC(b3).apply(g, op), + meijerg([a1], [a3, a4], [b1], [b3 + 1, b4], z), z) + assert tn(MeijerShiftD(a3).apply(g, op), + meijerg([a1], [a3 - 1, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftA([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1 - 1], [b3, b4], z), z) + + s = MeijerUnShiftC([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3, a4], [b1], [b3 - 1, b4], z), z) + + s = MeijerUnShiftB([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1 + 1], [a3, a4], [b1], [b3, b4], z), z) + + s = MeijerUnShiftD([a1], [a3, a4], [b1], [b3, b4], 0, z) + assert tn( + s.apply(g, op), meijerg([a1], [a3 + 1, a4], [b1], [b3, b4], z), z) + + +@slow +def test_meijerg_confluence(): + def t(m, a, b): + from sympy.core.sympify import sympify + a, b = sympify([a, b]) + m_ = m + m = hyperexpand(m) + if not m == Piecewise((a, abs(z) < 1), (b, abs(1/z) < 1), (m_, True)): + return False + if not (m.args[0].args[0] == a and m.args[1].args[0] == b): + return False + z0 = randcplx()/10 + if abs(m.subs(z, z0).n() - a.subs(z, z0).n()).n() > 1e-10: + return False + if abs(m.subs(z, 1/z0).n() - b.subs(z, 1/z0).n()).n() > 1e-10: + return False + return True + + assert t(meijerg([], [1, 1], [0, 0], [], z), -log(z), 0) + assert t(meijerg( + [], [3, 1], [0, 0], [], z), -z**2/4 + z - log(z)/2 - Rational(3, 4), 0) + assert t(meijerg([], [3, 1], [-1, 0], [], z), + z**2/12 - z/2 + log(z)/2 + Rational(1, 4) + 1/(6*z), 0) + assert t(meijerg([], [1, 1, 1, 1], [0, 0, 0, 0], [], z), -log(z)**3/6, 0) + assert t(meijerg([1, 1], [], [], [0, 0], z), 0, -log(1/z)) + assert t(meijerg([1, 1], [2, 2], [1, 1], [0, 0], z), + -z*log(z) + 2*z, -log(1/z) + 2) + assert t(meijerg([S.Half], [1, 1], [0, 0], [Rational(3, 2)], z), log(z)/2 - 1, 0) + + def u(an, ap, bm, bq): + m = meijerg(an, ap, bm, bq, z) + m2 = hyperexpand(m, allow_hyper=True) + if m2.has(meijerg) and not (m2.is_Piecewise and len(m2.args) == 3): + return False + return tn(m, m2, z) + assert u([], [1], [0, 0], []) + assert u([1, 1], [], [], [0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0, 0]) + assert u([1, 1], [2, 2, 5], [1, 1, 6], [0]) + + +def test_meijerg_with_Floats(): + # see issue #10681 + from sympy.polys.domains.realfield import RR + f = meijerg(((3.0, 1), ()), ((Rational(3, 2),), (0,)), z) + a = -2.3632718012073 + g = a*z**Rational(3, 2)*hyper((-0.5, Rational(3, 2)), (Rational(5, 2),), z*exp_polar(I*pi)) + assert RR.almosteq((hyperexpand(f)/g).n(), 1.0, 1e-12) + + +def test_lerchphi(): + from sympy.functions.special.zeta_functions import (lerchphi, polylog) + from sympy.simplify.gammasimp import gammasimp + assert hyperexpand(hyper([1, a], [a + 1], z)/a) == lerchphi(z, 1, a) + assert hyperexpand( + hyper([1, a, a], [a + 1, a + 1], z)/a**2) == lerchphi(z, 2, a) + assert hyperexpand(hyper([1, a, a, a], [a + 1, a + 1, a + 1], z)/a**3) == \ + lerchphi(z, 3, a) + assert hyperexpand(hyper([1] + [a]*10, [a + 1]*10, z)/a**10) == \ + lerchphi(z, 10, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a], [], [0], + [-a], exp_polar(-I*pi)*z))) == lerchphi(z, 1, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a], [], [0], + [-a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 2, a) + assert gammasimp(hyperexpand(meijerg([0, 1 - a, 1 - a, 1 - a], [], [0], + [-a, -a, -a], exp_polar(-I*pi)*z))) == lerchphi(z, 3, a) + + assert hyperexpand(z*hyper([1, 1], [2], z)) == -log(1 + -z) + assert hyperexpand(z*hyper([1, 1, 1], [2, 2], z)) == polylog(2, z) + assert hyperexpand(z*hyper([1, 1, 1, 1], [2, 2, 2], z)) == polylog(3, z) + + assert hyperexpand(hyper([1, a, 1 + S.Half], [a + 1, S.Half], z)) == \ + -2*a/(z - 1) + (-2*a**2 + a)*lerchphi(z, 1, a) + + # Now numerical tests. These make sure reductions etc are carried out + # correctly + + # a rational function (polylog at negative integer order) + assert can_do([2, 2, 2], [1, 1]) + + # NOTE these contain log(1-x) etc ... better make sure we have |z| < 1 + # reduction of order for polylog + assert can_do([1, 1, 1, b + 5], [2, 2, b], div=10) + + # reduction of order for lerchphi + # XXX lerchphi in mpmath is flaky + assert can_do( + [1, a, a, a, b + 5], [a + 1, a + 1, a + 1, b], numerical=False) + + # test a bug + from sympy.functions.elementary.complexes import Abs + assert hyperexpand(hyper([S.Half, S.Half, S.Half, 1], + [Rational(3, 2), Rational(3, 2), Rational(3, 2)], Rational(1, 4))) == \ + Abs(-polylog(3, exp_polar(I*pi)/2) + polylog(3, S.Half)) + + +def test_partial_simp(): + # First test that hypergeometric function formulae work. + a, b, c, d, e = (randcplx() for _ in range(5)) + for func in [Hyper_Function([a, b, c], [d, e]), + Hyper_Function([], [a, b, c, d, e])]: + f = build_hypergeometric_formula(func) + z = f.z + assert f.closed_form == func(z) + deriv1 = f.B.diff(z)*z + deriv2 = f.M*f.B + for func1, func2 in zip(deriv1, deriv2): + assert tn(func1, func2, z) + + # Now test that formulae are partially simplified. + a, b, z = symbols('a b z') + assert hyperexpand(hyper([3, a], [1, b], z)) == \ + (-a*b/2 + a*z/2 + 2*a)*hyper([a + 1], [b], z) \ + + (a*b/2 - 2*a + 1)*hyper([a], [b], z) + assert tn( + hyperexpand(hyper([3, d], [1, e], z)), hyper([3, d], [1, e], z), z) + assert hyperexpand(hyper([3], [1, a, b], z)) == \ + hyper((), (a, b), z) \ + + z*hyper((), (a + 1, b), z)/(2*a) \ + - z*(b - 4)*hyper((), (a + 1, b + 1), z)/(2*a*b) + assert tn( + hyperexpand(hyper([3], [1, d, e], z)), hyper([3], [1, d, e], z), z) + + +def test_hyperexpand_special(): + assert hyperexpand(hyper([a, b], [c], 1)) == \ + gamma(c)*gamma(c - a - b)/gamma(c - a)/gamma(c - b) + assert hyperexpand(hyper([a, b], [1 + a - b], -1)) == \ + gamma(1 + a/2)*gamma(1 + a - b)/gamma(1 + a)/gamma(1 + a/2 - b) + assert hyperexpand(hyper([a, b], [1 + b - a], -1)) == \ + gamma(1 + b/2)*gamma(1 + b - a)/gamma(1 + b)/gamma(1 + b/2 - a) + assert hyperexpand(meijerg([1 - z - a/2], [1 - z + a/2], [b/2], [-b/2], 1)) == \ + gamma(1 - 2*z)*gamma(z + a/2 + b/2)/gamma(1 - z + a/2 - b/2) \ + /gamma(1 - z - a/2 + b/2)/gamma(1 - z + a/2 + b/2) + assert hyperexpand(hyper([a], [b], 0)) == 1 + assert hyper([a], [b], 0) != 0 + + +def test_Mod1_behavior(): + from sympy.core.symbol import Symbol + from sympy.simplify.simplify import simplify + n = Symbol('n', integer=True) + # Note: this should not hang. + assert simplify(hyperexpand(meijerg([1], [], [n + 1], [0], z))) == \ + lowergamma(n + 1, z) + + +@slow +def test_prudnikov_misc(): + assert can_do([1, (3 + I)/2, (3 - I)/2], [Rational(3, 2), 2]) + assert can_do([S.Half, a - 1], [Rational(3, 2), a + 1], lowerplane=True) + assert can_do([], [b + 1]) + assert can_do([a], [a - 1, b + 1]) + + assert can_do([a], [a - S.Half, 2*a]) + assert can_do([a], [a - S.Half, 2*a + 1]) + assert can_do([a], [a - S.Half, 2*a - 1]) + assert can_do([a], [a + S.Half, 2*a]) + assert can_do([a], [a + S.Half, 2*a + 1]) + assert can_do([a], [a + S.Half, 2*a - 1]) + assert can_do([S.Half], [b, 2 - b]) + assert can_do([S.Half], [b, 3 - b]) + assert can_do([1], [2, b]) + + assert can_do([a, a + S.Half], [2*a, b, 2*a - b + 1]) + assert can_do([a, a + S.Half], [S.Half, 2*a, 2*a + S.Half]) + assert can_do([a], [a + 1], lowerplane=True) # lowergamma + + +def test_prudnikov_1(): + # A. P. Prudnikov, Yu. A. Brychkov and O. I. Marichev (1990). + # Integrals and Series: More Special Functions, Vol. 3,. + # Gordon and Breach Science Publisher + + # 7.3.1 + assert can_do([a, -a], [S.Half]) + assert can_do([a, 1 - a], [S.Half]) + assert can_do([a, 1 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [S.Half]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, 2 - a], [Rational(3, 2)]) + assert can_do([a, a + S.Half], [2*a - 1]) + assert can_do([a, a + S.Half], [2*a]) + assert can_do([a, a + S.Half], [2*a + 1]) + assert can_do([a, a + S.Half], [S.Half]) + assert can_do([a, a + S.Half], [Rational(3, 2)]) + assert can_do([a, a/2 + 1], [a/2]) + assert can_do([1, b], [2]) + assert can_do([1, b], [b + 1], numerical=False) # Lerch Phi + # NOTE: branches are complicated for |z| > 1 + + assert can_do([a], [2*a]) + assert can_do([a], [2*a + 1]) + assert can_do([a], [2*a - 1]) + + +@slow +def test_prudnikov_2(): + h = S.Half + assert can_do([-h, -h], [h]) + assert can_do([-h, h], [3*h]) + assert can_do([-h, h], [5*h]) + assert can_do([-h, h], [7*h]) + assert can_do([-h, 1], [h]) + + for p in [-h, h]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [-h, h, 3*h, 5*h, 7*h]: + assert can_do([p, n], [m]) + for n in [1, 2, 3, 4]: + for m in [1, 2, 3, 4]: + assert can_do([p, n], [m]) + + +def test_prudnikov_3(): + h = S.Half + assert can_do([Rational(1, 4), Rational(3, 4)], [h]) + assert can_do([Rational(1, 4), Rational(3, 4)], [3*h]) + assert can_do([Rational(1, 3), Rational(2, 3)], [3*h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [h]) + assert can_do([Rational(3, 4), Rational(5, 4)], [3*h]) + + +@tooslow +def test_prudnikov_3_slow(): + # XXX: This is marked as tooslow and hence skipped in CI. None of the + # individual cases below fails or hangs. Some cases are slow and the loops + # below generate 280 different cases. Is it really necessary to test all + # 280 cases here? + h = S.Half + for p in [1, 2, 3, 4]: + for n in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4, 9*h]: + for m in [1, 3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_4(): + h = S.Half + for p in [3*h, 5*h, 7*h]: + for n in [-h, h, 3*h, 5*h, 7*h]: + for m in [3*h, 2, 5*h, 3, 7*h, 4]: + assert can_do([p, m], [n]) + for n in [1, 2, 3, 4]: + for m in [2, 3, 4]: + assert can_do([p, m], [n]) + + +@slow +def test_prudnikov_5(): + h = S.Half + + for p in [1, 2, 3]: + for q in range(p, 4): + for r in [1, 2, 3]: + for s in range(r, 4): + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [h, 3*h, 5*h]: + for r in [h, 3*h, 5*h]: + for s in [h, 3*h, 5*h]: + if s <= q and s <= r: + assert can_do([-h, p, q], [r, s]) + + for p in [h, 1, 3*h, 2, 5*h, 3]: + for q in [1, 2, 3]: + for r in [h, 3*h, 5*h]: + for s in [1, 2, 3]: + assert can_do([-h, p, q], [r, s]) + + +@slow +def test_prudnikov_6(): + h = S.Half + + for m in [3*h, 5*h]: + for n in [1, 2, 3]: + for q in [h, 1, 2]: + for p in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + for q in [1, 2, 3]: + for p in [3*h, 5*h]: + assert can_do([h, q, p], [m, n]) + + for q in [1, 2]: + for p in [1, 2, 3]: + for m in [1, 2, 3]: + for n in [1, 2, 3]: + assert can_do([h, q, p], [m, n]) + + assert can_do([h, h, 5*h], [3*h, 3*h]) + assert can_do([h, 1, 5*h], [3*h, 3*h]) + assert can_do([h, 2, 2], [1, 3]) + + # pages 435 to 457 contain more PFDD and stuff like this + + +@slow +def test_prudnikov_7(): + assert can_do([3], [6]) + + h = S.Half + for n in [h, 3*h, 5*h, 7*h]: + assert can_do([-h], [n]) + for m in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: # HERE + for n in [-h, h, 3*h, 5*h, 7*h, 1, 2, 3, 4]: + assert can_do([m], [n]) + + +@slow +def test_prudnikov_8(): + h = S.Half + + # 7.12.2 + for ai in [1, 2, 3]: + for bi in [1, 2, 3]: + for ci in range(1, ai + 1): + for di in [h, 1, 3*h, 2, 5*h, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [3*h, 5*h]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + + for ai in [-h, h, 3*h, 5*h]: + for bi in [1, 2, 3]: + for ci in [h, 1, 3*h, 2, 5*h, 3]: + for di in [1, 2, 3]: + assert can_do([ai, bi], [ci, di]) + for bi in [h, 3*h, 5*h]: + for ci in [h, 3*h, 5*h, 3]: + for di in [h, 1, 3*h, 2, 5*h, 3]: + if ci <= bi: + assert can_do([ai, bi], [ci, di]) + + +def test_prudnikov_9(): + # 7.13.1 [we have a general formula ... so this is a bit pointless] + for i in range(9): + assert can_do([], [(S(i) + 1)/2]) + for i in range(5): + assert can_do([], [-(2*S(i) + 1)/2]) + + +@slow +def test_prudnikov_10(): + # 7.14.2 + h = S.Half + for p in [-h, h, 1, 3*h, 2, 5*h, 3, 7*h, 4]: + for m in [1, 2, 3, 4]: + for n in range(m, 5): + assert can_do([p], [m, n]) + + for p in [1, 2, 3, 4]: + for n in [h, 3*h, 5*h, 7*h]: + for m in [1, 2, 3, 4]: + assert can_do([p], [n, m]) + + for p in [3*h, 5*h, 7*h]: + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([p], [h, m]) + assert can_do([p], [3*h, m]) + + for m in [h, 1, 2, 5*h, 3, 7*h, 4]: + assert can_do([7*h], [5*h, m]) + + assert can_do([Rational(-1, 2)], [S.Half, S.Half]) # shine-integral shi + + +def test_prudnikov_11(): + # 7.15 + assert can_do([a, a + S.Half], [2*a, b, 2*a - b]) + assert can_do([a, a + S.Half], [Rational(3, 2), 2*a, 2*a - S.Half]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [S.Half, S.Half, 1]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), S.Half, 2]) + assert can_do([Rational(5, 4), Rational(3, 4)], [Rational(3, 2), Rational(3, 2), 1]) + assert can_do([Rational(5, 4), Rational(7, 4)], [Rational(3, 2), Rational(5, 2), 2]) + + assert can_do([1, 1], [Rational(3, 2), 2, 2]) # cosh-integral chi + + +def test_prudnikov_12(): + # 7.16 + assert can_do( + [], [a, a + S.Half, 2*a], False) # branches only agree for some z! + assert can_do([], [a, a + S.Half, 2*a + 1], False) # dito + assert can_do([], [S.Half, a, a + S.Half]) + assert can_do([], [Rational(3, 2), a, a + S.Half]) + + assert can_do([], [Rational(1, 4), S.Half, Rational(3, 4)]) + assert can_do([], [S.Half, S.Half, 1]) + assert can_do([], [S.Half, Rational(3, 2), 1]) + assert can_do([], [Rational(3, 4), Rational(3, 2), Rational(5, 4)]) + assert can_do([], [1, 1, Rational(3, 2)]) + assert can_do([], [1, 2, Rational(3, 2)]) + assert can_do([], [1, Rational(3, 2), Rational(3, 2)]) + assert can_do([], [Rational(5, 4), Rational(3, 2), Rational(7, 4)]) + assert can_do([], [2, Rational(3, 2), Rational(3, 2)]) + + +@slow +def test_prudnikov_2F1(): + h = S.Half + # Elliptic integrals + for p in [-h, h]: + for m in [h, 3*h, 5*h, 7*h]: + for n in [1, 2, 3, 4]: + assert can_do([p, m], [n]) + + +@XFAIL +def test_prudnikov_fail_2F1(): + assert can_do([a, b], [b + 1]) # incomplete beta function + assert can_do([-1, b], [c]) # Poly. also -2, -3 etc + + # TODO polys + + # Legendre functions: + assert can_do([a, b], [a + b + S.Half]) + assert can_do([a, b], [a + b - S.Half]) + assert can_do([a, b], [a + b + Rational(3, 2)]) + assert can_do([a, b], [(a + b + 1)/2]) + assert can_do([a, b], [(a + b)/2 + 1]) + assert can_do([a, b], [a - b + 1]) + assert can_do([a, b], [a - b + 2]) + assert can_do([a, b], [2*b]) + assert can_do([a, b], [S.Half]) + assert can_do([a, b], [Rational(3, 2)]) + assert can_do([a, 1 - a], [c]) + assert can_do([a, 2 - a], [c]) + assert can_do([a, 3 - a], [c]) + assert can_do([a, a + S.Half], [c]) + assert can_do([1, b], [c]) + assert can_do([1, b], [Rational(3, 2)]) + + assert can_do([Rational(1, 4), Rational(3, 4)], [1]) + + # PFDD + o = S.One + assert can_do([o/8, 1], [o/8*9]) + assert can_do([o/6, 1], [o/6*7]) + assert can_do([o/6, 1], [o/6*13]) + assert can_do([o/5, 1], [o/5*6]) + assert can_do([o/5, 1], [o/5*11]) + assert can_do([o/4, 1], [o/4*5]) + assert can_do([o/4, 1], [o/4*9]) + assert can_do([o/3, 1], [o/3*4]) + assert can_do([o/3, 1], [o/3*7]) + assert can_do([o/8*3, 1], [o/8*11]) + assert can_do([o/5*2, 1], [o/5*7]) + assert can_do([o/5*2, 1], [o/5*12]) + assert can_do([o/5*3, 1], [o/5*8]) + assert can_do([o/5*3, 1], [o/5*13]) + assert can_do([o/8*5, 1], [o/8*13]) + assert can_do([o/4*3, 1], [o/4*7]) + assert can_do([o/4*3, 1], [o/4*11]) + assert can_do([o/3*2, 1], [o/3*5]) + assert can_do([o/3*2, 1], [o/3*8]) + assert can_do([o/5*4, 1], [o/5*9]) + assert can_do([o/5*4, 1], [o/5*14]) + assert can_do([o/6*5, 1], [o/6*11]) + assert can_do([o/6*5, 1], [o/6*17]) + assert can_do([o/8*7, 1], [o/8*15]) + + +@XFAIL +def test_prudnikov_fail_3F2(): + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(1, 3), Rational(2, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(2, 3), Rational(4, 3)]) + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [Rational(4, 3), Rational(5, 3)]) + + # page 421 + assert can_do([a, a + Rational(1, 3), a + Rational(2, 3)], [a*Rational(3, 2), (3*a + 1)/2]) + + # pages 422 ... + assert can_do([Rational(-1, 2), S.Half, S.Half], [1, 1]) # elliptic integrals + assert can_do([Rational(-1, 2), S.Half, 1], [Rational(3, 2), Rational(3, 2)]) + # TODO LOTS more + + # PFDD + assert can_do([Rational(1, 8), Rational(3, 8), 1], [Rational(9, 8), Rational(11, 8)]) + assert can_do([Rational(1, 8), Rational(5, 8), 1], [Rational(9, 8), Rational(13, 8)]) + assert can_do([Rational(1, 8), Rational(7, 8), 1], [Rational(9, 8), Rational(15, 8)]) + assert can_do([Rational(1, 6), Rational(1, 3), 1], [Rational(7, 6), Rational(4, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(7, 6), Rational(5, 3)]) + assert can_do([Rational(1, 6), Rational(2, 3), 1], [Rational(5, 3), Rational(13, 6)]) + assert can_do([S.Half, 1, 1], [Rational(1, 4), Rational(3, 4)]) + # LOTS more + + +@XFAIL +def test_prudnikov_fail_other(): + # 7.11.2 + + # 7.12.1 + assert can_do([1, a], [b, 1 - 2*a + b]) # ??? + + # 7.14.2 + assert can_do([Rational(-1, 2)], [S.Half, 1]) # struve + assert can_do([1], [S.Half, S.Half]) # struve + assert can_do([Rational(1, 4)], [S.Half, Rational(5, 4)]) # PFDD + assert can_do([Rational(3, 4)], [Rational(3, 2), Rational(7, 4)]) # PFDD + assert can_do([1], [Rational(1, 4), Rational(3, 4)]) # PFDD + assert can_do([1], [Rational(3, 4), Rational(5, 4)]) # PFDD + assert can_do([1], [Rational(5, 4), Rational(7, 4)]) # PFDD + # TODO LOTS more + + # 7.15.2 + assert can_do([S.Half, 1], [Rational(3, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + assert can_do([S.Half, 1], [Rational(7, 4), Rational(5, 4), Rational(3, 2)]) # PFDD + + # 7.16.1 + assert can_do([], [Rational(1, 3), S(2/3)]) # PFDD + assert can_do([], [Rational(2, 3), S(4/3)]) # PFDD + assert can_do([], [Rational(5, 3), S(4/3)]) # PFDD + + # XXX this does not *evaluate* right?? + assert can_do([], [a, a + S.Half, 2*a - 1]) + + +def test_bug(): + h = hyper([-1, 1], [z], -1) + assert hyperexpand(h) == (z + 1)/z + + +def test_omgissue_203(): + h = hyper((-5, -3, -4), (-6, -6), 1) + assert hyperexpand(h) == Rational(1, 30) + h = hyper((-6, -7, -5), (-6, -6), 1) + assert hyperexpand(h) == Rational(-1, 6) diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_powsimp.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_powsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..fdae6bfc1b26e560abdfca626b059a1ce77aa0a5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_powsimp.py @@ -0,0 +1,366 @@ +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (E, I, Rational, oo, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import sin +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.simplify.powsimp import (powdenest, powsimp) +from sympy.simplify.simplify import (signsimp, simplify) +from sympy.core.symbol import Str + +from sympy.abc import x, y, z, a, b + + +def test_powsimp(): + x, y, z, n = symbols('x,y,z,n') + f = Function('f') + assert powsimp( 4**x * 2**(-x) * 2**(-x) ) == 1 + assert powsimp( (-4)**x * (-2)**(-x) * 2**(-x) ) == 1 + + assert powsimp( + f(4**x * 2**(-x) * 2**(-x)) ) == f(4**x * 2**(-x) * 2**(-x)) + assert powsimp( f(4**x * 2**(-x) * 2**(-x)), deep=True ) == f(1) + assert exp(x)*exp(y) == exp(x)*exp(y) + assert powsimp(exp(x)*exp(y)) == exp(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y) == (2*E)**(x + y) + assert powsimp(exp(x)*exp(y)*2**x*2**y, combine='exp') == \ + exp(x + y)*2**(x + y) + assert powsimp(exp(x)*exp(y)*exp(2)*sin(x) + sin(y) + 2**x*2**y) == \ + exp(2 + x + y)*sin(x) + sin(y) + 2**(x + y) + assert powsimp(sin(exp(x)*exp(y))) == sin(exp(x)*exp(y)) + assert powsimp(sin(exp(x)*exp(y)), deep=True) == sin(exp(x + y)) + assert powsimp(x**2*x**y) == x**(2 + y) + # This should remain factored, because 'exp' with deep=True is supposed + # to act like old automatic exponent combining. + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp', deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), deep=True) == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E)) == (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='exp') == \ + (1 + exp(1 + E))*exp(-E) + assert powsimp((1 + E*exp(E))*exp(-E), combine='base') == \ + (1 + E*exp(E))*exp(-E) + x, y = symbols('x,y', nonnegative=True) + n = Symbol('n', real=True) + assert powsimp(y**n * (y/x)**(-n)) == x**n + assert powsimp(x**(x**(x*y)*y**(x*y))*y**(x**(x*y)*y**(x*y)), deep=True) \ + == (x*y)**(x*y)**(x*y) + assert powsimp(2**(2**(2*x)*x), deep=False) == 2**(2**(2*x)*x) + assert powsimp(2**(2**(2*x)*x), deep=True) == 2**(x*4**x) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp( + exp(-x + exp(-x)*exp(-x*log(x))), deep=False, combine='exp') == \ + exp(-x + exp(-x)*exp(-x*log(x))) + assert powsimp((x + y)/(3*z), deep=False, combine='exp') == (x + y)/(3*z) + assert powsimp((x/3 + y/3)/z, deep=True, combine='exp') == (x/3 + y/3)/z + assert powsimp(exp(x)/(1 + exp(x)*exp(y)), deep=True) == \ + exp(x)/(1 + exp(x + y)) + assert powsimp(x*y**(z**x*z**y), deep=True) == x*y**(z**(x + y)) + assert powsimp((z**x*z**y)**x, deep=True) == (z**(x + y))**x + assert powsimp(x*(z**x*z**y)**x, deep=True) == x*(z**(x + y))**x + p = symbols('p', positive=True) + assert powsimp((1/x)**log(2)/x) == (1/x)**(1 + log(2)) + assert powsimp((1/p)**log(2)/p) == p**(-1 - log(2)) + + # coefficient of exponent can only be simplified for positive bases + assert powsimp(2**(2*x)) == 4**x + assert powsimp((-1)**(2*x)) == (-1)**(2*x) + i = symbols('i', integer=True) + assert powsimp((-1)**(2*i)) == 1 + assert powsimp((-1)**(-x)) != (-1)**x # could be 1/((-1)**x), but is not + # force=True overrides assumptions + assert powsimp((-1)**(2*x), force=True) == 1 + + # rational exponents allow combining of negative terms + w, n, m = symbols('w n m', negative=True) + e = i/a # not a rational exponent if `a` is unknown + ex = w**e*n**e*m**e + assert powsimp(ex) == m**(i/a)*n**(i/a)*w**(i/a) + e = i/3 + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**i*(-m*n*w)**(i/3) + e = (3 + i)/i + ex = w**e*n**e*m**e + assert powsimp(ex) == (-1)**(3*e)*(-m*n*w)**e + + eq = x**(a*Rational(2, 3)) + # eq != (x**a)**(2/3) (try x = -1 and a = 3 to see) + assert powsimp(eq).exp == eq.exp == a*Rational(2, 3) + # powdenest goes the other direction + assert powsimp(2**(2*x)) == 4**x + + assert powsimp(exp(p/2)) == exp(p/2) + + # issue 6368 + eq = Mul(*[sqrt(Dummy(imaginary=True)) for i in range(3)]) + assert powsimp(eq) == eq and eq.is_Mul + + assert all(powsimp(e) == e for e in (sqrt(x**a), sqrt(x**2))) + + # issue 8836 + assert str( powsimp(exp(I*pi/3)*root(-1,3)) ) == '(-1)**(2/3)' + + # issue 9183 + assert powsimp(-0.1**x) == -0.1**x + + # issue 10095 + assert powsimp((1/(2*E))**oo) == (exp(-1)/2)**oo + + # PR 13131 + eq = sin(2*x)**2*sin(2.0*x)**2 + assert powsimp(eq) == eq + + # issue 14615 + assert powsimp(x**2*y**3*(x*y**2)**Rational(3, 2) + ) == x*y*(x*y**2)**Rational(5, 2) + + +def test_powsimp_negated_base(): + assert powsimp((-x + y)/sqrt(x - y)) == -sqrt(x - y) + assert powsimp((-x + y)*(-z + y)/sqrt(x - y)/sqrt(z - y)) == sqrt(x - y)*sqrt(z - y) + p = symbols('p', positive=True) + reps = {p: 2, a: S.Half} + assert powsimp((-p)**a/p**a).subs(reps) == ((-1)**a).subs(reps) + assert powsimp((-p)**a*p**a).subs(reps) == ((-p**2)**a).subs(reps) + n = symbols('n', negative=True) + reps = {p: -2, a: S.Half} + assert powsimp((-n)**a/n**a).subs(reps) == (-1)**(-a).subs(a, S.Half) + assert powsimp((-n)**a*n**a).subs(reps) == ((-n**2)**a).subs(reps) + # if x is 0 then the lhs is 0**a*oo**a which is not (-1)**a + eq = (-x)**a/x**a + assert powsimp(eq) == eq + + +def test_powsimp_nc(): + x, y, z = symbols('x,y,z') + A, B, C = symbols('A B C', commutative=False) + + assert powsimp(A**x*A**y, combine='all') == A**(x + y) + assert powsimp(A**x*A**y, combine='base') == A**x*A**y + assert powsimp(A**x*A**y, combine='exp') == A**(x + y) + + assert powsimp(A**x*B**x, combine='all') == A**x*B**x + assert powsimp(A**x*B**x, combine='base') == A**x*B**x + assert powsimp(A**x*B**x, combine='exp') == A**x*B**x + + assert powsimp(B**x*A**x, combine='all') == B**x*A**x + assert powsimp(B**x*A**x, combine='base') == B**x*A**x + assert powsimp(B**x*A**x, combine='exp') == B**x*A**x + + assert powsimp(A**x*A**y*A**z, combine='all') == A**(x + y + z) + assert powsimp(A**x*A**y*A**z, combine='base') == A**x*A**y*A**z + assert powsimp(A**x*A**y*A**z, combine='exp') == A**(x + y + z) + + assert powsimp(A**x*B**x*C**x, combine='all') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='base') == A**x*B**x*C**x + assert powsimp(A**x*B**x*C**x, combine='exp') == A**x*B**x*C**x + + assert powsimp(B**x*A**x*C**x, combine='all') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='base') == B**x*A**x*C**x + assert powsimp(B**x*A**x*C**x, combine='exp') == B**x*A**x*C**x + + +def test_issue_6440(): + assert powsimp(16*2**a*8**b) == 2**(a + 3*b + 4) + + +def test_powdenest(): + x, y = symbols('x,y') + p, q = symbols('p q', positive=True) + i, j = symbols('i,j', integer=True) + + assert powdenest(x) == x + assert powdenest(x + 2*(x**(a*Rational(2, 3)))**(3*x)) == (x + 2*(x**(a*Rational(2, 3)))**(3*x)) + assert powdenest((exp(a*Rational(2, 3)))**(3*x)) # -X-> (exp(a/3))**(6*x) + assert powdenest((x**(a*Rational(2, 3)))**(3*x)) == ((x**(a*Rational(2, 3)))**(3*x)) + assert powdenest(exp(3*x*log(2))) == 2**(3*x) + assert powdenest(sqrt(p**2)) == p + eq = p**(2*i)*q**(4*i) + assert powdenest(eq) == (p*q**2)**(2*i) + # -X-> (x**x)**i*(x**x)**j == x**(x*(i + j)) + assert powdenest((x**x)**(i + j)) + assert powdenest(exp(3*y*log(x))) == x**(3*y) + assert powdenest(exp(y*(log(a) + log(b)))) == (a*b)**y + assert powdenest(exp(3*(log(a) + log(b)))) == a**3*b**3 + assert powdenest(((x**(2*i))**(3*y))**x) == ((x**(2*i))**(3*y))**x + assert powdenest(((x**(2*i))**(3*y))**x, force=True) == x**(6*i*x*y) + assert powdenest(((x**(a*Rational(2, 3)))**(3*y/i))**x) == \ + (((x**(a*Rational(2, 3)))**(3*y/i))**x) + assert powdenest((x**(2*i)*y**(4*i))**z, force=True) == (x*y**2)**(2*i*z) + assert powdenest((p**(2*i)*q**(4*i))**j) == (p*q**2)**(2*i*j) + e = ((p**(2*a))**(3*y))**x + assert powdenest(e) == e + e = ((x**2*y**4)**a)**(x*y) + assert powdenest(e) == e + e = (((x**2*y**4)**a)**(x*y))**3 + assert powdenest(e) == ((x**2*y**4)**a)**(3*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y)), force=True) == \ + (x*y**2)**(2*a*x*y) + assert powdenest((((x**2*y**4)**a)**(x*y))**3, force=True) == \ + (x*y**2)**(6*a*x*y) + assert powdenest((x**2*y**6)**i) != (x*y**3)**(2*i) + x, y = symbols('x,y', positive=True) + assert powdenest((x**2*y**6)**i) == (x*y**3)**(2*i) + + assert powdenest((x**(i*Rational(2, 3))*y**(i/2))**(2*i)) == (x**Rational(4, 3)*y)**(i**2) + assert powdenest(sqrt(x**(2*i)*y**(6*i))) == (x*y**3)**i + + assert powdenest(4**x) == 2**(2*x) + assert powdenest((4**x)**y) == 2**(2*x*y) + assert powdenest(4**x*y) == 2**(2*x)*y + + +def test_powdenest_polar(): + x, y, z = symbols('x y z', polar=True) + a, b, c = symbols('a b c') + assert powdenest((x*y*z)**a) == x**a*y**a*z**a + assert powdenest((x**a*y**b)**c) == x**(a*c)*y**(b*c) + assert powdenest(((x**a)**b*y**c)**c) == x**(a*b*c)*y**(c**2) + + +def test_issue_5805(): + arg = ((gamma(x)*hyper((), (), x))*pi)**2 + assert powdenest(arg) == (pi*gamma(x)*hyper((), (), x))**2 + assert arg.is_positive is None + + +def test_issue_9324_powsimp_on_matrix_symbol(): + M = MatrixSymbol('M', 10, 10) + expr = powsimp(M, deep=True) + assert expr == M + assert expr.args[0] == Str('M') + + +def test_issue_6367(): + z = -5*sqrt(2)/(2*sqrt(2*sqrt(29) + 29)) + sqrt(-sqrt(29)/29 + S.Half) + assert Mul(*[powsimp(a) for a in Mul.make_args(z.normal())]) == 0 + assert powsimp(z.normal()) == 0 + assert simplify(z) == 0 + assert powsimp(sqrt(2 + sqrt(3))*sqrt(2 - sqrt(3)) + 1) == 2 + assert powsimp(z) != 0 + + +def test_powsimp_polar(): + from sympy.functions.elementary.complexes import polar_lift + from sympy.functions.elementary.exponential import exp_polar + x, y, z = symbols('x y z') + p, q, r = symbols('p q r', polar=True) + + assert (polar_lift(-1))**(2*x) == exp_polar(2*pi*I*x) + assert powsimp(p**x * q**x) == (p*q)**x + assert p**x * (1/p)**x == 1 + assert (1/p)**x == p**(-x) + + assert exp_polar(x)*exp_polar(y) == exp_polar(x)*exp_polar(y) + assert powsimp(exp_polar(x)*exp_polar(y)) == exp_polar(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y) == \ + (p*exp_polar(1))**(x + y) + assert powsimp(exp_polar(x)*exp_polar(y)*p**x*p**y, combine='exp') == \ + exp_polar(x + y)*p**(x + y) + assert powsimp( + exp_polar(x)*exp_polar(y)*exp_polar(2)*sin(x) + sin(y) + p**x*p**y) \ + == p**(x + y) + sin(x)*exp_polar(2 + x + y) + sin(y) + assert powsimp(sin(exp_polar(x)*exp_polar(y))) == \ + sin(exp_polar(x)*exp_polar(y)) + assert powsimp(sin(exp_polar(x)*exp_polar(y)), deep=True) == \ + sin(exp_polar(x + y)) + + +def test_issue_5728(): + b = x*sqrt(y) + a = sqrt(b) + c = sqrt(sqrt(x)*y) + assert powsimp(a*b) == sqrt(b)**3 + assert powsimp(a*b**2*sqrt(y)) == sqrt(y)*a**5 + assert powsimp(a*x**2*c**3*y) == c**3*a**5 + assert powsimp(a*x*c**3*y**2) == c**7*a + assert powsimp(x*c**3*y**2) == c**7 + assert powsimp(x*c**3*y) == x*y*c**3 + assert powsimp(sqrt(x)*c**3*y) == c**5 + assert powsimp(sqrt(x)*a**3*sqrt(y)) == sqrt(x)*sqrt(y)*a**3 + assert powsimp(Mul(sqrt(x)*c**3*sqrt(y), y, evaluate=False)) == \ + sqrt(x)*sqrt(y)**3*c**3 + assert powsimp(a**2*a*x**2*y) == a**7 + + # symbolic powers work, too + b = x**y*y + a = b*sqrt(b) + assert a.is_Mul is True + assert powsimp(a) == sqrt(b)**3 + + # as does exp + a = x*exp(y*Rational(2, 3)) + assert powsimp(a*sqrt(a)) == sqrt(a)**3 + assert powsimp(a**2*sqrt(a)) == sqrt(a)**5 + assert powsimp(a**2*sqrt(sqrt(a))) == sqrt(sqrt(a))**9 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert (powsimp(sqrt(n1)*sqrt(n2)*sqrt(n3)) == + -I*sqrt(-n1)*sqrt(-n2)*sqrt(-n3)) + assert (powsimp(root(n1, 3)*root(n2, 3)*root(n3, 3)*root(n4, 3)) == + -(-1)**Rational(1, 3)* + (-n1)**Rational(1, 3)*(-n2)**Rational(1, 3)*(-n3)**Rational(1, 3)*(-n4)**Rational(1, 3)) + + +def test_issue_10195(): + a = Symbol('a', integer=True) + l = Symbol('l', even=True, nonzero=True) + n = Symbol('n', odd=True) + e_x = (-1)**(n/2 - S.Half) - (-1)**(n*Rational(3, 2) - S.Half) + assert powsimp((-1)**(l/2)) == I**l + assert powsimp((-1)**(n/2)) == I**n + assert powsimp((-1)**(n*Rational(3, 2))) == -I**n + assert powsimp(e_x) == (-1)**(n/2 - S.Half) + (-1)**(n*Rational(3, 2) + + S.Half) + assert powsimp((-1)**(a*Rational(3, 2))) == (-I)**a + +def test_issue_15709(): + assert powsimp(3**x*Rational(2, 3)) == 2*3**(x-1) + assert powsimp(2*3**x/3) == 2*3**(x-1) + + +def test_issue_11981(): + x, y = symbols('x y', commutative=False) + assert powsimp((x*y)**2 * (y*x)**2) == (x*y)**2 * (y*x)**2 + + +def test_issue_17524(): + a = symbols("a", real=True) + e = (-1 - a**2)*sqrt(1 + a**2) + assert signsimp(powsimp(e)) == signsimp(e) == -(a**2 + 1)**(S(3)/2) + + +def test_issue_19627(): + # if you use force the user must verify + assert powdenest(sqrt(sin(x)**2), force=True) == sin(x) + assert powdenest((x**(S.Half/y))**(2*y), force=True) == x + from sympy.core.function import expand_power_base + e = 1 - a + expr = (exp(z/e)*x**(b/e)*y**((1 - b)/e))**e + assert powdenest(expand_power_base(expr, force=True), force=True + ) == x**b*y**(1 - b)*exp(z) + + +def test_issue_22546(): + p1, p2 = symbols('p1, p2', positive=True) + ref = powsimp(p1**z/p2**z) + e = z + 1 + ans = ref.subs(z, e) + assert ans.is_Pow + assert powsimp(p1**e/p2**e) == ans + i = symbols('i', integer=True) + ref = powsimp(x**i/y**i) + e = i + 1 + ans = ref.subs(i, e) + assert ans.is_Pow + assert powsimp(x**e/y**e) == ans diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_radsimp.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_radsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..fabea1f1acb63c1e7845e82bcfd2a41c6bf97e67 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_radsimp.py @@ -0,0 +1,490 @@ +from sympy.core.add import Add +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational) +from sympy.core.power import Pow +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, Wild, symbols) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys.polytools import factor +from sympy.series.order import O +from sympy.simplify.radsimp import (collect, collect_const, fraction, radsimp, rcollect) + +from sympy.core.expr import unchanged +from sympy.core.mul import _unevaluated_Mul as umul +from sympy.simplify.radsimp import (_unevaluated_Add, + collect_sqrt, fraction_expand, collect_abs) +from sympy.testing.pytest import raises + +from sympy.abc import x, y, z, a, b, c, d + + +def test_radsimp(): + r2 = sqrt(2) + r3 = sqrt(3) + r5 = sqrt(5) + r7 = sqrt(7) + assert fraction(radsimp(1/r2)) == (sqrt(2), 2) + assert radsimp(1/(1 + r2)) == \ + -1 + sqrt(2) + assert radsimp(1/(r2 + r3)) == \ + -sqrt(2) + sqrt(3) + assert fraction(radsimp(1/(1 + r2 + r3))) == \ + (-sqrt(6) + sqrt(2) + 2, 4) + assert fraction(radsimp(1/(r2 + r3 + r5))) == \ + (-sqrt(30) + 2*sqrt(3) + 3*sqrt(2), 12) + assert fraction(radsimp(1/(1 + r2 + r3 + r5))) == ( + (-34*sqrt(10) - 26*sqrt(15) - 55*sqrt(3) - 61*sqrt(2) + 14*sqrt(30) + + 93 + 46*sqrt(6) + 53*sqrt(5), 71)) + assert fraction(radsimp(1/(r2 + r3 + r5 + r7))) == ( + (-50*sqrt(42) - 133*sqrt(5) - 34*sqrt(70) - 145*sqrt(3) + 22*sqrt(105) + + 185*sqrt(2) + 62*sqrt(30) + 135*sqrt(7), 215)) + z = radsimp(1/(1 + r2/3 + r3/5 + r5 + r7)) + assert len((3616791619821680643598*z).args) == 16 + assert radsimp(1/z) == 1/z + assert radsimp(1/z, max_terms=20).expand() == 1 + r2/3 + r3/5 + r5 + r7 + assert radsimp(1/(r2*3)) == \ + sqrt(2)/6 + assert radsimp(1/(r2*a + r3 + r5 + r7)) == ( + (8*sqrt(2)*a**7 - 8*sqrt(7)*a**6 - 8*sqrt(5)*a**6 - 8*sqrt(3)*a**6 - + 180*sqrt(2)*a**5 + 8*sqrt(30)*a**5 + 8*sqrt(42)*a**5 + 8*sqrt(70)*a**5 + - 24*sqrt(105)*a**4 + 84*sqrt(3)*a**4 + 100*sqrt(5)*a**4 + + 116*sqrt(7)*a**4 - 72*sqrt(70)*a**3 - 40*sqrt(42)*a**3 - + 8*sqrt(30)*a**3 + 782*sqrt(2)*a**3 - 462*sqrt(3)*a**2 - + 302*sqrt(7)*a**2 - 254*sqrt(5)*a**2 + 120*sqrt(105)*a**2 - + 795*sqrt(2)*a - 62*sqrt(30)*a + 82*sqrt(42)*a + 98*sqrt(70)*a - + 118*sqrt(105) + 59*sqrt(7) + 295*sqrt(5) + 531*sqrt(3))/(16*a**8 - + 480*a**6 + 3128*a**4 - 6360*a**2 + 3481)) + assert radsimp(1/(r2*a + r2*b + r3 + r7)) == ( + (sqrt(2)*a*(a + b)**2 - 5*sqrt(2)*a + sqrt(42)*a + sqrt(2)*b*(a + + b)**2 - 5*sqrt(2)*b + sqrt(42)*b - sqrt(7)*(a + b)**2 - sqrt(3)*(a + + b)**2 - 2*sqrt(3) + 2*sqrt(7))/(2*a**4 + 8*a**3*b + 12*a**2*b**2 - + 20*a**2 + 8*a*b**3 - 40*a*b + 2*b**4 - 20*b**2 + 8)) + assert radsimp(1/(r2*a + r2*b + r2*c + r2*d)) == \ + sqrt(2)/(2*a + 2*b + 2*c + 2*d) + assert radsimp(1/(1 + r2*a + r2*b + r2*c + r2*d)) == ( + (sqrt(2)*a + sqrt(2)*b + sqrt(2)*c + sqrt(2)*d - 1)/(2*a**2 + 4*a*b + + 4*a*c + 4*a*d + 2*b**2 + 4*b*c + 4*b*d + 2*c**2 + 4*c*d + 2*d**2 - 1)) + assert radsimp((y**2 - x)/(y - sqrt(x))) == \ + sqrt(x) + y + assert radsimp(-(y**2 - x)/(y - sqrt(x))) == \ + -(sqrt(x) + y) + assert radsimp(1/(1 - I + a*I)) == \ + (-I*a + 1 + I)/(a**2 - 2*a + 2) + assert radsimp(1/((-x + y)*(x - sqrt(y)))) == \ + (-x - sqrt(y))/((x - y)*(x**2 - y)) + e = (3 + 3*sqrt(2))*x*(3*x - 3*sqrt(y)) + assert radsimp(e) == x*(3 + 3*sqrt(2))*(3*x - 3*sqrt(y)) + assert radsimp(1/e) == ( + (-9*x + 9*sqrt(2)*x - 9*sqrt(y) + 9*sqrt(2)*sqrt(y))/(9*x*(9*x**2 - + 9*y))) + assert radsimp(1 + 1/(1 + sqrt(3))) == \ + Mul(S.Half, -1 + sqrt(3), evaluate=False) + 1 + A = symbols("A", commutative=False) + assert radsimp(x**2 + sqrt(2)*x**2 - sqrt(2)*x*A) == \ + x**2 + sqrt(2)*x**2 - sqrt(2)*x*A + assert radsimp(1/sqrt(5 + 2 * sqrt(6))) == -sqrt(2) + sqrt(3) + assert radsimp(1/sqrt(5 + 2 * sqrt(6))**3) == -(-sqrt(3) + sqrt(2))**3 + + # issue 6532 + assert fraction(radsimp(1/sqrt(x))) == (sqrt(x), x) + assert fraction(radsimp(1/sqrt(2*x + 3))) == (sqrt(2*x + 3), 2*x + 3) + assert fraction(radsimp(1/sqrt(2*(x + 3)))) == (sqrt(2*x + 6), 2*x + 6) + + # issue 5994 + e = S('-(2 + 2*sqrt(2) + 4*2**(1/4))/' + '(1 + 2**(3/4) + 3*2**(1/4) + 3*sqrt(2))') + assert radsimp(e).expand() == -2*2**Rational(3, 4) - 2*2**Rational(1, 4) + 2 + 2*sqrt(2) + + # issue 5986 (modifications to radimp didn't initially recognize this so + # the test is included here) + assert radsimp(1/(-sqrt(5)/2 - S.Half + (-sqrt(5)/2 - S.Half)**2)) == 1 + + # from issue 5934 + eq = ( + (-240*sqrt(2)*sqrt(sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) - + 360*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) - + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(-sqrt(5) + 5) + + 120*sqrt(2)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(2)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5) + + 120*sqrt(10)*sqrt(-sqrt(5) + 5)*sqrt(8*sqrt(5) + 40) + + 120*sqrt(10)*sqrt(-8*sqrt(5) + 40)*sqrt(sqrt(5) + 5))/(-36000 - + 7200*sqrt(5) + (12*sqrt(10)*sqrt(sqrt(5) + 5) + + 24*sqrt(10)*sqrt(-sqrt(5) + 5))**2)) + assert radsimp(eq) is S.NaN # it's 0/0 + + # work with normal form + e = 1/sqrt(sqrt(7)/7 + 2*sqrt(2) + 3*sqrt(3) + 5*sqrt(5)) + 3 + assert radsimp(e) == ( + -sqrt(sqrt(7) + 14*sqrt(2) + 21*sqrt(3) + + 35*sqrt(5))*(-11654899*sqrt(35) - 1577436*sqrt(210) - 1278438*sqrt(15) + - 1346996*sqrt(10) + 1635060*sqrt(6) + 5709765 + 7539830*sqrt(14) + + 8291415*sqrt(21))/1300423175 + 3) + + # obey power rules + base = sqrt(3) - sqrt(2) + assert radsimp(1/base**3) == (sqrt(3) + sqrt(2))**3 + assert radsimp(1/(-base)**3) == -(sqrt(2) + sqrt(3))**3 + assert radsimp(1/(-base)**x) == (-base)**(-x) + assert radsimp(1/base**x) == (sqrt(2) + sqrt(3))**x + assert radsimp(root(1/(-1 - sqrt(2)), -x)) == (-1)**(-1/x)*(1 + sqrt(2))**(1/x) + + # recurse + e = cos(1/(1 + sqrt(2))) + assert radsimp(e) == cos(-sqrt(2) + 1) + assert radsimp(e/2) == cos(-sqrt(2) + 1)/2 + assert radsimp(1/e) == 1/cos(-sqrt(2) + 1) + assert radsimp(2/e) == 2/cos(-sqrt(2) + 1) + assert fraction(radsimp(e/sqrt(x))) == (sqrt(x)*cos(-sqrt(2)+1), x) + + # test that symbolic denominators are not processed + r = 1 + sqrt(2) + assert radsimp(x/r, symbolic=False) == -x*(-sqrt(2) + 1) + assert radsimp(x/(y + r), symbolic=False) == x/(y + 1 + sqrt(2)) + assert radsimp(x/(y + r)/r, symbolic=False) == \ + -x*(-sqrt(2) + 1)/(y + 1 + sqrt(2)) + + # issue 7408 + eq = sqrt(x)/sqrt(y) + assert radsimp(eq) == umul(sqrt(x), sqrt(y), 1/y) + assert radsimp(eq, symbolic=False) == eq + + # issue 7498 + assert radsimp(sqrt(x)/sqrt(y)**3) == umul(sqrt(x), sqrt(y**3), 1/y**3) + + # for coverage + eq = sqrt(x)/y**2 + assert radsimp(eq) == eq + + +def test_radsimp_issue_3214(): + c, p = symbols('c p', positive=True) + s = sqrt(c**2 - p**2) + b = (c + I*p - s)/(c + I*p + s) + assert radsimp(b) == -I*(c + I*p - sqrt(c**2 - p**2))**2/(2*c*p) + + +def test_collect_1(): + """Collect with respect to Symbol""" + x, y, z, n = symbols('x,y,z,n') + assert collect(1, x) == 1 + assert collect( x + y*x, x ) == x * (1 + y) + assert collect( x + x**2, x ) == x + x**2 + assert collect( x**2 + y*x**2, x ) == (x**2)*(1 + y) + assert collect( x**2 + y*x, x ) == x*y + x**2 + assert collect( 2*x**2 + y*x**2 + 3*x*y, [x] ) == x**2*(2 + y) + 3*x*y + assert collect( 2*x**2 + y*x**2 + 3*x*y, [y] ) == 2*x**2 + y*(x**2 + 3*x) + + assert collect( ((1 + y + x)**4).expand(), x) == ((1 + y)**4).expand() + \ + x*(4*(1 + y)**3).expand() + x**2*(6*(1 + y)**2).expand() + \ + x**3*(4*(1 + y)).expand() + x**4 + # symbols can be given as any iterable + expr = x + y + assert collect(expr, expr.free_symbols) == expr + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x, x, exact=None + ) == x*exp(x) + 3*x + (y + 2)*sin(x) + assert collect(x*exp(x) + sin(x)*y + sin(x)*2 + 3*x + y*x + + y*x*exp(x), x, exact=None + ) == x*exp(x)*(y + 1) + (3 + y)*x + (y + 2)*sin(x) + + +def test_collect_2(): + """Collect with respect to a sum""" + a, b, x = symbols('a,b,x') + assert collect(a*(cos(x) + sin(x)) + b*(cos(x) + sin(x)), + sin(x) + cos(x)) == (a + b)*(cos(x) + sin(x)) + + +def test_collect_3(): + """Collect with respect to a product""" + a, b, c = symbols('a,b,c') + f = Function('f') + x, y, z, n = symbols('x,y,z,n') + + assert collect(-x/8 + x*y, -x) == x*(y - Rational(1, 8)) + + assert collect( 1 + x*(y**2), x*y ) == 1 + x*(y**2) + assert collect( x*y + a*x*y, x*y) == x*y*(1 + a) + assert collect( 1 + x*y + a*x*y, x*y) == 1 + x*y*(1 + a) + assert collect(a*x*f(x) + b*(x*f(x)), x*f(x)) == x*(a + b)*f(x) + + assert collect(a*x*log(x) + b*(x*log(x)), x*log(x)) == x*(a + b)*log(x) + assert collect(a*x**2*log(x)**2 + b*(x*log(x))**2, x*log(x)) == \ + x**2*log(x)**2*(a + b) + + # with respect to a product of three symbols + assert collect(y*x*z + a*x*y*z, x*y*z) == (1 + a)*x*y*z + + +def test_collect_4(): + """Collect with respect to a power""" + a, b, c, x = symbols('a,b,c,x') + + assert collect(a*x**c + b*x**c, x**c) == x**c*(a + b) + # issue 6096: 2 stays with c (unless c is integer or x is positive0 + assert collect(a*x**(2*c) + b*x**(2*c), x**c) == x**(2*c)*(a + b) + + +def test_collect_5(): + """Collect with respect to a tuple""" + a, x, y, z, n = symbols('a,x,y,z,n') + assert collect(x**2*y**4 + z*(x*y**2)**2 + z + a*z, [x*y**2, z]) in [ + z*(1 + a + x**2*y**4) + x**2*y**4, + z*(1 + a) + x**2*y**4*(1 + z) ] + assert collect((1 + (x + y) + (x + y)**2).expand(), + [x, y]) == 1 + y + x*(1 + 2*y) + x**2 + y**2 + + +def test_collect_pr19431(): + """Unevaluated collect with respect to a product""" + a = symbols('a') + assert collect(a**2*(a**2 + 1), a**2, evaluate=False)[a**2] == (a**2 + 1) + + +def test_collect_D(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fx = D(f(x), x) + fxx = D(f(x), x, x) + + assert collect(a*fx + b*fx, fx) == (a + b)*fx + assert collect(a*D(fx, x) + b*D(fx, x), fx) == (a + b)*D(fx, x) + assert collect(a*fxx + b*fxx, fx) == (a + b)*D(fx, x) + # issue 4784 + assert collect(5*f(x) + 3*fx, fx) == 5*f(x) + 3*fx + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x)) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(f(x) + f(x)*diff(f(x), x) + x*diff(f(x), x)*f(x), f(x).diff(x), exact=True) == \ + (x*f(x) + f(x))*D(f(x), x) + f(x) + assert collect(1/f(x) + 1/f(x)*diff(f(x), x) + x*diff(f(x), x)/f(x), f(x).diff(x), exact=True) == \ + (1/f(x) + x/f(x))*D(f(x), x) + 1/f(x) + e = (1 + x*fx + fx)/f(x) + assert collect(e.expand(), fx) == fx*(x/f(x) + 1/f(x)) + 1/f(x) + + +def test_collect_func(): + f = ((x + a + 1)**3).expand() + + assert collect(f, x) == a**3 + 3*a**2 + 3*a + x**3 + x**2*(3*a + 3) + \ + x*(3*a**2 + 6*a + 3) + 1 + assert collect(f, x, factor) == x**3 + 3*x**2*(a + 1) + 3*x*(a + 1)**2 + \ + (a + 1)**3 + + assert collect(f, x, evaluate=False) == { + S.One: a**3 + 3*a**2 + 3*a + 1, + x: 3*a**2 + 6*a + 3, x**2: 3*a + 3, + x**3: 1 + } + + assert collect(f, x, factor, evaluate=False) == { + S.One: (a + 1)**3, x: 3*(a + 1)**2, + x**2: umul(S(3), a + 1), x**3: 1} + + +def test_collect_order(): + a, b, x, t = symbols('a,b,x,t') + + assert collect(t + t*x + t*x**2 + O(x**3), t) == t*(1 + x + x**2 + O(x**3)) + assert collect(t + t*x + x**2 + O(x**3), t) == \ + t*(1 + x + O(x**3)) + x**2 + O(x**3) + + f = a*x + b*x + c*x**2 + d*x**2 + O(x**3) + g = x*(a + b) + x**2*(c + d) + O(x**3) + + assert collect(f, x) == g + assert collect(f, x, distribute_order_term=False) == g + + f = sin(a + b).series(b, 0, 10) + + assert collect(f, [sin(a), cos(a)]) == \ + sin(a)*cos(b).series(b, 0, 10) + cos(a)*sin(b).series(b, 0, 10) + assert collect(f, [sin(a), cos(a)], distribute_order_term=False) == \ + sin(a)*cos(b).series(b, 0, 10).removeO() + \ + cos(a)*sin(b).series(b, 0, 10).removeO() + O(b**10) + + +def test_rcollect(): + assert rcollect((x**2*y + x*y + x + y)/(x + y), y) == \ + (x + y*(1 + x + x**2))/(x + y) + assert rcollect(sqrt(-((x + 1)*(y + 1))), z) == sqrt(-((x + 1)*(y + 1))) + + +def test_collect_D_0(): + D = Derivative + f = Function('f') + x, a, b = symbols('x,a,b') + fxx = D(f(x), x, x) + + assert collect(a*fxx + b*fxx, fxx) == (a + b)*fxx + + +def test_collect_Wild(): + """Collect with respect to functions with Wild argument""" + a, b, x, y = symbols('a b x y') + f = Function('f') + w1 = Wild('.1') + w2 = Wild('.2') + assert collect(f(x) + a*f(x), f(w1)) == (1 + a)*f(x) + assert collect(f(x, y) + a*f(x, y), f(w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w2)) == (1 + a)*f(x, y) + assert collect(f(x, y) + a*f(x, y), f(w1, w1)) == f(x, y) + a*f(x, y) + assert collect(f(x, x) + a*f(x, x), f(w1, w1)) == (1 + a)*f(x, x) + assert collect(a*(x + 1)**y + (x + 1)**y, w1**y) == (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**b) == \ + a*(x + 1)**y + (x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, (x + 1)**w2) == \ + (1 + a)*(x + 1)**y + assert collect(a*(x + 1)**y + (x + 1)**y, w1**w2) == (1 + a)*(x + 1)**y + + +def test_collect_const(): + # coverage not provided by above tests + assert collect_const(2*sqrt(3) + 4*a*sqrt(5)) == \ + 2*(2*sqrt(5)*a + sqrt(3)) # let the primitive reabsorb + assert collect_const(2*sqrt(3) + 4*a*sqrt(5), sqrt(3)) == \ + 2*sqrt(3) + 4*a*sqrt(5) + assert collect_const(sqrt(2)*(1 + sqrt(2)) + sqrt(3) + x*sqrt(2)) == \ + sqrt(2)*(x + 1 + sqrt(2)) + sqrt(3) + + # issue 5290 + assert collect_const(2*x + 2*y + 1, 2) == \ + collect_const(2*x + 2*y + 1) == \ + Add(S.One, Mul(2, x + y, evaluate=False), evaluate=False) + assert collect_const(-y - z) == Mul(-1, y + z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, 2) == \ + Mul(2, x - y - z, evaluate=False) + assert collect_const(2*x - 2*y - 2*z, -2) == \ + _unevaluated_Add(2*x, Mul(-2, y + z, evaluate=False)) + + # this is why the content_primitive is used + eq = (sqrt(15 + 5*sqrt(2))*x + sqrt(3 + sqrt(2))*y)*2 + assert collect_sqrt(eq + 2) == \ + 2*sqrt(sqrt(2) + 3)*(sqrt(5)*x + y) + 2 + + # issue 16296 + assert collect_const(a + b + x/2 + y/2) == a + b + Mul(S.Half, x + y, evaluate=False) + + +def test_issue_13143(): + f = Function('f') + fx = f(x).diff(x) + e = f(x) + fx + f(x)*fx + # collect function before derivative + assert collect(e, Wild('w')) == f(x)*(fx + 1) + fx + e = f(x) + f(x)*fx + x*fx*f(x) + assert collect(e, fx) == (x*f(x) + f(x))*fx + f(x) + assert collect(e, f(x)) == (x*fx + fx + 1)*f(x) + e = f(x) + fx + f(x)*fx + assert collect(e, [f(x), fx]) == f(x)*(1 + fx) + fx + assert collect(e, [fx, f(x)]) == fx*(1 + f(x)) + f(x) + + +def test_issue_6097(): + assert collect(a*y**(2.0*x) + b*y**(2.0*x), y**x) == (a + b)*(y**x)**2.0 + assert collect(a*2**(2.0*x) + b*2**(2.0*x), 2**x) == (a + b)*(2**x)**2.0 + + +def test_fraction_expand(): + eq = (x + y)*y/x + assert eq.expand(frac=True) == fraction_expand(eq) == (x*y + y**2)/x + assert eq.expand() == y + y**2/x + + +def test_fraction(): + x, y, z = map(Symbol, 'xyz') + A = Symbol('A', commutative=False) + + assert fraction(S.Half) == (1, 2) + + assert fraction(x) == (x, 1) + assert fraction(1/x) == (1, x) + assert fraction(x/y) == (x, y) + assert fraction(x/2) == (x, 2) + + assert fraction(x*y/z) == (x*y, z) + assert fraction(x/(y*z)) == (x, y*z) + + assert fraction(1/y**2) == (1, y**2) + assert fraction(x/y**2) == (x, y**2) + + assert fraction((x**2 + 1)/y) == (x**2 + 1, y) + assert fraction(x*(y + 1)/y**7) == (x*(y + 1), y**7) + + assert fraction(exp(-x), exact=True) == (exp(-x), 1) + assert fraction((1/(x + y))/2, exact=True) == (1, Mul(2,(x + y), evaluate=False)) + + assert fraction(x*A/y) == (x*A, y) + assert fraction(x*A**-1/y) == (x*A**-1, y) + + n = symbols('n', negative=True) + assert fraction(exp(n)) == (1, exp(-n)) + assert fraction(exp(-n)) == (exp(-n), 1) + + p = symbols('p', positive=True) + assert fraction(exp(-p)*log(p), exact=True) == (exp(-p)*log(p), 1) + + m = Mul(1, 1, S.Half, evaluate=False) + assert fraction(m) == (1, 2) + assert fraction(m, exact=True) == (Mul(1, 1, evaluate=False), 2) + + m = Mul(1, 1, S.Half, S.Half, Pow(1, -1, evaluate=False), evaluate=False) + assert fraction(m) == (1, 4) + assert fraction(m, exact=True) == \ + (Mul(1, 1, evaluate=False), Mul(2, 2, 1, evaluate=False)) + + +def test_issue_5615(): + aA, Re, a, b, D = symbols('aA Re a b D') + e = ((D**3*a + b*aA**3)/Re).expand() + assert collect(e, [aA**3/Re, a]) == e + + +def test_issue_5933(): + from sympy.geometry.polygon import (Polygon, RegularPolygon) + from sympy.simplify.radsimp import denom + x = Polygon(*RegularPolygon((0, 0), 1, 5).vertices).centroid.x + assert abs(denom(x).n()) > 1e-12 + assert abs(denom(radsimp(x))) > 1e-12 # in case simplify didn't handle it + + +def test_issue_14608(): + a, b = symbols('a b', commutative=False) + x, y = symbols('x y') + raises(AttributeError, lambda: collect(a*b + b*a, a)) + assert collect(x*y + y*(x+1), a) == x*y + y*(x+1) + assert collect(x*y + y*(x+1) + a*b + b*a, y) == y*(2*x + 1) + a*b + b*a + + +def test_collect_abs(): + s = abs(x) + abs(y) + assert collect_abs(s) == s + assert unchanged(Mul, abs(x), abs(y)) + ans = Abs(x*y) + assert isinstance(ans, Abs) + assert collect_abs(abs(x)*abs(y)) == ans + assert collect_abs(1 + exp(abs(x)*abs(y))) == 1 + exp(ans) + + # See https://github.com/sympy/sympy/issues/12910 + p = Symbol('p', positive=True) + assert collect_abs(p/abs(1-p)).is_commutative is True + + +def test_issue_19149(): + eq = exp(3*x/4) + assert collect(eq, exp(x)) == eq + +def test_issue_19719(): + a, b = symbols('a, b') + expr = a**2 * (b + 1) + (7 + 1/b)/a + collected = collect(expr, (a**2, 1/a), evaluate=False) + # Would return {_Dummy_20**(-2): b + 1, 1/a: 7 + 1/b} without xreplace + assert collected == {a**2: b + 1, 1/a: 7 + 1/b} + + +def test_issue_21355(): + assert radsimp(1/(x + sqrt(x**2))) == 1/(x + sqrt(x**2)) + assert radsimp(1/(x - sqrt(x**2))) == 1/(x - sqrt(x**2)) diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_ratsimp.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_ratsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..14e84fd2b227518baff1bda4e5b27ecc40a8bcdd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_ratsimp.py @@ -0,0 +1,78 @@ +from sympy.core.numbers import (Rational, pi) +from sympy.functions.elementary.exponential import log +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.special.error_functions import erf +from sympy.polys.domains import GF +from sympy.simplify.ratsimp import (ratsimp, ratsimpmodprime) + +from sympy.abc import x, y, z, t, a, b, c, d, e + + +def test_ratsimp(): + f, g = 1/x + 1/y, (x + y)/(x*y) + + assert f != g and ratsimp(f) == g + + f, g = 1/(1 + 1/x), 1 - 1/(x + 1) + + assert f != g and ratsimp(f) == g + + f, g = x/(x + y) + y/(x + y), 1 + + assert f != g and ratsimp(f) == g + + f, g = -x - y - y**2/(x + y) + x**2/(x + y), -2*y + + assert f != g and ratsimp(f) == g + + f = (a*c*x*y + a*c*z - b*d*x*y - b*d*z - b*t*x*y - b*t*x - b*t*z + + e*x)/(x*y + z) + G = [a*c - b*d - b*t + (-b*t*x + e*x)/(x*y + z), + a*c - b*d - b*t - ( b*t*x - e*x)/(x*y + z)] + + assert f != g and ratsimp(f) in G + + A = sqrt(pi) + + B = log(erf(x) - 1) + C = log(erf(x) + 1) + + D = 8 - 8*erf(x) + + f = A*B/D - A*C/D + A*C*erf(x)/D - A*B*erf(x)/D + 2*A/D + + assert ratsimp(f) == A*B/8 - A*C/8 - A/(4*erf(x) - 4) + + +def test_ratsimpmodprime(): + a = y**5 + x + y + b = x - y + F = [x*y**5 - x - y] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (-x**2 - x*y - x - y) / (-x**2 + x*y) + + a = x + y**2 - 2 + b = x + y**2 - y - 1 + F = [x*y - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + y - x)/(y - x) + + a = 5*x**3 + 21*x**2 + 4*x*y + 23*x + 12*y + 15 + b = 7*x**3 - y*x**2 + 31*x**2 + 2*x*y + 15*y + 37*x + 21 + F = [x**2 + y**2 - 1] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + (1 + 5*y - 5*x)/(8*y - 6*x) + + a = x*y - x - 2*y + 4 + b = x + y**2 - 2*y + F = [x - 2, y - 3] + assert ratsimpmodprime(a/b, F, x, y, order='lex') == \ + Rational(2, 5) + + # Test a bug where denominators would be dropped + assert ratsimpmodprime(x, [y - 2*x], order='lex') == \ + y/2 + + a = (x**5 + 2*x**4 + 2*x**3 + 2*x**2 + x + 2/x + x**(-2)) + assert ratsimpmodprime(a, [x + 1], domain=GF(2)) == 1 + assert ratsimpmodprime(a, [x + 1], domain=GF(3)) == -1 diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_rewrite.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_rewrite.py new file mode 100644 index 0000000000000000000000000000000000000000..56d2fb7a85bd959bd4accc2f36127429efbdbe70 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_rewrite.py @@ -0,0 +1,31 @@ +from sympy.core.numbers import I +from sympy.core.symbol import symbols +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.trigonometric import (cos, cot, sin) +from sympy.testing.pytest import _both_exp_pow + +x, y, z, n = symbols('x,y,z,n') + + +@_both_exp_pow +def test_has(): + assert cot(x).has(x) + assert cot(x).has(cot) + assert not cot(x).has(sin) + assert sin(x).has(x) + assert sin(x).has(sin) + assert not sin(x).has(cot) + assert exp(x).has(exp) + + +@_both_exp_pow +def test_sin_exp_rewrite(): + assert sin(x).rewrite(sin, exp) == -I/2*(exp(I*x) - exp(-I*x)) + assert sin(x).rewrite(sin, exp).rewrite(exp, sin) == sin(x) + assert cos(x).rewrite(cos, exp).rewrite(exp, cos) == cos(x) + assert (sin(5*y) - sin( + 2*x)).rewrite(sin, exp).rewrite(exp, sin) == sin(5*y) - sin(2*x) + assert sin(x + y).rewrite(sin, exp).rewrite(exp, sin) == sin(x + y) + assert cos(x + y).rewrite(cos, exp).rewrite(exp, cos) == cos(x + y) + # This next test currently passes... not clear whether it should or not? + assert cos(x).rewrite(cos, exp).rewrite(exp, sin) == cos(x) diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_simplify.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_simplify.py new file mode 100644 index 0000000000000000000000000000000000000000..f4392b6693757a95fb2c1df562908e7bd0d79b8f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_simplify.py @@ -0,0 +1,1087 @@ +from sympy.concrete.summations import Sum +from sympy.core.add import Add +from sympy.core.basic import Basic +from sympy.core.expr import unchanged +from sympy.core.function import (count_ops, diff, expand, expand_multinomial, Function, Derivative) +from sympy.core.mul import Mul, _keep_coeff +from sympy.core import GoldenRatio +from sympy.core.numbers import (E, Float, I, oo, pi, Rational, zoo) +from sympy.core.relational import (Eq, Lt, Gt, Ge, Le) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import (binomial, factorial) +from sympy.functions.elementary.complexes import (Abs, sign) +from sympy.functions.elementary.exponential import (exp, exp_polar, log) +from sympy.functions.elementary.hyperbolic import (cosh, csch, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, sinc, tan) +from sympy.functions.special.error_functions import erf +from sympy.functions.special.gamma_functions import gamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.tensor_functions import KroneckerDelta +from sympy.geometry.polygon import rad +from sympy.integrals.integrals import (Integral, integrate) +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import (Matrix, eye) +from sympy.matrices.expressions.matexpr import MatrixSymbol +from sympy.polys.polytools import (factor, Poly) +from sympy.simplify.simplify import (besselsimp, hypersimp, inversecombine, logcombine, nsimplify, nthroot, posify, separatevars, signsimp, simplify) +from sympy.solvers.solvers import solve + +from sympy.testing.pytest import XFAIL, slow, _both_exp_pow +from sympy.abc import x, y, z, t, a, b, c, d, e, f, g, h, i, n + + +def test_issue_7263(): + assert abs((simplify(30.8**2 - 82.5**2 * sin(rad(11.6))**2)).evalf() - \ + 673.447451402970) < 1e-12 + + +def test_factorial_simplify(): + # There are more tests in test_factorials.py. + x = Symbol('x') + assert simplify(factorial(x)/x) == gamma(x) + assert simplify(factorial(factorial(x))) == factorial(factorial(x)) + + +def test_simplify_expr(): + x, y, z, k, n, m, w, s, A = symbols('x,y,z,k,n,m,w,s,A') + f = Function('f') + + assert all(simplify(tmp) == tmp for tmp in [I, E, oo, x, -x, -oo, -E, -I]) + + e = 1/x + 1/y + assert e != (x + y)/(x*y) + assert simplify(e) == (x + y)/(x*y) + + e = A**2*s**4/(4*pi*k*m**3) + assert simplify(e) == e + + e = (4 + 4*x - 2*(2 + 2*x))/(2 + 2*x) + assert simplify(e) == 0 + + e = (-4*x*y**2 - 2*y**3 - 2*x**2*y)/(x + y)**2 + assert simplify(e) == -2*y + + e = -x - y - (x + y)**(-1)*y**2 + (x + y)**(-1)*x**2 + assert simplify(e) == -2*y + + e = (x + x*y)/x + assert simplify(e) == 1 + y + + e = (f(x) + y*f(x))/f(x) + assert simplify(e) == 1 + y + + e = (2 * (1/n - cos(n * pi)/n))/pi + assert simplify(e) == (-cos(pi*n) + 1)/(pi*n)*2 + + e = integrate(1/(x**3 + 1), x).diff(x) + assert simplify(e) == 1/(x**3 + 1) + + e = integrate(x/(x**2 + 3*x + 1), x).diff(x) + assert simplify(e) == x/(x**2 + 3*x + 1) + + f = Symbol('f') + A = Matrix([[2*k - m*w**2, -k], [-k, k - m*w**2]]).inv() + assert simplify((A*Matrix([0, f]))[1] - + (-f*(2*k - m*w**2)/(k**2 - (k - m*w**2)*(2*k - m*w**2)))) == 0 + + f = -x + y/(z + t) + z*x/(z + t) + z*a/(z + t) + t*x/(z + t) + assert simplify(f) == (y + a*z)/(z + t) + + # issue 10347 + expr = -x*(y**2 - 1)*(2*y**2*(x**2 - 1)/(a*(x**2 - y**2)**2) + (x**2 - 1) + /(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + + y**2 - 1)*sin(z)/(a*(x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)* + (y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(x**2 - 1) + sqrt( + (-x**2 + 1)*(y**2 - 1))*(x*(-x*y**2 + x)/sqrt(-x**2*y**2 + x**2 + y**2 - + 1) + sqrt(-x**2*y**2 + x**2 + y**2 - 1))*sin(z))/(a*sqrt((-x**2 + 1)*( + y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a* + (x**2 - y**2)) + x*(-2*x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - y**2)**2) - x**2*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a* + (x**2 - 1)*(x**2 - y**2)) + (x**2*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2 + *y**2 + x**2 + y**2 - 1)*cos(z)/(x**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - + 1))*(-x*y**2 + x)*cos(z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1) + sqrt((-x**2 + + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z))/(a*sqrt((-x**2 + + 1)*(y**2 - 1))*(x**2 - y**2)))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos( + z)/(a*(x**2 - y**2)) - y*sqrt((-x**2 + 1)*(y**2 - 1))*(-x*y*sqrt(-x**2* + y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)*(y**2 - 1)) + 2*x*y*sqrt( + -x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(a*(x**2 - y**2)**2) + (x*y*sqrt(( + -x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin(z)/(y**2 - + 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*sin(z)/sqrt(-x**2*y**2 + + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2)))*sin( + z)/(a*(x**2 - y**2)) + y*(x**2 - 1)*(-2*x*y*(x**2 - 1)/(a*(x**2 - y**2) + **2) + 2*x*y/(a*(x**2 - y**2)))/(a*(x**2 - y**2)) + y*(x**2 - 1)*(y**2 - + 1)*(-x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2)*(y**2 + - 1)) + 2*x*y*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)/(a*(x**2 - y**2) + **2) + (x*y*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - + 1)*cos(z)/(y**2 - 1) + x*sqrt((-x**2 + 1)*(y**2 - 1))*(-x**2*y + y)*cos( + z)/sqrt(-x**2*y**2 + x**2 + y**2 - 1))/(a*sqrt((-x**2 + 1)*(y**2 - 1) + )*(x**2 - y**2)))*cos(z)/(a*sqrt((-x**2 + 1)*(y**2 - 1))*(x**2 - y**2) + ) - x*sqrt((-x**2 + 1)*(y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*sin( + z)**2/(a**2*(x**2 - 1)*(x**2 - y**2)*(y**2 - 1)) - x*sqrt((-x**2 + 1)*( + y**2 - 1))*sqrt(-x**2*y**2 + x**2 + y**2 - 1)*cos(z)**2/(a**2*(x**2 - 1)*( + x**2 - y**2)*(y**2 - 1)) + assert simplify(expr) == 2*x/(a**2*(x**2 - y**2)) + + #issue 17631 + assert simplify('((-1/2)*Boole(True)*Boole(False)-1)*Boole(True)') == \ + Mul(sympify('(2 + Boole(True)*Boole(False))'), sympify('-Boole(True)/2')) + + A, B = symbols('A,B', commutative=False) + + assert simplify(A*B - B*A) == A*B - B*A + assert simplify(A/(1 + y/x)) == x*A/(x + y) + assert simplify(A*(1/x + 1/y)) == A/x + A/y #(x + y)*A/(x*y) + + assert simplify(log(2) + log(3)) == log(6) + assert simplify(log(2*x) - log(2)) == log(x) + + assert simplify(hyper([], [], x)) == exp(x) + + +def test_issue_3557(): + f_1 = x*a + y*b + z*c - 1 + f_2 = x*d + y*e + z*f - 1 + f_3 = x*g + y*h + z*i - 1 + + solutions = solve([f_1, f_2, f_3], x, y, z, simplify=False) + + assert simplify(solutions[y]) == \ + (a*i + c*d + f*g - a*f - c*g - d*i)/ \ + (a*e*i + b*f*g + c*d*h - a*f*h - b*d*i - c*e*g) + + +def test_simplify_other(): + assert simplify(sin(x)**2 + cos(x)**2) == 1 + assert simplify(gamma(x + 1)/gamma(x)) == x + assert simplify(sin(x)**2 + cos(x)**2 + factorial(x)/gamma(x)) == 1 + x + assert simplify( + Eq(sin(x)**2 + cos(x)**2, factorial(x)/gamma(x))) == Eq(x, 1) + nc = symbols('nc', commutative=False) + assert simplify(x + x*nc) == x*(1 + nc) + # issue 6123 + # f = exp(-I*(k*sqrt(t) + x/(2*sqrt(t)))**2) + # ans = integrate(f, (k, -oo, oo), conds='none') + ans = I*(-pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))*erf(x*exp(I*pi*Rational(-3, 4))/ + (2*sqrt(t)))/(2*sqrt(t)) + pi*x*exp(I*pi*Rational(-3, 4) + I*x**2/(4*t))/ + (2*sqrt(t)))*exp(-I*x**2/(4*t))/(sqrt(pi)*x) - I*sqrt(pi) * \ + (-erf(x*exp(I*pi/4)/(2*sqrt(t))) + 1)*exp(I*pi/4)/(2*sqrt(t)) + assert simplify(ans) == -(-1)**Rational(3, 4)*sqrt(pi)/sqrt(t) + # issue 6370 + assert simplify(2**(2 + x)/4) == 2**x + + +@_both_exp_pow +def test_simplify_complex(): + cosAsExp = cos(x)._eval_rewrite_as_exp(x) + tanAsExp = tan(x)._eval_rewrite_as_exp(x) + assert simplify(cosAsExp*tanAsExp) == sin(x) # issue 4341 + + # issue 10124 + assert simplify(exp(Matrix([[0, -1], [1, 0]]))) == Matrix([[cos(1), + -sin(1)], [sin(1), cos(1)]]) + + +def test_simplify_ratio(): + # roots of x**3-3*x+5 + roots = ['(1/2 - sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3) + 1/((1/2 - ' + 'sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3))', + '1/((1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)) + ' + '(1/2 + sqrt(3)*I/2)*(sqrt(21)/2 + 5/2)**(1/3)', + '-(sqrt(21)/2 + 5/2)**(1/3) - 1/(sqrt(21)/2 + 5/2)**(1/3)'] + + for r in roots: + r = S(r) + assert count_ops(simplify(r, ratio=1)) <= count_ops(r) + # If ratio=oo, simplify() is always applied: + assert simplify(r, ratio=oo) is not r + + +def test_simplify_measure(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + assert measure1(simplify(expr, measure=measure1)) <= measure1(expr) + assert measure2(simplify(expr, measure=measure2)) <= measure2(expr) + + expr2 = Eq(sin(x)**2 + cos(x)**2, 1) + assert measure1(simplify(expr2, measure=measure1)) <= measure1(expr2) + assert measure2(simplify(expr2, measure=measure2)) <= measure2(expr2) + + +def test_simplify_rational(): + expr = 2**x*2.**y + assert simplify(expr, rational = True) == 2**(x+y) + assert simplify(expr, rational = None) == 2.0**(x+y) + assert simplify(expr, rational = False) == expr + assert simplify('0.9 - 0.8 - 0.1', rational = True) == 0 + + +def test_simplify_issue_1308(): + assert simplify(exp(Rational(-1, 2)) + exp(Rational(-3, 2))) == \ + (1 + E)*exp(Rational(-3, 2)) + + +def test_issue_5652(): + assert simplify(E + exp(-E)) == exp(-E) + E + n = symbols('n', commutative=False) + assert simplify(n + n**(-n)) == n + n**(-n) + + +def test_simplify_fail1(): + x = Symbol('x') + y = Symbol('y') + e = (x + y)**2/(-4*x*y**2 - 2*y**3 - 2*x**2*y) + assert simplify(e) == 1 / (-2*y) + + +def test_nthroot(): + assert nthroot(90 + 34*sqrt(7), 3) == sqrt(7) + 3 + q = 1 + sqrt(2) - 2*sqrt(3) + sqrt(6) + sqrt(7) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(41 + 29*sqrt(2), 5) == 1 + sqrt(2) + assert nthroot(-41 - 29*sqrt(2), 5) == -1 - sqrt(2) + expr = 1320*sqrt(10) + 4216 + 2576*sqrt(6) + 1640*sqrt(15) + assert nthroot(expr, 5) == 1 + sqrt(6) + sqrt(15) + q = 1 + sqrt(2) + sqrt(3) + sqrt(5) + assert expand_multinomial(nthroot(expand_multinomial(q**5), 5)) == q + q = 1 + sqrt(2) + 7*sqrt(6) + 2*sqrt(10) + assert nthroot(expand_multinomial(q**5), 5, 8) == q + q = 1 + sqrt(2) - 2*sqrt(3) + 1171*sqrt(6) + assert nthroot(expand_multinomial(q**3), 3) == q + assert nthroot(expand_multinomial(q**6), 6) == q + + +def test_nthroot1(): + q = 1 + sqrt(2) + sqrt(3) + S.One/10**20 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + q = 1 + sqrt(2) + sqrt(3) + S.One/10**30 + p = expand_multinomial(q**5) + assert nthroot(p, 5) == q + + +@_both_exp_pow +def test_separatevars(): + x, y, z, n = symbols('x,y,z,n') + assert separatevars(2*n*x*z + 2*x*y*z) == 2*x*z*(n + y) + assert separatevars(x*z + x*y*z) == x*z*(1 + y) + assert separatevars(pi*x*z + pi*x*y*z) == pi*x*z*(1 + y) + assert separatevars(x*y**2*sin(x) + x*sin(x)*sin(y)) == \ + x*(sin(y) + y**2)*sin(x) + assert separatevars(x*exp(x + y) + x*exp(x)) == x*(1 + exp(y))*exp(x) + assert separatevars((x*(y + 1))**z).is_Pow # != x**z*(1 + y)**z + assert separatevars(1 + x + y + x*y) == (x + 1)*(y + 1) + assert separatevars(y/pi*exp(-(z - x)/cos(n))) == \ + y*exp(x/cos(n))*exp(-z/cos(n))/pi + assert separatevars((x + y)*(x - y) + y**2 + 2*x + 1) == (x + 1)**2 + # issue 4858 + p = Symbol('p', positive=True) + assert separatevars(sqrt(p**2 + x*p**2)) == p*sqrt(1 + x) + assert separatevars(sqrt(y*(p**2 + x*p**2))) == p*sqrt(y*(1 + x)) + assert separatevars(sqrt(y*(p**2 + x*p**2)), force=True) == \ + p*sqrt(y)*sqrt(1 + x) + # issue 4865 + assert separatevars(sqrt(x*y)).is_Pow + assert separatevars(sqrt(x*y), force=True) == sqrt(x)*sqrt(y) + # issue 4957 + # any type sequence for symbols is fine + assert separatevars(((2*x + 2)*y), dict=True, symbols=()) == \ + {'coeff': 1, x: 2*x + 2, y: y} + # separable + assert separatevars(((2*x + 2)*y), dict=True, symbols=[x]) == \ + {'coeff': y, x: 2*x + 2} + assert separatevars(((2*x + 2)*y), dict=True, symbols=[]) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True) == \ + {'coeff': 1, x: 2*x + 2, y: y} + assert separatevars(((2*x + 2)*y), dict=True, symbols=None) == \ + {'coeff': y*(2*x + 2)} + # not separable + assert separatevars(3, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=()) is None + assert separatevars(2*x + y, dict=True) is None + assert separatevars(2*x + y, dict=True, symbols=None) == {'coeff': 2*x + y} + # issue 4808 + n, m = symbols('n,m', commutative=False) + assert separatevars(m + n*m) == (1 + n)*m + assert separatevars(x + x*n) == x*(1 + n) + # issue 4910 + f = Function('f') + assert separatevars(f(x) + x*f(x)) == f(x) + x*f(x) + # a noncommutable object present + eq = x*(1 + hyper((), (), y*z)) + assert separatevars(eq) == eq + + s = separatevars(abs(x*y)) + assert s == abs(x)*abs(y) and s.is_Mul + z = cos(1)**2 + sin(1)**2 - 1 + a = abs(x*z) + s = separatevars(a) + assert not a.is_Mul and s.is_Mul and s == abs(x)*abs(z) + s = separatevars(abs(x*y*z)) + assert s == abs(x)*abs(y)*abs(z) + + # abs(x+y)/abs(z) would be better but we test this here to + # see that it doesn't raise + assert separatevars(abs((x+y)/z)) == abs((x+y)/z) + + +def test_separatevars_advanced_factor(): + x, y, z = symbols('x,y,z') + assert separatevars(1 + log(x)*log(y) + log(x) + log(y)) == \ + (log(x) + 1)*(log(y) + 1) + assert separatevars(1 + x - log(z) - x*log(z) - exp(y)*log(z) - + x*exp(y)*log(z) + x*exp(y) + exp(y)) == \ + -((x + 1)*(log(z) - 1)*(exp(y) + 1)) + x, y = symbols('x,y', positive=True) + assert separatevars(1 + log(x**log(y)) + log(x*y)) == \ + (log(x) + 1)*(log(y) + 1) + + +def test_hypersimp(): + n, k = symbols('n,k', integer=True) + + assert hypersimp(factorial(k), k) == k + 1 + assert hypersimp(factorial(k**2), k) is None + + assert hypersimp(1/factorial(k), k) == 1/(k + 1) + + assert hypersimp(2**k/factorial(k)**2, k) == 2/(k + 1)**2 + + assert hypersimp(binomial(n, k), k) == (n - k)/(k + 1) + assert hypersimp(binomial(n + 1, k), k) == (n - k + 1)/(k + 1) + + term = (4*k + 1)*factorial(k)/factorial(2*k + 1) + assert hypersimp(term, k) == S.Half*((4*k + 5)/(3 + 14*k + 8*k**2)) + + term = 1/((2*k - 1)*factorial(2*k + 1)) + assert hypersimp(term, k) == (k - S.Half)/((k + 1)*(2*k + 1)*(2*k + 3)) + + term = binomial(n, k)*(-1)**k/factorial(k) + assert hypersimp(term, k) == (k - n)/(k + 1)**2 + + +def test_nsimplify(): + x = Symbol("x") + assert nsimplify(0) == 0 + assert nsimplify(-1) == -1 + assert nsimplify(1) == 1 + assert nsimplify(1 + x) == 1 + x + assert nsimplify(2.7) == Rational(27, 10) + assert nsimplify(1 - GoldenRatio) == (1 - sqrt(5))/2 + assert nsimplify((1 + sqrt(5))/4, [GoldenRatio]) == GoldenRatio/2 + assert nsimplify(2/GoldenRatio, [GoldenRatio]) == 2*GoldenRatio - 2 + assert nsimplify(exp(pi*I*Rational(5, 3), evaluate=False)) == \ + sympify('1/2 - sqrt(3)*I/2') + assert nsimplify(sin(pi*Rational(3, 5), evaluate=False)) == \ + sympify('sqrt(sqrt(5)/8 + 5/8)') + assert nsimplify(sqrt(atan('1', evaluate=False))*(2 + I), [pi]) == \ + sqrt(pi) + sqrt(pi)/2*I + assert nsimplify(2 + exp(2*atan('1/4')*I)) == sympify('49/17 + 8*I/17') + assert nsimplify(pi, tolerance=0.01) == Rational(22, 7) + assert nsimplify(pi, tolerance=0.001) == Rational(355, 113) + assert nsimplify(0.33333, tolerance=1e-4) == Rational(1, 3) + assert nsimplify(2.0**(1/3.), tolerance=0.001) == Rational(635, 504) + assert nsimplify(2.0**(1/3.), tolerance=0.001, full=True) == \ + 2**Rational(1, 3) + assert nsimplify(x + .5, rational=True) == S.Half + x + assert nsimplify(1/.3 + x, rational=True) == Rational(10, 3) + x + assert nsimplify(log(3).n(), rational=True) == \ + sympify('109861228866811/100000000000000') + assert nsimplify(Float(0.272198261287950), [pi, log(2)]) == pi*log(2)/8 + assert nsimplify(Float(0.272198261287950).n(3), [pi, log(2)]) == \ + -pi/4 - log(2) + Rational(7, 4) + assert nsimplify(x/7.0) == x/7 + assert nsimplify(pi/1e2) == pi/100 + assert nsimplify(pi/1e2, rational=False) == pi/100.0 + assert nsimplify(pi/1e-7) == 10000000*pi + assert not nsimplify( + factor(-3.0*z**2*(z**2)**(-2.5) + 3*(z**2)**(-1.5))).atoms(Float) + e = x**0.0 + assert e.is_Pow and nsimplify(x**0.0) == 1 + assert nsimplify(3.333333, tolerance=0.1, rational=True) == Rational(10, 3) + assert nsimplify(3.333333, tolerance=0.01, rational=True) == Rational(10, 3) + assert nsimplify(3.666666, tolerance=0.1, rational=True) == Rational(11, 3) + assert nsimplify(3.666666, tolerance=0.01, rational=True) == Rational(11, 3) + assert nsimplify(33, tolerance=10, rational=True) == Rational(33) + assert nsimplify(33.33, tolerance=10, rational=True) == Rational(30) + assert nsimplify(37.76, tolerance=10, rational=True) == Rational(40) + assert nsimplify(-203.1) == Rational(-2031, 10) + assert nsimplify(.2, tolerance=0) == Rational(1, 5) + assert nsimplify(-.2, tolerance=0) == Rational(-1, 5) + assert nsimplify(.2222, tolerance=0) == Rational(1111, 5000) + assert nsimplify(-.2222, tolerance=0) == Rational(-1111, 5000) + # issue 7211, PR 4112 + assert nsimplify(S(2e-8)) == Rational(1, 50000000) + # issue 7322 direct test + assert nsimplify(1e-42, rational=True) != 0 + # issue 10336 + inf = Float('inf') + infs = (-oo, oo, inf, -inf) + for zi in infs: + ans = sign(zi)*oo + assert nsimplify(zi) == ans + assert nsimplify(zi + x) == x + ans + + assert nsimplify(0.33333333, rational=True, rational_conversion='exact') == Rational(0.33333333) + + # Make sure nsimplify on expressions uses full precision + assert nsimplify(pi.evalf(100)*x, rational_conversion='exact').evalf(100) == pi.evalf(100)*x + + +def test_issue_9448(): + tmp = sympify("1/(1 - (-1)**(2/3) - (-1)**(1/3)) + 1/(1 + (-1)**(2/3) + (-1)**(1/3))") + assert nsimplify(tmp) == S.Half + + +def test_extract_minus_sign(): + x = Symbol("x") + y = Symbol("y") + a = Symbol("a") + b = Symbol("b") + assert simplify(-x/-y) == x/y + assert simplify(-x/y) == -x/y + assert simplify(x/y) == x/y + assert simplify(x/-y) == -x/y + assert simplify(-x/0) == zoo*x + assert simplify(Rational(-5, 0)) is zoo + assert simplify(-a*x/(-y - b)) == a*x/(b + y) + + +def test_diff(): + x = Symbol("x") + y = Symbol("y") + f = Function("f") + g = Function("g") + assert simplify(g(x).diff(x)*f(x).diff(x) - f(x).diff(x)*g(x).diff(x)) == 0 + assert simplify(2*f(x)*f(x).diff(x) - diff(f(x)**2, x)) == 0 + assert simplify(diff(1/f(x), x) + f(x).diff(x)/f(x)**2) == 0 + assert simplify(f(x).diff(x, y) - f(x).diff(y, x)) == 0 + + +def test_logcombine_1(): + x, y = symbols("x,y") + a = Symbol("a") + z, w = symbols("z,w", positive=True) + b = Symbol("b", real=True) + assert logcombine(log(x) + 2*log(y)) == log(x) + 2*log(y) + assert logcombine(log(x) + 2*log(y), force=True) == log(x*y**2) + assert logcombine(a*log(w) + log(z)) == a*log(w) + log(z) + assert logcombine(b*log(z) + b*log(x)) == log(z**b) + b*log(x) + assert logcombine(b*log(z) - log(w)) == log(z**b/w) + assert logcombine(log(x)*log(z)) == log(x)*log(z) + assert logcombine(log(w)*log(x)) == log(w)*log(x) + assert logcombine(cos(-2*log(z) + b*log(w))) in [cos(log(w**b/z**2)), + cos(log(z**2/w**b))] + assert logcombine(log(log(x) - log(y)) - log(z), force=True) == \ + log(log(x/y)/z) + assert logcombine((2 + I)*log(x), force=True) == (2 + I)*log(x) + assert logcombine((x**2 + log(x) - log(y))/(x*y), force=True) == \ + (x**2 + log(x/y))/(x*y) + # the following could also give log(z*x**log(y**2)), what we + # are testing is that a canonical result is obtained + assert logcombine(log(x)*2*log(y) + log(z), force=True) == \ + log(z*y**log(x**2)) + assert logcombine((x*y + sqrt(x**4 + y**4) + log(x) - log(y))/(pi*x**Rational(2, 3)* + sqrt(y)**3), force=True) == ( + x*y + sqrt(x**4 + y**4) + log(x/y))/(pi*x**Rational(2, 3)*y**Rational(3, 2)) + assert logcombine(gamma(-log(x/y))*acos(-log(x/y)), force=True) == \ + acos(-log(x/y))*gamma(-log(x/y)) + + assert logcombine(2*log(z)*log(w)*log(x) + log(z) + log(w)) == \ + log(z**log(w**2))*log(x) + log(w*z) + assert logcombine(3*log(w) + 3*log(z)) == log(w**3*z**3) + assert logcombine(x*(y + 1) + log(2) + log(3)) == x*(y + 1) + log(6) + assert logcombine((x + y)*log(w) + (-x - y)*log(3)) == (x + y)*log(w/3) + # a single unknown can combine + assert logcombine(log(x) + log(2)) == log(2*x) + eq = log(abs(x)) + log(abs(y)) + assert logcombine(eq) == eq + reps = {x: 0, y: 0} + assert log(abs(x)*abs(y)).subs(reps) != eq.subs(reps) + + +def test_logcombine_complex_coeff(): + i = Integral((sin(x**2) + cos(x**3))/x, x) + assert logcombine(i, force=True) == i + assert logcombine(i + 2*log(x), force=True) == \ + i + log(x**2) + + +def test_issue_5950(): + x, y = symbols("x,y", positive=True) + assert logcombine(log(3) - log(2)) == log(Rational(3,2), evaluate=False) + assert logcombine(log(x) - log(y)) == log(x/y) + assert logcombine(log(Rational(3,2), evaluate=False) - log(2)) == \ + log(Rational(3,4), evaluate=False) + + +def test_posify(): + x = symbols('x') + + assert str(posify( + x + + Symbol('p', positive=True) + + Symbol('n', negative=True))) == '(_x + n + p, {_x: x})' + + eq, rep = posify(1/x) + assert log(eq).expand().subs(rep) == -log(x) + assert str(posify([x, 1 + x])) == '([_x, _x + 1], {_x: x})' + + p = symbols('p', positive=True) + n = symbols('n', negative=True) + orig = [x, n, p] + modified, reps = posify(orig) + assert str(modified) == '[_x, n, p]' + assert [w.subs(reps) for w in modified] == orig + + assert str(Integral(posify(1/x + y)[0], (y, 1, 3)).expand()) == \ + 'Integral(1/_x, (y, 1, 3)) + Integral(_y, (y, 1, 3))' + assert str(Sum(posify(1/x**n)[0], (n,1,3)).expand()) == \ + 'Sum(_x**(-n), (n, 1, 3))' + + # issue 16438 + k = Symbol('k', finite=True) + eq, rep = posify(k) + assert eq.assumptions0 == {'positive': True, 'zero': False, 'imaginary': False, + 'nonpositive': False, 'commutative': True, 'hermitian': True, 'real': True, 'nonzero': True, + 'nonnegative': True, 'negative': False, 'complex': True, 'finite': True, + 'infinite': False, 'extended_real':True, 'extended_negative': False, + 'extended_nonnegative': True, 'extended_nonpositive': False, + 'extended_nonzero': True, 'extended_positive': True} + + +def test_issue_4194(): + # simplify should call cancel + f = Function('f') + assert simplify((4*x + 6*f(y))/(2*x + 3*f(y))) == 2 + + +@XFAIL +def test_simplify_float_vs_integer(): + # Test for issue 4473: + # https://github.com/sympy/sympy/issues/4473 + assert simplify(x**2.0 - x**2) == 0 + assert simplify(x**2 - x**2.0) == 0 + + +def test_as_content_primitive(): + assert (x/2 + y).as_content_primitive() == (S.Half, x + 2*y) + assert (x/2 + y).as_content_primitive(clear=False) == (S.One, x/2 + y) + assert (y*(x/2 + y)).as_content_primitive() == (S.Half, y*(x + 2*y)) + assert (y*(x/2 + y)).as_content_primitive(clear=False) == (S.One, y*(x/2 + y)) + + # although the _as_content_primitive methods do not alter the underlying structure, + # the as_content_primitive function will touch up the expression and join + # bases that would otherwise have not been joined. + assert (x*(2 + 2*x)*(3*x + 3)**2).as_content_primitive() == \ + (18, x*(x + 1)**3) + assert (2 + 2*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (2, x + 3*y*(y + 1) + 1) + assert ((2 + 6*x)**2).as_content_primitive() == \ + (4, (3*x + 1)**2) + assert ((2 + 6*x)**(2*y)).as_content_primitive() == \ + (1, (_keep_coeff(S(2), (3*x + 1)))**(2*y)) + assert (5 + 10*x + 2*y*(3 + 3*y)).as_content_primitive() == \ + (1, 10*x + 6*y*(y + 1) + 5) + assert (5*(x*(1 + y)) + 2*x*(3 + 3*y)).as_content_primitive() == \ + (11, x*(y + 1)) + assert ((5*(x*(1 + y)) + 2*x*(3 + 3*y))**2).as_content_primitive() == \ + (121, x**2*(y + 1)**2) + assert (y**2).as_content_primitive() == \ + (1, y**2) + assert (S.Infinity).as_content_primitive() == (1, oo) + eq = x**(2 + y) + assert (eq).as_content_primitive() == (1, eq) + assert (S.Half**(2 + x)).as_content_primitive() == (Rational(1, 4), 2**-x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), (Rational(-1, 2))**x) + assert (Rational(-1, 2)**(2 + x)).as_content_primitive() == \ + (Rational(1, 4), Rational(-1, 2)**x) + assert (4**((1 + y)/2)).as_content_primitive() == (2, 4**(y/2)) + assert (3**((1 + y)/2)).as_content_primitive() == \ + (1, 3**(Mul(S.Half, 1 + y, evaluate=False))) + assert (5**Rational(3, 4)).as_content_primitive() == (1, 5**Rational(3, 4)) + assert (5**Rational(7, 4)).as_content_primitive() == (5, 5**Rational(3, 4)) + assert Add(z*Rational(5, 7), 0.5*x, y*Rational(3, 2), evaluate=False).as_content_primitive() == \ + (Rational(1, 14), 7.0*x + 21*y + 10*z) + assert (2**Rational(3, 4) + 2**Rational(1, 4)*sqrt(3)).as_content_primitive(radical=True) == \ + (1, 2**Rational(1, 4)*(sqrt(2) + sqrt(3))) + + +def test_signsimp(): + e = x*(-x + 1) + x*(x - 1) + assert signsimp(Eq(e, 0)) is S.true + assert Abs(x - 1) == Abs(1 - x) + assert signsimp(y - x) == y - x + assert signsimp(y - x, evaluate=False) == Mul(-1, x - y, evaluate=False) + + +def test_besselsimp(): + from sympy.functions.special.bessel import (besseli, besselj, bessely) + from sympy.integrals.transforms import cosine_transform + assert besselsimp(exp(-I*pi*y/2)*besseli(y, z*exp_polar(I*pi/2))) == \ + besselj(y, z) + assert besselsimp(exp(-I*pi*a/2)*besseli(a, 2*sqrt(x)*exp_polar(I*pi/2))) == \ + besselj(a, 2*sqrt(x)) + assert besselsimp(sqrt(2)*sqrt(pi)*x**Rational(1, 4)*exp(I*pi/4)*exp(-I*pi*a/2) * + besseli(Rational(-1, 2), sqrt(x)*exp_polar(I*pi/2)) * + besseli(a, sqrt(x)*exp_polar(I*pi/2))/2) == \ + besselj(a, sqrt(x)) * cos(sqrt(x)) + assert besselsimp(besseli(Rational(-1, 2), z)) == \ + sqrt(2)*cosh(z)/(sqrt(pi)*sqrt(z)) + assert besselsimp(besseli(a, z*exp_polar(-I*pi/2))) == \ + exp(-I*pi*a/2)*besselj(a, z) + assert cosine_transform(1/t*sin(a/t), t, y) == \ + sqrt(2)*sqrt(pi)*besselj(0, 2*sqrt(a)*sqrt(y))/2 + + assert besselsimp(x**2*(a*(-2*besselj(5*I, x) + besselj(-2 + 5*I, x) + + besselj(2 + 5*I, x)) + b*(-2*bessely(5*I, x) + bessely(-2 + 5*I, x) + + bessely(2 + 5*I, x)))/4 + x*(a*(besselj(-1 + 5*I, x)/2 - besselj(1 + 5*I, x)/2) + + b*(bessely(-1 + 5*I, x)/2 - bessely(1 + 5*I, x)/2)) + (x**2 + 25)*(a*besselj(5*I, x) + + b*bessely(5*I, x))) == 0 + + assert besselsimp(81*x**2*(a*(besselj(Rational(-5, 3), 9*x) - 2*besselj(Rational(1, 3), 9*x) + besselj(Rational(7, 3), 9*x)) + + b*(bessely(Rational(-5, 3), 9*x) - 2*bessely(Rational(1, 3), 9*x) + bessely(Rational(7, 3), 9*x)))/4 + x*(a*(9*besselj(Rational(-2, 3), 9*x)/2 + - 9*besselj(Rational(4, 3), 9*x)/2) + b*(9*bessely(Rational(-2, 3), 9*x)/2 - 9*bessely(Rational(4, 3), 9*x)/2)) + + (81*x**2 - Rational(1, 9))*(a*besselj(Rational(1, 3), 9*x) + b*bessely(Rational(1, 3), 9*x))) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) - 2*a*besselj(a, x)/x) == 0 + + assert besselsimp(besselj(a-1,x) + besselj(a+1, x) + besselj(a, x)) == (2*a + x)*besselj(a, x)/x + + assert besselsimp(x**2* besselj(a,x) + x**3*besselj(a+1, x) + besselj(a+2, x)) == \ + 2*a*x*besselj(a + 1, x) + x**3*besselj(a + 1, x) - x**2*besselj(a + 2, x) + 2*x*besselj(a + 1, x) + besselj(a + 2, x) + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + s1 = simplify(e1) + s2 = simplify(e2) + s3 = simplify(e3) + assert simplify(Piecewise((e1, x < e2), (e3, True))) == \ + Piecewise((s1, x < s2), (s3, True)) + + +def test_polymorphism(): + class A(Basic): + def _eval_simplify(x, **kwargs): + return S.One + + a = A(S(5), S(2)) + assert simplify(a) == 1 + + +def test_issue_from_PR1599(): + n1, n2, n3, n4 = symbols('n1 n2 n3 n4', negative=True) + assert simplify(I*sqrt(n1)) == -sqrt(-n1) + + +def test_issue_6811(): + eq = (x + 2*y)*(2*x + 2) + assert simplify(eq) == (x + 1)*(x + 2*y)*2 + # reject the 2-arg Mul -- these are a headache for test writing + assert simplify(eq.expand()) == \ + 2*x**2 + 4*x*y + 2*x + 4*y + + +def test_issue_6920(): + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + # wrap in f to show that the change happens wherever ei occurs + f = Function('f') + assert [simplify(f(ei)).args[0] for ei in e] == ok + + +def test_issue_7001(): + from sympy.abc import r, R + assert simplify(-(r*Piecewise((pi*Rational(4, 3), r <= R), + (-8*pi*R**3/(3*r**3), True)) + 2*Piecewise((pi*r*Rational(4, 3), r <= R), + (4*pi*R**3/(3*r**2), True)))/(4*pi*r)) == \ + Piecewise((-1, r <= R), (0, True)) + + +def test_inequality_no_auto_simplify(): + # no simplify on creation but can be simplified + lhs = cos(x)**2 + sin(x)**2 + rhs = 2 + e = Lt(lhs, rhs, evaluate=False) + assert e is not S.true + assert simplify(e) + + +def test_issue_9398(): + from sympy.core.numbers import Number + from sympy.polys.polytools import cancel + assert cancel(1e-14) != 0 + assert cancel(1e-14*I) != 0 + + assert simplify(1e-14) != 0 + assert simplify(1e-14*I) != 0 + + assert (I*Number(1.)*Number(10)**Number(-14)).simplify() != 0 + + assert cancel(1e-20) != 0 + assert cancel(1e-20*I) != 0 + + assert simplify(1e-20) != 0 + assert simplify(1e-20*I) != 0 + + assert cancel(1e-100) != 0 + assert cancel(1e-100*I) != 0 + + assert simplify(1e-100) != 0 + assert simplify(1e-100*I) != 0 + + f = Float("1e-1000") + assert cancel(f) != 0 + assert cancel(f*I) != 0 + + assert simplify(f) != 0 + assert simplify(f*I) != 0 + + +def test_issue_9324_simplify(): + M = MatrixSymbol('M', 10, 10) + e = M[0, 0] + M[5, 4] + 1304 + assert simplify(e) == e + + +def test_issue_9817_simplify(): + # simplify on trace of substituted explicit quadratic form of matrix + # expressions (a scalar) should return without errors (AttributeError) + # See issue #9817 and #9190 for the original bug more discussion on this + from sympy.matrices.expressions import Identity, trace + v = MatrixSymbol('v', 3, 1) + A = MatrixSymbol('A', 3, 3) + x = Matrix([i + 1 for i in range(3)]) + X = Identity(3) + quadratic = v.T * A * v + assert simplify((trace(quadratic.as_explicit())).xreplace({v:x, A:X})) == 14 + + +def test_issue_13474(): + x = Symbol('x') + assert simplify(x + csch(sinc(1))) == x + csch(sinc(1)) + + +@_both_exp_pow +def test_simplify_function_inverse(): + # "inverse" attribute does not guarantee that f(g(x)) is x + # so this simplification should not happen automatically. + # See issue #12140 + x, y = symbols('x, y') + g = Function('g') + + class f(Function): + def inverse(self, argindex=1): + return g + + assert simplify(f(g(x))) == f(g(x)) + assert inversecombine(f(g(x))) == x + assert simplify(f(g(x)), inverse=True) == x + assert simplify(f(g(sin(x)**2 + cos(x)**2)), inverse=True) == 1 + assert simplify(f(g(x, y)), inverse=True) == f(g(x, y)) + assert unchanged(asin, sin(x)) + assert simplify(asin(sin(x))) == asin(sin(x)) + assert simplify(2*asin(sin(3*x)), inverse=True) == 6*x + assert simplify(log(exp(x))) == log(exp(x)) + assert simplify(log(exp(x)), inverse=True) == x + assert simplify(exp(log(x)), inverse=True) == x + assert simplify(log(exp(x), 2), inverse=True) == x/log(2) + assert simplify(log(exp(x), 2, evaluate=False), inverse=True) == x/log(2) + + +def test_clear_coefficients(): + from sympy.simplify.simplify import clear_coefficients + assert clear_coefficients(4*y*(6*x + 3)) == (y*(2*x + 1), 0) + assert clear_coefficients(4*y*(6*x + 3) - 2) == (y*(2*x + 1), Rational(1, 6)) + assert clear_coefficients(4*y*(6*x + 3) - 2, x) == (y*(2*x + 1), x/12 + Rational(1, 6)) + assert clear_coefficients(sqrt(2) - 2) == (sqrt(2), 2) + assert clear_coefficients(4*sqrt(2) - 2) == (sqrt(2), S.Half) + assert clear_coefficients(S(3), x) == (0, x - 3) + assert clear_coefficients(S.Infinity, x) == (S.Infinity, x) + assert clear_coefficients(-S.Pi, x) == (S.Pi, -x) + assert clear_coefficients(2 - S.Pi/3, x) == (pi, -3*x + 6) + +def test_nc_simplify(): + from sympy.simplify.simplify import nc_simplify + from sympy.matrices.expressions import MatPow, Identity + from sympy.core import Pow + from functools import reduce + + a, b, c, d = symbols('a b c d', commutative = False) + x = Symbol('x') + A = MatrixSymbol("A", x, x) + B = MatrixSymbol("B", x, x) + C = MatrixSymbol("C", x, x) + D = MatrixSymbol("D", x, x) + subst = {a: A, b: B, c: C, d:D} + funcs = {Add: lambda x,y: x+y, Mul: lambda x,y: x*y } + + def _to_matrix(expr): + if expr in subst: + return subst[expr] + if isinstance(expr, Pow): + return MatPow(_to_matrix(expr.args[0]), expr.args[1]) + elif isinstance(expr, (Add, Mul)): + return reduce(funcs[expr.func],[_to_matrix(a) for a in expr.args]) + else: + return expr*Identity(x) + + def _check(expr, simplified, deep=True, matrix=True): + assert nc_simplify(expr, deep=deep) == simplified + assert expand(expr) == expand(simplified) + if matrix: + m_simp = _to_matrix(simplified).doit(inv_expand=False) + assert nc_simplify(_to_matrix(expr), deep=deep) == m_simp + + _check(a*b*a*b*a*b*c*(a*b)**3*c, ((a*b)**3*c)**2) + _check(a*b*(a*b)**-2*a*b, 1) + _check(a**2*b*a*b*a*b*(a*b)**-1, a*(a*b)**2, matrix=False) + _check(b*a*b**2*a*b**2*a*b**2, b*(a*b**2)**3) + _check(a*b*a**2*b*a**2*b*a**3, (a*b*a)**3*a**2) + _check(a**2*b*a**4*b*a**4*b*a**2, (a**2*b*a**2)**3) + _check(a**3*b*a**4*b*a**4*b*a, a**3*(b*a**4)**3*a**-3) + _check(a*b*a*b + a*b*c*x*a*b*c, (a*b)**2 + x*(a*b*c)**2) + _check(a*b*a*b*c*a*b*a*b*c, ((a*b)**2*c)**2) + _check(b**-1*a**-1*(a*b)**2, a*b) + _check(a**-1*b*c**-1, (c*b**-1*a)**-1) + expr = a**3*b*a**4*b*a**4*b*a**2*b*a**2*(b*a**2)**2*b*a**2*b*a**2 + for _ in range(10): + expr *= a*b + _check(expr, a**3*(b*a**4)**2*(b*a**2)**6*(a*b)**10) + _check((a*b*a*b)**2, (a*b*a*b)**2, deep=False) + _check(a*b*(c*d)**2, a*b*(c*d)**2) + expr = b**-1*(a**-1*b**-1 - a**-1*c*b**-1)**-1*a**-1 + assert nc_simplify(expr) == (1-c)**-1 + # commutative expressions should be returned without an error + assert nc_simplify(2*x**2) == 2*x**2 + +def test_issue_15965(): + A = Sum(z*x**y, (x, 1, a)) + anew = z*Sum(x**y, (x, 1, a)) + B = Integral(x*y, x) + bdo = x**2*y/2 + assert simplify(A + B) == anew + bdo + assert simplify(A) == anew + assert simplify(B) == bdo + assert simplify(B, doit=False) == y*Integral(x, x) + + +def test_issue_17137(): + assert simplify(cos(x)**I) == cos(x)**I + assert simplify(cos(x)**(2 + 3*I)) == cos(x)**(2 + 3*I) + + +def test_issue_21869(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + expr = And(Eq(x**2, 4), Le(x, y)) + assert expr.simplify() == expr + + expr = And(Eq(x**2, 4), Eq(x, 2)) + assert expr.simplify() == Eq(x, 2) + + expr = And(Eq(x**3, x**2), Eq(x, 1)) + assert expr.simplify() == Eq(x, 1) + + expr = And(Eq(sin(x), x**2), Eq(x, 0)) + assert expr.simplify() == Eq(x, 0) + + expr = And(Eq(x**3, x**2), Eq(x, 2)) + assert expr.simplify() == S.false + + expr = And(Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 1), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,1), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, 2*x**2), Eq(x, 1)) + assert expr.simplify() == And(Eq(y,2), Eq(x, 1)) + + expr = And(Eq(y**2, 4), Eq(y, x**2), Eq(x, 1)) + assert expr.simplify() == S.false + + +def test_issue_7971_21740(): + z = Integral(x, (x, 1, 1)) + assert z != 0 + assert simplify(z) is S.Zero + assert simplify(S.Zero) is S.Zero + z = simplify(Float(0)) + assert z is not S.Zero and z == 0.0 + + +@slow +def test_issue_17141_slow(): + # Should not give RecursionError + assert simplify((2**acos(I+1)**2).rewrite('log')) == 2**((pi + 2*I*log(-1 + + sqrt(1 - 2*I) + I))**2/4) + + +def test_issue_17141(): + # Check that there is no RecursionError + assert simplify(x**(1 / acos(I))) == x**(2/(pi - 2*I*log(1 + sqrt(2)))) + assert simplify(acos(-I)**2*acos(I)**2) == \ + log(1 + sqrt(2))**4 + pi**2*log(1 + sqrt(2))**2/2 + pi**4/16 + assert simplify(2**acos(I)**2) == 2**((pi - 2*I*log(1 + sqrt(2)))**2/4) + p = 2**acos(I+1)**2 + assert simplify(p) == p + + +def test_simplify_kroneckerdelta(): + i, j = symbols("i j") + K = KroneckerDelta + + assert simplify(K(i, j)) == K(i, j) + assert simplify(K(0, j)) == K(0, j) + assert simplify(K(i, 0)) == K(i, 0) + + assert simplify(K(0, j).rewrite(Piecewise) * K(1, j)) == 0 + assert simplify(K(1, i) + Piecewise((1, Eq(j, 2)), (0, True))) == K(1, i) + K(2, j) + + # issue 17214 + assert simplify(K(0, j) * K(1, j)) == 0 + + n = Symbol('n', integer=True) + assert simplify(K(0, n) * K(1, n)) == 0 + + M = Matrix(4, 4, lambda i, j: K(j - i, n) if i <= j else 0) + assert simplify(M**2) == Matrix([[K(0, n), 0, K(1, n), 0], + [0, K(0, n), 0, K(1, n)], + [0, 0, K(0, n), 0], + [0, 0, 0, K(0, n)]]) + assert simplify(eye(1) * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) == Matrix([[0]]) + + assert simplify(S.Infinity * KroneckerDelta(0, n) * + KroneckerDelta(1, n)) is S.NaN + + +def test_issue_17292(): + assert simplify(abs(x)/abs(x**2)) == 1/abs(x) + # this is bigger than the issue: check that deep processing works + assert simplify(5*abs((x**2 - 1)/(x - 1))) == 5*Abs(x + 1) + + +def test_issue_19822(): + expr = And(Gt(n-2, 1), Gt(n, 1)) + assert simplify(expr) == Gt(n, 3) + + +def test_issue_18645(): + expr = And(Ge(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + expr = And(Eq(x, 3), Le(x, 3)) + assert simplify(expr) == Eq(x, 3) + + +@XFAIL +def test_issue_18642(): + i = Symbol("i", integer=True) + n = Symbol("n", integer=True) + expr = And(Eq(i, 2 * n), Le(i, 2*n -1)) + assert simplify(expr) == S.false + + +@XFAIL +def test_issue_18389(): + n = Symbol("n", integer=True) + expr = Eq(n, 0) | (n >= 1) + assert simplify(expr) == Ge(n, 0) + + +def test_issue_8373(): + x = Symbol('x', real=True) + assert simplify(Or(x < 1, x >= 1)) == S.true + + +def test_issue_7950(): + expr = And(Eq(x, 1), Eq(x, 2)) + assert simplify(expr) == S.false + + +def test_issue_22020(): + expr = I*pi/2 -oo + assert simplify(expr) == expr + # Used to throw an error + + +def test_issue_19484(): + assert simplify(sign(x) * Abs(x)) == x + + e = x + sign(x + x**3) + assert simplify(Abs(x + x**3)*e) == x**3 + x*Abs(x**3 + x) + x + + e = x**2 + sign(x**3 + 1) + assert simplify(Abs(x**3 + 1) * e) == x**3 + x**2*Abs(x**3 + 1) + 1 + + f = Function('f') + e = x + sign(x + f(x)**3) + assert simplify(Abs(x + f(x)**3) * e) == x*Abs(x + f(x)**3) + x + f(x)**3 + + +def test_issue_23543(): + # Used to give an error + x, y, z = symbols("x y z", commutative=False) + assert (x*(y + z/2)).simplify() == x*(2*y + z)/2 + + +def test_issue_11004(): + + def f(n): + return sqrt(2*pi*n) * (n/E)**n + + def m(n, k): + return f(n) / (f(n/k)**k) + + def p(n,k): + return m(n, k) / (k**n) + + N, k = symbols('N k') + half = Float('0.5', 4) + z = log(p(n, k) / p(n, k + 1)).expand(force=True) + r = simplify(z.subs(n, N).n(4)) + assert r == ( + half*k*log(k) + - half*k*log(k + 1) + + half*log(N) + - half*log(k + 1) + + Float(0.9189224, 4) + ) + + +def test_issue_19161(): + polynomial = Poly('x**2').simplify() + assert (polynomial-x**2).simplify() == 0 + + +def test_issue_22210(): + d = Symbol('d', integer=True) + expr = 2*Derivative(sin(x), (x, d)) + assert expr.simplify() == expr + + +def test_reduce_inverses_nc_pow(): + x, y = symbols("x y", commutative=True) + Z = symbols("Z", commutative=False) + assert simplify(2**Z * y**Z) == 2**Z * y**Z + assert simplify(x**Z * y**Z) == x**Z * y**Z + x, y = symbols("x y", positive=True) + assert expand((x*y)**Z) == x**Z * y**Z + assert simplify(x**Z * y**Z) == expand((x*y)**Z) + +def test_nc_recursion_coeff(): + X = symbols("X", commutative = False) + assert (2 * cos(pi/3) * X).simplify() == X + assert (2.0 * cos(pi/3) * X).simplify() == X diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_sqrtdenest.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_sqrtdenest.py new file mode 100644 index 0000000000000000000000000000000000000000..41c771bb2055a1199d349ae3649f33927d79313a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_sqrtdenest.py @@ -0,0 +1,204 @@ +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.symbol import Symbol +from sympy.functions.elementary.miscellaneous import (root, sqrt) +from sympy.functions.elementary.trigonometric import cos +from sympy.integrals.integrals import Integral +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.sqrtdenest import ( + _subsets as subsets, _sqrt_numeric_denest) + +r2, r3, r5, r6, r7, r10, r15, r29 = [sqrt(x) for x in (2, 3, 5, 6, 7, 10, + 15, 29)] + + +def test_sqrtdenest(): + d = {sqrt(5 + 2 * r6): r2 + r3, + sqrt(5. + 2 * r6): sqrt(5. + 2 * r6), + sqrt(5. + 4*sqrt(5 + 2 * r6)): sqrt(5.0 + 4*r2 + 4*r3), + sqrt(r2): sqrt(r2), + sqrt(5 + r7): sqrt(5 + r7), + sqrt(3 + sqrt(5 + 2*r7)): + 3*r2*(5 + 2*r7)**Rational(1, 4)/(2*sqrt(6 + 3*r7)) + + r2*sqrt(6 + 3*r7)/(2*(5 + 2*r7)**Rational(1, 4)), + sqrt(3 + 2*r3): 3**Rational(3, 4)*(r6/2 + 3*r2/2)/3} + for i in d: + assert sqrtdenest(i) == d[i], i + + +def test_sqrtdenest2(): + assert sqrtdenest(sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29))) == \ + r5 + sqrt(11 - 2*r29) + e = sqrt(-r5 + sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + assert sqrtdenest(e) == root(-2*r29 + 11, 4) + r = sqrt(1 + r7) + assert sqrtdenest(sqrt(1 + r)) == sqrt(1 + r) + e = sqrt(((1 + sqrt(1 + 2*sqrt(3 + r2 + r5)))**2).expand()) + assert sqrtdenest(e) == 1 + sqrt(1 + 2*sqrt(r2 + r5 + 3)) + + assert sqrtdenest(sqrt(5*r3 + 6*r2)) == \ + sqrt(2)*root(3, 4) + root(3, 4)**3 + + assert sqrtdenest(sqrt(((1 + r5 + sqrt(1 + r3))**2).expand())) == \ + 1 + r5 + sqrt(1 + r3) + + assert sqrtdenest(sqrt(((1 + r5 + r7 + sqrt(1 + r3))**2).expand())) == \ + 1 + sqrt(1 + r3) + r5 + r7 + + e = sqrt(((1 + cos(2) + cos(3) + sqrt(1 + r3))**2).expand()) + assert sqrtdenest(e) == cos(3) + cos(2) + 1 + sqrt(1 + r3) + + e = sqrt(-2*r10 + 2*r2*sqrt(-2*r10 + 11) + 14) + assert sqrtdenest(e) == sqrt(-2*r10 - 2*r2 + 4*r5 + 14) + + # check that the result is not more complicated than the input + z = sqrt(-2*r29 + cos(2) + 2*sqrt(-10*r29 + 55) + 16) + assert sqrtdenest(z) == z + + assert sqrtdenest(sqrt(r6 + sqrt(15))) == sqrt(r6 + sqrt(15)) + + z = sqrt(15 - 2*sqrt(31) + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(z) == z + + +def test_sqrtdenest_rec(): + assert sqrtdenest(sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 33)) == \ + -r2 + r3 + 2*r7 + assert sqrtdenest(sqrt(-28*r7 - 14*r5 + 4*sqrt(35) + 82)) == \ + -7 + r5 + 2*r7 + assert sqrtdenest(sqrt(6*r2/11 + 2*sqrt(22)/11 + 6*sqrt(11)/11 + 2)) == \ + sqrt(11)*(r2 + 3 + sqrt(11))/11 + assert sqrtdenest(sqrt(468*r3 + 3024*r2 + 2912*r6 + 19735)) == \ + 9*r3 + 26 + 56*r6 + z = sqrt(-490*r3 - 98*sqrt(115) - 98*sqrt(345) - 2107) + assert sqrtdenest(z) == sqrt(-1)*(7*r5 + 7*r15 + 7*sqrt(23)) + z = sqrt(-4*sqrt(14) - 2*r6 + 4*sqrt(21) + 34) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-8*r2 - 2*r5 + 18)) == -r10 + 1 + r2 + r5 + assert sqrtdenest(sqrt(8*r2 + 2*r5 - 18)) == \ + sqrt(-1)*(-r10 + 1 + r2 + r5) + assert sqrtdenest(sqrt(8*r2/3 + 14*r5/3 + Rational(154, 9))) == \ + -r10/3 + r2 + r5 + 3 + assert sqrtdenest(sqrt(sqrt(2*r6 + 5) + sqrt(2*r7 + 8))) == \ + sqrt(1 + r2 + r3 + r7) + assert sqrtdenest(sqrt(4*r15 + 8*r5 + 12*r3 + 24)) == 1 + r3 + r5 + r15 + + w = 1 + r2 + r3 + r5 + r7 + assert sqrtdenest(sqrt((w**2).expand())) == w + z = sqrt((w**2).expand() + 1) + assert sqrtdenest(z) == z + + z = sqrt(2*r10 + 6*r2 + 4*r5 + 12 + 10*r15 + 30*r3) + assert sqrtdenest(z) == z + + +def test_issue_6241(): + z = sqrt( -320 + 32*sqrt(5) + 64*r15) + assert sqrtdenest(z) == z + + +def test_sqrtdenest3(): + z = sqrt(13 - 2*r10 + 2*r2*sqrt(-2*r10 + 11)) + assert sqrtdenest(z) == -1 + r2 + r10 + assert sqrtdenest(z, max_iter=1) == -1 + sqrt(2) + sqrt(10) + z = sqrt(sqrt(r2 + 2) + 2) + assert sqrtdenest(z) == z + assert sqrtdenest(sqrt(-2*r10 + 4*r2*sqrt(-2*r10 + 11) + 20)) == \ + sqrt(-2*r10 - 4*r2 + 8*r5 + 20) + assert sqrtdenest(sqrt((112 + 70*r2) + (46 + 34*r2)*r5)) == \ + r10 + 5 + 4*r2 + 3*r5 + z = sqrt(5 + sqrt(2*r6 + 5)*sqrt(-2*r29 + 2*sqrt(-10*r29 + 55) + 16)) + r = sqrt(-2*r29 + 11) + assert sqrtdenest(z) == sqrt(r2*r + r3*r + r10 + r15 + 5) + + n = sqrt(2*r6/7 + 2*r7/7 + 2*sqrt(42)/7 + 2) + d = sqrt(16 - 2*r29 + 2*sqrt(55 - 10*r29)) + assert sqrtdenest(n/d) == r7*(1 + r6 + r7)/(Mul(7, (sqrt(-2*r29 + 11) + r5), + evaluate=False)) + + +def test_sqrtdenest4(): + # see Denest_en.pdf in https://github.com/sympy/sympy/issues/3192 + z = sqrt(8 - r2*sqrt(5 - r5) - sqrt(3)*(1 + r5)) + z1 = sqrtdenest(z) + c = sqrt(-r5 + 5) + z1 = ((-r15*c - r3*c + c + r5*c - r6 - r2 + r10 + sqrt(30))/4).expand() + assert sqrtdenest(z) == z1 + + z = sqrt(2*r2*sqrt(r2 + 2) + 5*r2 + 4*sqrt(r2 + 2) + 8) + assert sqrtdenest(z) == r2 + sqrt(r2 + 2) + 2 + + w = 2 + r2 + r3 + (1 + r3)*sqrt(2 + r2 + 5*r3) + z = sqrt((w**2).expand()) + assert sqrtdenest(z) == w.expand() + + +def test_sqrt_symbolic_denest(): + x = Symbol('x') + z = sqrt(((1 + sqrt(sqrt(2 + x) + 3))**2).expand()) + assert sqrtdenest(z) == sqrt((1 + sqrt(sqrt(2 + x) + 3))**2) + z = sqrt(((1 + sqrt(sqrt(2 + cos(1)) + 3))**2).expand()) + assert sqrtdenest(z) == 1 + sqrt(sqrt(2 + cos(1)) + 3) + z = ((1 + cos(2))**4 + 1).expand() + assert sqrtdenest(z) == z + z = sqrt(((1 + sqrt(sqrt(2 + cos(3*x)) + 3))**2 + 1).expand()) + assert sqrtdenest(z) == z + c = cos(3) + c2 = c**2 + assert sqrtdenest(sqrt(2*sqrt(1 + r3)*c + c2 + 1 + r3*c2)) == \ + -1 - sqrt(1 + r3)*c + ra = sqrt(1 + r3) + z = sqrt(20*ra*sqrt(3 + 3*r3) + 12*r3*ra*sqrt(3 + 3*r3) + 64*r3 + 112) + assert sqrtdenest(z) == z + + +def test_issue_5857(): + from sympy.abc import x, y + z = sqrt(1/(4*r3 + 7) + 1) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + ans = (r2 + r6)/(r3 + 2) + assert sqrtdenest(z) == ans + assert sqrtdenest(1 + z) == 1 + ans + assert sqrtdenest(Integral(z + 1, (x, 1, 2))) == \ + Integral(1 + ans, (x, 1, 2)) + assert sqrtdenest(x + sqrt(y)) == x + sqrt(y) + + +def test_subsets(): + assert subsets(1) == [[1]] + assert subsets(4) == [ + [1, 0, 0, 0], [0, 1, 0, 0], [1, 1, 0, 0], [0, 0, 1, 0], [1, 0, 1, 0], + [0, 1, 1, 0], [1, 1, 1, 0], [0, 0, 0, 1], [1, 0, 0, 1], [0, 1, 0, 1], + [1, 1, 0, 1], [0, 0, 1, 1], [1, 0, 1, 1], [0, 1, 1, 1], [1, 1, 1, 1]] + + +def test_issue_5653(): + assert sqrtdenest( + sqrt(2 + sqrt(2 + sqrt(2)))) == sqrt(2 + sqrt(2 + sqrt(2))) + +def test_issue_12420(): + assert sqrtdenest((3 - sqrt(2)*sqrt(4 + 3*I) + 3*I)/2) == I + e = 3 - sqrt(2)*sqrt(4 + I) + 3*I + assert sqrtdenest(e) == e + +def test_sqrt_ratcomb(): + assert sqrtdenest(sqrt(1 + r3) + sqrt(3 + 3*r3) - sqrt(10 + 6*r3)) == 0 + +def test_issue_18041(): + e = -sqrt(-2 + 2*sqrt(3)*I) + assert sqrtdenest(e) == -1 - sqrt(3)*I + +def test_issue_19914(): + a = Integer(-8) + b = Integer(-1) + r = Integer(63) + d2 = a*a - b*b*r + + assert _sqrt_numeric_denest(a, b, r, d2) == \ + sqrt(14)*I/2 + 3*sqrt(2)*I/2 + assert sqrtdenest(sqrt(-8-sqrt(63))) == sqrt(14)*I/2 + 3*sqrt(2)*I/2 diff --git a/lib/python3.10/site-packages/sympy/simplify/tests/test_trigsimp.py b/lib/python3.10/site-packages/sympy/simplify/tests/test_trigsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..ea091ec8a6c7d654405968e3d035c2bbe02ccdf7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/simplify/tests/test_trigsimp.py @@ -0,0 +1,520 @@ +from itertools import product +from sympy.core.function import (Subs, count_ops, diff, expand) +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, coth, sinh, tanh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (cos, cot, sin, tan) +from sympy.functions.elementary.trigonometric import (acos, asin, atan2) +from sympy.functions.elementary.trigonometric import (asec, acsc) +from sympy.functions.elementary.trigonometric import (acot, atan) +from sympy.integrals.integrals import integrate +from sympy.matrices.dense import Matrix +from sympy.simplify.simplify import simplify +from sympy.simplify.trigsimp import (exptrigsimp, trigsimp) + +from sympy.testing.pytest import XFAIL + +from sympy.abc import x, y + + + +def test_trigsimp1(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2) == cos(x)**2 + assert trigsimp(1 - cos(x)**2) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2) == 1 + assert trigsimp(1 + tan(x)**2) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2) == 1 + assert trigsimp(1 + cot(x)**2) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - 1) == 1/tan(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2) == 5 + assert trigsimp(5*cos(x/2)**2 + 2*sin(x/2)**2) == 3*cos(x)/2 + Rational(7, 2) + + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(2*tan(x)*cos(x)) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x)) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y)) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y)) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y)) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y)) == -2*sin(x)*sin(y) + assert trigsimp(tan(x + y) - tan(x)/(1 - tan(x)*tan(y))) == \ + sin(y)/(-sin(y)*tan(x) + cos(y)) # -tan(y)/(tan(x)*tan(y) - 1) + + assert trigsimp(sinh(x + y) + sinh(x - y)) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y)) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y)) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y)) == 2*sinh(x)*sinh(y) + assert trigsimp(tanh(x + y) - tanh(x)/(1 + tanh(x)*tanh(y))) == \ + sinh(y)/(sinh(y)*tanh(x) + cosh(y)) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2) == 1.0 + e = 2*sin(x)**2 + 2*cos(x)**2 + assert trigsimp(log(e)) == log(2) + + +def test_trigsimp1a(): + assert trigsimp(sin(2)**2*cos(3)*exp(2)/cos(2)**2) == tan(2)**2*cos(3)*exp(2) + assert trigsimp(tan(2)**2*cos(3)*exp(2)*cos(2)**2) == sin(2)**2*cos(3)*exp(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*sin(2)) == cos(3)*exp(2)*cos(2) + assert trigsimp(tan(2)*cos(3)*exp(2)/sin(2)) == cos(3)*exp(2)/cos(2) + assert trigsimp(cot(2)*cos(3)*exp(2)/cos(2)) == cos(3)*exp(2)/sin(2) + assert trigsimp(cot(2)*cos(3)*exp(2)*tan(2)) == cos(3)*exp(2) + assert trigsimp(sinh(2)*cos(3)*exp(2)/cosh(2)) == tanh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)*cosh(2)) == sinh(2)*cos(3)*exp(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*sinh(2)) == cosh(2)*cos(3)*exp(2) + assert trigsimp(tanh(2)*cos(3)*exp(2)/sinh(2)) == cos(3)*exp(2)/cosh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)/cosh(2)) == cos(3)*exp(2)/sinh(2) + assert trigsimp(coth(2)*cos(3)*exp(2)*tanh(2)) == cos(3)*exp(2) + + +def test_trigsimp2(): + x, y = symbols('x,y') + assert trigsimp(cos(x)**2*sin(y)**2 + cos(x)**2*cos(y)**2 + sin(x)**2, + recursive=True) == 1 + assert trigsimp(sin(x)**2*sin(y)**2 + sin(x)**2*cos(y)**2 + cos(x)**2, + recursive=True) == 1 + assert trigsimp( + Subs(x, x, sin(y)**2 + cos(y)**2)) == Subs(x, x, 1) + + +def test_issue_4373(): + x = Symbol("x") + assert abs(trigsimp(2.0*sin(x)**2 + 2.0*cos(x)**2) - 2.0) < 1e-10 + + +def test_trigsimp3(): + x, y = symbols('x,y') + assert trigsimp(sin(x)/cos(x)) == tan(x) + assert trigsimp(sin(x)**2/cos(x)**2) == tan(x)**2 + assert trigsimp(sin(x)**3/cos(x)**3) == tan(x)**3 + assert trigsimp(sin(x)**10/cos(x)**10) == tan(x)**10 + + assert trigsimp(cos(x)/sin(x)) == 1/tan(x) + assert trigsimp(cos(x)**2/sin(x)**2) == 1/tan(x)**2 + assert trigsimp(cos(x)**10/sin(x)**10) == 1/tan(x)**10 + + assert trigsimp(tan(x)) == trigsimp(sin(x)/cos(x)) + + +def test_issue_4661(): + a, x, y = symbols('a x y') + eq = -4*sin(x)**4 + 4*cos(x)**4 - 8*cos(x)**2 + assert trigsimp(eq) == -4 + n = sin(x)**6 + 4*sin(x)**4*cos(x)**2 + 5*sin(x)**2*cos(x)**4 + 2*cos(x)**6 + d = -sin(x)**2 - 2*cos(x)**2 + assert simplify(n/d) == -1 + assert trigsimp(-2*cos(x)**2 + cos(x)**4 - sin(x)**4) == -1 + eq = (- sin(x)**3/4)*cos(x) + (cos(x)**3/4)*sin(x) - sin(2*x)*cos(2*x)/8 + assert trigsimp(eq) == 0 + + +def test_issue_4494(): + a, b = symbols('a b') + eq = sin(a)**2*sin(b)**2 + cos(a)**2*cos(b)**2*tan(a)**2 + cos(a)**2 + assert trigsimp(eq) == 1 + + +def test_issue_5948(): + a, x, y = symbols('a x y') + assert trigsimp(diff(integrate(cos(x)/sin(x)**7, x), x)) == \ + cos(x)/sin(x)**7 + + +def test_issue_4775(): + a, x, y = symbols('a x y') + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)) == sin(x + y) + assert trigsimp(sin(x)*cos(y)+cos(x)*sin(y)+3) == sin(x + y) + 3 + + +def test_issue_4280(): + a, x, y = symbols('a x y') + assert trigsimp(cos(x)**2 + cos(y)**2*sin(x)**2 + sin(y)**2*sin(x)**2) == 1 + assert trigsimp(a**2*sin(x)**2 + a**2*cos(y)**2*cos(x)**2 + a**2*cos(x)**2*sin(y)**2) == a**2 + assert trigsimp(a**2*cos(y)**2*sin(x)**2 + a**2*sin(y)**2*sin(x)**2) == a**2*sin(x)**2 + + +def test_issue_3210(): + eqs = (sin(2)*cos(3) + sin(3)*cos(2), + -sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*cos(3) - sin(3)*cos(2), + sin(2)*sin(3) + cos(2)*cos(3), + sin(2)*sin(3) + cos(2)*cos(3) + cos(2), + sinh(2)*cosh(3) + sinh(3)*cosh(2), + sinh(2)*sinh(3) + cosh(2)*cosh(3), + ) + assert [trigsimp(e) for e in eqs] == [ + sin(5), + cos(5), + -sin(1), + cos(1), + cos(1) + cos(2), + sinh(5), + cosh(5), + ] + + +def test_trigsimp_issues(): + a, x, y = symbols('a x y') + + # issue 4625 - factor_terms works, too + assert trigsimp(sin(x)**3 + cos(x)**2*sin(x)) == sin(x) + + # issue 5948 + assert trigsimp(diff(integrate(cos(x)/sin(x)**3, x), x)) == \ + cos(x)/sin(x)**3 + assert trigsimp(diff(integrate(sin(x)/cos(x)**3, x), x)) == \ + sin(x)/cos(x)**3 + + # check integer exponents + e = sin(x)**y/cos(x)**y + assert trigsimp(e) == e + assert trigsimp(e.subs(y, 2)) == tan(x)**2 + assert trigsimp(e.subs(x, 1)) == tan(1)**y + + # check for multiple patterns + assert (cos(x)**2/sin(x)**2*cos(y)**2/sin(y)**2).trigsimp() == \ + 1/tan(x)**2/tan(y)**2 + assert trigsimp(cos(x)/sin(x)*cos(x+y)/sin(x+y)) == \ + 1/(tan(x)*tan(x + y)) + + eq = cos(2)*(cos(3) + 1)**2/(cos(3) - 1)**2 + assert trigsimp(eq) == eq.factor() # factor makes denom (-1 + cos(3))**2 + assert trigsimp(cos(2)*(cos(3) + 1)**2*(cos(3) - 1)**2) == \ + cos(2)*sin(3)**4 + + # issue 6789; this generates an expression that formerly caused + # trigsimp to hang + assert cot(x).equals(tan(x)) is False + + # nan or the unchanged expression is ok, but not sin(1) + z = cos(x)**2 + sin(x)**2 - 1 + z1 = tan(x)**2 - 1/cot(x)**2 + n = (1 + z1/z) + assert trigsimp(sin(n)) != sin(1) + eq = x*(n - 1) - x*n + assert trigsimp(eq) is S.NaN + assert trigsimp(eq, recursive=True) is S.NaN + assert trigsimp(1).is_Integer + + assert trigsimp(-sin(x)**4 - 2*sin(x)**2*cos(x)**2 - cos(x)**4) == -1 + + +def test_trigsimp_issue_2515(): + x = Symbol('x') + assert trigsimp(x*cos(x)*tan(x)) == x*sin(x) + assert trigsimp(-sin(x) + cos(x)*tan(x)) == 0 + + +def test_trigsimp_issue_3826(): + assert trigsimp(tan(2*x).expand(trig=True)) == tan(2*x) + + +def test_trigsimp_issue_4032(): + n = Symbol('n', integer=True, positive=True) + assert trigsimp(2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2) == \ + 2**(n/2)*cos(pi*n/4)/2 + 2**n/4 + + +def test_trigsimp_issue_7761(): + assert trigsimp(cosh(pi/4)) == cosh(pi/4) + + +def test_trigsimp_noncommutative(): + x, y = symbols('x,y') + A, B = symbols('A,B', commutative=False) + + assert trigsimp(A - A*sin(x)**2) == A*cos(x)**2 + assert trigsimp(A - A*cos(x)**2) == A*sin(x)**2 + assert trigsimp(A*sin(x)**2 + A*cos(x)**2) == A + assert trigsimp(A + A*tan(x)**2) == A/cos(x)**2 + assert trigsimp(A/cos(x)**2 - A) == A*tan(x)**2 + assert trigsimp(A/cos(x)**2 - A*tan(x)**2) == A + assert trigsimp(A + A*cot(x)**2) == A/sin(x)**2 + assert trigsimp(A/sin(x)**2 - A) == A/tan(x)**2 + assert trigsimp(A/sin(x)**2 - A*cot(x)**2) == A + + assert trigsimp(y*A*cos(x)**2 + y*A*sin(x)**2) == y*A + + assert trigsimp(A*sin(x)/cos(x)) == A*tan(x) + assert trigsimp(A*tan(x)*cos(x)) == A*sin(x) + assert trigsimp(A*cot(x)**3*sin(x)**3) == A*cos(x)**3 + assert trigsimp(y*A*tan(x)**2/sin(x)**2) == y*A/cos(x)**2 + assert trigsimp(A*cot(x)/cos(x)) == A/sin(x) + + assert trigsimp(A*sin(x + y) + A*sin(x - y)) == 2*A*sin(x)*cos(y) + assert trigsimp(A*sin(x + y) - A*sin(x - y)) == 2*A*sin(y)*cos(x) + assert trigsimp(A*cos(x + y) + A*cos(x - y)) == 2*A*cos(x)*cos(y) + assert trigsimp(A*cos(x + y) - A*cos(x - y)) == -2*A*sin(x)*sin(y) + + assert trigsimp(A*sinh(x + y) + A*sinh(x - y)) == 2*A*sinh(x)*cosh(y) + assert trigsimp(A*sinh(x + y) - A*sinh(x - y)) == 2*A*sinh(y)*cosh(x) + assert trigsimp(A*cosh(x + y) + A*cosh(x - y)) == 2*A*cosh(x)*cosh(y) + assert trigsimp(A*cosh(x + y) - A*cosh(x - y)) == 2*A*sinh(x)*sinh(y) + + assert trigsimp(A*cos(0.12345)**2 + A*sin(0.12345)**2) == 1.0*A + + +def test_hyperbolic_simp(): + x, y = symbols('x,y') + + assert trigsimp(sinh(x)**2 + 1) == cosh(x)**2 + assert trigsimp(cosh(x)**2 - 1) == sinh(x)**2 + assert trigsimp(cosh(x)**2 - sinh(x)**2) == 1 + assert trigsimp(1 - tanh(x)**2) == 1/cosh(x)**2 + assert trigsimp(1 - 1/cosh(x)**2) == tanh(x)**2 + assert trigsimp(tanh(x)**2 + 1/cosh(x)**2) == 1 + assert trigsimp(coth(x)**2 - 1) == 1/sinh(x)**2 + assert trigsimp(1/sinh(x)**2 + 1) == 1/tanh(x)**2 + assert trigsimp(coth(x)**2 - 1/sinh(x)**2) == 1 + + assert trigsimp(5*cosh(x)**2 - 5*sinh(x)**2) == 5 + assert trigsimp(5*cosh(x/2)**2 - 2*sinh(x/2)**2) == 3*cosh(x)/2 + Rational(7, 2) + + assert trigsimp(sinh(x)/cosh(x)) == tanh(x) + assert trigsimp(tanh(x)) == trigsimp(sinh(x)/cosh(x)) + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(2*tanh(x)*cosh(x)) == 2*sinh(x) + assert trigsimp(coth(x)**3*sinh(x)**3) == cosh(x)**3 + assert trigsimp(y*tanh(x)**2/sinh(x)**2) == y/cosh(x)**2 + assert trigsimp(coth(x)/cosh(x)) == 1/sinh(x) + + for a in (pi/6*I, pi/4*I, pi/3*I): + assert trigsimp(sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x + a) + assert trigsimp(-sinh(a)*cosh(x) + cosh(a)*sinh(x)) == sinh(x - a) + + e = 2*cosh(x)**2 - 2*sinh(x)**2 + assert trigsimp(log(e)) == log(2) + + # issue 19535: + assert trigsimp(sqrt(cosh(x)**2 - 1)) == sqrt(sinh(x)**2) + + assert trigsimp(cosh(x)**2*cosh(y)**2 - cosh(x)**2*sinh(y)**2 - sinh(x)**2, + recursive=True) == 1 + assert trigsimp(sinh(x)**2*sinh(y)**2 - sinh(x)**2*cosh(y)**2 + cosh(x)**2, + recursive=True) == 1 + + assert abs(trigsimp(2.0*cosh(x)**2 - 2.0*sinh(x)**2) - 2.0) < 1e-10 + + assert trigsimp(sinh(x)**2/cosh(x)**2) == tanh(x)**2 + assert trigsimp(sinh(x)**3/cosh(x)**3) == tanh(x)**3 + assert trigsimp(sinh(x)**10/cosh(x)**10) == tanh(x)**10 + assert trigsimp(cosh(x)**3/sinh(x)**3) == 1/tanh(x)**3 + + assert trigsimp(cosh(x)/sinh(x)) == 1/tanh(x) + assert trigsimp(cosh(x)**2/sinh(x)**2) == 1/tanh(x)**2 + assert trigsimp(cosh(x)**10/sinh(x)**10) == 1/tanh(x)**10 + + assert trigsimp(x*cosh(x)*tanh(x)) == x*sinh(x) + assert trigsimp(-sinh(x) + cosh(x)*tanh(x)) == 0 + + assert tan(x) != 1/cot(x) # cot doesn't auto-simplify + + assert trigsimp(tan(x) - 1/cot(x)) == 0 + assert trigsimp(3*tanh(x)**7 - 2/coth(x)**7) == tanh(x)**7 + + +def test_trigsimp_groebner(): + from sympy.simplify.trigsimp import trigsimp_groebner + + c = cos(x) + s = sin(x) + ex = (4*s*c + 12*s + 5*c**3 + 21*c**2 + 23*c + 15)/( + -s*c**2 + 2*s*c + 15*s + 7*c**3 + 31*c**2 + 37*c + 21) + resnum = (5*s - 5*c + 1) + resdenom = (8*s - 6*c) + results = [resnum/resdenom, (-resnum)/(-resdenom)] + assert trigsimp_groebner(ex) in results + assert trigsimp_groebner(s/c, hints=[tan]) == tan(x) + assert trigsimp_groebner(c*s) == c*s + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner') == 2/c + assert trigsimp((-s + 1)/c + c/(-s + 1), + method='groebner', polynomial=True) == 2/c + + # Test quick=False works + assert trigsimp_groebner(ex, hints=[2]) in results + assert trigsimp_groebner(ex, hints=[int(2)]) in results + + # test "I" + assert trigsimp_groebner(sin(I*x)/cos(I*x), hints=[tanh]) == I*tanh(x) + + # test hyperbolic / sums + assert trigsimp_groebner((tanh(x)+tanh(y))/(1+tanh(x)*tanh(y)), + hints=[(tanh, x, y)]) == tanh(x + y) + + +def test_issue_2827_trigsimp_methods(): + measure1 = lambda expr: len(str(expr)) + measure2 = lambda expr: -count_ops(expr) + # Return the most complicated result + expr = (x + 1)/(x + sin(x)**2 + cos(x)**2) + ans = Matrix([1]) + M = Matrix([expr]) + assert trigsimp(M, method='fu', measure=measure1) == ans + assert trigsimp(M, method='fu', measure=measure2) != ans + # all methods should work with Basic expressions even if they + # aren't Expr + M = Matrix.eye(1) + assert all(trigsimp(M, method=m) == M for m in + 'fu matching groebner old'.split()) + # watch for E in exptrigsimp, not only exp() + eq = 1/sqrt(E) + E + assert exptrigsimp(eq) == eq + +def test_issue_15129_trigsimp_methods(): + t1 = Matrix([sin(Rational(1, 50)), cos(Rational(1, 50)), 0]) + t2 = Matrix([sin(Rational(1, 25)), cos(Rational(1, 25)), 0]) + t3 = Matrix([cos(Rational(1, 25)), sin(Rational(1, 25)), 0]) + r1 = t1.dot(t2) + r2 = t1.dot(t3) + assert trigsimp(r1) == cos(Rational(1, 50)) + assert trigsimp(r2) == sin(Rational(3, 50)) + +def test_exptrigsimp(): + def valid(a, b): + from sympy.core.random import verify_numerically as tn + if not (tn(a, b) and a == b): + return False + return True + + assert exptrigsimp(exp(x) + exp(-x)) == 2*cosh(x) + assert exptrigsimp(exp(x) - exp(-x)) == 2*sinh(x) + assert exptrigsimp((2*exp(x)-2*exp(-x))/(exp(x)+exp(-x))) == 2*tanh(x) + assert exptrigsimp((2*exp(2*x)-2)/(exp(2*x)+1)) == 2*tanh(x) + e = [cos(x) + I*sin(x), cos(x) - I*sin(x), + cosh(x) - sinh(x), cosh(x) + sinh(x)] + ok = [exp(I*x), exp(-I*x), exp(-x), exp(x)] + assert all(valid(i, j) for i, j in zip( + [exptrigsimp(ei) for ei in e], ok)) + + ue = [cos(x) + sin(x), cos(x) - sin(x), + cosh(x) + I*sinh(x), cosh(x) - I*sinh(x)] + assert [exptrigsimp(ei) == ei for ei in ue] + + res = [] + ok = [y*tanh(1), 1/(y*tanh(1)), I*y*tan(1), -I/(y*tan(1)), + y*tanh(x), 1/(y*tanh(x)), I*y*tan(x), -I/(y*tan(x)), + y*tanh(1 + I), 1/(y*tanh(1 + I))] + for a in (1, I, x, I*x, 1 + I): + w = exp(a) + eq = y*(w - 1/w)/(w + 1/w) + res.append(simplify(eq)) + res.append(simplify(1/eq)) + assert all(valid(i, j) for i, j in zip(res, ok)) + + for a in range(1, 3): + w = exp(a) + e = w + 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*cosh(a)) + e = w - 1/w + s = simplify(e) + assert s == exptrigsimp(e) + assert valid(s, 2*sinh(a)) + +def test_exptrigsimp_noncommutative(): + a,b = symbols('a b', commutative=False) + x = Symbol('x', commutative=True) + assert exp(a + x) == exptrigsimp(exp(a)*exp(x)) + p = exp(a)*exp(b) - exp(b)*exp(a) + assert p == exptrigsimp(p) != 0 + +def test_powsimp_on_numbers(): + assert 2**(Rational(1, 3) - 2) == 2**Rational(1, 3)/4 + + +@XFAIL +def test_issue_6811_fail(): + # from doc/src/modules/physics/mechanics/examples.rst, the current `eq` + # at Line 576 (in different variables) was formerly the equivalent and + # shorter expression given below...it would be nice to get the short one + # back again + xp, y, x, z = symbols('xp, y, x, z') + eq = 4*(-19*sin(x)*y + 5*sin(3*x)*y + 15*cos(2*x)*z - 21*z)*xp/(9*cos(x) - 5*cos(3*x)) + assert trigsimp(eq) == -2*(2*cos(x)*tan(x)*y + 3*z)*xp/cos(x) + + +def test_Piecewise(): + e1 = x*(x + y) - y*(x + y) + e2 = sin(x)**2 + cos(x)**2 + e3 = expand((x + y)*y/x) + # s1 = simplify(e1) + s2 = simplify(e2) + # s3 = simplify(e3) + + # trigsimp tries not to touch non-trig containing args + assert trigsimp(Piecewise((e1, e3 < e2), (e3, True))) == \ + Piecewise((e1, e3 < s2), (e3, True)) + + +def test_issue_21594(): + assert simplify(exp(Rational(1,2)) + exp(Rational(-1,2))) == cosh(S.Half)*2 + + +def test_trigsimp_old(): + x, y = symbols('x,y') + + assert trigsimp(1 - sin(x)**2, old=True) == cos(x)**2 + assert trigsimp(1 - cos(x)**2, old=True) == sin(x)**2 + assert trigsimp(sin(x)**2 + cos(x)**2, old=True) == 1 + assert trigsimp(1 + tan(x)**2, old=True) == 1/cos(x)**2 + assert trigsimp(1/cos(x)**2 - 1, old=True) == tan(x)**2 + assert trigsimp(1/cos(x)**2 - tan(x)**2, old=True) == 1 + assert trigsimp(1 + cot(x)**2, old=True) == 1/sin(x)**2 + assert trigsimp(1/sin(x)**2 - cot(x)**2, old=True) == 1 + + assert trigsimp(5*cos(x)**2 + 5*sin(x)**2, old=True) == 5 + + assert trigsimp(sin(x)/cos(x), old=True) == tan(x) + assert trigsimp(2*tan(x)*cos(x), old=True) == 2*sin(x) + assert trigsimp(cot(x)**3*sin(x)**3, old=True) == cos(x)**3 + assert trigsimp(y*tan(x)**2/sin(x)**2, old=True) == y/cos(x)**2 + assert trigsimp(cot(x)/cos(x), old=True) == 1/sin(x) + + assert trigsimp(sin(x + y) + sin(x - y), old=True) == 2*sin(x)*cos(y) + assert trigsimp(sin(x + y) - sin(x - y), old=True) == 2*sin(y)*cos(x) + assert trigsimp(cos(x + y) + cos(x - y), old=True) == 2*cos(x)*cos(y) + assert trigsimp(cos(x + y) - cos(x - y), old=True) == -2*sin(x)*sin(y) + + assert trigsimp(sinh(x + y) + sinh(x - y), old=True) == 2*sinh(x)*cosh(y) + assert trigsimp(sinh(x + y) - sinh(x - y), old=True) == 2*sinh(y)*cosh(x) + assert trigsimp(cosh(x + y) + cosh(x - y), old=True) == 2*cosh(x)*cosh(y) + assert trigsimp(cosh(x + y) - cosh(x - y), old=True) == 2*sinh(x)*sinh(y) + + assert trigsimp(cos(0.12345)**2 + sin(0.12345)**2, old=True) == 1.0 + + assert trigsimp(sin(x)/cos(x), old=True, method='combined') == tan(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner') == sin(x)/cos(x) + assert trigsimp(sin(x)/cos(x), old=True, method='groebner', hints=[tan]) == tan(x) + + assert trigsimp(1-sin(sin(x)**2+cos(x)**2)**2, old=True, deep=True) == cos(1)**2 + + +def test_trigsimp_inverse(): + alpha = symbols('alpha') + s, c = sin(alpha), cos(alpha) + + for finv in [asin, acos, asec, acsc, atan, acot]: + f = finv.inverse(None) + assert alpha == trigsimp(finv(f(alpha)), inverse=True) + + # test atan2(cos, sin), atan2(sin, cos), etc... + for a, b in [[c, s], [s, c]]: + for i, j in product([-1, 1], repeat=2): + angle = atan2(i*b, j*a) + angle_inverted = trigsimp(angle, inverse=True) + assert angle_inverted != angle # assures simplification happened + assert sin(angle_inverted) == trigsimp(sin(angle)) + assert cos(angle_inverted) == trigsimp(cos(angle)) diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..448a0aae3a97802a49f3cc62ea77fe85c5840992 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/bivariate.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/bivariate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..671362b3fdaac94b42a7a28d2a669749976b3960 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/bivariate.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/decompogen.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/decompogen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1b2a37b7ef09bb9e5b6f7ebf468123c3ce45861 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/decompogen.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/deutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/deutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3d55222c1c7f643679bf05f8eeb7dfdd004fc9fa Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/deutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/inequalities.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/inequalities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f4baf18b60bd7b8bd4e78006afa3fb9bd17161c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/inequalities.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/pde.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/pde.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..92f38b2ae15ed1cf0263e99a90fef72812d05922 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/pde.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/polysys.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/polysys.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..29ca79d3eac2bfc78f671b133fa9f21598d90275 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/polysys.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/recurr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/recurr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8ae6679865fd8f1d44a5f401384f254a39fcc58e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/recurr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/__pycache__/simplex.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/__pycache__/simplex.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..277e8f863cb5f59ff4973a2d917b20aec1dff4af Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/__pycache__/simplex.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/benchmarks/__init__.py b/lib/python3.10/site-packages/sympy/solvers/benchmarks/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/solvers/benchmarks/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/benchmarks/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4cbe1f49aed9d20a553cbd8c43cf8ab0ba2f3d0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/benchmarks/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/benchmarks/__pycache__/bench_solvers.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/benchmarks/__pycache__/bench_solvers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a81516fad3c6d540037576ee9fb9723d72d76be2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/benchmarks/__pycache__/bench_solvers.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/benchmarks/bench_solvers.py b/lib/python3.10/site-packages/sympy/solvers/benchmarks/bench_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..d18102873f7efcde1d111e0e8eca12e208f94663 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/benchmarks/bench_solvers.py @@ -0,0 +1,12 @@ +from sympy.core.symbol import Symbol +from sympy.matrices.dense import (eye, zeros) +from sympy.solvers.solvers import solve_linear_system + +N = 8 +M = zeros(N, N + 1) +M[:, :N] = eye(N) +S = [Symbol('A%i' % i) for i in range(N)] + + +def timeit_linsolve_trivial(): + solve_linear_system(M, *S) diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/__init__.py b/lib/python3.10/site-packages/sympy/solvers/diophantine/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23c21242208d6f520c130250ecdce43382b9d868 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/diophantine/__init__.py @@ -0,0 +1,5 @@ +from .diophantine import diophantine, classify_diop, diop_solve + +__all__ = [ + 'diophantine', 'classify_diop', 'diop_solve' +] diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/diophantine/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8899d42d255e7257a0662c8822f91332e798b2f2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/diophantine/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/diophantine.py b/lib/python3.10/site-packages/sympy/solvers/diophantine/diophantine.py new file mode 100644 index 0000000000000000000000000000000000000000..3df4fe9b0df137828233a9243d2e1e604af309fd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/diophantine/diophantine.py @@ -0,0 +1,3960 @@ +from __future__ import annotations + +from sympy.core.add import Add +from sympy.core.assumptions import check_assumptions +from sympy.core.containers import Tuple +from sympy.core.exprtools import factor_terms +from sympy.core.function import _mexpand +from sympy.core.mul import Mul +from sympy.core.numbers import Rational, int_valued +from sympy.core.intfunc import igcdex, ilcm, igcd, integer_nthroot, isqrt +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, symbols +from sympy.core.sympify import _sympify +from sympy.external.gmpy import jacobi, remove, invert, iroot +from sympy.functions.elementary.complexes import sign +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.ntheory.factor_ import divisors, factorint, perfect_power +from sympy.ntheory.generate import nextprime +from sympy.ntheory.primetest import is_square, isprime +from sympy.ntheory.modular import symmetric_residue +from sympy.ntheory.residue_ntheory import sqrt_mod, sqrt_mod_iter +from sympy.polys.polyerrors import GeneratorsNeeded +from sympy.polys.polytools import Poly, factor_list +from sympy.simplify.simplify import signsimp +from sympy.solvers.solveset import solveset_real +from sympy.utilities import numbered_symbols +from sympy.utilities.misc import as_int, filldedent +from sympy.utilities.iterables import (is_sequence, subsets, permute_signs, + signed_permutations, ordered_partitions) + + +# these are imported with 'from sympy.solvers.diophantine import * +__all__ = ['diophantine', 'classify_diop'] + + +class DiophantineSolutionSet(set): + """ + Container for a set of solutions to a particular diophantine equation. + + The base representation is a set of tuples representing each of the solutions. + + Parameters + ========== + + symbols : list + List of free symbols in the original equation. + parameters: list + List of parameters to be used in the solution. + + Examples + ======== + + Adding solutions: + + >>> from sympy.solvers.diophantine.diophantine import DiophantineSolutionSet + >>> from sympy.abc import x, y, t, u + >>> s1 = DiophantineSolutionSet([x, y], [t, u]) + >>> s1 + set() + >>> s1.add((2, 3)) + >>> s1.add((-1, u)) + >>> s1 + {(-1, u), (2, 3)} + >>> s2 = DiophantineSolutionSet([x, y], [t, u]) + >>> s2.add((3, 4)) + >>> s1.update(*s2) + >>> s1 + {(-1, u), (2, 3), (3, 4)} + + Conversion of solutions into dicts: + + >>> list(s1.dict_iterator()) + [{x: -1, y: u}, {x: 2, y: 3}, {x: 3, y: 4}] + + Substituting values: + + >>> s3 = DiophantineSolutionSet([x, y], [t, u]) + >>> s3.add((t**2, t + u)) + >>> s3 + {(t**2, t + u)} + >>> s3.subs({t: 2, u: 3}) + {(4, 5)} + >>> s3.subs(t, -1) + {(1, u - 1)} + >>> s3.subs(t, 3) + {(9, u + 3)} + + Evaluation at specific values. Positional arguments are given in the same order as the parameters: + + >>> s3(-2, 3) + {(4, 1)} + >>> s3(5) + {(25, u + 5)} + >>> s3(None, 2) + {(t**2, t + 2)} + """ + + def __init__(self, symbols_seq, parameters): + super().__init__() + + if not is_sequence(symbols_seq): + raise ValueError("Symbols must be given as a sequence.") + + if not is_sequence(parameters): + raise ValueError("Parameters must be given as a sequence.") + + self.symbols = tuple(symbols_seq) + self.parameters = tuple(parameters) + + def add(self, solution): + if len(solution) != len(self.symbols): + raise ValueError("Solution should have a length of %s, not %s" % (len(self.symbols), len(solution))) + # make solution canonical wrt sign (i.e. no -x unless x is also present as an arg) + args = set(solution) + for i in range(len(solution)): + x = solution[i] + if not type(x) is int and (-x).is_Symbol and -x not in args: + solution = [_.subs(-x, x) for _ in solution] + super().add(Tuple(*solution)) + + def update(self, *solutions): + for solution in solutions: + self.add(solution) + + def dict_iterator(self): + for solution in ordered(self): + yield dict(zip(self.symbols, solution)) + + def subs(self, *args, **kwargs): + result = DiophantineSolutionSet(self.symbols, self.parameters) + for solution in self: + result.add(solution.subs(*args, **kwargs)) + return result + + def __call__(self, *args): + if len(args) > len(self.parameters): + raise ValueError("Evaluation should have at most %s values, not %s" % (len(self.parameters), len(args))) + rep = {p: v for p, v in zip(self.parameters, args) if v is not None} + return self.subs(rep) + + +class DiophantineEquationType: + """ + Internal representation of a particular diophantine equation type. + + Parameters + ========== + + equation : + The diophantine equation that is being solved. + free_symbols : list (optional) + The symbols being solved for. + + Attributes + ========== + + total_degree : + The maximum of the degrees of all terms in the equation + homogeneous : + Does the equation contain a term of degree 0 + homogeneous_order : + Does the equation contain any coefficient that is in the symbols being solved for + dimension : + The number of symbols being solved for + """ + name = None # type: str + + def __init__(self, equation, free_symbols=None): + self.equation = _sympify(equation).expand(force=True) + + if free_symbols is not None: + self.free_symbols = free_symbols + else: + self.free_symbols = list(self.equation.free_symbols) + self.free_symbols.sort(key=default_sort_key) + + if not self.free_symbols: + raise ValueError('equation should have 1 or more free symbols') + + self.coeff = self.equation.as_coefficients_dict() + if not all(int_valued(c) for c in self.coeff.values()): + raise TypeError("Coefficients should be Integers") + + self.total_degree = Poly(self.equation).total_degree() + self.homogeneous = 1 not in self.coeff + self.homogeneous_order = not (set(self.coeff) & set(self.free_symbols)) + self.dimension = len(self.free_symbols) + self._parameters = None + + def matches(self): + """ + Determine whether the given equation can be matched to the particular equation type. + """ + return False + + @property + def n_parameters(self): + return self.dimension + + @property + def parameters(self): + if self._parameters is None: + self._parameters = symbols('t_:%i' % (self.n_parameters,), integer=True) + return self._parameters + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + raise NotImplementedError('No solver has been written for %s.' % self.name) + + def pre_solve(self, parameters=None): + if not self.matches(): + raise ValueError("This equation does not match the %s equation type." % self.name) + + if parameters is not None: + if len(parameters) != self.n_parameters: + raise ValueError("Expected %s parameter(s) but got %s" % (self.n_parameters, len(parameters))) + + self._parameters = parameters + + +class Univariate(DiophantineEquationType): + """ + Representation of a univariate diophantine equation. + + A univariate diophantine equation is an equation of the form + `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x` is an integer variable. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import Univariate + >>> from sympy.abc import x + >>> Univariate((x - 2)*(x - 3)**2).solve() # solves equation (x - 2)*(x - 3)**2 == 0 + {(2,), (3,)} + + """ + + name = 'univariate' + + def matches(self): + return self.dimension == 1 + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + result = DiophantineSolutionSet(self.free_symbols, parameters=self.parameters) + for i in solveset_real(self.equation, self.free_symbols[0]).intersect(S.Integers): + result.add((i,)) + return result + + +class Linear(DiophantineEquationType): + """ + Representation of a linear diophantine equation. + + A linear diophantine equation is an equation of the form `a_{1}x_{1} + + a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import Linear + >>> from sympy.abc import x, y, z + >>> l1 = Linear(2*x - 3*y - 5) + >>> l1.matches() # is this equation linear + True + >>> l1.solve() # solves equation 2*x - 3*y - 5 == 0 + {(3*t_0 - 5, 2*t_0 - 5)} + + Here x = -3*t_0 - 5 and y = -2*t_0 - 5 + + >>> Linear(2*x - 3*y - 4*z -3).solve() + {(t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3)} + + """ + + name = 'linear' + + def matches(self): + return self.total_degree == 1 + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + coeff = self.coeff + var = self.free_symbols + + if 1 in coeff: + # negate coeff[] because input is of the form: ax + by + c == 0 + # but is used as: ax + by == -c + c = -coeff[1] + else: + c = 0 + + result = DiophantineSolutionSet(var, parameters=self.parameters) + params = result.parameters + + if len(var) == 1: + q, r = divmod(c, coeff[var[0]]) + if not r: + result.add((q,)) + return result + + ''' + base_solution_linear() can solve diophantine equations of the form: + + a*x + b*y == c + + We break down multivariate linear diophantine equations into a + series of bivariate linear diophantine equations which can then + be solved individually by base_solution_linear(). + + Consider the following: + + a_0*x_0 + a_1*x_1 + a_2*x_2 == c + + which can be re-written as: + + a_0*x_0 + g_0*y_0 == c + + where + + g_0 == gcd(a_1, a_2) + + and + + y == (a_1*x_1)/g_0 + (a_2*x_2)/g_0 + + This leaves us with two binary linear diophantine equations. + For the first equation: + + a == a_0 + b == g_0 + c == c + + For the second: + + a == a_1/g_0 + b == a_2/g_0 + c == the solution we find for y_0 in the first equation. + + The arrays A and B are the arrays of integers used for + 'a' and 'b' in each of the n-1 bivariate equations we solve. + ''' + + A = [coeff[v] for v in var] + B = [] + if len(var) > 2: + B.append(igcd(A[-2], A[-1])) + A[-2] = A[-2] // B[0] + A[-1] = A[-1] // B[0] + for i in range(len(A) - 3, 0, -1): + gcd = igcd(B[0], A[i]) + B[0] = B[0] // gcd + A[i] = A[i] // gcd + B.insert(0, gcd) + B.append(A[-1]) + + ''' + Consider the trivariate linear equation: + + 4*x_0 + 6*x_1 + 3*x_2 == 2 + + This can be re-written as: + + 4*x_0 + 3*y_0 == 2 + + where + + y_0 == 2*x_1 + x_2 + (Note that gcd(3, 6) == 3) + + The complete integral solution to this equation is: + + x_0 == 2 + 3*t_0 + y_0 == -2 - 4*t_0 + + where 't_0' is any integer. + + Now that we have a solution for 'x_0', find 'x_1' and 'x_2': + + 2*x_1 + x_2 == -2 - 4*t_0 + + We can then solve for '-2' and '-4' independently, + and combine the results: + + 2*x_1a + x_2a == -2 + x_1a == 0 + t_0 + x_2a == -2 - 2*t_0 + + 2*x_1b + x_2b == -4*t_0 + x_1b == 0*t_0 + t_1 + x_2b == -4*t_0 - 2*t_1 + + ==> + + x_1 == t_0 + t_1 + x_2 == -2 - 6*t_0 - 2*t_1 + + where 't_0' and 't_1' are any integers. + + Note that: + + 4*(2 + 3*t_0) + 6*(t_0 + t_1) + 3*(-2 - 6*t_0 - 2*t_1) == 2 + + for any integral values of 't_0', 't_1'; as required. + + This method is generalised for many variables, below. + + ''' + solutions = [] + for Ai, Bi in zip(A, B): + tot_x, tot_y = [], [] + + for arg in Add.make_args(c): + if arg.is_Integer: + # example: 5 -> k = 5 + k, p = arg, S.One + pnew = params[0] + else: # arg is a Mul or Symbol + # example: 3*t_1 -> k = 3 + # example: t_0 -> k = 1 + k, p = arg.as_coeff_Mul() + pnew = params[params.index(p) + 1] + + sol = sol_x, sol_y = base_solution_linear(k, Ai, Bi, pnew) + + if p is S.One: + if None in sol: + return result + else: + # convert a + b*pnew -> a*p + b*pnew + if isinstance(sol_x, Add): + sol_x = sol_x.args[0]*p + sol_x.args[1] + if isinstance(sol_y, Add): + sol_y = sol_y.args[0]*p + sol_y.args[1] + + tot_x.append(sol_x) + tot_y.append(sol_y) + + solutions.append(Add(*tot_x)) + c = Add(*tot_y) + + solutions.append(c) + result.add(solutions) + return result + + +class BinaryQuadratic(DiophantineEquationType): + """ + Representation of a binary quadratic diophantine equation. + + A binary quadratic diophantine equation is an equation of the + form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`, where `A, B, C, D, E, + F` are integer constants and `x` and `y` are integer variables. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import BinaryQuadratic + >>> b1 = BinaryQuadratic(x**3 + y**2 + 1) + >>> b1.matches() + False + >>> b2 = BinaryQuadratic(x**2 + y**2 + 2*x + 2*y + 2) + >>> b2.matches() + True + >>> b2.solve() + {(-1, -1)} + + References + ========== + + .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online], + Available: https://www.alpertron.com.ar/METHODS.HTM + .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online], + Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + """ + + name = 'binary_quadratic' + + def matches(self): + return self.total_degree == 2 and self.dimension == 2 + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + x, y = var + + A = coeff[x**2] + B = coeff[x*y] + C = coeff[y**2] + D = coeff[x] + E = coeff[y] + F = coeff[S.One] + + A, B, C, D, E, F = [as_int(i) for i in _remove_gcd(A, B, C, D, E, F)] + + # (1) Simple-Hyperbolic case: A = C = 0, B != 0 + # In this case equation can be converted to (Bx + E)(By + D) = DE - BF + # We consider two cases; DE - BF = 0 and DE - BF != 0 + # More details, https://www.alpertron.com.ar/METHODS.HTM#SHyperb + + result = DiophantineSolutionSet(var, self.parameters) + t, u = result.parameters + + discr = B**2 - 4*A*C + if A == 0 and C == 0 and B != 0: + + if D*E - B*F == 0: + q, r = divmod(E, B) + if not r: + result.add((-q, t)) + q, r = divmod(D, B) + if not r: + result.add((t, -q)) + else: + div = divisors(D*E - B*F) + div = div + [-term for term in div] + for d in div: + x0, r = divmod(d - E, B) + if not r: + q, r = divmod(D*E - B*F, d) + if not r: + y0, r = divmod(q - D, B) + if not r: + result.add((x0, y0)) + + # (2) Parabolic case: B**2 - 4*A*C = 0 + # There are two subcases to be considered in this case. + # sqrt(c)D - sqrt(a)E = 0 and sqrt(c)D - sqrt(a)E != 0 + # More Details, https://www.alpertron.com.ar/METHODS.HTM#Parabol + + elif discr == 0: + + if A == 0: + s = BinaryQuadratic(self.equation, free_symbols=[y, x]).solve(parameters=[t, u]) + for soln in s: + result.add((soln[1], soln[0])) + + else: + g = sign(A)*igcd(A, C) + a = A // g + c = C // g + e = sign(B / A) + + sqa = isqrt(a) + sqc = isqrt(c) + _c = e*sqc*D - sqa*E + if not _c: + z = Symbol("z", real=True) + eq = sqa*g*z**2 + D*z + sqa*F + roots = solveset_real(eq, z).intersect(S.Integers) + for root in roots: + ans = diop_solve(sqa*x + e*sqc*y - root) + result.add((ans[0], ans[1])) + + elif int_valued(c): + solve_x = lambda u: -e*sqc*g*_c*t**2 - (E + 2*e*sqc*g*u)*t \ + - (e*sqc*g*u**2 + E*u + e*sqc*F) // _c + + solve_y = lambda u: sqa*g*_c*t**2 + (D + 2*sqa*g*u)*t \ + + (sqa*g*u**2 + D*u + sqa*F) // _c + + for z0 in range(0, abs(_c)): + # Check if the coefficients of y and x obtained are integers or not + if (divisible(sqa*g*z0**2 + D*z0 + sqa*F, _c) and + divisible(e*sqc*g*z0**2 + E*z0 + e*sqc*F, _c)): + result.add((solve_x(z0), solve_y(z0))) + + # (3) Method used when B**2 - 4*A*C is a square, is described in p. 6 of the below paper + # by John P. Robertson. + # https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + elif is_square(discr): + if A != 0: + r = sqrt(discr) + u, v = symbols("u, v", integer=True) + eq = _mexpand( + 4*A*r*u*v + 4*A*D*(B*v + r*u + r*v - B*u) + + 2*A*4*A*E*(u - v) + 4*A*r*4*A*F) + + solution = diop_solve(eq, t) + + for s0, t0 in solution: + + num = B*t0 + r*s0 + r*t0 - B*s0 + x_0 = S(num) / (4*A*r) + y_0 = S(s0 - t0) / (2*r) + if isinstance(s0, Symbol) or isinstance(t0, Symbol): + if len(check_param(x_0, y_0, 4*A*r, parameters)) > 0: + ans = check_param(x_0, y_0, 4*A*r, parameters) + result.update(*ans) + elif x_0.is_Integer and y_0.is_Integer: + if is_solution_quad(var, coeff, x_0, y_0): + result.add((x_0, y_0)) + + else: + s = BinaryQuadratic(self.equation, free_symbols=var[::-1]).solve(parameters=[t, u]) # Interchange x and y + while s: + result.add(s.pop()[::-1]) # and solution <--------+ + + # (4) B**2 - 4*A*C > 0 and B**2 - 4*A*C not a square or B**2 - 4*A*C < 0 + + else: + + P, Q = _transformation_to_DN(var, coeff) + D, N = _find_DN(var, coeff) + solns_pell = diop_DN(D, N) + + if D < 0: + for x0, y0 in solns_pell: + for x in [-x0, x0]: + for y in [-y0, y0]: + s = P*Matrix([x, y]) + Q + try: + result.add([as_int(_) for _ in s]) + except ValueError: + pass + else: + # In this case equation can be transformed into a Pell equation + + solns_pell = set(solns_pell) + solns_pell.update((-X, -Y) for X, Y in list(solns_pell)) + + a = diop_DN(D, 1) + T = a[0][0] + U = a[0][1] + + if all(int_valued(_) for _ in P[:4] + Q[:2]): + for r, s in solns_pell: + _a = (r + s*sqrt(D))*(T + U*sqrt(D))**t + _b = (r - s*sqrt(D))*(T - U*sqrt(D))**t + x_n = _mexpand(S(_a + _b) / 2) + y_n = _mexpand(S(_a - _b) / (2*sqrt(D))) + s = P*Matrix([x_n, y_n]) + Q + result.add(s) + + else: + L = ilcm(*[_.q for _ in P[:4] + Q[:2]]) + + k = 1 + + T_k = T + U_k = U + + while (T_k - 1) % L != 0 or U_k % L != 0: + T_k, U_k = T_k*T + D*U_k*U, T_k*U + U_k*T + k += 1 + + for X, Y in solns_pell: + + for i in range(k): + if all(int_valued(_) for _ in P*Matrix([X, Y]) + Q): + _a = (X + sqrt(D)*Y)*(T_k + sqrt(D)*U_k)**t + _b = (X - sqrt(D)*Y)*(T_k - sqrt(D)*U_k)**t + Xt = S(_a + _b) / 2 + Yt = S(_a - _b) / (2*sqrt(D)) + s = P*Matrix([Xt, Yt]) + Q + result.add(s) + + X, Y = X*T + D*U*Y, X*U + Y*T + + return result + + +class InhomogeneousTernaryQuadratic(DiophantineEquationType): + """ + + Representation of an inhomogeneous ternary quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'inhomogeneous_ternary_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + return not self.homogeneous_order + + +class HomogeneousTernaryQuadraticNormal(DiophantineEquationType): + """ + Representation of a homogeneous ternary quadratic normal diophantine equation. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadraticNormal + >>> HomogeneousTernaryQuadraticNormal(4*x**2 - 5*y**2 + z**2).solve() + {(1, 2, 4)} + + """ + + name = 'homogeneous_ternary_quadratic_normal' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + if not self.homogeneous_order: + return False + + nonzero = [k for k in self.coeff if self.coeff[k]] + return len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols) + + def solve(self, parameters=None, limit=None) -> DiophantineSolutionSet: + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + x, y, z = var + + a = coeff[x**2] + b = coeff[y**2] + c = coeff[z**2] + + (sqf_of_a, sqf_of_b, sqf_of_c), (a_1, b_1, c_1), (a_2, b_2, c_2) = \ + sqf_normal(a, b, c, steps=True) + + A = -a_2*c_2 + B = -b_2*c_2 + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + # If following two conditions are satisfied then there are no solutions + if A < 0 and B < 0: + return result + + if ( + sqrt_mod(-b_2*c_2, a_2) is None or + sqrt_mod(-c_2*a_2, b_2) is None or + sqrt_mod(-a_2*b_2, c_2) is None): + return result + + z_0, x_0, y_0 = descent(A, B) + + z_0, q = _rational_pq(z_0, abs(c_2)) + x_0 *= q + y_0 *= q + + x_0, y_0, z_0 = _remove_gcd(x_0, y_0, z_0) + + # Holzer reduction + if sign(a) == sign(b): + x_0, y_0, z_0 = holzer(x_0, y_0, z_0, abs(a_2), abs(b_2), abs(c_2)) + elif sign(a) == sign(c): + x_0, z_0, y_0 = holzer(x_0, z_0, y_0, abs(a_2), abs(c_2), abs(b_2)) + else: + y_0, z_0, x_0 = holzer(y_0, z_0, x_0, abs(b_2), abs(c_2), abs(a_2)) + + x_0 = reconstruct(b_1, c_1, x_0) + y_0 = reconstruct(a_1, c_1, y_0) + z_0 = reconstruct(a_1, b_1, z_0) + + sq_lcm = ilcm(sqf_of_a, sqf_of_b, sqf_of_c) + + x_0 = abs(x_0*sq_lcm // sqf_of_a) + y_0 = abs(y_0*sq_lcm // sqf_of_b) + z_0 = abs(z_0*sq_lcm // sqf_of_c) + + result.add(_remove_gcd(x_0, y_0, z_0)) + return result + + +class HomogeneousTernaryQuadratic(DiophantineEquationType): + """ + Representation of a homogeneous ternary quadratic diophantine equation. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import HomogeneousTernaryQuadratic + >>> HomogeneousTernaryQuadratic(x**2 + y**2 - 3*z**2 + x*y).solve() + {(-1, 2, 1)} + >>> HomogeneousTernaryQuadratic(3*x**2 + y**2 - 3*z**2 + 5*x*y + y*z).solve() + {(3, 12, 13)} + + """ + + name = 'homogeneous_ternary_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension == 3): + return False + if not self.homogeneous: + return False + if not self.homogeneous_order: + return False + + nonzero = [k for k in self.coeff if self.coeff[k]] + return not (len(nonzero) == 3 and all(i**2 in nonzero for i in self.free_symbols)) + + def solve(self, parameters=None, limit=None): + self.pre_solve(parameters) + + _var = self.free_symbols + coeff = self.coeff + + x, y, z = _var + var = [x, y, z] + + # Equations of the form B*x*y + C*z*x + E*y*z = 0 and At least two of the + # coefficients A, B, C are non-zero. + # There are infinitely many solutions for the equation. + # Ex: (0, 0, t), (0, t, 0), (t, 0, 0) + # Equation can be re-written as y*(B*x + E*z) = -C*x*z and we can find rather + # unobvious solutions. Set y = -C and B*x + E*z = x*z. The latter can be solved by + # using methods for binary quadratic diophantine equations. Let's select the + # solution which minimizes |x| + |z| + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + def unpack_sol(sol): + if len(sol) > 0: + return list(sol)[0] + return None, None, None + + if not any(coeff[i**2] for i in var): + if coeff[x*z]: + sols = diophantine(coeff[x*y]*x + coeff[y*z]*z - x*z) + s = min(sols, key=lambda r: abs(r[0]) + abs(r[1])) + result.add(_remove_gcd(s[0], -coeff[x*z], s[1])) + return result + + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + if x_0 is not None: + result.add((x_0, y_0, z_0)) + return result + + if coeff[x**2] == 0: + # If the coefficient of x is zero change the variables + if coeff[y**2] == 0: + var[0], var[2] = _var[2], _var[0] + z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + if coeff[x*y] or coeff[x*z]: + # Apply the transformation x --> X - (B*y + C*z)/(2*A) + A = coeff[x**2] + B = coeff[x*y] + C = coeff[x*z] + D = coeff[y**2] + E = coeff[y*z] + F = coeff[z**2] + + _coeff = {} + + _coeff[x**2] = 4*A**2 + _coeff[y**2] = 4*A*D - B**2 + _coeff[z**2] = 4*A*F - C**2 + _coeff[y*z] = 4*A*E - 2*B*C + _coeff[x*y] = 0 + _coeff[x*z] = 0 + + x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, _coeff)) + + if x_0 is None: + return result + + p, q = _rational_pq(B*y_0 + C*z_0, 2*A) + x_0, y_0, z_0 = x_0*q - p, y_0*q, z_0*q + + elif coeff[z*y] != 0: + if coeff[y**2] == 0: + if coeff[z**2] == 0: + # Equations of the form A*x**2 + E*yz = 0. + A = coeff[x**2] + E = coeff[y*z] + + b, a = _rational_pq(-E, A) + + x_0, y_0, z_0 = b, a, b + + else: + # Ax**2 + E*y*z + F*z**2 = 0 + var[0], var[2] = _var[2], _var[0] + z_0, y_0, x_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, C may be zero + var[0], var[1] = _var[1], _var[0] + y_0, x_0, z_0 = unpack_sol(_diop_ternary_quadratic(var, coeff)) + + else: + # Ax**2 + D*y**2 + F*z**2 = 0, C may be zero + x_0, y_0, z_0 = unpack_sol(_diop_ternary_quadratic_normal(var, coeff)) + + if x_0 is None: + return result + + result.add(_remove_gcd(x_0, y_0, z_0)) + return result + + +class InhomogeneousGeneralQuadratic(DiophantineEquationType): + """ + + Representation of an inhomogeneous general quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'inhomogeneous_general_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return True + # there may be Pow keys like x**2 or Mul keys like x*y + return any(k.is_Mul for k in self.coeff) and not self.homogeneous + + +class HomogeneousGeneralQuadratic(DiophantineEquationType): + """ + + Representation of a homogeneous general quadratic. + + No solver is currently implemented for this equation type. + + """ + + name = 'homogeneous_general_quadratic' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + # there may be Pow keys like x**2 or Mul keys like x*y + return any(k.is_Mul for k in self.coeff) and self.homogeneous + + +class GeneralSumOfSquares(DiophantineEquationType): + r""" + Representation of the diophantine equation + + `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Details + ======= + + When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be + no solutions. Refer [1]_ for more details. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfSquares + >>> from sympy.abc import a, b, c, d, e + >>> GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve() + {(15, 22, 22, 24, 24)} + + By default only 1 solution is returned. Use the `limit` keyword for more: + + >>> sorted(GeneralSumOfSquares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345).solve(limit=3)) + [(15, 22, 22, 24, 24), (16, 19, 24, 24, 24), (16, 20, 22, 23, 26)] + + References + ========== + + .. [1] Representing an integer as a sum of three squares, [online], + Available: + https://www.proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares + """ + + name = 'general_sum_of_squares' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + if any(k.is_Mul for k in self.coeff): + return False + return all(self.coeff[k] == 1 for k in self.coeff if k != 1) + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + var = self.free_symbols + k = -int(self.coeff[1]) + n = self.dimension + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + if k < 0 or limit < 1: + return result + + signs = [-1 if x.is_nonpositive else 1 for x in var] + negs = signs.count(-1) != 0 + + took = 0 + for t in sum_of_squares(k, n, zeros=True): + if negs: + result.add([signs[i]*j for i, j in enumerate(t)]) + else: + result.add(t) + took += 1 + if took == limit: + break + return result + + +class GeneralPythagorean(DiophantineEquationType): + """ + Representation of the general pythagorean equation, + `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralPythagorean + >>> from sympy.abc import a, b, c, d, e, x, y, z, t + >>> GeneralPythagorean(a**2 + b**2 + c**2 - d**2).solve() + {(t_0**2 + t_1**2 - t_2**2, 2*t_0*t_2, 2*t_1*t_2, t_0**2 + t_1**2 + t_2**2)} + >>> GeneralPythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2).solve(parameters=[x, y, z, t]) + {(-10*t**2 + 10*x**2 + 10*y**2 + 10*z**2, 15*t**2 + 15*x**2 + 15*y**2 + 15*z**2, 15*t*x, 12*t*y, 60*t*z)} + """ + + name = 'general_pythagorean' + + def matches(self): + if not (self.total_degree == 2 and self.dimension >= 3): + return False + if not self.homogeneous_order: + return False + if any(k.is_Mul for k in self.coeff): + return False + if all(self.coeff[k] == 1 for k in self.coeff if k != 1): + return False + if not all(is_square(abs(self.coeff[k])) for k in self.coeff): + return False + # all but one has the same sign + # e.g. 4*x**2 + y**2 - 4*z**2 + return abs(sum(sign(self.coeff[k]) for k in self.coeff)) == self.dimension - 2 + + @property + def n_parameters(self): + return self.dimension - 1 + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + coeff = self.coeff + var = self.free_symbols + n = self.dimension + + if sign(coeff[var[0] ** 2]) + sign(coeff[var[1] ** 2]) + sign(coeff[var[2] ** 2]) < 0: + for key in coeff.keys(): + coeff[key] = -coeff[key] + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + index = 0 + + for i, v in enumerate(var): + if sign(coeff[v ** 2]) == -1: + index = i + + m = result.parameters + + ith = sum(m_i ** 2 for m_i in m) + L = [ith - 2 * m[n - 2] ** 2] + L.extend([2 * m[i] * m[n - 2] for i in range(n - 2)]) + sol = L[:index] + [ith] + L[index:] + + lcm = 1 + for i, v in enumerate(var): + if i == index or (index > 0 and i == 0) or (index == 0 and i == 1): + lcm = ilcm(lcm, sqrt(abs(coeff[v ** 2]))) + else: + s = sqrt(coeff[v ** 2]) + lcm = ilcm(lcm, s if _odd(s) else s // 2) + + for i, v in enumerate(var): + sol[i] = (lcm * sol[i]) / sqrt(abs(coeff[v ** 2])) + + result.add(sol) + return result + + +class CubicThue(DiophantineEquationType): + """ + Representation of a cubic Thue diophantine equation. + + A cubic Thue diophantine equation is a polynomial of the form + `f(x, y) = r` of degree 3, where `x` and `y` are integers + and `r` is a rational number. + + No solver is currently implemented for this equation type. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import CubicThue + >>> c1 = CubicThue(x**3 + y**2 + 1) + >>> c1.matches() + True + + """ + + name = 'cubic_thue' + + def matches(self): + return self.total_degree == 3 and self.dimension == 2 + + +class GeneralSumOfEvenPowers(DiophantineEquationType): + """ + Representation of the diophantine equation + + `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0` + + where `e` is an even, integer power. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import GeneralSumOfEvenPowers + >>> from sympy.abc import a, b + >>> GeneralSumOfEvenPowers(a**4 + b**4 - (2**4 + 3**4)).solve() + {(2, 3)} + + """ + + name = 'general_sum_of_even_powers' + + def matches(self): + if not self.total_degree > 3: + return False + if self.total_degree % 2 != 0: + return False + if not all(k.is_Pow and k.exp == self.total_degree for k in self.coeff if k != 1): + return False + return all(self.coeff[k] == 1 for k in self.coeff if k != 1) + + def solve(self, parameters=None, limit=1): + self.pre_solve(parameters) + + var = self.free_symbols + coeff = self.coeff + + p = None + for q in coeff.keys(): + if q.is_Pow and coeff[q]: + p = q.exp + + k = len(var) + n = -coeff[1] + + result = DiophantineSolutionSet(var, parameters=self.parameters) + + if n < 0 or limit < 1: + return result + + sign = [-1 if x.is_nonpositive else 1 for x in var] + negs = sign.count(-1) != 0 + + took = 0 + for t in power_representation(n, p, k): + if negs: + result.add([sign[i]*j for i, j in enumerate(t)]) + else: + result.add(t) + took += 1 + if took == limit: + break + return result + +# these types are known (but not necessarily handled) +# note that order is important here (in the current solver state) +all_diop_classes = [ + Linear, + Univariate, + BinaryQuadratic, + InhomogeneousTernaryQuadratic, + HomogeneousTernaryQuadraticNormal, + HomogeneousTernaryQuadratic, + InhomogeneousGeneralQuadratic, + HomogeneousGeneralQuadratic, + GeneralSumOfSquares, + GeneralPythagorean, + CubicThue, + GeneralSumOfEvenPowers, +] + +diop_known = {diop_class.name for diop_class in all_diop_classes} + + +def _remove_gcd(*x): + try: + g = igcd(*x) + except ValueError: + fx = list(filter(None, x)) + if len(fx) < 2: + return x + g = igcd(*[i.as_content_primitive()[0] for i in fx]) + except TypeError: + raise TypeError('_remove_gcd(a,b,c) or _remove_gcd(*container)') + if g == 1: + return x + return tuple([i//g for i in x]) + + +def _rational_pq(a, b): + # return `(numer, denom)` for a/b; sign in numer and gcd removed + return _remove_gcd(sign(b)*a, abs(b)) + + +def _nint_or_floor(p, q): + # return nearest int to p/q; in case of tie return floor(p/q) + w, r = divmod(p, q) + if abs(r) <= abs(q)//2: + return w + return w + 1 + + +def _odd(i): + return i % 2 != 0 + + +def _even(i): + return i % 2 == 0 + + +def diophantine(eq, param=symbols("t", integer=True), syms=None, + permute=False): + """ + Simplify the solution procedure of diophantine equation ``eq`` by + converting it into a product of terms which should equal zero. + + Explanation + =========== + + For example, when solving, `x^2 - y^2 = 0` this is treated as + `(x + y)(x - y) = 0` and `x + y = 0` and `x - y = 0` are solved + independently and combined. Each term is solved by calling + ``diop_solve()``. (Although it is possible to call ``diop_solve()`` + directly, one must be careful to pass an equation in the correct + form and to interpret the output correctly; ``diophantine()`` is + the public-facing function to use in general.) + + Output of ``diophantine()`` is a set of tuples. The elements of the + tuple are the solutions for each variable in the equation and + are arranged according to the alphabetic ordering of the variables. + e.g. For an equation with two variables, `a` and `b`, the first + element of the tuple is the solution for `a` and the second for `b`. + + Usage + ===== + + ``diophantine(eq, t, syms)``: Solve the diophantine + equation ``eq``. + ``t`` is the optional parameter to be used by ``diop_solve()``. + ``syms`` is an optional list of symbols which determines the + order of the elements in the returned tuple. + + By default, only the base solution is returned. If ``permute`` is set to + True then permutations of the base solution and/or permutations of the + signs of the values will be returned when applicable. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``t`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy import diophantine + >>> from sympy.abc import a, b + >>> eq = a**4 + b**4 - (2**4 + 3**4) + >>> diophantine(eq) + {(2, 3)} + >>> diophantine(eq, permute=True) + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + + >>> from sympy.abc import x, y, z + >>> diophantine(x**2 - y**2) + {(t_0, -t_0), (t_0, t_0)} + + >>> diophantine(x*(2*x + 3*y - z)) + {(0, n1, n2), (t_0, t_1, 2*t_0 + 3*t_1)} + >>> diophantine(x**2 + 3*x*y + 4*x) + {(0, n1), (-3*t_0 - 4, t_0)} + + See Also + ======== + + diop_solve + sympy.utilities.iterables.permute_signs + sympy.utilities.iterables.signed_permutations + """ + + eq = _sympify(eq) + + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + + try: + var = list(eq.expand(force=True).free_symbols) + var.sort(key=default_sort_key) + if syms: + if not is_sequence(syms): + raise TypeError( + 'syms should be given as a sequence, e.g. a list') + syms = [i for i in syms if i in var] + if syms != var: + dict_sym_index = dict(zip(syms, range(len(syms)))) + return {tuple([t[dict_sym_index[i]] for i in var]) + for t in diophantine(eq, param, permute=permute)} + n, d = eq.as_numer_denom() + if n.is_number: + return set() + if not d.is_number: + dsol = diophantine(d) + good = diophantine(n) - dsol + return {s for s in good if _mexpand(d.subs(zip(var, s)))} + eq = factor_terms(n) + assert not eq.is_number + eq = eq.as_independent(*var, as_Add=False)[1] + p = Poly(eq) + assert not any(g.is_number for g in p.gens) + eq = p.as_expr() + assert eq.is_polynomial() + except (GeneratorsNeeded, AssertionError): + raise TypeError(filldedent(''' + Equation should be a polynomial with Rational coefficients.''')) + + # permute only sign + do_permute_signs = False + # permute sign and values + do_permute_signs_var = False + # permute few signs + permute_few_signs = False + try: + # if we know that factoring should not be attempted, skip + # the factoring step + v, c, t = classify_diop(eq) + + # check for permute sign + if permute: + len_var = len(v) + permute_signs_for = [ + GeneralSumOfSquares.name, + GeneralSumOfEvenPowers.name] + permute_signs_check = [ + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name, + BinaryQuadratic.name] + if t in permute_signs_for: + do_permute_signs_var = True + elif t in permute_signs_check: + # if all the variables in eq have even powers + # then do_permute_sign = True + if len_var == 3: + var_mul = list(subsets(v, 2)) + # here var_mul is like [(x, y), (x, z), (y, z)] + xy_coeff = True + x_coeff = True + var1_mul_var2 = (a[0]*a[1] for a in var_mul) + # if coeff(y*z), coeff(y*x), coeff(x*z) is not 0 then + # `xy_coeff` => True and do_permute_sign => False. + # Means no permuted solution. + for v1_mul_v2 in var1_mul_var2: + try: + coeff = c[v1_mul_v2] + except KeyError: + coeff = 0 + xy_coeff = bool(xy_coeff) and bool(coeff) + var_mul = list(subsets(v, 1)) + # here var_mul is like [(x,), (y, )] + for v1 in var_mul: + try: + coeff = c[v1[0]] + except KeyError: + coeff = 0 + x_coeff = bool(x_coeff) and bool(coeff) + if not any((xy_coeff, x_coeff)): + # means only x**2, y**2, z**2, const is present + do_permute_signs = True + elif not x_coeff: + permute_few_signs = True + elif len_var == 2: + var_mul = list(subsets(v, 2)) + # here var_mul is like [(x, y)] + xy_coeff = True + x_coeff = True + var1_mul_var2 = (x[0]*x[1] for x in var_mul) + for v1_mul_v2 in var1_mul_var2: + try: + coeff = c[v1_mul_v2] + except KeyError: + coeff = 0 + xy_coeff = bool(xy_coeff) and bool(coeff) + var_mul = list(subsets(v, 1)) + # here var_mul is like [(x,), (y, )] + for v1 in var_mul: + try: + coeff = c[v1[0]] + except KeyError: + coeff = 0 + x_coeff = bool(x_coeff) and bool(coeff) + if not any((xy_coeff, x_coeff)): + # means only x**2, y**2 and const is present + # so we can get more soln by permuting this soln. + do_permute_signs = True + elif not x_coeff: + # when coeff(x), coeff(y) is not present then signs of + # x, y can be permuted such that their sign are same + # as sign of x*y. + # e.g 1. (x_val,y_val)=> (x_val,y_val), (-x_val,-y_val) + # 2. (-x_vall, y_val)=> (-x_val,y_val), (x_val,-y_val) + permute_few_signs = True + if t == 'general_sum_of_squares': + # trying to factor such expressions will sometimes hang + terms = [(eq, 1)] + else: + raise TypeError + except (TypeError, NotImplementedError): + fl = factor_list(eq) + if fl[0].is_Rational and fl[0] != 1: + return diophantine(eq/fl[0], param=param, syms=syms, permute=permute) + terms = fl[1] + + sols = set() + + for term in terms: + + base, _ = term + var_t, _, eq_type = classify_diop(base, _dict=False) + _, base = signsimp(base, evaluate=False).as_coeff_Mul() + solution = diop_solve(base, param) + + if eq_type in [ + Linear.name, + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name, + GeneralPythagorean.name]: + sols.add(merge_solution(var, var_t, solution)) + + elif eq_type in [ + BinaryQuadratic.name, + GeneralSumOfSquares.name, + GeneralSumOfEvenPowers.name, + Univariate.name]: + sols.update(merge_solution(var, var_t, sol) for sol in solution) + + else: + raise NotImplementedError('unhandled type: %s' % eq_type) + + # remove null merge results + if () in sols: + sols.remove(()) + null = tuple([0]*len(var)) + # if there is no solution, return trivial solution + if not sols and eq.subs(zip(var, null)).is_zero: + sols.add(null) + final_soln = set() + for sol in sols: + if all(int_valued(s) for s in sol): + if do_permute_signs: + permuted_sign = set(permute_signs(sol)) + final_soln.update(permuted_sign) + elif permute_few_signs: + lst = list(permute_signs(sol)) + lst = list(filter(lambda x: x[0]*x[1] == sol[1]*sol[0], lst)) + permuted_sign = set(lst) + final_soln.update(permuted_sign) + elif do_permute_signs_var: + permuted_sign_var = set(signed_permutations(sol)) + final_soln.update(permuted_sign_var) + else: + final_soln.add(sol) + else: + final_soln.add(sol) + return final_soln + + +def merge_solution(var, var_t, solution): + """ + This is used to construct the full solution from the solutions of sub + equations. + + Explanation + =========== + + For example when solving the equation `(x - y)(x^2 + y^2 - z^2) = 0`, + solutions for each of the equations `x - y = 0` and `x^2 + y^2 - z^2` are + found independently. Solutions for `x - y = 0` are `(x, y) = (t, t)`. But + we should introduce a value for z when we output the solution for the + original equation. This function converts `(t, t)` into `(t, t, n_{1})` + where `n_{1}` is an integer parameter. + """ + sol = [] + + if None in solution: + return () + + solution = iter(solution) + params = numbered_symbols("n", integer=True, start=1) + for v in var: + if v in var_t: + sol.append(next(solution)) + else: + sol.append(next(params)) + + for val, symb in zip(sol, var): + if check_assumptions(val, **symb.assumptions0) is False: + return () + + return tuple(sol) + + +def _diop_solve(eq, params=None): + for diop_type in all_diop_classes: + if diop_type(eq).matches(): + return diop_type(eq).solve(parameters=params) + + +def diop_solve(eq, param=symbols("t", integer=True)): + """ + Solves the diophantine equation ``eq``. + + Explanation + =========== + + Unlike ``diophantine()``, factoring of ``eq`` is not attempted. Uses + ``classify_diop()`` to determine the type of the equation and calls + the appropriate solver function. + + Use of ``diophantine()`` is recommended over other helper functions. + ``diop_solve()`` can return either a set or a tuple depending on the + nature of the equation. + + Usage + ===== + + ``diop_solve(eq, t)``: Solve diophantine equation, ``eq`` using ``t`` + as a parameter if needed. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``t`` is a parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine import diop_solve + >>> from sympy.abc import x, y, z, w + >>> diop_solve(2*x + 3*y - 5) + (3*t_0 - 5, 5 - 2*t_0) + >>> diop_solve(4*x + 3*y - 4*z + 5) + (t_0, 8*t_0 + 4*t_1 + 5, 7*t_0 + 3*t_1 + 5) + >>> diop_solve(x + 3*y - 4*z + w - 6) + (t_0, t_0 + t_1, 6*t_0 + 5*t_1 + 4*t_2 - 6, 5*t_0 + 4*t_1 + 3*t_2 - 6) + >>> diop_solve(x**2 + y**2 - 5) + {(-2, -1), (-2, 1), (-1, -2), (-1, 2), (1, -2), (1, 2), (2, -1), (2, 1)} + + + See Also + ======== + + diophantine() + """ + var, coeff, eq_type = classify_diop(eq, _dict=False) + + if eq_type == Linear.name: + return diop_linear(eq, param) + + elif eq_type == BinaryQuadratic.name: + return diop_quadratic(eq, param) + + elif eq_type == HomogeneousTernaryQuadratic.name: + return diop_ternary_quadratic(eq, parameterize=True) + + elif eq_type == HomogeneousTernaryQuadraticNormal.name: + return diop_ternary_quadratic_normal(eq, parameterize=True) + + elif eq_type == GeneralPythagorean.name: + return diop_general_pythagorean(eq, param) + + elif eq_type == Univariate.name: + return diop_univariate(eq) + + elif eq_type == GeneralSumOfSquares.name: + return diop_general_sum_of_squares(eq, limit=S.Infinity) + + elif eq_type == GeneralSumOfEvenPowers.name: + return diop_general_sum_of_even_powers(eq, limit=S.Infinity) + + if eq_type is not None and eq_type not in diop_known: + raise ValueError(filldedent(''' + Although this type of equation was identified, it is not yet + handled. It should, however, be listed in `diop_known` at the + top of this file. Developers should see comments at the end of + `classify_diop`. + ''')) # pragma: no cover + else: + raise NotImplementedError( + 'No solver has been written for %s.' % eq_type) + + +def classify_diop(eq, _dict=True): + # docstring supplied externally + + matched = False + diop_type = None + for diop_class in all_diop_classes: + diop_type = diop_class(eq) + if diop_type.matches(): + matched = True + break + + if matched: + return diop_type.free_symbols, dict(diop_type.coeff) if _dict else diop_type.coeff, diop_type.name + + # new diop type instructions + # -------------------------- + # if this error raises and the equation *can* be classified, + # * it should be identified in the if-block above + # * the type should be added to the diop_known + # if a solver can be written for it, + # * a dedicated handler should be written (e.g. diop_linear) + # * it should be passed to that handler in diop_solve + raise NotImplementedError(filldedent(''' + This equation is not yet recognized or else has not been + simplified sufficiently to put it in a form recognized by + diop_classify().''')) + + +classify_diop.func_doc = ( # type: ignore + ''' + Helper routine used by diop_solve() to find information about ``eq``. + + Explanation + =========== + + Returns a tuple containing the type of the diophantine equation + along with the variables (free symbols) and their coefficients. + Variables are returned as a list and coefficients are returned + as a dict with the key being the respective term and the constant + term is keyed to 1. The type is one of the following: + + * %s + + Usage + ===== + + ``classify_diop(eq)``: Return variables, coefficients and type of the + ``eq``. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``_dict`` is for internal use: when True (default) a dict is returned, + otherwise a defaultdict which supplies 0 for missing keys is returned. + + Examples + ======== + + >>> from sympy.solvers.diophantine import classify_diop + >>> from sympy.abc import x, y, z, w, t + >>> classify_diop(4*x + 6*y - 4) + ([x, y], {1: -4, x: 4, y: 6}, 'linear') + >>> classify_diop(x + 3*y -4*z + 5) + ([x, y, z], {1: 5, x: 1, y: 3, z: -4}, 'linear') + >>> classify_diop(x**2 + y**2 - x*y + x + 5) + ([x, y], {1: 5, x: 1, x**2: 1, y**2: 1, x*y: -1}, 'binary_quadratic') + ''' % ('\n * '.join(sorted(diop_known)))) + + +def diop_linear(eq, param=symbols("t", integer=True)): + """ + Solves linear diophantine equations. + + A linear diophantine equation is an equation of the form `a_{1}x_{1} + + a_{2}x_{2} + .. + a_{n}x_{n} = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x_{1}, x_{2}, ..x_{n}` are integer variables. + + Usage + ===== + + ``diop_linear(eq)``: Returns a tuple containing solutions to the + diophantine equation ``eq``. Values in the tuple is arranged in the same + order as the sorted variables. + + Details + ======= + + ``eq`` is a linear diophantine equation which is assumed to be zero. + ``param`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_linear + >>> from sympy.abc import x, y, z + >>> diop_linear(2*x - 3*y - 5) # solves equation 2*x - 3*y - 5 == 0 + (3*t_0 - 5, 2*t_0 - 5) + + Here x = -3*t_0 - 5 and y = -2*t_0 - 5 + + >>> diop_linear(2*x - 3*y - 4*z -3) + (t_0, 2*t_0 + 4*t_1 + 3, -t_0 - 3*t_1 - 3) + + See Also + ======== + + diop_quadratic(), diop_ternary_quadratic(), diop_general_pythagorean(), + diop_general_sum_of_squares() + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == Linear.name: + parameters = None + if param is not None: + parameters = symbols('%s_0:%i' % (param, len(var)), integer=True) + + result = Linear(eq).solve(parameters=parameters) + + if param is None: + result = result(*[0]*len(result.parameters)) + + if len(result) > 0: + return list(result)[0] + else: + return tuple([None]*len(result.parameters)) + + +def base_solution_linear(c, a, b, t=None): + """ + Return the base solution for the linear equation, `ax + by = c`. + + Explanation + =========== + + Used by ``diop_linear()`` to find the base solution of a linear + Diophantine equation. If ``t`` is given then the parametrized solution is + returned. + + Usage + ===== + + ``base_solution_linear(c, a, b, t)``: ``a``, ``b``, ``c`` are coefficients + in `ax + by = c` and ``t`` is the parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import base_solution_linear + >>> from sympy.abc import t + >>> base_solution_linear(5, 2, 3) # equation 2*x + 3*y = 5 + (-5, 5) + >>> base_solution_linear(0, 5, 7) # equation 5*x + 7*y = 0 + (0, 0) + >>> base_solution_linear(5, 2, 3, t) # equation 2*x + 3*y = 5 + (3*t - 5, 5 - 2*t) + >>> base_solution_linear(0, 5, 7, t) # equation 5*x + 7*y = 0 + (7*t, -5*t) + """ + a, b, c = _remove_gcd(a, b, c) + + if c == 0: + if t is None: + return (0, 0) + if b < 0: + t = -t + return (b*t, -a*t) + + x0, y0, d = igcdex(abs(a), abs(b)) + x0 *= sign(a) + y0 *= sign(b) + if c % d: + return (None, None) + if t is None: + return (c*x0, c*y0) + if b < 0: + t = -t + return (c*x0 + b*t, c*y0 - a*t) + + +def diop_univariate(eq): + """ + Solves a univariate diophantine equations. + + Explanation + =========== + + A univariate diophantine equation is an equation of the form + `a_{0} + a_{1}x + a_{2}x^2 + .. + a_{n}x^n = 0` where `a_{1}, a_{2}, ..a_{n}` are + integer constants and `x` is an integer variable. + + Usage + ===== + + ``diop_univariate(eq)``: Returns a set containing solutions to the + diophantine equation ``eq``. + + Details + ======= + + ``eq`` is a univariate diophantine equation which is assumed to be zero. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_univariate + >>> from sympy.abc import x + >>> diop_univariate((x - 2)*(x - 3)**2) # solves equation (x - 2)*(x - 3)**2 == 0 + {(2,), (3,)} + + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == Univariate.name: + return {(int(i),) for i in solveset_real( + eq, var[0]).intersect(S.Integers)} + + +def divisible(a, b): + """ + Returns `True` if ``a`` is divisible by ``b`` and `False` otherwise. + """ + return not a % b + + +def diop_quadratic(eq, param=symbols("t", integer=True)): + """ + Solves quadratic diophantine equations. + + i.e. equations of the form `Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0`. Returns a + set containing the tuples `(x, y)` which contains the solutions. If there + are no solutions then `(None, None)` is returned. + + Usage + ===== + + ``diop_quadratic(eq, param)``: ``eq`` is a quadratic binary diophantine + equation. ``param`` is used to indicate the parameter to be used in the + solution. + + Details + ======= + + ``eq`` should be an expression which is assumed to be zero. + ``param`` is a parameter to be used in the solution. + + Examples + ======== + + >>> from sympy.abc import x, y, t + >>> from sympy.solvers.diophantine.diophantine import diop_quadratic + >>> diop_quadratic(x**2 + y**2 + 2*x + 2*y + 2, t) + {(-1, -1)} + + References + ========== + + .. [1] Methods to solve Ax^2 + Bxy + Cy^2 + Dx + Ey + F = 0, [online], + Available: https://www.alpertron.com.ar/METHODS.HTM + .. [2] Solving the equation ax^2+ bxy + cy^2 + dx + ey + f= 0, [online], + Available: https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + + See Also + ======== + + diop_linear(), diop_ternary_quadratic(), diop_general_sum_of_squares(), + diop_general_pythagorean() + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == BinaryQuadratic.name: + if param is not None: + parameters = [param, Symbol("u", integer=True)] + else: + parameters = None + return set(BinaryQuadratic(eq).solve(parameters=parameters)) + + +def is_solution_quad(var, coeff, u, v): + """ + Check whether `(u, v)` is solution to the quadratic binary diophantine + equation with the variable list ``var`` and coefficient dictionary + ``coeff``. + + Not intended for use by normal users. + """ + reps = dict(zip(var, (u, v))) + eq = Add(*[j*i.xreplace(reps) for i, j in coeff.items()]) + return _mexpand(eq) == 0 + + +def diop_DN(D, N, t=symbols("t", integer=True)): + """ + Solves the equation `x^2 - Dy^2 = N`. + + Explanation + =========== + + Mainly concerned with the case `D > 0, D` is not a perfect square, + which is the same as the generalized Pell equation. The LMM + algorithm [1]_ is used to solve this equation. + + Returns one solution tuple, (`x, y)` for each class of the solutions. + Other solutions of the class can be constructed according to the + values of ``D`` and ``N``. + + Usage + ===== + + ``diop_DN(D, N, t)``: D and N are integers as in `x^2 - Dy^2 = N` and + ``t`` is the parameter to be used in the solutions. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + ``t`` is the parameter to be used in the solutions. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_DN + >>> diop_DN(13, -4) # Solves equation x**2 - 13*y**2 = -4 + [(3, 1), (393, 109), (36, 10)] + + The output can be interpreted as follows: There are three fundamental + solutions to the equation `x^2 - 13y^2 = -4` given by (3, 1), (393, 109) + and (36, 10). Each tuple is in the form (x, y), i.e. solution (3, 1) means + that `x = 3` and `y = 1`. + + >>> diop_DN(986, 1) # Solves equation x**2 - 986*y**2 = 1 + [(49299, 1570)] + + See Also + ======== + + find_DN(), diop_bf_DN() + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Pages 16 - 17. [online], Available: + https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + if D < 0: + if N == 0: + return [(0, 0)] + if N < 0: + return [] + # N > 0: + sol = [] + for d in divisors(square_factor(N), generator=True): + for x, y in cornacchia(1, int(-D), int(N // d**2)): + sol.append((d*x, d*y)) + if D == -1: + sol.append((d*y, d*x)) + return sol + + if D == 0: + if N < 0: + return [] + if N == 0: + return [(0, t)] + sN, _exact = integer_nthroot(N, 2) + if _exact: + return [(sN, t)] + return [] + + # D > 0 + sD, _exact = integer_nthroot(D, 2) + if _exact: + if N == 0: + return [(sD*t, t)] + + sol = [] + for y in range(floor(sign(N)*(N - 1)/(2*sD)) + 1): + try: + sq, _exact = integer_nthroot(D*y**2 + N, 2) + except ValueError: + _exact = False + if _exact: + sol.append((sq, y)) + return sol + + if 1 < N**2 < D: + # It is much faster to call `_special_diop_DN`. + return _special_diop_DN(D, N) + + if N == 0: + return [(0, 0)] + + sol = [] + if abs(N) == 1: + pqa = PQa(0, 1, D) + *_, prev_B, prev_G = next(pqa) + for j, (*_, a, _, _B, _G) in enumerate(pqa): + if a == 2*sD: + break + prev_B, prev_G = _B, _G + if j % 2: + if N == 1: + sol.append((prev_G, prev_B)) + return sol + if N == -1: + return [(prev_G, prev_B)] + for _ in range(j): + *_, _B, _G = next(pqa) + return [(_G, _B)] + + for f in divisors(square_factor(N), generator=True): + m = N // f**2 + am = abs(m) + for sqm in sqrt_mod(D, am, all_roots=True): + z = symmetric_residue(sqm, am) + pqa = PQa(z, am, D) + *_, prev_B, prev_G = next(pqa) + for _ in range(length(z, am, D) - 1): + _, q, *_, _B, _G = next(pqa) + if abs(q) == 1: + if prev_G**2 - D*prev_B**2 == m: + sol.append((f*prev_G, f*prev_B)) + elif a := diop_DN(D, -1): + sol.append((f*(prev_G*a[0][0] + prev_B*D*a[0][1]), + f*(prev_G*a[0][1] + prev_B*a[0][0]))) + break + prev_B, prev_G = _B, _G + return sol + + +def _special_diop_DN(D, N): + """ + Solves the equation `x^2 - Dy^2 = N` for the special case where + `1 < N**2 < D` and `D` is not a perfect square. + It is better to call `diop_DN` rather than this function, as + the former checks the condition `1 < N**2 < D`, and calls the latter only + if appropriate. + + Usage + ===== + + WARNING: Internal method. Do not call directly! + + ``_special_diop_DN(D, N)``: D and N are integers as in `x^2 - Dy^2 = N`. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import _special_diop_DN + >>> _special_diop_DN(13, -3) # Solves equation x**2 - 13*y**2 = -3 + [(7, 2), (137, 38)] + + The output can be interpreted as follows: There are two fundamental + solutions to the equation `x^2 - 13y^2 = -3` given by (7, 2) and + (137, 38). Each tuple is in the form (x, y), i.e. solution (7, 2) means + that `x = 7` and `y = 2`. + + >>> _special_diop_DN(2445, -20) # Solves equation x**2 - 2445*y**2 = -20 + [(445, 9), (17625560, 356454), (698095554475, 14118073569)] + + See Also + ======== + + diop_DN() + + References + ========== + + .. [1] Section 4.4.4 of the following book: + Quadratic Diophantine Equations, T. Andreescu and D. Andrica, + Springer, 2015. + """ + + # The following assertion was removed for efficiency, with the understanding + # that this method is not called directly. The parent method, `diop_DN` + # is responsible for performing the appropriate checks. + # + # assert (1 < N**2 < D) and (not integer_nthroot(D, 2)[1]) + + sqrt_D = isqrt(D) + F = {N // f**2: f for f in divisors(square_factor(abs(N)), generator=True)} + P = 0 + Q = 1 + G0, G1 = 0, 1 + B0, B1 = 1, 0 + + solutions = [] + while True: + for _ in range(2): + a = (P + sqrt_D) // Q + P = a*Q - P + Q = (D - P**2) // Q + G0, G1 = G1, a*G1 + G0 + B0, B1 = B1, a*B1 + B0 + if (s := G1**2 - D*B1**2) in F: + f = F[s] + solutions.append((f*G1, f*B1)) + if Q == 1: + break + return solutions + + +def cornacchia(a:int, b:int, m:int) -> set[tuple[int, int]]: + r""" + Solves `ax^2 + by^2 = m` where `\gcd(a, b) = 1 = gcd(a, m)` and `a, b > 0`. + + Explanation + =========== + + Uses the algorithm due to Cornacchia. The method only finds primitive + solutions, i.e. ones with `\gcd(x, y) = 1`. So this method cannot be used to + find the solutions of `x^2 + y^2 = 20` since the only solution to former is + `(x, y) = (4, 2)` and it is not primitive. When `a = b`, only the + solutions with `x \leq y` are found. For more details, see the References. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import cornacchia + >>> cornacchia(2, 3, 35) # equation 2x**2 + 3y**2 = 35 + {(2, 3), (4, 1)} + >>> cornacchia(1, 1, 25) # equation x**2 + y**2 = 25 + {(4, 3)} + + References + =========== + + .. [1] A. Nitaj, "L'algorithme de Cornacchia" + .. [2] Solving the diophantine equation ax**2 + by**2 = m by Cornacchia's + method, [online], Available: + http://www.numbertheory.org/php/cornacchia.html + + See Also + ======== + + sympy.utilities.iterables.signed_permutations + """ + # Assume gcd(a, b) = gcd(a, m) = 1 and a, b > 0 but no error checking + sols = set() + for t in sqrt_mod_iter(-b*invert(a, m), m): + if t < m // 2: + continue + u, r = m, t + while (m1 := m - a*r**2) <= 0: + u, r = r, u % r + m1, _r = divmod(m1, b) + if _r: + continue + s, _exact = iroot(m1, 2) + if _exact: + if a == b and r < s: + r, s = s, r + sols.add((int(r), int(s))) + return sols + + +def PQa(P_0, Q_0, D): + r""" + Returns useful information needed to solve the Pell equation. + + Explanation + =========== + + There are six sequences of integers defined related to the continued + fraction representation of `\\frac{P + \sqrt{D}}{Q}`, namely {`P_{i}`}, + {`Q_{i}`}, {`a_{i}`},{`A_{i}`}, {`B_{i}`}, {`G_{i}`}. ``PQa()`` Returns + these values as a 6-tuple in the same order as mentioned above. Refer [1]_ + for more detailed information. + + Usage + ===== + + ``PQa(P_0, Q_0, D)``: ``P_0``, ``Q_0`` and ``D`` are integers corresponding + to `P_{0}`, `Q_{0}` and `D` in the continued fraction + `\\frac{P_{0} + \sqrt{D}}{Q_{0}}`. + Also it's assumed that `P_{0}^2 == D mod(|Q_{0}|)` and `D` is square free. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import PQa + >>> pqa = PQa(13, 4, 5) # (13 + sqrt(5))/4 + >>> next(pqa) # (P_0, Q_0, a_0, A_0, B_0, G_0) + (13, 4, 3, 3, 1, -1) + >>> next(pqa) # (P_1, Q_1, a_1, A_1, B_1, G_1) + (-1, 1, 1, 4, 1, 3) + + References + ========== + + .. [1] Solving the generalized Pell equation x^2 - Dy^2 = N, John P. + Robertson, July 31, 2004, Pages 4 - 8. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + sqD = isqrt(D) + A2 = B1 = 0 + A1 = B2 = 1 + G1 = Q_0 + G2 = -P_0 + P_i = P_0 + Q_i = Q_0 + + while True: + a_i = (P_i + sqD) // Q_i + A1, A2 = a_i*A1 + A2, A1 + B1, B2 = a_i*B1 + B2, B1 + G1, G2 = a_i*G1 + G2, G1 + yield P_i, Q_i, a_i, A1, B1, G1 + + P_i = a_i*Q_i - P_i + Q_i = (D - P_i**2) // Q_i + + +def diop_bf_DN(D, N, t=symbols("t", integer=True)): + r""" + Uses brute force to solve the equation, `x^2 - Dy^2 = N`. + + Explanation + =========== + + Mainly concerned with the generalized Pell equation which is the case when + `D > 0, D` is not a perfect square. For more information on the case refer + [1]_. Let `(t, u)` be the minimal positive solution of the equation + `x^2 - Dy^2 = 1`. Then this method requires + `\sqrt{\\frac{\mid N \mid (t \pm 1)}{2D}}` to be small. + + Usage + ===== + + ``diop_bf_DN(D, N, t)``: ``D`` and ``N`` are coefficients in + `x^2 - Dy^2 = N` and ``t`` is the parameter to be used in the solutions. + + Details + ======= + + ``D`` and ``N`` correspond to D and N in the equation. + ``t`` is the parameter to be used in the solutions. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_bf_DN + >>> diop_bf_DN(13, -4) + [(3, 1), (-3, 1), (36, 10)] + >>> diop_bf_DN(986, 1) + [(49299, 1570)] + + See Also + ======== + + diop_DN() + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Page 15. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + """ + D = as_int(D) + N = as_int(N) + + sol = [] + a = diop_DN(D, 1) + u = a[0][0] + + if N == 0: + if D < 0: + return [(0, 0)] + if D == 0: + return [(0, t)] + sD, _exact = integer_nthroot(D, 2) + if _exact: + return [(sD*t, t), (-sD*t, t)] + return [(0, 0)] + + if abs(N) == 1: + return diop_DN(D, N) + + if N > 1: + L1 = 0 + L2 = integer_nthroot(int(N*(u - 1)/(2*D)), 2)[0] + 1 + else: # N < -1 + L1, _exact = integer_nthroot(-int(N/D), 2) + if not _exact: + L1 += 1 + L2 = integer_nthroot(-int(N*(u + 1)/(2*D)), 2)[0] + 1 + + for y in range(L1, L2): + try: + x, _exact = integer_nthroot(N + D*y**2, 2) + except ValueError: + _exact = False + if _exact: + sol.append((x, y)) + if not equivalent(x, y, -x, y, D, N): + sol.append((-x, y)) + + return sol + + +def equivalent(u, v, r, s, D, N): + """ + Returns True if two solutions `(u, v)` and `(r, s)` of `x^2 - Dy^2 = N` + belongs to the same equivalence class and False otherwise. + + Explanation + =========== + + Two solutions `(u, v)` and `(r, s)` to the above equation fall to the same + equivalence class iff both `(ur - Dvs)` and `(us - vr)` are divisible by + `N`. See reference [1]_. No test is performed to test whether `(u, v)` and + `(r, s)` are actually solutions to the equation. User should take care of + this. + + Usage + ===== + + ``equivalent(u, v, r, s, D, N)``: `(u, v)` and `(r, s)` are two solutions + of the equation `x^2 - Dy^2 = N` and all parameters involved are integers. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import equivalent + >>> equivalent(18, 5, -18, -5, 13, -1) + True + >>> equivalent(3, 1, -18, 393, 109, -4) + False + + References + ========== + + .. [1] Solving the generalized Pell equation x**2 - D*y**2 = N, John P. + Robertson, July 31, 2004, Page 12. https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + + """ + return divisible(u*r - D*v*s, N) and divisible(u*s - v*r, N) + + +def length(P, Q, D): + r""" + Returns the (length of aperiodic part + length of periodic part) of + continued fraction representation of `\\frac{P + \sqrt{D}}{Q}`. + + It is important to remember that this does NOT return the length of the + periodic part but the sum of the lengths of the two parts as mentioned + above. + + Usage + ===== + + ``length(P, Q, D)``: ``P``, ``Q`` and ``D`` are integers corresponding to + the continued fraction `\\frac{P + \sqrt{D}}{Q}`. + + Details + ======= + + ``P``, ``D`` and ``Q`` corresponds to P, D and Q in the continued fraction, + `\\frac{P + \sqrt{D}}{Q}`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import length + >>> length(-2, 4, 5) # (-2 + sqrt(5))/4 + 3 + >>> length(-5, 4, 17) # (-5 + sqrt(17))/4 + 4 + + See Also + ======== + sympy.ntheory.continued_fraction.continued_fraction_periodic + """ + from sympy.ntheory.continued_fraction import continued_fraction_periodic + v = continued_fraction_periodic(P, Q, D) + if isinstance(v[-1], list): + rpt = len(v[-1]) + nonrpt = len(v) - 1 + else: + rpt = 0 + nonrpt = len(v) + return rpt + nonrpt + + +def transformation_to_DN(eq): + """ + This function transforms general quadratic, + `ax^2 + bxy + cy^2 + dx + ey + f = 0` + to more easy to deal with `X^2 - DY^2 = N` form. + + Explanation + =========== + + This is used to solve the general quadratic equation by transforming it to + the latter form. Refer to [1]_ for more detailed information on the + transformation. This function returns a tuple (A, B) where A is a 2 X 2 + matrix and B is a 2 X 1 matrix such that, + + Transpose([x y]) = A * Transpose([X Y]) + B + + Usage + ===== + + ``transformation_to_DN(eq)``: where ``eq`` is the quadratic to be + transformed. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import transformation_to_DN + >>> A, B = transformation_to_DN(x**2 - 3*x*y - y**2 - 2*y + 1) + >>> A + Matrix([ + [1/26, 3/26], + [ 0, 1/13]]) + >>> B + Matrix([ + [-6/13], + [-4/13]]) + + A, B returned are such that Transpose((x y)) = A * Transpose((X Y)) + B. + Substituting these values for `x` and `y` and a bit of simplifying work + will give an equation of the form `x^2 - Dy^2 = N`. + + >>> from sympy.abc import X, Y + >>> from sympy import Matrix, simplify + >>> u = (A*Matrix([X, Y]) + B)[0] # Transformation for x + >>> u + X/26 + 3*Y/26 - 6/13 + >>> v = (A*Matrix([X, Y]) + B)[1] # Transformation for y + >>> v + Y/13 - 4/13 + + Next we will substitute these formulas for `x` and `y` and do + ``simplify()``. + + >>> eq = simplify((x**2 - 3*x*y - y**2 - 2*y + 1).subs(zip((x, y), (u, v)))) + >>> eq + X**2/676 - Y**2/52 + 17/13 + + By multiplying the denominator appropriately, we can get a Pell equation + in the standard form. + + >>> eq * 676 + X**2 - 13*Y**2 + 884 + + If only the final equation is needed, ``find_DN()`` can be used. + + See Also + ======== + + find_DN() + + References + ========== + + .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0, + John P.Robertson, May 8, 2003, Page 7 - 11. + https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + """ + + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == BinaryQuadratic.name: + return _transformation_to_DN(var, coeff) + + +def _transformation_to_DN(var, coeff): + + x, y = var + + a = coeff[x**2] + b = coeff[x*y] + c = coeff[y**2] + d = coeff[x] + e = coeff[y] + f = coeff[1] + + a, b, c, d, e, f = [as_int(i) for i in _remove_gcd(a, b, c, d, e, f)] + + X, Y = symbols("X, Y", integer=True) + + if b: + B, C = _rational_pq(2*a, b) + A, T = _rational_pq(a, B**2) + + # eq_1 = A*B*X**2 + B*(c*T - A*C**2)*Y**2 + d*T*X + (B*e*T - d*T*C)*Y + f*T*B + coeff = {X**2: A*B, X*Y: 0, Y**2: B*(c*T - A*C**2), X: d*T, Y: B*e*T - d*T*C, 1: f*T*B} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*A_0, Matrix(2, 2, [S.One/B, -S(C)/B, 0, 1])*B_0 + + if d: + B, C = _rational_pq(2*a, d) + A, T = _rational_pq(a, B**2) + + # eq_2 = A*X**2 + c*T*Y**2 + e*T*Y + f*T - A*C**2 + coeff = {X**2: A, X*Y: 0, Y**2: c*T, X: 0, Y: e*T, 1: f*T - A*C**2} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [S.One/B, 0, 0, 1])*A_0, Matrix(2, 2, [S.One/B, 0, 0, 1])*B_0 + Matrix([-S(C)/B, 0]) + + if e: + B, C = _rational_pq(2*c, e) + A, T = _rational_pq(c, B**2) + + # eq_3 = a*T*X**2 + A*Y**2 + f*T - A*C**2 + coeff = {X**2: a*T, X*Y: 0, Y**2: A, X: 0, Y: 0, 1: f*T - A*C**2} + A_0, B_0 = _transformation_to_DN([X, Y], coeff) + return Matrix(2, 2, [1, 0, 0, S.One/B])*A_0, Matrix(2, 2, [1, 0, 0, S.One/B])*B_0 + Matrix([0, -S(C)/B]) + + # TODO: pre-simplification: Not necessary but may simplify + # the equation. + return Matrix(2, 2, [S.One/a, 0, 0, 1]), Matrix([0, 0]) + + +def find_DN(eq): + """ + This function returns a tuple, `(D, N)` of the simplified form, + `x^2 - Dy^2 = N`, corresponding to the general quadratic, + `ax^2 + bxy + cy^2 + dx + ey + f = 0`. + + Solving the general quadratic is then equivalent to solving the equation + `X^2 - DY^2 = N` and transforming the solutions by using the transformation + matrices returned by ``transformation_to_DN()``. + + Usage + ===== + + ``find_DN(eq)``: where ``eq`` is the quadratic to be transformed. + + Examples + ======== + + >>> from sympy.abc import x, y + >>> from sympy.solvers.diophantine.diophantine import find_DN + >>> find_DN(x**2 - 3*x*y - y**2 - 2*y + 1) + (13, -884) + + Interpretation of the output is that we get `X^2 -13Y^2 = -884` after + transforming `x^2 - 3xy - y^2 - 2y + 1` using the transformation returned + by ``transformation_to_DN()``. + + See Also + ======== + + transformation_to_DN() + + References + ========== + + .. [1] Solving the equation ax^2 + bxy + cy^2 + dx + ey + f = 0, + John P.Robertson, May 8, 2003, Page 7 - 11. + https://web.archive.org/web/20160323033111/http://www.jpr2718.org/ax2p.pdf + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == BinaryQuadratic.name: + return _find_DN(var, coeff) + + +def _find_DN(var, coeff): + + x, y = var + X, Y = symbols("X, Y", integer=True) + A, B = _transformation_to_DN(var, coeff) + + u = (A*Matrix([X, Y]) + B)[0] + v = (A*Matrix([X, Y]) + B)[1] + eq = x**2*coeff[x**2] + x*y*coeff[x*y] + y**2*coeff[y**2] + x*coeff[x] + y*coeff[y] + coeff[1] + + simplified = _mexpand(eq.subs(zip((x, y), (u, v)))) + + coeff = simplified.as_coefficients_dict() + + return -coeff[Y**2]/coeff[X**2], -coeff[1]/coeff[X**2] + + +def check_param(x, y, a, params): + """ + If there is a number modulo ``a`` such that ``x`` and ``y`` are both + integers, then return a parametric representation for ``x`` and ``y`` + else return (None, None). + + Here ``x`` and ``y`` are functions of ``t``. + """ + from sympy.simplify.simplify import clear_coefficients + + if x.is_number and not x.is_Integer: + return DiophantineSolutionSet([x, y], parameters=params) + + if y.is_number and not y.is_Integer: + return DiophantineSolutionSet([x, y], parameters=params) + + m, n = symbols("m, n", integer=True) + c, p = (m*x + n*y).as_content_primitive() + if a % c.q: + return DiophantineSolutionSet([x, y], parameters=params) + + # clear_coefficients(mx + b, R)[1] -> (R - b)/m + eq = clear_coefficients(x, m)[1] - clear_coefficients(y, n)[1] + junk, eq = eq.as_content_primitive() + + return _diop_solve(eq, params=params) + + +def diop_ternary_quadratic(eq, parameterize=False): + """ + Solves the general quadratic ternary form, + `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`. + + Returns a tuple `(x, y, z)` which is a base solution for the above + equation. If there are no solutions, `(None, None, None)` is returned. + + Usage + ===== + + ``diop_ternary_quadratic(eq)``: Return a tuple containing a basic solution + to ``eq``. + + Details + ======= + + ``eq`` should be an homogeneous expression of degree two in three variables + and it is assumed to be zero. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic + >>> diop_ternary_quadratic(x**2 + 3*y**2 - z**2) + (1, 0, 1) + >>> diop_ternary_quadratic(4*x**2 + 5*y**2 - z**2) + (1, 0, 2) + >>> diop_ternary_quadratic(45*x**2 - 7*y**2 - 8*x*y - z**2) + (28, 45, 105) + >>> diop_ternary_quadratic(x**2 - 49*y**2 - z**2 + 13*z*y -8*x*y) + (9, 1, 5) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + HomogeneousTernaryQuadratic.name, + HomogeneousTernaryQuadraticNormal.name): + sol = _diop_ternary_quadratic(var, coeff) + if len(sol) > 0: + x_0, y_0, z_0 = list(sol)[0] + else: + x_0, y_0, z_0 = None, None, None + + if parameterize: + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + return x_0, y_0, z_0 + + +def _diop_ternary_quadratic(_var, coeff): + eq = sum(i*coeff[i] for i in coeff) + if HomogeneousTernaryQuadratic(eq).matches(): + return HomogeneousTernaryQuadratic(eq, free_symbols=_var).solve() + elif HomogeneousTernaryQuadraticNormal(eq).matches(): + return HomogeneousTernaryQuadraticNormal(eq, free_symbols=_var).solve() + + +def transformation_to_normal(eq): + """ + Returns the transformation Matrix that converts a general ternary + quadratic equation ``eq`` (`ax^2 + by^2 + cz^2 + dxy + eyz + fxz`) + to a form without cross terms: `ax^2 + by^2 + cz^2 = 0`. This is + not used in solving ternary quadratics; it is only implemented for + the sake of completeness. + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + "homogeneous_ternary_quadratic", + "homogeneous_ternary_quadratic_normal"): + return _transformation_to_normal(var, coeff) + + +def _transformation_to_normal(var, coeff): + + _var = list(var) # copy + x, y, z = var + + if not any(coeff[i**2] for i in var): + # https://math.stackexchange.com/questions/448051/transform-quadratic-ternary-form-to-normal-form/448065#448065 + a = coeff[x*y] + b = coeff[y*z] + c = coeff[x*z] + swap = False + if not a: # b can't be 0 or else there aren't 3 vars + swap = True + a, b = b, a + T = Matrix(((1, 1, -b/a), (1, -1, -c/a), (0, 0, 1))) + if swap: + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + if coeff[x**2] == 0: + # If the coefficient of x is zero change the variables + if coeff[y**2] == 0: + _var[0], _var[2] = var[2], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 2) + T.col_swap(0, 2) + return T + + _var[0], _var[1] = var[1], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + # Apply the transformation x --> X - (B*Y + C*Z)/(2*A) + if coeff[x*y] != 0 or coeff[x*z] != 0: + A = coeff[x**2] + B = coeff[x*y] + C = coeff[x*z] + D = coeff[y**2] + E = coeff[y*z] + F = coeff[z**2] + + _coeff = {} + + _coeff[x**2] = 4*A**2 + _coeff[y**2] = 4*A*D - B**2 + _coeff[z**2] = 4*A*F - C**2 + _coeff[y*z] = 4*A*E - 2*B*C + _coeff[x*y] = 0 + _coeff[x*z] = 0 + + T_0 = _transformation_to_normal(_var, _coeff) + return Matrix(3, 3, [1, S(-B)/(2*A), S(-C)/(2*A), 0, 1, 0, 0, 0, 1])*T_0 + + elif coeff[y*z] != 0: + if coeff[y**2] == 0: + if coeff[z**2] == 0: + # Equations of the form A*x**2 + E*yz = 0. + # Apply transformation y -> Y + Z ans z -> Y - Z + return Matrix(3, 3, [1, 0, 0, 0, 1, 1, 0, 1, -1]) + + # Ax**2 + E*y*z + F*z**2 = 0 + _var[0], _var[2] = var[2], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 2) + T.col_swap(0, 2) + return T + + # A*x**2 + D*y**2 + E*y*z + F*z**2 = 0, F may be zero + _var[0], _var[1] = var[1], var[0] + T = _transformation_to_normal(_var, coeff) + T.row_swap(0, 1) + T.col_swap(0, 1) + return T + + return Matrix.eye(3) + + +def parametrize_ternary_quadratic(eq): + """ + Returns the parametrized general solution for the ternary quadratic + equation ``eq`` which has the form + `ax^2 + by^2 + cz^2 + fxy + gyz + hxz = 0`. + + Examples + ======== + + >>> from sympy import Tuple, ordered + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import parametrize_ternary_quadratic + + The parametrized solution may be returned with three parameters: + + >>> parametrize_ternary_quadratic(2*x**2 + y**2 - 2*z**2) + (p**2 - 2*q**2, -2*p**2 + 4*p*q - 4*p*r - 4*q**2, p**2 - 4*p*q + 2*q**2 - 4*q*r) + + There might also be only two parameters: + + >>> parametrize_ternary_quadratic(4*x**2 + 2*y**2 - 3*z**2) + (2*p**2 - 3*q**2, -4*p**2 + 12*p*q - 6*q**2, 4*p**2 - 8*p*q + 6*q**2) + + Notes + ===== + + Consider ``p`` and ``q`` in the previous 2-parameter + solution and observe that more than one solution can be represented + by a given pair of parameters. If `p` and ``q`` are not coprime, this is + trivially true since the common factor will also be a common factor of the + solution values. But it may also be true even when ``p`` and + ``q`` are coprime: + + >>> sol = Tuple(*_) + >>> p, q = ordered(sol.free_symbols) + >>> sol.subs([(p, 3), (q, 2)]) + (6, 12, 12) + >>> sol.subs([(q, 1), (p, 1)]) + (-1, 2, 2) + >>> sol.subs([(q, 0), (p, 1)]) + (2, -4, 4) + >>> sol.subs([(q, 1), (p, 0)]) + (-3, -6, 6) + + Except for sign and a common factor, these are equivalent to + the solution of (1, 2, 2). + + References + ========== + + .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart, + London Mathematical Society Student Texts 41, Cambridge University + Press, Cambridge, 1998. + + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type in ( + "homogeneous_ternary_quadratic", + "homogeneous_ternary_quadratic_normal"): + x_0, y_0, z_0 = list(_diop_ternary_quadratic(var, coeff))[0] + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + + +def _parametrize_ternary_quadratic(solution, _var, coeff): + # called for a*x**2 + b*y**2 + c*z**2 + d*x*y + e*y*z + f*x*z = 0 + assert 1 not in coeff + + x_0, y_0, z_0 = solution + + v = list(_var) # copy + + if x_0 is None: + return (None, None, None) + + if solution.count(0) >= 2: + # if there are 2 zeros the equation reduces + # to k*X**2 == 0 where X is x, y, or z so X must + # be zero, too. So there is only the trivial + # solution. + return (None, None, None) + + if x_0 == 0: + v[0], v[1] = v[1], v[0] + y_p, x_p, z_p = _parametrize_ternary_quadratic( + (y_0, x_0, z_0), v, coeff) + return x_p, y_p, z_p + + x, y, z = v + r, p, q = symbols("r, p, q", integer=True) + + eq = sum(k*v for k, v in coeff.items()) + eq_1 = _mexpand(eq.subs(zip( + (x, y, z), (r*x_0, r*y_0 + p, r*z_0 + q)))) + A, B = eq_1.as_independent(r, as_Add=True) + + + x = A*x_0 + y = (A*y_0 - _mexpand(B/r*p)) + z = (A*z_0 - _mexpand(B/r*q)) + + return _remove_gcd(x, y, z) + + +def diop_ternary_quadratic_normal(eq, parameterize=False): + """ + Solves the quadratic ternary diophantine equation, + `ax^2 + by^2 + cz^2 = 0`. + + Explanation + =========== + + Here the coefficients `a`, `b`, and `c` should be non zero. Otherwise the + equation will be a quadratic binary or univariate equation. If solvable, + returns a tuple `(x, y, z)` that satisfies the given equation. If the + equation does not have integer solutions, `(None, None, None)` is returned. + + Usage + ===== + + ``diop_ternary_quadratic_normal(eq)``: where ``eq`` is an equation of the form + `ax^2 + by^2 + cz^2 = 0`. + + Examples + ======== + + >>> from sympy.abc import x, y, z + >>> from sympy.solvers.diophantine.diophantine import diop_ternary_quadratic_normal + >>> diop_ternary_quadratic_normal(x**2 + 3*y**2 - z**2) + (1, 0, 1) + >>> diop_ternary_quadratic_normal(4*x**2 + 5*y**2 - z**2) + (1, 0, 2) + >>> diop_ternary_quadratic_normal(34*x**2 - 3*y**2 - 301*z**2) + (4, 9, 1) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + if diop_type == HomogeneousTernaryQuadraticNormal.name: + sol = _diop_ternary_quadratic_normal(var, coeff) + if len(sol) > 0: + x_0, y_0, z_0 = list(sol)[0] + else: + x_0, y_0, z_0 = None, None, None + if parameterize: + return _parametrize_ternary_quadratic( + (x_0, y_0, z_0), var, coeff) + return x_0, y_0, z_0 + + +def _diop_ternary_quadratic_normal(var, coeff): + eq = sum(i * coeff[i] for i in coeff) + return HomogeneousTernaryQuadraticNormal(eq, free_symbols=var).solve() + + +def sqf_normal(a, b, c, steps=False): + """ + Return `a', b', c'`, the coefficients of the square-free normal + form of `ax^2 + by^2 + cz^2 = 0`, where `a', b', c'` are pairwise + prime. If `steps` is True then also return three tuples: + `sq`, `sqf`, and `(a', b', c')` where `sq` contains the square + factors of `a`, `b` and `c` after removing the `gcd(a, b, c)`; + `sqf` contains the values of `a`, `b` and `c` after removing + both the `gcd(a, b, c)` and the square factors. + + The solutions for `ax^2 + by^2 + cz^2 = 0` can be + recovered from the solutions of `a'x^2 + b'y^2 + c'z^2 = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sqf_normal + >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11) + (11, 1, 5) + >>> sqf_normal(2 * 3**2 * 5, 2 * 5 * 11, 2 * 7**2 * 11, True) + ((3, 1, 7), (5, 55, 11), (11, 1, 5)) + + References + ========== + + .. [1] Legendre's Theorem, Legrange's Descent, + https://public.csusm.edu/aitken_html/notes/legendre.pdf + + + See Also + ======== + + reconstruct() + """ + ABC = _remove_gcd(a, b, c) + sq = tuple(square_factor(i) for i in ABC) + sqf = A, B, C = tuple([i//j**2 for i,j in zip(ABC, sq)]) + pc = igcd(A, B) + A /= pc + B /= pc + pa = igcd(B, C) + B /= pa + C /= pa + pb = igcd(A, C) + A /= pb + B /= pb + + A *= pa + B *= pb + C *= pc + + if steps: + return (sq, sqf, (A, B, C)) + else: + return A, B, C + + +def square_factor(a): + r""" + Returns an integer `c` s.t. `a = c^2k, \ c,k \in Z`. Here `k` is square + free. `a` can be given as an integer or a dictionary of factors. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import square_factor + >>> square_factor(24) + 2 + >>> square_factor(-36*3) + 6 + >>> square_factor(1) + 1 + >>> square_factor({3: 2, 2: 1, -1: 1}) # -18 + 3 + + See Also + ======== + sympy.ntheory.factor_.core + """ + f = a if isinstance(a, dict) else factorint(a) + return Mul(*[p**(e//2) for p, e in f.items()]) + + +def reconstruct(A, B, z): + """ + Reconstruct the `z` value of an equivalent solution of `ax^2 + by^2 + cz^2` + from the `z` value of a solution of the square-free normal form of the + equation, `a'*x^2 + b'*y^2 + c'*z^2`, where `a'`, `b'` and `c'` are square + free and `gcd(a', b', c') == 1`. + """ + f = factorint(igcd(A, B)) + for p, e in f.items(): + if e != 1: + raise ValueError('a and b should be square-free') + z *= p + return z + + +def ldescent(A, B): + """ + Return a non-trivial solution to `w^2 = Ax^2 + By^2` using + Lagrange's method; return None if there is no such solution. + + Parameters + ========== + + A : Integer + B : Integer + non-zero integer + + Returns + ======= + + (int, int, int) | None : a tuple `(w_0, x_0, y_0)` which is a solution to the above equation. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import ldescent + >>> ldescent(1, 1) # w^2 = x^2 + y^2 + (1, 1, 0) + >>> ldescent(4, -7) # w^2 = 4x^2 - 7y^2 + (2, -1, 0) + + This means that `x = -1, y = 0` and `w = 2` is a solution to the equation + `w^2 = 4x^2 - 7y^2` + + >>> ldescent(5, -1) # w^2 = 5x^2 - y^2 + (2, 1, -1) + + References + ========== + + .. [1] The algorithmic resolution of Diophantine equations, Nigel P. Smart, + London Mathematical Society Student Texts 41, Cambridge University + Press, Cambridge, 1998. + .. [2] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + if A == 0 or B == 0: + raise ValueError("A and B must be non-zero integers") + if abs(A) > abs(B): + w, y, x = ldescent(B, A) + return w, x, y + if A == 1: + return (1, 1, 0) + if B == 1: + return (1, 0, 1) + if B == -1: # and A == -1 + return + + r = sqrt_mod(A, B) + if r is None: + return + Q = (r**2 - A) // B + if Q == 0: + return r, -1, 0 + for i in divisors(Q): + d, _exact = integer_nthroot(abs(Q) // i, 2) + if _exact: + B_0 = sign(Q)*i + W, X, Y = ldescent(A, B_0) + return _remove_gcd(-A*X + r*W, r*X - W, Y*B_0*d) + + +def descent(A, B): + """ + Returns a non-trivial solution, (x, y, z), to `x^2 = Ay^2 + Bz^2` + using Lagrange's descent method with lattice-reduction. `A` and `B` + are assumed to be valid for such a solution to exist. + + This is faster than the normal Lagrange's descent algorithm because + the Gaussian reduction is used. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import descent + >>> descent(3, 1) # x**2 = 3*y**2 + z**2 + (1, 0, 1) + + `(x, y, z) = (1, 0, 1)` is a solution to the above equation. + + >>> descent(41, -113) + (-16, -3, 1) + + References + ========== + + .. [1] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + if abs(A) > abs(B): + x, y, z = descent(B, A) + return x, z, y + + if B == 1: + return (1, 0, 1) + if A == 1: + return (1, 1, 0) + if B == -A: + return (0, 1, 1) + if B == A: + x, z, y = descent(-1, A) + return (A*y, z, x) + + w = sqrt_mod(A, B) + x_0, z_0 = gaussian_reduce(w, A, B) + + t = (x_0**2 - A*z_0**2) // B + t_2 = square_factor(t) + t_1 = t // t_2**2 + + x_1, z_1, y_1 = descent(A, t_1) + + return _remove_gcd(x_0*x_1 + A*z_0*z_1, z_0*x_1 + x_0*z_1, t_1*t_2*y_1) + + +def gaussian_reduce(w:int, a:int, b:int) -> tuple[int, int]: + r""" + Returns a reduced solution `(x, z)` to the congruence + `X^2 - aZ^2 \equiv 0 \pmod{b}` so that `x^2 + |a|z^2` is as small as possible. + Here ``w`` is a solution of the congruence `x^2 \equiv a \pmod{b}`. + + This function is intended to be used only for ``descent()``. + + Explanation + =========== + + The Gaussian reduction can find the shortest vector for any norm. + So we define the special norm for the vectors `u = (u_1, u_2)` and `v = (v_1, v_2)` as follows. + + .. math :: + u \cdot v := (wu_1 + bu_2)(wv_1 + bv_2) + |a|u_1v_1 + + Note that, given the mapping `f: (u_1, u_2) \to (wu_1 + bu_2, u_1)`, + `f((u_1,u_2))` is the solution to `X^2 - aZ^2 \equiv 0 \pmod{b}`. + In other words, finding the shortest vector in this norm will yield a solution with smaller `X^2 + |a|Z^2`. + The algorithm starts from basis vectors `(0, 1)` and `(1, 0)` + (corresponding to solutions `(b, 0)` and `(w, 1)`, respectively) and finds the shortest vector. + The shortest vector does not necessarily correspond to the smallest solution, + but since ``descent()`` only wants the smallest possible solution, it is sufficient. + + Parameters + ========== + + w : int + ``w`` s.t. `w^2 \equiv a \pmod{b}` + a : int + square-free nonzero integer + b : int + square-free nonzero integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import gaussian_reduce + >>> from sympy.ntheory.residue_ntheory import sqrt_mod + >>> a, b = 19, 101 + >>> gaussian_reduce(sqrt_mod(a, b), a, b) # 1**2 - 19*(-4)**2 = -303 + (1, -4) + >>> a, b = 11, 14 + >>> x, z = gaussian_reduce(sqrt_mod(a, b), a, b) + >>> (x**2 - a*z**2) % b == 0 + True + + It does not always return the smallest solution. + + >>> a, b = 6, 95 + >>> min_x, min_z = 1, 4 + >>> x, z = gaussian_reduce(sqrt_mod(a, b), a, b) + >>> (x**2 - a*z**2) % b == 0 and (min_x**2 - a*min_z**2) % b == 0 + True + >>> min_x**2 + abs(a)*min_z**2 < x**2 + abs(a)*z**2 + True + + References + ========== + + .. [1] Gaussian lattice Reduction [online]. Available: + https://web.archive.org/web/20201021115213/http://home.ie.cuhk.edu.hk/~wkshum/wordpress/?p=404 + .. [2] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + """ + a = abs(a) + def _dot(u, v): + return u[0]*v[0] + a*u[1]*v[1] + + u = (b, 0) + v = (w, 1) if b*w >= 0 else (-w, -1) + # i.e., _dot(u, v) >= 0 + + if b**2 < w**2 + a: + u, v = v, u + # i.e., norm(u) >= norm(v), where norm(u) := sqrt(_dot(u, u)) + + while _dot(u, u) > (dv := _dot(v, v)): + k = _dot(u, v) // dv + u, v = v, (u[0] - k*v[0], u[1] - k*v[1]) + c = (v[0] - u[0], v[1] - u[1]) + if _dot(c, c) <= _dot(u, u) <= 2*_dot(u, v): + return c + return u + + +def holzer(x, y, z, a, b, c): + r""" + Simplify the solution `(x, y, z)` of the equation + `ax^2 + by^2 = cz^2` with `a, b, c > 0` and `z^2 \geq \mid ab \mid` to + a new reduced solution `(x', y', z')` such that `z'^2 \leq \mid ab \mid`. + + The algorithm is an interpretation of Mordell's reduction as described + on page 8 of Cremona and Rusin's paper [1]_ and the work of Mordell in + reference [2]_. + + References + ========== + + .. [1] Cremona, J. E., Rusin, D. (2003). Efficient Solution of Rational Conics. + Mathematics of Computation, 72(243), 1417-1441. + https://doi.org/10.1090/S0025-5718-02-01480-1 + .. [2] Diophantine Equations, L. J. Mordell, page 48. + + """ + + if _odd(c): + k = 2*c + else: + k = c//2 + + small = a*b*c + step = 0 + while True: + t1, t2, t3 = a*x**2, b*y**2, c*z**2 + # check that it's a solution + if t1 + t2 != t3: + if step == 0: + raise ValueError('bad starting solution') + break + x_0, y_0, z_0 = x, y, z + if max(t1, t2, t3) <= small: + # Holzer condition + break + + uv = u, v = base_solution_linear(k, y_0, -x_0) + if None in uv: + break + + p, q = -(a*u*x_0 + b*v*y_0), c*z_0 + r = Rational(p, q) + if _even(c): + w = _nint_or_floor(p, q) + assert abs(w - r) <= S.Half + else: + w = p//q # floor + if _odd(a*u + b*v + c*w): + w += 1 + assert abs(w - r) <= S.One + + A = (a*u**2 + b*v**2 + c*w**2) + B = (a*u*x_0 + b*v*y_0 + c*w*z_0) + x = Rational(x_0*A - 2*u*B, k) + y = Rational(y_0*A - 2*v*B, k) + z = Rational(z_0*A - 2*w*B, k) + assert all(i.is_Integer for i in (x, y, z)) + step += 1 + + return tuple([int(i) for i in (x_0, y_0, z_0)]) + + +def diop_general_pythagorean(eq, param=symbols("m", integer=True)): + """ + Solves the general pythagorean equation, + `a_{1}^2x_{1}^2 + a_{2}^2x_{2}^2 + . . . + a_{n}^2x_{n}^2 - a_{n + 1}^2x_{n + 1}^2 = 0`. + + Returns a tuple which contains a parametrized solution to the equation, + sorted in the same order as the input variables. + + Usage + ===== + + ``diop_general_pythagorean(eq, param)``: where ``eq`` is a general + pythagorean equation which is assumed to be zero and ``param`` is the base + parameter used to construct other parameters by subscripting. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_pythagorean + >>> from sympy.abc import a, b, c, d, e + >>> diop_general_pythagorean(a**2 + b**2 + c**2 - d**2) + (m1**2 + m2**2 - m3**2, 2*m1*m3, 2*m2*m3, m1**2 + m2**2 + m3**2) + >>> diop_general_pythagorean(9*a**2 - 4*b**2 + 16*c**2 + 25*d**2 + e**2) + (10*m1**2 + 10*m2**2 + 10*m3**2 - 10*m4**2, 15*m1**2 + 15*m2**2 + 15*m3**2 + 15*m4**2, 15*m1*m4, 12*m2*m4, 60*m3*m4) + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralPythagorean.name: + if param is None: + params = None + else: + params = symbols('%s1:%i' % (param, len(var)), integer=True) + return list(GeneralPythagorean(eq).solve(parameters=params))[0] + + +def diop_general_sum_of_squares(eq, limit=1): + r""" + Solves the equation `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Returns at most ``limit`` number of solutions. + + Usage + ===== + + ``general_sum_of_squares(eq, limit)`` : Here ``eq`` is an expression which + is assumed to be zero. Also, ``eq`` should be in the form, + `x_{1}^2 + x_{2}^2 + . . . + x_{n}^2 - k = 0`. + + Details + ======= + + When `n = 3` if `k = 4^a(8m + 7)` for some `a, m \in Z` then there will be + no solutions. Refer to [1]_ for more details. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_squares + >>> from sympy.abc import a, b, c, d, e + >>> diop_general_sum_of_squares(a**2 + b**2 + c**2 + d**2 + e**2 - 2345) + {(15, 22, 22, 24, 24)} + + Reference + ========= + + .. [1] Representing an integer as a sum of three squares, [online], + Available: + https://www.proofwiki.org/wiki/Integer_as_Sum_of_Three_Squares + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralSumOfSquares.name: + return set(GeneralSumOfSquares(eq).solve(limit=limit)) + + +def diop_general_sum_of_even_powers(eq, limit=1): + """ + Solves the equation `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0` + where `e` is an even, integer power. + + Returns at most ``limit`` number of solutions. + + Usage + ===== + + ``general_sum_of_even_powers(eq, limit)`` : Here ``eq`` is an expression which + is assumed to be zero. Also, ``eq`` should be in the form, + `x_{1}^e + x_{2}^e + . . . + x_{n}^e - k = 0`. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import diop_general_sum_of_even_powers + >>> from sympy.abc import a, b + >>> diop_general_sum_of_even_powers(a**4 + b**4 - (2**4 + 3**4)) + {(2, 3)} + + See Also + ======== + + power_representation + """ + var, coeff, diop_type = classify_diop(eq, _dict=False) + + if diop_type == GeneralSumOfEvenPowers.name: + return set(GeneralSumOfEvenPowers(eq).solve(limit=limit)) + + +## Functions below this comment can be more suitably grouped under +## an Additive number theory module rather than the Diophantine +## equation module. + + +def partition(n, k=None, zeros=False): + """ + Returns a generator that can be used to generate partitions of an integer + `n`. + + Explanation + =========== + + A partition of `n` is a set of positive integers which add up to `n`. For + example, partitions of 3 are 3, 1 + 2, 1 + 1 + 1. A partition is returned + as a tuple. If ``k`` equals None, then all possible partitions are returned + irrespective of their size, otherwise only the partitions of size ``k`` are + returned. If the ``zero`` parameter is set to True then a suitable + number of zeros are added at the end of every partition of size less than + ``k``. + + ``zero`` parameter is considered only if ``k`` is not None. When the + partitions are over, the last `next()` call throws the ``StopIteration`` + exception, so this function should always be used inside a try - except + block. + + Details + ======= + + ``partition(n, k)``: Here ``n`` is a positive integer and ``k`` is the size + of the partition which is also positive integer. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import partition + >>> f = partition(5) + >>> next(f) + (1, 1, 1, 1, 1) + >>> next(f) + (1, 1, 1, 2) + >>> g = partition(5, 3) + >>> next(g) + (1, 1, 3) + >>> next(g) + (1, 2, 2) + >>> g = partition(5, 3, zeros=True) + >>> next(g) + (0, 0, 5) + + """ + if not zeros or k is None: + for i in ordered_partitions(n, k): + yield tuple(i) + else: + for m in range(1, k + 1): + for i in ordered_partitions(n, m): + i = tuple(i) + yield (0,)*(k - len(i)) + i + + +def prime_as_sum_of_two_squares(p): + """ + Represent a prime `p` as a unique sum of two squares; this can + only be done if the prime is congruent to 1 mod 4. + + Parameters + ========== + + p : Integer + A prime that is congruent to 1 mod 4 + + Returns + ======= + + (int, int) | None : Pair of positive integers ``(x, y)`` satisfying ``x**2 + y**2 = p``. + None if ``p`` is not congruent to 1 mod 4. + + Raises + ====== + + ValueError + If ``p`` is not prime number + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import prime_as_sum_of_two_squares + >>> prime_as_sum_of_two_squares(7) # can't be done + >>> prime_as_sum_of_two_squares(5) + (1, 2) + + Reference + ========= + + .. [1] Representing a number as a sum of four squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + sum_of_squares + + """ + p = as_int(p) + if p % 4 != 1: + return + if not isprime(p): + raise ValueError("p should be a prime number") + + if p % 8 == 5: + # Legendre symbol (2/p) == -1 if p % 8 in [3, 5] + b = 2 + elif p % 12 == 5: + # Legendre symbol (3/p) == -1 if p % 12 in [5, 7] + b = 3 + elif p % 5 in [2, 3]: + # Legendre symbol (5/p) == -1 if p % 5 in [2, 3] + b = 5 + else: + b = 7 + while jacobi(b, p) == 1: + b = nextprime(b) + + b = pow(b, p >> 2, p) + a = p + while b**2 > p: + a, b = b, a % b + return (int(a % b), int(b)) # convert from long + + +def sum_of_three_squares(n): + r""" + Returns a 3-tuple $(a, b, c)$ such that $a^2 + b^2 + c^2 = n$ and + $a, b, c \geq 0$. + + Returns None if $n = 4^a(8m + 7)$ for some `a, m \in \mathbb{Z}`. See + [1]_ for more details. + + Parameters + ========== + + n : Integer + non-negative integer + + Returns + ======= + + (int, int, int) | None : 3-tuple non-negative integers ``(a, b, c)`` satisfying ``a**2 + b**2 + c**2 = n``. + a,b,c are sorted in ascending order. ``None`` if no such ``(a,b,c)``. + + Raises + ====== + + ValueError + If ``n`` is a negative integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_three_squares + >>> sum_of_three_squares(44542) + (18, 37, 207) + + References + ========== + + .. [1] Representing a number as a sum of three squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + power_representation : + ``sum_of_three_squares(n)`` is one of the solutions output by ``power_representation(n, 2, 3, zeros=True)`` + + """ + # https://math.stackexchange.com/questions/483101/rabin-and-shallit-algorithm/651425#651425 + # discusses these numbers (except for 1, 2, 3) as the exceptions of H&L's conjecture that + # Every sufficiently large number n is either a square or the sum of a prime and a square. + special = {1: (0, 0, 1), 2: (0, 1, 1), 3: (1, 1, 1), 10: (0, 1, 3), 34: (3, 3, 4), + 58: (0, 3, 7), 85: (0, 6, 7), 130: (0, 3, 11), 214: (3, 6, 13), 226: (8, 9, 9), + 370: (8, 9, 15), 526: (6, 7, 21), 706: (15, 15, 16), 730: (0, 1, 27), + 1414: (6, 17, 33), 1906: (13, 21, 36), 2986: (21, 32, 39), 9634: (56, 57, 57)} + n = as_int(n) + if n < 0: + raise ValueError("n should be a non-negative integer") + if n == 0: + return (0, 0, 0) + n, v = remove(n, 4) + v = 1 << v + if n % 8 == 7: + return + if n in special: + return tuple([v*i for i in special[n]]) + + s, _exact = integer_nthroot(n, 2) + if _exact: + return (0, 0, v*s) + if n % 8 == 3: + if not s % 2: + s -= 1 + for x in range(s, -1, -2): + N = (n - x**2) // 2 + if isprime(N): + # n % 8 == 3 and x % 2 == 1 => N % 4 == 1 + y, z = prime_as_sum_of_two_squares(N) + return tuple(sorted([v*x, v*(y + z), v*abs(y - z)])) + # We will never reach this point because there must be a solution. + assert False + + # assert n % 4 in [1, 2] + if not((n % 2) ^ (s % 2)): + s -= 1 + for x in range(s, -1, -2): + N = n - x**2 + if isprime(N): + # assert N % 4 == 1 + y, z = prime_as_sum_of_two_squares(N) + return tuple(sorted([v*x, v*y, v*z])) + # We will never reach this point because there must be a solution. + assert False + + +def sum_of_four_squares(n): + r""" + Returns a 4-tuple `(a, b, c, d)` such that `a^2 + b^2 + c^2 + d^2 = n`. + Here `a, b, c, d \geq 0`. + + Parameters + ========== + + n : Integer + non-negative integer + + Returns + ======= + + (int, int, int, int) : 4-tuple non-negative integers ``(a, b, c, d)`` satisfying ``a**2 + b**2 + c**2 + d**2 = n``. + a,b,c,d are sorted in ascending order. + + Raises + ====== + + ValueError + If ``n`` is a negative integer + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_four_squares + >>> sum_of_four_squares(3456) + (8, 8, 32, 48) + >>> sum_of_four_squares(1294585930293) + (0, 1234, 2161, 1137796) + + References + ========== + + .. [1] Representing a number as a sum of four squares, [online], + Available: https://schorn.ch/lagrange.html + + See Also + ======== + + power_representation : + ``sum_of_four_squares(n)`` is one of the solutions output by ``power_representation(n, 2, 4, zeros=True)`` + + """ + n = as_int(n) + if n < 0: + raise ValueError("n should be a non-negative integer") + if n == 0: + return (0, 0, 0, 0) + # remove factors of 4 since a solution in terms of 3 squares is + # going to be returned; this is also done in sum_of_three_squares, + # but it needs to be done here to select d + n, v = remove(n, 4) + v = 1 << v + if n % 8 == 7: + d = 2 + n = n - 4 + elif n % 8 in (2, 6): + d = 1 + n = n - 1 + else: + d = 0 + x, y, z = sum_of_three_squares(n) # sorted + return tuple(sorted([v*d, v*x, v*y, v*z])) + + +def power_representation(n, p, k, zeros=False): + r""" + Returns a generator for finding k-tuples of integers, + `(n_{1}, n_{2}, . . . n_{k})`, such that + `n = n_{1}^p + n_{2}^p + . . . n_{k}^p`. + + Usage + ===== + + ``power_representation(n, p, k, zeros)``: Represent non-negative number + ``n`` as a sum of ``k`` ``p``\ th powers. If ``zeros`` is true, then the + solutions is allowed to contain zeros. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import power_representation + + Represent 1729 as a sum of two cubes: + + >>> f = power_representation(1729, 3, 2) + >>> next(f) + (9, 10) + >>> next(f) + (1, 12) + + If the flag `zeros` is True, the solution may contain tuples with + zeros; any such solutions will be generated after the solutions + without zeros: + + >>> list(power_representation(125, 2, 3, zeros=True)) + [(5, 6, 8), (3, 4, 10), (0, 5, 10), (0, 2, 11)] + + For even `p` the `permute_sign` function can be used to get all + signed values: + + >>> from sympy.utilities.iterables import permute_signs + >>> list(permute_signs((1, 12))) + [(1, 12), (-1, 12), (1, -12), (-1, -12)] + + All possible signed permutations can also be obtained: + + >>> from sympy.utilities.iterables import signed_permutations + >>> list(signed_permutations((1, 12))) + [(1, 12), (-1, 12), (1, -12), (-1, -12), (12, 1), (-12, 1), (12, -1), (-12, -1)] + """ + n, p, k = [as_int(i) for i in (n, p, k)] + + if n < 0: + if p % 2: + for t in power_representation(-n, p, k, zeros): + yield tuple(-i for i in t) + return + + if p < 1 or k < 1: + raise ValueError(filldedent(''' + Expecting positive integers for `(p, k)`, but got `(%s, %s)`''' + % (p, k))) + + if n == 0: + if zeros: + yield (0,)*k + return + + if k == 1: + if p == 1: + yield (n,) + elif n == 1: + yield (1,) + else: + be = perfect_power(n) + if be: + b, e = be + d, r = divmod(e, p) + if not r: + yield (b**d,) + return + + if p == 1: + yield from partition(n, k, zeros=zeros) + return + + if p == 2: + if k == 3: + n, v = remove(n, 4) + if v: + v = 1 << v + for t in power_representation(n, p, k, zeros): + yield tuple(i*v for i in t) + return + feasible = _can_do_sum_of_squares(n, k) + if not feasible: + return + if not zeros: + if n > 33 and k >= 5 and k <= n and n - k in ( + 13, 10, 7, 5, 4, 2, 1): + '''Todd G. Will, "When Is n^2 a Sum of k Squares?", [online]. + Available: https://www.maa.org/sites/default/files/Will-MMz-201037918.pdf''' + return + # quick tests since feasibility includes the possiblity of 0 + if k == 4 and (n in (1, 3, 5, 9, 11, 17, 29, 41) or remove(n, 4)[0] in (2, 6, 14)): + # A000534 + return + if k == 3 and n in (1, 2, 5, 10, 13, 25, 37, 58, 85, 130): # or n = some number >= 5*10**10 + # A051952 + return + if feasible is not True: # it's prime and k == 2 + yield prime_as_sum_of_two_squares(n) + return + + if k == 2 and p > 2: + be = perfect_power(n) + if be and be[1] % p == 0: + return # Fermat: a**n + b**n = c**n has no solution for n > 2 + + if n >= k: + a = integer_nthroot(n - (k - 1), p)[0] + for t in pow_rep_recursive(a, k, n, [], p): + yield tuple(reversed(t)) + + if zeros: + a = integer_nthroot(n, p)[0] + for i in range(1, k): + for t in pow_rep_recursive(a, i, n, [], p): + yield tuple(reversed(t + (0,)*(k - i))) + + +sum_of_powers = power_representation + + +def pow_rep_recursive(n_i, k, n_remaining, terms, p): + # Invalid arguments + if n_i <= 0 or k <= 0: + return + + # No solutions may exist + if n_remaining < k: + return + if k * pow(n_i, p) < n_remaining: + return + + if k == 0 and n_remaining == 0: + yield tuple(terms) + + elif k == 1: + # next_term^p must equal to n_remaining + next_term, exact = integer_nthroot(n_remaining, p) + if exact and next_term <= n_i: + yield tuple(terms + [next_term]) + return + + else: + # TODO: Fall back to diop_DN when k = 2 + if n_i >= 1 and k > 0: + for next_term in range(1, n_i + 1): + residual = n_remaining - pow(next_term, p) + if residual < 0: + break + yield from pow_rep_recursive(next_term, k - 1, residual, terms + [next_term], p) + + +def sum_of_squares(n, k, zeros=False): + """Return a generator that yields the k-tuples of nonnegative + values, the squares of which sum to n. If zeros is False (default) + then the solution will not contain zeros. The nonnegative + elements of a tuple are sorted. + + * If k == 1 and n is square, (n,) is returned. + + * If k == 2 then n can only be written as a sum of squares if + every prime in the factorization of n that has the form + 4*k + 3 has an even multiplicity. If n is prime then + it can only be written as a sum of two squares if it is + in the form 4*k + 1. + + * if k == 3 then n can be written as a sum of squares if it does + not have the form 4**m*(8*k + 7). + + * all integers can be written as the sum of 4 squares. + + * if k > 4 then n can be partitioned and each partition can + be written as a sum of 4 squares; if n is not evenly divisible + by 4 then n can be written as a sum of squares only if the + an additional partition can be written as sum of squares. + For example, if k = 6 then n is partitioned into two parts, + the first being written as a sum of 4 squares and the second + being written as a sum of 2 squares -- which can only be + done if the condition above for k = 2 can be met, so this will + automatically reject certain partitions of n. + + Examples + ======== + + >>> from sympy.solvers.diophantine.diophantine import sum_of_squares + >>> list(sum_of_squares(25, 2)) + [(3, 4)] + >>> list(sum_of_squares(25, 2, True)) + [(3, 4), (0, 5)] + >>> list(sum_of_squares(25, 4)) + [(1, 2, 2, 4)] + + See Also + ======== + + sympy.utilities.iterables.signed_permutations + """ + yield from power_representation(n, 2, k, zeros) + + +def _can_do_sum_of_squares(n, k): + """Return True if n can be written as the sum of k squares, + False if it cannot, or 1 if ``k == 2`` and ``n`` is prime (in which + case it *can* be written as a sum of two squares). A False + is returned only if it cannot be written as ``k``-squares, even + if 0s are allowed. + """ + if k < 1: + return False + if n < 0: + return False + if n == 0: + return True + if k == 1: + return is_square(n) + if k == 2: + if n in (1, 2): + return True + if isprime(n): + if n % 4 == 1: + return 1 # signal that it was prime + return False + # n is a composite number + # we can proceed iff no prime factor in the form 4*k + 3 + # has an odd multiplicity + return all(p % 4 !=3 or m % 2 == 0 for p, m in factorint(n).items()) + if k == 3: + return remove(n, 4)[0] % 8 != 7 + # every number can be written as a sum of 4 squares; for k > 4 partitions + # can be 0 + return True diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__init__.py b/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..920c187eea0a84189fd8546adc331c9bb04e0d0b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__pycache__/test_diophantine.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__pycache__/test_diophantine.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..241bf7b25bd8ad86c0d30f78d24dad2024d06796 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/__pycache__/test_diophantine.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py b/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py new file mode 100644 index 0000000000000000000000000000000000000000..094770b7bba795aef306ea71831388414fde935e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/diophantine/tests/test_diophantine.py @@ -0,0 +1,1051 @@ +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.numbers import (Rational, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.matrices.dense import Matrix +from sympy.ntheory.factor_ import factorint +from sympy.simplify.powsimp import powsimp +from sympy.core.function import _mexpand +from sympy.core.sorting import default_sort_key, ordered +from sympy.functions.elementary.trigonometric import sin +from sympy.solvers.diophantine import diophantine +from sympy.solvers.diophantine.diophantine import (diop_DN, + diop_solve, diop_ternary_quadratic_normal, + diop_general_pythagorean, diop_ternary_quadratic, diop_linear, + diop_quadratic, diop_general_sum_of_squares, diop_general_sum_of_even_powers, + descent, diop_bf_DN, divisible, equivalent, find_DN, ldescent, length, + reconstruct, partition, power_representation, + prime_as_sum_of_two_squares, square_factor, sum_of_four_squares, + sum_of_three_squares, transformation_to_DN, transformation_to_normal, + classify_diop, base_solution_linear, cornacchia, sqf_normal, gaussian_reduce, holzer, + check_param, parametrize_ternary_quadratic, sum_of_powers, sum_of_squares, + _diop_ternary_quadratic_normal, _nint_or_floor, + _odd, _even, _remove_gcd, _can_do_sum_of_squares, DiophantineSolutionSet, GeneralPythagorean, + BinaryQuadratic) + +from sympy.testing.pytest import slow, raises, XFAIL +from sympy.utilities.iterables import ( + signed_permutations) + +a, b, c, d, p, q, x, y, z, w, t, u, v, X, Y, Z = symbols( + "a, b, c, d, p, q, x, y, z, w, t, u, v, X, Y, Z", integer=True) +t_0, t_1, t_2, t_3, t_4, t_5, t_6 = symbols("t_:7", integer=True) +m1, m2, m3 = symbols('m1:4', integer=True) +n1 = symbols('n1', integer=True) + + +def diop_simplify(eq): + return _mexpand(powsimp(_mexpand(eq))) + + +def test_input_format(): + raises(TypeError, lambda: diophantine(sin(x))) + raises(TypeError, lambda: diophantine(x/pi - 3)) + + +def test_nosols(): + # diophantine should sympify eq so that these are equivalent + assert diophantine(3) == set() + assert diophantine(S(3)) == set() + + +def test_univariate(): + assert diop_solve((x - 1)*(x - 2)**2) == {(1,), (2,)} + assert diop_solve((x - 1)*(x - 2)) == {(1,), (2,)} + + +def test_classify_diop(): + raises(TypeError, lambda: classify_diop(x**2/3 - 1)) + raises(ValueError, lambda: classify_diop(1)) + raises(NotImplementedError, lambda: classify_diop(w*x*y*z - 1)) + raises(NotImplementedError, lambda: classify_diop(x**3 + y**3 + z**4 - 90)) + assert classify_diop(14*x**2 + 15*x - 42) == ( + [x], {1: -42, x: 15, x**2: 14}, 'univariate') + assert classify_diop(x*y + z) == ( + [x, y, z], {x*y: 1, z: 1}, 'inhomogeneous_ternary_quadratic') + assert classify_diop(x*y + z + w + x**2) == ( + [w, x, y, z], {x*y: 1, w: 1, x**2: 1, z: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + x*z + x**2 + 1) == ( + [x, y, z], {x*y: 1, x*z: 1, x**2: 1, 1: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + z + w + 42) == ( + [w, x, y, z], {x*y: 1, w: 1, 1: 42, z: 1}, 'inhomogeneous_general_quadratic') + assert classify_diop(x*y + z*w) == ( + [w, x, y, z], {x*y: 1, w*z: 1}, 'homogeneous_general_quadratic') + assert classify_diop(x*y**2 + 1) == ( + [x, y], {x*y**2: 1, 1: 1}, 'cubic_thue') + assert classify_diop(x**4 + y**4 + z**4 - (1 + 16 + 81)) == ( + [x, y, z], {1: -98, x**4: 1, z**4: 1, y**4: 1}, 'general_sum_of_even_powers') + assert classify_diop(x**2 + y**2 + z**2) == ( + [x, y, z], {x**2: 1, y**2: 1, z**2: 1}, 'homogeneous_ternary_quadratic_normal') + + +def test_linear(): + assert diop_solve(x) == (0,) + assert diop_solve(1*x) == (0,) + assert diop_solve(3*x) == (0,) + assert diop_solve(x + 1) == (-1,) + assert diop_solve(2*x + 1) == (None,) + assert diop_solve(2*x + 4) == (-2,) + assert diop_solve(y + x) == (t_0, -t_0) + assert diop_solve(y + x + 0) == (t_0, -t_0) + assert diop_solve(y + x - 0) == (t_0, -t_0) + assert diop_solve(0*x - y - 5) == (-5,) + assert diop_solve(3*y + 2*x - 5) == (3*t_0 - 5, -2*t_0 + 5) + assert diop_solve(2*x - 3*y - 5) == (3*t_0 - 5, 2*t_0 - 5) + assert diop_solve(-2*x - 3*y - 5) == (3*t_0 + 5, -2*t_0 - 5) + assert diop_solve(7*x + 5*y) == (5*t_0, -7*t_0) + assert diop_solve(2*x + 4*y) == (-2*t_0, t_0) + assert diop_solve(4*x + 6*y - 4) == (3*t_0 - 2, -2*t_0 + 2) + assert diop_solve(4*x + 6*y - 3) == (None, None) + assert diop_solve(0*x + 3*y - 4*z + 5) == (4*t_0 + 5, 3*t_0 + 5) + assert diop_solve(4*x + 3*y - 4*z + 5) == (t_0, 8*t_0 + 4*t_1 + 5, 7*t_0 + 3*t_1 + 5) + assert diop_solve(4*x + 3*y - 4*z + 5, None) == (0, 5, 5) + assert diop_solve(4*x + 2*y + 8*z - 5) == (None, None, None) + assert diop_solve(5*x + 7*y - 2*z - 6) == (t_0, -3*t_0 + 2*t_1 + 6, -8*t_0 + 7*t_1 + 18) + assert diop_solve(3*x - 6*y + 12*z - 9) == (2*t_0 + 3, t_0 + 2*t_1, t_1) + assert diop_solve(6*w + 9*x + 20*y - z) == (t_0, t_1, t_1 + t_2, 6*t_0 + 29*t_1 + 20*t_2) + + # to ignore constant factors, use diophantine + raises(TypeError, lambda: diop_solve(x/2)) + + +def test_quadratic_simple_hyperbolic_case(): + # Simple Hyperbolic case: A = C = 0 and B != 0 + assert diop_solve(3*x*y + 34*x - 12*y + 1) == \ + {(-133, -11), (5, -57)} + assert diop_solve(6*x*y + 2*x + 3*y + 1) == set() + assert diop_solve(-13*x*y + 2*x - 4*y - 54) == {(27, 0)} + assert diop_solve(-27*x*y - 30*x - 12*y - 54) == {(-14, -1)} + assert diop_solve(2*x*y + 5*x + 56*y + 7) == {(-161, -3), (-47, -6), (-35, -12), + (-29, -69), (-27, 64), (-21, 7), + (-9, 1), (105, -2)} + assert diop_solve(6*x*y + 9*x + 2*y + 3) == set() + assert diop_solve(x*y + x + y + 1) == {(-1, t), (t, -1)} + assert diophantine(48*x*y) + + +def test_quadratic_elliptical_case(): + # Elliptical case: B**2 - 4AC < 0 + + assert diop_solve(42*x**2 + 8*x*y + 15*y**2 + 23*x + 17*y - 4915) == {(-11, -1)} + assert diop_solve(4*x**2 + 3*y**2 + 5*x - 11*y + 12) == set() + assert diop_solve(x**2 + y**2 + 2*x + 2*y + 2) == {(-1, -1)} + assert diop_solve(15*x**2 - 9*x*y + 14*y**2 - 23*x - 14*y - 4950) == {(-15, 6)} + assert diop_solve(10*x**2 + 12*x*y + 12*y**2 - 34) == \ + {(-1, -1), (-1, 2), (1, -2), (1, 1)} + + +def test_quadratic_parabolic_case(): + # Parabolic case: B**2 - 4AC = 0 + assert check_solutions(8*x**2 - 24*x*y + 18*y**2 + 5*x + 7*y + 16) + assert check_solutions(8*x**2 - 24*x*y + 18*y**2 + 6*x + 12*y - 6) + assert check_solutions(8*x**2 + 24*x*y + 18*y**2 + 4*x + 6*y - 7) + assert check_solutions(-4*x**2 + 4*x*y - y**2 + 2*x - 3) + assert check_solutions(x**2 + 2*x*y + y**2 + 2*x + 2*y + 1) + assert check_solutions(x**2 - 2*x*y + y**2 + 2*x + 2*y + 1) + assert check_solutions(y**2 - 41*x + 40) + + +def test_quadratic_perfect_square(): + # B**2 - 4*A*C > 0 + # B**2 - 4*A*C is a perfect square + assert check_solutions(48*x*y) + assert check_solutions(4*x**2 - 5*x*y + y**2 + 2) + assert check_solutions(-2*x**2 - 3*x*y + 2*y**2 -2*x - 17*y + 25) + assert check_solutions(12*x**2 + 13*x*y + 3*y**2 - 2*x + 3*y - 12) + assert check_solutions(8*x**2 + 10*x*y + 2*y**2 - 32*x - 13*y - 23) + assert check_solutions(4*x**2 - 4*x*y - 3*y- 8*x - 3) + assert check_solutions(- 4*x*y - 4*y**2 - 3*y- 5*x - 10) + assert check_solutions(x**2 - y**2 - 2*x - 2*y) + assert check_solutions(x**2 - 9*y**2 - 2*x - 6*y) + assert check_solutions(4*x**2 - 9*y**2 - 4*x - 12*y - 3) + + +def test_quadratic_non_perfect_square(): + # B**2 - 4*A*C is not a perfect square + # Used check_solutions() since the solutions are complex expressions involving + # square roots and exponents + assert check_solutions(x**2 - 2*x - 5*y**2) + assert check_solutions(3*x**2 - 2*y**2 - 2*x - 2*y) + assert check_solutions(x**2 - x*y - y**2 - 3*y) + assert check_solutions(x**2 - 9*y**2 - 2*x - 6*y) + assert BinaryQuadratic(x**2 + y**2 + 2*x + 2*y + 2).solve() == {(-1, -1)} + + +def test_issue_9106(): + eq = -48 - 2*x*(3*x - 1) + y*(3*y - 1) + v = (x, y) + for sol in diophantine(eq): + assert not diop_simplify(eq.xreplace(dict(zip(v, sol)))) + + +def test_issue_18138(): + eq = x**2 - x - y**2 + v = (x, y) + for sol in diophantine(eq): + assert not diop_simplify(eq.xreplace(dict(zip(v, sol)))) + + +@slow +def test_quadratic_non_perfect_slow(): + assert check_solutions(8*x**2 + 10*x*y - 2*y**2 - 32*x - 13*y - 23) + # This leads to very large numbers. + # assert check_solutions(5*x**2 - 13*x*y + y**2 - 4*x - 4*y - 15) + assert check_solutions(-3*x**2 - 2*x*y + 7*y**2 - 5*x - 7) + assert check_solutions(-4 - x + 4*x**2 - y - 3*x*y - 4*y**2) + assert check_solutions(1 + 2*x + 2*x**2 + 2*y + x*y - 2*y**2) + + +def test_DN(): + # Most of the test cases were adapted from, + # Solving the generalized Pell equation x**2 - D*y**2 = N, John P. Robertson, July 31, 2004. + # https://web.archive.org/web/20160323033128/http://www.jpr2718.org/pell.pdf + # others are verified using Wolfram Alpha. + + # Covers cases where D <= 0 or D > 0 and D is a square or N = 0 + # Solutions are straightforward in these cases. + assert diop_DN(3, 0) == [(0, 0)] + assert diop_DN(-17, -5) == [] + assert diop_DN(-19, 23) == [(2, 1)] + assert diop_DN(-13, 17) == [(2, 1)] + assert diop_DN(-15, 13) == [] + assert diop_DN(0, 5) == [] + assert diop_DN(0, 9) == [(3, t)] + assert diop_DN(9, 0) == [(3*t, t)] + assert diop_DN(16, 24) == [] + assert diop_DN(9, 180) == [(18, 4)] + assert diop_DN(9, -180) == [(12, 6)] + assert diop_DN(7, 0) == [(0, 0)] + + # When equation is x**2 + y**2 = N + # Solutions are interchangeable + assert diop_DN(-1, 5) == [(2, 1), (1, 2)] + assert diop_DN(-1, 169) == [(12, 5), (5, 12), (13, 0), (0, 13)] + + # D > 0 and D is not a square + + # N = 1 + assert diop_DN(13, 1) == [(649, 180)] + assert diop_DN(980, 1) == [(51841, 1656)] + assert diop_DN(981, 1) == [(158070671986249, 5046808151700)] + assert diop_DN(986, 1) == [(49299, 1570)] + assert diop_DN(991, 1) == [(379516400906811930638014896080, 12055735790331359447442538767)] + assert diop_DN(17, 1) == [(33, 8)] + assert diop_DN(19, 1) == [(170, 39)] + + # N = -1 + assert diop_DN(13, -1) == [(18, 5)] + assert diop_DN(991, -1) == [] + assert diop_DN(41, -1) == [(32, 5)] + assert diop_DN(290, -1) == [(17, 1)] + assert diop_DN(21257, -1) == [(13913102721304, 95427381109)] + assert diop_DN(32, -1) == [] + + # |N| > 1 + # Some tests were created using calculator at + # http://www.numbertheory.org/php/patz.html + + assert diop_DN(13, -4) == [(3, 1), (393, 109), (36, 10)] + # Source I referred returned (3, 1), (393, 109) and (-3, 1) as fundamental solutions + # So (-3, 1) and (393, 109) should be in the same equivalent class + assert equivalent(-3, 1, 393, 109, 13, -4) == True + + assert diop_DN(13, 27) == [(220, 61), (40, 11), (768, 213), (12, 3)] + assert set(diop_DN(157, 12)) == {(13, 1), (10663, 851), (579160, 46222), + (483790960, 38610722), (26277068347, 2097138361), + (21950079635497, 1751807067011)} + assert diop_DN(13, 25) == [(3245, 900)] + assert diop_DN(192, 18) == [] + assert diop_DN(23, 13) == [(-6, 1), (6, 1)] + assert diop_DN(167, 2) == [(13, 1)] + assert diop_DN(167, -2) == [] + + assert diop_DN(123, -2) == [(11, 1)] + # One calculator returned [(11, 1), (-11, 1)] but both of these are in + # the same equivalence class + assert equivalent(11, 1, -11, 1, 123, -2) + + assert diop_DN(123, -23) == [(-10, 1), (10, 1)] + + assert diop_DN(0, 0, t) == [(0, t)] + assert diop_DN(0, -1, t) == [] + + +def test_bf_pell(): + assert diop_bf_DN(13, -4) == [(3, 1), (-3, 1), (36, 10)] + assert diop_bf_DN(13, 27) == [(12, 3), (-12, 3), (40, 11), (-40, 11)] + assert diop_bf_DN(167, -2) == [] + assert diop_bf_DN(1729, 1) == [(44611924489705, 1072885712316)] + assert diop_bf_DN(89, -8) == [(9, 1), (-9, 1)] + assert diop_bf_DN(21257, -1) == [(13913102721304, 95427381109)] + assert diop_bf_DN(340, -4) == [(756, 41)] + assert diop_bf_DN(-1, 0, t) == [(0, 0)] + assert diop_bf_DN(0, 0, t) == [(0, t)] + assert diop_bf_DN(4, 0, t) == [(2*t, t), (-2*t, t)] + assert diop_bf_DN(3, 0, t) == [(0, 0)] + assert diop_bf_DN(1, -2, t) == [] + + +def test_length(): + assert length(2, 1, 0) == 1 + assert length(-2, 4, 5) == 3 + assert length(-5, 4, 17) == 4 + assert length(0, 4, 13) == 6 + assert length(7, 13, 11) == 23 + assert length(1, 6, 4) == 2 + + +def is_pell_transformation_ok(eq): + """ + Test whether X*Y, X, or Y terms are present in the equation + after transforming the equation using the transformation returned + by transformation_to_pell(). If they are not present we are good. + Moreover, coefficient of X**2 should be a divisor of coefficient of + Y**2 and the constant term. + """ + A, B = transformation_to_DN(eq) + u = (A*Matrix([X, Y]) + B)[0] + v = (A*Matrix([X, Y]) + B)[1] + simplified = diop_simplify(eq.subs(zip((x, y), (u, v)))) + + coeff = dict([reversed(t.as_independent(*[X, Y])) for t in simplified.args]) + + for term in [X*Y, X, Y]: + if term in coeff.keys(): + return False + + for term in [X**2, Y**2, 1]: + if term not in coeff.keys(): + coeff[term] = 0 + + if coeff[X**2] != 0: + return divisible(coeff[Y**2], coeff[X**2]) and \ + divisible(coeff[1], coeff[X**2]) + + return True + + +def test_transformation_to_pell(): + assert is_pell_transformation_ok(-13*x**2 - 7*x*y + y**2 + 2*x - 2*y - 14) + assert is_pell_transformation_ok(-17*x**2 + 19*x*y - 7*y**2 - 5*x - 13*y - 23) + assert is_pell_transformation_ok(x**2 - y**2 + 17) + assert is_pell_transformation_ok(-x**2 + 7*y**2 - 23) + assert is_pell_transformation_ok(25*x**2 - 45*x*y + 5*y**2 - 5*x - 10*y + 5) + assert is_pell_transformation_ok(190*x**2 + 30*x*y + y**2 - 3*y - 170*x - 130) + assert is_pell_transformation_ok(x**2 - 2*x*y -190*y**2 - 7*y - 23*x - 89) + assert is_pell_transformation_ok(15*x**2 - 9*x*y + 14*y**2 - 23*x - 14*y - 4950) + + +def test_find_DN(): + assert find_DN(x**2 - 2*x - y**2) == (1, 1) + assert find_DN(x**2 - 3*y**2 - 5) == (3, 5) + assert find_DN(x**2 - 2*x*y - 4*y**2 - 7) == (5, 7) + assert find_DN(4*x**2 - 8*x*y - y**2 - 9) == (20, 36) + assert find_DN(7*x**2 - 2*x*y - y**2 - 12) == (8, 84) + assert find_DN(-3*x**2 + 4*x*y -y**2) == (1, 0) + assert find_DN(-13*x**2 - 7*x*y + y**2 + 2*x - 2*y -14) == (101, -7825480) + + +def test_ldescent(): + # Equations which have solutions + u = ([(13, 23), (3, -11), (41, -113), (4, -7), (-7, 4), (91, -3), (1, 1), (1, -1), + (4, 32), (17, 13), (123689, 1), (19, -570)]) + for a, b in u: + w, x, y = ldescent(a, b) + assert a*x**2 + b*y**2 == w**2 + assert ldescent(-1, -1) is None + assert ldescent(2, 6) is None + + +def test_diop_ternary_quadratic_normal(): + assert check_solutions(234*x**2 - 65601*y**2 - z**2) + assert check_solutions(23*x**2 + 616*y**2 - z**2) + assert check_solutions(5*x**2 + 4*y**2 - z**2) + assert check_solutions(3*x**2 + 6*y**2 - 3*z**2) + assert check_solutions(x**2 + 3*y**2 - z**2) + assert check_solutions(4*x**2 + 5*y**2 - z**2) + assert check_solutions(x**2 + y**2 - z**2) + assert check_solutions(16*x**2 + y**2 - 25*z**2) + assert check_solutions(6*x**2 - y**2 + 10*z**2) + assert check_solutions(213*x**2 + 12*y**2 - 9*z**2) + assert check_solutions(34*x**2 - 3*y**2 - 301*z**2) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + + +def is_normal_transformation_ok(eq): + A = transformation_to_normal(eq) + X, Y, Z = A*Matrix([x, y, z]) + simplified = diop_simplify(eq.subs(zip((x, y, z), (X, Y, Z)))) + + coeff = dict([reversed(t.as_independent(*[X, Y, Z])) for t in simplified.args]) + for term in [X*Y, Y*Z, X*Z]: + if term in coeff.keys(): + return False + + return True + + +def test_transformation_to_normal(): + assert is_normal_transformation_ok(x**2 + 3*y**2 + z**2 - 13*x*y - 16*y*z + 12*x*z) + assert is_normal_transformation_ok(x**2 + 3*y**2 - 100*z**2) + assert is_normal_transformation_ok(x**2 + 23*y*z) + assert is_normal_transformation_ok(3*y**2 - 100*z**2 - 12*x*y) + assert is_normal_transformation_ok(x**2 + 23*x*y - 34*y*z + 12*x*z) + assert is_normal_transformation_ok(z**2 + 34*x*y - 23*y*z + x*z) + assert is_normal_transformation_ok(x**2 + y**2 + z**2 - x*y - y*z - x*z) + assert is_normal_transformation_ok(x**2 + 2*y*z + 3*z**2) + assert is_normal_transformation_ok(x*y + 2*x*z + 3*y*z) + assert is_normal_transformation_ok(2*x*z + 3*y*z) + + +def test_diop_ternary_quadratic(): + assert check_solutions(2*x**2 + z**2 + y**2 - 4*x*y) + assert check_solutions(x**2 - y**2 - z**2 - x*y - y*z) + assert check_solutions(3*x**2 - x*y - y*z - x*z) + assert check_solutions(x**2 - y*z - x*z) + assert check_solutions(5*x**2 - 3*x*y - x*z) + assert check_solutions(4*x**2 - 5*y**2 - x*z) + assert check_solutions(3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) + assert check_solutions(8*x**2 - 12*y*z) + assert check_solutions(45*x**2 - 7*y**2 - 8*x*y - z**2) + assert check_solutions(x**2 - 49*y**2 - z**2 + 13*z*y -8*x*y) + assert check_solutions(90*x**2 + 3*y**2 + 5*x*y + 2*z*y + 5*x*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - x*y - 17*y*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - x*y - 16*y*z + 12*x*z) + assert check_solutions(x**2 + 3*y**2 + z**2 - 13*x*y - 16*y*z + 12*x*z) + assert check_solutions(x*y - 7*y*z + 13*x*z) + + assert diop_ternary_quadratic_normal(x**2 + y**2 + z**2) == (None, None, None) + assert diop_ternary_quadratic_normal(x**2 + y**2) is None + raises(ValueError, lambda: + _diop_ternary_quadratic_normal((x, y, z), + {x*y: 1, x**2: 2, y**2: 3, z**2: 0})) + eq = -2*x*y - 6*x*z + 7*y**2 - 3*y*z + 4*z**2 + assert diop_ternary_quadratic(eq) == (7, 2, 0) + assert diop_ternary_quadratic_normal(4*x**2 + 5*y**2 - z**2) == \ + (1, 0, 2) + assert diop_ternary_quadratic(x*y + 2*y*z) == \ + (-2, 0, n1) + eq = -5*x*y - 8*x*z - 3*y*z + 8*z**2 + assert parametrize_ternary_quadratic(eq) == \ + (8*p**2 - 3*p*q, -8*p*q + 8*q**2, 5*p*q) + # this cannot be tested with diophantine because it will + # factor into a product + assert diop_solve(x*y + 2*y*z) == (-2*p*q, -n1*p**2 + p**2, p*q) + + +def test_square_factor(): + assert square_factor(1) == square_factor(-1) == 1 + assert square_factor(0) == 1 + assert square_factor(5) == square_factor(-5) == 1 + assert square_factor(4) == square_factor(-4) == 2 + assert square_factor(12) == square_factor(-12) == 2 + assert square_factor(6) == 1 + assert square_factor(18) == 3 + assert square_factor(52) == 2 + assert square_factor(49) == 7 + assert square_factor(392) == 14 + assert square_factor(factorint(-12)) == 2 + + +def test_parametrize_ternary_quadratic(): + assert check_solutions(x**2 + y**2 - z**2) + assert check_solutions(x**2 + 2*x*y + z**2) + assert check_solutions(234*x**2 - 65601*y**2 - z**2) + assert check_solutions(3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) + assert check_solutions(x**2 - y**2 - z**2) + assert check_solutions(x**2 - 49*y**2 - z**2 + 13*z*y - 8*x*y) + assert check_solutions(8*x*y + z**2) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + assert check_solutions(236*x**2 - 225*y**2 - 11*x*y - 13*y*z - 17*x*z) + assert check_solutions(90*x**2 + 3*y**2 + 5*x*y + 2*z*y + 5*x*z) + assert check_solutions(124*x**2 - 30*y**2 - 7729*z**2) + + +def test_no_square_ternary_quadratic(): + assert check_solutions(2*x*y + y*z - 3*x*z) + assert check_solutions(189*x*y - 345*y*z - 12*x*z) + assert check_solutions(23*x*y + 34*y*z) + assert check_solutions(x*y + y*z + z*x) + assert check_solutions(23*x*y + 23*y*z + 23*x*z) + + +def test_descent(): + + u = ([(13, 23), (3, -11), (41, -113), (91, -3), (1, 1), (1, -1), (17, 13), (123689, 1), (19, -570)]) + for a, b in u: + w, x, y = descent(a, b) + assert a*x**2 + b*y**2 == w**2 + # the docstring warns against bad input, so these are expected results + # - can't both be negative + raises(TypeError, lambda: descent(-1, -3)) + # A can't be zero unless B != 1 + raises(ZeroDivisionError, lambda: descent(0, 3)) + # supposed to be square-free + raises(TypeError, lambda: descent(4, 3)) + + +def test_diophantine(): + assert check_solutions((x - y)*(y - z)*(z - x)) + assert check_solutions((x - y)*(x**2 + y**2 - z**2)) + assert check_solutions((x - 3*y + 7*z)*(x**2 + y**2 - z**2)) + assert check_solutions(x**2 - 3*y**2 - 1) + assert check_solutions(y**2 + 7*x*y) + assert check_solutions(x**2 - 3*x*y + y**2) + assert check_solutions(z*(x**2 - y**2 - 15)) + assert check_solutions(x*(2*y - 2*z + 5)) + assert check_solutions((x**2 - 3*y**2 - 1)*(x**2 - y**2 - 15)) + assert check_solutions((x**2 - 3*y**2 - 1)*(y - 7*z)) + assert check_solutions((x**2 + y**2 - z**2)*(x - 7*y - 3*z + 4*w)) + # Following test case caused problems in parametric representation + # But this can be solved by factoring out y. + # No need to use methods for ternary quadratic equations. + assert check_solutions(y**2 - 7*x*y + 4*y*z) + assert check_solutions(x**2 - 2*x + 1) + + assert diophantine(x - y) == diophantine(Eq(x, y)) + # 18196 + eq = x**4 + y**4 - 97 + assert diophantine(eq, permute=True) == diophantine(-eq, permute=True) + assert diophantine(3*x*pi - 2*y*pi) == {(2*t_0, 3*t_0)} + eq = x**2 + y**2 + z**2 - 14 + base_sol = {(1, 2, 3)} + assert diophantine(eq) == base_sol + complete_soln = set(signed_permutations(base_sol.pop())) + assert diophantine(eq, permute=True) == complete_soln + + assert diophantine(x**2 + x*Rational(15, 14) - 3) == set() + # test issue 11049 + eq = 92*x**2 - 99*y**2 - z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(9, 7, 51)} + assert diophantine(eq) == {( + 891*p**2 + 9*q**2, -693*p**2 - 102*p*q + 7*q**2, + 5049*p**2 - 1386*p*q - 51*q**2)} + eq = 2*x**2 + 2*y**2 - z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(1, 1, 2)} + assert diophantine(eq) == {( + 2*p**2 - q**2, -2*p**2 + 4*p*q - q**2, + 4*p**2 - 4*p*q + 2*q**2)} + eq = 411*x**2+57*y**2-221*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(2021, 2645, 3066)} + assert diophantine(eq) == \ + {(115197*p**2 - 446641*q**2, -150765*p**2 + 1355172*p*q - + 584545*q**2, 174762*p**2 - 301530*p*q + 677586*q**2)} + eq = 573*x**2+267*y**2-984*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(49, 233, 127)} + assert diophantine(eq) == \ + {(4361*p**2 - 16072*q**2, -20737*p**2 + 83312*p*q - 76424*q**2, + 11303*p**2 - 41474*p*q + 41656*q**2)} + # this produces factors during reconstruction + eq = x**2 + 3*y**2 - 12*z**2 + coeff = eq.as_coefficients_dict() + assert _diop_ternary_quadratic_normal((x, y, z), coeff) == \ + {(0, 2, 1)} + assert diophantine(eq) == \ + {(24*p*q, 2*p**2 - 24*q**2, p**2 + 12*q**2)} + # solvers have not been written for every type + raises(NotImplementedError, lambda: diophantine(x*y**2 + 1)) + + # rational expressions + assert diophantine(1/x) == set() + assert diophantine(1/x + 1/y - S.Half) == {(6, 3), (-2, 1), (4, 4), (1, -2), (3, 6)} + assert diophantine(x**2 + y**2 +3*x- 5, permute=True) == \ + {(-1, 1), (-4, -1), (1, -1), (1, 1), (-4, 1), (-1, -1), (4, 1), (4, -1)} + + + #test issue 18186 + assert diophantine(y**4 + x**4 - 2**4 - 3**4, syms=(x, y), permute=True) == \ + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + assert diophantine(y**4 + x**4 - 2**4 - 3**4, syms=(y, x), permute=True) == \ + {(-3, -2), (-3, 2), (-2, -3), (-2, 3), (2, -3), (2, 3), (3, -2), (3, 2)} + + # issue 18122 + assert check_solutions(x**2 - y) + assert check_solutions(y**2 - x) + assert diophantine((x**2 - y), t) == {(t, t**2)} + assert diophantine((y**2 - x), t) == {(t**2, t)} + + +def test_general_pythagorean(): + from sympy.abc import a, b, c, d, e + + assert check_solutions(a**2 + b**2 + c**2 - d**2) + assert check_solutions(a**2 + 4*b**2 + 4*c**2 - d**2) + assert check_solutions(9*a**2 + 4*b**2 + 4*c**2 - d**2) + assert check_solutions(9*a**2 + 4*b**2 - 25*d**2 + 4*c**2 ) + assert check_solutions(9*a**2 - 16*d**2 + 4*b**2 + 4*c**2) + assert check_solutions(-e**2 + 9*a**2 + 4*b**2 + 4*c**2 + 25*d**2) + assert check_solutions(16*a**2 - b**2 + 9*c**2 + d**2 + 25*e**2) + + assert GeneralPythagorean(a**2 + b**2 + c**2 - d**2).solve(parameters=[x, y, z]) == \ + {(x**2 + y**2 - z**2, 2*x*z, 2*y*z, x**2 + y**2 + z**2)} + + +def test_diop_general_sum_of_squares_quick(): + for i in range(3, 10): + assert check_solutions(sum(i**2 for i in symbols(':%i' % i)) - i) + + assert diop_general_sum_of_squares(x**2 + y**2 - 2) is None + assert diop_general_sum_of_squares(x**2 + y**2 + z**2 + 2) == set() + eq = x**2 + y**2 + z**2 - (1 + 4 + 9) + assert diop_general_sum_of_squares(eq) == \ + {(1, 2, 3)} + eq = u**2 + v**2 + x**2 + y**2 + z**2 - 1313 + assert len(diop_general_sum_of_squares(eq, 3)) == 3 + # issue 11016 + var = symbols(':5') + (symbols('6', negative=True),) + eq = Add(*[i**2 for i in var]) - 112 + + base_soln = {(0, 1, 1, 5, 6, -7), (1, 1, 1, 3, 6, -8), (2, 3, 3, 4, 5, -7), (0, 1, 1, 1, 3, -10), + (0, 0, 4, 4, 4, -8), (1, 2, 3, 3, 5, -8), (0, 1, 2, 3, 7, -7), (2, 2, 4, 4, 6, -6), + (1, 1, 3, 4, 6, -7), (0, 2, 3, 3, 3, -9), (0, 0, 2, 2, 2, -10), (1, 1, 2, 3, 4, -9), + (0, 1, 1, 2, 5, -9), (0, 0, 2, 6, 6, -6), (1, 3, 4, 5, 5, -6), (0, 2, 2, 2, 6, -8), + (0, 3, 3, 3, 6, -7), (0, 2, 3, 5, 5, -7), (0, 1, 5, 5, 5, -6)} + assert diophantine(eq) == base_soln + assert len(diophantine(eq, permute=True)) == 196800 + + # handle negated squares with signsimp + assert diophantine(12 - x**2 - y**2 - z**2) == {(2, 2, 2)} + # diophantine handles simplification, so classify_diop should + # not have to look for additional patterns that are removed + # by diophantine + eq = a**2 + b**2 + c**2 + d**2 - 4 + raises(NotImplementedError, lambda: classify_diop(-eq)) + + +def test_issue_23807(): + # fixes recursion error + eq = x**2 + y**2 + z**2 - 1000000 + base_soln = {(0, 0, 1000), (0, 352, 936), (480, 600, 640), (24, 640, 768), (192, 640, 744), + (192, 480, 856), (168, 224, 960), (0, 600, 800), (280, 576, 768), (152, 480, 864), + (0, 280, 960), (352, 360, 864), (424, 480, 768), (360, 480, 800), (224, 600, 768), + (96, 360, 928), (168, 576, 800), (96, 480, 872)} + + assert diophantine(eq) == base_soln + + +def test_diop_partition(): + for n in [8, 10]: + for k in range(1, 8): + for p in partition(n, k): + assert len(p) == k + assert list(partition(3, 5)) == [] + assert [list(p) for p in partition(3, 5, 1)] == [ + [0, 0, 0, 0, 3], [0, 0, 0, 1, 2], [0, 0, 1, 1, 1]] + assert list(partition(0)) == [()] + assert list(partition(1, 0)) == [()] + assert [list(i) for i in partition(3)] == [[1, 1, 1], [1, 2], [3]] + + +def test_prime_as_sum_of_two_squares(): + for i in [5, 13, 17, 29, 37, 41, 2341, 3557, 34841, 64601]: + a, b = prime_as_sum_of_two_squares(i) + assert a**2 + b**2 == i + assert prime_as_sum_of_two_squares(7) is None + ans = prime_as_sum_of_two_squares(800029) + assert ans == (450, 773) and type(ans[0]) is int + + +def test_sum_of_three_squares(): + for i in [0, 1, 2, 34, 123, 34304595905, 34304595905394941, 343045959052344, + 800, 801, 802, 803, 804, 805, 806]: + a, b, c = sum_of_three_squares(i) + assert a**2 + b**2 + c**2 == i + assert a >= 0 + + # error + raises(ValueError, lambda: sum_of_three_squares(-1)) + + assert sum_of_three_squares(7) is None + assert sum_of_three_squares((4**5)*15) is None + # if there are two zeros, there might be a solution + # with only one zero, e.g. 25 => (0, 3, 4) or + # with no zeros, e.g. 49 => (2, 3, 6) + assert sum_of_three_squares(25) == (0, 0, 5) + assert sum_of_three_squares(4) == (0, 0, 2) + + +def test_sum_of_four_squares(): + from sympy.core.random import randint + + # this should never fail + n = randint(1, 100000000000000) + assert sum(i**2 for i in sum_of_four_squares(n)) == n + + # error + raises(ValueError, lambda: sum_of_four_squares(-1)) + + for n in range(1000): + result = sum_of_four_squares(n) + assert len(result) == 4 + assert all(r >= 0 for r in result) + assert sum(r**2 for r in result) == n + assert list(result) == sorted(result) + + +def test_power_representation(): + tests = [(1729, 3, 2), (234, 2, 4), (2, 1, 2), (3, 1, 3), (5, 2, 2), (12352, 2, 4), + (32760, 2, 3)] + + for test in tests: + n, p, k = test + f = power_representation(n, p, k) + + while True: + try: + l = next(f) + assert len(l) == k + + chk_sum = 0 + for l_i in l: + chk_sum = chk_sum + l_i**p + assert chk_sum == n + + except StopIteration: + break + + assert list(power_representation(20, 2, 4, True)) == \ + [(1, 1, 3, 3), (0, 0, 2, 4)] + raises(ValueError, lambda: list(power_representation(1.2, 2, 2))) + raises(ValueError, lambda: list(power_representation(2, 0, 2))) + raises(ValueError, lambda: list(power_representation(2, 2, 0))) + assert list(power_representation(-1, 2, 2)) == [] + assert list(power_representation(1, 1, 1)) == [(1,)] + assert list(power_representation(3, 2, 1)) == [] + assert list(power_representation(4, 2, 1)) == [(2,)] + assert list(power_representation(3**4, 4, 6, zeros=True)) == \ + [(1, 2, 2, 2, 2, 2), (0, 0, 0, 0, 0, 3)] + assert list(power_representation(3**4, 4, 5, zeros=False)) == [] + assert list(power_representation(-2, 3, 2)) == [(-1, -1)] + assert list(power_representation(-2, 4, 2)) == [] + assert list(power_representation(0, 3, 2, True)) == [(0, 0)] + assert list(power_representation(0, 3, 2, False)) == [] + # when we are dealing with squares, do feasibility checks + assert len(list(power_representation(4**10*(8*10 + 7), 2, 3))) == 0 + # there will be a recursion error if these aren't recognized + big = 2**30 + for i in [13, 10, 7, 5, 4, 2, 1]: + assert list(sum_of_powers(big, 2, big - i)) == [] + + +def test_assumptions(): + """ + Test whether diophantine respects the assumptions. + """ + #Test case taken from the below so question regarding assumptions in diophantine module + #https://stackoverflow.com/questions/23301941/how-can-i-declare-natural-symbols-with-sympy + m, n = symbols('m n', integer=True, positive=True) + diof = diophantine(n**2 + m*n - 500) + assert diof == {(5, 20), (40, 10), (95, 5), (121, 4), (248, 2), (499, 1)} + + a, b = symbols('a b', integer=True, positive=False) + diof = diophantine(a*b + 2*a + 3*b - 6) + assert diof == {(-15, -3), (-9, -4), (-7, -5), (-6, -6), (-5, -8), (-4, -14)} + + +def check_solutions(eq): + """ + Determines whether solutions returned by diophantine() satisfy the original + equation. Hope to generalize this so we can remove functions like check_ternay_quadratic, + check_solutions_normal, check_solutions() + """ + s = diophantine(eq) + + factors = Mul.make_args(eq) + + var = list(eq.free_symbols) + var.sort(key=default_sort_key) + + while s: + solution = s.pop() + for f in factors: + if diop_simplify(f.subs(zip(var, solution))) == 0: + break + else: + return False + return True + + +def test_diopcoverage(): + eq = (2*x + y + 1)**2 + assert diop_solve(eq) == {(t_0, -2*t_0 - 1)} + eq = 2*x**2 + 6*x*y + 12*x + 4*y**2 + 18*y + 18 + assert diop_solve(eq) == {(t, -t - 3), (-2*t - 3, t)} + assert diop_quadratic(x + y**2 - 3) == {(-t**2 + 3, t)} + + assert diop_linear(x + y - 3) == (t_0, 3 - t_0) + + assert base_solution_linear(0, 1, 2, t=None) == (0, 0) + ans = (3*t - 1, -2*t + 1) + assert base_solution_linear(4, 8, 12, t) == ans + assert base_solution_linear(4, 8, 12, t=None) == tuple(_.subs(t, 0) for _ in ans) + + assert cornacchia(1, 1, 20) == set() + assert cornacchia(1, 1, 5) == {(2, 1)} + assert cornacchia(1, 2, 17) == {(3, 2)} + + raises(ValueError, lambda: reconstruct(4, 20, 1)) + + assert gaussian_reduce(4, 1, 3) == (1, 1) + eq = -w**2 - x**2 - y**2 + z**2 + + assert diop_general_pythagorean(eq) == \ + diop_general_pythagorean(-eq) == \ + (m1**2 + m2**2 - m3**2, 2*m1*m3, + 2*m2*m3, m1**2 + m2**2 + m3**2) + + assert len(check_param(S(3) + x/3, S(4) + x/2, S(2), [x])) == 0 + assert len(check_param(Rational(3, 2), S(4) + x, S(2), [x])) == 0 + assert len(check_param(S(4) + x, Rational(3, 2), S(2), [x])) == 0 + + assert _nint_or_floor(16, 10) == 2 + assert _odd(1) == (not _even(1)) == True + assert _odd(0) == (not _even(0)) == False + assert _remove_gcd(2, 4, 6) == (1, 2, 3) + raises(TypeError, lambda: _remove_gcd((2, 4, 6))) + assert sqf_normal(2*3**2*5, 2*5*11, 2*7**2*11) == \ + (11, 1, 5) + + # it's ok if these pass some day when the solvers are implemented + raises(NotImplementedError, lambda: diophantine(x**2 + y**2 + x*y + 2*y*z - 12)) + raises(NotImplementedError, lambda: diophantine(x**3 + y**2)) + assert diop_quadratic(x**2 + y**2 - 1**2 - 3**4) == \ + {(-9, -1), (-9, 1), (-1, -9), (-1, 9), (1, -9), (1, 9), (9, -1), (9, 1)} + + +def test_holzer(): + # if the input is good, don't let it diverge in holzer() + # (but see test_fail_holzer below) + assert holzer(2, 7, 13, 4, 79, 23) == (2, 7, 13) + + # None in uv condition met; solution is not Holzer reduced + # so this will hopefully change but is here for coverage + assert holzer(2, 6, 2, 1, 1, 10) == (2, 6, 2) + + raises(ValueError, lambda: holzer(2, 7, 14, 4, 79, 23)) + + +@XFAIL +def test_fail_holzer(): + eq = lambda x, y, z: a*x**2 + b*y**2 - c*z**2 + a, b, c = 4, 79, 23 + x, y, z = xyz = 26, 1, 11 + X, Y, Z = ans = 2, 7, 13 + assert eq(*xyz) == 0 + assert eq(*ans) == 0 + assert max(a*x**2, b*y**2, c*z**2) <= a*b*c + assert max(a*X**2, b*Y**2, c*Z**2) <= a*b*c + h = holzer(x, y, z, a, b, c) + assert h == ans # it would be nice to get the smaller soln + + +def test_issue_9539(): + assert diophantine(6*w + 9*y + 20*x - z) == \ + {(t_0, t_1, t_1 + t_2, 6*t_0 + 29*t_1 + 9*t_2)} + + +def test_issue_8943(): + assert diophantine( + 3*(x**2 + y**2 + z**2) - 14*(x*y + y*z + z*x)) == \ + {(0, 0, 0)} + + +def test_diop_sum_of_even_powers(): + eq = x**4 + y**4 + z**4 - 2673 + assert diop_solve(eq) == {(3, 6, 6), (2, 4, 7)} + assert diop_general_sum_of_even_powers(eq, 2) == {(3, 6, 6), (2, 4, 7)} + raises(NotImplementedError, lambda: diop_general_sum_of_even_powers(-eq, 2)) + neg = symbols('neg', negative=True) + eq = x**4 + y**4 + neg**4 - 2673 + assert diop_general_sum_of_even_powers(eq) == {(-3, 6, 6)} + assert diophantine(x**4 + y**4 + 2) == set() + assert diop_general_sum_of_even_powers(x**4 + y**4 - 2, limit=0) == set() + + +def test_sum_of_squares_powers(): + tru = {(0, 0, 1, 1, 11), (0, 0, 5, 7, 7), (0, 1, 3, 7, 8), (0, 1, 4, 5, 9), (0, 3, 4, 7, 7), (0, 3, 5, 5, 8), + (1, 1, 2, 6, 9), (1, 1, 6, 6, 7), (1, 2, 3, 3, 10), (1, 3, 4, 4, 9), (1, 5, 5, 6, 6), (2, 2, 3, 5, 9), + (2, 3, 5, 6, 7), (3, 3, 4, 5, 8)} + eq = u**2 + v**2 + x**2 + y**2 + z**2 - 123 + ans = diop_general_sum_of_squares(eq, oo) # allow oo to be used + assert len(ans) == 14 + assert ans == tru + + raises(ValueError, lambda: list(sum_of_squares(10, -1))) + assert list(sum_of_squares(1, 1)) == [(1,)] + assert list(sum_of_squares(1, 2)) == [] + assert list(sum_of_squares(1, 2, True)) == [(0, 1)] + assert list(sum_of_squares(-10, 2)) == [] + assert list(sum_of_squares(2, 3)) == [] + assert list(sum_of_squares(0, 3, True)) == [(0, 0, 0)] + assert list(sum_of_squares(0, 3)) == [] + assert list(sum_of_squares(4, 1)) == [(2,)] + assert list(sum_of_squares(5, 1)) == [] + assert list(sum_of_squares(50, 2)) == [(5, 5), (1, 7)] + assert list(sum_of_squares(11, 5, True)) == [ + (1, 1, 1, 2, 2), (0, 0, 1, 1, 3)] + assert list(sum_of_squares(8, 8)) == [(1, 1, 1, 1, 1, 1, 1, 1)] + + assert [len(list(sum_of_squares(i, 5, True))) for i in range(30)] == [ + 1, 1, 1, 1, 2, + 2, 1, 1, 2, 2, + 2, 2, 2, 3, 2, + 1, 3, 3, 3, 3, + 4, 3, 3, 2, 2, + 4, 4, 4, 4, 5] + assert [len(list(sum_of_squares(i, 5))) for i in range(30)] == [ + 0, 0, 0, 0, 0, + 1, 0, 0, 1, 0, + 0, 1, 0, 1, 1, + 0, 1, 1, 0, 1, + 2, 1, 1, 1, 1, + 1, 1, 1, 1, 3] + for i in range(30): + s1 = set(sum_of_squares(i, 5, True)) + assert not s1 or all(sum(j**2 for j in t) == i for t in s1) + s2 = set(sum_of_squares(i, 5)) + assert all(sum(j**2 for j in t) == i for t in s2) + + raises(ValueError, lambda: list(sum_of_powers(2, -1, 1))) + raises(ValueError, lambda: list(sum_of_powers(2, 1, -1))) + assert list(sum_of_powers(-2, 3, 2)) == [(-1, -1)] + assert list(sum_of_powers(-2, 4, 2)) == [] + assert list(sum_of_powers(2, 1, 1)) == [(2,)] + assert list(sum_of_powers(2, 1, 3, True)) == [(0, 0, 2), (0, 1, 1)] + assert list(sum_of_powers(5, 1, 2, True)) == [(0, 5), (1, 4), (2, 3)] + assert list(sum_of_powers(6, 2, 2)) == [] + assert list(sum_of_powers(3**5, 3, 1)) == [] + assert list(sum_of_powers(3**6, 3, 1)) == [(9,)] and (9**3 == 3**6) + assert list(sum_of_powers(2**1000, 5, 2)) == [] + + +def test__can_do_sum_of_squares(): + assert _can_do_sum_of_squares(3, -1) is False + assert _can_do_sum_of_squares(-3, 1) is False + assert _can_do_sum_of_squares(0, 1) + assert _can_do_sum_of_squares(4, 1) + assert _can_do_sum_of_squares(1, 2) + assert _can_do_sum_of_squares(2, 2) + assert _can_do_sum_of_squares(3, 2) is False + + +def test_diophantine_permute_sign(): + from sympy.abc import a, b, c, d, e + eq = a**4 + b**4 - (2**4 + 3**4) + base_sol = {(2, 3)} + assert diophantine(eq) == base_sol + complete_soln = set(signed_permutations(base_sol.pop())) + assert diophantine(eq, permute=True) == complete_soln + + eq = a**2 + b**2 + c**2 + d**2 + e**2 - 234 + assert len(diophantine(eq)) == 35 + assert len(diophantine(eq, permute=True)) == 62000 + soln = {(-1, -1), (-1, 2), (1, -2), (1, 1)} + assert diophantine(10*x**2 + 12*x*y + 12*y**2 - 34, permute=True) == soln + + +@XFAIL +def test_not_implemented(): + eq = x**2 + y**4 - 1**2 - 3**4 + assert diophantine(eq, syms=[x, y]) == {(9, 1), (1, 3)} + + +def test_issue_9538(): + eq = x - 3*y + 2 + assert diophantine(eq, syms=[y,x]) == {(t_0, 3*t_0 - 2)} + raises(TypeError, lambda: diophantine(eq, syms={y, x})) + + +def test_ternary_quadratic(): + # solution with 3 parameters + s = diophantine(2*x**2 + y**2 - 2*z**2) + p, q, r = ordered(S(s).free_symbols) + assert s == {( + p**2 - 2*q**2, + -2*p**2 + 4*p*q - 4*p*r - 4*q**2, + p**2 - 4*p*q + 2*q**2 - 4*q*r)} + # solution with Mul in solution + s = diophantine(x**2 + 2*y**2 - 2*z**2) + assert s == {(4*p*q, p**2 - 2*q**2, p**2 + 2*q**2)} + # solution with no Mul in solution + s = diophantine(2*x**2 + 2*y**2 - z**2) + assert s == {(2*p**2 - q**2, -2*p**2 + 4*p*q - q**2, + 4*p**2 - 4*p*q + 2*q**2)} + # reduced form when parametrized + s = diophantine(3*x**2 + 72*y**2 - 27*z**2) + assert s == {(24*p**2 - 9*q**2, 6*p*q, 8*p**2 + 3*q**2)} + assert parametrize_ternary_quadratic( + 3*x**2 + 2*y**2 - z**2 - 2*x*y + 5*y*z - 7*y*z) == ( + 2*p**2 - 2*p*q - q**2, 2*p**2 + 2*p*q - q**2, 2*p**2 - + 2*p*q + 3*q**2) + assert parametrize_ternary_quadratic( + 124*x**2 - 30*y**2 - 7729*z**2) == ( + -1410*p**2 - 363263*q**2, 2700*p**2 + 30916*p*q - + 695610*q**2, -60*p**2 + 5400*p*q + 15458*q**2) + + +def test_diophantine_solution_set(): + s1 = DiophantineSolutionSet([], []) + assert set(s1) == set() + assert s1.symbols == () + assert s1.parameters == () + raises(ValueError, lambda: s1.add((x,))) + assert list(s1.dict_iterator()) == [] + + s2 = DiophantineSolutionSet([x, y], [t, u]) + assert s2.symbols == (x, y) + assert s2.parameters == (t, u) + raises(ValueError, lambda: s2.add((1,))) + s2.add((3, 4)) + assert set(s2) == {(3, 4)} + s2.update((3, 4), (-1, u)) + assert set(s2) == {(3, 4), (-1, u)} + raises(ValueError, lambda: s1.update(s2)) + assert list(s2.dict_iterator()) == [{x: -1, y: u}, {x: 3, y: 4}] + + s3 = DiophantineSolutionSet([x, y, z], [t, u]) + assert len(s3.parameters) == 2 + s3.add((t**2 + u, t - u, 1)) + assert set(s3) == {(t**2 + u, t - u, 1)} + assert s3.subs(t, 2) == {(u + 4, 2 - u, 1)} + assert s3(2) == {(u + 4, 2 - u, 1)} + assert s3.subs({t: 7, u: 8}) == {(57, -1, 1)} + assert s3(7, 8) == {(57, -1, 1)} + assert s3.subs({t: 5}) == {(u + 25, 5 - u, 1)} + assert s3(5) == {(u + 25, 5 - u, 1)} + assert s3.subs(u, -3) == {(t**2 - 3, t + 3, 1)} + assert s3(None, -3) == {(t**2 - 3, t + 3, 1)} + assert s3.subs({t: 2, u: 8}) == {(12, -6, 1)} + assert s3(2, 8) == {(12, -6, 1)} + assert s3.subs({t: 5, u: -3}) == {(22, 8, 1)} + assert s3(5, -3) == {(22, 8, 1)} + raises(ValueError, lambda: s3.subs(x=1)) + raises(ValueError, lambda: s3.subs(1, 2, 3)) + raises(ValueError, lambda: s3.add(())) + raises(ValueError, lambda: s3.add((1, 2, 3, 4))) + raises(ValueError, lambda: s3.add((1, 2))) + raises(ValueError, lambda: s3(1, 2, 3)) + raises(TypeError, lambda: s3(t=1)) + + s4 = DiophantineSolutionSet([x, y], [t, u]) + s4.add((t, 11*t)) + s4.add((-t, 22*t)) + assert s4(0, 0) == {(0, 0)} + + +def test_quadratic_parameter_passing(): + eq = -33*x*y + 3*y**2 + solution = BinaryQuadratic(eq).solve(parameters=[t, u]) + # test that parameters are passed all the way to the final solution + assert solution == {(t, 11*t), (t, -22*t)} + assert solution(0, 0) == {(0, 0)} diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__init__.py b/lib/python3.10/site-packages/sympy/solvers/ode/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2b543425251dea6380a1860279cb6d636f3dd629 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/__init__.py @@ -0,0 +1,16 @@ +from .ode import (allhints, checkinfsol, classify_ode, + constantsimp, dsolve, homogeneous_order) + +from .lie_group import infinitesimals + +from .subscheck import checkodesol + +from .systems import (canonical_odes, linear_ode_to_matrix, + linodesolve) + + +__all__ = [ + 'allhints', 'checkinfsol', 'checkodesol', 'classify_ode', 'constantsimp', + 'dsolve', 'homogeneous_order', 'infinitesimals', 'canonical_odes', 'linear_ode_to_matrix', + 'linodesolve' +] diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d44df73711c7bfd58afc2853de080b36eb2af11a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/hypergeometric.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/hypergeometric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..461d815e7fb0ef2de395ea960f8e988ab1965a2b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/hypergeometric.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/lie_group.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/lie_group.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c8f6e1a60e9954683e3db133da56366970375050 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/lie_group.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/nonhomogeneous.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/nonhomogeneous.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e6ed5c68f39bd15df9db56cf248437150f5b6f6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/nonhomogeneous.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/riccati.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/riccati.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2a8c472fb6b46553a3deb7aa6de5be9f1f1926 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/riccati.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/subscheck.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/subscheck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5fa7e6b4e206bc326f71f84acc2c8e469fedfd23 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/subscheck.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/systems.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/systems.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4921ad282d62d8beef84bc9ba579ad9b95ba4129 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/__pycache__/systems.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/hypergeometric.py b/lib/python3.10/site-packages/sympy/solvers/ode/hypergeometric.py new file mode 100644 index 0000000000000000000000000000000000000000..51a40b1cba32eabbdb120f9c4d5e3fd05dc644eb --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/hypergeometric.py @@ -0,0 +1,272 @@ +r''' +This module contains the implementation of the 2nd_hypergeometric hint for +dsolve. This is an incomplete implementation of the algorithm described in [1]. +The algorithm solves 2nd order linear ODEs of the form + +.. math:: y'' + A(x) y' + B(x) y = 0\text{,} + +where `A` and `B` are rational functions. The algorithm should find any +solution of the form + +.. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,} + +where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function". +Currently only the 2F1 case is implemented in SymPy but the other cases are +described in the paper and could be implemented in future (contributions +welcome!). + +References +========== + +.. [1] L. Chan, E.S. Cheb-Terrab, Non-Liouvillian solutions for second order + linear ODEs, (2004). + https://arxiv.org/abs/math-ph/0402063 +''' + +from sympy.core import S, Pow +from sympy.core.function import expand +from sympy.core.relational import Eq +from sympy.core.symbol import Symbol, Wild +from sympy.functions import exp, sqrt, hyper +from sympy.integrals import Integral +from sympy.polys import roots, gcd +from sympy.polys.polytools import cancel, factor +from sympy.simplify import collect, simplify, logcombine # type: ignore +from sympy.simplify.powsimp import powdenest +from sympy.solvers.ode.ode import get_numbered_constants + + +def match_2nd_hypergeometric(eq, func): + x = func.args[0] + df = func.diff(x) + a3 = Wild('a3', exclude=[func, func.diff(x), func.diff(x, 2)]) + b3 = Wild('b3', exclude=[func, func.diff(x), func.diff(x, 2)]) + c3 = Wild('c3', exclude=[func, func.diff(x), func.diff(x, 2)]) + deq = a3*(func.diff(x, 2)) + b3*df + c3*func + r = collect(eq, + [func.diff(x, 2), func.diff(x), func]).match(deq) + if r: + if not all(val.is_polynomial() for val in r.values()): + n, d = eq.as_numer_denom() + eq = expand(n) + r = collect(eq, [func.diff(x, 2), func.diff(x), func]).match(deq) + + if r and r[a3]!=0: + A = cancel(r[b3]/r[a3]) + B = cancel(r[c3]/r[a3]) + return [A, B] + else: + return [] + + +def equivalence_hypergeometric(A, B, func): + # This method for finding the equivalence is only for 2F1 type. + # We can extend it for 1F1 and 0F1 type also. + x = func.args[0] + + # making given equation in normal form + I1 = factor(cancel(A.diff(x)/2 + A**2/4 - B)) + + # computing shifted invariant(J1) of the equation + J1 = factor(cancel(x**2*I1 + S(1)/4)) + num, dem = J1.as_numer_denom() + num = powdenest(expand(num)) + dem = powdenest(expand(dem)) + # this function will compute the different powers of variable(x) in J1. + # then it will help in finding value of k. k is power of x such that we can express + # J1 = x**k * J0(x**k) then all the powers in J0 become integers. + def _power_counting(num): + _pow = {0} + for val in num: + if val.has(x): + if isinstance(val, Pow) and val.as_base_exp()[0] == x: + _pow.add(val.as_base_exp()[1]) + elif val == x: + _pow.add(val.as_base_exp()[1]) + else: + _pow.update(_power_counting(val.args)) + return _pow + + pow_num = _power_counting((num, )) + pow_dem = _power_counting((dem, )) + pow_dem.update(pow_num) + + _pow = pow_dem + k = gcd(_pow) + + # computing I0 of the given equation + I0 = powdenest(simplify(factor(((J1/k**2) - S(1)/4)/((x**k)**2))), force=True) + I0 = factor(cancel(powdenest(I0.subs(x, x**(S(1)/k)), force=True))) + + # Before this point I0, J1 might be functions of e.g. sqrt(x) but replacing + # x with x**(1/k) should result in I0 being a rational function of x or + # otherwise the hypergeometric solver cannot be used. Note that k can be a + # non-integer rational such as 2/7. + if not I0.is_rational_function(x): + return None + + num, dem = I0.as_numer_denom() + + max_num_pow = max(_power_counting((num, ))) + dem_args = dem.args + sing_point = [] + dem_pow = [] + # calculating singular point of I0. + for arg in dem_args: + if arg.has(x): + if isinstance(arg, Pow): + # (x-a)**n + dem_pow.append(arg.as_base_exp()[1]) + sing_point.append(list(roots(arg.as_base_exp()[0], x).keys())[0]) + else: + # (x-a) type + dem_pow.append(arg.as_base_exp()[1]) + sing_point.append(list(roots(arg, x).keys())[0]) + + dem_pow.sort() + # checking if equivalence is exists or not. + + if equivalence(max_num_pow, dem_pow) == "2F1": + return {'I0':I0, 'k':k, 'sing_point':sing_point, 'type':"2F1"} + else: + return None + + +def match_2nd_2F1_hypergeometric(I, k, sing_point, func): + x = func.args[0] + a = Wild("a") + b = Wild("b") + c = Wild("c") + t = Wild("t") + s = Wild("s") + r = Wild("r") + alpha = Wild("alpha") + beta = Wild("beta") + gamma = Wild("gamma") + delta = Wild("delta") + # I0 of the standerd 2F1 equation. + I0 = ((a-b+1)*(a-b-1)*x**2 + 2*((1-a-b)*c + 2*a*b)*x + c*(c-2))/(4*x**2*(x-1)**2) + if sing_point != [0, 1]: + # If singular point is [0, 1] then we have standerd equation. + eqs = [] + sing_eqs = [-beta/alpha, -delta/gamma, (delta-beta)/(alpha-gamma)] + # making equations for the finding the mobius transformation + for i in range(3): + if i>> from sympy import Function, Eq, pprint + >>> from sympy.abc import x, y + >>> xi, eta, h = map(Function, ['xi', 'eta', 'h']) + >>> h = h(x, y) # dy/dx = h + >>> eta = eta(x, y) + >>> xi = xi(x, y) + >>> genform = Eq(eta.diff(x) + (eta.diff(y) - xi.diff(x))*h + ... - (xi.diff(y))*h**2 - xi*(h.diff(x)) - eta*(h.diff(y)), 0) + >>> pprint(genform) + /d d \ d 2 d d d + |--(eta(x, y)) - --(xi(x, y))|*h(x, y) - eta(x, y)*--(h(x, y)) - h (x, y)*--(xi(x, y)) - xi(x, y)*--(h(x, y)) + --(eta(x, y)) = 0 + \dy dx / dy dy dx dx + + Solving the above mentioned PDE is not trivial, and can be solved only by + making intelligent assumptions for `\xi` and `\eta` (heuristics). Once an + infinitesimal is found, the attempt to find more heuristics stops. This is done to + optimise the speed of solving the differential equation. If a list of all the + infinitesimals is needed, ``hint`` should be flagged as ``all``, which gives + the complete list of infinitesimals. If the infinitesimals for a particular + heuristic needs to be found, it can be passed as a flag to ``hint``. + + Examples + ======== + + >>> from sympy import Function + >>> from sympy.solvers.ode.lie_group import infinitesimals + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = f(x).diff(x) - x**2*f(x) + >>> infinitesimals(eq) + [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0}] + + References + ========== + + - Solving differential equations by Symmetry Groups, + John Starrett, pp. 1 - pp. 14 + + """ + + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + if not func: + eq, func = _preprocess(eq) + variables = func.args + if len(variables) != 1: + raise ValueError("ODE's have only one independent variable") + else: + x = variables[0] + if not order: + order = ode_order(eq, func) + if order != 1: + raise NotImplementedError("Infinitesimals for only " + "first order ODE's have been implemented") + else: + df = func.diff(x) + # Matching differential equation of the form a*df + b + a = Wild('a', exclude = [df]) + b = Wild('b', exclude = [df]) + if match: # Used by lie_group hint + h = match['h'] + y = match['y'] + else: + match = collect(expand(eq), df).match(a*df + b) + if match: + h = -simplify(match[b]/match[a]) + else: + try: + sol = solve(eq, df) + except NotImplementedError: + raise NotImplementedError("Infinitesimals for the " + "first order ODE could not be found") + else: + h = sol[0] # Find infinitesimals for one solution + y = Dummy("y") + h = h.subs(func, y) + + u = Dummy("u") + hx = h.diff(x) + hy = h.diff(y) + hinv = ((1/h).subs([(x, u), (y, x)])).subs(u, y) # Inverse ODE + match = {'h': h, 'func': func, 'hx': hx, 'hy': hy, 'y': y, 'hinv': hinv} + if hint == 'all': + xieta = [] + for heuristic in lie_heuristics: + function = globals()['lie_heuristic_' + heuristic] + inflist = function(match, comp=True) + if inflist: + xieta.extend([inf for inf in inflist if inf not in xieta]) + if xieta: + return xieta + else: + raise NotImplementedError("Infinitesimals could not be found for " + "the given ODE") + + elif hint == 'default': + for heuristic in lie_heuristics: + function = globals()['lie_heuristic_' + heuristic] + xieta = function(match, comp=False) + if xieta: + return xieta + + raise NotImplementedError("Infinitesimals could not be found for" + " the given ODE") + + elif hint not in lie_heuristics: + raise ValueError("Heuristic not recognized: " + hint) + + else: + function = globals()['lie_heuristic_' + hint] + xieta = function(match, comp=True) + if xieta: + return xieta + else: + raise ValueError("Infinitesimals could not be found using the" + " given heuristic") + + +def lie_heuristic_abaco1_simple(match, comp=False): + r""" + The first heuristic uses the following four sets of + assumptions on `\xi` and `\eta` + + .. math:: \xi = 0, \eta = f(x) + + .. math:: \xi = 0, \eta = f(y) + + .. math:: \xi = f(x), \eta = 0 + + .. math:: \xi = f(y), \eta = 0 + + The success of this heuristic is determined by algebraic factorisation. + For the first assumption `\xi = 0` and `\eta` to be a function of `x`, the PDE + + .. math:: \frac{\partial \eta}{\partial x} + (\frac{\partial \eta}{\partial y} + - \frac{\partial \xi}{\partial x})*h + - \frac{\partial \xi}{\partial y}*h^{2} + - \xi*\frac{\partial h}{\partial x} - \eta*\frac{\partial h}{\partial y} = 0 + + reduces to `f'(x) - f\frac{\partial h}{\partial y} = 0` + If `\frac{\partial h}{\partial y}` is a function of `x`, then this can usually + be integrated easily. A similar idea is applied to the other 3 assumptions as well. + + + References + ========== + + - E.S Cheb-Terrab, L.G.S Duarte and L.A,C.P da Mota, Computer Algebra + Solving of First Order ODEs Using Symmetry Methods, pp. 8 + + + """ + + xieta = [] + y = match['y'] + h = match['h'] + func = match['func'] + x = func.args[0] + hx = match['hx'] + hy = match['hy'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + hysym = hy.free_symbols + if y not in hysym: + try: + fx = exp(integrate(hy, x)) + except NotImplementedError: + pass + else: + inf = {xi: S.Zero, eta: fx} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = hy/h + facsym = factor.free_symbols + if x not in facsym: + try: + fy = exp(integrate(factor, y)) + except NotImplementedError: + pass + else: + inf = {xi: S.Zero, eta: fy.subs(y, func)} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = -hx/h + facsym = factor.free_symbols + if y not in facsym: + try: + fx = exp(integrate(factor, x)) + except NotImplementedError: + pass + else: + inf = {xi: fx, eta: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + factor = -hx/(h**2) + facsym = factor.free_symbols + if x not in facsym: + try: + fy = exp(integrate(factor, y)) + except NotImplementedError: + pass + else: + inf = {xi: fy.subs(y, func), eta: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_abaco1_product(match, comp=False): + r""" + The second heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = 0, \xi = f(x)*g(y) + + .. math:: \eta = f(x)*g(y), \xi = 0 + + The first assumption of this heuristic holds good if + `\frac{1}{h^{2}}\frac{\partial^2}{\partial x \partial y}\log(h)` is + separable in `x` and `y`, then the separated factors containing `x` + is `f(x)`, and `g(y)` is obtained by + + .. math:: e^{\int f\frac{\partial}{\partial x}\left(\frac{1}{f*h}\right)\,dy} + + provided `f\frac{\partial}{\partial x}\left(\frac{1}{f*h}\right)` is a function + of `y` only. + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first assumption + satisfies. After obtaining `f(x)` and `g(y)`, the coordinates are again + interchanged, to get `\eta` as `f(x)*g(y)` + + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 7 - pp. 8 + + """ + + xieta = [] + y = match['y'] + h = match['h'] + hinv = match['hinv'] + func = match['func'] + x = func.args[0] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + + inf = separatevars(((log(h).diff(y)).diff(x))/h**2, dict=True, symbols=[x, y]) + if inf and inf['coeff']: + fx = inf[x] + gy = simplify(fx*((1/(fx*h)).diff(x))) + gysyms = gy.free_symbols + if x not in gysyms: + gy = exp(integrate(gy, y)) + inf = {eta: S.Zero, xi: (fx*gy).subs(y, func)} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + u1 = Dummy("u1") + inf = separatevars(((log(hinv).diff(y)).diff(x))/hinv**2, dict=True, symbols=[x, y]) + if inf and inf['coeff']: + fx = inf[x] + gy = simplify(fx*((1/(fx*hinv)).diff(x))) + gysyms = gy.free_symbols + if x not in gysyms: + gy = exp(integrate(gy, y)) + etaval = fx*gy + etaval = (etaval.subs([(x, u1), (y, x)])).subs(u1, y) + inf = {eta: etaval.subs(y, func), xi: S.Zero} + if not comp: + return [inf] + if comp and inf not in xieta: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_bivariate(match, comp=False): + r""" + The third heuristic assumes the infinitesimals `\xi` and `\eta` + to be bi-variate polynomials in `x` and `y`. The assumption made here + for the logic below is that `h` is a rational function in `x` and `y` + though that may not be necessary for the infinitesimals to be + bivariate polynomials. The coefficients of the infinitesimals + are found out by substituting them in the PDE and grouping similar terms + that are polynomials and since they form a linear system, solve and check + for non trivial solutions. The degree of the assumed bivariates + are increased till a certain maximum value. + + References + ========== + - Lie Groups and Differential Equations + pp. 327 - pp. 329 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + if h.is_rational_function(): + # The maximum degree that the infinitesimals can take is + # calculated by this technique. + etax, etay, etad, xix, xiy, xid = symbols("etax etay etad xix xiy xid") + ipde = etax + (etay - xix)*h - xiy*h**2 - xid*hx - etad*hy + num, denom = cancel(ipde).as_numer_denom() + deg = Poly(num, x, y).total_degree() + deta = Function('deta')(x, y) + dxi = Function('dxi')(x, y) + ipde = (deta.diff(x) + (deta.diff(y) - dxi.diff(x))*h - (dxi.diff(y))*h**2 + - dxi*hx - deta*hy) + xieq = Symbol("xi0") + etaeq = Symbol("eta0") + + for i in range(deg + 1): + if i: + xieq += Add(*[ + Symbol("xi_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + etaeq += Add(*[ + Symbol("eta_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + pden, denom = (ipde.subs({dxi: xieq, deta: etaeq}).doit()).as_numer_denom() + pden = expand(pden) + + # If the individual terms are monomials, the coefficients + # are grouped + if pden.is_polynomial(x, y) and pden.is_Add: + polyy = Poly(pden, x, y).as_dict() + if polyy: + symset = xieq.free_symbols.union(etaeq.free_symbols) - {x, y} + soldict = solve(polyy.values(), *symset) + if isinstance(soldict, list): + soldict = soldict[0] + if any(soldict.values()): + xired = xieq.subs(soldict) + etared = etaeq.subs(soldict) + # Scaling is done by substituting one for the parameters + # This can be any number except zero. + dict_ = dict.fromkeys(symset, 1) + inf = {eta: etared.subs(dict_).subs(y, func), + xi: xired.subs(dict_).subs(y, func)} + return [inf] + +def lie_heuristic_chi(match, comp=False): + r""" + The aim of the fourth heuristic is to find the function `\chi(x, y)` + that satisfies the PDE `\frac{d\chi}{dx} + h\frac{d\chi}{dx} + - \frac{\partial h}{\partial y}\chi = 0`. + + This assumes `\chi` to be a bivariate polynomial in `x` and `y`. By intuition, + `h` should be a rational function in `x` and `y`. The method used here is + to substitute a general binomial for `\chi` up to a certain maximum degree + is reached. The coefficients of the polynomials, are calculated by by collecting + terms of the same order in `x` and `y`. + + After finding `\chi`, the next step is to use `\eta = \xi*h + \chi`, to + determine `\xi` and `\eta`. This can be done by dividing `\chi` by `h` + which would give `-\xi` as the quotient and `\eta` as the remainder. + + + References + ========== + - E.S Cheb-Terrab, L.G.S Duarte and L.A,C.P da Mota, Computer Algebra + Solving of First Order ODEs Using Symmetry Methods, pp. 8 + + """ + + h = match['h'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + if h.is_rational_function(): + schi, schix, schiy = symbols("schi, schix, schiy") + cpde = schix + h*schiy - hy*schi + num, denom = cancel(cpde).as_numer_denom() + deg = Poly(num, x, y).total_degree() + + chi = Function('chi')(x, y) + chix = chi.diff(x) + chiy = chi.diff(y) + cpde = chix + h*chiy - hy*chi + chieq = Symbol("chi") + for i in range(1, deg + 1): + chieq += Add(*[ + Symbol("chi_" + str(power) + "_" + str(i - power))*x**power*y**(i - power) + for power in range(i + 1)]) + cnum, cden = cancel(cpde.subs({chi : chieq}).doit()).as_numer_denom() + cnum = expand(cnum) + if cnum.is_polynomial(x, y) and cnum.is_Add: + cpoly = Poly(cnum, x, y).as_dict() + if cpoly: + solsyms = chieq.free_symbols - {x, y} + soldict = solve(cpoly.values(), *solsyms) + if isinstance(soldict, list): + soldict = soldict[0] + if any(soldict.values()): + chieq = chieq.subs(soldict) + dict_ = dict.fromkeys(solsyms, 1) + chieq = chieq.subs(dict_) + # After finding chi, the main aim is to find out + # eta, xi by the equation eta = xi*h + chi + # One method to set xi, would be rearranging it to + # (eta/h) - xi = (chi/h). This would mean dividing + # chi by h would give -xi as the quotient and eta + # as the remainder. Thanks to Sean Vig for suggesting + # this method. + xic, etac = div(chieq, h) + inf = {eta: etac.subs(y, func), xi: -xic.subs(y, func)} + return [inf] + +def lie_heuristic_function_sum(match, comp=False): + r""" + This heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = 0, \xi = f(x) + g(y) + + .. math:: \eta = f(x) + g(y), \xi = 0 + + The first assumption of this heuristic holds good if + + .. math:: \frac{\partial}{\partial y}[(h\frac{\partial^{2}}{ + \partial x^{2}}(h^{-1}))^{-1}] + + is separable in `x` and `y`, + + 1. The separated factors containing `y` is `\frac{\partial g}{\partial y}`. + From this `g(y)` can be determined. + 2. The separated factors containing `x` is `f''(x)`. + 3. `h\frac{\partial^{2}}{\partial x^{2}}(h^{-1})` equals + `\frac{f''(x)}{f(x) + g(y)}`. From this `f(x)` can be determined. + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first + assumption satisfies. After obtaining `f(x)` and `g(y)`, the coordinates + are again interchanged, to get `\eta` as `f(x) + g(y)`. + + For both assumptions, the constant factors are separated among `g(y)` + and `f''(x)`, such that `f''(x)` obtained from 3] is the same as that + obtained from 2]. If not possible, then this heuristic fails. + + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 7 - pp. 8 + + """ + + xieta = [] + h = match['h'] + func = match['func'] + hinv = match['hinv'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + for odefac in [h, hinv]: + factor = odefac*((1/odefac).diff(x, 2)) + sep = separatevars((1/factor).diff(y), dict=True, symbols=[x, y]) + if sep and sep['coeff'] and sep[x].has(x) and sep[y].has(y): + k = Dummy("k") + try: + gy = k*integrate(sep[y], y) + except NotImplementedError: + pass + else: + fdd = 1/(k*sep[x]*sep['coeff']) + fx = simplify(fdd/factor - gy) + check = simplify(fx.diff(x, 2) - fdd) + if fx: + if not check: + fx = fx.subs(k, 1) + gy = (gy/k) + else: + sol = solve(check, k) + if sol: + sol = sol[0] + fx = fx.subs(k, sol) + gy = (gy/k)*sol + else: + continue + if odefac == hinv: # Inverse ODE + fx = fx.subs(x, y) + gy = gy.subs(y, x) + etaval = factor_terms(fx + gy) + if etaval.is_Mul: + etaval = Mul(*[arg for arg in etaval.args if arg.has(x, y)]) + if odefac == hinv: # Inverse ODE + inf = {eta: etaval.subs(y, func), xi : S.Zero} + else: + inf = {xi: etaval.subs(y, func), eta : S.Zero} + if not comp: + return [inf] + else: + xieta.append(inf) + + if xieta: + return xieta + +def lie_heuristic_abaco2_similar(match, comp=False): + r""" + This heuristic uses the following two assumptions on `\xi` and `\eta` + + .. math:: \eta = g(x), \xi = f(x) + + .. math:: \eta = f(y), \xi = g(y) + + For the first assumption, + + 1. First `\frac{\frac{\partial h}{\partial y}}{\frac{\partial^{2} h}{ + \partial yy}}` is calculated. Let us say this value is A + + 2. If this is constant, then `h` is matched to the form `A(x) + B(x)e^{ + \frac{y}{C}}` then, `\frac{e^{\int \frac{A(x)}{C} \,dx}}{B(x)}` gives `f(x)` + and `A(x)*f(x)` gives `g(x)` + + 3. Otherwise `\frac{\frac{\partial A}{\partial X}}{\frac{\partial A}{ + \partial Y}} = \gamma` is calculated. If + + a] `\gamma` is a function of `x` alone + + b] `\frac{\gamma\frac{\partial h}{\partial y} - \gamma'(x) - \frac{ + \partial h}{\partial x}}{h + \gamma} = G` is a function of `x` alone. + then, `e^{\int G \,dx}` gives `f(x)` and `-\gamma*f(x)` gives `g(x)` + + The second assumption holds good if `\frac{dy}{dx} = h(x, y)` is rewritten as + `\frac{dy}{dx} = \frac{1}{h(y, x)}` and the same properties of the first assumption + satisfies. After obtaining `f(x)` and `g(x)`, the coordinates are again + interchanged, to get `\xi` as `f(x^*)` and `\eta` as `g(y^*)` + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + hinv = match['hinv'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + factor = cancel(h.diff(y)/h.diff(y, 2)) + factorx = factor.diff(x) + factory = factor.diff(y) + if not factor.has(x) and not factor.has(y): + A = Wild('A', exclude=[y]) + B = Wild('B', exclude=[y]) + C = Wild('C', exclude=[x, y]) + match = h.match(A + B*exp(y/C)) + try: + tau = exp(-integrate(match[A]/match[C]), x)/match[B] + except NotImplementedError: + pass + else: + gx = match[A]*tau + return [{xi: tau, eta: gx}] + + else: + gamma = cancel(factorx/factory) + if not gamma.has(y): + tauint = cancel((gamma*hy - gamma.diff(x) - hx)/(h + gamma)) + if not tauint.has(y): + try: + tau = exp(integrate(tauint, x)) + except NotImplementedError: + pass + else: + gx = -tau*gamma + return [{xi: tau, eta: gx}] + + factor = cancel(hinv.diff(y)/hinv.diff(y, 2)) + factorx = factor.diff(x) + factory = factor.diff(y) + if not factor.has(x) and not factor.has(y): + A = Wild('A', exclude=[y]) + B = Wild('B', exclude=[y]) + C = Wild('C', exclude=[x, y]) + match = h.match(A + B*exp(y/C)) + try: + tau = exp(-integrate(match[A]/match[C]), x)/match[B] + except NotImplementedError: + pass + else: + gx = match[A]*tau + return [{eta: tau.subs(x, func), xi: gx.subs(x, func)}] + + else: + gamma = cancel(factorx/factory) + if not gamma.has(y): + tauint = cancel((gamma*hinv.diff(y) - gamma.diff(x) - hinv.diff(x))/( + hinv + gamma)) + if not tauint.has(y): + try: + tau = exp(integrate(tauint, x)) + except NotImplementedError: + pass + else: + gx = -tau*gamma + return [{eta: tau.subs(x, func), xi: gx.subs(x, func)}] + + +def lie_heuristic_abaco2_unique_unknown(match, comp=False): + r""" + This heuristic assumes the presence of unknown functions or known functions + with non-integer powers. + + 1. A list of all functions and non-integer powers containing x and y + 2. Loop over each element `f` in the list, find `\frac{\frac{\partial f}{\partial x}}{ + \frac{\partial f}{\partial x}} = R` + + If it is separable in `x` and `y`, let `X` be the factors containing `x`. Then + + a] Check if `\xi = X` and `\eta = -\frac{X}{R}` satisfy the PDE. If yes, then return + `\xi` and `\eta` + b] Check if `\xi = \frac{-R}{X}` and `\eta = -\frac{1}{X}` satisfy the PDE. + If yes, then return `\xi` and `\eta` + + If not, then check if + + a] :math:`\xi = -R,\eta = 1` + + b] :math:`\xi = 1, \eta = -\frac{1}{R}` + + are solutions. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + funclist = [] + for atom in h.atoms(Pow): + base, exp = atom.as_base_exp() + if base.has(x) and base.has(y): + if not exp.is_Integer: + funclist.append(atom) + + for function in h.atoms(AppliedUndef): + syms = function.free_symbols + if x in syms and y in syms: + funclist.append(function) + + for f in funclist: + frac = cancel(f.diff(y)/f.diff(x)) + sep = separatevars(frac, dict=True, symbols=[x, y]) + if sep and sep['coeff']: + xitry1 = sep[x] + etatry1 = -1/(sep[y]*sep['coeff']) + pde1 = etatry1.diff(y)*h - xitry1.diff(x)*h - xitry1*hx - etatry1*hy + if not simplify(pde1): + return [{xi: xitry1, eta: etatry1.subs(y, func)}] + xitry2 = 1/etatry1 + etatry2 = 1/xitry1 + pde2 = etatry2.diff(x) - (xitry2.diff(y))*h**2 - xitry2*hx - etatry2*hy + if not simplify(expand(pde2)): + return [{xi: xitry2.subs(y, func), eta: etatry2}] + + else: + etatry = -1/frac + pde = etatry.diff(x) + etatry.diff(y)*h - hx - etatry*hy + if not simplify(pde): + return [{xi: S.One, eta: etatry.subs(y, func)}] + xitry = -frac + pde = -xitry.diff(x)*h -xitry.diff(y)*h**2 - xitry*hx -hy + if not simplify(expand(pde)): + return [{xi: xitry.subs(y, func), eta: S.One}] + + +def lie_heuristic_abaco2_unique_general(match, comp=False): + r""" + This heuristic finds if infinitesimals of the form `\eta = f(x)`, `\xi = g(y)` + without making any assumptions on `h`. + + The complete sequence of steps is given in the paper mentioned below. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + A = hx.diff(y) + B = hy.diff(y) + hy**2 + C = hx.diff(x) - hx**2 + + if not (A and B and C): + return + + Ax = A.diff(x) + Ay = A.diff(y) + Axy = Ax.diff(y) + Axx = Ax.diff(x) + Ayy = Ay.diff(y) + D = simplify(2*Axy + hx*Ay - Ax*hy + (hx*hy + 2*A)*A)*A - 3*Ax*Ay + if not D: + E1 = simplify(3*Ax**2 + ((hx**2 + 2*C)*A - 2*Axx)*A) + if E1: + E2 = simplify((2*Ayy + (2*B - hy**2)*A)*A - 3*Ay**2) + if not E2: + E3 = simplify( + E1*((28*Ax + 4*hx*A)*A**3 - E1*(hy*A + Ay)) - E1.diff(x)*8*A**4) + if not E3: + etaval = cancel((4*A**3*(Ax - hx*A) + E1*(hy*A - Ay))/(S(2)*A*E1)) + if x not in etaval: + try: + etaval = exp(integrate(etaval, y)) + except NotImplementedError: + pass + else: + xival = -4*A**3*etaval/E1 + if y not in xival: + return [{xi: xival, eta: etaval.subs(y, func)}] + + else: + E1 = simplify((2*Ayy + (2*B - hy**2)*A)*A - 3*Ay**2) + if E1: + E2 = simplify( + 4*A**3*D - D**2 + E1*((2*Axx - (hx**2 + 2*C)*A)*A - 3*Ax**2)) + if not E2: + E3 = simplify( + -(A*D)*E1.diff(y) + ((E1.diff(x) - hy*D)*A + 3*Ay*D + + (A*hx - 3*Ax)*E1)*E1) + if not E3: + etaval = cancel(((A*hx - Ax)*E1 - (Ay + A*hy)*D)/(S(2)*A*D)) + if x not in etaval: + try: + etaval = exp(integrate(etaval, y)) + except NotImplementedError: + pass + else: + xival = -E1*etaval/D + if y not in xival: + return [{xi: xival, eta: etaval.subs(y, func)}] + + +def lie_heuristic_linear(match, comp=False): + r""" + This heuristic assumes + + 1. `\xi = ax + by + c` and + 2. `\eta = fx + gy + h` + + After substituting the following assumptions in the determining PDE, it + reduces to + + .. math:: f + (g - a)h - bh^{2} - (ax + by + c)\frac{\partial h}{\partial x} + - (fx + gy + c)\frac{\partial h}{\partial y} + + Solving the reduced PDE obtained, using the method of characteristics, becomes + impractical. The method followed is grouping similar terms and solving the system + of linear equations obtained. The difference between the bivariate heuristic is that + `h` need not be a rational function in this case. + + References + ========== + - E.S. Cheb-Terrab, A.D. Roche, Symmetries and First Order + ODE Patterns, pp. 10 - pp. 12 + + """ + h = match['h'] + hx = match['hx'] + hy = match['hy'] + func = match['func'] + x = func.args[0] + y = match['y'] + xi = Function('xi')(x, func) + eta = Function('eta')(x, func) + + coeffdict = {} + symbols = numbered_symbols("c", cls=Dummy) + symlist = [next(symbols) for _ in islice(symbols, 6)] + C0, C1, C2, C3, C4, C5 = symlist + pde = C3 + (C4 - C0)*h - (C0*x + C1*y + C2)*hx - (C3*x + C4*y + C5)*hy - C1*h**2 + pde, denom = pde.as_numer_denom() + pde = powsimp(expand(pde)) + if pde.is_Add: + terms = pde.args + for term in terms: + if term.is_Mul: + rem = Mul(*[m for m in term.args if not m.has(x, y)]) + xypart = term/rem + if xypart not in coeffdict: + coeffdict[xypart] = rem + else: + coeffdict[xypart] += rem + else: + if term not in coeffdict: + coeffdict[term] = S.One + else: + coeffdict[term] += S.One + + sollist = coeffdict.values() + soldict = solve(sollist, symlist) + if soldict: + if isinstance(soldict, list): + soldict = soldict[0] + subval = soldict.values() + if any(t for t in subval): + onedict = dict(zip(symlist, [1]*6)) + xival = C0*x + C1*func + C2 + etaval = C3*x + C4*func + C5 + xival = xival.subs(soldict) + etaval = etaval.subs(soldict) + xival = xival.subs(onedict) + etaval = etaval.subs(onedict) + return [{xi: xival, eta: etaval}] + + +def _lie_group_remove(coords): + r""" + This function is strictly meant for internal use by the Lie group ODE solving + method. It replaces arbitrary functions returned by pdsolve as follows: + + 1] If coords is an arbitrary function, then its argument is returned. + 2] An arbitrary function in an Add object is replaced by zero. + 3] An arbitrary function in a Mul object is replaced by one. + 4] If there is no arbitrary function coords is returned unchanged. + + Examples + ======== + + >>> from sympy.solvers.ode.lie_group import _lie_group_remove + >>> from sympy import Function + >>> from sympy.abc import x, y + >>> F = Function("F") + >>> eq = x**2*y + >>> _lie_group_remove(eq) + x**2*y + >>> eq = F(x**2*y) + >>> _lie_group_remove(eq) + x**2*y + >>> eq = x*y**2 + F(x**3) + >>> _lie_group_remove(eq) + x*y**2 + >>> eq = (F(x**3) + y)*x**4 + >>> _lie_group_remove(eq) + x**4*y + + """ + if isinstance(coords, AppliedUndef): + return coords.args[0] + elif coords.is_Add: + subfunc = coords.atoms(AppliedUndef) + if subfunc: + for func in subfunc: + coords = coords.subs(func, 0) + return coords + elif coords.is_Pow: + base, expr = coords.as_base_exp() + base = _lie_group_remove(base) + expr = _lie_group_remove(expr) + return base**expr + elif coords.is_Mul: + mulargs = [] + coordargs = coords.args + for arg in coordargs: + if not isinstance(coords, AppliedUndef): + mulargs.append(_lie_group_remove(arg)) + return Mul(*mulargs) + return coords diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/nonhomogeneous.py b/lib/python3.10/site-packages/sympy/solvers/ode/nonhomogeneous.py new file mode 100644 index 0000000000000000000000000000000000000000..87ff54074871f76304a60ec0e46aa3ff999df9ec --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/nonhomogeneous.py @@ -0,0 +1,499 @@ +r""" +This File contains helper functions for nth_linear_constant_coeff_undetermined_coefficients, +nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients, +nth_linear_constant_coeff_variation_of_parameters, +and nth_linear_euler_eq_nonhomogeneous_variation_of_parameters. + +All the functions in this file are used by more than one solvers so, instead of creating +instances in other classes for using them it is better to keep it here as separate helpers. + +""" +from collections import defaultdict +from sympy.core import Add, S +from sympy.core.function import diff, expand, _mexpand, expand_mul +from sympy.core.relational import Eq +from sympy.core.sorting import default_sort_key +from sympy.core.symbol import Dummy, Wild +from sympy.functions import exp, cos, cosh, im, log, re, sin, sinh, \ + atan2, conjugate +from sympy.integrals import Integral +from sympy.polys import (Poly, RootOf, rootof, roots) +from sympy.simplify import collect, simplify, separatevars, powsimp, trigsimp # type: ignore +from sympy.utilities import numbered_symbols +from sympy.solvers.solvers import solve +from sympy.matrices import wronskian +from .subscheck import sub_func_doit +from sympy.solvers.ode.ode import get_numbered_constants + + +def _test_term(coeff, func, order): + r""" + Linear Euler ODEs have the form K*x**order*diff(y(x), x, order) = F(x), + where K is independent of x and y(x), order>= 0. + So we need to check that for each term, coeff == K*x**order from + some K. We have a few cases, since coeff may have several + different types. + """ + x = func.args[0] + f = func.func + if order < 0: + raise ValueError("order should be greater than 0") + if coeff == 0: + return True + if order == 0: + if x in coeff.free_symbols: + return False + return True + if coeff.is_Mul: + if coeff.has(f(x)): + return False + return x**order in coeff.args + elif coeff.is_Pow: + return coeff.as_base_exp() == (x, order) + elif order == 1: + return x == coeff + return False + + +def _get_euler_characteristic_eq_sols(eq, func, match_obj): + r""" + Returns the solution of homogeneous part of the linear euler ODE and + the list of roots of characteristic equation. + + The parameter ``match_obj`` is a dict of order:coeff terms, where order is the order + of the derivative on each term, and coeff is the coefficient of that derivative. + + """ + x = func.args[0] + f = func.func + + # First, set up characteristic equation. + chareq, symbol = S.Zero, Dummy('x') + + for i in match_obj: + if i >= 0: + chareq += (match_obj[i]*diff(x**symbol, x, i)*x**-symbol).expand() + + chareq = Poly(chareq, symbol) + chareqroots = [rootof(chareq, k) for k in range(chareq.degree())] + collectterms = [] + + # A generator of constants + constants = list(get_numbered_constants(eq, num=chareq.degree()*2)) + constants.reverse() + + # Create a dict root: multiplicity or charroots + charroots = defaultdict(int) + for root in chareqroots: + charroots[root] += 1 + gsol = S.Zero + ln = log + for root, multiplicity in charroots.items(): + for i in range(multiplicity): + if isinstance(root, RootOf): + gsol += (x**root) * constants.pop() + if multiplicity != 1: + raise ValueError("Value should be 1") + collectterms = [(0, root, 0)] + collectterms + elif root.is_real: + gsol += ln(x)**i*(x**root) * constants.pop() + collectterms = [(i, root, 0)] + collectterms + else: + reroot = re(root) + imroot = im(root) + gsol += ln(x)**i * (x**reroot) * ( + constants.pop() * sin(abs(imroot)*ln(x)) + + constants.pop() * cos(imroot*ln(x))) + collectterms = [(i, reroot, imroot)] + collectterms + + gsol = Eq(f(x), gsol) + + gensols = [] + # Keep track of when to use sin or cos for nonzero imroot + for i, reroot, imroot in collectterms: + if imroot == 0: + gensols.append(ln(x)**i*x**reroot) + else: + sin_form = ln(x)**i*x**reroot*sin(abs(imroot)*ln(x)) + if sin_form in gensols: + cos_form = ln(x)**i*x**reroot*cos(imroot*ln(x)) + gensols.append(cos_form) + else: + gensols.append(sin_form) + return gsol, gensols + + +def _solve_variation_of_parameters(eq, func, roots, homogen_sol, order, match_obj, simplify_flag=True): + r""" + Helper function for the method of variation of parameters and nonhomogeneous euler eq. + + See the + :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffVariationOfParameters` + docstring for more information on this method. + + The parameter are ``match_obj`` should be a dictionary that has the following + keys: + + ``list`` + A list of solutions to the homogeneous equation. + + ``sol`` + The general solution. + + """ + f = func.func + x = func.args[0] + r = match_obj + psol = 0 + wr = wronskian(roots, x) + + if simplify_flag: + wr = simplify(wr) # We need much better simplification for + # some ODEs. See issue 4662, for example. + # To reduce commonly occurring sin(x)**2 + cos(x)**2 to 1 + wr = trigsimp(wr, deep=True, recursive=True) + if not wr: + # The wronskian will be 0 iff the solutions are not linearly + # independent. + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply " + + "variation of parameters to " + str(eq) + " (Wronskian == 0)") + if len(roots) != order: + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply " + + "variation of parameters to " + + str(eq) + " (number of terms != order)") + negoneterm = S.NegativeOne**(order) + for i in roots: + psol += negoneterm*Integral(wronskian([sol for sol in roots if sol != i], x)*r[-1]/wr, x)*i/r[order] + negoneterm *= -1 + + if simplify_flag: + psol = simplify(psol) + psol = trigsimp(psol, deep=True) + return Eq(f(x), homogen_sol.rhs + psol) + + +def _get_const_characteristic_eq_sols(r, func, order): + r""" + Returns the roots of characteristic equation of constant coefficient + linear ODE and list of collectterms which is later on used by simplification + to use collect on solution. + + The parameter `r` is a dict of order:coeff terms, where order is the order of the + derivative on each term, and coeff is the coefficient of that derivative. + + """ + x = func.args[0] + # First, set up characteristic equation. + chareq, symbol = S.Zero, Dummy('x') + + for i in r.keys(): + if isinstance(i, str) or i < 0: + pass + else: + chareq += r[i]*symbol**i + + chareq = Poly(chareq, symbol) + # Can't just call roots because it doesn't return rootof for unsolveable + # polynomials. + chareqroots = roots(chareq, multiple=True) + if len(chareqroots) != order: + chareqroots = [rootof(chareq, k) for k in range(chareq.degree())] + + chareq_is_complex = not all(i.is_real for i in chareq.all_coeffs()) + + # Create a dict root: multiplicity or charroots + charroots = defaultdict(int) + for root in chareqroots: + charroots[root] += 1 + # We need to keep track of terms so we can run collect() at the end. + # This is necessary for constantsimp to work properly. + collectterms = [] + gensols = [] + conjugate_roots = [] # used to prevent double-use of conjugate roots + # Loop over roots in theorder provided by roots/rootof... + for root in chareqroots: + # but don't repoeat multiple roots. + if root not in charroots: + continue + multiplicity = charroots.pop(root) + for i in range(multiplicity): + if chareq_is_complex: + gensols.append(x**i*exp(root*x)) + collectterms = [(i, root, 0)] + collectterms + continue + reroot = re(root) + imroot = im(root) + if imroot.has(atan2) and reroot.has(atan2): + # Remove this condition when re and im stop returning + # circular atan2 usages. + gensols.append(x**i*exp(root*x)) + collectterms = [(i, root, 0)] + collectterms + else: + if root in conjugate_roots: + collectterms = [(i, reroot, imroot)] + collectterms + continue + if imroot == 0: + gensols.append(x**i*exp(reroot*x)) + collectterms = [(i, reroot, 0)] + collectterms + continue + conjugate_roots.append(conjugate(root)) + gensols.append(x**i*exp(reroot*x) * sin(abs(imroot) * x)) + gensols.append(x**i*exp(reroot*x) * cos( imroot * x)) + + # This ordering is important + collectterms = [(i, reroot, imroot)] + collectterms + return gensols, collectterms + + +# Ideally these kind of simplification functions shouldn't be part of solvers. +# odesimp should be improved to handle these kind of specific simplifications. +def _get_simplified_sol(sol, func, collectterms): + r""" + Helper function which collects the solution on + collectterms. Ideally this should be handled by odesimp.It is used + only when the simplify is set to True in dsolve. + + The parameter ``collectterms`` is a list of tuple (i, reroot, imroot) where `i` is + the multiplicity of the root, reroot is real part and imroot being the imaginary part. + + """ + f = func.func + x = func.args[0] + collectterms.sort(key=default_sort_key) + collectterms.reverse() + assert len(sol) == 1 and sol[0].lhs == f(x) + sol = sol[0].rhs + sol = expand_mul(sol) + for i, reroot, imroot in collectterms: + sol = collect(sol, x**i*exp(reroot*x)*sin(abs(imroot)*x)) + sol = collect(sol, x**i*exp(reroot*x)*cos(imroot*x)) + for i, reroot, imroot in collectterms: + sol = collect(sol, x**i*exp(reroot*x)) + sol = powsimp(sol) + return Eq(f(x), sol) + + +def _undetermined_coefficients_match(expr, x, func=None, eq_homogeneous=S.Zero): + r""" + Returns a trial function match if undetermined coefficients can be applied + to ``expr``, and ``None`` otherwise. + + A trial expression can be found for an expression for use with the method + of undetermined coefficients if the expression is an + additive/multiplicative combination of constants, polynomials in `x` (the + independent variable of expr), `\sin(a x + b)`, `\cos(a x + b)`, and + `e^{a x}` terms (in other words, it has a finite number of linearly + independent derivatives). + + Note that you may still need to multiply each term returned here by + sufficient `x` to make it linearly independent with the solutions to the + homogeneous equation. + + This is intended for internal use by ``undetermined_coefficients`` hints. + + SymPy currently has no way to convert `\sin^n(x) \cos^m(y)` into a sum of + only `\sin(a x)` and `\cos(b x)` terms, so these are not implemented. So, + for example, you will need to manually convert `\sin^2(x)` into `[1 + + \cos(2 x)]/2` to properly apply the method of undetermined coefficients on + it. + + Examples + ======== + + >>> from sympy import log, exp + >>> from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match + >>> from sympy.abc import x + >>> _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x) + {'test': True, 'trialset': {x*exp(x), exp(-x), exp(x)}} + >>> _undetermined_coefficients_match(log(x), x) + {'test': False} + + """ + a = Wild('a', exclude=[x]) + b = Wild('b', exclude=[x]) + expr = powsimp(expr, combine='exp') # exp(x)*exp(2*x + 1) => exp(3*x + 1) + retdict = {} + + def _test_term(expr, x): + r""" + Test if ``expr`` fits the proper form for undetermined coefficients. + """ + if not expr.has(x): + return True + elif expr.is_Add: + return all(_test_term(i, x) for i in expr.args) + elif expr.is_Mul: + if expr.has(sin, cos): + foundtrig = False + # Make sure that there is only one trig function in the args. + # See the docstring. + for i in expr.args: + if i.has(sin, cos): + if foundtrig: + return False + else: + foundtrig = True + return all(_test_term(i, x) for i in expr.args) + elif expr.is_Function: + if expr.func in (sin, cos, exp, sinh, cosh): + if expr.args[0].match(a*x + b): + return True + else: + return False + else: + return False + elif expr.is_Pow and expr.base.is_Symbol and expr.exp.is_Integer and \ + expr.exp >= 0: + return True + elif expr.is_Pow and expr.base.is_number: + if expr.exp.match(a*x + b): + return True + else: + return False + elif expr.is_Symbol or expr.is_number: + return True + else: + return False + + def _get_trial_set(expr, x, exprs=set()): + r""" + Returns a set of trial terms for undetermined coefficients. + + The idea behind undetermined coefficients is that the terms expression + repeat themselves after a finite number of derivatives, except for the + coefficients (they are linearly dependent). So if we collect these, + we should have the terms of our trial function. + """ + def _remove_coefficient(expr, x): + r""" + Returns the expression without a coefficient. + + Similar to expr.as_independent(x)[1], except it only works + multiplicatively. + """ + term = S.One + if expr.is_Mul: + for i in expr.args: + if i.has(x): + term *= i + elif expr.has(x): + term = expr + return term + + expr = expand_mul(expr) + if expr.is_Add: + for term in expr.args: + if _remove_coefficient(term, x) in exprs: + pass + else: + exprs.add(_remove_coefficient(term, x)) + exprs = exprs.union(_get_trial_set(term, x, exprs)) + else: + term = _remove_coefficient(expr, x) + tmpset = exprs.union({term}) + oldset = set() + while tmpset != oldset: + # If you get stuck in this loop, then _test_term is probably + # broken + oldset = tmpset.copy() + expr = expr.diff(x) + term = _remove_coefficient(expr, x) + if term.is_Add: + tmpset = tmpset.union(_get_trial_set(term, x, tmpset)) + else: + tmpset.add(term) + exprs = tmpset + return exprs + + def is_homogeneous_solution(term): + r""" This function checks whether the given trialset contains any root + of homogeneous equation""" + return expand(sub_func_doit(eq_homogeneous, func, term)).is_zero + + retdict['test'] = _test_term(expr, x) + if retdict['test']: + # Try to generate a list of trial solutions that will have the + # undetermined coefficients. Note that if any of these are not linearly + # independent with any of the solutions to the homogeneous equation, + # then they will need to be multiplied by sufficient x to make them so. + # This function DOES NOT do that (it doesn't even look at the + # homogeneous equation). + temp_set = set() + for i in Add.make_args(expr): + act = _get_trial_set(i, x) + if eq_homogeneous is not S.Zero: + while any(is_homogeneous_solution(ts) for ts in act): + act = {x*ts for ts in act} + temp_set = temp_set.union(act) + + retdict['trialset'] = temp_set + return retdict + + +def _solve_undetermined_coefficients(eq, func, order, match, trialset): + r""" + Helper function for the method of undetermined coefficients. + + See the + :py:meth:`~sympy.solvers.ode.single.NthLinearConstantCoeffUndeterminedCoefficients` + docstring for more information on this method. + + The parameter ``trialset`` is the set of trial functions as returned by + ``_undetermined_coefficients_match()['trialset']``. + + The parameter ``match`` should be a dictionary that has the following + keys: + + ``list`` + A list of solutions to the homogeneous equation. + + ``sol`` + The general solution. + + """ + r = match + coeffs = numbered_symbols('a', cls=Dummy) + coefflist = [] + gensols = r['list'] + gsol = r['sol'] + f = func.func + x = func.args[0] + + if len(gensols) != order: + raise NotImplementedError("Cannot find " + str(order) + + " solutions to the homogeneous equation necessary to apply" + + " undetermined coefficients to " + str(eq) + + " (number of terms != order)") + + trialfunc = 0 + for i in trialset: + c = next(coeffs) + coefflist.append(c) + trialfunc += c*i + + eqs = sub_func_doit(eq, f(x), trialfunc) + + coeffsdict = dict(list(zip(trialset, [0]*(len(trialset) + 1)))) + + eqs = _mexpand(eqs) + + for i in Add.make_args(eqs): + s = separatevars(i, dict=True, symbols=[x]) + if coeffsdict.get(s[x]): + coeffsdict[s[x]] += s['coeff'] + else: + coeffsdict[s[x]] = s['coeff'] + + coeffvals = solve(list(coeffsdict.values()), coefflist) + + if not coeffvals: + raise NotImplementedError( + "Could not solve `%s` using the " + "method of undetermined coefficients " + "(unable to solve for coefficients)." % eq) + + psol = trialfunc.subs(coeffvals) + + return Eq(f(x), gsol.rhs + psol) diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/ode.py b/lib/python3.10/site-packages/sympy/solvers/ode/ode.py new file mode 100644 index 0000000000000000000000000000000000000000..be82bb18a5c800ae3848605d636d9744a011b6bc --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/ode.py @@ -0,0 +1,3575 @@ +r""" +This module contains :py:meth:`~sympy.solvers.ode.dsolve` and different helper +functions that it uses. + +:py:meth:`~sympy.solvers.ode.dsolve` solves ordinary differential equations. +See the docstring on the various functions for their uses. Note that partial +differential equations support is in ``pde.py``. Note that hint functions +have docstrings describing their various methods, but they are intended for +internal use. Use ``dsolve(ode, func, hint=hint)`` to solve an ODE using a +specific hint. See also the docstring on +:py:meth:`~sympy.solvers.ode.dsolve`. + +**Functions in this module** + + These are the user functions in this module: + + - :py:meth:`~sympy.solvers.ode.dsolve` - Solves ODEs. + - :py:meth:`~sympy.solvers.ode.classify_ode` - Classifies ODEs into + possible hints for :py:meth:`~sympy.solvers.ode.dsolve`. + - :py:meth:`~sympy.solvers.ode.checkodesol` - Checks if an equation is the + solution to an ODE. + - :py:meth:`~sympy.solvers.ode.homogeneous_order` - Returns the + homogeneous order of an expression. + - :py:meth:`~sympy.solvers.ode.infinitesimals` - Returns the infinitesimals + of the Lie group of point transformations of an ODE, such that it is + invariant. + - :py:meth:`~sympy.solvers.ode.checkinfsol` - Checks if the given infinitesimals + are the actual infinitesimals of a first order ODE. + + These are the non-solver helper functions that are for internal use. The + user should use the various options to + :py:meth:`~sympy.solvers.ode.dsolve` to obtain the functionality provided + by these functions: + + - :py:meth:`~sympy.solvers.ode.ode.odesimp` - Does all forms of ODE + simplification. + - :py:meth:`~sympy.solvers.ode.ode.ode_sol_simplicity` - A key function for + comparing solutions by simplicity. + - :py:meth:`~sympy.solvers.ode.constantsimp` - Simplifies arbitrary + constants. + - :py:meth:`~sympy.solvers.ode.ode.constant_renumber` - Renumber arbitrary + constants. + - :py:meth:`~sympy.solvers.ode.ode._handle_Integral` - Evaluate unevaluated + Integrals. + + See also the docstrings of these functions. + +**Currently implemented solver methods** + +The following methods are implemented for solving ordinary differential +equations. See the docstrings of the various hint functions for more +information on each (run ``help(ode)``): + + - 1st order separable differential equations. + - 1st order differential equations whose coefficients or `dx` and `dy` are + functions homogeneous of the same order. + - 1st order exact differential equations. + - 1st order linear differential equations. + - 1st order Bernoulli differential equations. + - Power series solutions for first order differential equations. + - Lie Group method of solving first order differential equations. + - 2nd order Liouville differential equations. + - Power series solutions for second order differential equations + at ordinary and regular singular points. + - `n`\th order differential equation that can be solved with algebraic + rearrangement and integration. + - `n`\th order linear homogeneous differential equation with constant + coefficients. + - `n`\th order linear inhomogeneous differential equation with constant + coefficients using the method of undetermined coefficients. + - `n`\th order linear inhomogeneous differential equation with constant + coefficients using the method of variation of parameters. + +**Philosophy behind this module** + +This module is designed to make it easy to add new ODE solving methods without +having to mess with the solving code for other methods. The idea is that +there is a :py:meth:`~sympy.solvers.ode.classify_ode` function, which takes in +an ODE and tells you what hints, if any, will solve the ODE. It does this +without attempting to solve the ODE, so it is fast. Each solving method is a +hint, and it has its own function, named ``ode_``. That function takes +in the ODE and any match expression gathered by +:py:meth:`~sympy.solvers.ode.classify_ode` and returns a solved result. If +this result has any integrals in it, the hint function will return an +unevaluated :py:class:`~sympy.integrals.integrals.Integral` class. +:py:meth:`~sympy.solvers.ode.dsolve`, which is the user wrapper function +around all of this, will then call :py:meth:`~sympy.solvers.ode.ode.odesimp` on +the result, which, among other things, will attempt to solve the equation for +the dependent variable (the function we are solving for), simplify the +arbitrary constants in the expression, and evaluate any integrals, if the hint +allows it. + +**How to add new solution methods** + +If you have an ODE that you want :py:meth:`~sympy.solvers.ode.dsolve` to be +able to solve, try to avoid adding special case code here. Instead, try +finding a general method that will solve your ODE, as well as others. This +way, the :py:mod:`~sympy.solvers.ode` module will become more robust, and +unhindered by special case hacks. WolphramAlpha and Maple's +DETools[odeadvisor] function are two resources you can use to classify a +specific ODE. It is also better for a method to work with an `n`\th order ODE +instead of only with specific orders, if possible. + +To add a new method, there are a few things that you need to do. First, you +need a hint name for your method. Try to name your hint so that it is +unambiguous with all other methods, including ones that may not be implemented +yet. If your method uses integrals, also include a ``hint_Integral`` hint. +If there is more than one way to solve ODEs with your method, include a hint +for each one, as well as a ``_best`` hint. Your ``ode__best()`` +function should choose the best using min with ``ode_sol_simplicity`` as the +key argument. See +:obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest`, for example. +The function that uses your method will be called ``ode_()``, so the +hint must only use characters that are allowed in a Python function name +(alphanumeric characters and the underscore '``_``' character). Include a +function for every hint, except for ``_Integral`` hints +(:py:meth:`~sympy.solvers.ode.dsolve` takes care of those automatically). +Hint names should be all lowercase, unless a word is commonly capitalized +(such as Integral or Bernoulli). If you have a hint that you do not want to +run with ``all_Integral`` that does not have an ``_Integral`` counterpart (such +as a best hint that would defeat the purpose of ``all_Integral``), you will +need to remove it manually in the :py:meth:`~sympy.solvers.ode.dsolve` code. +See also the :py:meth:`~sympy.solvers.ode.classify_ode` docstring for +guidelines on writing a hint name. + +Determine *in general* how the solutions returned by your method compare with +other methods that can potentially solve the same ODEs. Then, put your hints +in the :py:data:`~sympy.solvers.ode.allhints` tuple in the order that they +should be called. The ordering of this tuple determines which hints are +default. Note that exceptions are ok, because it is easy for the user to +choose individual hints with :py:meth:`~sympy.solvers.ode.dsolve`. In +general, ``_Integral`` variants should go at the end of the list, and +``_best`` variants should go before the various hints they apply to. For +example, the ``undetermined_coefficients`` hint comes before the +``variation_of_parameters`` hint because, even though variation of parameters +is more general than undetermined coefficients, undetermined coefficients +generally returns cleaner results for the ODEs that it can solve than +variation of parameters does, and it does not require integration, so it is +much faster. + +Next, you need to have a match expression or a function that matches the type +of the ODE, which you should put in :py:meth:`~sympy.solvers.ode.classify_ode` +(if the match function is more than just a few lines. It should match the +ODE without solving for it as much as possible, so that +:py:meth:`~sympy.solvers.ode.classify_ode` remains fast and is not hindered by +bugs in solving code. Be sure to consider corner cases. For example, if your +solution method involves dividing by something, make sure you exclude the case +where that division will be 0. + +In most cases, the matching of the ODE will also give you the various parts +that you need to solve it. You should put that in a dictionary (``.match()`` +will do this for you), and add that as ``matching_hints['hint'] = matchdict`` +in the relevant part of :py:meth:`~sympy.solvers.ode.classify_ode`. +:py:meth:`~sympy.solvers.ode.classify_ode` will then send this to +:py:meth:`~sympy.solvers.ode.dsolve`, which will send it to your function as +the ``match`` argument. Your function should be named ``ode_(eq, func, +order, match)`. If you need to send more information, put it in the ``match`` +dictionary. For example, if you had to substitute in a dummy variable in +:py:meth:`~sympy.solvers.ode.classify_ode` to match the ODE, you will need to +pass it to your function using the `match` dict to access it. You can access +the independent variable using ``func.args[0]``, and the dependent variable +(the function you are trying to solve for) as ``func.func``. If, while trying +to solve the ODE, you find that you cannot, raise ``NotImplementedError``. +:py:meth:`~sympy.solvers.ode.dsolve` will catch this error with the ``all`` +meta-hint, rather than causing the whole routine to fail. + +Add a docstring to your function that describes the method employed. Like +with anything else in SymPy, you will need to add a doctest to the docstring, +in addition to real tests in ``test_ode.py``. Try to maintain consistency +with the other hint functions' docstrings. Add your method to the list at the +top of this docstring. Also, add your method to ``ode.rst`` in the +``docs/src`` directory, so that the Sphinx docs will pull its docstring into +the main SymPy documentation. Be sure to make the Sphinx documentation by +running ``make html`` from within the doc directory to verify that the +docstring formats correctly. + +If your solution method involves integrating, use :py:obj:`~.Integral` instead of +:py:meth:`~sympy.core.expr.Expr.integrate`. This allows the user to bypass +hard/slow integration by using the ``_Integral`` variant of your hint. In +most cases, calling :py:meth:`sympy.core.basic.Basic.doit` will integrate your +solution. If this is not the case, you will need to write special code in +:py:meth:`~sympy.solvers.ode.ode._handle_Integral`. Arbitrary constants should be +symbols named ``C1``, ``C2``, and so on. All solution methods should return +an equality instance. If you need an arbitrary number of arbitrary constants, +you can use ``constants = numbered_symbols(prefix='C', cls=Symbol, start=1)``. +If it is possible to solve for the dependent function in a general way, do so. +Otherwise, do as best as you can, but do not call solve in your +``ode_()`` function. :py:meth:`~sympy.solvers.ode.ode.odesimp` will attempt +to solve the solution for you, so you do not need to do that. Lastly, if your +ODE has a common simplification that can be applied to your solutions, you can +add a special case in :py:meth:`~sympy.solvers.ode.ode.odesimp` for it. For +example, solutions returned from the ``1st_homogeneous_coeff`` hints often +have many :obj:`~sympy.functions.elementary.exponential.log` terms, so +:py:meth:`~sympy.solvers.ode.ode.odesimp` calls +:py:meth:`~sympy.simplify.simplify.logcombine` on them (it also helps to write +the arbitrary constant as ``log(C1)`` instead of ``C1`` in this case). Also +consider common ways that you can rearrange your solution to have +:py:meth:`~sympy.solvers.ode.constantsimp` take better advantage of it. It is +better to put simplification in :py:meth:`~sympy.solvers.ode.ode.odesimp` than in +your method, because it can then be turned off with the simplify flag in +:py:meth:`~sympy.solvers.ode.dsolve`. If you have any extraneous +simplification in your function, be sure to only run it using ``if +match.get('simplify', True):``, especially if it can be slow or if it can +reduce the domain of the solution. + +Finally, as with every contribution to SymPy, your method will need to be +tested. Add a test for each method in ``test_ode.py``. Follow the +conventions there, i.e., test the solver using ``dsolve(eq, f(x), +hint=your_hint)``, and also test the solution using +:py:meth:`~sympy.solvers.ode.checkodesol` (you can put these in a separate +tests and skip/XFAIL if it runs too slow/does not work). Be sure to call your +hint specifically in :py:meth:`~sympy.solvers.ode.dsolve`, that way the test +will not be broken simply by the introduction of another matching hint. If your +method works for higher order (>1) ODEs, you will need to run ``sol = +constant_renumber(sol, 'C', 1, order)`` for each solution, where ``order`` is +the order of the ODE. This is because ``constant_renumber`` renumbers the +arbitrary constants by printing order, which is platform dependent. Try to +test every corner case of your solver, including a range of orders if it is a +`n`\th order solver, but if your solver is slow, such as if it involves hard +integration, try to keep the test run time down. + +Feel free to refactor existing hints to avoid duplicating code or creating +inconsistencies. If you can show that your method exactly duplicates an +existing method, including in the simplicity and speed of obtaining the +solutions, then you can remove the old, less general method. The existing +code is tested extensively in ``test_ode.py``, so if anything is broken, one +of those tests will surely fail. + +""" + +from sympy.core import Add, S, Mul, Pow, oo +from sympy.core.containers import Tuple +from sympy.core.expr import AtomicExpr, Expr +from sympy.core.function import (Function, Derivative, AppliedUndef, diff, + expand, expand_mul, Subs) +from sympy.core.multidimensional import vectorize +from sympy.core.numbers import nan, zoo, Number +from sympy.core.relational import Equality, Eq +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Symbol, Wild, Dummy, symbols +from sympy.core.sympify import sympify +from sympy.core.traversal import preorder_traversal + +from sympy.logic.boolalg import (BooleanAtom, BooleanTrue, + BooleanFalse) +from sympy.functions import exp, log, sqrt +from sympy.functions.combinatorial.factorials import factorial +from sympy.integrals.integrals import Integral +from sympy.polys import (Poly, terms_gcd, PolynomialError, lcm) +from sympy.polys.polytools import cancel +from sympy.series import Order +from sympy.series.series import series +from sympy.simplify import (collect, logcombine, powsimp, # type: ignore + separatevars, simplify, cse) +from sympy.simplify.radsimp import collect_const +from sympy.solvers import checksol, solve + +from sympy.utilities import numbered_symbols +from sympy.utilities.iterables import uniq, sift, iterable +from sympy.solvers.deutils import _preprocess, ode_order, _desolve + + +#: This is a list of hints in the order that they should be preferred by +#: :py:meth:`~sympy.solvers.ode.classify_ode`. In general, hints earlier in the +#: list should produce simpler solutions than those later in the list (for +#: ODEs that fit both). For now, the order of this list is based on empirical +#: observations by the developers of SymPy. +#: +#: The hint used by :py:meth:`~sympy.solvers.ode.dsolve` for a specific ODE +#: can be overridden (see the docstring). +#: +#: In general, ``_Integral`` hints are grouped at the end of the list, unless +#: there is a method that returns an unevaluable integral most of the time +#: (which go near the end of the list anyway). ``default``, ``all``, +#: ``best``, and ``all_Integral`` meta-hints should not be included in this +#: list, but ``_best`` and ``_Integral`` hints should be included. +allhints = ( + "factorable", + "nth_algebraic", + "separable", + "1st_exact", + "1st_linear", + "Bernoulli", + "1st_rational_riccati", + "Riccati_special_minus2", + "1st_homogeneous_coeff_best", + "1st_homogeneous_coeff_subs_indep_div_dep", + "1st_homogeneous_coeff_subs_dep_div_indep", + "almost_linear", + "linear_coefficients", + "separable_reduced", + "1st_power_series", + "lie_group", + "nth_linear_constant_coeff_homogeneous", + "nth_linear_euler_eq_homogeneous", + "nth_linear_constant_coeff_undetermined_coefficients", + "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients", + "nth_linear_constant_coeff_variation_of_parameters", + "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters", + "Liouville", + "2nd_linear_airy", + "2nd_linear_bessel", + "2nd_hypergeometric", + "2nd_hypergeometric_Integral", + "nth_order_reducible", + "2nd_power_series_ordinary", + "2nd_power_series_regular", + "nth_algebraic_Integral", + "separable_Integral", + "1st_exact_Integral", + "1st_linear_Integral", + "Bernoulli_Integral", + "1st_homogeneous_coeff_subs_indep_div_dep_Integral", + "1st_homogeneous_coeff_subs_dep_div_indep_Integral", + "almost_linear_Integral", + "linear_coefficients_Integral", + "separable_reduced_Integral", + "nth_linear_constant_coeff_variation_of_parameters_Integral", + "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters_Integral", + "Liouville_Integral", + "2nd_nonlinear_autonomous_conserved", + "2nd_nonlinear_autonomous_conserved_Integral", + ) + + + +def get_numbered_constants(eq, num=1, start=1, prefix='C'): + """ + Returns a list of constants that do not occur + in eq already. + """ + + ncs = iter_numbered_constants(eq, start, prefix) + Cs = [next(ncs) for i in range(num)] + return (Cs[0] if num == 1 else tuple(Cs)) + + +def iter_numbered_constants(eq, start=1, prefix='C'): + """ + Returns an iterator of constants that do not occur + in eq already. + """ + + if isinstance(eq, (Expr, Eq)): + eq = [eq] + elif not iterable(eq): + raise ValueError("Expected Expr or iterable but got %s" % eq) + + atom_set = set().union(*[i.free_symbols for i in eq]) + func_set = set().union(*[i.atoms(Function) for i in eq]) + if func_set: + atom_set |= {Symbol(str(f.func)) for f in func_set} + return numbered_symbols(start=start, prefix=prefix, exclude=atom_set) + + +def dsolve(eq, func=None, hint="default", simplify=True, + ics= None, xi=None, eta=None, x0=0, n=6, **kwargs): + r""" + Solves any (supported) kind of ordinary differential equation and + system of ordinary differential equations. + + For single ordinary differential equation + ========================================= + + It is classified under this when number of equation in ``eq`` is one. + **Usage** + + ``dsolve(eq, f(x), hint)`` -> Solve ordinary differential equation + ``eq`` for function ``f(x)``, using method ``hint``. + + **Details** + + ``eq`` can be any supported ordinary differential equation (see the + :py:mod:`~sympy.solvers.ode` docstring for supported methods). + This can either be an :py:class:`~sympy.core.relational.Equality`, + or an expression, which is assumed to be equal to ``0``. + + ``f(x)`` is a function of one variable whose derivatives in that + variable make up the ordinary differential equation ``eq``. In + many cases it is not necessary to provide this; it will be + autodetected (and an error raised if it could not be detected). + + ``hint`` is the solving method that you want dsolve to use. Use + ``classify_ode(eq, f(x))`` to get all of the possible hints for an + ODE. The default hint, ``default``, will use whatever hint is + returned first by :py:meth:`~sympy.solvers.ode.classify_ode`. See + Hints below for more options that you can use for hint. + + ``simplify`` enables simplification by + :py:meth:`~sympy.solvers.ode.ode.odesimp`. See its docstring for more + information. Turn this off, for example, to disable solving of + solutions for ``func`` or simplification of arbitrary constants. + It will still integrate with this hint. Note that the solution may + contain more arbitrary constants than the order of the ODE with + this option enabled. + + ``xi`` and ``eta`` are the infinitesimal functions of an ordinary + differential equation. They are the infinitesimals of the Lie group + of point transformations for which the differential equation is + invariant. The user can specify values for the infinitesimals. If + nothing is specified, ``xi`` and ``eta`` are calculated using + :py:meth:`~sympy.solvers.ode.infinitesimals` with the help of various + heuristics. + + ``ics`` is the set of initial/boundary conditions for the differential equation. + It should be given in the form of ``{f(x0): x1, f(x).diff(x).subs(x, x2): + x3}`` and so on. For power series solutions, if no initial + conditions are specified ``f(0)`` is assumed to be ``C0`` and the power + series solution is calculated about 0. + + ``x0`` is the point about which the power series solution of a differential + equation is to be evaluated. + + ``n`` gives the exponent of the dependent variable up to which the power series + solution of a differential equation is to be evaluated. + + **Hints** + + Aside from the various solving methods, there are also some meta-hints + that you can pass to :py:meth:`~sympy.solvers.ode.dsolve`: + + ``default``: + This uses whatever hint is returned first by + :py:meth:`~sympy.solvers.ode.classify_ode`. This is the + default argument to :py:meth:`~sympy.solvers.ode.dsolve`. + + ``all``: + To make :py:meth:`~sympy.solvers.ode.dsolve` apply all + relevant classification hints, use ``dsolve(ODE, func, + hint="all")``. This will return a dictionary of + ``hint:solution`` terms. If a hint causes dsolve to raise the + ``NotImplementedError``, value of that hint's key will be the + exception object raised. The dictionary will also include + some special keys: + + - ``order``: The order of the ODE. See also + :py:meth:`~sympy.solvers.deutils.ode_order` in + ``deutils.py``. + - ``best``: The simplest hint; what would be returned by + ``best`` below. + - ``best_hint``: The hint that would produce the solution + given by ``best``. If more than one hint produces the best + solution, the first one in the tuple returned by + :py:meth:`~sympy.solvers.ode.classify_ode` is chosen. + - ``default``: The solution that would be returned by default. + This is the one produced by the hint that appears first in + the tuple returned by + :py:meth:`~sympy.solvers.ode.classify_ode`. + + ``all_Integral``: + This is the same as ``all``, except if a hint also has a + corresponding ``_Integral`` hint, it only returns the + ``_Integral`` hint. This is useful if ``all`` causes + :py:meth:`~sympy.solvers.ode.dsolve` to hang because of a + difficult or impossible integral. This meta-hint will also be + much faster than ``all``, because + :py:meth:`~sympy.core.expr.Expr.integrate` is an expensive + routine. + + ``best``: + To have :py:meth:`~sympy.solvers.ode.dsolve` try all methods + and return the simplest one. This takes into account whether + the solution is solvable in the function, whether it contains + any Integral classes (i.e. unevaluatable integrals), and + which one is the shortest in size. + + See also the :py:meth:`~sympy.solvers.ode.classify_ode` docstring for + more info on hints, and the :py:mod:`~sympy.solvers.ode` docstring for + a list of all supported hints. + + **Tips** + + - You can declare the derivative of an unknown function this way: + + >>> from sympy import Function, Derivative + >>> from sympy.abc import x # x is the independent variable + >>> f = Function("f")(x) # f is a function of x + >>> # f_ will be the derivative of f with respect to x + >>> f_ = Derivative(f, x) + + - See ``test_ode.py`` for many tests, which serves also as a set of + examples for how to use :py:meth:`~sympy.solvers.ode.dsolve`. + - :py:meth:`~sympy.solvers.ode.dsolve` always returns an + :py:class:`~sympy.core.relational.Equality` class (except for the + case when the hint is ``all`` or ``all_Integral``). If possible, it + solves the solution explicitly for the function being solved for. + Otherwise, it returns an implicit solution. + - Arbitrary constants are symbols named ``C1``, ``C2``, and so on. + - Because all solutions should be mathematically equivalent, some + hints may return the exact same result for an ODE. Often, though, + two different hints will return the same solution formatted + differently. The two should be equivalent. Also note that sometimes + the values of the arbitrary constants in two different solutions may + not be the same, because one constant may have "absorbed" other + constants into it. + - Do ``help(ode.ode_)`` to get help more information on a + specific hint, where ```` is the name of a hint without + ``_Integral``. + + For system of ordinary differential equations + ============================================= + + **Usage** + ``dsolve(eq, func)`` -> Solve a system of ordinary differential + equations ``eq`` for ``func`` being list of functions including + `x(t)`, `y(t)`, `z(t)` where number of functions in the list depends + upon the number of equations provided in ``eq``. + + **Details** + + ``eq`` can be any supported system of ordinary differential equations + This can either be an :py:class:`~sympy.core.relational.Equality`, + or an expression, which is assumed to be equal to ``0``. + + ``func`` holds ``x(t)`` and ``y(t)`` being functions of one variable which + together with some of their derivatives make up the system of ordinary + differential equation ``eq``. It is not necessary to provide this; it + will be autodetected (and an error raised if it could not be detected). + + **Hints** + + The hints are formed by parameters returned by classify_sysode, combining + them give hints name used later for forming method name. + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, Derivative, sin, cos, symbols + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(Derivative(f(x), x, x) + 9*f(x), f(x)) + Eq(f(x), C1*sin(3*x) + C2*cos(3*x)) + + >>> eq = sin(x)*cos(f(x)) + cos(x)*sin(f(x))*f(x).diff(x) + >>> dsolve(eq, hint='1st_exact') + [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))] + >>> dsolve(eq, hint='almost_linear') + [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))] + >>> t = symbols('t') + >>> x, y = symbols('x, y', cls=Function) + >>> eq = (Eq(Derivative(x(t),t), 12*t*x(t) + 8*y(t)), Eq(Derivative(y(t),t), 21*x(t) + 7*t*y(t))) + >>> dsolve(eq) + [Eq(x(t), C1*x0(t) + C2*x0(t)*Integral(8*exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)**2, t)), + Eq(y(t), C1*y0(t) + C2*(y0(t)*Integral(8*exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)**2, t) + + exp(Integral(7*t, t))*exp(Integral(12*t, t))/x0(t)))] + >>> eq = (Eq(Derivative(x(t),t),x(t)*y(t)*sin(t)), Eq(Derivative(y(t),t),y(t)**2*sin(t))) + >>> dsolve(eq) + {Eq(x(t), -exp(C1)/(C2*exp(C1) - cos(t))), Eq(y(t), -1/(C1 - cos(t)))} + """ + if iterable(eq): + from sympy.solvers.ode.systems import dsolve_system + + # This may have to be changed in future + # when we have weakly and strongly + # connected components. This have to + # changed to show the systems that haven't + # been solved. + try: + sol = dsolve_system(eq, funcs=func, ics=ics, doit=True) + return sol[0] if len(sol) == 1 else sol + except NotImplementedError: + pass + + match = classify_sysode(eq, func) + + eq = match['eq'] + order = match['order'] + func = match['func'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + + # keep highest order term coefficient positive + for i in range(len(eq)): + for func_ in func: + if isinstance(func_, list): + pass + else: + if eq[i].coeff(diff(func[i],t,ode_order(eq[i], func[i]))).is_negative: + eq[i] = -eq[i] + match['eq'] = eq + if len(set(order.values()))!=1: + raise ValueError("It solves only those systems of equations whose orders are equal") + match['order'] = list(order.values())[0] + def recur_len(l): + return sum(recur_len(item) if isinstance(item,list) else 1 for item in l) + if recur_len(func) != len(eq): + raise ValueError("dsolve() and classify_sysode() work with " + "number of functions being equal to number of equations") + if match['type_of_equation'] is None: + raise NotImplementedError + else: + if match['is_linear'] == True: + solvefunc = globals()['sysode_linear_%(no_of_equation)seq_order%(order)s' % match] + else: + solvefunc = globals()['sysode_nonlinear_%(no_of_equation)seq_order%(order)s' % match] + sols = solvefunc(match) + if ics: + constants = Tuple(*sols).free_symbols - Tuple(*eq).free_symbols + solved_constants = solve_ics(sols, func, constants, ics) + return [sol.subs(solved_constants) for sol in sols] + return sols + else: + given_hint = hint # hint given by the user + + # See the docstring of _desolve for more details. + hints = _desolve(eq, func=func, + hint=hint, simplify=True, xi=xi, eta=eta, type='ode', ics=ics, + x0=x0, n=n, **kwargs) + eq = hints.pop('eq', eq) + all_ = hints.pop('all', False) + if all_: + retdict = {} + failed_hints = {} + gethints = classify_ode(eq, dict=True, hint='all') + orderedhints = gethints['ordered_hints'] + for hint in hints: + try: + rv = _helper_simplify(eq, hint, hints[hint], simplify) + except NotImplementedError as detail: + failed_hints[hint] = detail + else: + retdict[hint] = rv + func = hints[hint]['func'] + + retdict['best'] = min(list(retdict.values()), key=lambda x: + ode_sol_simplicity(x, func, trysolving=not simplify)) + if given_hint == 'best': + return retdict['best'] + for i in orderedhints: + if retdict['best'] == retdict.get(i, None): + retdict['best_hint'] = i + break + retdict['default'] = gethints['default'] + retdict['order'] = gethints['order'] + retdict.update(failed_hints) + return retdict + + else: + # The key 'hint' stores the hint needed to be solved for. + hint = hints['hint'] + return _helper_simplify(eq, hint, hints, simplify, ics=ics) + + +def _helper_simplify(eq, hint, match, simplify=True, ics=None, **kwargs): + r""" + Helper function of dsolve that calls the respective + :py:mod:`~sympy.solvers.ode` functions to solve for the ordinary + differential equations. This minimizes the computation in calling + :py:meth:`~sympy.solvers.deutils._desolve` multiple times. + """ + r = match + func = r['func'] + order = r['order'] + match = r[hint] + + if isinstance(match, SingleODESolver): + solvefunc = match + elif hint.endswith('_Integral'): + solvefunc = globals()['ode_' + hint[:-len('_Integral')]] + else: + solvefunc = globals()['ode_' + hint] + + free = eq.free_symbols + cons = lambda s: s.free_symbols.difference(free) + + if simplify: + # odesimp() will attempt to integrate, if necessary, apply constantsimp(), + # attempt to solve for func, and apply any other hint specific + # simplifications + if isinstance(solvefunc, SingleODESolver): + sols = solvefunc.get_general_solution() + else: + sols = solvefunc(eq, func, order, match) + if iterable(sols): + rv = [] + for s in sols: + simp = odesimp(eq, s, func, hint) + if iterable(simp): + rv.extend(simp) + else: + rv.append(simp) + else: + rv = odesimp(eq, sols, func, hint) + else: + # We still want to integrate (you can disable it separately with the hint) + if isinstance(solvefunc, SingleODESolver): + exprs = solvefunc.get_general_solution(simplify=False) + else: + match['simplify'] = False # Some hints can take advantage of this option + exprs = solvefunc(eq, func, order, match) + if isinstance(exprs, list): + rv = [_handle_Integral(expr, func, hint) for expr in exprs] + else: + rv = _handle_Integral(exprs, func, hint) + + if isinstance(rv, list): + assert all(isinstance(i, Eq) for i in rv), rv # if not => internal error + if simplify: + rv = _remove_redundant_solutions(eq, rv, order, func.args[0]) + if len(rv) == 1: + rv = rv[0] + if ics and 'power_series' not in hint: + if isinstance(rv, (Expr, Eq)): + solved_constants = solve_ics([rv], [r['func']], cons(rv), ics) + rv = rv.subs(solved_constants) + else: + rv1 = [] + for s in rv: + try: + solved_constants = solve_ics([s], [r['func']], cons(s), ics) + except ValueError: + continue + rv1.append(s.subs(solved_constants)) + if len(rv1) == 1: + return rv1[0] + rv = rv1 + return rv + + +def solve_ics(sols, funcs, constants, ics): + """ + Solve for the constants given initial conditions + + ``sols`` is a list of solutions. + + ``funcs`` is a list of functions. + + ``constants`` is a list of constants. + + ``ics`` is the set of initial/boundary conditions for the differential + equation. It should be given in the form of ``{f(x0): x1, + f(x).diff(x).subs(x, x2): x3}`` and so on. + + Returns a dictionary mapping constants to values. + ``solution.subs(constants)`` will replace the constants in ``solution``. + + Example + ======= + >>> # From dsolve(f(x).diff(x) - f(x), f(x)) + >>> from sympy import symbols, Eq, exp, Function + >>> from sympy.solvers.ode.ode import solve_ics + >>> f = Function('f') + >>> x, C1 = symbols('x C1') + >>> sols = [Eq(f(x), C1*exp(x))] + >>> funcs = [f(x)] + >>> constants = [C1] + >>> ics = {f(0): 2} + >>> solved_constants = solve_ics(sols, funcs, constants, ics) + >>> solved_constants + {C1: 2} + >>> sols[0].subs(solved_constants) + Eq(f(x), 2*exp(x)) + + """ + # Assume ics are of the form f(x0): value or Subs(diff(f(x), x, n), (x, + # x0)): value (currently checked by classify_ode). To solve, replace x + # with x0, f(x0) with value, then solve for constants. For f^(n)(x0), + # differentiate the solution n times, so that f^(n)(x) appears. + x = funcs[0].args[0] + diff_sols = [] + subs_sols = [] + diff_variables = set() + for funcarg, value in ics.items(): + if isinstance(funcarg, AppliedUndef): + x0 = funcarg.args[0] + matching_func = [f for f in funcs if f.func == funcarg.func][0] + S = sols + elif isinstance(funcarg, (Subs, Derivative)): + if isinstance(funcarg, Subs): + # Make sure it stays a subs. Otherwise subs below will produce + # a different looking term. + funcarg = funcarg.doit() + if isinstance(funcarg, Subs): + deriv = funcarg.expr + x0 = funcarg.point[0] + variables = funcarg.expr.variables + matching_func = deriv + elif isinstance(funcarg, Derivative): + deriv = funcarg + x0 = funcarg.variables[0] + variables = (x,)*len(funcarg.variables) + matching_func = deriv.subs(x0, x) + for sol in sols: + if sol.has(deriv.expr.func): + diff_sols.append(Eq(sol.lhs.diff(*variables), sol.rhs.diff(*variables))) + diff_variables.add(variables) + S = diff_sols + else: + raise NotImplementedError("Unrecognized initial condition") + + for sol in S: + if sol.has(matching_func): + sol2 = sol + sol2 = sol2.subs(x, x0) + sol2 = sol2.subs(funcarg, value) + # This check is necessary because of issue #15724 + if not isinstance(sol2, BooleanAtom) or not subs_sols: + subs_sols = [s for s in subs_sols if not isinstance(s, BooleanAtom)] + subs_sols.append(sol2) + + # TODO: Use solveset here + try: + solved_constants = solve(subs_sols, constants, dict=True) + except NotImplementedError: + solved_constants = [] + + # XXX: We can't differentiate between the solution not existing because of + # invalid initial conditions, and not existing because solve is not smart + # enough. If we could use solveset, this might be improvable, but for now, + # we use NotImplementedError in this case. + if not solved_constants: + raise ValueError("Couldn't solve for initial conditions") + + if solved_constants == True: + raise ValueError("Initial conditions did not produce any solutions for constants. Perhaps they are degenerate.") + + if len(solved_constants) > 1: + raise NotImplementedError("Initial conditions produced too many solutions for constants") + + return solved_constants[0] + +def classify_ode(eq, func=None, dict=False, ics=None, *, prep=True, xi=None, eta=None, n=None, **kwargs): + r""" + Returns a tuple of possible :py:meth:`~sympy.solvers.ode.dsolve` + classifications for an ODE. + + The tuple is ordered so that first item is the classification that + :py:meth:`~sympy.solvers.ode.dsolve` uses to solve the ODE by default. In + general, classifications at the near the beginning of the list will + produce better solutions faster than those near the end, thought there are + always exceptions. To make :py:meth:`~sympy.solvers.ode.dsolve` use a + different classification, use ``dsolve(ODE, func, + hint=)``. See also the + :py:meth:`~sympy.solvers.ode.dsolve` docstring for different meta-hints + you can use. + + If ``dict`` is true, :py:meth:`~sympy.solvers.ode.classify_ode` will + return a dictionary of ``hint:match`` expression terms. This is intended + for internal use by :py:meth:`~sympy.solvers.ode.dsolve`. Note that + because dictionaries are ordered arbitrarily, this will most likely not be + in the same order as the tuple. + + You can get help on different hints by executing + ``help(ode.ode_hintname)``, where ``hintname`` is the name of the hint + without ``_Integral``. + + See :py:data:`~sympy.solvers.ode.allhints` or the + :py:mod:`~sympy.solvers.ode` docstring for a list of all supported hints + that can be returned from :py:meth:`~sympy.solvers.ode.classify_ode`. + + Notes + ===== + + These are remarks on hint names. + + ``_Integral`` + + If a classification has ``_Integral`` at the end, it will return the + expression with an unevaluated :py:class:`~.Integral` + class in it. Note that a hint may do this anyway if + :py:meth:`~sympy.core.expr.Expr.integrate` cannot do the integral, + though just using an ``_Integral`` will do so much faster. Indeed, an + ``_Integral`` hint will always be faster than its corresponding hint + without ``_Integral`` because + :py:meth:`~sympy.core.expr.Expr.integrate` is an expensive routine. + If :py:meth:`~sympy.solvers.ode.dsolve` hangs, it is probably because + :py:meth:`~sympy.core.expr.Expr.integrate` is hanging on a tough or + impossible integral. Try using an ``_Integral`` hint or + ``all_Integral`` to get it return something. + + Note that some hints do not have ``_Integral`` counterparts. This is + because :py:func:`~sympy.integrals.integrals.integrate` is not used in + solving the ODE for those method. For example, `n`\th order linear + homogeneous ODEs with constant coefficients do not require integration + to solve, so there is no + ``nth_linear_homogeneous_constant_coeff_Integrate`` hint. You can + easily evaluate any unevaluated + :py:class:`~sympy.integrals.integrals.Integral`\s in an expression by + doing ``expr.doit()``. + + Ordinals + + Some hints contain an ordinal such as ``1st_linear``. This is to help + differentiate them from other hints, as well as from other methods + that may not be implemented yet. If a hint has ``nth`` in it, such as + the ``nth_linear`` hints, this means that the method used to applies + to ODEs of any order. + + ``indep`` and ``dep`` + + Some hints contain the words ``indep`` or ``dep``. These reference + the independent variable and the dependent function, respectively. For + example, if an ODE is in terms of `f(x)`, then ``indep`` will refer to + `x` and ``dep`` will refer to `f`. + + ``subs`` + + If a hints has the word ``subs`` in it, it means that the ODE is solved + by substituting the expression given after the word ``subs`` for a + single dummy variable. This is usually in terms of ``indep`` and + ``dep`` as above. The substituted expression will be written only in + characters allowed for names of Python objects, meaning operators will + be spelled out. For example, ``indep``/``dep`` will be written as + ``indep_div_dep``. + + ``coeff`` + + The word ``coeff`` in a hint refers to the coefficients of something + in the ODE, usually of the derivative terms. See the docstring for + the individual methods for more info (``help(ode)``). This is + contrast to ``coefficients``, as in ``undetermined_coefficients``, + which refers to the common name of a method. + + ``_best`` + + Methods that have more than one fundamental way to solve will have a + hint for each sub-method and a ``_best`` meta-classification. This + will evaluate all hints and return the best, using the same + considerations as the normal ``best`` meta-hint. + + + Examples + ======== + + >>> from sympy import Function, classify_ode, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> classify_ode(Eq(f(x).diff(x), 0), f(x)) + ('nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + >>> classify_ode(f(x).diff(x, 2) + 3*f(x).diff(x) + 2*f(x) - 4) + ('factorable', 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + + """ + ics = sympify(ics) + + if func and len(func.args) != 1: + raise ValueError("dsolve() and classify_ode() only " + "work with functions of one variable, not %s" % func) + + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + + # Some methods want the unprocessed equation + eq_orig = eq + + if prep or func is None: + eq, func_ = _preprocess(eq, func) + if func is None: + func = func_ + x = func.args[0] + f = func.func + y = Dummy('y') + terms = 5 if n is None else n + + order = ode_order(eq, f(x)) + # hint:matchdict or hint:(tuple of matchdicts) + # Also will contain "default": and "order":order items. + matching_hints = {"order": order} + + df = f(x).diff(x) + a = Wild('a', exclude=[f(x)]) + d = Wild('d', exclude=[df, f(x).diff(x, 2)]) + e = Wild('e', exclude=[df]) + n = Wild('n', exclude=[x, f(x), df]) + c1 = Wild('c1', exclude=[x]) + a3 = Wild('a3', exclude=[f(x), df, f(x).diff(x, 2)]) + b3 = Wild('b3', exclude=[f(x), df, f(x).diff(x, 2)]) + c3 = Wild('c3', exclude=[f(x), df, f(x).diff(x, 2)]) + boundary = {} # Used to extract initial conditions + C1 = Symbol("C1") + + # Preprocessing to get the initial conditions out + if ics is not None: + for funcarg in ics: + # Separating derivatives + if isinstance(funcarg, (Subs, Derivative)): + # f(x).diff(x).subs(x, 0) is a Subs, but f(x).diff(x).subs(x, + # y) is a Derivative + if isinstance(funcarg, Subs): + deriv = funcarg.expr + old = funcarg.variables[0] + new = funcarg.point[0] + elif isinstance(funcarg, Derivative): + deriv = funcarg + # No information on this. Just assume it was x + old = x + new = funcarg.variables[0] + + if (isinstance(deriv, Derivative) and isinstance(deriv.args[0], + AppliedUndef) and deriv.args[0].func == f and + len(deriv.args[0].args) == 1 and old == x and not + new.has(x) and all(i == deriv.variables[0] for i in + deriv.variables) and x not in ics[funcarg].free_symbols): + + dorder = ode_order(deriv, x) + temp = 'f' + str(dorder) + boundary.update({temp: new, temp + 'val': ics[funcarg]}) + else: + raise ValueError("Invalid boundary conditions for Derivatives") + + + # Separating functions + elif isinstance(funcarg, AppliedUndef): + if (funcarg.func == f and len(funcarg.args) == 1 and + not funcarg.args[0].has(x) and x not in ics[funcarg].free_symbols): + boundary.update({'f0': funcarg.args[0], 'f0val': ics[funcarg]}) + else: + raise ValueError("Invalid boundary conditions for Function") + + else: + raise ValueError("Enter boundary conditions of the form ics={f(point): value, f(x).diff(x, order).subs(x, point): value}") + + ode = SingleODEProblem(eq_orig, func, x, prep=prep, xi=xi, eta=eta) + user_hint = kwargs.get('hint', 'default') + # Used when dsolve is called without an explicit hint. + # We exit early to return the first valid match + early_exit = (user_hint=='default') + if user_hint.endswith('_Integral'): + user_hint = user_hint[:-len('_Integral')] + user_map = solver_map + # An explicit hint has been given to dsolve + # Skip matching code for other hints + if user_hint not in ['default', 'all', 'all_Integral', 'best'] and user_hint in solver_map: + user_map = {user_hint: solver_map[user_hint]} + + for hint in user_map: + solver = user_map[hint](ode) + if solver.matches(): + matching_hints[hint] = solver + if user_map[hint].has_integral: + matching_hints[hint + "_Integral"] = solver + if dict and early_exit: + matching_hints["default"] = hint + return matching_hints + + eq = expand(eq) + # Precondition to try remove f(x) from highest order derivative + reduced_eq = None + if eq.is_Add: + deriv_coef = eq.coeff(f(x).diff(x, order)) + if deriv_coef not in (1, 0): + r = deriv_coef.match(a*f(x)**c1) + if r and r[c1]: + den = f(x)**r[c1] + reduced_eq = Add(*[arg/den for arg in eq.args]) + if not reduced_eq: + reduced_eq = eq + + if order == 1: + + # NON-REDUCED FORM OF EQUATION matches + r = collect(eq, df, exact=True).match(d + e * df) + if r: + r['d'] = d + r['e'] = e + r['y'] = y + r[d] = r[d].subs(f(x), y) + r[e] = r[e].subs(f(x), y) + + # FIRST ORDER POWER SERIES WHICH NEEDS INITIAL CONDITIONS + # TODO: Hint first order series should match only if d/e is analytic. + # For now, only d/e and (d/e).diff(arg) is checked for existence at + # at a given point. + # This is currently done internally in ode_1st_power_series. + point = boundary.get('f0', 0) + value = boundary.get('f0val', C1) + check = cancel(r[d]/r[e]) + check1 = check.subs({x: point, y: value}) + if not check1.has(oo) and not check1.has(zoo) and \ + not check1.has(nan) and not check1.has(-oo): + check2 = (check1.diff(x)).subs({x: point, y: value}) + if not check2.has(oo) and not check2.has(zoo) and \ + not check2.has(nan) and not check2.has(-oo): + rseries = r.copy() + rseries.update({'terms': terms, 'f0': point, 'f0val': value}) + matching_hints["1st_power_series"] = rseries + + elif order == 2: + # Homogeneous second order differential equation of the form + # a3*f(x).diff(x, 2) + b3*f(x).diff(x) + c3 + # It has a definite power series solution at point x0 if, b3/a3 and c3/a3 + # are analytic at x0. + deq = a3*(f(x).diff(x, 2)) + b3*df + c3*f(x) + r = collect(reduced_eq, + [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq) + ordinary = False + if r: + if not all(r[key].is_polynomial() for key in r): + n, d = reduced_eq.as_numer_denom() + reduced_eq = expand(n) + r = collect(reduced_eq, + [f(x).diff(x, 2), f(x).diff(x), f(x)]).match(deq) + if r and r[a3] != 0: + p = cancel(r[b3]/r[a3]) # Used below + q = cancel(r[c3]/r[a3]) # Used below + point = kwargs.get('x0', 0) + check = p.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + check = q.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + ordinary = True + r.update({'a3': a3, 'b3': b3, 'c3': c3, 'x0': point, 'terms': terms}) + matching_hints["2nd_power_series_ordinary"] = r + + # Checking if the differential equation has a regular singular point + # at x0. It has a regular singular point at x0, if (b3/a3)*(x - x0) + # and (c3/a3)*((x - x0)**2) are analytic at x0. + if not ordinary: + p = cancel((x - point)*p) + check = p.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + q = cancel(((x - point)**2)*q) + check = q.subs(x, point) + if not check.has(oo, nan, zoo, -oo): + coeff_dict = {'p': p, 'q': q, 'x0': point, 'terms': terms} + matching_hints["2nd_power_series_regular"] = coeff_dict + + + # Order keys based on allhints. + retlist = [i for i in allhints if i in matching_hints] + if dict: + # Dictionaries are ordered arbitrarily, so make note of which + # hint would come first for dsolve(). Use an ordered dict in Py 3. + matching_hints["default"] = retlist[0] if retlist else None + matching_hints["ordered_hints"] = tuple(retlist) + return matching_hints + else: + return tuple(retlist) + + +def classify_sysode(eq, funcs=None, **kwargs): + r""" + Returns a dictionary of parameter names and values that define the system + of ordinary differential equations in ``eq``. + The parameters are further used in + :py:meth:`~sympy.solvers.ode.dsolve` for solving that system. + + Some parameter names and values are: + + 'is_linear' (boolean), which tells whether the given system is linear. + Note that "linear" here refers to the operator: terms such as ``x*diff(x,t)`` are + nonlinear, whereas terms like ``sin(t)*diff(x,t)`` are still linear operators. + + 'func' (list) contains the :py:class:`~sympy.core.function.Function`s that + appear with a derivative in the ODE, i.e. those that we are trying to solve + the ODE for. + + 'order' (dict) with the maximum derivative for each element of the 'func' + parameter. + + 'func_coeff' (dict or Matrix) with the coefficient for each triple ``(equation number, + function, order)```. The coefficients are those subexpressions that do not + appear in 'func', and hence can be considered constant for purposes of ODE + solving. The value of this parameter can also be a Matrix if the system of ODEs are + linear first order of the form X' = AX where X is the vector of dependent variables. + Here, this function returns the coefficient matrix A. + + 'eq' (list) with the equations from ``eq``, sympified and transformed into + expressions (we are solving for these expressions to be zero). + + 'no_of_equations' (int) is the number of equations (same as ``len(eq)``). + + 'type_of_equation' (string) is an internal classification of the type of + ODE. + + 'is_constant' (boolean), which tells if the system of ODEs is constant coefficient + or not. This key is temporary addition for now and is in the match dict only when + the system of ODEs is linear first order constant coefficient homogeneous. So, this + key's value is True for now if it is available else it does not exist. + + 'is_homogeneous' (boolean), which tells if the system of ODEs is homogeneous. Like the + key 'is_constant', this key is a temporary addition and it is True since this key value + is available only when the system is linear first order constant coefficient homogeneous. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode-toc1.htm + -A. D. Polyanin and A. V. Manzhirov, Handbook of Mathematics for Engineers and Scientists + + Examples + ======== + + >>> from sympy import Function, Eq, symbols, diff + >>> from sympy.solvers.ode.ode import classify_sysode + >>> from sympy.abc import t + >>> f, x, y = symbols('f, x, y', cls=Function) + >>> k, l, m, n = symbols('k, l, m, n', Integer=True) + >>> x1 = diff(x(t), t) ; y1 = diff(y(t), t) + >>> x2 = diff(x(t), t, t) ; y2 = diff(y(t), t, t) + >>> eq = (Eq(x1, 12*x(t) - 6*y(t)), Eq(y1, 11*x(t) + 3*y(t))) + >>> classify_sysode(eq) + {'eq': [-12*x(t) + 6*y(t) + Derivative(x(t), t), -11*x(t) - 3*y(t) + Derivative(y(t), t)], 'func': [x(t), y(t)], + 'func_coeff': {(0, x(t), 0): -12, (0, x(t), 1): 1, (0, y(t), 0): 6, (0, y(t), 1): 0, (1, x(t), 0): -11, (1, x(t), 1): 0, (1, y(t), 0): -3, (1, y(t), 1): 1}, 'is_linear': True, 'no_of_equation': 2, 'order': {x(t): 1, y(t): 1}, 'type_of_equation': None} + >>> eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t) + 2), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + >>> classify_sysode(eq) + {'eq': [-t**2*y(t) - 5*t*x(t) + Derivative(x(t), t) - 2, t**2*x(t) - 5*t*y(t) + Derivative(y(t), t)], + 'func': [x(t), y(t)], 'func_coeff': {(0, x(t), 0): -5*t, (0, x(t), 1): 1, (0, y(t), 0): -t**2, (0, y(t), 1): 0, + (1, x(t), 0): t**2, (1, x(t), 1): 0, (1, y(t), 0): -5*t, (1, y(t), 1): 1}, 'is_linear': True, 'no_of_equation': 2, + 'order': {x(t): 1, y(t): 1}, 'type_of_equation': None} + + """ + + # Sympify equations and convert iterables of equations into + # a list of equations + def _sympify(eq): + return list(map(sympify, eq if iterable(eq) else [eq])) + + eq, funcs = (_sympify(w) for w in [eq, funcs]) + for i, fi in enumerate(eq): + if isinstance(fi, Equality): + eq[i] = fi.lhs - fi.rhs + + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + matching_hints = {"no_of_equation":i+1} + matching_hints['eq'] = eq + if i==0: + raise ValueError("classify_sysode() works for systems of ODEs. " + "For scalar ODEs, classify_ode should be used") + + # find all the functions if not given + order = {} + if funcs==[None]: + funcs = _extract_funcs(eq) + + funcs = list(set(funcs)) + if len(funcs) != len(eq): + raise ValueError("Number of functions given is not equal to the number of equations %s" % funcs) + + # This logic of list of lists in funcs to + # be replaced later. + func_dict = {} + for func in funcs: + if not order.get(func, False): + max_order = 0 + for i, eqs_ in enumerate(eq): + order_ = ode_order(eqs_,func) + if max_order < order_: + max_order = order_ + eq_no = i + if eq_no in func_dict: + func_dict[eq_no] = [func_dict[eq_no], func] + else: + func_dict[eq_no] = func + order[func] = max_order + + funcs = [func_dict[i] for i in range(len(func_dict))] + matching_hints['func'] = funcs + for func in funcs: + if isinstance(func, list): + for func_elem in func: + if len(func_elem.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + else: + if func and len(func.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + + # find the order of all equation in system of odes + matching_hints["order"] = order + + # find coefficients of terms f(t), diff(f(t),t) and higher derivatives + # and similarly for other functions g(t), diff(g(t),t) in all equations. + # Here j denotes the equation number, funcs[l] denotes the function about + # which we are talking about and k denotes the order of function funcs[l] + # whose coefficient we are calculating. + def linearity_check(eqs, j, func, is_linear_): + for k in range(order[func] + 1): + func_coef[j, func, k] = collect(eqs.expand(), [diff(func, t, k)]).coeff(diff(func, t, k)) + if is_linear_ == True: + if func_coef[j, func, k] == 0: + if k == 0: + coef = eqs.as_independent(func, as_Add=True)[1] + for xr in range(1, ode_order(eqs,func) + 1): + coef -= eqs.as_independent(diff(func, t, xr), as_Add=True)[1] + if coef != 0: + is_linear_ = False + else: + if eqs.as_independent(diff(func, t, k), as_Add=True)[1]: + is_linear_ = False + else: + for func_ in funcs: + if isinstance(func_, list): + for elem_func_ in func_: + dep = func_coef[j, func, k].as_independent(elem_func_, as_Add=True)[1] + if dep != 0: + is_linear_ = False + else: + dep = func_coef[j, func, k].as_independent(func_, as_Add=True)[1] + if dep != 0: + is_linear_ = False + return is_linear_ + + func_coef = {} + is_linear = True + for j, eqs in enumerate(eq): + for func in funcs: + if isinstance(func, list): + for func_elem in func: + is_linear = linearity_check(eqs, j, func_elem, is_linear) + else: + is_linear = linearity_check(eqs, j, func, is_linear) + matching_hints['func_coeff'] = func_coef + matching_hints['is_linear'] = is_linear + + + if len(set(order.values())) == 1: + order_eq = list(matching_hints['order'].values())[0] + if matching_hints['is_linear'] == True: + if matching_hints['no_of_equation'] == 2: + if order_eq == 1: + type_of_equation = check_linear_2eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + # If the equation does not match up with any of the + # general case solvers in systems.py and the number + # of equations is greater than 2, then NotImplementedError + # should be raised. + else: + type_of_equation = None + + else: + if matching_hints['no_of_equation'] == 2: + if order_eq == 1: + type_of_equation = check_nonlinear_2eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + elif matching_hints['no_of_equation'] == 3: + if order_eq == 1: + type_of_equation = check_nonlinear_3eq_order1(eq, funcs, func_coef) + else: + type_of_equation = None + else: + type_of_equation = None + else: + type_of_equation = None + + matching_hints['type_of_equation'] = type_of_equation + + return matching_hints + + +def check_linear_2eq_order1(eq, func, func_coef): + x = func[0].func + y = func[1].func + fc = func_coef + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + r = {} + # for equations Eq(a1*diff(x(t),t), b1*x(t) + c1*y(t) + d1) + # and Eq(a2*diff(y(t),t), b2*x(t) + c2*y(t) + d2) + r['a1'] = fc[0,x(t),1] ; r['a2'] = fc[1,y(t),1] + r['b1'] = -fc[0,x(t),0]/fc[0,x(t),1] ; r['b2'] = -fc[1,x(t),0]/fc[1,y(t),1] + r['c1'] = -fc[0,y(t),0]/fc[0,x(t),1] ; r['c2'] = -fc[1,y(t),0]/fc[1,y(t),1] + forcing = [S.Zero,S.Zero] + for i in range(2): + for j in Add.make_args(eq[i]): + if not j.has(x(t), y(t)): + forcing[i] += j + if not (forcing[0].has(t) or forcing[1].has(t)): + # We can handle homogeneous case and simple constant forcings + r['d1'] = forcing[0] + r['d2'] = forcing[1] + else: + # Issue #9244: nonhomogeneous linear systems are not supported + return None + + # Conditions to check for type 6 whose equations are Eq(diff(x(t),t), f(t)*x(t) + g(t)*y(t)) and + # Eq(diff(y(t),t), a*[f(t) + a*h(t)]x(t) + a*[g(t) - h(t)]*y(t)) + p = 0 + q = 0 + p1 = cancel(r['b2']/(cancel(r['b2']/r['c2']).as_numer_denom()[0])) + p2 = cancel(r['b1']/(cancel(r['b1']/r['c1']).as_numer_denom()[0])) + for n, i in enumerate([p1, p2]): + for j in Mul.make_args(collect_const(i)): + if not j.has(t): + q = j + if q and n==0: + if ((r['b2']/j - r['b1'])/(r['c1'] - r['c2']/j)) == j: + p = 1 + elif q and n==1: + if ((r['b1']/j - r['b2'])/(r['c2'] - r['c1']/j)) == j: + p = 2 + # End of condition for type 6 + + if r['d1']!=0 or r['d2']!=0: + return None + else: + if not any(r[k].has(t) for k in 'a1 a2 b1 b2 c1 c2'.split()): + return None + else: + r['b1'] = r['b1']/r['a1'] ; r['b2'] = r['b2']/r['a2'] + r['c1'] = r['c1']/r['a1'] ; r['c2'] = r['c2']/r['a2'] + if p: + return "type6" + else: + # Equations for type 7 are Eq(diff(x(t),t), f(t)*x(t) + g(t)*y(t)) and Eq(diff(y(t),t), h(t)*x(t) + p(t)*y(t)) + return "type7" +def check_nonlinear_2eq_order1(eq, func, func_coef): + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + f = Wild('f') + g = Wild('g') + u, v = symbols('u, v', cls=Dummy) + def check_type(x, y): + r1 = eq[0].match(t*diff(x(t),t) - x(t) + f) + r2 = eq[1].match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = eq[0].match(diff(x(t),t) - x(t)/t + f/t) + r2 = eq[1].match(diff(y(t),t) - y(t)/t + g/t) + if not (r1 and r2): + r1 = (-eq[0]).match(t*diff(x(t),t) - x(t) + f) + r2 = (-eq[1]).match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = (-eq[0]).match(diff(x(t),t) - x(t)/t + f/t) + r2 = (-eq[1]).match(diff(y(t),t) - y(t)/t + g/t) + if r1 and r2 and not (r1[f].subs(diff(x(t),t),u).subs(diff(y(t),t),v).has(t) \ + or r2[g].subs(diff(x(t),t),u).subs(diff(y(t),t),v).has(t)): + return 'type5' + else: + return None + for func_ in func: + if isinstance(func_, list): + x = func[0][0].func + y = func[0][1].func + eq_type = check_type(x, y) + if not eq_type: + eq_type = check_type(y, x) + return eq_type + x = func[0].func + y = func[1].func + fc = func_coef + n = Wild('n', exclude=[x(t),y(t)]) + f1 = Wild('f1', exclude=[v,t]) + f2 = Wild('f2', exclude=[v,t]) + g1 = Wild('g1', exclude=[u,t]) + g2 = Wild('g2', exclude=[u,t]) + for i in range(2): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + r = eq[0].match(diff(x(t),t) - x(t)**n*f) + if r: + g = (diff(y(t),t) - eq[1])/r[f] + if r and not (g.has(x(t)) or g.subs(y(t),v).has(t) or r[f].subs(x(t),u).subs(y(t),v).has(t)): + return 'type1' + r = eq[0].match(diff(x(t),t) - exp(n*x(t))*f) + if r: + g = (diff(y(t),t) - eq[1])/r[f] + if r and not (g.has(x(t)) or g.subs(y(t),v).has(t) or r[f].subs(x(t),u).subs(y(t),v).has(t)): + return 'type2' + g = Wild('g') + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + if r1 and r2 and not (r1[f].subs(x(t),u).subs(y(t),v).has(t) or \ + r2[g].subs(x(t),u).subs(y(t),v).has(t)): + return 'type3' + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + num, den = ( + (r1[f].subs(x(t),u).subs(y(t),v))/ + (r2[g].subs(x(t),u).subs(y(t),v))).as_numer_denom() + R1 = num.match(f1*g1) + R2 = den.match(f2*g2) + # phi = (r1[f].subs(x(t),u).subs(y(t),v))/num + if R1 and R2: + return 'type4' + return None + + +def check_nonlinear_2eq_order2(eq, func, func_coef): + return None + +def check_nonlinear_3eq_order1(eq, func, func_coef): + x = func[0].func + y = func[1].func + z = func[2].func + fc = func_coef + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + u, v, w = symbols('u, v, w', cls=Dummy) + a = Wild('a', exclude=[x(t), y(t), z(t), t]) + b = Wild('b', exclude=[x(t), y(t), z(t), t]) + c = Wild('c', exclude=[x(t), y(t), z(t), t]) + f = Wild('f') + F1 = Wild('F1') + F2 = Wild('F2') + F3 = Wild('F3') + for i in range(3): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + r1 = eq[0].match(diff(x(t),t) - a*y(t)*z(t)) + r2 = eq[1].match(diff(y(t),t) - b*z(t)*x(t)) + r3 = eq[2].match(diff(z(t),t) - c*x(t)*y(t)) + if r1 and r2 and r3: + num1, den1 = r1[a].as_numer_denom() + num2, den2 = r2[b].as_numer_denom() + num3, den3 = r3[c].as_numer_denom() + if solve([num1*u-den1*(v-w), num2*v-den2*(w-u), num3*w-den3*(u-v)],[u, v]): + return 'type1' + r = eq[0].match(diff(x(t),t) - y(t)*z(t)*f) + if r: + r1 = collect_const(r[f]).match(a*f) + r2 = ((diff(y(t),t) - eq[1])/r1[f]).match(b*z(t)*x(t)) + r3 = ((diff(z(t),t) - eq[2])/r1[f]).match(c*x(t)*y(t)) + if r1 and r2 and r3: + num1, den1 = r1[a].as_numer_denom() + num2, den2 = r2[b].as_numer_denom() + num3, den3 = r3[c].as_numer_denom() + if solve([num1*u-den1*(v-w), num2*v-den2*(w-u), num3*w-den3*(u-v)],[u, v]): + return 'type2' + r = eq[0].match(diff(x(t),t) - (F2-F3)) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = eq[1].match(diff(y(t),t) - a*r1[F3] + r1[c]*F1) + if r2: + r3 = (eq[2] == diff(z(t),t) - r1[b]*r2[F1] + r2[a]*r1[F2]) + if r1 and r2 and r3: + return 'type3' + r = eq[0].match(diff(x(t),t) - z(t)*F2 + y(t)*F3) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = (diff(y(t),t) - eq[1]).match(a*x(t)*r1[F3] - r1[c]*z(t)*F1) + if r2: + r3 = (diff(z(t),t) - eq[2] == r1[b]*y(t)*r2[F1] - r2[a]*x(t)*r1[F2]) + if r1 and r2 and r3: + return 'type4' + r = (diff(x(t),t) - eq[0]).match(x(t)*(F2 - F3)) + if r: + r1 = collect_const(r[F2]).match(c*F2) + r1.update(collect_const(r[F3]).match(b*F3)) + if r1: + if eq[1].has(r1[F2]) and not eq[1].has(r1[F3]): + r1[F2], r1[F3] = r1[F3], r1[F2] + r1[c], r1[b] = -r1[b], -r1[c] + r2 = (diff(y(t),t) - eq[1]).match(y(t)*(a*r1[F3] - r1[c]*F1)) + if r2: + r3 = (diff(z(t),t) - eq[2] == z(t)*(r1[b]*r2[F1] - r2[a]*r1[F2])) + if r1 and r2 and r3: + return 'type5' + return None + + +def check_nonlinear_3eq_order2(eq, func, func_coef): + return None + + +@vectorize(0) +def odesimp(ode, eq, func, hint): + r""" + Simplifies solutions of ODEs, including trying to solve for ``func`` and + running :py:meth:`~sympy.solvers.ode.constantsimp`. + + It may use knowledge of the type of solution that the hint returns to + apply additional simplifications. + + It also attempts to integrate any :py:class:`~sympy.integrals.integrals.Integral`\s + in the expression, if the hint is not an ``_Integral`` hint. + + This function should have no effect on expressions returned by + :py:meth:`~sympy.solvers.ode.dsolve`, as + :py:meth:`~sympy.solvers.ode.dsolve` already calls + :py:meth:`~sympy.solvers.ode.ode.odesimp`, but the individual hint functions + do not call :py:meth:`~sympy.solvers.ode.ode.odesimp` (because the + :py:meth:`~sympy.solvers.ode.dsolve` wrapper does). Therefore, this + function is designed for mainly internal use. + + Examples + ======== + + >>> from sympy import sin, symbols, dsolve, pprint, Function + >>> from sympy.solvers.ode.ode import odesimp + >>> x, u2, C1= symbols('x,u2,C1') + >>> f = Function('f') + + >>> eq = dsolve(x*f(x).diff(x) - f(x) - x*sin(f(x)/x), f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep_Integral', + ... simplify=False) + >>> pprint(eq, wrap_line=False) + x + ---- + f(x) + / + | + | / 1 \ + | -|u1 + -------| + | | /1 \| + | | sin|--|| + | \ \u1// + log(f(x)) = log(C1) + | ---------------- d(u1) + | 2 + | u1 + | + / + + >>> pprint(odesimp(eq, f(x), 1, {C1}, + ... hint='1st_homogeneous_coeff_subs_indep_div_dep' + ... )) #doctest: +SKIP + x + --------- = C1 + /f(x)\ + tan|----| + \2*x / + + """ + x = func.args[0] + f = func.func + C1 = get_numbered_constants(eq, num=1) + constants = eq.free_symbols - ode.free_symbols + + # First, integrate if the hint allows it. + eq = _handle_Integral(eq, func, hint) + if hint.startswith("nth_linear_euler_eq_nonhomogeneous"): + eq = simplify(eq) + if not isinstance(eq, Equality): + raise TypeError("eq should be an instance of Equality") + + # allow simplifications under assumption that symbols are nonzero + eq = eq.xreplace((_:={i: Dummy(nonzero=True) for i in constants})).xreplace({_[i]: i for i in _}) + + # Second, clean up the arbitrary constants. + # Right now, nth linear hints can put as many as 2*order constants in an + # expression. If that number grows with another hint, the third argument + # here should be raised accordingly, or constantsimp() rewritten to handle + # an arbitrary number of constants. + eq = constantsimp(eq, constants) + + # Lastly, now that we have cleaned up the expression, try solving for func. + # When CRootOf is implemented in solve(), we will want to return a CRootOf + # every time instead of an Equality. + + # Get the f(x) on the left if possible. + if eq.rhs == func and not eq.lhs.has(func): + eq = [Eq(eq.rhs, eq.lhs)] + + # make sure we are working with lists of solutions in simplified form. + if eq.lhs == func and not eq.rhs.has(func): + # The solution is already solved + eq = [eq] + + else: + # The solution is not solved, so try to solve it + try: + floats = any(i.is_Float for i in eq.atoms(Number)) + eqsol = solve(eq, func, force=True, rational=False if floats else None) + if not eqsol: + raise NotImplementedError + except (NotImplementedError, PolynomialError): + eq = [eq] + else: + def _expand(expr): + numer, denom = expr.as_numer_denom() + + if denom.is_Add: + return expr + else: + return powsimp(expr.expand(), combine='exp', deep=True) + + # XXX: the rest of odesimp() expects each ``t`` to be in a + # specific normal form: rational expression with numerator + # expanded, but with combined exponential functions (at + # least in this setup all tests pass). + eq = [Eq(f(x), _expand(t)) for t in eqsol] + + # special simplification of the lhs. + if hint.startswith("1st_homogeneous_coeff"): + for j, eqi in enumerate(eq): + newi = logcombine(eqi, force=True) + if isinstance(newi.lhs, log) and newi.rhs == 0: + newi = Eq(newi.lhs.args[0]/C1, C1) + eq[j] = newi + + # We cleaned up the constants before solving to help the solve engine with + # a simpler expression, but the solved expression could have introduced + # things like -C1, so rerun constantsimp() one last time before returning. + for i, eqi in enumerate(eq): + eq[i] = constantsimp(eqi, constants) + eq[i] = constant_renumber(eq[i], ode.free_symbols) + + # If there is only 1 solution, return it; + # otherwise return the list of solutions. + if len(eq) == 1: + eq = eq[0] + return eq + + +def ode_sol_simplicity(sol, func, trysolving=True): + r""" + Returns an extended integer representing how simple a solution to an ODE + is. + + The following things are considered, in order from most simple to least: + + - ``sol`` is solved for ``func``. + - ``sol`` is not solved for ``func``, but can be if passed to solve (e.g., + a solution returned by ``dsolve(ode, func, simplify=False``). + - If ``sol`` is not solved for ``func``, then base the result on the + length of ``sol``, as computed by ``len(str(sol))``. + - If ``sol`` has any unevaluated :py:class:`~sympy.integrals.integrals.Integral`\s, + this will automatically be considered less simple than any of the above. + + This function returns an integer such that if solution A is simpler than + solution B by above metric, then ``ode_sol_simplicity(sola, func) < + ode_sol_simplicity(solb, func)``. + + Currently, the following are the numbers returned, but if the heuristic is + ever improved, this may change. Only the ordering is guaranteed. + + +----------------------------------------------+-------------------+ + | Simplicity | Return | + +==============================================+===================+ + | ``sol`` solved for ``func`` | ``-2`` | + +----------------------------------------------+-------------------+ + | ``sol`` not solved for ``func`` but can be | ``-1`` | + +----------------------------------------------+-------------------+ + | ``sol`` is not solved nor solvable for | ``len(str(sol))`` | + | ``func`` | | + +----------------------------------------------+-------------------+ + | ``sol`` contains an | ``oo`` | + | :obj:`~sympy.integrals.integrals.Integral` | | + +----------------------------------------------+-------------------+ + + ``oo`` here means the SymPy infinity, which should compare greater than + any integer. + + If you already know :py:meth:`~sympy.solvers.solvers.solve` cannot solve + ``sol``, you can use ``trysolving=False`` to skip that step, which is the + only potentially slow step. For example, + :py:meth:`~sympy.solvers.ode.dsolve` with the ``simplify=False`` flag + should do this. + + If ``sol`` is a list of solutions, if the worst solution in the list + returns ``oo`` it returns that, otherwise it returns ``len(str(sol))``, + that is, the length of the string representation of the whole list. + + Examples + ======== + + This function is designed to be passed to ``min`` as the key argument, + such as ``min(listofsolutions, key=lambda i: ode_sol_simplicity(i, + f(x)))``. + + >>> from sympy import symbols, Function, Eq, tan, Integral + >>> from sympy.solvers.ode.ode import ode_sol_simplicity + >>> x, C1, C2 = symbols('x, C1, C2') + >>> f = Function('f') + + >>> ode_sol_simplicity(Eq(f(x), C1*x**2), f(x)) + -2 + >>> ode_sol_simplicity(Eq(x**2 + f(x), C1), f(x)) + -1 + >>> ode_sol_simplicity(Eq(f(x), C1*Integral(2*x, x)), f(x)) + oo + >>> eq1 = Eq(f(x)/tan(f(x)/(2*x)), C1) + >>> eq2 = Eq(f(x)/tan(f(x)/(2*x) + f(x)), C2) + >>> [ode_sol_simplicity(eq, f(x)) for eq in [eq1, eq2]] + [28, 35] + >>> min([eq1, eq2], key=lambda i: ode_sol_simplicity(i, f(x))) + Eq(f(x)/tan(f(x)/(2*x)), C1) + + """ + # TODO: if two solutions are solved for f(x), we still want to be + # able to get the simpler of the two + + # See the docstring for the coercion rules. We check easier (faster) + # things here first, to save time. + + if iterable(sol): + # See if there are Integrals + for i in sol: + if ode_sol_simplicity(i, func, trysolving=trysolving) == oo: + return oo + + return len(str(sol)) + + if sol.has(Integral): + return oo + + # Next, try to solve for func. This code will change slightly when CRootOf + # is implemented in solve(). Probably a CRootOf solution should fall + # somewhere between a normal solution and an unsolvable expression. + + # First, see if they are already solved + if sol.lhs == func and not sol.rhs.has(func) or \ + sol.rhs == func and not sol.lhs.has(func): + return -2 + # We are not so lucky, try solving manually + if trysolving: + try: + sols = solve(sol, func) + if not sols: + raise NotImplementedError + except NotImplementedError: + pass + else: + return -1 + + # Finally, a naive computation based on the length of the string version + # of the expression. This may favor combined fractions because they + # will not have duplicate denominators, and may slightly favor expressions + # with fewer additions and subtractions, as those are separated by spaces + # by the printer. + + # Additional ideas for simplicity heuristics are welcome, like maybe + # checking if a equation has a larger domain, or if constantsimp has + # introduced arbitrary constants numbered higher than the order of a + # given ODE that sol is a solution of. + return len(str(sol)) + + +def _extract_funcs(eqs): + funcs = [] + for eq in eqs: + derivs = [node for node in preorder_traversal(eq) if isinstance(node, Derivative)] + func = [] + for d in derivs: + func += list(d.atoms(AppliedUndef)) + for func_ in func: + funcs.append(func_) + funcs = list(uniq(funcs)) + + return funcs + + +def _get_constant_subexpressions(expr, Cs): + Cs = set(Cs) + Ces = [] + def _recursive_walk(expr): + expr_syms = expr.free_symbols + if expr_syms and expr_syms.issubset(Cs): + Ces.append(expr) + else: + if expr.func == exp: + expr = expr.expand(mul=True) + if expr.func in (Add, Mul): + d = sift(expr.args, lambda i : i.free_symbols.issubset(Cs)) + if len(d[True]) > 1: + x = expr.func(*d[True]) + if not x.is_number: + Ces.append(x) + elif isinstance(expr, Integral): + if expr.free_symbols.issubset(Cs) and \ + all(len(x) == 3 for x in expr.limits): + Ces.append(expr) + for i in expr.args: + _recursive_walk(i) + return + _recursive_walk(expr) + return Ces + +def __remove_linear_redundancies(expr, Cs): + cnts = {i: expr.count(i) for i in Cs} + Cs = [i for i in Cs if cnts[i] > 0] + + def _linear(expr): + if isinstance(expr, Add): + xs = [i for i in Cs if expr.count(i)==cnts[i] \ + and 0 == expr.diff(i, 2)] + d = {} + for x in xs: + y = expr.diff(x) + if y not in d: + d[y]=[] + d[y].append(x) + for y in d: + if len(d[y]) > 1: + d[y].sort(key=str) + for x in d[y][1:]: + expr = expr.subs(x, 0) + return expr + + def _recursive_walk(expr): + if len(expr.args) != 0: + expr = expr.func(*[_recursive_walk(i) for i in expr.args]) + expr = _linear(expr) + return expr + + if isinstance(expr, Equality): + lhs, rhs = [_recursive_walk(i) for i in expr.args] + f = lambda i: isinstance(i, Number) or i in Cs + if isinstance(lhs, Symbol) and lhs in Cs: + rhs, lhs = lhs, rhs + if lhs.func in (Add, Symbol) and rhs.func in (Add, Symbol): + dlhs = sift([lhs] if isinstance(lhs, AtomicExpr) else lhs.args, f) + drhs = sift([rhs] if isinstance(rhs, AtomicExpr) else rhs.args, f) + for i in [True, False]: + for hs in [dlhs, drhs]: + if i not in hs: + hs[i] = [0] + # this calculation can be simplified + lhs = Add(*dlhs[False]) - Add(*drhs[False]) + rhs = Add(*drhs[True]) - Add(*dlhs[True]) + elif lhs.func in (Mul, Symbol) and rhs.func in (Mul, Symbol): + dlhs = sift([lhs] if isinstance(lhs, AtomicExpr) else lhs.args, f) + if True in dlhs: + if False not in dlhs: + dlhs[False] = [1] + lhs = Mul(*dlhs[False]) + rhs = rhs/Mul(*dlhs[True]) + return Eq(lhs, rhs) + else: + return _recursive_walk(expr) + +@vectorize(0) +def constantsimp(expr, constants): + r""" + Simplifies an expression with arbitrary constants in it. + + This function is written specifically to work with + :py:meth:`~sympy.solvers.ode.dsolve`, and is not intended for general use. + + Simplification is done by "absorbing" the arbitrary constants into other + arbitrary constants, numbers, and symbols that they are not independent + of. + + The symbols must all have the same name with numbers after it, for + example, ``C1``, ``C2``, ``C3``. The ``symbolname`` here would be + '``C``', the ``startnumber`` would be 1, and the ``endnumber`` would be 3. + If the arbitrary constants are independent of the variable ``x``, then the + independent symbol would be ``x``. There is no need to specify the + dependent function, such as ``f(x)``, because it already has the + independent symbol, ``x``, in it. + + Because terms are "absorbed" into arbitrary constants and because + constants are renumbered after simplifying, the arbitrary constants in + expr are not necessarily equal to the ones of the same name in the + returned result. + + If two or more arbitrary constants are added, multiplied, or raised to the + power of each other, they are first absorbed together into a single + arbitrary constant. Then the new constant is combined into other terms if + necessary. + + Absorption of constants is done with limited assistance: + + 1. terms of :py:class:`~sympy.core.add.Add`\s are collected to try join + constants so `e^x (C_1 \cos(x) + C_2 \cos(x))` will simplify to `e^x + C_1 \cos(x)`; + + 2. powers with exponents that are :py:class:`~sympy.core.add.Add`\s are + expanded so `e^{C_1 + x}` will be simplified to `C_1 e^x`. + + Use :py:meth:`~sympy.solvers.ode.ode.constant_renumber` to renumber constants + after simplification or else arbitrary numbers on constants may appear, + e.g. `C_1 + C_3 x`. + + In rare cases, a single constant can be "simplified" into two constants. + Every differential equation solution should have as many arbitrary + constants as the order of the differential equation. The result here will + be technically correct, but it may, for example, have `C_1` and `C_2` in + an expression, when `C_1` is actually equal to `C_2`. Use your discretion + in such situations, and also take advantage of the ability to use hints in + :py:meth:`~sympy.solvers.ode.dsolve`. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.ode.ode import constantsimp + >>> C1, C2, C3, x, y = symbols('C1, C2, C3, x, y') + >>> constantsimp(2*C1*x, {C1, C2, C3}) + C1*x + >>> constantsimp(C1 + 2 + x, {C1, C2, C3}) + C1 + x + >>> constantsimp(C1*C2 + 2 + C2 + C3*x, {C1, C2, C3}) + C1 + C3*x + + """ + # This function works recursively. The idea is that, for Mul, + # Add, Pow, and Function, if the class has a constant in it, then + # we can simplify it, which we do by recursing down and + # simplifying up. Otherwise, we can skip that part of the + # expression. + + Cs = constants + + orig_expr = expr + + constant_subexprs = _get_constant_subexpressions(expr, Cs) + for xe in constant_subexprs: + xes = list(xe.free_symbols) + if not xes: + continue + if all(expr.count(c) == xe.count(c) for c in xes): + xes.sort(key=str) + expr = expr.subs(xe, xes[0]) + + # try to perform common sub-expression elimination of constant terms + try: + commons, rexpr = cse(expr) + commons.reverse() + rexpr = rexpr[0] + for s in commons: + cs = list(s[1].atoms(Symbol)) + if len(cs) == 1 and cs[0] in Cs and \ + cs[0] not in rexpr.atoms(Symbol) and \ + not any(cs[0] in ex for ex in commons if ex != s): + rexpr = rexpr.subs(s[0], cs[0]) + else: + rexpr = rexpr.subs(*s) + expr = rexpr + except IndexError: + pass + expr = __remove_linear_redundancies(expr, Cs) + + def _conditional_term_factoring(expr): + new_expr = terms_gcd(expr, clear=False, deep=True, expand=False) + + # we do not want to factor exponentials, so handle this separately + if new_expr.is_Mul: + infac = False + asfac = False + for m in new_expr.args: + if isinstance(m, exp): + asfac = True + elif m.is_Add: + infac = any(isinstance(fi, exp) for t in m.args + for fi in Mul.make_args(t)) + if asfac and infac: + new_expr = expr + break + return new_expr + + expr = _conditional_term_factoring(expr) + + # call recursively if more simplification is possible + if orig_expr != expr: + return constantsimp(expr, Cs) + return expr + + +def constant_renumber(expr, variables=None, newconstants=None): + r""" + Renumber arbitrary constants in ``expr`` to use the symbol names as given + in ``newconstants``. In the process, this reorders expression terms in a + standard way. + + If ``newconstants`` is not provided then the new constant names will be + ``C1``, ``C2`` etc. Otherwise ``newconstants`` should be an iterable + giving the new symbols to use for the constants in order. + + The ``variables`` argument is a list of non-constant symbols. All other + free symbols found in ``expr`` are assumed to be constants and will be + renumbered. If ``variables`` is not given then any numbered symbol + beginning with ``C`` (e.g. ``C1``) is assumed to be a constant. + + Symbols are renumbered based on ``.sort_key()``, so they should be + numbered roughly in the order that they appear in the final, printed + expression. Note that this ordering is based in part on hashes, so it can + produce different results on different machines. + + The structure of this function is very similar to that of + :py:meth:`~sympy.solvers.ode.constantsimp`. + + Examples + ======== + + >>> from sympy import symbols + >>> from sympy.solvers.ode.ode import constant_renumber + >>> x, C1, C2, C3 = symbols('x,C1:4') + >>> expr = C3 + C2*x + C1*x**2 + >>> expr + C1*x**2 + C2*x + C3 + >>> constant_renumber(expr) + C1 + C2*x + C3*x**2 + + The ``variables`` argument specifies which are constants so that the + other symbols will not be renumbered: + + >>> constant_renumber(expr, [C1, x]) + C1*x**2 + C2 + C3*x + + The ``newconstants`` argument is used to specify what symbols to use when + replacing the constants: + + >>> constant_renumber(expr, [x], newconstants=symbols('E1:4')) + E1 + E2*x + E3*x**2 + + """ + + # System of expressions + if isinstance(expr, (set, list, tuple)): + return type(expr)(constant_renumber(Tuple(*expr), + variables=variables, newconstants=newconstants)) + + # Symbols in solution but not ODE are constants + if variables is not None: + variables = set(variables) + free_symbols = expr.free_symbols + constantsymbols = list(free_symbols - variables) + # Any Cn is a constant... + else: + variables = set() + isconstant = lambda s: s.startswith('C') and s[1:].isdigit() + constantsymbols = [sym for sym in expr.free_symbols if isconstant(sym.name)] + + # Find new constants checking that they aren't already in the ODE + if newconstants is None: + iter_constants = numbered_symbols(start=1, prefix='C', exclude=variables) + else: + iter_constants = (sym for sym in newconstants if sym not in variables) + + constants_found = [] + + # make a mapping to send all constantsymbols to S.One and use + # that to make sure that term ordering is not dependent on + # the indexed value of C + C_1 = [(ci, S.One) for ci in constantsymbols] + sort_key=lambda arg: default_sort_key(arg.subs(C_1)) + + def _constant_renumber(expr): + r""" + We need to have an internal recursive function + """ + + # For system of expressions + if isinstance(expr, Tuple): + renumbered = [_constant_renumber(e) for e in expr] + return Tuple(*renumbered) + + if isinstance(expr, Equality): + return Eq( + _constant_renumber(expr.lhs), + _constant_renumber(expr.rhs)) + + if type(expr) not in (Mul, Add, Pow) and not expr.is_Function and \ + not expr.has(*constantsymbols): + # Base case, as above. Hope there aren't constants inside + # of some other class, because they won't be renumbered. + return expr + elif expr.is_Piecewise: + return expr + elif expr in constantsymbols: + if expr not in constants_found: + constants_found.append(expr) + return expr + elif expr.is_Function or expr.is_Pow: + return expr.func( + *[_constant_renumber(x) for x in expr.args]) + else: + sortedargs = list(expr.args) + sortedargs.sort(key=sort_key) + return expr.func(*[_constant_renumber(x) for x in sortedargs]) + expr = _constant_renumber(expr) + + # Don't renumber symbols present in the ODE. + constants_found = [c for c in constants_found if c not in variables] + + # Renumbering happens here + subs_dict = dict(zip(constants_found, iter_constants)) + expr = expr.subs(subs_dict, simultaneous=True) + + return expr + + +def _handle_Integral(expr, func, hint): + r""" + Converts a solution with Integrals in it into an actual solution. + + For most hints, this simply runs ``expr.doit()``. + + """ + if hint == "nth_linear_constant_coeff_homogeneous": + sol = expr + elif not hint.endswith("_Integral"): + sol = expr.doit() + else: + sol = expr + return sol + + +# XXX: Should this function maybe go somewhere else? + + +def homogeneous_order(eq, *symbols): + r""" + Returns the order `n` if `g` is homogeneous and ``None`` if it is not + homogeneous. + + Determines if a function is homogeneous and if so of what order. A + function `f(x, y, \cdots)` is homogeneous of order `n` if `f(t x, t y, + \cdots) = t^n f(x, y, \cdots)`. + + If the function is of two variables, `F(x, y)`, then `f` being homogeneous + of any order is equivalent to being able to rewrite `F(x, y)` as `G(x/y)` + or `H(y/x)`. This fact is used to solve 1st order ordinary differential + equations whose coefficients are homogeneous of the same order (see the + docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep`). + + Symbols can be functions, but every argument of the function must be a + symbol, and the arguments of the function that appear in the expression + must match those given in the list of symbols. If a declared function + appears with different arguments than given in the list of symbols, + ``None`` is returned. + + Examples + ======== + + >>> from sympy import Function, homogeneous_order, sqrt + >>> from sympy.abc import x, y + >>> f = Function('f') + >>> homogeneous_order(f(x), f(x)) is None + True + >>> homogeneous_order(f(x,y), f(y, x), x, y) is None + True + >>> homogeneous_order(f(x), f(x), x) + 1 + >>> homogeneous_order(x**2*f(x)/sqrt(x**2+f(x)**2), x, f(x)) + 2 + >>> homogeneous_order(x**2+f(x), x, f(x)) is None + True + + """ + + if not symbols: + raise ValueError("homogeneous_order: no symbols were given.") + symset = set(symbols) + eq = sympify(eq) + + # The following are not supported + if eq.has(Order, Derivative): + return None + + # These are all constants + if (eq.is_Number or + eq.is_NumberSymbol or + eq.is_number + ): + return S.Zero + + # Replace all functions with dummy variables + dum = numbered_symbols(prefix='d', cls=Dummy) + newsyms = set() + for i in [j for j in symset if getattr(j, 'is_Function')]: + iargs = set(i.args) + if iargs.difference(symset): + return None + else: + dummyvar = next(dum) + eq = eq.subs(i, dummyvar) + symset.remove(i) + newsyms.add(dummyvar) + symset.update(newsyms) + + if not eq.free_symbols & symset: + return None + + # assuming order of a nested function can only be equal to zero + if isinstance(eq, Function): + return None if homogeneous_order( + eq.args[0], *tuple(symset)) != 0 else S.Zero + + # make the replacement of x with x*t and see if t can be factored out + t = Dummy('t', positive=True) # It is sufficient that t > 0 + eqs = separatevars(eq.subs([(i, t*i) for i in symset]), [t], dict=True)[t] + if eqs is S.One: + return S.Zero # there was no term with only t + i, d = eqs.as_independent(t, as_Add=False) + b, e = d.as_base_exp() + if b == t: + return e + + +def ode_2nd_power_series_ordinary(eq, func, order, match): + r""" + Gives a power series solution to a second order homogeneous differential + equation with polynomial coefficients at an ordinary point. A homogeneous + differential equation is of the form + + .. math :: P(x)\frac{d^2y}{dx^2} + Q(x)\frac{dy}{dx} + R(x) y(x) = 0 + + For simplicity it is assumed that `P(x)`, `Q(x)` and `R(x)` are polynomials, + it is sufficient that `\frac{Q(x)}{P(x)}` and `\frac{R(x)}{P(x)}` exists at + `x_{0}`. A recurrence relation is obtained by substituting `y` as `\sum_{n=0}^\infty a_{n}x^{n}`, + in the differential equation, and equating the nth term. Using this relation + various terms can be generated. + + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = f(x).diff(x, 2) + f(x) + >>> pprint(dsolve(eq, hint='2nd_power_series_ordinary')) + / 4 2 \ / 2\ + |x x | | x | / 6\ + f(x) = C2*|-- - -- + 1| + C1*x*|1 - --| + O\x / + \24 2 / \ 6 / + + + References + ========== + - https://tutorial.math.lamar.edu/Classes/DE/SeriesSolutions.aspx + - George E. Simmons, "Differential Equations with Applications and + Historical Notes", p.p 176 - 184 + + """ + x = func.args[0] + f = func.func + C0, C1 = get_numbered_constants(eq, num=2) + n = Dummy("n", integer=True) + s = Wild("s") + k = Wild("k", exclude=[x]) + x0 = match['x0'] + terms = match['terms'] + p = match[match['a3']] + q = match[match['b3']] + r = match[match['c3']] + seriesdict = {} + recurr = Function("r") + + # Generating the recurrence relation which works this way: + # for the second order term the summation begins at n = 2. The coefficients + # p is multiplied with an*(n - 1)*(n - 2)*x**n-2 and a substitution is made such that + # the exponent of x becomes n. + # For example, if p is x, then the second degree recurrence term is + # an*(n - 1)*(n - 2)*x**n-1, substituting (n - 1) as n, it transforms to + # an+1*n*(n - 1)*x**n. + # A similar process is done with the first order and zeroth order term. + + coefflist = [(recurr(n), r), (n*recurr(n), q), (n*(n - 1)*recurr(n), p)] + for index, coeff in enumerate(coefflist): + if coeff[1]: + f2 = powsimp(expand((coeff[1]*(x - x0)**(n - index)).subs(x, x + x0))) + if f2.is_Add: + addargs = f2.args + else: + addargs = [f2] + for arg in addargs: + powm = arg.match(s*x**k) + term = coeff[0]*powm[s] + if not powm[k].is_Symbol: + term = term.subs(n, n - powm[k].as_independent(n)[0]) + startind = powm[k].subs(n, index) + # Seeing if the startterm can be reduced further. + # If it vanishes for n lesser than startind, it is + # equal to summation from n. + if startind: + for i in reversed(range(startind)): + if not term.subs(n, i): + seriesdict[term] = i + else: + seriesdict[term] = i + 1 + break + else: + seriesdict[term] = S.Zero + + # Stripping of terms so that the sum starts with the same number. + teq = S.Zero + suminit = seriesdict.values() + rkeys = seriesdict.keys() + req = Add(*rkeys) + if any(suminit): + maxval = max(suminit) + for term in seriesdict: + val = seriesdict[term] + if val != maxval: + for i in range(val, maxval): + teq += term.subs(n, val) + + finaldict = {} + if teq: + fargs = teq.atoms(AppliedUndef) + if len(fargs) == 1: + finaldict[fargs.pop()] = 0 + else: + maxf = max(fargs, key = lambda x: x.args[0]) + sol = solve(teq, maxf) + if isinstance(sol, list): + sol = sol[0] + finaldict[maxf] = sol + + # Finding the recurrence relation in terms of the largest term. + fargs = req.atoms(AppliedUndef) + maxf = max(fargs, key = lambda x: x.args[0]) + minf = min(fargs, key = lambda x: x.args[0]) + if minf.args[0].is_Symbol: + startiter = 0 + else: + startiter = -minf.args[0].as_independent(n)[0] + lhs = maxf + rhs = solve(req, maxf) + if isinstance(rhs, list): + rhs = rhs[0] + + # Checking how many values are already present + tcounter = len([t for t in finaldict.values() if t]) + + for _ in range(tcounter, terms - 3): # Assuming c0 and c1 to be arbitrary + check = rhs.subs(n, startiter) + nlhs = lhs.subs(n, startiter) + nrhs = check.subs(finaldict) + finaldict[nlhs] = nrhs + startiter += 1 + + # Post processing + series = C0 + C1*(x - x0) + for term in finaldict: + if finaldict[term]: + fact = term.args[0] + series += (finaldict[term].subs([(recurr(0), C0), (recurr(1), C1)])*( + x - x0)**fact) + series = collect(expand_mul(series), [C0, C1]) + Order(x**terms) + return Eq(f(x), series) + + +def ode_2nd_power_series_regular(eq, func, order, match): + r""" + Gives a power series solution to a second order homogeneous differential + equation with polynomial coefficients at a regular point. A second order + homogeneous differential equation is of the form + + .. math :: P(x)\frac{d^2y}{dx^2} + Q(x)\frac{dy}{dx} + R(x) y(x) = 0 + + A point is said to regular singular at `x0` if `x - x0\frac{Q(x)}{P(x)}` + and `(x - x0)^{2}\frac{R(x)}{P(x)}` are analytic at `x0`. For simplicity + `P(x)`, `Q(x)` and `R(x)` are assumed to be polynomials. The algorithm for + finding the power series solutions is: + + 1. Try expressing `(x - x0)P(x)` and `((x - x0)^{2})Q(x)` as power series + solutions about x0. Find `p0` and `q0` which are the constants of the + power series expansions. + 2. Solve the indicial equation `f(m) = m(m - 1) + m*p0 + q0`, to obtain the + roots `m1` and `m2` of the indicial equation. + 3. If `m1 - m2` is a non integer there exists two series solutions. If + `m1 = m2`, there exists only one solution. If `m1 - m2` is an integer, + then the existence of one solution is confirmed. The other solution may + or may not exist. + + The power series solution is of the form `x^{m}\sum_{n=0}^\infty a_{n}x^{n}`. The + coefficients are determined by the following recurrence relation. + `a_{n} = -\frac{\sum_{k=0}^{n-1} q_{n-k} + (m + k)p_{n-k}}{f(m + n)}`. For the case + in which `m1 - m2` is an integer, it can be seen from the recurrence relation + that for the lower root `m`, when `n` equals the difference of both the + roots, the denominator becomes zero. So if the numerator is not equal to zero, + a second series solution exists. + + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = x*(f(x).diff(x, 2)) + 2*(f(x).diff(x)) + x*f(x) + >>> pprint(dsolve(eq, hint='2nd_power_series_regular')) + / 6 4 2 \ + | x x x | + / 4 2 \ C1*|- --- + -- - -- + 1| + |x x | \ 720 24 2 / / 6\ + f(x) = C2*|--- - -- + 1| + ------------------------ + O\x / + \120 6 / x + + + References + ========== + - George E. Simmons, "Differential Equations with Applications and + Historical Notes", p.p 176 - 184 + + """ + x = func.args[0] + f = func.func + C0, C1 = get_numbered_constants(eq, num=2) + m = Dummy("m") # for solving the indicial equation + x0 = match['x0'] + terms = match['terms'] + p = match['p'] + q = match['q'] + + # Generating the indicial equation + indicial = [] + for term in [p, q]: + if not term.has(x): + indicial.append(term) + else: + term = series(term, x=x, n=1, x0=x0) + if isinstance(term, Order): + indicial.append(S.Zero) + else: + for arg in term.args: + if not arg.has(x): + indicial.append(arg) + break + + p0, q0 = indicial + sollist = solve(m*(m - 1) + m*p0 + q0, m) + if sollist and isinstance(sollist, list) and all( + sol.is_real for sol in sollist): + serdict1 = {} + serdict2 = {} + if len(sollist) == 1: + # Only one series solution exists in this case. + m1 = m2 = sollist.pop() + if terms-m1-1 <= 0: + return Eq(f(x), Order(terms)) + serdict1 = _frobenius(terms-m1-1, m1, p0, q0, p, q, x0, x, C0) + + else: + m1 = sollist[0] + m2 = sollist[1] + if m1 < m2: + m1, m2 = m2, m1 + # Irrespective of whether m1 - m2 is an integer or not, one + # Frobenius series solution exists. + serdict1 = _frobenius(terms-m1-1, m1, p0, q0, p, q, x0, x, C0) + if not (m1 - m2).is_integer: + # Second frobenius series solution exists. + serdict2 = _frobenius(terms-m2-1, m2, p0, q0, p, q, x0, x, C1) + else: + # Check if second frobenius series solution exists. + serdict2 = _frobenius(terms-m2-1, m2, p0, q0, p, q, x0, x, C1, check=m1) + + if serdict1: + finalseries1 = C0 + for key in serdict1: + power = int(key.name[1:]) + finalseries1 += serdict1[key]*(x - x0)**power + finalseries1 = (x - x0)**m1*finalseries1 + finalseries2 = S.Zero + if serdict2: + for key in serdict2: + power = int(key.name[1:]) + finalseries2 += serdict2[key]*(x - x0)**power + finalseries2 += C1 + finalseries2 = (x - x0)**m2*finalseries2 + return Eq(f(x), collect(finalseries1 + finalseries2, + [C0, C1]) + Order(x**terms)) + + +def _frobenius(n, m, p0, q0, p, q, x0, x, c, check=None): + r""" + Returns a dict with keys as coefficients and values as their values in terms of C0 + """ + n = int(n) + # In cases where m1 - m2 is not an integer + m2 = check + + d = Dummy("d") + numsyms = numbered_symbols("C", start=0) + numsyms = [next(numsyms) for i in range(n + 1)] + serlist = [] + for ser in [p, q]: + # Order term not present + if ser.is_polynomial(x) and Poly(ser, x).degree() <= n: + if x0: + ser = ser.subs(x, x + x0) + dict_ = Poly(ser, x).as_dict() + # Order term present + else: + tseries = series(ser, x=x0, n=n+1) + # Removing order + dict_ = Poly(list(ordered(tseries.args))[: -1], x).as_dict() + # Fill in with zeros, if coefficients are zero. + for i in range(n + 1): + if (i,) not in dict_: + dict_[(i,)] = S.Zero + serlist.append(dict_) + + pseries = serlist[0] + qseries = serlist[1] + indicial = d*(d - 1) + d*p0 + q0 + frobdict = {} + for i in range(1, n + 1): + num = c*(m*pseries[(i,)] + qseries[(i,)]) + for j in range(1, i): + sym = Symbol("C" + str(j)) + num += frobdict[sym]*((m + j)*pseries[(i - j,)] + qseries[(i - j,)]) + + # Checking for cases when m1 - m2 is an integer. If num equals zero + # then a second Frobenius series solution cannot be found. If num is not zero + # then set constant as zero and proceed. + if m2 is not None and i == m2 - m: + if num: + return False + else: + frobdict[numsyms[i]] = S.Zero + else: + frobdict[numsyms[i]] = -num/(indicial.subs(d, m+i)) + + return frobdict + +def _remove_redundant_solutions(eq, solns, order, var): + r""" + Remove redundant solutions from the set of solutions. + + This function is needed because otherwise dsolve can return + redundant solutions. As an example consider: + + eq = Eq((f(x).diff(x, 2))*f(x).diff(x), 0) + + There are two ways to find solutions to eq. The first is to solve f(x).diff(x, 2) = 0 + leading to solution f(x)=C1 + C2*x. The second is to solve the equation f(x).diff(x) = 0 + leading to the solution f(x) = C1. In this particular case we then see + that the second solution is a special case of the first and we do not + want to return it. + + This does not always happen. If we have + + eq = Eq((f(x)**2-4)*(f(x).diff(x)-4), 0) + + then we get the algebraic solution f(x) = [-2, 2] and the integral solution + f(x) = x + C1 and in this case the two solutions are not equivalent wrt + initial conditions so both should be returned. + """ + def is_special_case_of(soln1, soln2): + return _is_special_case_of(soln1, soln2, eq, order, var) + + unique_solns = [] + for soln1 in solns: + for soln2 in unique_solns[:]: + if is_special_case_of(soln1, soln2): + break + elif is_special_case_of(soln2, soln1): + unique_solns.remove(soln2) + else: + unique_solns.append(soln1) + + return unique_solns + +def _is_special_case_of(soln1, soln2, eq, order, var): + r""" + True if soln1 is found to be a special case of soln2 wrt some value of the + constants that appear in soln2. False otherwise. + """ + # The solutions returned by dsolve may be given explicitly or implicitly. + # We will equate the sol1=(soln1.rhs - soln1.lhs), sol2=(soln2.rhs - soln2.lhs) + # of the two solutions. + # + # Since this is supposed to hold for all x it also holds for derivatives. + # For an order n ode we should be able to differentiate + # each solution n times to get n+1 equations. + # + # We then try to solve those n+1 equations for the integrations constants + # in sol2. If we can find a solution that does not depend on x then it + # means that some value of the constants in sol1 is a special case of + # sol2 corresponding to a particular choice of the integration constants. + + # In case the solution is in implicit form we subtract the sides + soln1 = soln1.rhs - soln1.lhs + soln2 = soln2.rhs - soln2.lhs + + # Work for the series solution + if soln1.has(Order) and soln2.has(Order): + if soln1.getO() == soln2.getO(): + soln1 = soln1.removeO() + soln2 = soln2.removeO() + else: + return False + elif soln1.has(Order) or soln2.has(Order): + return False + + constants1 = soln1.free_symbols.difference(eq.free_symbols) + constants2 = soln2.free_symbols.difference(eq.free_symbols) + + constants1_new = get_numbered_constants(Tuple(soln1, soln2), len(constants1)) + if len(constants1) == 1: + constants1_new = {constants1_new} + for c_old, c_new in zip(constants1, constants1_new): + soln1 = soln1.subs(c_old, c_new) + + # n equations for sol1 = sol2, sol1'=sol2', ... + lhs = soln1 + rhs = soln2 + eqns = [Eq(lhs, rhs)] + for n in range(1, order): + lhs = lhs.diff(var) + rhs = rhs.diff(var) + eq = Eq(lhs, rhs) + eqns.append(eq) + + # BooleanTrue/False awkwardly show up for trivial equations + if any(isinstance(eq, BooleanFalse) for eq in eqns): + return False + eqns = [eq for eq in eqns if not isinstance(eq, BooleanTrue)] + + try: + constant_solns = solve(eqns, constants2) + except NotImplementedError: + return False + + # Sometimes returns a dict and sometimes a list of dicts + if isinstance(constant_solns, dict): + constant_solns = [constant_solns] + + # after solving the issue 17418, maybe we don't need the following checksol code. + for constant_soln in constant_solns: + for eq in eqns: + eq=eq.rhs-eq.lhs + if checksol(eq, constant_soln) is not True: + return False + + # If any solution gives all constants as expressions that don't depend on + # x then there exists constants for soln2 that give soln1 + for constant_soln in constant_solns: + if not any(c.has(var) for c in constant_soln.values()): + return True + + return False + + +def ode_1st_power_series(eq, func, order, match): + r""" + The power series solution is a method which gives the Taylor series expansion + to the solution of a differential equation. + + For a first order differential equation `\frac{dy}{dx} = h(x, y)`, a power + series solution exists at a point `x = x_{0}` if `h(x, y)` is analytic at `x_{0}`. + The solution is given by + + .. math:: y(x) = y(x_{0}) + \sum_{n = 1}^{\infty} \frac{F_{n}(x_{0},b)(x - x_{0})^n}{n!}, + + where `y(x_{0}) = b` is the value of y at the initial value of `x_{0}`. + To compute the values of the `F_{n}(x_{0},b)` the following algorithm is + followed, until the required number of terms are generated. + + 1. `F_1 = h(x_{0}, b)` + 2. `F_{n+1} = \frac{\partial F_{n}}{\partial x} + \frac{\partial F_{n}}{\partial y}F_{1}` + + Examples + ======== + + >>> from sympy import Function, pprint, exp, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = exp(x)*(f(x).diff(x)) - f(x) + >>> pprint(dsolve(eq, hint='1st_power_series')) + 3 4 5 + C1*x C1*x C1*x / 6\ + f(x) = C1 + C1*x - ----- + ----- + ----- + O\x / + 6 24 60 + + + References + ========== + + - Travis W. Walker, Analytic power series technique for solving first-order + differential equations, p.p 17, 18 + + """ + x = func.args[0] + y = match['y'] + f = func.func + h = -match[match['d']]/match[match['e']] + point = match['f0'] + value = match['f0val'] + terms = match['terms'] + + # First term + F = h + if not h: + return Eq(f(x), value) + + # Initialization + series = value + if terms > 1: + hc = h.subs({x: point, y: value}) + if hc.has(oo) or hc.has(nan) or hc.has(zoo): + # Derivative does not exist, not analytic + return Eq(f(x), oo) + elif hc: + series += hc*(x - point) + + for factcount in range(2, terms): + Fnew = F.diff(x) + F.diff(y)*h + Fnewc = Fnew.subs({x: point, y: value}) + # Same logic as above + if Fnewc.has(oo) or Fnewc.has(nan) or Fnewc.has(-oo) or Fnewc.has(zoo): + return Eq(f(x), oo) + series += Fnewc*((x - point)**factcount)/factorial(factcount) + F = Fnew + series += Order(x**terms) + return Eq(f(x), series) + + +def checkinfsol(eq, infinitesimals, func=None, order=None): + r""" + This function is used to check if the given infinitesimals are the + actual infinitesimals of the given first order differential equation. + This method is specific to the Lie Group Solver of ODEs. + + As of now, it simply checks, by substituting the infinitesimals in the + partial differential equation. + + + .. math:: \frac{\partial \eta}{\partial x} + \left(\frac{\partial \eta}{\partial y} + - \frac{\partial \xi}{\partial x}\right)*h + - \frac{\partial \xi}{\partial y}*h^{2} + - \xi\frac{\partial h}{\partial x} - \eta\frac{\partial h}{\partial y} = 0 + + + where `\eta`, and `\xi` are the infinitesimals and `h(x,y) = \frac{dy}{dx}` + + The infinitesimals should be given in the form of a list of dicts + ``[{xi(x, y): inf, eta(x, y): inf}]``, corresponding to the + output of the function infinitesimals. It returns a list + of values of the form ``[(True/False, sol)]`` where ``sol`` is the value + obtained after substituting the infinitesimals in the PDE. If it + is ``True``, then ``sol`` would be 0. + + """ + if isinstance(eq, Equality): + eq = eq.lhs - eq.rhs + if not func: + eq, func = _preprocess(eq) + variables = func.args + if len(variables) != 1: + raise ValueError("ODE's have only one independent variable") + else: + x = variables[0] + if not order: + order = ode_order(eq, func) + if order != 1: + raise NotImplementedError("Lie groups solver has been implemented " + "only for first order differential equations") + else: + df = func.diff(x) + a = Wild('a', exclude = [df]) + b = Wild('b', exclude = [df]) + match = collect(expand(eq), df).match(a*df + b) + + if match: + h = -simplify(match[b]/match[a]) + else: + try: + sol = solve(eq, df) + except NotImplementedError: + raise NotImplementedError("Infinitesimals for the " + "first order ODE could not be found") + else: + h = sol[0] # Find infinitesimals for one solution + + y = Dummy('y') + h = h.subs(func, y) + xi = Function('xi')(x, y) + eta = Function('eta')(x, y) + dxi = Function('xi')(x, func) + deta = Function('eta')(x, func) + pde = (eta.diff(x) + (eta.diff(y) - xi.diff(x))*h - + (xi.diff(y))*h**2 - xi*(h.diff(x)) - eta*(h.diff(y))) + soltup = [] + for sol in infinitesimals: + tsol = {xi: S(sol[dxi]).subs(func, y), + eta: S(sol[deta]).subs(func, y)} + sol = simplify(pde.subs(tsol).doit()) + if sol: + soltup.append((False, sol.subs(y, func))) + else: + soltup.append((True, 0)) + return soltup + + +def sysode_linear_2eq_order1(match_): + x = match_['func'][0].func + y = match_['func'][1].func + func = match_['func'] + fc = match_['func_coeff'] + eq = match_['eq'] + r = {} + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + for i in range(2): + eq[i] = Add(*[terms/fc[i,func[i],1] for terms in Add.make_args(eq[i])]) + + # for equations Eq(a1*diff(x(t),t), a*x(t) + b*y(t) + k1) + # and Eq(a2*diff(x(t),t), c*x(t) + d*y(t) + k2) + r['a'] = -fc[0,x(t),0]/fc[0,x(t),1] + r['c'] = -fc[1,x(t),0]/fc[1,y(t),1] + r['b'] = -fc[0,y(t),0]/fc[0,x(t),1] + r['d'] = -fc[1,y(t),0]/fc[1,y(t),1] + forcing = [S.Zero,S.Zero] + for i in range(2): + for j in Add.make_args(eq[i]): + if not j.has(x(t), y(t)): + forcing[i] += j + if not (forcing[0].has(t) or forcing[1].has(t)): + r['k1'] = forcing[0] + r['k2'] = forcing[1] + else: + raise NotImplementedError("Only homogeneous problems are supported" + + " (and constant inhomogeneity)") + + if match_['type_of_equation'] == 'type6': + sol = _linear_2eq_order1_type6(x, y, t, r, eq) + if match_['type_of_equation'] == 'type7': + sol = _linear_2eq_order1_type7(x, y, t, r, eq) + return sol + +def _linear_2eq_order1_type6(x, y, t, r, eq): + r""" + The equations of this type of ode are . + + .. math:: x' = f(t) x + g(t) y + + .. math:: y' = a [f(t) + a h(t)] x + a [g(t) - h(t)] y + + This is solved by first multiplying the first equation by `-a` and adding + it to the second equation to obtain + + .. math:: y' - a x' = -a h(t) (y - a x) + + Setting `U = y - ax` and integrating the equation we arrive at + + .. math:: y - ax = C_1 e^{-a \int h(t) \,dt} + + and on substituting the value of y in first equation give rise to first order ODEs. After solving for + `x`, we can obtain `y` by substituting the value of `x` in second equation. + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + p = 0 + q = 0 + p1 = cancel(r['c']/cancel(r['c']/r['d']).as_numer_denom()[0]) + p2 = cancel(r['a']/cancel(r['a']/r['b']).as_numer_denom()[0]) + for n, i in enumerate([p1, p2]): + for j in Mul.make_args(collect_const(i)): + if not j.has(t): + q = j + if q!=0 and n==0: + if ((r['c']/j - r['a'])/(r['b'] - r['d']/j)) == j: + p = 1 + s = j + break + if q!=0 and n==1: + if ((r['a']/j - r['c'])/(r['d'] - r['b']/j)) == j: + p = 2 + s = j + break + + if p == 1: + equ = diff(x(t),t) - r['a']*x(t) - r['b']*(s*x(t) + C1*exp(-s*Integral(r['b'] - r['d']/s, t))) + hint1 = classify_ode(equ)[1] + sol1 = dsolve(equ, hint=hint1+'_Integral').rhs + sol2 = s*sol1 + C1*exp(-s*Integral(r['b'] - r['d']/s, t)) + elif p ==2: + equ = diff(y(t),t) - r['c']*y(t) - r['d']*s*y(t) + C1*exp(-s*Integral(r['d'] - r['b']/s, t)) + hint1 = classify_ode(equ)[1] + sol2 = dsolve(equ, hint=hint1+'_Integral').rhs + sol1 = s*sol2 + C1*exp(-s*Integral(r['d'] - r['b']/s, t)) + return [Eq(x(t), sol1), Eq(y(t), sol2)] + +def _linear_2eq_order1_type7(x, y, t, r, eq): + r""" + The equations of this type of ode are . + + .. math:: x' = f(t) x + g(t) y + + .. math:: y' = h(t) x + p(t) y + + Differentiating the first equation and substituting the value of `y` + from second equation will give a second-order linear equation + + .. math:: g x'' - (fg + gp + g') x' + (fgp - g^{2} h + f g' - f' g) x = 0 + + This above equation can be easily integrated if following conditions are satisfied. + + 1. `fgp - g^{2} h + f g' - f' g = 0` + + 2. `fgp - g^{2} h + f g' - f' g = ag, fg + gp + g' = bg` + + If first condition is satisfied then it is solved by current dsolve solver and in second case it becomes + a constant coefficient differential equation which is also solved by current solver. + + Otherwise if the above condition fails then, + a particular solution is assumed as `x = x_0(t)` and `y = y_0(t)` + Then the general solution is expressed as + + .. math:: x = C_1 x_0(t) + C_2 x_0(t) \int \frac{g(t) F(t) P(t)}{x_0^{2}(t)} \,dt + + .. math:: y = C_1 y_0(t) + C_2 [\frac{F(t) P(t)}{x_0(t)} + y_0(t) \int \frac{g(t) F(t) P(t)}{x_0^{2}(t)} \,dt] + + where C1 and C2 are arbitrary constants and + + .. math:: F(t) = e^{\int f(t) \,dt}, P(t) = e^{\int p(t) \,dt} + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + e1 = r['a']*r['b']*r['c'] - r['b']**2*r['c'] + r['a']*diff(r['b'],t) - diff(r['a'],t)*r['b'] + e2 = r['a']*r['c']*r['d'] - r['b']*r['c']**2 + diff(r['c'],t)*r['d'] - r['c']*diff(r['d'],t) + m1 = r['a']*r['b'] + r['b']*r['d'] + diff(r['b'],t) + m2 = r['a']*r['c'] + r['c']*r['d'] + diff(r['c'],t) + if e1 == 0: + sol1 = dsolve(r['b']*diff(x(t),t,t) - m1*diff(x(t),t)).rhs + sol2 = dsolve(diff(y(t),t) - r['c']*sol1 - r['d']*y(t)).rhs + elif e2 == 0: + sol2 = dsolve(r['c']*diff(y(t),t,t) - m2*diff(y(t),t)).rhs + sol1 = dsolve(diff(x(t),t) - r['a']*x(t) - r['b']*sol2).rhs + elif not (e1/r['b']).has(t) and not (m1/r['b']).has(t): + sol1 = dsolve(diff(x(t),t,t) - (m1/r['b'])*diff(x(t),t) - (e1/r['b'])*x(t)).rhs + sol2 = dsolve(diff(y(t),t) - r['c']*sol1 - r['d']*y(t)).rhs + elif not (e2/r['c']).has(t) and not (m2/r['c']).has(t): + sol2 = dsolve(diff(y(t),t,t) - (m2/r['c'])*diff(y(t),t) - (e2/r['c'])*y(t)).rhs + sol1 = dsolve(diff(x(t),t) - r['a']*x(t) - r['b']*sol2).rhs + else: + x0 = Function('x0')(t) # x0 and y0 being particular solutions + y0 = Function('y0')(t) + F = exp(Integral(r['a'],t)) + P = exp(Integral(r['d'],t)) + sol1 = C1*x0 + C2*x0*Integral(r['b']*F*P/x0**2, t) + sol2 = C1*y0 + C2*(F*P/x0 + y0*Integral(r['b']*F*P/x0**2, t)) + return [Eq(x(t), sol1), Eq(y(t), sol2)] + + +def sysode_nonlinear_2eq_order1(match_): + func = match_['func'] + eq = match_['eq'] + fc = match_['func_coeff'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + if match_['type_of_equation'] == 'type5': + sol = _nonlinear_2eq_order1_type5(func, t, eq) + return sol + x = func[0].func + y = func[1].func + for i in range(2): + eqs = 0 + for terms in Add.make_args(eq[i]): + eqs += terms/fc[i,func[i],1] + eq[i] = eqs + if match_['type_of_equation'] == 'type1': + sol = _nonlinear_2eq_order1_type1(x, y, t, eq) + elif match_['type_of_equation'] == 'type2': + sol = _nonlinear_2eq_order1_type2(x, y, t, eq) + elif match_['type_of_equation'] == 'type3': + sol = _nonlinear_2eq_order1_type3(x, y, t, eq) + elif match_['type_of_equation'] == 'type4': + sol = _nonlinear_2eq_order1_type4(x, y, t, eq) + return sol + + +def _nonlinear_2eq_order1_type1(x, y, t, eq): + r""" + Equations: + + .. math:: x' = x^n F(x,y) + + .. math:: y' = g(y) F(x,y) + + Solution: + + .. math:: x = \varphi(y), \int \frac{1}{g(y) F(\varphi(y),y)} \,dy = t + C_2 + + where + + if `n \neq 1` + + .. math:: \varphi = [C_1 + (1-n) \int \frac{1}{g(y)} \,dy]^{\frac{1}{1-n}} + + if `n = 1` + + .. math:: \varphi = C_1 e^{\int \frac{1}{g(y)} \,dy} + + where `C_1` and `C_2` are arbitrary constants. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + n = Wild('n', exclude=[x(t),y(t)]) + f = Wild('f') + u, v = symbols('u, v') + r = eq[0].match(diff(x(t),t) - x(t)**n*f) + g = ((diff(y(t),t) - eq[1])/r[f]).subs(y(t),v) + F = r[f].subs(x(t),u).subs(y(t),v) + n = r[n] + if n!=1: + phi = (C1 + (1-n)*Integral(1/g, v))**(1/(1-n)) + else: + phi = C1*exp(Integral(1/g, v)) + phi = phi.doit() + sol2 = solve(Integral(1/(g*F.subs(u,phi)), v).doit() - t - C2, v) + sol = [] + for sols in sol2: + sol.append(Eq(x(t),phi.subs(v, sols))) + sol.append(Eq(y(t), sols)) + return sol + +def _nonlinear_2eq_order1_type2(x, y, t, eq): + r""" + Equations: + + .. math:: x' = e^{\lambda x} F(x,y) + + .. math:: y' = g(y) F(x,y) + + Solution: + + .. math:: x = \varphi(y), \int \frac{1}{g(y) F(\varphi(y),y)} \,dy = t + C_2 + + where + + if `\lambda \neq 0` + + .. math:: \varphi = -\frac{1}{\lambda} log(C_1 - \lambda \int \frac{1}{g(y)} \,dy) + + if `\lambda = 0` + + .. math:: \varphi = C_1 + \int \frac{1}{g(y)} \,dy + + where `C_1` and `C_2` are arbitrary constants. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + n = Wild('n', exclude=[x(t),y(t)]) + f = Wild('f') + u, v = symbols('u, v') + r = eq[0].match(diff(x(t),t) - exp(n*x(t))*f) + g = ((diff(y(t),t) - eq[1])/r[f]).subs(y(t),v) + F = r[f].subs(x(t),u).subs(y(t),v) + n = r[n] + if n: + phi = -1/n*log(C1 - n*Integral(1/g, v)) + else: + phi = C1 + Integral(1/g, v) + phi = phi.doit() + sol2 = solve(Integral(1/(g*F.subs(u,phi)), v).doit() - t - C2, v) + sol = [] + for sols in sol2: + sol.append(Eq(x(t),phi.subs(v, sols))) + sol.append(Eq(y(t), sols)) + return sol + +def _nonlinear_2eq_order1_type3(x, y, t, eq): + r""" + Autonomous system of general form + + .. math:: x' = F(x,y) + + .. math:: y' = G(x,y) + + Assuming `y = y(x, C_1)` where `C_1` is an arbitrary constant is the general + solution of the first-order equation + + .. math:: F(x,y) y'_x = G(x,y) + + Then the general solution of the original system of equations has the form + + .. math:: \int \frac{1}{F(x,y(x,C_1))} \,dx = t + C_1 + + """ + C1, C2, C3, C4 = get_numbered_constants(eq, num=4) + v = Function('v') + u = Symbol('u') + f = Wild('f') + g = Wild('g') + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + F = r1[f].subs(x(t), u).subs(y(t), v(u)) + G = r2[g].subs(x(t), u).subs(y(t), v(u)) + sol2r = dsolve(Eq(diff(v(u), u), G/F)) + if isinstance(sol2r, Equality): + sol2r = [sol2r] + for sol2s in sol2r: + sol1 = solve(Integral(1/F.subs(v(u), sol2s.rhs), u).doit() - t - C2, u) + sol = [] + for sols in sol1: + sol.append(Eq(x(t), sols)) + sol.append(Eq(y(t), (sol2s.rhs).subs(u, sols))) + return sol + +def _nonlinear_2eq_order1_type4(x, y, t, eq): + r""" + Equation: + + .. math:: x' = f_1(x) g_1(y) \phi(x,y,t) + + .. math:: y' = f_2(x) g_2(y) \phi(x,y,t) + + First integral: + + .. math:: \int \frac{f_2(x)}{f_1(x)} \,dx - \int \frac{g_1(y)}{g_2(y)} \,dy = C + + where `C` is an arbitrary constant. + + On solving the first integral for `x` (resp., `y` ) and on substituting the + resulting expression into either equation of the original solution, one + arrives at a first-order equation for determining `y` (resp., `x` ). + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v = symbols('u, v') + U, V = symbols('U, V', cls=Function) + f = Wild('f') + g = Wild('g') + f1 = Wild('f1', exclude=[v,t]) + f2 = Wild('f2', exclude=[v,t]) + g1 = Wild('g1', exclude=[u,t]) + g2 = Wild('g2', exclude=[u,t]) + r1 = eq[0].match(diff(x(t),t) - f) + r2 = eq[1].match(diff(y(t),t) - g) + num, den = ( + (r1[f].subs(x(t),u).subs(y(t),v))/ + (r2[g].subs(x(t),u).subs(y(t),v))).as_numer_denom() + R1 = num.match(f1*g1) + R2 = den.match(f2*g2) + phi = (r1[f].subs(x(t),u).subs(y(t),v))/num + F1 = R1[f1]; F2 = R2[f2] + G1 = R1[g1]; G2 = R2[g2] + sol1r = solve(Integral(F2/F1, u).doit() - Integral(G1/G2,v).doit() - C1, u) + sol2r = solve(Integral(F2/F1, u).doit() - Integral(G1/G2,v).doit() - C1, v) + sol = [] + for sols in sol1r: + sol.append(Eq(y(t), dsolve(diff(V(t),t) - F2.subs(u,sols).subs(v,V(t))*G2.subs(v,V(t))*phi.subs(u,sols).subs(v,V(t))).rhs)) + for sols in sol2r: + sol.append(Eq(x(t), dsolve(diff(U(t),t) - F1.subs(u,U(t))*G1.subs(v,sols).subs(u,U(t))*phi.subs(v,sols).subs(u,U(t))).rhs)) + return set(sol) + +def _nonlinear_2eq_order1_type5(func, t, eq): + r""" + Clairaut system of ODEs + + .. math:: x = t x' + F(x',y') + + .. math:: y = t y' + G(x',y') + + The following are solutions of the system + + `(i)` straight lines: + + .. math:: x = C_1 t + F(C_1, C_2), y = C_2 t + G(C_1, C_2) + + where `C_1` and `C_2` are arbitrary constants; + + `(ii)` envelopes of the above lines; + + `(iii)` continuously differentiable lines made up from segments of the lines + `(i)` and `(ii)`. + + """ + C1, C2 = get_numbered_constants(eq, num=2) + f = Wild('f') + g = Wild('g') + def check_type(x, y): + r1 = eq[0].match(t*diff(x(t),t) - x(t) + f) + r2 = eq[1].match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = eq[0].match(diff(x(t),t) - x(t)/t + f/t) + r2 = eq[1].match(diff(y(t),t) - y(t)/t + g/t) + if not (r1 and r2): + r1 = (-eq[0]).match(t*diff(x(t),t) - x(t) + f) + r2 = (-eq[1]).match(t*diff(y(t),t) - y(t) + g) + if not (r1 and r2): + r1 = (-eq[0]).match(diff(x(t),t) - x(t)/t + f/t) + r2 = (-eq[1]).match(diff(y(t),t) - y(t)/t + g/t) + return [r1, r2] + for func_ in func: + if isinstance(func_, list): + x = func[0][0].func + y = func[0][1].func + [r1, r2] = check_type(x, y) + if not (r1 and r2): + [r1, r2] = check_type(y, x) + x, y = y, x + x1 = diff(x(t),t); y1 = diff(y(t),t) + return {Eq(x(t), C1*t + r1[f].subs(x1,C1).subs(y1,C2)), Eq(y(t), C2*t + r2[g].subs(x1,C1).subs(y1,C2))} + +def sysode_nonlinear_3eq_order1(match_): + x = match_['func'][0].func + y = match_['func'][1].func + z = match_['func'][2].func + eq = match_['eq'] + t = list(list(eq[0].atoms(Derivative))[0].atoms(Symbol))[0] + if match_['type_of_equation'] == 'type1': + sol = _nonlinear_3eq_order1_type1(x, y, z, t, eq) + if match_['type_of_equation'] == 'type2': + sol = _nonlinear_3eq_order1_type2(x, y, z, t, eq) + if match_['type_of_equation'] == 'type3': + sol = _nonlinear_3eq_order1_type3(x, y, z, t, eq) + if match_['type_of_equation'] == 'type4': + sol = _nonlinear_3eq_order1_type4(x, y, z, t, eq) + if match_['type_of_equation'] == 'type5': + sol = _nonlinear_3eq_order1_type5(x, y, z, t, eq) + return sol + +def _nonlinear_3eq_order1_type1(x, y, z, t, eq): + r""" + Equations: + + .. math:: a x' = (b - c) y z, \enspace b y' = (c - a) z x, \enspace c z' = (a - b) x y + + First Integrals: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + .. math:: a^{2} x^{2} + b^{2} y^{2} + c^{2} z^{2} = C_2 + + where `C_1` and `C_2` are arbitrary constants. On solving the integrals for `y` and + `z` and on substituting the resulting expressions into the first equation of the + system, we arrives at a separable first-order equation on `x`. Similarly doing that + for other two equations, we will arrive at first order equation on `y` and `z` too. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0401.pdf + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + r = (diff(x(t),t) - eq[0]).match(p*y(t)*z(t)) + r.update((diff(y(t),t) - eq[1]).match(q*z(t)*x(t))) + r.update((diff(z(t),t) - eq[2]).match(s*x(t)*y(t))) + n1, d1 = r[p].as_numer_denom() + n2, d2 = r[q].as_numer_denom() + n3, d3 = r[s].as_numer_denom() + val = solve([n1*u-d1*v+d1*w, d2*u+n2*v-d2*w, d3*u-d3*v-n3*w],[u,v]) + vals = [val[v], val[u]] + c = lcm(vals[0].as_numer_denom()[1], vals[1].as_numer_denom()[1]) + b = vals[0].subs(w, c) + a = vals[1].subs(w, c) + y_x = sqrt(((c*C1-C2) - a*(c-a)*x(t)**2)/(b*(c-b))) + z_x = sqrt(((b*C1-C2) - a*(b-a)*x(t)**2)/(c*(b-c))) + z_y = sqrt(((a*C1-C2) - b*(a-b)*y(t)**2)/(c*(a-c))) + x_y = sqrt(((c*C1-C2) - b*(c-b)*y(t)**2)/(a*(c-a))) + x_z = sqrt(((b*C1-C2) - c*(b-c)*z(t)**2)/(a*(b-a))) + y_z = sqrt(((a*C1-C2) - c*(a-c)*z(t)**2)/(b*(a-b))) + sol1 = dsolve(a*diff(x(t),t) - (b-c)*y_x*z_x) + sol2 = dsolve(b*diff(y(t),t) - (c-a)*z_y*x_y) + sol3 = dsolve(c*diff(z(t),t) - (a-b)*x_z*y_z) + return [sol1, sol2, sol3] + + +def _nonlinear_3eq_order1_type2(x, y, z, t, eq): + r""" + Equations: + + .. math:: a x' = (b - c) y z f(x, y, z, t) + + .. math:: b y' = (c - a) z x f(x, y, z, t) + + .. math:: c z' = (a - b) x y f(x, y, z, t) + + First Integrals: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + .. math:: a^{2} x^{2} + b^{2} y^{2} + c^{2} z^{2} = C_2 + + where `C_1` and `C_2` are arbitrary constants. On solving the integrals for `y` and + `z` and on substituting the resulting expressions into the first equation of the + system, we arrives at a first-order differential equations on `x`. Similarly doing + that for other two equations we will arrive at first order equation on `y` and `z`. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0402.pdf + + """ + C1, C2 = get_numbered_constants(eq, num=2) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + f = Wild('f') + r1 = (diff(x(t),t) - eq[0]).match(y(t)*z(t)*f) + r = collect_const(r1[f]).match(p*f) + r.update(((diff(y(t),t) - eq[1])/r[f]).match(q*z(t)*x(t))) + r.update(((diff(z(t),t) - eq[2])/r[f]).match(s*x(t)*y(t))) + n1, d1 = r[p].as_numer_denom() + n2, d2 = r[q].as_numer_denom() + n3, d3 = r[s].as_numer_denom() + val = solve([n1*u-d1*v+d1*w, d2*u+n2*v-d2*w, -d3*u+d3*v+n3*w],[u,v]) + vals = [val[v], val[u]] + c = lcm(vals[0].as_numer_denom()[1], vals[1].as_numer_denom()[1]) + a = vals[0].subs(w, c) + b = vals[1].subs(w, c) + y_x = sqrt(((c*C1-C2) - a*(c-a)*x(t)**2)/(b*(c-b))) + z_x = sqrt(((b*C1-C2) - a*(b-a)*x(t)**2)/(c*(b-c))) + z_y = sqrt(((a*C1-C2) - b*(a-b)*y(t)**2)/(c*(a-c))) + x_y = sqrt(((c*C1-C2) - b*(c-b)*y(t)**2)/(a*(c-a))) + x_z = sqrt(((b*C1-C2) - c*(b-c)*z(t)**2)/(a*(b-a))) + y_z = sqrt(((a*C1-C2) - c*(a-c)*z(t)**2)/(b*(a-b))) + sol1 = dsolve(a*diff(x(t),t) - (b-c)*y_x*z_x*r[f]) + sol2 = dsolve(b*diff(y(t),t) - (c-a)*z_y*x_y*r[f]) + sol3 = dsolve(c*diff(z(t),t) - (a-b)*x_z*y_z*r[f]) + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type3(x, y, z, t, eq): + r""" + Equations: + + .. math:: x' = c F_2 - b F_3, \enspace y' = a F_3 - c F_1, \enspace z' = b F_1 - a F_2 + + where `F_n = F_n(x, y, z, t)`. + + 1. First Integral: + + .. math:: a x + b y + c z = C_1, + + where C is an arbitrary constant. + + 2. If we assume function `F_n` to be independent of `t`,i.e, `F_n` = `F_n (x, y, z)` + Then, on eliminating `t` and `z` from the first two equation of the system, one + arrives at the first-order equation + + .. math:: \frac{dy}{dx} = \frac{a F_3 (x, y, z) - c F_1 (x, y, z)}{c F_2 (x, y, z) - + b F_3 (x, y, z)} + + where `z = \frac{1}{c} (C_1 - a x - b y)` + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0404.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + fu, fv, fw = symbols('u, v, w', cls=Function) + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = (diff(x(t), t) - eq[0]).match(F2-F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t), t) - eq[1]).match(p*r[F3] - r[s]*F1)) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t), u).subs(y(t),v).subs(z(t), w) + F2 = r[F2].subs(x(t), u).subs(y(t),v).subs(z(t), w) + F3 = r[F3].subs(x(t), u).subs(y(t),v).subs(z(t), w) + z_xy = (C1-a*u-b*v)/c + y_zx = (C1-a*u-c*w)/b + x_yz = (C1-b*v-c*w)/a + y_x = dsolve(diff(fv(u),u) - ((a*F3-c*F1)/(c*F2-b*F3)).subs(w,z_xy).subs(v,fv(u))).rhs + z_x = dsolve(diff(fw(u),u) - ((b*F1-a*F2)/(c*F2-b*F3)).subs(v,y_zx).subs(w,fw(u))).rhs + z_y = dsolve(diff(fw(v),v) - ((b*F1-a*F2)/(a*F3-c*F1)).subs(u,x_yz).subs(w,fw(v))).rhs + x_y = dsolve(diff(fu(v),v) - ((c*F2-b*F3)/(a*F3-c*F1)).subs(w,z_xy).subs(u,fu(v))).rhs + y_z = dsolve(diff(fv(w),w) - ((a*F3-c*F1)/(b*F1-a*F2)).subs(u,x_yz).subs(v,fv(w))).rhs + x_z = dsolve(diff(fu(w),w) - ((c*F2-b*F3)/(b*F1-a*F2)).subs(v,y_zx).subs(u,fu(w))).rhs + sol1 = dsolve(diff(fu(t),t) - (c*F2 - b*F3).subs(v,y_x).subs(w,z_x).subs(u,fu(t))).rhs + sol2 = dsolve(diff(fv(t),t) - (a*F3 - c*F1).subs(u,x_y).subs(w,z_y).subs(v,fv(t))).rhs + sol3 = dsolve(diff(fw(t),t) - (b*F1 - a*F2).subs(u,x_z).subs(v,y_z).subs(w,fw(t))).rhs + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type4(x, y, z, t, eq): + r""" + Equations: + + .. math:: x' = c z F_2 - b y F_3, \enspace y' = a x F_3 - c z F_1, \enspace z' = b y F_1 - a x F_2 + + where `F_n = F_n (x, y, z, t)` + + 1. First integral: + + .. math:: a x^{2} + b y^{2} + c z^{2} = C_1 + + where `C` is an arbitrary constant. + + 2. Assuming the function `F_n` is independent of `t`: `F_n = F_n (x, y, z)`. Then on + eliminating `t` and `z` from the first two equations of the system, one arrives at + the first-order equation + + .. math:: \frac{dy}{dx} = \frac{a x F_3 (x, y, z) - c z F_1 (x, y, z)} + {c z F_2 (x, y, z) - b y F_3 (x, y, z)} + + where `z = \pm \sqrt{\frac{1}{c} (C_1 - a x^{2} - b y^{2})}` + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0405.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = eq[0].match(diff(x(t),t) - z(t)*F2 + y(t)*F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t),t) - eq[1]).match(p*x(t)*r[F3] - r[s]*z(t)*F1)) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t),u).subs(y(t),v).subs(z(t),w) + F2 = r[F2].subs(x(t),u).subs(y(t),v).subs(z(t),w) + F3 = r[F3].subs(x(t),u).subs(y(t),v).subs(z(t),w) + x_yz = sqrt((C1 - b*v**2 - c*w**2)/a) + y_zx = sqrt((C1 - c*w**2 - a*u**2)/b) + z_xy = sqrt((C1 - a*u**2 - b*v**2)/c) + y_x = dsolve(diff(v(u),u) - ((a*u*F3-c*w*F1)/(c*w*F2-b*v*F3)).subs(w,z_xy).subs(v,v(u))).rhs + z_x = dsolve(diff(w(u),u) - ((b*v*F1-a*u*F2)/(c*w*F2-b*v*F3)).subs(v,y_zx).subs(w,w(u))).rhs + z_y = dsolve(diff(w(v),v) - ((b*v*F1-a*u*F2)/(a*u*F3-c*w*F1)).subs(u,x_yz).subs(w,w(v))).rhs + x_y = dsolve(diff(u(v),v) - ((c*w*F2-b*v*F3)/(a*u*F3-c*w*F1)).subs(w,z_xy).subs(u,u(v))).rhs + y_z = dsolve(diff(v(w),w) - ((a*u*F3-c*w*F1)/(b*v*F1-a*u*F2)).subs(u,x_yz).subs(v,v(w))).rhs + x_z = dsolve(diff(u(w),w) - ((c*w*F2-b*v*F3)/(b*v*F1-a*u*F2)).subs(v,y_zx).subs(u,u(w))).rhs + sol1 = dsolve(diff(u(t),t) - (c*w*F2 - b*v*F3).subs(v,y_x).subs(w,z_x).subs(u,u(t))).rhs + sol2 = dsolve(diff(v(t),t) - (a*u*F3 - c*w*F1).subs(u,x_y).subs(w,z_y).subs(v,v(t))).rhs + sol3 = dsolve(diff(w(t),t) - (b*v*F1 - a*u*F2).subs(u,x_z).subs(v,y_z).subs(w,w(t))).rhs + return [sol1, sol2, sol3] + +def _nonlinear_3eq_order1_type5(x, y, z, t, eq): + r""" + .. math:: x' = x (c F_2 - b F_3), \enspace y' = y (a F_3 - c F_1), \enspace z' = z (b F_1 - a F_2) + + where `F_n = F_n (x, y, z, t)` and are arbitrary functions. + + First Integral: + + .. math:: \left|x\right|^{a} \left|y\right|^{b} \left|z\right|^{c} = C_1 + + where `C` is an arbitrary constant. If the function `F_n` is independent of `t`, + then, by eliminating `t` and `z` from the first two equations of the system, one + arrives at a first-order equation. + + References + ========== + -https://eqworld.ipmnet.ru/en/solutions/sysode/sode0406.pdf + + """ + C1 = get_numbered_constants(eq, num=1) + u, v, w = symbols('u, v, w') + fu, fv, fw = symbols('u, v, w', cls=Function) + p = Wild('p', exclude=[x(t), y(t), z(t), t]) + q = Wild('q', exclude=[x(t), y(t), z(t), t]) + s = Wild('s', exclude=[x(t), y(t), z(t), t]) + F1, F2, F3 = symbols('F1, F2, F3', cls=Wild) + r1 = eq[0].match(diff(x(t), t) - x(t)*F2 + x(t)*F3) + r = collect_const(r1[F2]).match(s*F2) + r.update(collect_const(r1[F3]).match(q*F3)) + if eq[1].has(r[F2]) and not eq[1].has(r[F3]): + r[F2], r[F3] = r[F3], r[F2] + r[s], r[q] = -r[q], -r[s] + r.update((diff(y(t), t) - eq[1]).match(y(t)*(p*r[F3] - r[s]*F1))) + a = r[p]; b = r[q]; c = r[s] + F1 = r[F1].subs(x(t), u).subs(y(t), v).subs(z(t), w) + F2 = r[F2].subs(x(t), u).subs(y(t), v).subs(z(t), w) + F3 = r[F3].subs(x(t), u).subs(y(t), v).subs(z(t), w) + x_yz = (C1*v**-b*w**-c)**-a + y_zx = (C1*w**-c*u**-a)**-b + z_xy = (C1*u**-a*v**-b)**-c + y_x = dsolve(diff(fv(u), u) - ((v*(a*F3 - c*F1))/(u*(c*F2 - b*F3))).subs(w, z_xy).subs(v, fv(u))).rhs + z_x = dsolve(diff(fw(u), u) - ((w*(b*F1 - a*F2))/(u*(c*F2 - b*F3))).subs(v, y_zx).subs(w, fw(u))).rhs + z_y = dsolve(diff(fw(v), v) - ((w*(b*F1 - a*F2))/(v*(a*F3 - c*F1))).subs(u, x_yz).subs(w, fw(v))).rhs + x_y = dsolve(diff(fu(v), v) - ((u*(c*F2 - b*F3))/(v*(a*F3 - c*F1))).subs(w, z_xy).subs(u, fu(v))).rhs + y_z = dsolve(diff(fv(w), w) - ((v*(a*F3 - c*F1))/(w*(b*F1 - a*F2))).subs(u, x_yz).subs(v, fv(w))).rhs + x_z = dsolve(diff(fu(w), w) - ((u*(c*F2 - b*F3))/(w*(b*F1 - a*F2))).subs(v, y_zx).subs(u, fu(w))).rhs + sol1 = dsolve(diff(fu(t), t) - (u*(c*F2 - b*F3)).subs(v, y_x).subs(w, z_x).subs(u, fu(t))).rhs + sol2 = dsolve(diff(fv(t), t) - (v*(a*F3 - c*F1)).subs(u, x_y).subs(w, z_y).subs(v, fv(t))).rhs + sol3 = dsolve(diff(fw(t), t) - (w*(b*F1 - a*F2)).subs(u, x_z).subs(v, y_z).subs(w, fw(t))).rhs + return [sol1, sol2, sol3] + + +#This import is written at the bottom to avoid circular imports. +from .single import SingleODEProblem, SingleODESolver, solver_map diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/riccati.py b/lib/python3.10/site-packages/sympy/solvers/ode/riccati.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef66ed0896d39bee8fba1b74a0c93734742fc1f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/riccati.py @@ -0,0 +1,893 @@ +r""" +This module contains :py:meth:`~sympy.solvers.ode.riccati.solve_riccati`, +a function which gives all rational particular solutions to first order +Riccati ODEs. A general first order Riccati ODE is given by - + +.. math:: y' = b_0(x) + b_1(x)w + b_2(x)w^2 + +where `b_0, b_1` and `b_2` can be arbitrary rational functions of `x` +with `b_2 \ne 0`. When `b_2 = 0`, the equation is not a Riccati ODE +anymore and becomes a Linear ODE. Similarly, when `b_0 = 0`, the equation +is a Bernoulli ODE. The algorithm presented below can find rational +solution(s) to all ODEs with `b_2 \ne 0` that have a rational solution, +or prove that no rational solution exists for the equation. + +Background +========== + +A Riccati equation can be transformed to its normal form + +.. math:: y' + y^2 = a(x) + +using the transformation + +.. math:: y = -b_2(x) - \frac{b'_2(x)}{2 b_2(x)} - \frac{b_1(x)}{2} + +where `a(x)` is given by + +.. math:: a(x) = \frac{1}{4}\left(\frac{b_2'}{b_2} + b_1\right)^2 - \frac{1}{2}\left(\frac{b_2'}{b_2} + b_1\right)' - b_0 b_2 + +Thus, we can develop an algorithm to solve for the Riccati equation +in its normal form, which would in turn give us the solution for +the original Riccati equation. + +Algorithm +========= + +The algorithm implemented here is presented in the Ph.D thesis +"Rational and Algebraic Solutions of First-Order Algebraic ODEs" +by N. Thieu Vo. The entire thesis can be found here - +https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf + +We have only implemented the Rational Riccati solver (Algorithm 11, +Pg 78-82 in Thesis). Before we proceed towards the implementation +of the algorithm, a few definitions to understand are - + +1. Valuation of a Rational Function at `\infty`: + The valuation of a rational function `p(x)` at `\infty` is equal + to the difference between the degree of the denominator and the + numerator of `p(x)`. + + NOTE: A general definition of valuation of a rational function + at any value of `x` can be found in Pg 63 of the thesis, but + is not of any interest for this algorithm. + +2. Zeros and Poles of a Rational Function: + Let `a(x) = \frac{S(x)}{T(x)}, T \ne 0` be a rational function + of `x`. Then - + + a. The Zeros of `a(x)` are the roots of `S(x)`. + b. The Poles of `a(x)` are the roots of `T(x)`. However, `\infty` + can also be a pole of a(x). We say that `a(x)` has a pole at + `\infty` if `a(\frac{1}{x})` has a pole at 0. + +Every pole is associated with an order that is equal to the multiplicity +of its appearance as a root of `T(x)`. A pole is called a simple pole if +it has an order 1. Similarly, a pole is called a multiple pole if it has +an order `\ge` 2. + +Necessary Conditions +==================== + +For a Riccati equation in its normal form, + +.. math:: y' + y^2 = a(x) + +we can define + +a. A pole is called a movable pole if it is a pole of `y(x)` and is not +a pole of `a(x)`. +b. Similarly, a pole is called a non-movable pole if it is a pole of both +`y(x)` and `a(x)`. + +Then, the algorithm states that a rational solution exists only if - + +a. Every pole of `a(x)` must be either a simple pole or a multiple pole +of even order. +b. The valuation of `a(x)` at `\infty` must be even or be `\ge` 2. + +This algorithm finds all possible rational solutions for the Riccati ODE. +If no rational solutions are found, it means that no rational solutions +exist. + +The algorithm works for Riccati ODEs where the coefficients are rational +functions in the independent variable `x` with rational number coefficients +i.e. in `Q(x)`. The coefficients in the rational function cannot be floats, +irrational numbers, symbols or any other kind of expression. The reasons +for this are - + +1. When using symbols, different symbols could take the same value and this +would affect the multiplicity of poles if symbols are present here. + +2. An integer degree bound is required to calculate a polynomial solution +to an auxiliary differential equation, which in turn gives the particular +solution for the original ODE. If symbols/floats/irrational numbers are +present, we cannot determine if the expression for the degree bound is an +integer or not. + +Solution +======== + +With these definitions, we can state a general form for the solution of +the equation. `y(x)` must have the form - + +.. math:: y(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=1}^{m} \frac{1}{x - \chi_i} + \sum_{i=0}^{N} d_i x^i + +where `x_1, x_2, \dots, x_n` are non-movable poles of `a(x)`, +`\chi_1, \chi_2, \dots, \chi_m` are movable poles of `a(x)`, and the values +of `N, n, r_1, r_2, \dots, r_n` can be determined from `a(x)`. The +coefficient vectors `(d_0, d_1, \dots, d_N)` and `(c_{i1}, c_{i2}, \dots, c_{i r_i})` +can be determined from `a(x)`. We will have 2 choices each of these vectors +and part of the procedure is figuring out which of the 2 should be used +to get the solution correctly. + +Implementation +============== + +In this implementation, we use ``Poly`` to represent a rational function +rather than using ``Expr`` since ``Poly`` is much faster. Since we cannot +represent rational functions directly using ``Poly``, we instead represent +a rational function with 2 ``Poly`` objects - one for its numerator and +the other for its denominator. + +The code is written to match the steps given in the thesis (Pg 82) + +Step 0 : Match the equation - +Find `b_0, b_1` and `b_2`. If `b_2 = 0` or no such functions exist, raise +an error + +Step 1 : Transform the equation to its normal form as explained in the +theory section. + +Step 2 : Initialize an empty set of solutions, ``sol``. + +Step 3 : If `a(x) = 0`, append `\frac{1}/{(x - C1)}` to ``sol``. + +Step 4 : If `a(x)` is a rational non-zero number, append `\pm \sqrt{a}` +to ``sol``. + +Step 5 : Find the poles and their multiplicities of `a(x)`. Let +the number of poles be `n`. Also find the valuation of `a(x)` at +`\infty` using ``val_at_inf``. + +NOTE: Although the algorithm considers `\infty` as a pole, it is +not mentioned if it a part of the set of finite poles. `\infty` +is NOT a part of the set of finite poles. If a pole exists at +`\infty`, we use its multiplicity to find the laurent series of +`a(x)` about `\infty`. + +Step 6 : Find `n` c-vectors (one for each pole) and 1 d-vector using +``construct_c`` and ``construct_d``. Now, determine all the ``2**(n + 1)`` +combinations of choosing between 2 choices for each of the `n` c-vectors +and 1 d-vector. + +NOTE: The equation for `d_{-1}` in Case 4 (Pg 80) has a printinig +mistake. The term `- d_N` must be replaced with `-N d_N`. The same +has been explained in the code as well. + +For each of these above combinations, do + +Step 8 : Compute `m` in ``compute_m_ybar``. `m` is the degree bound of +the polynomial solution we must find for the auxiliary equation. + +Step 9 : In ``compute_m_ybar``, compute ybar as well where ``ybar`` is +one part of y(x) - + +.. math:: \overline{y}(x) = \sum_{i=1}^{n} \sum_{j=1}^{r_i} \frac{c_{ij}}{(x - x_i)^j} + \sum_{i=0}^{N} d_i x^i + +Step 10 : If `m` is a non-negative integer - + +Step 11: Find a polynomial solution of degree `m` for the auxiliary equation. + +There are 2 cases possible - + + a. `m` is a non-negative integer: We can solve for the coefficients + in `p(x)` using Undetermined Coefficients. + + b. `m` is not a non-negative integer: In this case, we cannot find + a polynomial solution to the auxiliary equation, and hence, we ignore + this value of `m`. + +Step 12 : For each `p(x)` that exists, append `ybar + \frac{p'(x)}{p(x)}` +to ``sol``. + +Step 13 : For each solution in ``sol``, apply an inverse transformation, +so that the solutions of the original equation are found using the +solutions of the equation in its normal form. +""" + + +from itertools import product +from sympy.core import S +from sympy.core.add import Add +from sympy.core.numbers import oo, Float +from sympy.core.function import count_ops +from sympy.core.relational import Eq +from sympy.core.symbol import symbols, Symbol, Dummy +from sympy.functions import sqrt, exp +from sympy.functions.elementary.complexes import sign +from sympy.integrals.integrals import Integral +from sympy.polys.domains import ZZ +from sympy.polys.polytools import Poly +from sympy.polys.polyroots import roots +from sympy.solvers.solveset import linsolve + + +def riccati_normal(w, x, b1, b2): + """ + Given a solution `w(x)` to the equation + + .. math:: w'(x) = b_0(x) + b_1(x)*w(x) + b_2(x)*w(x)^2 + + and rational function coefficients `b_1(x)` and + `b_2(x)`, this function transforms the solution to + give a solution `y(x)` for its corresponding normal + Riccati ODE + + .. math:: y'(x) + y(x)^2 = a(x) + + using the transformation + + .. math:: y(x) = -b_2(x)*w(x) - b'_2(x)/(2*b_2(x)) - b_1(x)/2 + """ + return -b2*w - b2.diff(x)/(2*b2) - b1/2 + + +def riccati_inverse_normal(y, x, b1, b2, bp=None): + """ + Inverse transforming the solution to the normal + Riccati ODE to get the solution to the Riccati ODE. + """ + # bp is the expression which is independent of the solution + # and hence, it need not be computed again + if bp is None: + bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) + # w(x) = -y(x)/b2(x) - b2'(x)/(2*b2(x)^2) - b1(x)/(2*b2(x)) + return -y/b2 + bp + + +def riccati_reduced(eq, f, x): + """ + Convert a Riccati ODE into its corresponding + normal Riccati ODE. + """ + match, funcs = match_riccati(eq, f, x) + # If equation is not a Riccati ODE, exit + if not match: + return False + # Using the rational functions, find the expression for a(x) + b0, b1, b2 = funcs + a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ + b2.diff(x, 2)/(2*b2) + # Normal form of Riccati ODE is f'(x) + f(x)^2 = a(x) + return f(x).diff(x) + f(x)**2 - a + +def linsolve_dict(eq, syms): + """ + Get the output of linsolve as a dict + """ + # Convert tuple type return value of linsolve + # to a dictionary for ease of use + sol = linsolve(eq, syms) + if not sol: + return {} + return dict(zip(syms, list(sol)[0])) + + +def match_riccati(eq, f, x): + """ + A function that matches and returns the coefficients + if an equation is a Riccati ODE + + Parameters + ========== + + eq: Equation to be matched + f: Dependent variable + x: Independent variable + + Returns + ======= + + match: True if equation is a Riccati ODE, False otherwise + funcs: [b0, b1, b2] if match is True, [] otherwise. Here, + b0, b1 and b2 are rational functions which match the equation. + """ + # Group terms based on f(x) + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + eq = eq.expand().collect(f(x)) + cf = eq.coeff(f(x).diff(x)) + + # There must be an f(x).diff(x) term. + # eq must be an Add object since we are using the expanded + # equation and it must have atleast 2 terms (b2 != 0) + if cf != 0 and isinstance(eq, Add): + + # Divide all coefficients by the coefficient of f(x).diff(x) + # and add the terms again to get the same equation + eq = Add(*((x/cf).cancel() for x in eq.args)).collect(f(x)) + + # Match the equation with the pattern + b1 = -eq.coeff(f(x)) + b2 = -eq.coeff(f(x)**2) + b0 = (f(x).diff(x) - b1*f(x) - b2*f(x)**2 - eq).expand() + funcs = [b0, b1, b2] + + # Check if coefficients are not symbols and floats + if any(len(x.atoms(Symbol)) > 1 or len(x.atoms(Float)) for x in funcs): + return False, [] + + # If b_0(x) contains f(x), it is not a Riccati ODE + if len(b0.atoms(f)) or not all((b2 != 0, b0.is_rational_function(x), + b1.is_rational_function(x), b2.is_rational_function(x))): + return False, [] + return True, funcs + return False, [] + + +def val_at_inf(num, den, x): + # Valuation of a rational function at oo = deg(denom) - deg(numer) + return den.degree(x) - num.degree(x) + + +def check_necessary_conds(val_inf, muls): + """ + The necessary conditions for a rational solution + to exist are as follows - + + i) Every pole of a(x) must be either a simple pole + or a multiple pole of even order. + + ii) The valuation of a(x) at infinity must be even + or be greater than or equal to 2. + + Here, a simple pole is a pole with multiplicity 1 + and a multiple pole is a pole with multiplicity + greater than 1. + """ + return (val_inf >= 2 or (val_inf <= 0 and val_inf%2 == 0)) and \ + all(mul == 1 or (mul%2 == 0 and mul >= 2) for mul in muls) + + +def inverse_transform_poly(num, den, x): + """ + A function to make the substitution + x -> 1/x in a rational function that + is represented using Poly objects for + numerator and denominator. + """ + # Declare for reuse + one = Poly(1, x) + xpoly = Poly(x, x) + + # Check if degree of numerator is same as denominator + pwr = val_at_inf(num, den, x) + if pwr >= 0: + # Denominator has greater degree. Substituting x with + # 1/x would make the extra power go to the numerator + if num.expr != 0: + num = num.transform(one, xpoly) * x**pwr + den = den.transform(one, xpoly) + else: + # Numerator has greater degree. Substituting x with + # 1/x would make the extra power go to the denominator + num = num.transform(one, xpoly) + den = den.transform(one, xpoly) * x**(-pwr) + return num.cancel(den, include=True) + + +def limit_at_inf(num, den, x): + """ + Find the limit of a rational function + at oo + """ + # pwr = degree(num) - degree(den) + pwr = -val_at_inf(num, den, x) + # Numerator has a greater degree than denominator + # Limit at infinity would depend on the sign of the + # leading coefficients of numerator and denominator + if pwr > 0: + return oo*sign(num.LC()/den.LC()) + # Degree of numerator is equal to that of denominator + # Limit at infinity is just the ratio of leading coeffs + elif pwr == 0: + return num.LC()/den.LC() + # Degree of numerator is less than that of denominator + # Limit at infinity is just 0 + else: + return 0 + + +def construct_c_case_1(num, den, x, pole): + # Find the coefficient of 1/(x - pole)**2 in the + # Laurent series expansion of a(x) about pole. + num1, den1 = (num*Poly((x - pole)**2, x, extension=True)).cancel(den, include=True) + r = (num1.subs(x, pole))/(den1.subs(x, pole)) + + # If multiplicity is 2, the coefficient to be added + # in the c-vector is c = (1 +- sqrt(1 + 4*r))/2 + if r != -S(1)/4: + return [[(1 + sqrt(1 + 4*r))/2], [(1 - sqrt(1 + 4*r))/2]] + return [[S.Half]] + + +def construct_c_case_2(num, den, x, pole, mul): + # Generate the coefficients using the recurrence + # relation mentioned in (5.14) in the thesis (Pg 80) + + # r_i = mul/2 + ri = mul//2 + + # Find the Laurent series coefficients about the pole + ser = rational_laurent_series(num, den, x, pole, mul, 6) + + # Start with an empty memo to store the coefficients + # This is for the plus case + cplus = [0 for i in range(ri)] + + # Base Case + cplus[ri-1] = sqrt(ser[2*ri]) + + # Iterate backwards to find all coefficients + s = ri - 1 + sm = 0 + for s in range(ri-1, 0, -1): + sm = 0 + for j in range(s+1, ri): + sm += cplus[j-1]*cplus[ri+s-j-1] + if s!= 1: + cplus[s-1] = (ser[ri+s] - sm)/(2*cplus[ri-1]) + + # Memo for the minus case + cminus = [-x for x in cplus] + + # Find the 0th coefficient in the recurrence + cplus[0] = (ser[ri+s] - sm - ri*cplus[ri-1])/(2*cplus[ri-1]) + cminus[0] = (ser[ri+s] - sm - ri*cminus[ri-1])/(2*cminus[ri-1]) + + # Add both the plus and minus cases' coefficients + if cplus != cminus: + return [cplus, cminus] + return cplus + + +def construct_c_case_3(): + # If multiplicity is 1, the coefficient to be added + # in the c-vector is 1 (no choice) + return [[1]] + + +def construct_c(num, den, x, poles, muls): + """ + Helper function to calculate the coefficients + in the c-vector for each pole. + """ + c = [] + for pole, mul in zip(poles, muls): + c.append([]) + + # Case 3 + if mul == 1: + # Add the coefficients from Case 3 + c[-1].extend(construct_c_case_3()) + + # Case 1 + elif mul == 2: + # Add the coefficients from Case 1 + c[-1].extend(construct_c_case_1(num, den, x, pole)) + + # Case 2 + else: + # Add the coefficients from Case 2 + c[-1].extend(construct_c_case_2(num, den, x, pole, mul)) + + return c + + +def construct_d_case_4(ser, N): + # Initialize an empty vector + dplus = [0 for i in range(N+2)] + # d_N = sqrt(a_{2*N}) + dplus[N] = sqrt(ser[2*N]) + + # Use the recurrence relations to find + # the value of d_s + for s in range(N-1, -2, -1): + sm = 0 + for j in range(s+1, N): + sm += dplus[j]*dplus[N+s-j] + if s != -1: + dplus[s] = (ser[N+s] - sm)/(2*dplus[N]) + + # Coefficients for the case of d_N = -sqrt(a_{2*N}) + dminus = [-x for x in dplus] + + # The third equation in Eq 5.15 of the thesis is WRONG! + # d_N must be replaced with N*d_N in that equation. + dplus[-1] = (ser[N+s] - N*dplus[N] - sm)/(2*dplus[N]) + dminus[-1] = (ser[N+s] - N*dminus[N] - sm)/(2*dminus[N]) + + if dplus != dminus: + return [dplus, dminus] + return dplus + + +def construct_d_case_5(ser): + # List to store coefficients for plus case + dplus = [0, 0] + + # d_0 = sqrt(a_0) + dplus[0] = sqrt(ser[0]) + + # d_(-1) = a_(-1)/(2*d_0) + dplus[-1] = ser[-1]/(2*dplus[0]) + + # Coefficients for the minus case are just the negative + # of the coefficients for the positive case. + dminus = [-x for x in dplus] + + if dplus != dminus: + return [dplus, dminus] + return dplus + + +def construct_d_case_6(num, den, x): + # s_oo = lim x->0 1/x**2 * a(1/x) which is equivalent to + # s_oo = lim x->oo x**2 * a(x) + s_inf = limit_at_inf(Poly(x**2, x)*num, den, x) + + # d_(-1) = (1 +- sqrt(1 + 4*s_oo))/2 + if s_inf != -S(1)/4: + return [[(1 + sqrt(1 + 4*s_inf))/2], [(1 - sqrt(1 + 4*s_inf))/2]] + return [[S.Half]] + + +def construct_d(num, den, x, val_inf): + """ + Helper function to calculate the coefficients + in the d-vector based on the valuation of the + function at oo. + """ + N = -val_inf//2 + # Multiplicity of oo as a pole + mul = -val_inf if val_inf < 0 else 0 + ser = rational_laurent_series(num, den, x, oo, mul, 1) + + # Case 4 + if val_inf < 0: + d = construct_d_case_4(ser, N) + + # Case 5 + elif val_inf == 0: + d = construct_d_case_5(ser) + + # Case 6 + else: + d = construct_d_case_6(num, den, x) + + return d + + +def rational_laurent_series(num, den, x, r, m, n): + r""" + The function computes the Laurent series coefficients + of a rational function. + + Parameters + ========== + + num: A Poly object that is the numerator of `f(x)`. + den: A Poly object that is the denominator of `f(x)`. + x: The variable of expansion of the series. + r: The point of expansion of the series. + m: Multiplicity of r if r is a pole of `f(x)`. Should + be zero otherwise. + n: Order of the term upto which the series is expanded. + + Returns + ======= + + series: A dictionary that has power of the term as key + and coefficient of that term as value. + + Below is a basic outline of how the Laurent series of a + rational function `f(x)` about `x_0` is being calculated - + + 1. Substitute `x + x_0` in place of `x`. If `x_0` + is a pole of `f(x)`, multiply the expression by `x^m` + where `m` is the multiplicity of `x_0`. Denote the + the resulting expression as g(x). We do this substitution + so that we can now find the Laurent series of g(x) about + `x = 0`. + + 2. We can then assume that the Laurent series of `g(x)` + takes the following form - + + .. math:: g(x) = \frac{num(x)}{den(x)} = \sum_{m = 0}^{\infty} a_m x^m + + where `a_m` denotes the Laurent series coefficients. + + 3. Multiply the denominator to the RHS of the equation + and form a recurrence relation for the coefficients `a_m`. + """ + one = Poly(1, x, extension=True) + + if r == oo: + # Series at x = oo is equal to first transforming + # the function from x -> 1/x and finding the + # series at x = 0 + num, den = inverse_transform_poly(num, den, x) + r = S(0) + + if r: + # For an expansion about a non-zero point, a + # transformation from x -> x + r must be made + num = num.transform(Poly(x + r, x, extension=True), one) + den = den.transform(Poly(x + r, x, extension=True), one) + + # Remove the pole from the denominator if the series + # expansion is about one of the poles + num, den = (num*x**m).cancel(den, include=True) + + # Equate coefficients for the first terms (base case) + maxdegree = 1 + max(num.degree(), den.degree()) + syms = symbols(f'a:{maxdegree}', cls=Dummy) + diff = num - den * Poly(syms[::-1], x) + coeff_diffs = diff.all_coeffs()[::-1][:maxdegree] + (coeffs, ) = linsolve(coeff_diffs, syms) + + # Use the recursion relation for the rest + recursion = den.all_coeffs()[::-1] + div, rec_rhs = recursion[0], recursion[1:] + series = list(coeffs) + while len(series) < n: + next_coeff = Add(*(c*series[-1-n] for n, c in enumerate(rec_rhs))) / div + series.append(-next_coeff) + series = {m - i: val for i, val in enumerate(series)} + return series + +def compute_m_ybar(x, poles, choice, N): + """ + Helper function to calculate - + + 1. m - The degree bound for the polynomial + solution that must be found for the auxiliary + differential equation. + + 2. ybar - Part of the solution which can be + computed using the poles, c and d vectors. + """ + ybar = 0 + m = Poly(choice[-1][-1], x, extension=True) + + # Calculate the first (nested) summation for ybar + # as given in Step 9 of the Thesis (Pg 82) + dybar = [] + for i, polei in enumerate(poles): + for j, cij in enumerate(choice[i]): + dybar.append(cij/(x - polei)**(j + 1)) + m -=Poly(choice[i][0], x, extension=True) # can't accumulate Poly and use with Add + ybar += Add(*dybar) + + # Calculate the second summation for ybar + for i in range(N+1): + ybar += choice[-1][i]*x**i + return (m.expr, ybar) + + +def solve_aux_eq(numa, dena, numy, deny, x, m): + """ + Helper function to find a polynomial solution + of degree m for the auxiliary differential + equation. + """ + # Assume that the solution is of the type + # p(x) = C_0 + C_1*x + ... + C_{m-1}*x**(m-1) + x**m + psyms = symbols(f'C0:{m}', cls=Dummy) + K = ZZ[psyms] + psol = Poly(K.gens, x, domain=K) + Poly(x**m, x, domain=K) + + # Eq (5.16) in Thesis - Pg 81 + auxeq = (dena*(numy.diff(x)*deny - numy*deny.diff(x) + numy**2) - numa*deny**2)*psol + if m >= 1: + px = psol.diff(x) + auxeq += px*(2*numy*deny*dena) + if m >= 2: + auxeq += px.diff(x)*(deny**2*dena) + if m != 0: + # m is a non-zero integer. Find the constant terms using undetermined coefficients + return psol, linsolve_dict(auxeq.all_coeffs(), psyms), True + else: + # m == 0 . Check if 1 (x**0) is a solution to the auxiliary equation + return S.One, auxeq, auxeq == 0 + + +def remove_redundant_sols(sol1, sol2, x): + """ + Helper function to remove redundant + solutions to the differential equation. + """ + # If y1 and y2 are redundant solutions, there is + # some value of the arbitrary constant for which + # they will be equal + + syms1 = sol1.atoms(Symbol, Dummy) + syms2 = sol2.atoms(Symbol, Dummy) + num1, den1 = [Poly(e, x, extension=True) for e in sol1.together().as_numer_denom()] + num2, den2 = [Poly(e, x, extension=True) for e in sol2.together().as_numer_denom()] + # Cross multiply + e = num1*den2 - den1*num2 + # Check if there are any constants + syms = list(e.atoms(Symbol, Dummy)) + if len(syms): + # Find values of constants for which solutions are equal + redn = linsolve(e.all_coeffs(), syms) + if len(redn): + # Return the general solution over a particular solution + if len(syms1) > len(syms2): + return sol2 + # If both have constants, return the lesser complex solution + elif len(syms1) == len(syms2): + return sol1 if count_ops(syms1) >= count_ops(syms2) else sol2 + else: + return sol1 + + +def get_gen_sol_from_part_sol(part_sols, a, x): + """" + Helper function which computes the general + solution for a Riccati ODE from its particular + solutions. + + There are 3 cases to find the general solution + from the particular solutions for a Riccati ODE + depending on the number of particular solution(s) + we have - 1, 2 or 3. + + For more information, see Section 6 of + "Methods of Solution of the Riccati Differential Equation" + by D. R. Haaheim and F. M. Stein + """ + + # If no particular solutions are found, a general + # solution cannot be found + if len(part_sols) == 0: + return [] + + # In case of a single particular solution, the general + # solution can be found by using the substitution + # y = y1 + 1/z and solving a Bernoulli ODE to find z. + elif len(part_sols) == 1: + y1 = part_sols[0] + i = exp(Integral(2*y1, x)) + z = i * Integral(a/i, x) + z = z.doit() + if a == 0 or z == 0: + return y1 + return y1 + 1/z + + # In case of 2 particular solutions, the general solution + # can be found by solving a separable equation. This is + # the most common case, i.e. most Riccati ODEs have 2 + # rational particular solutions. + elif len(part_sols) == 2: + y1, y2 = part_sols + # One of them already has a constant + if len(y1.atoms(Dummy)) + len(y2.atoms(Dummy)) > 0: + u = exp(Integral(y2 - y1, x)).doit() + # Introduce a constant + else: + C1 = Dummy('C1') + u = C1*exp(Integral(y2 - y1, x)).doit() + if u == 1: + return y2 + return (y2*u - y1)/(u - 1) + + # In case of 3 particular solutions, a closed form + # of the general solution can be obtained directly + else: + y1, y2, y3 = part_sols[:3] + C1 = Dummy('C1') + return (C1 + 1)*y2*(y1 - y3)/(C1*y1 + y2 - (C1 + 1)*y3) + + +def solve_riccati(fx, x, b0, b1, b2, gensol=False): + """ + The main function that gives particular/general + solutions to Riccati ODEs that have atleast 1 + rational particular solution. + """ + # Step 1 : Convert to Normal Form + a = -b0*b2 + b1**2/4 - b1.diff(x)/2 + 3*b2.diff(x)**2/(4*b2**2) + b1*b2.diff(x)/(2*b2) - \ + b2.diff(x, 2)/(2*b2) + a_t = a.together() + num, den = [Poly(e, x, extension=True) for e in a_t.as_numer_denom()] + num, den = num.cancel(den, include=True) + + # Step 2 + presol = [] + + # Step 3 : a(x) is 0 + if num == 0: + presol.append(1/(x + Dummy('C1'))) + + # Step 4 : a(x) is a non-zero constant + elif x not in num.free_symbols.union(den.free_symbols): + presol.extend([sqrt(a), -sqrt(a)]) + + # Step 5 : Find poles and valuation at infinity + poles = roots(den, x) + poles, muls = list(poles.keys()), list(poles.values()) + val_inf = val_at_inf(num, den, x) + + if len(poles): + # Check necessary conditions (outlined in the module docstring) + if not check_necessary_conds(val_inf, muls): + raise ValueError("Rational Solution doesn't exist") + + # Step 6 + # Construct c-vectors for each singular point + c = construct_c(num, den, x, poles, muls) + + # Construct d vectors for each singular point + d = construct_d(num, den, x, val_inf) + + # Step 7 : Iterate over all possible combinations and return solutions + # For each possible combination, generate an array of 0's and 1's + # where 0 means pick 1st choice and 1 means pick the second choice. + + # NOTE: We could exit from the loop if we find 3 particular solutions, + # but it is not implemented here as - + # a. Finding 3 particular solutions is very rare. Most of the time, + # only 2 particular solutions are found. + # b. In case we exit after finding 3 particular solutions, it might + # happen that 1 or 2 of them are redundant solutions. So, instead of + # spending some more time in computing the particular solutions, + # we will end up computing the general solution from a single + # particular solution which is usually slower than computing the + # general solution from 2 or 3 particular solutions. + c.append(d) + choices = product(*c) + for choice in choices: + m, ybar = compute_m_ybar(x, poles, choice, -val_inf//2) + numy, deny = [Poly(e, x, extension=True) for e in ybar.together().as_numer_denom()] + # Step 10 : Check if a valid solution exists. If yes, also check + # if m is a non-negative integer + if m.is_nonnegative == True and m.is_integer == True: + + # Step 11 : Find polynomial solutions of degree m for the auxiliary equation + psol, coeffs, exists = solve_aux_eq(num, den, numy, deny, x, m) + + # Step 12 : If valid polynomial solution exists, append solution. + if exists: + # m == 0 case + if psol == 1 and coeffs == 0: + # p(x) = 1, so p'(x)/p(x) term need not be added + presol.append(ybar) + # m is a positive integer and there are valid coefficients + elif len(coeffs): + # Substitute the valid coefficients to get p(x) + psol = psol.xreplace(coeffs) + # y(x) = ybar(x) + p'(x)/p(x) + presol.append(ybar + psol.diff(x)/psol) + + # Remove redundant solutions from the list of existing solutions + remove = set() + for i in range(len(presol)): + for j in range(i+1, len(presol)): + rem = remove_redundant_sols(presol[i], presol[j], x) + if rem is not None: + remove.add(rem) + sols = [x for x in presol if x not in remove] + + # Step 15 : Inverse transform the solutions of the equation in normal form + bp = -b2.diff(x)/(2*b2**2) - b1/(2*b2) + + # If general solution is required, compute it from the particular solutions + if gensol: + sols = [get_gen_sol_from_part_sol(sols, a, x)] + + # Inverse transform the particular solutions + presol = [Eq(fx, riccati_inverse_normal(y, x, b1, b2, bp).cancel(extension=True)) for y in sols] + return presol diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/single.py b/lib/python3.10/site-packages/sympy/solvers/ode/single.py new file mode 100644 index 0000000000000000000000000000000000000000..7d46931122f11a1592097e6a7117192d39bae10e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/single.py @@ -0,0 +1,2979 @@ +# +# This is the module for ODE solver classes for single ODEs. +# + +from __future__ import annotations +from typing import ClassVar, Iterator + +from .riccati import match_riccati, solve_riccati +from sympy.core import Add, S, Pow, Rational +from sympy.core.cache import cached_property +from sympy.core.exprtools import factor_terms +from sympy.core.expr import Expr +from sympy.core.function import AppliedUndef, Derivative, diff, Function, expand, Subs, _mexpand +from sympy.core.numbers import zoo +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Symbol, Dummy, Wild +from sympy.core.mul import Mul +from sympy.functions import exp, tan, log, sqrt, besselj, bessely, cbrt, airyai, airybi +from sympy.integrals import Integral +from sympy.polys import Poly +from sympy.polys.polytools import cancel, factor, degree +from sympy.simplify import collect, simplify, separatevars, logcombine, posify # type: ignore +from sympy.simplify.radsimp import fraction +from sympy.utilities import numbered_symbols +from sympy.solvers.solvers import solve +from sympy.solvers.deutils import ode_order, _preprocess +from sympy.polys.matrices.linsolve import _lin_eq2dict +from sympy.polys.solvers import PolyNonlinearError +from .hypergeometric import equivalence_hypergeometric, match_2nd_2F1_hypergeometric, \ + get_sol_2F1_hypergeometric, match_2nd_hypergeometric +from .nonhomogeneous import _get_euler_characteristic_eq_sols, _get_const_characteristic_eq_sols, \ + _solve_undetermined_coefficients, _solve_variation_of_parameters, _test_term, _undetermined_coefficients_match, \ + _get_simplified_sol +from .lie_group import _ode_lie_group + + +class ODEMatchError(NotImplementedError): + """Raised if a SingleODESolver is asked to solve an ODE it does not match""" + pass + + +class SingleODEProblem: + """Represents an ordinary differential equation (ODE) + + This class is used internally in the by dsolve and related + functions/classes so that properties of an ODE can be computed + efficiently. + + Examples + ======== + + This class is used internally by dsolve. To instantiate an instance + directly first define an ODE problem: + + >>> from sympy import Function, Symbol + >>> x = Symbol('x') + >>> f = Function('f') + >>> eq = f(x).diff(x, 2) + + Now you can create a SingleODEProblem instance and query its properties: + + >>> from sympy.solvers.ode.single import SingleODEProblem + >>> problem = SingleODEProblem(f(x).diff(x), f(x), x) + >>> problem.eq + Derivative(f(x), x) + >>> problem.func + f(x) + >>> problem.sym + x + """ + + # Instance attributes: + eq = None # type: Expr + func = None # type: AppliedUndef + sym = None # type: Symbol + _order = None # type: int + _eq_expanded = None # type: Expr + _eq_preprocessed = None # type: Expr + _eq_high_order_free = None + + def __init__(self, eq, func, sym, prep=True, **kwargs): + assert isinstance(eq, Expr) + assert isinstance(func, AppliedUndef) + assert isinstance(sym, Symbol) + assert isinstance(prep, bool) + self.eq = eq + self.func = func + self.sym = sym + self.prep = prep + self.params = kwargs + + @cached_property + def order(self) -> int: + return ode_order(self.eq, self.func) + + @cached_property + def eq_preprocessed(self) -> Expr: + return self._get_eq_preprocessed() + + @cached_property + def eq_high_order_free(self) -> Expr: + a = Wild('a', exclude=[self.func]) + c1 = Wild('c1', exclude=[self.sym]) + # Precondition to try remove f(x) from highest order derivative + reduced_eq = None + if self.eq.is_Add: + deriv_coef = self.eq.coeff(self.func.diff(self.sym, self.order)) + if deriv_coef not in (1, 0): + r = deriv_coef.match(a*self.func**c1) + if r and r[c1]: + den = self.func**r[c1] + reduced_eq = Add(*[arg/den for arg in self.eq.args]) + if not reduced_eq: + reduced_eq = expand(self.eq) + return reduced_eq + + @cached_property + def eq_expanded(self) -> Expr: + return expand(self.eq_preprocessed) + + def _get_eq_preprocessed(self) -> Expr: + if self.prep: + process_eq, process_func = _preprocess(self.eq, self.func) + if process_func != self.func: + raise ValueError + else: + process_eq = self.eq + return process_eq + + def get_numbered_constants(self, num=1, start=1, prefix='C') -> list[Symbol]: + """ + Returns a list of constants that do not occur + in eq already. + """ + ncs = self.iter_numbered_constants(start, prefix) + Cs = [next(ncs) for i in range(num)] + return Cs + + def iter_numbered_constants(self, start=1, prefix='C') -> Iterator[Symbol]: + """ + Returns an iterator of constants that do not occur + in eq already. + """ + atom_set = self.eq.free_symbols + func_set = self.eq.atoms(Function) + if func_set: + atom_set |= {Symbol(str(f.func)) for f in func_set} + return numbered_symbols(start=start, prefix=prefix, exclude=atom_set) + + @cached_property + def is_autonomous(self): + u = Dummy('u') + x = self.sym + syms = self.eq.subs(self.func, u).free_symbols + return x not in syms + + def get_linear_coefficients(self, eq, func, order): + r""" + Matches a differential equation to the linear form: + + .. math:: a_n(x) y^{(n)} + \cdots + a_1(x)y' + a_0(x) y + B(x) = 0 + + Returns a dict of order:coeff terms, where order is the order of the + derivative on each term, and coeff is the coefficient of that derivative. + The key ``-1`` holds the function `B(x)`. Returns ``None`` if the ODE is + not linear. This function assumes that ``func`` has already been checked + to be good. + + Examples + ======== + + >>> from sympy import Function, cos, sin + >>> from sympy.abc import x + >>> from sympy.solvers.ode.single import SingleODEProblem + >>> f = Function('f') + >>> eq = f(x).diff(x, 3) + 2*f(x).diff(x) + \ + ... x*f(x).diff(x, 2) + cos(x)*f(x).diff(x) + x - f(x) - \ + ... sin(x) + >>> obj = SingleODEProblem(eq, f(x), x) + >>> obj.get_linear_coefficients(eq, f(x), 3) + {-1: x - sin(x), 0: -1, 1: cos(x) + 2, 2: x, 3: 1} + >>> eq = f(x).diff(x, 3) + 2*f(x).diff(x) + \ + ... x*f(x).diff(x, 2) + cos(x)*f(x).diff(x) + x - f(x) - \ + ... sin(f(x)) + >>> obj = SingleODEProblem(eq, f(x), x) + >>> obj.get_linear_coefficients(eq, f(x), 3) == None + True + + """ + f = func.func + x = func.args[0] + symset = {Derivative(f(x), x, i) for i in range(order+1)} + try: + rhs, lhs_terms = _lin_eq2dict(eq, symset) + except PolyNonlinearError: + return None + + if rhs.has(func) or any(c.has(func) for c in lhs_terms.values()): + return None + terms = {i: lhs_terms.get(f(x).diff(x, i), S.Zero) for i in range(order+1)} + terms[-1] = rhs + return terms + + # TODO: Add methods that can be used by many ODE solvers: + # order + # is_linear() + # get_linear_coefficients() + # eq_prepared (the ODE in prepared form) + + +class SingleODESolver: + """ + Base class for Single ODE solvers. + + Subclasses should implement the _matches and _get_general_solution + methods. This class is not intended to be instantiated directly but its + subclasses are as part of dsolve. + + Examples + ======== + + You can use a subclass of SingleODEProblem to solve a particular type of + ODE. We first define a particular ODE problem: + + >>> from sympy import Function, Symbol + >>> x = Symbol('x') + >>> f = Function('f') + >>> eq = f(x).diff(x, 2) + + Now we solve this problem using the NthAlgebraic solver which is a + subclass of SingleODESolver: + + >>> from sympy.solvers.ode.single import NthAlgebraic, SingleODEProblem + >>> problem = SingleODEProblem(eq, f(x), x) + >>> solver = NthAlgebraic(problem) + >>> solver.get_general_solution() + [Eq(f(x), _C*x + _C)] + + The normal way to solve an ODE is to use dsolve (which would use + NthAlgebraic and other solvers internally). When using dsolve a number of + other things are done such as evaluating integrals, simplifying the + solution and renumbering the constants: + + >>> from sympy import dsolve + >>> dsolve(eq, hint='nth_algebraic') + Eq(f(x), C1 + C2*x) + """ + + # Subclasses should store the hint name (the argument to dsolve) in this + # attribute + hint: ClassVar[str] + + # Subclasses should define this to indicate if they support an _Integral + # hint. + has_integral: ClassVar[bool] + + # The ODE to be solved + ode_problem = None # type: SingleODEProblem + + # Cache whether or not the equation has matched the method + _matched: bool | None = None + + # Subclasses should store in this attribute the list of order(s) of ODE + # that subclass can solve or leave it to None if not specific to any order + order: list | None = None + + def __init__(self, ode_problem): + self.ode_problem = ode_problem + + def matches(self) -> bool: + if self.order is not None and self.ode_problem.order not in self.order: + self._matched = False + return self._matched + + if self._matched is None: + self._matched = self._matches() + return self._matched + + def get_general_solution(self, *, simplify: bool = True) -> list[Equality]: + if not self.matches(): + msg = "%s solver cannot solve:\n%s" + raise ODEMatchError(msg % (self.hint, self.ode_problem.eq)) + return self._get_general_solution(simplify_flag=simplify) + + def _matches(self) -> bool: + msg = "Subclasses of SingleODESolver should implement matches." + raise NotImplementedError(msg) + + def _get_general_solution(self, *, simplify_flag: bool = True) -> list[Equality]: + msg = "Subclasses of SingleODESolver should implement get_general_solution." + raise NotImplementedError(msg) + + +class SinglePatternODESolver(SingleODESolver): + '''Superclass for ODE solvers based on pattern matching''' + + def wilds(self): + prob = self.ode_problem + f = prob.func.func + x = prob.sym + order = prob.order + return self._wilds(f, x, order) + + def wilds_match(self): + match = self._wilds_match + return [match.get(w, S.Zero) for w in self.wilds()] + + def _matches(self): + eq = self.ode_problem.eq_expanded + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + df = f(x).diff(x, order) + + if order not in [1, 2]: + return False + + pattern = self._equation(f(x), x, order) + + if not pattern.coeff(df).has(Wild): + eq = expand(eq / eq.coeff(df)) + eq = eq.collect([f(x).diff(x), f(x)], func = cancel) + + self._wilds_match = match = eq.match(pattern) + if match is not None: + return self._verify(f(x)) + return False + + def _verify(self, fx) -> bool: + return True + + def _wilds(self, f, x, order): + msg = "Subclasses of SingleODESolver should implement _wilds" + raise NotImplementedError(msg) + + def _equation(self, fx, x, order): + msg = "Subclasses of SingleODESolver should implement _equation" + raise NotImplementedError(msg) + + +class NthAlgebraic(SingleODESolver): + r""" + Solves an `n`\th order ordinary differential equation using algebra and + integrals. + + There is no general form for the kind of equation that this can solve. The + the equation is solved algebraically treating differentiation as an + invertible algebraic function. + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = Eq(f(x) * (f(x).diff(x)**2 - 1), 0) + >>> dsolve(eq, f(x), hint='nth_algebraic') + [Eq(f(x), 0), Eq(f(x), C1 - x), Eq(f(x), C1 + x)] + + Note that this solver can return algebraic solutions that do not have any + integration constants (f(x) = 0 in the above example). + """ + + hint = 'nth_algebraic' + has_integral = True # nth_algebraic_Integral hint + + def _matches(self): + r""" + Matches any differential equation that nth_algebraic can solve. Uses + `sympy.solve` but teaches it how to integrate derivatives. + + This involves calling `sympy.solve` and does most of the work of finding a + solution (apart from evaluating the integrals). + """ + eq = self.ode_problem.eq + func = self.ode_problem.func + var = self.ode_problem.sym + + # Derivative that solve can handle: + diffx = self._get_diffx(var) + + # Replace derivatives wrt the independent variable with diffx + def replace(eq, var): + def expand_diffx(*args): + differand, diffs = args[0], args[1:] + toreplace = differand + for v, n in diffs: + for _ in range(n): + if v == var: + toreplace = diffx(toreplace) + else: + toreplace = Derivative(toreplace, v) + return toreplace + return eq.replace(Derivative, expand_diffx) + + # Restore derivatives in solution afterwards + def unreplace(eq, var): + return eq.replace(diffx, lambda e: Derivative(e, var)) + + subs_eqn = replace(eq, var) + try: + # turn off simplification to protect Integrals that have + # _t instead of fx in them and would otherwise factor + # as t_*Integral(1, x) + solns = solve(subs_eqn, func, simplify=False) + except NotImplementedError: + solns = [] + + solns = [simplify(unreplace(soln, var)) for soln in solns] + solns = [Equality(func, soln) for soln in solns] + + self.solutions = solns + return len(solns) != 0 + + def _get_general_solution(self, *, simplify_flag: bool = True): + return self.solutions + + # This needs to produce an invertible function but the inverse depends + # which variable we are integrating with respect to. Since the class can + # be stored in cached results we need to ensure that we always get the + # same class back for each particular integration variable so we store these + # classes in a global dict: + _diffx_stored: dict[Symbol, type[Function]] = {} + + @staticmethod + def _get_diffx(var): + diffcls = NthAlgebraic._diffx_stored.get(var, None) + + if diffcls is None: + # A class that behaves like Derivative wrt var but is "invertible". + class diffx(Function): + def inverse(self): + # don't use integrate here because fx has been replaced by _t + # in the equation; integrals will not be correct while solve + # is at work. + return lambda expr: Integral(expr, var) + Dummy('C') + + diffcls = NthAlgebraic._diffx_stored.setdefault(var, diffx) + + return diffcls + + +class FirstExact(SinglePatternODESolver): + r""" + Solves 1st order exact ordinary differential equations. + + A 1st order differential equation is called exact if it is the total + differential of a function. That is, the differential equation + + .. math:: P(x, y) \,\partial{}x + Q(x, y) \,\partial{}y = 0 + + is exact if there is some function `F(x, y)` such that `P(x, y) = + \partial{}F/\partial{}x` and `Q(x, y) = \partial{}F/\partial{}y`. It can + be shown that a necessary and sufficient condition for a first order ODE + to be exact is that `\partial{}P/\partial{}y = \partial{}Q/\partial{}x`. + Then, the solution will be as given below:: + + >>> from sympy import Function, Eq, Integral, symbols, pprint + >>> x, y, t, x0, y0, C1= symbols('x,y,t,x0,y0,C1') + >>> P, Q, F= map(Function, ['P', 'Q', 'F']) + >>> pprint(Eq(Eq(F(x, y), Integral(P(t, y), (t, x0, x)) + + ... Integral(Q(x0, t), (t, y0, y))), C1)) + x y + / / + | | + F(x, y) = | P(t, y) dt + | Q(x0, t) dt = C1 + | | + / / + x0 y0 + + Where the first partials of `P` and `Q` exist and are continuous in a + simply connected region. + + A note: SymPy currently has no way to represent inert substitution on an + expression, so the hint ``1st_exact_Integral`` will return an integral + with `dy`. This is supposed to represent the function that you are + solving for. + + Examples + ======== + + >>> from sympy import Function, dsolve, cos, sin + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + ... f(x), hint='1st_exact') + Eq(x*cos(f(x)) + f(x)**3/3, C1) + + References + ========== + + - https://en.wikipedia.org/wiki/Exact_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 73 + + # indirect doctest + + """ + hint = "1st_exact" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x).diff(x)]) + Q = Wild('Q', exclude=[f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return P + Q*fx.diff(x) + + def _verify(self, fx) -> bool: + P, Q = self.wilds() + x = self.ode_problem.sym + y = Dummy('y') + + m, n = self.wilds_match() + + m = m.subs(fx, y) + n = n.subs(fx, y) + numerator = cancel(m.diff(y) - n.diff(x)) + + if numerator.is_zero: + # Is exact + return True + else: + # The following few conditions try to convert a non-exact + # differential equation into an exact one. + # References: + # 1. Differential equations with applications + # and historical notes - George E. Simmons + # 2. https://math.okstate.edu/people/binegar/2233-S99/2233-l12.pdf + + factor_n = cancel(numerator/n) + factor_m = cancel(-numerator/m) + if y not in factor_n.free_symbols: + # If (dP/dy - dQ/dx) / Q = f(x) + # then exp(integral(f(x))*equation becomes exact + factor = factor_n + integration_variable = x + elif x not in factor_m.free_symbols: + # If (dP/dy - dQ/dx) / -P = f(y) + # then exp(integral(f(y))*equation becomes exact + factor = factor_m + integration_variable = y + else: + # Couldn't convert to exact + return False + + factor = exp(Integral(factor, integration_variable)) + m *= factor + n *= factor + self._wilds_match[P] = m.subs(y, fx) + self._wilds_match[Q] = n.subs(y, fx) + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + m, n = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + y = Dummy('y') + + m = m.subs(fx, y) + n = n.subs(fx, y) + + gen_sol = Eq(Subs(Integral(m, x) + + Integral(n - Integral(m, x).diff(y), y), y, fx), C1) + return [gen_sol] + + +class FirstLinear(SinglePatternODESolver): + r""" + Solves 1st order linear differential equations. + + These are differential equations of the form + + .. math:: dy/dx + P(x) y = Q(x)\text{.} + + These kinds of differential equations can be solved in a general way. The + integrating factor `e^{\int P(x) \,dx}` will turn the equation into a + separable equation. The general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint, diff, sin + >>> from sympy.abc import x + >>> f, P, Q = map(Function, ['f', 'P', 'Q']) + >>> genform = Eq(f(x).diff(x) + P(x)*f(x), Q(x)) + >>> pprint(genform) + d + P(x)*f(x) + --(f(x)) = Q(x) + dx + >>> pprint(dsolve(genform, f(x), hint='1st_linear_Integral')) + / / \ + | | | + | | / | / + | | | | | + | | | P(x) dx | - | P(x) dx + | | | | | + | | / | / + f(x) = |C1 + | Q(x)*e dx|*e + | | | + \ / / + + + Examples + ======== + + >>> f = Function('f') + >>> pprint(dsolve(Eq(x*diff(f(x), x) - f(x), x**2*sin(x)), + ... f(x), '1st_linear')) + f(x) = x*(C1 - cos(x)) + + References + ========== + + - https://en.wikipedia.org/wiki/Linear_differential_equation#First-order_equation_with_variable_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 92 + + # indirect doctest + + """ + hint = '1st_linear' + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x)]) + Q = Wild('Q', exclude=[f(x), f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return fx.diff(x) + P*fx - Q + + def _get_general_solution(self, *, simplify_flag: bool = True): + P, Q = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + gensol = Eq(fx, ((C1 + Integral(Q*exp(Integral(P, x)), x)) + * exp(-Integral(P, x)))) + return [gensol] + + +class AlmostLinear(SinglePatternODESolver): + r""" + Solves an almost-linear differential equation. + + The general form of an almost linear differential equation is + + .. math:: a(x) g'(f(x)) f'(x) + b(x) g(f(x)) + c(x) + + Here `f(x)` is the function to be solved for (the dependent variable). + The substitution `g(f(x)) = u(x)` leads to a linear differential equation + for `u(x)` of the form `a(x) u' + b(x) u + c(x) = 0`. This can be solved + for `u(x)` by the `first_linear` hint and then `f(x)` is found by solving + `g(f(x)) = u(x)`. + + See Also + ======== + :obj:`sympy.solvers.ode.single.FirstLinear` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint, sin, cos + >>> from sympy.abc import x + >>> f = Function('f') + >>> d = f(x).diff(x) + >>> eq = x*d + x*f(x) + 1 + >>> dsolve(eq, f(x), hint='almost_linear') + Eq(f(x), (C1 - Ei(x))*exp(-x)) + >>> pprint(dsolve(eq, f(x), hint='almost_linear')) + -x + f(x) = (C1 - Ei(x))*e + >>> example = cos(f(x))*f(x).diff(x) + sin(f(x)) + 1 + >>> pprint(example) + d + sin(f(x)) + cos(f(x))*--(f(x)) + 1 + dx + >>> pprint(dsolve(example, f(x), hint='almost_linear')) + / -x \ / -x \ + [f(x) = pi - asin\C1*e - 1/, f(x) = asin\C1*e - 1/] + + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "almost_linear" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x).diff(x)]) + Q = Wild('Q', exclude=[f(x).diff(x)]) + return P, Q + + def _equation(self, fx, x, order): + P, Q = self.wilds() + return P*fx.diff(x) + Q + + def _verify(self, fx): + a, b = self.wilds_match() + c, b = b.as_independent(fx) if b.is_Add else (S.Zero, b) + # a, b and c are the function a(x), b(x) and c(x) respectively. + # c(x) is obtained by separating out b as terms with and without fx i.e, l(y) + # The following conditions checks if the given equation is an almost-linear differential equation using the fact that + # a(x)*(l(y))' / l(y)' is independent of l(y) + + if b.diff(fx) != 0 and not simplify(b.diff(fx)/a).has(fx): + self.ly = factor_terms(b).as_independent(fx, as_Add=False)[1] # Gives the term containing fx i.e., l(y) + self.ax = a / self.ly.diff(fx) + self.cx = -c # cx is taken as -c(x) to simplify expression in the solution integral + self.bx = factor_terms(b) / self.ly + return True + + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + gensol = Eq(self.ly, ((C1 + Integral((self.cx/self.ax)*exp(Integral(self.bx/self.ax, x)), x)) + * exp(-Integral(self.bx/self.ax, x)))) + + return [gensol] + + +class Bernoulli(SinglePatternODESolver): + r""" + Solves Bernoulli differential equations. + + These are equations of the form + + .. math:: dy/dx + P(x) y = Q(x) y^n\text{, }n \ne 1`\text{.} + + The substitution `w = 1/y^{1-n}` will transform an equation of this form + into one that is linear (see the docstring of + :obj:`~sympy.solvers.ode.single.FirstLinear`). The general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x, n + >>> f, P, Q = map(Function, ['f', 'P', 'Q']) + >>> genform = Eq(f(x).diff(x) + P(x)*f(x), Q(x)*f(x)**n) + >>> pprint(genform) + d n + P(x)*f(x) + --(f(x)) = Q(x)*f (x) + dx + >>> pprint(dsolve(genform, f(x), hint='Bernoulli_Integral'), num_columns=110) + -1 + ----- + n - 1 + // / / \ \ + || | | | | + || | / | / | / | + || | | | | | | | + || | -(n - 1)* | P(x) dx | -(n - 1)* | P(x) dx | (n - 1)* | P(x) dx| + || | | | | | | | + || | / | / | / | + f(x) = ||C1 - n* | Q(x)*e dx + | Q(x)*e dx|*e | + || | | | | + \\ / / / / + + + Note that the equation is separable when `n = 1` (see the docstring of + :obj:`~sympy.solvers.ode.single.Separable`). + + >>> pprint(dsolve(Eq(f(x).diff(x) + P(x)*f(x), Q(x)*f(x)), f(x), + ... hint='separable_Integral')) + f(x) + / + | / + | 1 | + | - dy = C1 + | (-P(x) + Q(x)) dx + | y | + | / + / + + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, pprint, log + >>> from sympy.abc import x + >>> f = Function('f') + + >>> pprint(dsolve(Eq(x*f(x).diff(x) + f(x), log(x)*f(x)**2), + ... f(x), hint='Bernoulli')) + 1 + f(x) = ----------------- + C1*x + log(x) + 1 + + References + ========== + + - https://en.wikipedia.org/wiki/Bernoulli_differential_equation + + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 95 + + # indirect doctest + + """ + hint = "Bernoulli" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + P = Wild('P', exclude=[f(x)]) + Q = Wild('Q', exclude=[f(x)]) + n = Wild('n', exclude=[x, f(x), f(x).diff(x)]) + return P, Q, n + + def _equation(self, fx, x, order): + P, Q, n = self.wilds() + return fx.diff(x) + P*fx - Q*fx**n + + def _get_general_solution(self, *, simplify_flag: bool = True): + P, Q, n = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + if n==1: + gensol = Eq(log(fx), ( + C1 + Integral((-P + Q), x) + )) + else: + gensol = Eq(fx**(1-n), ( + (C1 - (n - 1) * Integral(Q*exp(-n*Integral(P, x)) + * exp(Integral(P, x)), x) + ) * exp(-(1 - n)*Integral(P, x))) + ) + return [gensol] + + +class Factorable(SingleODESolver): + r""" + Solves equations having a solvable factor. + + This function is used to solve the equation having factors. Factors may be of type algebraic or ode. It + will try to solve each factor independently. Factors will be solved by calling dsolve. We will return the + list of solutions. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = (f(x)**2-4)*(f(x).diff(x)+f(x)) + >>> pprint(dsolve(eq, f(x))) + -x + [f(x) = 2, f(x) = -2, f(x) = C1*e ] + + + """ + hint = "factorable" + has_integral = False + + def _matches(self): + eq_orig = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + df = f(x).diff(x) + self.eqs = [] + eq = eq_orig.collect(f(x), func = cancel) + eq = fraction(factor(eq))[0] + factors = Mul.make_args(factor(eq)) + roots = [fac.as_base_exp() for fac in factors if len(fac.args)!=0] + if len(roots)>1 or roots[0][1]>1: + for base, expo in roots: + if base.has(f(x)): + self.eqs.append(base) + if len(self.eqs)>0: + return True + roots = solve(eq, df) + if len(roots)>0: + self.eqs = [(df - root) for root in roots] + # Avoid infinite recursion + matches = self.eqs != [eq_orig] + return matches + for i in factors: + if i.has(f(x)): + self.eqs.append(i) + return len(self.eqs)>0 and len(factors)>1 + + def _get_general_solution(self, *, simplify_flag: bool = True): + func = self.ode_problem.func.func + x = self.ode_problem.sym + eqns = self.eqs + sols = [] + for eq in eqns: + try: + sol = dsolve(eq, func(x)) + except NotImplementedError: + continue + else: + if isinstance(sol, list): + sols.extend(sol) + else: + sols.append(sol) + + if sols == []: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the factorable group method") + return sols + + +class RiccatiSpecial(SinglePatternODESolver): + r""" + The general Riccati equation has the form + + .. math:: dy/dx = f(x) y^2 + g(x) y + h(x)\text{.} + + While it does not have a general solution [1], the "special" form, `dy/dx + = a y^2 - b x^c`, does have solutions in many cases [2]. This routine + returns a solution for `a(dy/dx) = b y^2 + c y/x + d/x^2` that is obtained + by using a suitable change of variables to reduce it to the special form + and is valid when neither `a` nor `b` are zero and either `c` or `d` is + zero. + + >>> from sympy.abc import x, a, b, c, d + >>> from sympy import dsolve, checkodesol, pprint, Function + >>> f = Function('f') + >>> y = f(x) + >>> genform = a*y.diff(x) - (b*y**2 + c*y/x + d/x**2) + >>> sol = dsolve(genform, y, hint="Riccati_special_minus2") + >>> pprint(sol, wrap_line=False) + / / __________________ \\ + | __________________ | / 2 || + | / 2 | \/ 4*b*d - (a + c) *log(x)|| + -|a + c - \/ 4*b*d - (a + c) *tan|C1 + ----------------------------|| + \ \ 2*a // + f(x) = ------------------------------------------------------------------------ + 2*b*x + + >>> checkodesol(genform, sol, order=1)[0] + True + + References + ========== + + - https://www.maplesoft.com/support/help/Maple/view.aspx?path=odeadvisor/Riccati + - https://eqworld.ipmnet.ru/en/solutions/ode/ode0106.pdf - + https://eqworld.ipmnet.ru/en/solutions/ode/ode0123.pdf + """ + hint = "Riccati_special_minus2" + has_integral = False + order = [1] + + def _wilds(self, f, x, order): + a = Wild('a', exclude=[x, f(x), f(x).diff(x), 0]) + b = Wild('b', exclude=[x, f(x), f(x).diff(x), 0]) + c = Wild('c', exclude=[x, f(x), f(x).diff(x)]) + d = Wild('d', exclude=[x, f(x), f(x).diff(x)]) + return a, b, c, d + + def _equation(self, fx, x, order): + a, b, c, d = self.wilds() + return a*fx.diff(x) + b*fx**2 + c*fx/x + d/x**2 + + def _get_general_solution(self, *, simplify_flag: bool = True): + a, b, c, d = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + (C1,) = self.ode_problem.get_numbered_constants(num=1) + mu = sqrt(4*d*b - (a - c)**2) + + gensol = Eq(fx, (a - c - mu*tan(mu/(2*a)*log(x) + C1))/(2*b*x)) + return [gensol] + + +class RationalRiccati(SinglePatternODESolver): + r""" + Gives general solutions to the first order Riccati differential + equations that have atleast one rational particular solution. + + .. math :: y' = b_0(x) + b_1(x) y + b_2(x) y^2 + + where `b_0`, `b_1` and `b_2` are rational functions of `x` + with `b_2 \ne 0` (`b_2 = 0` would make it a Bernoulli equation). + + Examples + ======== + + >>> from sympy import Symbol, Function, dsolve, checkodesol + >>> f = Function('f') + >>> x = Symbol('x') + + >>> eq = -x**4*f(x)**2 + x**3*f(x).diff(x) + x**2*f(x) + 20 + >>> sol = dsolve(eq, hint="1st_rational_riccati") + >>> sol + Eq(f(x), (4*C1 - 5*x**9 - 4)/(x**2*(C1 + x**9 - 1))) + >>> checkodesol(eq, sol) + (True, 0) + + References + ========== + + - Riccati ODE: https://en.wikipedia.org/wiki/Riccati_equation + - N. Thieu Vo - Rational and Algebraic Solutions of First-Order Algebraic ODEs: + Algorithm 11, pp. 78 - https://www3.risc.jku.at/publications/download/risc_5387/PhDThesisThieu.pdf + """ + has_integral = False + hint = "1st_rational_riccati" + order = [1] + + def _wilds(self, f, x, order): + b0 = Wild('b0', exclude=[f(x), f(x).diff(x)]) + b1 = Wild('b1', exclude=[f(x), f(x).diff(x)]) + b2 = Wild('b2', exclude=[f(x), f(x).diff(x)]) + return (b0, b1, b2) + + def _equation(self, fx, x, order): + b0, b1, b2 = self.wilds() + return fx.diff(x) - b0 - b1*fx - b2*fx**2 + + def _matches(self): + eq = self.ode_problem.eq_expanded + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + + if order != 1: + return False + + match, funcs = match_riccati(eq, f, x) + if not match: + return False + _b0, _b1, _b2 = funcs + b0, b1, b2 = self.wilds() + self._wilds_match = match = {b0: _b0, b1: _b1, b2: _b2} + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + # Match the equation + b0, b1, b2 = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + return solve_riccati(fx, x, b0, b1, b2, gensol=True) + + +class SecondNonlinearAutonomousConserved(SinglePatternODESolver): + r""" + Gives solution for the autonomous second order nonlinear + differential equation of the form + + .. math :: f''(x) = g(f(x)) + + The solution for this differential equation can be computed + by multiplying by `f'(x)` and integrating on both sides, + converting it into a first order differential equation. + + Examples + ======== + + >>> from sympy import Function, symbols, dsolve + >>> f, g = symbols('f g', cls=Function) + >>> x = symbols('x') + + >>> eq = f(x).diff(x, 2) - g(f(x)) + >>> dsolve(eq, simplify=False) + [Eq(Integral(1/sqrt(C1 + 2*Integral(g(_u), _u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 + 2*Integral(g(_u), _u)), (_u, f(x))), C2 - x)] + + >>> from sympy import exp, log + >>> eq = f(x).diff(x, 2) - exp(f(x)) + log(f(x)) + >>> dsolve(eq, simplify=False) + [Eq(Integral(1/sqrt(-2*_u*log(_u) + 2*_u + C1 + 2*exp(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(-2*_u*log(_u) + 2*_u + C1 + 2*exp(_u)), (_u, f(x))), C2 - x)] + + References + ========== + + - https://eqworld.ipmnet.ru/en/solutions/ode/ode0301.pdf + """ + hint = "2nd_nonlinear_autonomous_conserved" + has_integral = True + order = [2] + + def _wilds(self, f, x, order): + fy = Wild('fy', exclude=[0, f(x).diff(x), f(x).diff(x, 2)]) + return (fy, ) + + def _equation(self, fx, x, order): + fy = self.wilds()[0] + return fx.diff(x, 2) + fy + + def _verify(self, fx): + return self.ode_problem.is_autonomous + + def _get_general_solution(self, *, simplify_flag: bool = True): + g = self.wilds_match()[0] + fx = self.ode_problem.func + x = self.ode_problem.sym + u = Dummy('u') + g = g.subs(fx, u) + C1, C2 = self.ode_problem.get_numbered_constants(num=2) + inside = -2*Integral(g, u) + C1 + lhs = Integral(1/sqrt(inside), (u, fx)) + return [Eq(lhs, C2 + x), Eq(lhs, C2 - x)] + + +class Liouville(SinglePatternODESolver): + r""" + Solves 2nd order Liouville differential equations. + + The general form of a Liouville ODE is + + .. math:: \frac{d^2 y}{dx^2} + g(y) \left(\! + \frac{dy}{dx}\!\right)^2 + h(x) + \frac{dy}{dx}\text{.} + + The general solution is: + + >>> from sympy import Function, dsolve, Eq, pprint, diff + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = Eq(diff(f(x),x,x) + g(f(x))*diff(f(x),x)**2 + + ... h(x)*diff(f(x),x), 0) + >>> pprint(genform) + 2 2 + /d \ d d + g(f(x))*|--(f(x))| + h(x)*--(f(x)) + ---(f(x)) = 0 + \dx / dx 2 + dx + >>> pprint(dsolve(genform, f(x), hint='Liouville_Integral')) + f(x) + / / + | | + | / | / + | | | | + | - | h(x) dx | | g(y) dy + | | | | + | / | / + C1 + C2* | e dx + | e dy = 0 + | | + / / + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(diff(f(x), x, x) + diff(f(x), x)**2/f(x) + + ... diff(f(x), x)/x, f(x), hint='Liouville')) + ________________ ________________ + [f(x) = -\/ C1 + C2*log(x) , f(x) = \/ C1 + C2*log(x) ] + + References + ========== + + - Goldstein and Braun, "Advanced Methods for the Solution of Differential + Equations", pp. 98 + - https://www.maplesoft.com/support/help/Maple/view.aspx?path=odeadvisor/Liouville + + # indirect doctest + + """ + hint = "Liouville" + has_integral = True + order = [2] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + k = Wild('k', exclude=[f(x).diff(x)]) + return d, e, k + + def _equation(self, fx, x, order): + # Liouville ODE in the form + # f(x).diff(x, 2) + g(f(x))*(f(x).diff(x))**2 + h(x)*f(x).diff(x) + # See Goldstein and Braun, "Advanced Methods for the Solution of + # Differential Equations", pg. 98 + d, e, k = self.wilds() + return d*fx.diff(x, 2) + e*fx.diff(x)**2 + k*fx.diff(x) + + def _verify(self, fx): + d, e, k = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.g = simplify(e/d).subs(fx, self.y) + self.h = simplify(k/d).subs(fx, self.y) + if self.y in self.h.free_symbols or x in self.g.free_symbols: + return False + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, k = self.wilds_match() + fx = self.ode_problem.func + x = self.ode_problem.sym + C1, C2 = self.ode_problem.get_numbered_constants(num=2) + int = Integral(exp(Integral(self.g, self.y)), (self.y, None, fx)) + gen_sol = Eq(int + C1*Integral(exp(-Integral(self.h, x)), x) + C2, 0) + + return [gen_sol] + + +class Separable(SinglePatternODESolver): + r""" + Solves separable 1st order differential equations. + + This is any differential equation that can be written as `P(y) + \tfrac{dy}{dx} = Q(x)`. The solution can then just be found by + rearranging terms and integrating: `\int P(y) \,dy = \int Q(x) \,dx`. + This hint uses :py:meth:`sympy.simplify.simplify.separatevars` as its back + end, so if a separable equation is not caught by this solver, it is most + likely the fault of that function. + :py:meth:`~sympy.simplify.simplify.separatevars` is + smart enough to do most expansion and factoring necessary to convert a + separable equation `F(x, y)` into the proper form `P(x)\cdot{}Q(y)`. The + general solution is:: + + >>> from sympy import Function, dsolve, Eq, pprint + >>> from sympy.abc import x + >>> a, b, c, d, f = map(Function, ['a', 'b', 'c', 'd', 'f']) + >>> genform = Eq(a(x)*b(f(x))*f(x).diff(x), c(x)*d(f(x))) + >>> pprint(genform) + d + a(x)*b(f(x))*--(f(x)) = c(x)*d(f(x)) + dx + >>> pprint(dsolve(genform, f(x), hint='separable_Integral')) + f(x) + / / + | | + | b(y) | c(x) + | ---- dy = C1 + | ---- dx + | d(y) | a(x) + | | + / / + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(Eq(f(x)*f(x).diff(x) + x, 3*x*f(x)**2), f(x), + ... hint='separable', simplify=False)) + / 2 \ 2 + log\3*f (x) - 1/ x + ---------------- = C1 + -- + 6 2 + + References + ========== + + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 52 + + # indirect doctest + + """ + hint = "separable" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + d, e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + d = separatevars(d.subs(fx, self.y)) + e = separatevars(e.subs(fx, self.y)) + # m1[coeff]*m1[x]*m1[y] + m2[coeff]*m2[x]*m2[y]*y' + self.m1 = separatevars(d, dict=True, symbols=(x, self.y)) + self.m2 = separatevars(e, dict=True, symbols=(x, self.y)) + if self.m1 and self.m2: + return True + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + return self.m1, self.m2, x, fx + + def _get_general_solution(self, *, simplify_flag: bool = True): + m1, m2, x, fx = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral(m2['coeff']*m2[self.y]/m1[self.y], + (self.y, None, fx)) + gen_sol = Eq(int, Integral(-m1['coeff']*m1[x]/ + m2[x], x) + C1) + return [gen_sol] + + +class SeparableReduced(Separable): + r""" + Solves a differential equation that can be reduced to the separable form. + + The general form of this equation is + + .. math:: y' + (y/x) H(x^n y) = 0\text{}. + + This can be solved by substituting `u(y) = x^n y`. The equation then + reduces to the separable form `\frac{u'}{u (\mathrm{power} - H(u))} - + \frac{1}{x} = 0`. + + The general solution is: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x, n + >>> f, g = map(Function, ['f', 'g']) + >>> genform = f(x).diff(x) + (f(x)/x)*g(x**n*f(x)) + >>> pprint(genform) + / n \ + d f(x)*g\x *f(x)/ + --(f(x)) + --------------- + dx x + >>> pprint(dsolve(genform, hint='separable_reduced')) + n + x *f(x) + / + | + | 1 + | ------------ dy = C1 + log(x) + | y*(n - g(y)) + | + / + + See Also + ======== + :obj:`sympy.solvers.ode.single.Separable` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> d = f(x).diff(x) + >>> eq = (x - x**2*f(x))*d - f(x) + >>> dsolve(eq, hint='separable_reduced') + [Eq(f(x), (1 - sqrt(C1*x**2 + 1))/x), Eq(f(x), (sqrt(C1*x**2 + 1) + 1)/x)] + >>> pprint(dsolve(eq, hint='separable_reduced')) + ___________ ___________ + / 2 / 2 + 1 - \/ C1*x + 1 \/ C1*x + 1 + 1 + [f(x) = ------------------, f(x) = ------------------] + x x + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "separable_reduced" + has_integral = True + order = [1] + + def _degree(self, expr, x): + # Made this function to calculate the degree of + # x in an expression. If expr will be of form + # x**p*y, (wheare p can be variables/rationals) then it + # will return p. + for val in expr: + if val.has(x): + if isinstance(val, Pow) and val.as_base_exp()[0] == x: + return (val.as_base_exp()[1]) + elif val == x: + return (val.as_base_exp()[1]) + else: + return self._degree(val.args, x) + return 0 + + def _powers(self, expr): + # this function will return all the different relative power of x w.r.t f(x). + # expr = x**p * f(x)**q then it will return {p/q}. + pows = set() + fx = self.ode_problem.func + x = self.ode_problem.sym + self.y = Dummy('y') + if isinstance(expr, Add): + exprs = expr.atoms(Add) + elif isinstance(expr, Mul): + exprs = expr.atoms(Mul) + elif isinstance(expr, Pow): + exprs = expr.atoms(Pow) + else: + exprs = {expr} + + for arg in exprs: + if arg.has(x): + _, u = arg.as_independent(x, fx) + pow = self._degree((u.subs(fx, self.y), ), x)/self._degree((u.subs(fx, self.y), ), self.y) + pows.add(pow) + return pows + + def _verify(self, fx): + num, den = self.wilds_match() + x = self.ode_problem.sym + factor = simplify(x/fx*num/den) + # Try representing factor in terms of x^n*y + # where n is lowest power of x in factor; + # first remove terms like sqrt(2)*3 from factor.atoms(Mul) + num, dem = factor.as_numer_denom() + num = expand(num) + dem = expand(dem) + pows = self._powers(num) + pows.update(self._powers(dem)) + pows = list(pows) + if(len(pows)==1) and pows[0]!=zoo: + self.t = Dummy('t') + self.r2 = {'t': self.t} + num = num.subs(x**pows[0]*fx, self.t) + dem = dem.subs(x**pows[0]*fx, self.t) + test = num/dem + free = test.free_symbols + if len(free) == 1 and free.pop() == self.t: + self.r2.update({'power' : pows[0], 'u' : test}) + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + u = self.r2['u'].subs(self.r2['t'], self.y) + ycoeff = 1/(self.y*(self.r2['power'] - u)) + m1 = {self.y: 1, x: -1/x, 'coeff': 1} + m2 = {self.y: ycoeff, x: 1, 'coeff': 1} + return m1, m2, x, x**self.r2['power']*fx + + +class HomogeneousCoeffSubsDepDivIndep(SinglePatternODESolver): + r""" + Solves a 1st order differential equation with homogeneous coefficients + using the substitution `u_1 = \frac{\text{}}{\text{}}`. + + This is a differential equation + + .. math:: P(x, y) + Q(x, y) dy/dx = 0 + + such that `P` and `Q` are homogeneous and of the same order. A function + `F(x, y)` is homogeneous of order `n` if `F(x t, y t) = t^n F(x, y)`. + Equivalently, `F(x, y)` can be rewritten as `G(y/x)` or `H(x/y)`. See + also the docstring of :py:meth:`~sympy.solvers.ode.homogeneous_order`. + + If the coefficients `P` and `Q` in the differential equation above are + homogeneous functions of the same order, then it can be shown that the + substitution `y = u_1 x` (i.e. `u_1 = y/x`) will turn the differential + equation into an equation separable in the variables `x` and `u`. If + `h(u_1)` is the function that results from making the substitution `u_1 = + f(x)/x` on `P(x, f(x))` and `g(u_2)` is the function that results from the + substitution on `Q(x, f(x))` in the differential equation `P(x, f(x)) + + Q(x, f(x)) f'(x) = 0`, then the general solution is:: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = g(f(x)/x) + h(f(x)/x)*f(x).diff(x) + >>> pprint(genform) + /f(x)\ /f(x)\ d + g|----| + h|----|*--(f(x)) + \ x / \ x / dx + >>> pprint(dsolve(genform, f(x), + ... hint='1st_homogeneous_coeff_subs_dep_div_indep_Integral')) + f(x) + ---- + x + / + | + | -h(u1) + log(x) = C1 + | ---------------- d(u1) + | u1*h(u1) + g(u1) + | + / + + Where `u_1 h(u_1) + g(u_1) \ne 0` and `x \ne 0`. + + See also the docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep`. + + Examples + ======== + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_subs_dep_div_indep', simplify=False)) + / 3 \ + |3*f(x) f (x)| + log|------ + -----| + | x 3 | + \ x / + log(x) = log(C1) - ------------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_subs_dep_div_indep" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.d = separatevars(self.d.subs(fx, self.y)) + self.e = separatevars(self.e.subs(fx, self.y)) + ordera = homogeneous_order(self.d, x, self.y) + orderb = homogeneous_order(self.e, x, self.y) + if ordera == orderb and ordera is not None: + self.u = Dummy('u') + if simplify((self.d + self.u*self.e).subs({x: 1, self.y: self.u})) != 0: + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + xarg = 0 + yarg = 0 + return [self.d, self.e, fx, x, self.u, self.u1, self.y, xarg, yarg] + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, fx, x, u, u1, y, xarg, yarg = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral( + (-e/(d + u1*e)).subs({x: 1, y: u1}), + (u1, None, fx/x)) + sol = logcombine(Eq(log(x), int + log(C1)), force=True) + gen_sol = sol.subs(fx, u).subs(((u, u - yarg), (x, x - xarg), (u, fx))) + return [gen_sol] + + +class HomogeneousCoeffSubsIndepDivDep(SinglePatternODESolver): + r""" + Solves a 1st order differential equation with homogeneous coefficients + using the substitution `u_2 = \frac{\text{}}{\text{}}`. + + This is a differential equation + + .. math:: P(x, y) + Q(x, y) dy/dx = 0 + + such that `P` and `Q` are homogeneous and of the same order. A function + `F(x, y)` is homogeneous of order `n` if `F(x t, y t) = t^n F(x, y)`. + Equivalently, `F(x, y)` can be rewritten as `G(y/x)` or `H(x/y)`. See + also the docstring of :py:meth:`~sympy.solvers.ode.homogeneous_order`. + + If the coefficients `P` and `Q` in the differential equation above are + homogeneous functions of the same order, then it can be shown that the + substitution `x = u_2 y` (i.e. `u_2 = x/y`) will turn the differential + equation into an equation separable in the variables `y` and `u_2`. If + `h(u_2)` is the function that results from making the substitution `u_2 = + x/f(x)` on `P(x, f(x))` and `g(u_2)` is the function that results from the + substitution on `Q(x, f(x))` in the differential equation `P(x, f(x)) + + Q(x, f(x)) f'(x) = 0`, then the general solution is: + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f, g, h = map(Function, ['f', 'g', 'h']) + >>> genform = g(x/f(x)) + h(x/f(x))*f(x).diff(x) + >>> pprint(genform) + / x \ / x \ d + g|----| + h|----|*--(f(x)) + \f(x)/ \f(x)/ dx + >>> pprint(dsolve(genform, f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep_Integral')) + x + ---- + f(x) + / + | + | -g(u1) + | ---------------- d(u1) + | u1*g(u1) + h(u1) + | + / + + f(x) = C1*e + + Where `u_1 g(u_1) + h(u_1) \ne 0` and `f(x) \ne 0`. + + See also the docstrings of + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffBest` and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep`. + + Examples + ======== + + >>> from sympy import Function, pprint, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_subs_indep_div_dep', + ... simplify=False)) + / 2 \ + |3*x | + log|----- + 1| + | 2 | + \f (x) / + log(f(x)) = log(C1) - -------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_subs_indep_div_dep" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + self.y = Dummy('y') + x = self.ode_problem.sym + self.d = separatevars(self.d.subs(fx, self.y)) + self.e = separatevars(self.e.subs(fx, self.y)) + ordera = homogeneous_order(self.d, x, self.y) + orderb = homogeneous_order(self.e, x, self.y) + if ordera == orderb and ordera is not None: + self.u = Dummy('u') + if simplify((self.e + self.u*self.d).subs({x: self.u, self.y: 1})) != 0: + return True + return False + return False + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + xarg = 0 + yarg = 0 + return [self.d, self.e, fx, x, self.u, self.u1, self.y, xarg, yarg] + + def _get_general_solution(self, *, simplify_flag: bool = True): + d, e, fx, x, u, u1, y, xarg, yarg = self._get_match_object() + (C1,) = self.ode_problem.get_numbered_constants(num=1) + int = Integral(simplify((-d/(e + u1*d)).subs({x: u1, y: 1})), (u1, None, x/fx)) # type: ignore + sol = logcombine(Eq(log(fx), int + log(C1)), force=True) + gen_sol = sol.subs(fx, u).subs(((u, u - yarg), (x, x - xarg), (u, fx))) + return [gen_sol] + + +class HomogeneousCoeffBest(HomogeneousCoeffSubsIndepDivDep, HomogeneousCoeffSubsDepDivIndep): + r""" + Returns the best solution to an ODE from the two hints + ``1st_homogeneous_coeff_subs_dep_div_indep`` and + ``1st_homogeneous_coeff_subs_indep_div_dep``. + + This is as determined by :py:meth:`~sympy.solvers.ode.ode.ode_sol_simplicity`. + + See the + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep` + and + :obj:`~sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` + docstrings for more information on these hints. Note that there is no + ``ode_1st_homogeneous_coeff_best_Integral`` hint. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), f(x), + ... hint='1st_homogeneous_coeff_best', simplify=False)) + / 2 \ + |3*x | + log|----- + 1| + | 2 | + \f (x) / + log(f(x)) = log(C1) - -------------- + 3 + + References + ========== + + - https://en.wikipedia.org/wiki/Homogeneous_differential_equation + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 59 + + # indirect doctest + + """ + hint = "1st_homogeneous_coeff_best" + has_integral = False + order = [1] + + def _verify(self, fx): + if HomogeneousCoeffSubsIndepDivDep._verify(self, fx) and HomogeneousCoeffSubsDepDivIndep._verify(self, fx): + return True + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + # There are two substitutions that solve the equation, u1=y/x and u2=x/y + # # They produce different integrals, so try them both and see which + # # one is easier + sol1 = HomogeneousCoeffSubsIndepDivDep._get_general_solution(self) + sol2 = HomogeneousCoeffSubsDepDivIndep._get_general_solution(self) + fx = self.ode_problem.func + if simplify_flag: + sol1 = odesimp(self.ode_problem.eq, *sol1, fx, "1st_homogeneous_coeff_subs_indep_div_dep") + sol2 = odesimp(self.ode_problem.eq, *sol2, fx, "1st_homogeneous_coeff_subs_dep_div_indep") + return min([sol1, sol2], key=lambda x: ode_sol_simplicity(x, fx, trysolving=not simplify)) + + +class LinearCoefficients(HomogeneousCoeffBest): + r""" + Solves a differential equation with linear coefficients. + + The general form of a differential equation with linear coefficients is + + .. math:: y' + F\left(\!\frac{a_1 x + b_1 y + c_1}{a_2 x + b_2 y + + c_2}\!\right) = 0\text{,} + + where `a_1`, `b_1`, `c_1`, `a_2`, `b_2`, `c_2` are constants and `a_1 b_2 + - a_2 b_1 \ne 0`. + + This can be solved by substituting: + + .. math:: x = x' + \frac{b_2 c_1 - b_1 c_2}{a_2 b_1 - a_1 b_2} + + y = y' + \frac{a_1 c_2 - a_2 c_1}{a_2 b_1 - a_1 + b_2}\text{.} + + This substitution reduces the equation to a homogeneous differential + equation. + + See Also + ======== + :obj:`sympy.solvers.ode.single.HomogeneousCoeffBest` + :obj:`sympy.solvers.ode.single.HomogeneousCoeffSubsIndepDivDep` + :obj:`sympy.solvers.ode.single.HomogeneousCoeffSubsDepDivIndep` + + Examples + ======== + + >>> from sympy import dsolve, Function, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> df = f(x).diff(x) + >>> eq = (x + f(x) + 1)*df + (f(x) - 6*x + 1) + >>> dsolve(eq, hint='linear_coefficients') + [Eq(f(x), -x - sqrt(C1 + 7*x**2) - 1), Eq(f(x), -x + sqrt(C1 + 7*x**2) - 1)] + >>> pprint(dsolve(eq, hint='linear_coefficients')) + ___________ ___________ + / 2 / 2 + [f(x) = -x - \/ C1 + 7*x - 1, f(x) = -x + \/ C1 + 7*x - 1] + + + References + ========== + + - Joel Moses, "Symbolic Integration - The Stormy Decade", Communications + of the ACM, Volume 14, Number 8, August 1971, pp. 558 + """ + hint = "linear_coefficients" + has_integral = True + order = [1] + + def _wilds(self, f, x, order): + d = Wild('d', exclude=[f(x).diff(x), f(x).diff(x, 2)]) + e = Wild('e', exclude=[f(x).diff(x)]) + return d, e + + def _equation(self, fx, x, order): + d, e = self.wilds() + return d + e*fx.diff(x) + + def _verify(self, fx): + self.d, self.e = self.wilds_match() + a, b = self.wilds() + F = self.d/self.e + x = self.ode_problem.sym + params = self._linear_coeff_match(F, fx) + if params: + self.xarg, self.yarg = params + u = Dummy('u') + t = Dummy('t') + self.y = Dummy('y') + # Dummy substitution for df and f(x). + dummy_eq = self.ode_problem.eq.subs(((fx.diff(x), t), (fx, u))) + reps = ((x, x + self.xarg), (u, u + self.yarg), (t, fx.diff(x)), (u, fx)) + dummy_eq = simplify(dummy_eq.subs(reps)) + # get the re-cast values for e and d + r2 = collect(expand(dummy_eq), [fx.diff(x), fx]).match(a*fx.diff(x) + b) + if r2: + self.d, self.e = r2[b], r2[a] + orderd = homogeneous_order(self.d, x, fx) + ordere = homogeneous_order(self.e, x, fx) + if orderd == ordere and orderd is not None: + self.d = self.d.subs(fx, self.y) + self.e = self.e.subs(fx, self.y) + return True + return False + return False + + def _linear_coeff_match(self, expr, func): + r""" + Helper function to match hint ``linear_coefficients``. + + Matches the expression to the form `(a_1 x + b_1 f(x) + c_1)/(a_2 x + b_2 + f(x) + c_2)` where the following conditions hold: + + 1. `a_1`, `b_1`, `c_1`, `a_2`, `b_2`, `c_2` are Rationals; + 2. `c_1` or `c_2` are not equal to zero; + 3. `a_2 b_1 - a_1 b_2` is not equal to zero. + + Return ``xarg``, ``yarg`` where + + 1. ``xarg`` = `(b_2 c_1 - b_1 c_2)/(a_2 b_1 - a_1 b_2)` + 2. ``yarg`` = `(a_1 c_2 - a_2 c_1)/(a_2 b_1 - a_1 b_2)` + + + Examples + ======== + + >>> from sympy import Function, sin + >>> from sympy.abc import x + >>> from sympy.solvers.ode.single import LinearCoefficients + >>> f = Function('f') + >>> eq = (-25*f(x) - 8*x + 62)/(4*f(x) + 11*x - 11) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + (1/9, 22/9) + >>> eq = sin((-5*f(x) - 8*x + 6)/(4*f(x) + x - 1)) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + (19/27, 2/27) + >>> eq = sin(f(x)/x) + >>> obj = LinearCoefficients(eq) + >>> obj._linear_coeff_match(eq, f(x)) + + """ + f = func.func + x = func.args[0] + def abc(eq): + r''' + Internal function of _linear_coeff_match + that returns Rationals a, b, c + if eq is a*x + b*f(x) + c, else None. + ''' + eq = _mexpand(eq) + c = eq.as_independent(x, f(x), as_Add=True)[0] + if not c.is_Rational: + return + a = eq.coeff(x) + if not a.is_Rational: + return + b = eq.coeff(f(x)) + if not b.is_Rational: + return + if eq == a*x + b*f(x) + c: + return a, b, c + + def match(arg): + r''' + Internal function of _linear_coeff_match that returns Rationals a1, + b1, c1, a2, b2, c2 and a2*b1 - a1*b2 of the expression (a1*x + b1*f(x) + + c1)/(a2*x + b2*f(x) + c2) if one of c1 or c2 and a2*b1 - a1*b2 is + non-zero, else None. + ''' + n, d = arg.together().as_numer_denom() + m = abc(n) + if m is not None: + a1, b1, c1 = m + m = abc(d) + if m is not None: + a2, b2, c2 = m + d = a2*b1 - a1*b2 + if (c1 or c2) and d: + return a1, b1, c1, a2, b2, c2, d + + m = [fi.args[0] for fi in expr.atoms(Function) if fi.func != f and + len(fi.args) == 1 and not fi.args[0].is_Function] or {expr} + m1 = match(m.pop()) + if m1 and all(match(mi) == m1 for mi in m): + a1, b1, c1, a2, b2, c2, denom = m1 + return (b2*c1 - b1*c2)/denom, (a1*c2 - a2*c1)/denom + + def _get_match_object(self): + fx = self.ode_problem.func + x = self.ode_problem.sym + self.u1 = Dummy('u1') + u = Dummy('u') + return [self.d, self.e, fx, x, u, self.u1, self.y, self.xarg, self.yarg] + + +class NthOrderReducible(SingleODESolver): + r""" + Solves ODEs that only involve derivatives of the dependent variable using + a substitution of the form `f^n(x) = g(x)`. + + For example any second order ODE of the form `f''(x) = h(f'(x), x)` can be + transformed into a pair of 1st order ODEs `g'(x) = h(g(x), x)` and + `f'(x) = g(x)`. Usually the 1st order ODE for `g` is easier to solve. If + that gives an explicit solution for `g` then `f` is found simply by + integration. + + + Examples + ======== + + >>> from sympy import Function, dsolve, Eq + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = Eq(x*f(x).diff(x)**2 + f(x).diff(x, 2), 0) + >>> dsolve(eq, f(x), hint='nth_order_reducible') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), C1 - sqrt(-1/C2)*log(-C2*sqrt(-1/C2) + x) + sqrt(-1/C2)*log(C2*sqrt(-1/C2) + x)) + + """ + hint = "nth_order_reducible" + has_integral = False + + def _matches(self): + # Any ODE that can be solved with a substitution and + # repeated integration e.g.: + # `d^2/dx^2(y) + x*d/dx(y) = constant + #f'(x) must be finite for this to work + eq = self.ode_problem.eq_preprocessed + func = self.ode_problem.func + x = self.ode_problem.sym + r""" + Matches any differential equation that can be rewritten with a smaller + order. Only derivatives of ``func`` alone, wrt a single variable, + are considered, and only in them should ``func`` appear. + """ + # ODE only handles functions of 1 variable so this affirms that state + assert len(func.args) == 1 + vc = [d.variable_count[0] for d in eq.atoms(Derivative) + if d.expr == func and len(d.variable_count) == 1] + ords = [c for v, c in vc if v == x] + if len(ords) < 2: + return False + self.smallest = min(ords) + # make sure func does not appear outside of derivatives + D = Dummy() + if eq.subs(func.diff(x, self.smallest), D).has(func): + return False + return True + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + n = self.smallest + # get a unique function name for g + names = [a.name for a in eq.atoms(AppliedUndef)] + while True: + name = Dummy().name + if name not in names: + g = Function(name) + break + w = f(x).diff(x, n) + geq = eq.subs(w, g(x)) + gsol = dsolve(geq, g(x)) + + if not isinstance(gsol, list): + gsol = [gsol] + + # Might be multiple solutions to the reduced ODE: + fsol = [] + for gsoli in gsol: + fsoli = dsolve(gsoli.subs(g(x), w), f(x)) # or do integration n times + fsol.append(fsoli) + + return fsol + + +class SecondHypergeometric(SingleODESolver): + r""" + Solves 2nd order linear differential equations. + + It computes special function solutions which can be expressed using the + 2F1, 1F1 or 0F1 hypergeometric functions. + + .. math:: y'' + A(x) y' + B(x) y = 0\text{,} + + where `A` and `B` are rational functions. + + These kinds of differential equations have solution of non-Liouvillian form. + + Given linear ODE can be obtained from 2F1 given by + + .. math:: (x^2 - x) y'' + ((a + b + 1) x - c) y' + b a y = 0\text{,} + + where {a, b, c} are arbitrary constants. + + Notes + ===== + + The algorithm should find any solution of the form + + .. math:: y = P(x) _pF_q(..; ..;\frac{\alpha x^k + \beta}{\gamma x^k + \delta})\text{,} + + where pFq is any of 2F1, 1F1 or 0F1 and `P` is an "arbitrary function". + Currently only the 2F1 case is implemented in SymPy but the other cases are + described in the paper and could be implemented in future (contributions + welcome!). + + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = (x*x - x)*f(x).diff(x,2) + (5*x - 1)*f(x).diff(x) + 4*f(x) + >>> pprint(dsolve(eq, f(x), '2nd_hypergeometric')) + _ + / / 4 \\ |_ /-1, -1 | \ + |C1 + C2*|log(x) + -----||* | | | x| + \ \ x + 1// 2 1 \ 1 | / + f(x) = -------------------------------------------- + 3 + (x - 1) + + + References + ========== + + - "Non-Liouvillian solutions for second order linear ODEs" by L. Chan, E.S. Cheb-Terrab + + """ + hint = "2nd_hypergeometric" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + func = self.ode_problem.func + r = match_2nd_hypergeometric(eq, func) + self.match_object = None + if r: + A, B = r + d = equivalence_hypergeometric(A, B, func) + if d: + if d['type'] == "2F1": + self.match_object = match_2nd_2F1_hypergeometric(d['I0'], d['k'], d['sing_point'], func) + if self.match_object is not None: + self.match_object.update({'A':A, 'B':B}) + # We can extend it for 1F1 and 0F1 type also. + return self.match_object is not None + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + func = self.ode_problem.func + if self.match_object['type'] == "2F1": + sol = get_sol_2F1_hypergeometric(eq, func, self.match_object) + if sol is None: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the hypergeometric method") + + return [sol] + + +class NthLinearConstantCoeffHomogeneous(SingleODESolver): + r""" + Solves an `n`\th order linear homogeneous differential equation with + constant coefficients. + + This is an equation of the form + + .. math:: a_n f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + + a_0 f(x) = 0\text{.} + + These equations can be solved in a general manner, by taking the roots of + the characteristic equation `a_n m^n + a_{n-1} m^{n-1} + \cdots + a_1 m + + a_0 = 0`. The solution will then be the sum of `C_n x^i e^{r x}` terms, + for each where `C_n` is an arbitrary constant, `r` is a root of the + characteristic equation and `i` is one of each from 0 to the multiplicity + of the root - 1 (for example, a root 3 of multiplicity 2 would create the + terms `C_1 e^{3 x} + C_2 x e^{3 x}`). The exponential is usually expanded + for complex roots using Euler's equation `e^{I x} = \cos(x) + I \sin(x)`. + Complex roots always come in conjugate pairs in polynomials with real + coefficients, so the two roots will be represented (after simplifying the + constants) as `e^{a x} \left(C_1 \cos(b x) + C_2 \sin(b x)\right)`. + + If SymPy cannot find exact roots to the characteristic equation, a + :py:class:`~sympy.polys.rootoftools.ComplexRootOf` instance will be return + instead. + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(f(x).diff(x, 5) + 10*f(x).diff(x) - 2*f(x), f(x), + ... hint='nth_linear_constant_coeff_homogeneous') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), C5*exp(x*CRootOf(_x**5 + 10*_x - 2, 0)) + + (C1*sin(x*im(CRootOf(_x**5 + 10*_x - 2, 1))) + + C2*cos(x*im(CRootOf(_x**5 + 10*_x - 2, 1))))*exp(x*re(CRootOf(_x**5 + 10*_x - 2, 1))) + + (C3*sin(x*im(CRootOf(_x**5 + 10*_x - 2, 3))) + + C4*cos(x*im(CRootOf(_x**5 + 10*_x - 2, 3))))*exp(x*re(CRootOf(_x**5 + 10*_x - 2, 3)))) + + Note that because this method does not involve integration, there is no + ``nth_linear_constant_coeff_homogeneous_Integral`` hint. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 4) + 2*f(x).diff(x, 3) - + ... 2*f(x).diff(x, 2) - 6*f(x).diff(x) + 5*f(x), f(x), + ... hint='nth_linear_constant_coeff_homogeneous')) + x -2*x + f(x) = (C1 + C2*x)*e + (C3*sin(x) + C4*cos(x))*e + + References + ========== + + - https://en.wikipedia.org/wiki/Linear_differential_equation section: + Nonhomogeneous_equation_with_constant_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 211 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_homogeneous" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if not self.r[-1]: + return True + else: + return False + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + fx = self.ode_problem.func + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, fx, order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + gsol = Add(*[i*j for (i, j) in zip(constants, roots)]) + gsol = Eq(fx, gsol) + if simplify_flag: + gsol = _get_simplified_sol([gsol], fx, collectterms) + + return [gsol] + + +class NthLinearConstantCoeffVariationOfParameters(SingleODESolver): + r""" + Solves an `n`\th order linear differential equation with constant + coefficients using the method of variation of parameters. + + This method works on any differential equations of the form + + .. math:: f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + a_0 + f(x) = P(x)\text{.} + + This method works by assuming that the particular solution takes the form + + .. math:: \sum_{x=1}^{n} c_i(x) y_i(x)\text{,} + + where `y_i` is the `i`\th solution to the homogeneous equation. The + solution is then solved using Wronskian's and Cramer's Rule. The + particular solution is given by + + .. math:: \sum_{x=1}^n \left( \int \frac{W_i(x)}{W(x)} \,dx + \right) y_i(x) \text{,} + + where `W(x)` is the Wronskian of the fundamental system (the system of `n` + linearly independent solutions to the homogeneous equation), and `W_i(x)` + is the Wronskian of the fundamental system with the `i`\th column replaced + with `[0, 0, \cdots, 0, P(x)]`. + + This method is general enough to solve any `n`\th order inhomogeneous + linear differential equation with constant coefficients, but sometimes + SymPy cannot simplify the Wronskian well enough to integrate it. If this + method hangs, try using the + ``nth_linear_constant_coeff_variation_of_parameters_Integral`` hint and + simplifying the integrals manually. Also, prefer using + ``nth_linear_constant_coeff_undetermined_coefficients`` when it + applies, because it does not use integration, making it faster and more + reliable. + + Warning, using simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters' in + :py:meth:`~sympy.solvers.ode.dsolve` may cause it to hang, because it will + not attempt to simplify the Wronskian before integrating. It is + recommended that you only use simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters_Integral' for this + method, especially if the solution to the homogeneous equation has + trigonometric functions in it. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint, exp, log + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 3) - 3*f(x).diff(x, 2) + + ... 3*f(x).diff(x) - f(x) - exp(x)*log(x), f(x), + ... hint='nth_linear_constant_coeff_variation_of_parameters')) + / / / x*log(x) 11*x\\\ x + f(x) = |C1 + x*|C2 + x*|C3 + -------- - ----|||*e + \ \ \ 6 36 /// + + References + ========== + + - https://en.wikipedia.org/wiki/Variation_of_parameters + - https://planetmath.org/VariationOfParameters + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 233 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_variation_of_parameters" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if self.r[-1]: + return True + else: + return False + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, f(x), order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + homogen_sol = Add(*[i*j for (i, j) in zip(constants, roots)]) + homogen_sol = Eq(f(x), homogen_sol) + homogen_sol = _solve_variation_of_parameters(eq, f(x), roots, homogen_sol, order, self.r, simplify_flag) + if simplify_flag: + homogen_sol = _get_simplified_sol([homogen_sol], f(x), collectterms) + return [homogen_sol] + + +class NthLinearConstantCoeffUndeterminedCoefficients(SingleODESolver): + r""" + Solves an `n`\th order linear differential equation with constant + coefficients using the method of undetermined coefficients. + + This method works on differential equations of the form + + .. math:: a_n f^{(n)}(x) + a_{n-1} f^{(n-1)}(x) + \cdots + a_1 f'(x) + + a_0 f(x) = P(x)\text{,} + + where `P(x)` is a function that has a finite number of linearly + independent derivatives. + + Functions that fit this requirement are finite sums functions of the form + `a x^i e^{b x} \sin(c x + d)` or `a x^i e^{b x} \cos(c x + d)`, where `i` + is a non-negative integer and `a`, `b`, `c`, and `d` are constants. For + example any polynomial in `x`, functions like `x^2 e^{2 x}`, `x \sin(x)`, + and `e^x \cos(x)` can all be used. Products of `\sin`'s and `\cos`'s have + a finite number of derivatives, because they can be expanded into `\sin(a + x)` and `\cos(b x)` terms. However, SymPy currently cannot do that + expansion, so you will need to manually rewrite the expression in terms of + the above to use this method. So, for example, you will need to manually + convert `\sin^2(x)` into `(1 + \cos(2 x))/2` to properly apply the method + of undetermined coefficients on it. + + This method works by creating a trial function from the expression and all + of its linear independent derivatives and substituting them into the + original ODE. The coefficients for each term will be a system of linear + equations, which are be solved for and substituted, giving the solution. + If any of the trial functions are linearly dependent on the solution to + the homogeneous equation, they are multiplied by sufficient `x` to make + them linearly independent. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint, exp, cos + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x, 2) + 2*f(x).diff(x) + f(x) - + ... 4*exp(-x)*x**2 + cos(2*x), f(x), + ... hint='nth_linear_constant_coeff_undetermined_coefficients')) + / / 3\\ + | | x || -x 4*sin(2*x) 3*cos(2*x) + f(x) = |C1 + x*|C2 + --||*e - ---------- + ---------- + \ \ 3 // 25 25 + + References + ========== + + - https://en.wikipedia.org/wiki/Method_of_undetermined_coefficients + - M. Tenenbaum & H. Pollard, "Ordinary Differential Equations", + Dover 1963, pp. 221 + + # indirect doctest + + """ + hint = "nth_linear_constant_coeff_undetermined_coefficients" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + func = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + self.r = self.ode_problem.get_linear_coefficients(eq, func, order) + does_match = False + if order and self.r and not any(self.r[i].has(x) for i in self.r if i >= 0): + if self.r[-1]: + eq_homogeneous = Add(eq, -self.r[-1]) + undetcoeff = _undetermined_coefficients_match(self.r[-1], x, func, eq_homogeneous) + if undetcoeff['test']: + self.trialset = undetcoeff['trialset'] + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + roots, collectterms = _get_const_characteristic_eq_sols(self.r, f(x), order) + # A generator of constants + constants = self.ode_problem.get_numbered_constants(num=len(roots)) + homogen_sol = Add(*[i*j for (i, j) in zip(constants, roots)]) + homogen_sol = Eq(f(x), homogen_sol) + self.r.update({'list': roots, 'sol': homogen_sol, 'simpliy_flag': simplify_flag}) + gsol = _solve_undetermined_coefficients(eq, f(x), order, self.r, self.trialset) + if simplify_flag: + gsol = _get_simplified_sol([gsol], f(x), collectterms) + return [gsol] + + +class NthLinearEulerEqHomogeneous(SingleODESolver): + r""" + Solves an `n`\th order linear homogeneous variable-coefficient + Cauchy-Euler equidimensional ordinary differential equation. + + This is an equation with form `0 = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + These equations can be solved in a general manner, by substituting + solutions of the form `f(x) = x^r`, and deriving a characteristic equation + for `r`. When there are repeated roots, we include extra terms of the + form `C_{r k} \ln^k(x) x^r`, where `C_{r k}` is an arbitrary integration + constant, `r` is a root of the characteristic equation, and `k` ranges + over the multiplicity of `r`. In the cases where the roots are complex, + solutions of the form `C_1 x^a \sin(b \log(x)) + C_2 x^a \cos(b \log(x))` + are returned, based on expansions with Euler's formula. The general + solution is the sum of the terms found. If SymPy cannot find exact roots + to the characteristic equation, a + :py:obj:`~.ComplexRootOf` instance will be returned + instead. + + >>> from sympy import Function, dsolve + >>> from sympy.abc import x + >>> f = Function('f') + >>> dsolve(4*x**2*f(x).diff(x, 2) + f(x), f(x), + ... hint='nth_linear_euler_eq_homogeneous') + ... # doctest: +NORMALIZE_WHITESPACE + Eq(f(x), sqrt(x)*(C1 + C2*log(x))) + + Note that because this method does not involve integration, there is no + ``nth_linear_euler_eq_homogeneous_Integral`` hint. + + The following is for internal use: + + - ``returns = 'sol'`` returns the solution to the ODE. + - ``returns = 'list'`` returns a list of linearly independent solutions, + corresponding to the fundamental solution set, for use with non + homogeneous solution methods like variation of parameters and + undetermined coefficients. Note that, though the solutions should be + linearly independent, this function does not explicitly check that. You + can do ``assert simplify(wronskian(sollist)) != 0`` to check for linear + independence. Also, ``assert len(sollist) == order`` will need to pass. + - ``returns = 'both'``, return a dictionary ``{'sol': , + 'list': }``. + + Examples + ======== + + >>> from sympy import Function, dsolve, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = f(x).diff(x, 2)*x**2 - 4*f(x).diff(x)*x + 6*f(x) + >>> pprint(dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_homogeneous')) + 2 + f(x) = x *(C1 + C2*x) + + References + ========== + + - https://en.wikipedia.org/wiki/Cauchy%E2%80%93Euler_equation + - C. Bender & S. Orszag, "Advanced Mathematical Methods for Scientists and + Engineers", Springer 1999, pp. 12 + + # indirect doctest + + """ + hint = "nth_linear_euler_eq_homogeneous" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if not self.r[-1]: + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + fx = self.ode_problem.func + eq = self.ode_problem.eq + homogen_sol = _get_euler_characteristic_eq_sols(eq, fx, self.r)[0] + return [homogen_sol] + + +class NthLinearEulerEqNonhomogeneousVariationOfParameters(SingleODESolver): + r""" + Solves an `n`\th order linear non homogeneous Cauchy-Euler equidimensional + ordinary differential equation using variation of parameters. + + This is an equation with form `g(x) = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + This method works by assuming that the particular solution takes the form + + .. math:: \sum_{x=1}^{n} c_i(x) y_i(x) {a_n} {x^n} \text{, } + + where `y_i` is the `i`\th solution to the homogeneous equation. The + solution is then solved using Wronskian's and Cramer's Rule. The + particular solution is given by multiplying eq given below with `a_n x^{n}` + + .. math:: \sum_{x=1}^n \left( \int \frac{W_i(x)}{W(x)} \, dx + \right) y_i(x) \text{, } + + where `W(x)` is the Wronskian of the fundamental system (the system of `n` + linearly independent solutions to the homogeneous equation), and `W_i(x)` + is the Wronskian of the fundamental system with the `i`\th column replaced + with `[0, 0, \cdots, 0, \frac{x^{- n}}{a_n} g{\left(x \right)}]`. + + This method is general enough to solve any `n`\th order inhomogeneous + linear differential equation, but sometimes SymPy cannot simplify the + Wronskian well enough to integrate it. If this method hangs, try using the + ``nth_linear_constant_coeff_variation_of_parameters_Integral`` hint and + simplifying the integrals manually. Also, prefer using + ``nth_linear_constant_coeff_undetermined_coefficients`` when it + applies, because it does not use integration, making it faster and more + reliable. + + Warning, using simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters' in + :py:meth:`~sympy.solvers.ode.dsolve` may cause it to hang, because it will + not attempt to simplify the Wronskian before integrating. It is + recommended that you only use simplify=False with + 'nth_linear_constant_coeff_variation_of_parameters_Integral' for this + method, especially if the solution to the homogeneous equation has + trigonometric functions in it. + + Examples + ======== + + >>> from sympy import Function, dsolve, Derivative + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - x**4 + >>> dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_nonhomogeneous_variation_of_parameters').expand() + Eq(f(x), C1*x + C2*x**2 + x**4/6) + + """ + hint = "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters" + has_integral = True + + def _matches(self): + eq = self.ode_problem.eq_preprocessed + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if self.r[-1]: + does_match = True + + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + x = self.ode_problem.sym + order = self.ode_problem.order + homogen_sol, roots = _get_euler_characteristic_eq_sols(eq, f(x), self.r) + self.r[-1] = self.r[-1]/self.r[order] + sol = _solve_variation_of_parameters(eq, f(x), roots, homogen_sol, order, self.r, simplify_flag) + + return [Eq(f(x), homogen_sol.rhs + (sol.rhs - homogen_sol.rhs)*self.r[order])] + + +class NthLinearEulerEqNonhomogeneousUndeterminedCoefficients(SingleODESolver): + r""" + Solves an `n`\th order linear non homogeneous Cauchy-Euler equidimensional + ordinary differential equation using undetermined coefficients. + + This is an equation with form `g(x) = a_0 f(x) + a_1 x f'(x) + a_2 x^2 f''(x) + \cdots`. + + These equations can be solved in a general manner, by substituting + solutions of the form `x = exp(t)`, and deriving a characteristic equation + of form `g(exp(t)) = b_0 f(t) + b_1 f'(t) + b_2 f''(t) \cdots` which can + be then solved by nth_linear_constant_coeff_undetermined_coefficients if + g(exp(t)) has finite number of linearly independent derivatives. + + Functions that fit this requirement are finite sums functions of the form + `a x^i e^{b x} \sin(c x + d)` or `a x^i e^{b x} \cos(c x + d)`, where `i` + is a non-negative integer and `a`, `b`, `c`, and `d` are constants. For + example any polynomial in `x`, functions like `x^2 e^{2 x}`, `x \sin(x)`, + and `e^x \cos(x)` can all be used. Products of `\sin`'s and `\cos`'s have + a finite number of derivatives, because they can be expanded into `\sin(a + x)` and `\cos(b x)` terms. However, SymPy currently cannot do that + expansion, so you will need to manually rewrite the expression in terms of + the above to use this method. So, for example, you will need to manually + convert `\sin^2(x)` into `(1 + \cos(2 x))/2` to properly apply the method + of undetermined coefficients on it. + + After replacement of x by exp(t), this method works by creating a trial function + from the expression and all of its linear independent derivatives and + substituting them into the original ODE. The coefficients for each term + will be a system of linear equations, which are be solved for and + substituted, giving the solution. If any of the trial functions are linearly + dependent on the solution to the homogeneous equation, they are multiplied + by sufficient `x` to make them linearly independent. + + Examples + ======== + + >>> from sympy import dsolve, Function, Derivative, log + >>> from sympy.abc import x + >>> f = Function('f') + >>> eq = x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - log(x) + >>> dsolve(eq, f(x), + ... hint='nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients').expand() + Eq(f(x), C1*x + C2*x**2 + log(x)/2 + 3/4) + + """ + hint = "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + match = self.ode_problem.get_linear_coefficients(eq, f(x), order) + self.r = None + does_match = False + + if order and match: + coeff = match[order] + factor = x**order / coeff + self.r = {i: factor*match[i] for i in match} + if self.r and all(_test_term(self.r[i], f(x), i) for i in + self.r if i >= 0): + if self.r[-1]: + e, re = posify(self.r[-1].subs(x, exp(x))) + undetcoeff = _undetermined_coefficients_match(e.subs(re), x) + if undetcoeff['test']: + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + chareq, eq, symbol = S.Zero, S.Zero, Dummy('x') + for i in self.r.keys(): + if i >= 0: + chareq += (self.r[i]*diff(x**symbol, x, i)*x**-symbol).expand() + + for i in range(1, degree(Poly(chareq, symbol))+1): + eq += chareq.coeff(symbol**i)*diff(f(x), x, i) + + if chareq.as_coeff_add(symbol)[0]: + eq += chareq.as_coeff_add(symbol)[0]*f(x) + e, re = posify(self.r[-1].subs(x, exp(x))) + eq += e.subs(re) + + self.const_undet_instance = NthLinearConstantCoeffUndeterminedCoefficients(SingleODEProblem(eq, f(x), x)) + sol = self.const_undet_instance.get_general_solution(simplify = simplify_flag)[0] + sol = sol.subs(x, log(x)) + sol = sol.subs(f(log(x)), f(x)).expand() + + return [sol] + + +class SecondLinearBessel(SingleODESolver): + r""" + Gives solution of the Bessel differential equation + + .. math :: x^2 \frac{d^2y}{dx^2} + x \frac{dy}{dx} y(x) + (x^2-n^2) y(x) + + if `n` is integer then the solution is of the form ``Eq(f(x), C0 besselj(n,x) + + C1 bessely(n,x))`` as both the solutions are linearly independent else if + `n` is a fraction then the solution is of the form ``Eq(f(x), C0 besselj(n,x) + + C1 besselj(-n,x))`` which can also transform into ``Eq(f(x), C0 besselj(n,x) + + C1 bessely(n,x))``. + + Examples + ======== + + >>> from sympy.abc import x + >>> from sympy import Symbol + >>> v = Symbol('v', positive=True) + >>> from sympy import dsolve, Function + >>> f = Function('f') + >>> y = f(x) + >>> genform = x**2*y.diff(x, 2) + x*y.diff(x) + (x**2 - v**2)*y + >>> dsolve(genform) + Eq(f(x), C1*besselj(v, x) + C2*bessely(v, x)) + + References + ========== + + https://math24.net/bessel-differential-equation.html + + """ + hint = "2nd_linear_bessel" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f.diff(x) + a = Wild('a', exclude=[f,df]) + b = Wild('b', exclude=[x, f,df]) + a4 = Wild('a4', exclude=[x,f,df]) + b4 = Wild('b4', exclude=[x,f,df]) + c4 = Wild('c4', exclude=[x,f,df]) + d4 = Wild('d4', exclude=[x,f,df]) + a3 = Wild('a3', exclude=[f, df, f.diff(x, 2)]) + b3 = Wild('b3', exclude=[f, df, f.diff(x, 2)]) + c3 = Wild('c3', exclude=[f, df, f.diff(x, 2)]) + deq = a3*(f.diff(x, 2)) + b3*df + c3*f + r = collect(eq, + [f.diff(x, 2), df, f]).match(deq) + if order == 2 and r: + if not all(r[key].is_polynomial() for key in r): + n, d = eq.as_numer_denom() + eq = expand(n) + r = collect(eq, + [f.diff(x, 2), df, f]).match(deq) + + if r and r[a3] != 0: + # leading coeff of f(x).diff(x, 2) + coeff = factor(r[a3]).match(a4*(x-b)**b4) + + if coeff: + # if coeff[b4] = 0 means constant coefficient + if coeff[b4] == 0: + return False + point = coeff[b] + else: + return False + + if point: + r[a3] = simplify(r[a3].subs(x, x+point)) + r[b3] = simplify(r[b3].subs(x, x+point)) + r[c3] = simplify(r[c3].subs(x, x+point)) + + # making a3 in the form of x**2 + r[a3] = cancel(r[a3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + r[b3] = cancel(r[b3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + r[c3] = cancel(r[c3]/(coeff[a4]*(x)**(-2+coeff[b4]))) + # checking if b3 is of form c*(x-b) + coeff1 = factor(r[b3]).match(a4*(x)) + if coeff1 is None: + return False + # c3 maybe of very complex form so I am simply checking (a - b) form + # if yes later I will match with the standerd form of bessel in a and b + # a, b are wild variable defined above. + _coeff2 = expand(r[c3]).match(a - b) + if _coeff2 is None: + return False + # matching with standerd form for c3 + coeff2 = factor(_coeff2[a]).match(c4**2*(x)**(2*a4)) + if coeff2 is None: + return False + + if _coeff2[b] == 0: + coeff2[d4] = 0 + else: + coeff2[d4] = factor(_coeff2[b]).match(d4**2)[d4] + + self.rn = {'n':coeff2[d4], 'a4':coeff2[c4], 'd4':coeff2[a4]} + self.rn['c4'] = coeff1[a4] + self.rn['b4'] = point + return True + return False + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + n = self.rn['n'] + a4 = self.rn['a4'] + c4 = self.rn['c4'] + d4 = self.rn['d4'] + b4 = self.rn['b4'] + n = sqrt(n**2 + Rational(1, 4)*(c4 - 1)**2) + (C1, C2) = self.ode_problem.get_numbered_constants(num=2) + return [Eq(f(x), ((x**(Rational(1-c4,2)))*(C1*besselj(n/d4,a4*x**d4/d4) + + C2*bessely(n/d4,a4*x**d4/d4))).subs(x, x-b4))] + + +class SecondLinearAiry(SingleODESolver): + r""" + Gives solution of the Airy differential equation + + .. math :: \frac{d^2y}{dx^2} + (a + b x) y(x) = 0 + + in terms of Airy special functions airyai and airybi. + + Examples + ======== + + >>> from sympy import dsolve, Function + >>> from sympy.abc import x + >>> f = Function("f") + >>> eq = f(x).diff(x, 2) - x*f(x) + >>> dsolve(eq) + Eq(f(x), C1*airyai(x) + C2*airybi(x)) + """ + hint = "2nd_linear_airy" + has_integral = False + + def _matches(self): + eq = self.ode_problem.eq_high_order_free + f = self.ode_problem.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f.diff(x) + a4 = Wild('a4', exclude=[x,f,df]) + b4 = Wild('b4', exclude=[x,f,df]) + match = self.ode_problem.get_linear_coefficients(eq, f, order) + does_match = False + if order == 2 and match and match[2] != 0: + if match[1].is_zero: + self.rn = cancel(match[0]/match[2]).match(a4+b4*x) + if self.rn and self.rn[b4] != 0: + self.rn = {'b':self.rn[a4],'m':self.rn[b4]} + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + f = self.ode_problem.func.func + x = self.ode_problem.sym + (C1, C2) = self.ode_problem.get_numbered_constants(num=2) + b = self.rn['b'] + m = self.rn['m'] + if m.is_positive: + arg = - b/cbrt(m)**2 - cbrt(m)*x + elif m.is_negative: + arg = - b/cbrt(-m)**2 + cbrt(-m)*x + else: + arg = - b/cbrt(-m)**2 + cbrt(-m)*x + + return [Eq(f(x), C1*airyai(arg) + C2*airybi(arg))] + + +class LieGroup(SingleODESolver): + r""" + This hint implements the Lie group method of solving first order differential + equations. The aim is to convert the given differential equation from the + given coordinate system into another coordinate system where it becomes + invariant under the one-parameter Lie group of translations. The converted + ODE can be easily solved by quadrature. It makes use of the + :py:meth:`sympy.solvers.ode.infinitesimals` function which returns the + infinitesimals of the transformation. + + The coordinates `r` and `s` can be found by solving the following Partial + Differential Equations. + + .. math :: \xi\frac{\partial r}{\partial x} + \eta\frac{\partial r}{\partial y} + = 0 + + .. math :: \xi\frac{\partial s}{\partial x} + \eta\frac{\partial s}{\partial y} + = 1 + + The differential equation becomes separable in the new coordinate system + + .. math :: \frac{ds}{dr} = \frac{\frac{\partial s}{\partial x} + + h(x, y)\frac{\partial s}{\partial y}}{ + \frac{\partial r}{\partial x} + h(x, y)\frac{\partial r}{\partial y}} + + After finding the solution by integration, it is then converted back to the original + coordinate system by substituting `r` and `s` in terms of `x` and `y` again. + + Examples + ======== + + >>> from sympy import Function, dsolve, exp, pprint + >>> from sympy.abc import x + >>> f = Function('f') + >>> pprint(dsolve(f(x).diff(x) + 2*x*f(x) - x*exp(-x**2), f(x), + ... hint='lie_group')) + / 2\ 2 + | x | -x + f(x) = |C1 + --|*e + \ 2 / + + + References + ========== + + - Solving differential equations by Symmetry Groups, + John Starrett, pp. 1 - pp. 14 + + """ + hint = "lie_group" + has_integral = False + + def _has_additional_params(self): + return 'xi' in self.ode_problem.params and 'eta' in self.ode_problem.params + + def _matches(self): + eq = self.ode_problem.eq + f = self.ode_problem.func.func + order = self.ode_problem.order + x = self.ode_problem.sym + df = f(x).diff(x) + y = Dummy('y') + d = Wild('d', exclude=[df, f(x).diff(x, 2)]) + e = Wild('e', exclude=[df]) + does_match = False + if self._has_additional_params() and order == 1: + xi = self.ode_problem.params['xi'] + eta = self.ode_problem.params['eta'] + self.r3 = {'xi': xi, 'eta': eta} + r = collect(eq, df, exact=True).match(d + e * df) + if r: + r['d'] = d + r['e'] = e + r['y'] = y + r[d] = r[d].subs(f(x), y) + r[e] = r[e].subs(f(x), y) + self.r3.update(r) + does_match = True + return does_match + + def _get_general_solution(self, *, simplify_flag: bool = True): + eq = self.ode_problem.eq + x = self.ode_problem.sym + func = self.ode_problem.func + order = self.ode_problem.order + df = func.diff(x) + + try: + eqsol = solve(eq, df) + except NotImplementedError: + eqsol = [] + + desols = [] + for s in eqsol: + sol = _ode_lie_group(s, func, order, match=self.r3) + if sol: + desols.extend(sol) + + if desols == []: + raise NotImplementedError("The given ODE " + str(eq) + " cannot be solved by" + + " the lie group method") + return desols + + +solver_map = { + 'factorable': Factorable, + 'nth_linear_constant_coeff_homogeneous': NthLinearConstantCoeffHomogeneous, + 'nth_linear_euler_eq_homogeneous': NthLinearEulerEqHomogeneous, + 'nth_linear_constant_coeff_undetermined_coefficients': NthLinearConstantCoeffUndeterminedCoefficients, + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients': NthLinearEulerEqNonhomogeneousUndeterminedCoefficients, + 'separable': Separable, + '1st_exact': FirstExact, + '1st_linear': FirstLinear, + 'Bernoulli': Bernoulli, + 'Riccati_special_minus2': RiccatiSpecial, + '1st_rational_riccati': RationalRiccati, + '1st_homogeneous_coeff_best': HomogeneousCoeffBest, + '1st_homogeneous_coeff_subs_indep_div_dep': HomogeneousCoeffSubsIndepDivDep, + '1st_homogeneous_coeff_subs_dep_div_indep': HomogeneousCoeffSubsDepDivIndep, + 'almost_linear': AlmostLinear, + 'linear_coefficients': LinearCoefficients, + 'separable_reduced': SeparableReduced, + 'nth_linear_constant_coeff_variation_of_parameters': NthLinearConstantCoeffVariationOfParameters, + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters': NthLinearEulerEqNonhomogeneousVariationOfParameters, + 'Liouville': Liouville, + '2nd_linear_airy': SecondLinearAiry, + '2nd_linear_bessel': SecondLinearBessel, + '2nd_hypergeometric': SecondHypergeometric, + 'nth_order_reducible': NthOrderReducible, + '2nd_nonlinear_autonomous_conserved': SecondNonlinearAutonomousConserved, + 'nth_algebraic': NthAlgebraic, + 'lie_group': LieGroup, + } + +# Avoid circular import: +from .ode import dsolve, ode_sol_simplicity, odesimp, homogeneous_order diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/subscheck.py b/lib/python3.10/site-packages/sympy/solvers/ode/subscheck.py new file mode 100644 index 0000000000000000000000000000000000000000..6ac7fba7d364bf599e928ccf591b5bef096576d0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/subscheck.py @@ -0,0 +1,392 @@ +from sympy.core import S, Pow +from sympy.core.function import (Derivative, AppliedUndef, diff) +from sympy.core.relational import Equality, Eq +from sympy.core.symbol import Dummy +from sympy.core.sympify import sympify + +from sympy.logic.boolalg import BooleanAtom +from sympy.functions import exp +from sympy.series import Order +from sympy.simplify.simplify import simplify, posify, besselsimp +from sympy.simplify.trigsimp import trigsimp +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.solvers import solve +from sympy.solvers.deutils import _preprocess, ode_order +from sympy.utilities.iterables import iterable, is_sequence + + +def sub_func_doit(eq, func, new): + r""" + When replacing the func with something else, we usually want the + derivative evaluated, so this function helps in making that happen. + + Examples + ======== + + >>> from sympy import Derivative, symbols, Function + >>> from sympy.solvers.ode.subscheck import sub_func_doit + >>> x, z = symbols('x, z') + >>> y = Function('y') + + >>> sub_func_doit(3*Derivative(y(x), x) - 1, y(x), x) + 2 + + >>> sub_func_doit(x*Derivative(y(x), x) - y(x)**2 + y(x), y(x), + ... 1/(x*(z + 1/x))) + x*(-1/(x**2*(z + 1/x)) + 1/(x**3*(z + 1/x)**2)) + 1/(x*(z + 1/x)) + ...- 1/(x**2*(z + 1/x)**2) + """ + reps= {func: new} + for d in eq.atoms(Derivative): + if d.expr == func: + reps[d] = new.diff(*d.variable_count) + else: + reps[d] = d.xreplace({func: new}).doit(deep=False) + return eq.xreplace(reps) + + +def checkodesol(ode, sol, func=None, order='auto', solve_for_func=True): + r""" + Substitutes ``sol`` into ``ode`` and checks that the result is ``0``. + + This works when ``func`` is one function, like `f(x)` or a list of + functions like `[f(x), g(x)]` when `ode` is a system of ODEs. ``sol`` can + be a single solution or a list of solutions. Each solution may be an + :py:class:`~sympy.core.relational.Equality` that the solution satisfies, + e.g. ``Eq(f(x), C1), Eq(f(x) + C1, 0)``; or simply an + :py:class:`~sympy.core.expr.Expr`, e.g. ``f(x) - C1``. In most cases it + will not be necessary to explicitly identify the function, but if the + function cannot be inferred from the original equation it can be supplied + through the ``func`` argument. + + If a sequence of solutions is passed, the same sort of container will be + used to return the result for each solution. + + It tries the following methods, in order, until it finds zero equivalence: + + 1. Substitute the solution for `f` in the original equation. This only + works if ``ode`` is solved for `f`. It will attempt to solve it first + unless ``solve_for_func == False``. + 2. Take `n` derivatives of the solution, where `n` is the order of + ``ode``, and check to see if that is equal to the solution. This only + works on exact ODEs. + 3. Take the 1st, 2nd, ..., `n`\th derivatives of the solution, each time + solving for the derivative of `f` of that order (this will always be + possible because `f` is a linear operator). Then back substitute each + derivative into ``ode`` in reverse order. + + This function returns a tuple. The first item in the tuple is ``True`` if + the substitution results in ``0``, and ``False`` otherwise. The second + item in the tuple is what the substitution results in. It should always + be ``0`` if the first item is ``True``. Sometimes this function will + return ``False`` even when an expression is identically equal to ``0``. + This happens when :py:meth:`~sympy.simplify.simplify.simplify` does not + reduce the expression to ``0``. If an expression returned by this + function vanishes identically, then ``sol`` really is a solution to + the ``ode``. + + If this function seems to hang, it is probably because of a hard + simplification. + + To use this function to test, test the first item of the tuple. + + Examples + ======== + + >>> from sympy import (Eq, Function, checkodesol, symbols, + ... Derivative, exp) + >>> x, C1, C2 = symbols('x,C1,C2') + >>> f, g = symbols('f g', cls=Function) + >>> checkodesol(f(x).diff(x), Eq(f(x), C1)) + (True, 0) + >>> assert checkodesol(f(x).diff(x), C1)[0] + >>> assert not checkodesol(f(x).diff(x), x)[0] + >>> checkodesol(f(x).diff(x, 2), x**2) + (False, 2) + + >>> eqs = [Eq(Derivative(f(x), x), f(x)), Eq(Derivative(g(x), x), g(x))] + >>> sol = [Eq(f(x), C1*exp(x)), Eq(g(x), C2*exp(x))] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + """ + if iterable(ode): + return checksysodesol(ode, sol, func=func) + + if not isinstance(ode, Equality): + ode = Eq(ode, 0) + if func is None: + try: + _, func = _preprocess(ode.lhs) + except ValueError: + funcs = [s.atoms(AppliedUndef) for s in ( + sol if is_sequence(sol, set) else [sol])] + funcs = set().union(*funcs) + if len(funcs) != 1: + raise ValueError( + 'must pass func arg to checkodesol for this case.') + func = funcs.pop() + if not isinstance(func, AppliedUndef) or len(func.args) != 1: + raise ValueError( + "func must be a function of one variable, not %s" % func) + if is_sequence(sol, set): + return type(sol)([checkodesol(ode, i, order=order, solve_for_func=solve_for_func) for i in sol]) + + if not isinstance(sol, Equality): + sol = Eq(func, sol) + elif sol.rhs == func: + sol = sol.reversed + + if order == 'auto': + order = ode_order(ode, func) + solved = sol.lhs == func and not sol.rhs.has(func) + if solve_for_func and not solved: + rhs = solve(sol, func) + if rhs: + eqs = [Eq(func, t) for t in rhs] + if len(rhs) == 1: + eqs = eqs[0] + return checkodesol(ode, eqs, order=order, + solve_for_func=False) + + x = func.args[0] + + # Handle series solutions here + if sol.has(Order): + assert sol.lhs == func + Oterm = sol.rhs.getO() + solrhs = sol.rhs.removeO() + + Oexpr = Oterm.expr + assert isinstance(Oexpr, Pow) + sorder = Oexpr.exp + assert Oterm == Order(x**sorder) + + odesubs = (ode.lhs-ode.rhs).subs(func, solrhs).doit().expand() + + neworder = Order(x**(sorder - order)) + odesubs = odesubs + neworder + assert odesubs.getO() == neworder + residual = odesubs.removeO() + + return (residual == 0, residual) + + s = True + testnum = 0 + while s: + if testnum == 0: + # First pass, try substituting a solved solution directly into the + # ODE. This has the highest chance of succeeding. + ode_diff = ode.lhs - ode.rhs + + if sol.lhs == func: + s = sub_func_doit(ode_diff, func, sol.rhs) + s = besselsimp(s) + else: + testnum += 1 + continue + ss = simplify(s.rewrite(exp)) + if ss: + # with the new numer_denom in power.py, if we do a simple + # expansion then testnum == 0 verifies all solutions. + s = ss.expand(force=True) + else: + s = 0 + testnum += 1 + elif testnum == 1: + # Second pass. If we cannot substitute f, try seeing if the nth + # derivative is equal, this will only work for odes that are exact, + # by definition. + s = simplify( + trigsimp(diff(sol.lhs, x, order) - diff(sol.rhs, x, order)) - + trigsimp(ode.lhs) + trigsimp(ode.rhs)) + # s2 = simplify( + # diff(sol.lhs, x, order) - diff(sol.rhs, x, order) - \ + # ode.lhs + ode.rhs) + testnum += 1 + elif testnum == 2: + # Third pass. Try solving for df/dx and substituting that into the + # ODE. Thanks to Chris Smith for suggesting this method. Many of + # the comments below are his, too. + # The method: + # - Take each of 1..n derivatives of the solution. + # - Solve each nth derivative for d^(n)f/dx^(n) + # (the differential of that order) + # - Back substitute into the ODE in decreasing order + # (i.e., n, n-1, ...) + # - Check the result for zero equivalence + if sol.lhs == func and not sol.rhs.has(func): + diffsols = {0: sol.rhs} + elif sol.rhs == func and not sol.lhs.has(func): + diffsols = {0: sol.lhs} + else: + diffsols = {} + sol = sol.lhs - sol.rhs + for i in range(1, order + 1): + # Differentiation is a linear operator, so there should always + # be 1 solution. Nonetheless, we test just to make sure. + # We only need to solve once. After that, we automatically + # have the solution to the differential in the order we want. + if i == 1: + ds = sol.diff(x) + try: + sdf = solve(ds, func.diff(x, i)) + if not sdf: + raise NotImplementedError + except NotImplementedError: + testnum += 1 + break + else: + diffsols[i] = sdf[0] + else: + # This is what the solution says df/dx should be. + diffsols[i] = diffsols[i - 1].diff(x) + + # Make sure the above didn't fail. + if testnum > 2: + continue + else: + # Substitute it into ODE to check for self consistency. + lhs, rhs = ode.lhs, ode.rhs + for i in range(order, -1, -1): + if i == 0 and 0 not in diffsols: + # We can only substitute f(x) if the solution was + # solved for f(x). + break + lhs = sub_func_doit(lhs, func.diff(x, i), diffsols[i]) + rhs = sub_func_doit(rhs, func.diff(x, i), diffsols[i]) + ode_or_bool = Eq(lhs, rhs) + ode_or_bool = simplify(ode_or_bool) + + if isinstance(ode_or_bool, (bool, BooleanAtom)): + if ode_or_bool: + lhs = rhs = S.Zero + else: + lhs = ode_or_bool.lhs + rhs = ode_or_bool.rhs + # No sense in overworking simplify -- just prove that the + # numerator goes to zero + num = trigsimp((lhs - rhs).as_numer_denom()[0]) + # since solutions are obtained using force=True we test + # using the same level of assumptions + ## replace function with dummy so assumptions will work + _func = Dummy('func') + num = num.subs(func, _func) + ## posify the expression + num, reps = posify(num) + s = simplify(num).xreplace(reps).xreplace({_func: func}) + testnum += 1 + else: + break + + if not s: + return (True, s) + elif s is True: # The code above never was able to change s + raise NotImplementedError("Unable to test if " + str(sol) + + " is a solution to " + str(ode) + ".") + else: + return (False, s) + + +def checksysodesol(eqs, sols, func=None): + r""" + Substitutes corresponding ``sols`` for each functions into each ``eqs`` and + checks that the result of substitutions for each equation is ``0``. The + equations and solutions passed can be any iterable. + + This only works when each ``sols`` have one function only, like `x(t)` or `y(t)`. + For each function, ``sols`` can have a single solution or a list of solutions. + In most cases it will not be necessary to explicitly identify the function, + but if the function cannot be inferred from the original equation it + can be supplied through the ``func`` argument. + + When a sequence of equations is passed, the same sequence is used to return + the result for each equation with each function substituted with corresponding + solutions. + + It tries the following method to find zero equivalence for each equation: + + Substitute the solutions for functions, like `x(t)` and `y(t)` into the + original equations containing those functions. + This function returns a tuple. The first item in the tuple is ``True`` if + the substitution results for each equation is ``0``, and ``False`` otherwise. + The second item in the tuple is what the substitution results in. Each element + of the ``list`` should always be ``0`` corresponding to each equation if the + first item is ``True``. Note that sometimes this function may return ``False``, + but with an expression that is identically equal to ``0``, instead of returning + ``True``. This is because :py:meth:`~sympy.simplify.simplify.simplify` cannot + reduce the expression to ``0``. If an expression returned by each function + vanishes identically, then ``sols`` really is a solution to ``eqs``. + + If this function seems to hang, it is probably because of a difficult simplification. + + Examples + ======== + + >>> from sympy import Eq, diff, symbols, sin, cos, exp, sqrt, S, Function + >>> from sympy.solvers.ode.subscheck import checksysodesol + >>> C1, C2 = symbols('C1:3') + >>> t = symbols('t') + >>> x, y = symbols('x, y', cls=Function) + >>> eq = (Eq(diff(x(t),t), x(t) + y(t) + 17), Eq(diff(y(t),t), -2*x(t) + y(t) + 12)) + >>> sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - S(5)/3), + ... Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - S(46)/3)] + >>> checksysodesol(eq, sol) + (True, [0, 0]) + >>> eq = (Eq(diff(x(t),t),x(t)*y(t)**4), Eq(diff(y(t),t),y(t)**3)) + >>> sol = [Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), -sqrt(2)*sqrt(-1/(C2 + t))/2), + ... Eq(x(t), C1*exp(-1/(4*(C2 + t)))), Eq(y(t), sqrt(2)*sqrt(-1/(C2 + t))/2)] + >>> checksysodesol(eq, sol) + (True, [0, 0]) + + """ + def _sympify(eq): + return list(map(sympify, eq if iterable(eq) else [eq])) + eqs = _sympify(eqs) + for i in range(len(eqs)): + if isinstance(eqs[i], Equality): + eqs[i] = eqs[i].lhs - eqs[i].rhs + if func is None: + funcs = [] + for eq in eqs: + derivs = eq.atoms(Derivative) + func = set().union(*[d.atoms(AppliedUndef) for d in derivs]) + funcs.extend(func) + funcs = list(set(funcs)) + if not all(isinstance(func, AppliedUndef) and len(func.args) == 1 for func in funcs)\ + and len({func.args for func in funcs})!=1: + raise ValueError("func must be a function of one variable, not %s" % func) + for sol in sols: + if len(sol.atoms(AppliedUndef)) != 1: + raise ValueError("solutions should have one function only") + if len(funcs) != len({sol.lhs for sol in sols}): + raise ValueError("number of solutions provided does not match the number of equations") + dictsol = {} + for sol in sols: + func = list(sol.atoms(AppliedUndef))[0] + if sol.rhs == func: + sol = sol.reversed + solved = sol.lhs == func and not sol.rhs.has(func) + if not solved: + rhs = solve(sol, func) + if not rhs: + raise NotImplementedError + else: + rhs = sol.rhs + dictsol[func] = rhs + checkeq = [] + for eq in eqs: + for func in funcs: + eq = sub_func_doit(eq, func, dictsol[func]) + ss = simplify(eq) + if ss != 0: + eq = ss.expand(force=True) + if eq != 0: + eq = sqrtdenest(eq).simplify() + else: + eq = 0 + checkeq.append(eq) + if len(set(checkeq)) == 1 and list(set(checkeq))[0] == 0: + return (True, checkeq) + else: + return (False, checkeq) diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/systems.py b/lib/python3.10/site-packages/sympy/solvers/ode/systems.py new file mode 100644 index 0000000000000000000000000000000000000000..2d2c9b57a969c7fb5c67c06ce952fa398e22a48d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/systems.py @@ -0,0 +1,2135 @@ +from sympy.core import Add, Mul, S +from sympy.core.containers import Tuple +from sympy.core.exprtools import factor_terms +from sympy.core.numbers import I +from sympy.core.relational import Eq, Equality +from sympy.core.sorting import default_sort_key, ordered +from sympy.core.symbol import Dummy, Symbol +from sympy.core.function import (expand_mul, expand, Derivative, + AppliedUndef, Function, Subs) +from sympy.functions import (exp, im, cos, sin, re, Piecewise, + piecewise_fold, sqrt, log) +from sympy.functions.combinatorial.factorials import factorial +from sympy.matrices import zeros, Matrix, NonSquareMatrixError, MatrixBase, eye +from sympy.polys import Poly, together +from sympy.simplify import collect, radsimp, signsimp # type: ignore +from sympy.simplify.powsimp import powdenest, powsimp +from sympy.simplify.ratsimp import ratsimp +from sympy.simplify.simplify import simplify +from sympy.sets.sets import FiniteSet +from sympy.solvers.deutils import ode_order +from sympy.solvers.solveset import NonlinearError, solveset +from sympy.utilities.iterables import (connected_components, iterable, + strongly_connected_components) +from sympy.utilities.misc import filldedent +from sympy.integrals.integrals import Integral, integrate + + +def _get_func_order(eqs, funcs): + return {func: max(ode_order(eq, func) for eq in eqs) for func in funcs} + + +class ODEOrderError(ValueError): + """Raised by linear_ode_to_matrix if the system has the wrong order""" + pass + + +class ODENonlinearError(NonlinearError): + """Raised by linear_ode_to_matrix if the system is nonlinear""" + pass + + +def _simpsol(soleq): + lhs = soleq.lhs + sol = soleq.rhs + sol = powsimp(sol) + gens = list(sol.atoms(exp)) + p = Poly(sol, *gens, expand=False) + gens = [factor_terms(g) for g in gens] + if not gens: + gens = p.gens + syms = [Symbol('C1'), Symbol('C2')] + terms = [] + for coeff, monom in zip(p.coeffs(), p.monoms()): + coeff = piecewise_fold(coeff) + if isinstance(coeff, Piecewise): + coeff = Piecewise(*((ratsimp(coef).collect(syms), cond) for coef, cond in coeff.args)) + else: + coeff = ratsimp(coeff).collect(syms) + monom = Mul(*(g ** i for g, i in zip(gens, monom))) + terms.append(coeff * monom) + return Eq(lhs, Add(*terms)) + + +def _solsimp(e, t): + no_t, has_t = powsimp(expand_mul(e)).as_independent(t) + + no_t = ratsimp(no_t) + has_t = has_t.replace(exp, lambda a: exp(factor_terms(a))) + + return no_t + has_t + + +def simpsol(sol, wrt1, wrt2, doit=True): + """Simplify solutions from dsolve_system.""" + + # The parameter sol is the solution as returned by dsolve (list of Eq). + # + # The parameters wrt1 and wrt2 are lists of symbols to be collected for + # with those in wrt1 being collected for first. This allows for collecting + # on any factors involving the independent variable before collecting on + # the integration constants or vice versa using e.g.: + # + # sol = simpsol(sol, [t], [C1, C2]) # t first, constants after + # sol = simpsol(sol, [C1, C2], [t]) # constants first, t after + # + # If doit=True (default) then simpsol will begin by evaluating any + # unevaluated integrals. Since many integrals will appear multiple times + # in the solutions this is done intelligently by computing each integral + # only once. + # + # The strategy is to first perform simple cancellation with factor_terms + # and then multiply out all brackets with expand_mul. This gives an Add + # with many terms. + # + # We split each term into two multiplicative factors dep and coeff where + # all factors that involve wrt1 are in dep and any constant factors are in + # coeff e.g. + # sqrt(2)*C1*exp(t) -> ( exp(t), sqrt(2)*C1 ) + # + # The dep factors are simplified using powsimp to combine expanded + # exponential factors e.g. + # exp(a*t)*exp(b*t) -> exp(t*(a+b)) + # + # We then collect coefficients for all terms having the same (simplified) + # dep. The coefficients are then simplified using together and ratsimp and + # lastly by recursively applying the same transformation to the + # coefficients to collect on wrt2. + # + # Finally the result is recombined into an Add and signsimp is used to + # normalise any minus signs. + + def simprhs(rhs, rep, wrt1, wrt2): + """Simplify the rhs of an ODE solution""" + if rep: + rhs = rhs.subs(rep) + rhs = factor_terms(rhs) + rhs = simp_coeff_dep(rhs, wrt1, wrt2) + rhs = signsimp(rhs) + return rhs + + def simp_coeff_dep(expr, wrt1, wrt2=None): + """Split rhs into terms, split terms into dep and coeff and collect on dep""" + add_dep_terms = lambda e: e.is_Add and e.has(*wrt1) + expandable = lambda e: e.is_Mul and any(map(add_dep_terms, e.args)) + expand_func = lambda e: expand_mul(e, deep=False) + expand_mul_mod = lambda e: e.replace(expandable, expand_func) + terms = Add.make_args(expand_mul_mod(expr)) + dc = {} + for term in terms: + coeff, dep = term.as_independent(*wrt1, as_Add=False) + # Collect together the coefficients for terms that have the same + # dependence on wrt1 (after dep is normalised using simpdep). + dep = simpdep(dep, wrt1) + + # See if the dependence on t cancels out... + if dep is not S.One: + dep2 = factor_terms(dep) + if not dep2.has(*wrt1): + coeff *= dep2 + dep = S.One + + if dep not in dc: + dc[dep] = coeff + else: + dc[dep] += coeff + # Apply the method recursively to the coefficients but this time + # collecting on wrt2 rather than wrt2. + termpairs = ((simpcoeff(c, wrt2), d) for d, c in dc.items()) + if wrt2 is not None: + termpairs = ((simp_coeff_dep(c, wrt2), d) for c, d in termpairs) + return Add(*(c * d for c, d in termpairs)) + + def simpdep(term, wrt1): + """Normalise factors involving t with powsimp and recombine exp""" + def canonicalise(a): + # Using factor_terms here isn't quite right because it leads to things + # like exp(t*(1+t)) that we don't want. We do want to cancel factors + # and pull out a common denominator but ideally the numerator would be + # expressed as a standard form polynomial in t so we expand_mul + # and collect afterwards. + a = factor_terms(a) + num, den = a.as_numer_denom() + num = expand_mul(num) + num = collect(num, wrt1) + return num / den + + term = powsimp(term) + rep = {e: exp(canonicalise(e.args[0])) for e in term.atoms(exp)} + term = term.subs(rep) + return term + + def simpcoeff(coeff, wrt2): + """Bring to a common fraction and cancel with ratsimp""" + coeff = together(coeff) + if coeff.is_polynomial(): + # Calling ratsimp can be expensive. The main reason is to simplify + # sums of terms with irrational denominators so we limit ourselves + # to the case where the expression is polynomial in any symbols. + # Maybe there's a better approach... + coeff = ratsimp(radsimp(coeff)) + # collect on secondary variables first and any remaining symbols after + if wrt2 is not None: + syms = list(wrt2) + list(ordered(coeff.free_symbols - set(wrt2))) + else: + syms = list(ordered(coeff.free_symbols)) + coeff = collect(coeff, syms) + coeff = together(coeff) + return coeff + + # There are often repeated integrals. Collect unique integrals and + # evaluate each once and then substitute into the final result to replace + # all occurrences in each of the solution equations. + if doit: + integrals = set().union(*(s.atoms(Integral) for s in sol)) + rep = {i: factor_terms(i).doit() for i in integrals} + else: + rep = {} + + sol = [Eq(s.lhs, simprhs(s.rhs, rep, wrt1, wrt2)) for s in sol] + return sol + + +def linodesolve_type(A, t, b=None): + r""" + Helper function that determines the type of the system of ODEs for solving with :obj:`sympy.solvers.ode.systems.linodesolve()` + + Explanation + =========== + + This function takes in the coefficient matrix and/or the non-homogeneous term + and returns the type of the equation that can be solved by :obj:`sympy.solvers.ode.systems.linodesolve()`. + + If the system is constant coefficient homogeneous, then "type1" is returned + + If the system is constant coefficient non-homogeneous, then "type2" is returned + + If the system is non-constant coefficient homogeneous, then "type3" is returned + + If the system is non-constant coefficient non-homogeneous, then "type4" is returned + + If the system has a non-constant coefficient matrix which can be factorized into constant + coefficient matrix, then "type5" or "type6" is returned for when the system is homogeneous or + non-homogeneous respectively. + + Note that, if the system of ODEs is of "type3" or "type4", then along with the type, + the commutative antiderivative of the coefficient matrix is also returned. + + If the system cannot be solved by :obj:`sympy.solvers.ode.systems.linodesolve()`, then + NotImplementedError is raised. + + Parameters + ========== + + A : Matrix + Coefficient matrix of the system of ODEs + b : Matrix or None + Non-homogeneous term of the system. The default value is None. + If this argument is None, then the system is assumed to be homogeneous. + + Examples + ======== + + >>> from sympy import symbols, Matrix + >>> from sympy.solvers.ode.systems import linodesolve_type + >>> t = symbols("t") + >>> A = Matrix([[1, 1], [2, 3]]) + >>> b = Matrix([t, 1]) + + >>> linodesolve_type(A, t) + {'antiderivative': None, 'type_of_equation': 'type1'} + + >>> linodesolve_type(A, t, b=b) + {'antiderivative': None, 'type_of_equation': 'type2'} + + >>> A_t = Matrix([[1, t], [-t, 1]]) + + >>> linodesolve_type(A_t, t) + {'antiderivative': Matrix([ + [ t, t**2/2], + [-t**2/2, t]]), 'type_of_equation': 'type3'} + + >>> linodesolve_type(A_t, t, b=b) + {'antiderivative': Matrix([ + [ t, t**2/2], + [-t**2/2, t]]), 'type_of_equation': 'type4'} + + >>> A_non_commutative = Matrix([[1, t], [t, -1]]) + >>> linodesolve_type(A_non_commutative, t) + Traceback (most recent call last): + ... + NotImplementedError: + The system does not have a commutative antiderivative, it cannot be + solved by linodesolve. + + Returns + ======= + + Dict + + Raises + ====== + + NotImplementedError + When the coefficient matrix does not have a commutative antiderivative + + See Also + ======== + + linodesolve: Function for which linodesolve_type gets the information + + """ + + match = {} + is_non_constant = not _matrix_is_constant(A, t) + is_non_homogeneous = not (b is None or b.is_zero_matrix) + type = "type{}".format(int("{}{}".format(int(is_non_constant), int(is_non_homogeneous)), 2) + 1) + + B = None + match.update({"type_of_equation": type, "antiderivative": B}) + + if is_non_constant: + B, is_commuting = _is_commutative_anti_derivative(A, t) + if not is_commuting: + raise NotImplementedError(filldedent(''' + The system does not have a commutative antiderivative, it cannot be solved + by linodesolve. + ''')) + + match['antiderivative'] = B + match.update(_first_order_type5_6_subs(A, t, b=b)) + + return match + + +def _first_order_type5_6_subs(A, t, b=None): + match = {} + + factor_terms = _factor_matrix(A, t) + is_homogeneous = b is None or b.is_zero_matrix + + if factor_terms is not None: + t_ = Symbol("{}_".format(t)) + F_t = integrate(factor_terms[0], t) + inverse = solveset(Eq(t_, F_t), t) + + # Note: A simple way to check if a function is invertible + # or not. + if isinstance(inverse, FiniteSet) and not inverse.has(Piecewise)\ + and len(inverse) == 1: + + A = factor_terms[1] + if not is_homogeneous: + b = b / factor_terms[0] + b = b.subs(t, list(inverse)[0]) + type = "type{}".format(5 + (not is_homogeneous)) + match.update({'func_coeff': A, 'tau': F_t, + 't_': t_, 'type_of_equation': type, 'rhs': b}) + + return match + + +def linear_ode_to_matrix(eqs, funcs, t, order): + r""" + Convert a linear system of ODEs to matrix form + + Explanation + =========== + + Express a system of linear ordinary differential equations as a single + matrix differential equation [1]. For example the system $x' = x + y + 1$ + and $y' = x - y$ can be represented as + + .. math:: A_1 X' = A_0 X + b + + where $A_1$ and $A_0$ are $2 \times 2$ matrices and $b$, $X$ and $X'$ are + $2 \times 1$ matrices with $X = [x, y]^T$. + + Higher-order systems are represented with additional matrices e.g. a + second-order system would look like + + .. math:: A_2 X'' = A_1 X' + A_0 X + b + + Examples + ======== + + >>> from sympy import Function, Symbol, Matrix, Eq + >>> from sympy.solvers.ode.systems import linear_ode_to_matrix + >>> t = Symbol('t') + >>> x = Function('x') + >>> y = Function('y') + + We can create a system of linear ODEs like + + >>> eqs = [ + ... Eq(x(t).diff(t), x(t) + y(t) + 1), + ... Eq(y(t).diff(t), x(t) - y(t)), + ... ] + >>> funcs = [x(t), y(t)] + >>> order = 1 # 1st order system + + Now ``linear_ode_to_matrix`` can represent this as a matrix + differential equation. + + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, order) + >>> A1 + Matrix([ + [1, 0], + [0, 1]]) + >>> A0 + Matrix([ + [1, 1], + [1, -1]]) + >>> b + Matrix([ + [1], + [0]]) + + The original equations can be recovered from these matrices: + + >>> eqs_mat = Matrix([eq.lhs - eq.rhs for eq in eqs]) + >>> X = Matrix(funcs) + >>> A1 * X.diff(t) - A0 * X - b == eqs_mat + True + + If the system of equations has a maximum order greater than the + order of the system specified, a ODEOrderError exception is raised. + + >>> eqs = [Eq(x(t).diff(t, 2), x(t).diff(t) + x(t)), Eq(y(t).diff(t), y(t) + x(t))] + >>> linear_ode_to_matrix(eqs, funcs, t, 1) + Traceback (most recent call last): + ... + ODEOrderError: Cannot represent system in 1-order form + + If the system of equations is nonlinear, then ODENonlinearError is + raised. + + >>> eqs = [Eq(x(t).diff(t), x(t) + y(t)), Eq(y(t).diff(t), y(t)**2 + x(t))] + >>> linear_ode_to_matrix(eqs, funcs, t, 1) + Traceback (most recent call last): + ... + ODENonlinearError: The system of ODEs is nonlinear. + + Parameters + ========== + + eqs : list of SymPy expressions or equalities + The equations as expressions (assumed equal to zero). + funcs : list of applied functions + The dependent variables of the system of ODEs. + t : symbol + The independent variable. + order : int + The order of the system of ODEs. + + Returns + ======= + + The tuple ``(As, b)`` where ``As`` is a tuple of matrices and ``b`` is the + the matrix representing the rhs of the matrix equation. + + Raises + ====== + + ODEOrderError + When the system of ODEs have an order greater than what was specified + ODENonlinearError + When the system of ODEs is nonlinear + + See Also + ======== + + linear_eq_to_matrix: for systems of linear algebraic equations. + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Matrix_differential_equation + + """ + from sympy.solvers.solveset import linear_eq_to_matrix + + if any(ode_order(eq, func) > order for eq in eqs for func in funcs): + msg = "Cannot represent system in {}-order form" + raise ODEOrderError(msg.format(order)) + + As = [] + + for o in range(order, -1, -1): + # Work from the highest derivative down + syms = [func.diff(t, o) for func in funcs] + + # Ai is the matrix for X(t).diff(t, o) + # eqs is minus the remainder of the equations. + try: + Ai, b = linear_eq_to_matrix(eqs, syms) + except NonlinearError: + raise ODENonlinearError("The system of ODEs is nonlinear.") + + Ai = Ai.applyfunc(expand_mul) + + As.append(Ai if o == order else -Ai) + + if o: + eqs = [-eq for eq in b] + else: + rhs = b + + return As, rhs + + +def matrix_exp(A, t): + r""" + Matrix exponential $\exp(A*t)$ for the matrix ``A`` and scalar ``t``. + + Explanation + =========== + + This functions returns the $\exp(A*t)$ by doing a simple + matrix multiplication: + + .. math:: \exp(A*t) = P * expJ * P^{-1} + + where $expJ$ is $\exp(J*t)$. $J$ is the Jordan normal + form of $A$ and $P$ is matrix such that: + + .. math:: A = P * J * P^{-1} + + The matrix exponential $\exp(A*t)$ appears in the solution of linear + differential equations. For example if $x$ is a vector and $A$ is a matrix + then the initial value problem + + .. math:: \frac{dx(t)}{dt} = A \times x(t), x(0) = x0 + + has the unique solution + + .. math:: x(t) = \exp(A t) x0 + + Examples + ======== + + >>> from sympy import Symbol, Matrix, pprint + >>> from sympy.solvers.ode.systems import matrix_exp + >>> t = Symbol('t') + + We will consider a 2x2 matrix for comupting the exponential + + >>> A = Matrix([[2, -5], [2, -4]]) + >>> pprint(A) + [2 -5] + [ ] + [2 -4] + + Now, exp(A*t) is given as follows: + + >>> pprint(matrix_exp(A, t)) + [ -t -t -t ] + [3*e *sin(t) + e *cos(t) -5*e *sin(t) ] + [ ] + [ -t -t -t ] + [ 2*e *sin(t) - 3*e *sin(t) + e *cos(t)] + + Parameters + ========== + + A : Matrix + The matrix $A$ in the expression $\exp(A*t)$ + t : Symbol + The independent variable + + See Also + ======== + + matrix_exp_jordan_form: For exponential of Jordan normal form + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Jordan_normal_form + .. [2] https://en.wikipedia.org/wiki/Matrix_exponential + + """ + P, expJ = matrix_exp_jordan_form(A, t) + return P * expJ * P.inv() + + +def matrix_exp_jordan_form(A, t): + r""" + Matrix exponential $\exp(A*t)$ for the matrix *A* and scalar *t*. + + Explanation + =========== + + Returns the Jordan form of the $\exp(A*t)$ along with the matrix $P$ such that: + + .. math:: + \exp(A*t) = P * expJ * P^{-1} + + Examples + ======== + + >>> from sympy import Matrix, Symbol + >>> from sympy.solvers.ode.systems import matrix_exp, matrix_exp_jordan_form + >>> t = Symbol('t') + + We will consider a 2x2 defective matrix. This shows that our method + works even for defective matrices. + + >>> A = Matrix([[1, 1], [0, 1]]) + + It can be observed that this function gives us the Jordan normal form + and the required invertible matrix P. + + >>> P, expJ = matrix_exp_jordan_form(A, t) + + Here, it is shown that P and expJ returned by this function is correct + as they satisfy the formula: P * expJ * P_inverse = exp(A*t). + + >>> P * expJ * P.inv() == matrix_exp(A, t) + True + + Parameters + ========== + + A : Matrix + The matrix $A$ in the expression $\exp(A*t)$ + t : Symbol + The independent variable + + References + ========== + + .. [1] https://en.wikipedia.org/wiki/Defective_matrix + .. [2] https://en.wikipedia.org/wiki/Jordan_matrix + .. [3] https://en.wikipedia.org/wiki/Jordan_normal_form + + """ + + N, M = A.shape + if N != M: + raise ValueError('Needed square matrix but got shape (%s, %s)' % (N, M)) + elif A.has(t): + raise ValueError('Matrix A should not depend on t') + + def jordan_chains(A): + '''Chains from Jordan normal form analogous to M.eigenvects(). + Returns a dict with eignevalues as keys like: + {e1: [[v111,v112,...], [v121, v122,...]], e2:...} + where vijk is the kth vector in the jth chain for eigenvalue i. + ''' + P, blocks = A.jordan_cells() + basis = [P[:,i] for i in range(P.shape[1])] + n = 0 + chains = {} + for b in blocks: + eigval = b[0, 0] + size = b.shape[0] + if eigval not in chains: + chains[eigval] = [] + chains[eigval].append(basis[n:n+size]) + n += size + return chains + + eigenchains = jordan_chains(A) + + # Needed for consistency across Python versions + eigenchains_iter = sorted(eigenchains.items(), key=default_sort_key) + isreal = not A.has(I) + + blocks = [] + vectors = [] + seen_conjugate = set() + for e, chains in eigenchains_iter: + for chain in chains: + n = len(chain) + if isreal and e != e.conjugate() and e.conjugate() in eigenchains: + if e in seen_conjugate: + continue + seen_conjugate.add(e.conjugate()) + exprt = exp(re(e) * t) + imrt = im(e) * t + imblock = Matrix([[cos(imrt), sin(imrt)], + [-sin(imrt), cos(imrt)]]) + expJblock2 = Matrix(n, n, lambda i,j: + imblock * t**(j-i) / factorial(j-i) if j >= i + else zeros(2, 2)) + expJblock = Matrix(2*n, 2*n, lambda i,j: expJblock2[i//2,j//2][i%2,j%2]) + + blocks.append(exprt * expJblock) + for i in range(n): + vectors.append(re(chain[i])) + vectors.append(im(chain[i])) + else: + vectors.extend(chain) + fun = lambda i,j: t**(j-i)/factorial(j-i) if j >= i else 0 + expJblock = Matrix(n, n, fun) + blocks.append(exp(e * t) * expJblock) + + expJ = Matrix.diag(*blocks) + P = Matrix(N, N, lambda i,j: vectors[j][i]) + + return P, expJ + + +# Note: To add a docstring example with tau +def linodesolve(A, t, b=None, B=None, type="auto", doit=False, + tau=None): + r""" + System of n equations linear first-order differential equations + + Explanation + =========== + + This solver solves the system of ODEs of the following form: + + .. math:: + X'(t) = A(t) X(t) + b(t) + + Here, $A(t)$ is the coefficient matrix, $X(t)$ is the vector of n independent variables, + $b(t)$ is the non-homogeneous term and $X'(t)$ is the derivative of $X(t)$ + + Depending on the properties of $A(t)$ and $b(t)$, this solver evaluates the solution + differently. + + When $A(t)$ is constant coefficient matrix and $b(t)$ is zero vector i.e. system is homogeneous, + the system is "type1". The solution is: + + .. math:: + X(t) = \exp(A t) C + + Here, $C$ is a vector of constants and $A$ is the constant coefficient matrix. + + When $A(t)$ is constant coefficient matrix and $b(t)$ is non-zero i.e. system is non-homogeneous, + the system is "type2". The solution is: + + .. math:: + X(t) = e^{A t} ( \int e^{- A t} b \,dt + C) + + When $A(t)$ is coefficient matrix such that its commutative with its antiderivative $B(t)$ and + $b(t)$ is a zero vector i.e. system is homogeneous, the system is "type3". The solution is: + + .. math:: + X(t) = \exp(B(t)) C + + When $A(t)$ is commutative with its antiderivative $B(t)$ and $b(t)$ is non-zero i.e. system is + non-homogeneous, the system is "type4". The solution is: + + .. math:: + X(t) = e^{B(t)} ( \int e^{-B(t)} b(t) \,dt + C) + + When $A(t)$ is a coefficient matrix such that it can be factorized into a scalar and a constant + coefficient matrix: + + .. math:: + A(t) = f(t) * A + + Where $f(t)$ is a scalar expression in the independent variable $t$ and $A$ is a constant matrix, + then we can do the following substitutions: + + .. math:: + tau = \int f(t) dt, X(t) = Y(tau), b(t) = b(f^{-1}(tau)) + + Here, the substitution for the non-homogeneous term is done only when its non-zero. + Using these substitutions, our original system becomes: + + .. math:: + Y'(tau) = A * Y(tau) + b(tau)/f(tau) + + The above system can be easily solved using the solution for "type1" or "type2" depending + on the homogeneity of the system. After we get the solution for $Y(tau)$, we substitute the + solution for $tau$ as $t$ to get back $X(t)$ + + .. math:: + X(t) = Y(tau) + + Systems of "type5" and "type6" have a commutative antiderivative but we use this solution + because its faster to compute. + + The final solution is the general solution for all the four equations since a constant coefficient + matrix is always commutative with its antidervative. + + An additional feature of this function is, if someone wants to substitute for value of the independent + variable, they can pass the substitution `tau` and the solution will have the independent variable + substituted with the passed expression(`tau`). + + Parameters + ========== + + A : Matrix + Coefficient matrix of the system of linear first order ODEs. + t : Symbol + Independent variable in the system of ODEs. + b : Matrix or None + Non-homogeneous term in the system of ODEs. If None is passed, + a homogeneous system of ODEs is assumed. + B : Matrix or None + Antiderivative of the coefficient matrix. If the antiderivative + is not passed and the solution requires the term, then the solver + would compute it internally. + type : String + Type of the system of ODEs passed. Depending on the type, the + solution is evaluated. The type values allowed and the corresponding + system it solves are: "type1" for constant coefficient homogeneous + "type2" for constant coefficient non-homogeneous, "type3" for non-constant + coefficient homogeneous, "type4" for non-constant coefficient non-homogeneous, + "type5" and "type6" for non-constant coefficient homogeneous and non-homogeneous + systems respectively where the coefficient matrix can be factorized to a constant + coefficient matrix. + The default value is "auto" which will let the solver decide the correct type of + the system passed. + doit : Boolean + Evaluate the solution if True, default value is False + tau: Expression + Used to substitute for the value of `t` after we get the solution of the system. + + Examples + ======== + + To solve the system of ODEs using this function directly, several things must be + done in the right order. Wrong inputs to the function will lead to incorrect results. + + >>> from sympy import symbols, Function, Eq + >>> from sympy.solvers.ode.systems import canonical_odes, linear_ode_to_matrix, linodesolve, linodesolve_type + >>> from sympy.solvers.ode.subscheck import checkodesol + >>> f, g = symbols("f, g", cls=Function) + >>> x, a = symbols("x, a") + >>> funcs = [f(x), g(x)] + >>> eqs = [Eq(f(x).diff(x) - f(x), a*g(x) + 1), Eq(g(x).diff(x) + g(x), a*f(x))] + + Here, it is important to note that before we derive the coefficient matrix, it is + important to get the system of ODEs into the desired form. For that we will use + :obj:`sympy.solvers.ode.systems.canonical_odes()`. + + >>> eqs = canonical_odes(eqs, funcs, x) + >>> eqs + [[Eq(Derivative(f(x), x), a*g(x) + f(x) + 1), Eq(Derivative(g(x), x), a*f(x) - g(x))]] + + Now, we will use :obj:`sympy.solvers.ode.systems.linear_ode_to_matrix()` to get the coefficient matrix and the + non-homogeneous term if it is there. + + >>> eqs = eqs[0] + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, x, 1) + >>> A = A0 + + We have the coefficient matrices and the non-homogeneous term ready. Now, we can use + :obj:`sympy.solvers.ode.systems.linodesolve_type()` to get the information for the system of ODEs + to finally pass it to the solver. + + >>> system_info = linodesolve_type(A, x, b=b) + >>> sol_vector = linodesolve(A, x, b=b, B=system_info['antiderivative'], type=system_info['type_of_equation']) + + Now, we can prove if the solution is correct or not by using :obj:`sympy.solvers.ode.checkodesol()` + + >>> sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + We can also use the doit method to evaluate the solutions passed by the function. + + >>> sol_vector_evaluated = linodesolve(A, x, b=b, type="type2", doit=True) + + Now, we will look at a system of ODEs which is non-constant. + + >>> eqs = [Eq(f(x).diff(x), f(x) + x*g(x)), Eq(g(x).diff(x), -x*f(x) + g(x))] + + The system defined above is already in the desired form, so we do not have to convert it. + + >>> (A1, A0), b = linear_ode_to_matrix(eqs, funcs, x, 1) + >>> A = A0 + + A user can also pass the commutative antiderivative required for type3 and type4 system of ODEs. + Passing an incorrect one will lead to incorrect results. If the coefficient matrix is not commutative + with its antiderivative, then :obj:`sympy.solvers.ode.systems.linodesolve_type()` raises a NotImplementedError. + If it does have a commutative antiderivative, then the function just returns the information about the system. + + >>> system_info = linodesolve_type(A, x, b=b) + + Now, we can pass the antiderivative as an argument to get the solution. If the system information is not + passed, then the solver will compute the required arguments internally. + + >>> sol_vector = linodesolve(A, x, b=b) + + Once again, we can verify the solution obtained. + + >>> sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + >>> checkodesol(eqs, sol) + (True, [0, 0]) + + Returns + ======= + + List + + Raises + ====== + + ValueError + This error is raised when the coefficient matrix, non-homogeneous term + or the antiderivative, if passed, are not a matrix or + do not have correct dimensions + NonSquareMatrixError + When the coefficient matrix or its antiderivative, if passed is not a + square matrix + NotImplementedError + If the coefficient matrix does not have a commutative antiderivative + + See Also + ======== + + linear_ode_to_matrix: Coefficient matrix computation function + canonical_odes: System of ODEs representation change + linodesolve_type: Getting information about systems of ODEs to pass in this solver + + """ + + if not isinstance(A, MatrixBase): + raise ValueError(filldedent('''\ + The coefficients of the system of ODEs should be of type Matrix + ''')) + + if not A.is_square: + raise NonSquareMatrixError(filldedent('''\ + The coefficient matrix must be a square + ''')) + + if b is not None: + if not isinstance(b, MatrixBase): + raise ValueError(filldedent('''\ + The non-homogeneous terms of the system of ODEs should be of type Matrix + ''')) + + if A.rows != b.rows: + raise ValueError(filldedent('''\ + The system of ODEs should have the same number of non-homogeneous terms and the number of + equations + ''')) + + if B is not None: + if not isinstance(B, MatrixBase): + raise ValueError(filldedent('''\ + The antiderivative of coefficients of the system of ODEs should be of type Matrix + ''')) + + if not B.is_square: + raise NonSquareMatrixError(filldedent('''\ + The antiderivative of the coefficient matrix must be a square + ''')) + + if A.rows != B.rows: + raise ValueError(filldedent('''\ + The coefficient matrix and its antiderivative should have same dimensions + ''')) + + if not any(type == "type{}".format(i) for i in range(1, 7)) and not type == "auto": + raise ValueError(filldedent('''\ + The input type should be a valid one + ''')) + + n = A.rows + + # constants = numbered_symbols(prefix='C', cls=Dummy, start=const_idx+1) + Cvect = Matrix([Dummy() for _ in range(n)]) + + if b is None and any(type == typ for typ in ["type2", "type4", "type6"]): + b = zeros(n, 1) + + is_transformed = tau is not None + passed_type = type + + if type == "auto": + system_info = linodesolve_type(A, t, b=b) + type = system_info["type_of_equation"] + B = system_info["antiderivative"] + + if type in ("type5", "type6"): + is_transformed = True + if passed_type != "auto": + if tau is None: + system_info = _first_order_type5_6_subs(A, t, b=b) + if not system_info: + raise ValueError(filldedent(''' + The system passed isn't {}. + '''.format(type))) + + tau = system_info['tau'] + t = system_info['t_'] + A = system_info['A'] + b = system_info['b'] + + intx_wrtt = lambda x: Integral(x, t) if x else 0 + if type in ("type1", "type2", "type5", "type6"): + P, J = matrix_exp_jordan_form(A, t) + P = simplify(P) + + if type in ("type1", "type5"): + sol_vector = P * (J * Cvect) + else: + Jinv = J.subs(t, -t) + sol_vector = P * J * ((Jinv * P.inv() * b).applyfunc(intx_wrtt) + Cvect) + else: + if B is None: + B, _ = _is_commutative_anti_derivative(A, t) + + if type == "type3": + sol_vector = B.exp() * Cvect + else: + sol_vector = B.exp() * (((-B).exp() * b).applyfunc(intx_wrtt) + Cvect) + + if is_transformed: + sol_vector = sol_vector.subs(t, tau) + + gens = sol_vector.atoms(exp) + + if type != "type1": + sol_vector = [expand_mul(s) for s in sol_vector] + + sol_vector = [collect(s, ordered(gens), exact=True) for s in sol_vector] + + if doit: + sol_vector = [s.doit() for s in sol_vector] + + return sol_vector + + +def _matrix_is_constant(M, t): + """Checks if the matrix M is independent of t or not.""" + return all(coef.as_independent(t, as_Add=True)[1] == 0 for coef in M) + + +def canonical_odes(eqs, funcs, t): + r""" + Function that solves for highest order derivatives in a system + + Explanation + =========== + + This function inputs a system of ODEs and based on the system, + the dependent variables and their highest order, returns the system + in the following form: + + .. math:: + X'(t) = A(t) X(t) + b(t) + + Here, $X(t)$ is the vector of dependent variables of lower order, $A(t)$ is + the coefficient matrix, $b(t)$ is the non-homogeneous term and $X'(t)$ is the + vector of dependent variables in their respective highest order. We use the term + canonical form to imply the system of ODEs which is of the above form. + + If the system passed has a non-linear term with multiple solutions, then a list of + systems is returned in its canonical form. + + Parameters + ========== + + eqs : List + List of the ODEs + funcs : List + List of dependent variables + t : Symbol + Independent variable + + Examples + ======== + + >>> from sympy import symbols, Function, Eq, Derivative + >>> from sympy.solvers.ode.systems import canonical_odes + >>> f, g = symbols("f g", cls=Function) + >>> x, y = symbols("x y") + >>> funcs = [f(x), g(x)] + >>> eqs = [Eq(f(x).diff(x) - 7*f(x), 12*g(x)), Eq(g(x).diff(x) + g(x), 20*f(x))] + + >>> canonical_eqs = canonical_odes(eqs, funcs, x) + >>> canonical_eqs + [[Eq(Derivative(f(x), x), 7*f(x) + 12*g(x)), Eq(Derivative(g(x), x), 20*f(x) - g(x))]] + + >>> system = [Eq(Derivative(f(x), x)**2 - 2*Derivative(f(x), x) + 1, 4), Eq(-y*f(x) + Derivative(g(x), x), 0)] + + >>> canonical_system = canonical_odes(system, funcs, x) + >>> canonical_system + [[Eq(Derivative(f(x), x), -1), Eq(Derivative(g(x), x), y*f(x))], [Eq(Derivative(f(x), x), 3), Eq(Derivative(g(x), x), y*f(x))]] + + Returns + ======= + + List + + """ + from sympy.solvers.solvers import solve + + order = _get_func_order(eqs, funcs) + + canon_eqs = solve(eqs, *[func.diff(t, order[func]) for func in funcs], dict=True) + + systems = [] + for eq in canon_eqs: + system = [Eq(func.diff(t, order[func]), eq[func.diff(t, order[func])]) for func in funcs] + systems.append(system) + + return systems + + +def _is_commutative_anti_derivative(A, t): + r""" + Helper function for determining if the Matrix passed is commutative with its antiderivative + + Explanation + =========== + + This function checks if the Matrix $A$ passed is commutative with its antiderivative with respect + to the independent variable $t$. + + .. math:: + B(t) = \int A(t) dt + + The function outputs two values, first one being the antiderivative $B(t)$, second one being a + boolean value, if True, then the matrix $A(t)$ passed is commutative with $B(t)$, else the matrix + passed isn't commutative with $B(t)$. + + Parameters + ========== + + A : Matrix + The matrix which has to be checked + t : Symbol + Independent variable + + Examples + ======== + + >>> from sympy import symbols, Matrix + >>> from sympy.solvers.ode.systems import _is_commutative_anti_derivative + >>> t = symbols("t") + >>> A = Matrix([[1, t], [-t, 1]]) + + >>> B, is_commuting = _is_commutative_anti_derivative(A, t) + >>> is_commuting + True + + Returns + ======= + + Matrix, Boolean + + """ + B = integrate(A, t) + is_commuting = (B*A - A*B).applyfunc(expand).applyfunc(factor_terms).is_zero_matrix + + is_commuting = False if is_commuting is None else is_commuting + + return B, is_commuting + + +def _factor_matrix(A, t): + term = None + for element in A: + temp_term = element.as_independent(t)[1] + if temp_term.has(t): + term = temp_term + break + + if term is not None: + A_factored = (A/term).applyfunc(ratsimp) + can_factor = _matrix_is_constant(A_factored, t) + term = (term, A_factored) if can_factor else None + + return term + + +def _is_second_order_type2(A, t): + term = _factor_matrix(A, t) + is_type2 = False + + if term is not None: + term = 1/term[0] + is_type2 = term.is_polynomial() + + if is_type2: + poly = Poly(term.expand(), t) + monoms = poly.monoms() + + if monoms[0][0] in (2, 4): + cs = _get_poly_coeffs(poly, 4) + a, b, c, d, e = cs + + a1 = powdenest(sqrt(a), force=True) + c1 = powdenest(sqrt(e), force=True) + b1 = powdenest(sqrt(c - 2*a1*c1), force=True) + + is_type2 = (b == 2*a1*b1) and (d == 2*b1*c1) + term = a1*t**2 + b1*t + c1 + + else: + is_type2 = False + + return is_type2, term + + +def _get_poly_coeffs(poly, order): + cs = [0 for _ in range(order+1)] + for c, m in zip(poly.coeffs(), poly.monoms()): + cs[-1-m[0]] = c + return cs + + +def _match_second_order_type(A1, A0, t, b=None): + r""" + Works only for second order system in its canonical form. + + Type 0: Constant coefficient matrix, can be simply solved by + introducing dummy variables. + Type 1: When the substitution: $U = t*X' - X$ works for reducing + the second order system to first order system. + Type 2: When the system is of the form: $poly * X'' = A*X$ where + $poly$ is square of a quadratic polynomial with respect to + *t* and $A$ is a constant coefficient matrix. + + """ + match = {"type_of_equation": "type0"} + n = A1.shape[0] + + if _matrix_is_constant(A1, t) and _matrix_is_constant(A0, t): + return match + + if (A1 + A0*t).applyfunc(expand_mul).is_zero_matrix: + match.update({"type_of_equation": "type1", "A1": A1}) + + elif A1.is_zero_matrix and (b is None or b.is_zero_matrix): + is_type2, term = _is_second_order_type2(A0, t) + if is_type2: + a, b, c = _get_poly_coeffs(Poly(term, t), 2) + A = (A0*(term**2).expand()).applyfunc(ratsimp) + (b**2/4 - a*c)*eye(n, n) + tau = integrate(1/term, t) + t_ = Symbol("{}_".format(t)) + match.update({"type_of_equation": "type2", "A0": A, + "g(t)": sqrt(term), "tau": tau, "is_transformed": True, + "t_": t_}) + + return match + + +def _second_order_subs_type1(A, b, funcs, t): + r""" + For a linear, second order system of ODEs, a particular substitution. + + A system of the below form can be reduced to a linear first order system of + ODEs: + .. math:: + X'' = A(t) * (t*X' - X) + b(t) + + By substituting: + .. math:: U = t*X' - X + + To get the system: + .. math:: U' = t*(A(t)*U + b(t)) + + Where $U$ is the vector of dependent variables, $X$ is the vector of dependent + variables in `funcs` and $X'$ is the first order derivative of $X$ with respect to + $t$. It may or may not reduce the system into linear first order system of ODEs. + + Then a check is made to determine if the system passed can be reduced or not, if + this substitution works, then the system is reduced and its solved for the new + substitution. After we get the solution for $U$: + + .. math:: U = a(t) + + We substitute and return the reduced system: + + .. math:: + a(t) = t*X' - X + + Parameters + ========== + + A: Matrix + Coefficient matrix($A(t)*t$) of the second order system of this form. + b: Matrix + Non-homogeneous term($b(t)$) of the system of ODEs. + funcs: List + List of dependent variables + t: Symbol + Independent variable of the system of ODEs. + + Returns + ======= + + List + + """ + + U = Matrix([t*func.diff(t) - func for func in funcs]) + + sol = linodesolve(A, t, t*b) + reduced_eqs = [Eq(u, s) for s, u in zip(sol, U)] + reduced_eqs = canonical_odes(reduced_eqs, funcs, t)[0] + + return reduced_eqs + + +def _second_order_subs_type2(A, funcs, t_): + r""" + Returns a second order system based on the coefficient matrix passed. + + Explanation + =========== + + This function returns a system of second order ODE of the following form: + + .. math:: + X'' = A * X + + Here, $X$ is the vector of dependent variables, but a bit modified, $A$ is the + coefficient matrix passed. + + Along with returning the second order system, this function also returns the new + dependent variables with the new independent variable `t_` passed. + + Parameters + ========== + + A: Matrix + Coefficient matrix of the system + funcs: List + List of old dependent variables + t_: Symbol + New independent variable + + Returns + ======= + + List, List + + """ + func_names = [func.func.__name__ for func in funcs] + new_funcs = [Function(Dummy("{}_".format(name)))(t_) for name in func_names] + rhss = A * Matrix(new_funcs) + new_eqs = [Eq(func.diff(t_, 2), rhs) for func, rhs in zip(new_funcs, rhss)] + + return new_eqs, new_funcs + + +def _is_euler_system(As, t): + return all(_matrix_is_constant((A*t**i).applyfunc(ratsimp), t) for i, A in enumerate(As)) + + +def _classify_linear_system(eqs, funcs, t, is_canon=False): + r""" + Returns a dictionary with details of the eqs if the system passed is linear + and can be classified by this function else returns None + + Explanation + =========== + + This function takes the eqs, converts it into a form Ax = b where x is a vector of terms + containing dependent variables and their derivatives till their maximum order. If it is + possible to convert eqs into Ax = b, then all the equations in eqs are linear otherwise + they are non-linear. + + To check if the equations are constant coefficient, we need to check if all the terms in + A obtained above are constant or not. + + To check if the equations are homogeneous or not, we need to check if b is a zero matrix + or not. + + Parameters + ========== + + eqs: List + List of ODEs + funcs: List + List of dependent variables + t: Symbol + Independent variable of the equations in eqs + is_canon: Boolean + If True, then this function will not try to get the + system in canonical form. Default value is False + + Returns + ======= + + match = { + 'no_of_equation': len(eqs), + 'eq': eqs, + 'func': funcs, + 'order': order, + 'is_linear': is_linear, + 'is_constant': is_constant, + 'is_homogeneous': is_homogeneous, + } + + Dict or list of Dicts or None + Dict with values for keys: + 1. no_of_equation: Number of equations + 2. eq: The set of equations + 3. func: List of dependent variables + 4. order: A dictionary that gives the order of the + dependent variable in eqs + 5. is_linear: Boolean value indicating if the set of + equations are linear or not. + 6. is_constant: Boolean value indicating if the set of + equations have constant coefficients or not. + 7. is_homogeneous: Boolean value indicating if the set of + equations are homogeneous or not. + 8. commutative_antiderivative: Antiderivative of the coefficient + matrix if the coefficient matrix is non-constant + and commutative with its antiderivative. This key + may or may not exist. + 9. is_general: Boolean value indicating if the system of ODEs is + solvable using one of the general case solvers or not. + 10. rhs: rhs of the non-homogeneous system of ODEs in Matrix form. This + key may or may not exist. + 11. is_higher_order: True if the system passed has an order greater than 1. + This key may or may not exist. + 12. is_second_order: True if the system passed is a second order ODE. This + key may or may not exist. + This Dict is the answer returned if the eqs are linear and constant + coefficient. Otherwise, None is returned. + + """ + + # Error for i == 0 can be added but isn't for now + + # Check for len(funcs) == len(eqs) + if len(funcs) != len(eqs): + raise ValueError("Number of functions given is not equal to the number of equations %s" % funcs) + + # ValueError when functions have more than one arguments + for func in funcs: + if len(func.args) != 1: + raise ValueError("dsolve() and classify_sysode() work with " + "functions of one variable only, not %s" % func) + + # Getting the func_dict and order using the helper + # function + order = _get_func_order(eqs, funcs) + system_order = max(order[func] for func in funcs) + is_higher_order = system_order > 1 + is_second_order = system_order == 2 and all(order[func] == 2 for func in funcs) + + # Not adding the check if the len(func.args) for + # every func in funcs is 1 + + # Linearity check + try: + + canon_eqs = canonical_odes(eqs, funcs, t) if not is_canon else [eqs] + if len(canon_eqs) == 1: + As, b = linear_ode_to_matrix(canon_eqs[0], funcs, t, system_order) + else: + + match = { + 'is_implicit': True, + 'canon_eqs': canon_eqs + } + + return match + + # When the system of ODEs is non-linear, an ODENonlinearError is raised. + # This function catches the error and None is returned. + except ODENonlinearError: + return None + + is_linear = True + + # Homogeneous check + is_homogeneous = True if b.is_zero_matrix else False + + # Is general key is used to identify if the system of ODEs can be solved by + # one of the general case solvers or not. + match = { + 'no_of_equation': len(eqs), + 'eq': eqs, + 'func': funcs, + 'order': order, + 'is_linear': is_linear, + 'is_homogeneous': is_homogeneous, + 'is_general': True + } + + if not is_homogeneous: + match['rhs'] = b + + is_constant = all(_matrix_is_constant(A_, t) for A_ in As) + + # The match['is_linear'] check will be added in the future when this + # function becomes ready to deal with non-linear systems of ODEs + + if not is_higher_order: + A = As[1] + match['func_coeff'] = A + + # Constant coefficient check + is_constant = _matrix_is_constant(A, t) + match['is_constant'] = is_constant + + try: + system_info = linodesolve_type(A, t, b=b) + except NotImplementedError: + return None + + match.update(system_info) + antiderivative = match.pop("antiderivative") + + if not is_constant: + match['commutative_antiderivative'] = antiderivative + + return match + else: + match['type_of_equation'] = "type0" + + if is_second_order: + A1, A0 = As[1:] + + match_second_order = _match_second_order_type(A1, A0, t) + match.update(match_second_order) + + match['is_second_order'] = True + + # If system is constant, then no need to check if its in euler + # form or not. It will be easier and faster to directly proceed + # to solve it. + if match['type_of_equation'] == "type0" and not is_constant: + is_euler = _is_euler_system(As, t) + if is_euler: + t_ = Symbol('{}_'.format(t)) + match.update({'is_transformed': True, 'type_of_equation': 'type1', + 't_': t_}) + else: + is_jordan = lambda M: M == Matrix.jordan_block(M.shape[0], M[0, 0]) + terms = _factor_matrix(As[-1], t) + if all(A.is_zero_matrix for A in As[1:-1]) and terms is not None and not is_jordan(terms[1]): + P, J = terms[1].jordan_form() + match.update({'type_of_equation': 'type2', 'J': J, + 'f(t)': terms[0], 'P': P, 'is_transformed': True}) + + if match['type_of_equation'] != 'type0' and is_second_order: + match.pop('is_second_order', None) + + match['is_higher_order'] = is_higher_order + + return match + +def _preprocess_eqs(eqs): + processed_eqs = [] + for eq in eqs: + processed_eqs.append(eq if isinstance(eq, Equality) else Eq(eq, 0)) + + return processed_eqs + + +def _eqs2dict(eqs, funcs): + eqsorig = {} + eqsmap = {} + funcset = set(funcs) + for eq in eqs: + f1, = eq.lhs.atoms(AppliedUndef) + f2s = (eq.rhs.atoms(AppliedUndef) - {f1}) & funcset + eqsmap[f1] = f2s + eqsorig[f1] = eq + return eqsmap, eqsorig + + +def _dict2graph(d): + nodes = list(d) + edges = [(f1, f2) for f1, f2s in d.items() for f2 in f2s] + G = (nodes, edges) + return G + + +def _is_type1(scc, t): + eqs, funcs = scc + + try: + (A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, 1) + except (ODENonlinearError, ODEOrderError): + return False + + if _matrix_is_constant(A0, t) and b.is_zero_matrix: + return True + + return False + + +def _combine_type1_subsystems(subsystem, funcs, t): + indices = [i for i, sys in enumerate(zip(subsystem, funcs)) if _is_type1(sys, t)] + remove = set() + for ip, i in enumerate(indices): + for j in indices[ip+1:]: + if any(eq2.has(funcs[i]) for eq2 in subsystem[j]): + subsystem[j] = subsystem[i] + subsystem[j] + remove.add(i) + subsystem = [sys for i, sys in enumerate(subsystem) if i not in remove] + return subsystem + + +def _component_division(eqs, funcs, t): + + # Assuming that each eq in eqs is in canonical form, + # that is, [f(x).diff(x) = .., g(x).diff(x) = .., etc] + # and that the system passed is in its first order + eqsmap, eqsorig = _eqs2dict(eqs, funcs) + + subsystems = [] + for cc in connected_components(_dict2graph(eqsmap)): + eqsmap_c = {f: eqsmap[f] for f in cc} + sccs = strongly_connected_components(_dict2graph(eqsmap_c)) + subsystem = [[eqsorig[f] for f in scc] for scc in sccs] + subsystem = _combine_type1_subsystems(subsystem, sccs, t) + subsystems.append(subsystem) + + return subsystems + + +# Returns: List of equations +def _linear_ode_solver(match): + t = match['t'] + funcs = match['func'] + + rhs = match.get('rhs', None) + tau = match.get('tau', None) + t = match['t_'] if 't_' in match else t + A = match['func_coeff'] + + # Note: To make B None when the matrix has constant + # coefficient + B = match.get('commutative_antiderivative', None) + type = match['type_of_equation'] + + sol_vector = linodesolve(A, t, b=rhs, B=B, + type=type, tau=tau) + + sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + + return sol + + +def _select_equations(eqs, funcs, key=lambda x: x): + eq_dict = {e.lhs: e.rhs for e in eqs} + return [Eq(f, eq_dict[key(f)]) for f in funcs] + + +def _higher_order_ode_solver(match): + eqs = match["eq"] + funcs = match["func"] + t = match["t"] + sysorder = match['order'] + type = match.get('type_of_equation', "type0") + + is_second_order = match.get('is_second_order', False) + is_transformed = match.get('is_transformed', False) + is_euler = is_transformed and type == "type1" + is_higher_order_type2 = is_transformed and type == "type2" and 'P' in match + + if is_second_order: + new_eqs, new_funcs = _second_order_to_first_order(eqs, funcs, t, + A1=match.get("A1", None), A0=match.get("A0", None), + b=match.get("rhs", None), type=type, + t_=match.get("t_", None)) + else: + new_eqs, new_funcs = _higher_order_to_first_order(eqs, sysorder, t, funcs=funcs, + type=type, J=match.get('J', None), + f_t=match.get('f(t)', None), + P=match.get('P', None), b=match.get('rhs', None)) + + if is_transformed: + t = match.get('t_', t) + + if not is_higher_order_type2: + new_eqs = _select_equations(new_eqs, [f.diff(t) for f in new_funcs]) + + sol = None + + # NotImplementedError may be raised when the system may be actually + # solvable if it can be just divided into sub-systems + try: + if not is_higher_order_type2: + sol = _strong_component_solver(new_eqs, new_funcs, t) + except NotImplementedError: + sol = None + + # Dividing the system only when it becomes essential + if sol is None: + try: + sol = _component_solver(new_eqs, new_funcs, t) + except NotImplementedError: + sol = None + + if sol is None: + return sol + + is_second_order_type2 = is_second_order and type == "type2" + + underscores = '__' if is_transformed else '_' + + sol = _select_equations(sol, funcs, + key=lambda x: Function(Dummy('{}{}0'.format(x.func.__name__, underscores)))(t)) + + if match.get("is_transformed", False): + if is_second_order_type2: + g_t = match["g(t)"] + tau = match["tau"] + sol = [Eq(s.lhs, s.rhs.subs(t, tau) * g_t) for s in sol] + elif is_euler: + t = match['t'] + tau = match['t_'] + sol = [s.subs(tau, log(t)) for s in sol] + elif is_higher_order_type2: + P = match['P'] + sol_vector = P * Matrix([s.rhs for s in sol]) + sol = [Eq(f, s) for f, s in zip(funcs, sol_vector)] + + return sol + + +# Returns: List of equations or None +# If None is returned by this solver, then the system +# of ODEs cannot be solved directly by dsolve_system. +def _strong_component_solver(eqs, funcs, t): + from sympy.solvers.ode.ode import dsolve, constant_renumber + + match = _classify_linear_system(eqs, funcs, t, is_canon=True) + sol = None + + # Assuming that we can't get an implicit system + # since we are already canonical equations from + # dsolve_system + if match: + match['t'] = t + + if match.get('is_higher_order', False): + sol = _higher_order_ode_solver(match) + + elif match.get('is_linear', False): + sol = _linear_ode_solver(match) + + # Note: For now, only linear systems are handled by this function + # hence, the match condition is added. This can be removed later. + if sol is None and len(eqs) == 1: + sol = dsolve(eqs[0], func=funcs[0]) + variables = Tuple(eqs[0]).free_symbols + new_constants = [Dummy() for _ in range(ode_order(eqs[0], funcs[0]))] + sol = constant_renumber(sol, variables=variables, newconstants=new_constants) + sol = [sol] + + # To add non-linear case here in future + + return sol + + +def _get_funcs_from_canon(eqs): + return [eq.lhs.args[0] for eq in eqs] + + +# Returns: List of Equations(a solution) +def _weak_component_solver(wcc, t): + + # We will divide the systems into sccs + # only when the wcc cannot be solved as + # a whole + eqs = [] + for scc in wcc: + eqs += scc + funcs = _get_funcs_from_canon(eqs) + + sol = _strong_component_solver(eqs, funcs, t) + if sol: + return sol + + sol = [] + + for scc in wcc: + eqs = scc + funcs = _get_funcs_from_canon(eqs) + + # Substituting solutions for the dependent + # variables solved in previous SCC, if any solved. + comp_eqs = [eq.subs({s.lhs: s.rhs for s in sol}) for eq in eqs] + scc_sol = _strong_component_solver(comp_eqs, funcs, t) + + if scc_sol is None: + raise NotImplementedError(filldedent(''' + The system of ODEs passed cannot be solved by dsolve_system. + ''')) + + # scc_sol: List of equations + # scc_sol is a solution + sol += scc_sol + + return sol + + +# Returns: List of Equations(a solution) +def _component_solver(eqs, funcs, t): + components = _component_division(eqs, funcs, t) + sol = [] + + for wcc in components: + + # wcc_sol: List of Equations + sol += _weak_component_solver(wcc, t) + + # sol: List of Equations + return sol + + +def _second_order_to_first_order(eqs, funcs, t, type="auto", A1=None, + A0=None, b=None, t_=None): + r""" + Expects the system to be in second order and in canonical form + + Explanation + =========== + + Reduces a second order system into a first order one depending on the type of second + order system. + 1. "type0": If this is passed, then the system will be reduced to first order by + introducing dummy variables. + 2. "type1": If this is passed, then a particular substitution will be used to reduce the + the system into first order. + 3. "type2": If this is passed, then the system will be transformed with new dependent + variables and independent variables. This transformation is a part of solving + the corresponding system of ODEs. + + `A1` and `A0` are the coefficient matrices from the system and it is assumed that the + second order system has the form given below: + + .. math:: + A2 * X'' = A1 * X' + A0 * X + b + + Here, $A2$ is the coefficient matrix for the vector $X''$ and $b$ is the non-homogeneous + term. + + Default value for `b` is None but if `A1` and `A0` are passed and `b` is not passed, then the + system will be assumed homogeneous. + + """ + is_a1 = A1 is None + is_a0 = A0 is None + + if (type == "type1" and is_a1) or (type == "type2" and is_a0)\ + or (type == "auto" and (is_a1 or is_a0)): + (A2, A1, A0), b = linear_ode_to_matrix(eqs, funcs, t, 2) + + if not A2.is_Identity: + raise ValueError(filldedent(''' + The system must be in its canonical form. + ''')) + + if type == "auto": + match = _match_second_order_type(A1, A0, t) + type = match["type_of_equation"] + A1 = match.get("A1", None) + A0 = match.get("A0", None) + + sys_order = dict.fromkeys(funcs, 2) + + if type == "type1": + if b is None: + b = zeros(len(eqs)) + eqs = _second_order_subs_type1(A1, b, funcs, t) + sys_order = dict.fromkeys(funcs, 1) + + if type == "type2": + if t_ is None: + t_ = Symbol("{}_".format(t)) + t = t_ + eqs, funcs = _second_order_subs_type2(A0, funcs, t_) + sys_order = dict.fromkeys(funcs, 2) + + return _higher_order_to_first_order(eqs, sys_order, t, funcs=funcs) + + +def _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, b=None, P=None): + + # Note: To add a test for this ValueError + if J is None or f_t is None or not _matrix_is_constant(J, t): + raise ValueError(filldedent(''' + Correctly input for args 'A' and 'f_t' for Linear, Higher Order, + Type 2 + ''')) + + if P is None and b is not None and not b.is_zero_matrix: + raise ValueError(filldedent(''' + Provide the keyword 'P' for matrix P in A = P * J * P-1. + ''')) + + new_funcs = Matrix([Function(Dummy('{}__0'.format(f.func.__name__)))(t) for f in funcs]) + new_eqs = new_funcs.diff(t, max_order) - f_t * J * new_funcs + + if b is not None and not b.is_zero_matrix: + new_eqs -= P.inv() * b + + new_eqs = canonical_odes(new_eqs, new_funcs, t)[0] + + return new_eqs, new_funcs + + +def _higher_order_to_first_order(eqs, sys_order, t, funcs=None, type="type0", **kwargs): + if funcs is None: + funcs = sys_order.keys() + + # Standard Cauchy Euler system + if type == "type1": + t_ = Symbol('{}_'.format(t)) + new_funcs = [Function(Dummy('{}_'.format(f.func.__name__)))(t_) for f in funcs] + max_order = max(sys_order[func] for func in funcs) + subs_dict = dict(zip(funcs, new_funcs)) + subs_dict[t] = exp(t_) + + free_function = Function(Dummy()) + + def _get_coeffs_from_subs_expression(expr): + if isinstance(expr, Subs): + free_symbol = expr.args[1][0] + term = expr.args[0] + return {ode_order(term, free_symbol): 1} + + if isinstance(expr, Mul): + coeff = expr.args[0] + order = list(_get_coeffs_from_subs_expression(expr.args[1]).keys())[0] + return {order: coeff} + + if isinstance(expr, Add): + coeffs = {} + for arg in expr.args: + + if isinstance(arg, Mul): + coeffs.update(_get_coeffs_from_subs_expression(arg)) + + else: + order = list(_get_coeffs_from_subs_expression(arg).keys())[0] + coeffs[order] = 1 + + return coeffs + + for o in range(1, max_order + 1): + expr = free_function(log(t_)).diff(t_, o)*t_**o + coeff_dict = _get_coeffs_from_subs_expression(expr) + coeffs = [coeff_dict[order] if order in coeff_dict else 0 for order in range(o + 1)] + expr_to_subs = sum(free_function(t_).diff(t_, i) * c for i, c in + enumerate(coeffs)) / t**o + subs_dict.update({f.diff(t, o): expr_to_subs.subs(free_function(t_), nf) + for f, nf in zip(funcs, new_funcs)}) + + new_eqs = [eq.subs(subs_dict) for eq in eqs] + new_sys_order = {nf: sys_order[f] for f, nf in zip(funcs, new_funcs)} + + new_eqs = canonical_odes(new_eqs, new_funcs, t_)[0] + + return _higher_order_to_first_order(new_eqs, new_sys_order, t_, funcs=new_funcs) + + # Systems of the form: X(n)(t) = f(t)*A*X + b + # where X(n)(t) is the nth derivative of the vector of dependent variables + # with respect to the independent variable and A is a constant matrix. + if type == "type2": + J = kwargs.get('J', None) + f_t = kwargs.get('f_t', None) + b = kwargs.get('b', None) + P = kwargs.get('P', None) + max_order = max(sys_order[func] for func in funcs) + + return _higher_order_type2_to_sub_systems(J, f_t, funcs, t, max_order, P=P, b=b) + + # Note: To be changed to this after doit option is disabled for default cases + # new_sysorder = _get_func_order(new_eqs, new_funcs) + # + # return _higher_order_to_first_order(new_eqs, new_sysorder, t, funcs=new_funcs) + + new_funcs = [] + + for prev_func in funcs: + func_name = prev_func.func.__name__ + func = Function(Dummy('{}_0'.format(func_name)))(t) + new_funcs.append(func) + subs_dict = {prev_func: func} + new_eqs = [] + + for i in range(1, sys_order[prev_func]): + new_func = Function(Dummy('{}_{}'.format(func_name, i)))(t) + subs_dict[prev_func.diff(t, i)] = new_func + new_funcs.append(new_func) + + prev_f = subs_dict[prev_func.diff(t, i-1)] + new_eq = Eq(prev_f.diff(t), new_func) + new_eqs.append(new_eq) + + eqs = [eq.subs(subs_dict) for eq in eqs] + new_eqs + + return eqs, new_funcs + + +def dsolve_system(eqs, funcs=None, t=None, ics=None, doit=False, simplify=True): + r""" + Solves any(supported) system of Ordinary Differential Equations + + Explanation + =========== + + This function takes a system of ODEs as an input, determines if the + it is solvable by this function, and returns the solution if found any. + + This function can handle: + 1. Linear, First Order, Constant coefficient homogeneous system of ODEs + 2. Linear, First Order, Constant coefficient non-homogeneous system of ODEs + 3. Linear, First Order, non-constant coefficient homogeneous system of ODEs + 4. Linear, First Order, non-constant coefficient non-homogeneous system of ODEs + 5. Any implicit system which can be divided into system of ODEs which is of the above 4 forms + 6. Any higher order linear system of ODEs that can be reduced to one of the 5 forms of systems described above. + + The types of systems described above are not limited by the number of equations, i.e. this + function can solve the above types irrespective of the number of equations in the system passed. + But, the bigger the system, the more time it will take to solve the system. + + This function returns a list of solutions. Each solution is a list of equations where LHS is + the dependent variable and RHS is an expression in terms of the independent variable. + + Among the non constant coefficient types, not all the systems are solvable by this function. Only + those which have either a coefficient matrix with a commutative antiderivative or those systems which + may be divided further so that the divided systems may have coefficient matrix with commutative antiderivative. + + Parameters + ========== + + eqs : List + system of ODEs to be solved + funcs : List or None + List of dependent variables that make up the system of ODEs + t : Symbol or None + Independent variable in the system of ODEs + ics : Dict or None + Set of initial boundary/conditions for the system of ODEs + doit : Boolean + Evaluate the solutions if True. Default value is True. Can be + set to false if the integral evaluation takes too much time and/or + is not required. + simplify: Boolean + Simplify the solutions for the systems. Default value is True. + Can be set to false if simplification takes too much time and/or + is not required. + + Examples + ======== + + >>> from sympy import symbols, Eq, Function + >>> from sympy.solvers.ode.systems import dsolve_system + >>> f, g = symbols("f g", cls=Function) + >>> x = symbols("x") + + >>> eqs = [Eq(f(x).diff(x), g(x)), Eq(g(x).diff(x), f(x))] + >>> dsolve_system(eqs) + [[Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), C1*exp(-x) + C2*exp(x))]] + + You can also pass the initial conditions for the system of ODEs: + + >>> dsolve_system(eqs, ics={f(0): 1, g(0): 0}) + [[Eq(f(x), exp(x)/2 + exp(-x)/2), Eq(g(x), exp(x)/2 - exp(-x)/2)]] + + Optionally, you can pass the dependent variables and the independent + variable for which the system is to be solved: + + >>> funcs = [f(x), g(x)] + >>> dsolve_system(eqs, funcs=funcs, t=x) + [[Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), C1*exp(-x) + C2*exp(x))]] + + Lets look at an implicit system of ODEs: + + >>> eqs = [Eq(f(x).diff(x)**2, g(x)**2), Eq(g(x).diff(x), g(x))] + >>> dsolve_system(eqs) + [[Eq(f(x), C1 - C2*exp(x)), Eq(g(x), C2*exp(x))], [Eq(f(x), C1 + C2*exp(x)), Eq(g(x), C2*exp(x))]] + + Returns + ======= + + List of List of Equations + + Raises + ====== + + NotImplementedError + When the system of ODEs is not solvable by this function. + ValueError + When the parameters passed are not in the required form. + + """ + from sympy.solvers.ode.ode import solve_ics, _extract_funcs, constant_renumber + + if not iterable(eqs): + raise ValueError(filldedent(''' + List of equations should be passed. The input is not valid. + ''')) + + eqs = _preprocess_eqs(eqs) + + if funcs is not None and not isinstance(funcs, list): + raise ValueError(filldedent(''' + Input to the funcs should be a list of functions. + ''')) + + if funcs is None: + funcs = _extract_funcs(eqs) + + if any(len(func.args) != 1 for func in funcs): + raise ValueError(filldedent(''' + dsolve_system can solve a system of ODEs with only one independent + variable. + ''')) + + if len(eqs) != len(funcs): + raise ValueError(filldedent(''' + Number of equations and number of functions do not match + ''')) + + if t is not None and not isinstance(t, Symbol): + raise ValueError(filldedent(''' + The independent variable must be of type Symbol + ''')) + + if t is None: + t = list(list(eqs[0].atoms(Derivative))[0].atoms(Symbol))[0] + + sols = [] + canon_eqs = canonical_odes(eqs, funcs, t) + + for canon_eq in canon_eqs: + try: + sol = _strong_component_solver(canon_eq, funcs, t) + except NotImplementedError: + sol = None + + if sol is None: + sol = _component_solver(canon_eq, funcs, t) + + sols.append(sol) + + if sols: + final_sols = [] + variables = Tuple(*eqs).free_symbols + + for sol in sols: + + sol = _select_equations(sol, funcs) + sol = constant_renumber(sol, variables=variables) + + if ics: + constants = Tuple(*sol).free_symbols - variables + solved_constants = solve_ics(sol, funcs, constants, ics) + sol = [s.subs(solved_constants) for s in sol] + + if simplify: + constants = Tuple(*sol).free_symbols - variables + sol = simpsol(sol, [t], constants, doit=doit) + + final_sols.append(sol) + + sols = final_sols + + return sols diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__init__.py b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..021b09c287ef95e9f3af0d7e186691e5cd0cd719 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_lie_group.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_lie_group.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03d15a0b21ddc5046964ada66fc7a0992fbcd185 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_lie_group.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_ode.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_ode.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3979eef03fc86b4e3316ad7c5d961f21c654e611 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_ode.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_riccati.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_riccati.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..dfda7d15d238293e230dc37bbeb10bd9f07a6ff8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_riccati.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_single.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_single.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f2d287a1eac0088cc7e86645f761886b42b7808 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_single.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_subscheck.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_subscheck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..16b664e9e0b089ca0220a5e9c41e5c95df35d82a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/ode/tests/__pycache__/test_subscheck.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_lie_group.py b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_lie_group.py new file mode 100644 index 0000000000000000000000000000000000000000..153d30ff563773819e49c989f447c1ec7962169b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_lie_group.py @@ -0,0 +1,152 @@ +from sympy.core.function import Function +from sympy.core.numbers import Rational +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan, sin, tan) + +from sympy.solvers.ode import (classify_ode, checkinfsol, dsolve, infinitesimals) + +from sympy.solvers.ode.subscheck import checkodesol + +from sympy.testing.pytest import XFAIL + + +C1 = Symbol('C1') +x, y = symbols("x y") +f = Function('f') +xi = Function('xi') +eta = Function('eta') + + +def test_heuristic1(): + a, b, c, a4, a3, a2, a1, a0 = symbols("a b c a4 a3 a2 a1 a0") + df = f(x).diff(x) + eq = Eq(df, x**2*f(x)) + eq1 = f(x).diff(x) + a*f(x) - c*exp(b*x) + eq2 = f(x).diff(x) + 2*x*f(x) - x*exp(-x**2) + eq3 = (1 + 2*x)*df + 2 - 4*exp(-f(x)) + eq4 = f(x).diff(x) - (a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**Rational(-1, 2) + eq5 = x**2*df - f(x) + x**2*exp(x - (1/x)) + eqlist = [eq, eq1, eq2, eq3, eq4, eq5] + + i = infinitesimals(eq, hint='abaco1_simple') + assert i == [{eta(x, f(x)): exp(x**3/3), xi(x, f(x)): 0}, + {eta(x, f(x)): f(x), xi(x, f(x)): 0}, + {eta(x, f(x)): 0, xi(x, f(x)): x**(-2)}] + i1 = infinitesimals(eq1, hint='abaco1_simple') + assert i1 == [{eta(x, f(x)): exp(-a*x), xi(x, f(x)): 0}] + i2 = infinitesimals(eq2, hint='abaco1_simple') + assert i2 == [{eta(x, f(x)): exp(-x**2), xi(x, f(x)): 0}] + i3 = infinitesimals(eq3, hint='abaco1_simple') + assert i3 == [{eta(x, f(x)): 0, xi(x, f(x)): 2*x + 1}, + {eta(x, f(x)): 0, xi(x, f(x)): 1/(exp(f(x)) - 2)}] + i4 = infinitesimals(eq4, hint='abaco1_simple') + assert i4 == [{eta(x, f(x)): 1, xi(x, f(x)): 0}, + {eta(x, f(x)): 0, + xi(x, f(x)): sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4)}] + i5 = infinitesimals(eq5, hint='abaco1_simple') + assert i5 == [{xi(x, f(x)): 0, eta(x, f(x)): exp(-1/x)}] + + ilist = [i, i1, i2, i3, i4, i5] + for eq, i in (zip(eqlist, ilist)): + check = checkinfsol(eq, i) + assert check[0] + + # This ODE can be solved by the Lie Group method, when there are + # better assumptions + eq6 = df - (f(x)/x)*(x*log(x**2/f(x)) + 2) + i = infinitesimals(eq6, hint='abaco1_product') + assert i == [{eta(x, f(x)): f(x)*exp(-x), xi(x, f(x)): 0}] + assert checkinfsol(eq6, i)[0] + + eq7 = x*(f(x).diff(x)) + 1 - f(x)**2 + i = infinitesimals(eq7, hint='chi') + assert checkinfsol(eq7, i)[0] + + +def test_heuristic3(): + a, b = symbols("a b") + df = f(x).diff(x) + + eq = x**2*df + x*f(x) + f(x)**2 + x**2 + i = infinitesimals(eq, hint='bivariate') + assert i == [{eta(x, f(x)): f(x), xi(x, f(x)): x}] + assert checkinfsol(eq, i)[0] + + eq = x**2*(-f(x)**2 + df)- a*x**2*f(x) + 2 - a*x + i = infinitesimals(eq, hint='bivariate') + assert checkinfsol(eq, i)[0] + + +def test_heuristic_function_sum(): + eq = f(x).diff(x) - (3*(1 + x**2/f(x)**2)*atan(f(x)/x) + (1 - 2*f(x))/x + + (1 - 3*f(x))*(x/f(x)**2)) + i = infinitesimals(eq, hint='function_sum') + assert i == [{eta(x, f(x)): f(x)**(-2) + x**(-2), xi(x, f(x)): 0}] + assert checkinfsol(eq, i)[0] + + +def test_heuristic_abaco2_similar(): + a, b = symbols("a b") + F = Function('F') + eq = f(x).diff(x) - F(a*x + b*f(x)) + i = infinitesimals(eq, hint='abaco2_similar') + assert i == [{eta(x, f(x)): -a/b, xi(x, f(x)): 1}] + assert checkinfsol(eq, i)[0] + + eq = f(x).diff(x) - (f(x)**2 / (sin(f(x) - x) - x**2 + 2*x*f(x))) + i = infinitesimals(eq, hint='abaco2_similar') + assert i == [{eta(x, f(x)): f(x)**2, xi(x, f(x)): f(x)**2}] + assert checkinfsol(eq, i)[0] + + +def test_heuristic_abaco2_unique_unknown(): + + a, b = symbols("a b") + F = Function('F') + eq = f(x).diff(x) - x**(a - 1)*(f(x)**(1 - b))*F(x**a/a + f(x)**b/b) + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert i == [{eta(x, f(x)): -f(x)*f(x)**(-b), xi(x, f(x)): x*x**(-a)}] + assert checkinfsol(eq, i)[0] + + eq = f(x).diff(x) + tan(F(x**2 + f(x)**2) + atan(x/f(x))) + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert i == [{eta(x, f(x)): x, xi(x, f(x)): -f(x)}] + assert checkinfsol(eq, i)[0] + + eq = (x*f(x).diff(x) + f(x) + 2*x)**2 -4*x*f(x) -4*x**2 -4*a + i = infinitesimals(eq, hint='abaco2_unique_unknown') + assert checkinfsol(eq, i)[0] + + +def test_heuristic_linear(): + a, b, m, n = symbols("a b m n") + + eq = x**(n*(m + 1) - m)*(f(x).diff(x)) - a*f(x)**n -b*x**(n*(m + 1)) + i = infinitesimals(eq, hint='linear') + assert checkinfsol(eq, i)[0] + + +@XFAIL +def test_kamke(): + a, b, alpha, c = symbols("a b alpha c") + eq = x**2*(a*f(x)**2+(f(x).diff(x))) + b*x**alpha + c + i = infinitesimals(eq, hint='sum_function') # XFAIL + assert checkinfsol(eq, i)[0] + + +def test_user_infinitesimals(): + x = Symbol("x") # assuming x is real generates an error + eq = x*(f(x).diff(x)) + 1 - f(x)**2 + sol = Eq(f(x), (C1 + x**2)/(C1 - x**2)) + infinitesimals = {'xi':sqrt(f(x) - 1)/sqrt(f(x) + 1), 'eta':0} + assert dsolve(eq, hint='lie_group', **infinitesimals) == sol + assert checkodesol(eq, sol) == (True, 0) + + +@XFAIL +def test_lie_group_issue15219(): + eqn = exp(f(x).diff(x)-f(x)) + assert 'lie_group' not in classify_ode(eqn, f(x)) diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_ode.py b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_ode.py new file mode 100644 index 0000000000000000000000000000000000000000..b1ddcc784fde15c1176feb23ca47f4adf6fddbff --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_ode.py @@ -0,0 +1,1104 @@ +from sympy.core.function import (Derivative, Function, Subs, diff) +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import acosh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (atan2, cos, sin, tan) +from sympy.integrals.integrals import Integral +from sympy.polys.polytools import Poly +from sympy.series.order import O +from sympy.simplify.radsimp import collect + +from sympy.solvers.ode import (classify_ode, + homogeneous_order, dsolve) + +from sympy.solvers.ode.subscheck import checkodesol +from sympy.solvers.ode.ode import (classify_sysode, + constant_renumber, constantsimp, get_numbered_constants, solve_ics) + +from sympy.solvers.ode.nonhomogeneous import _undetermined_coefficients_match +from sympy.solvers.ode.single import LinearCoefficients +from sympy.solvers.deutils import ode_order +from sympy.testing.pytest import XFAIL, raises, slow, SKIP +from sympy.utilities.misc import filldedent + + +C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C0:11') +u, x, y, z = symbols('u,x:z', real=True) +f = Function('f') +g = Function('g') +h = Function('h') + +# Note: Examples which were specifically testing Single ODE solver are moved to test_single.py +# and all the system of ode examples are moved to test_systems.py +# Note: the tests below may fail (but still be correct) if ODE solver, +# the integral engine, solve(), or even simplify() changes. Also, in +# differently formatted solutions, the arbitrary constants might not be +# equal. Using specific hints in tests can help to avoid this. + +# Tests of order higher than 1 should run the solutions through +# constant_renumber because it will normalize it (constant_renumber causes +# dsolve() to return different results on different machines) + + +def test_get_numbered_constants(): + with raises(ValueError): + get_numbered_constants(None) + + +def test_dsolve_all_hint(): + eq = f(x).diff(x) + output = dsolve(eq, hint='all') + + # Match the Dummy variables: + sol1 = output['separable_Integral'] + _y = sol1.lhs.args[1][0] + sol1 = output['1st_homogeneous_coeff_subs_dep_div_indep_Integral'] + _u1 = sol1.rhs.args[1].args[1][0] + + expected = {'Bernoulli_Integral': Eq(f(x), C1 + Integral(0, x)), + '1st_homogeneous_coeff_best': Eq(f(x), C1), + 'Bernoulli': Eq(f(x), C1), + 'nth_algebraic': Eq(f(x), C1), + 'nth_linear_euler_eq_homogeneous': Eq(f(x), C1), + 'nth_linear_constant_coeff_homogeneous': Eq(f(x), C1), + 'separable': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_indep_div_dep': Eq(f(x), C1), + 'nth_algebraic_Integral': Eq(f(x), C1), + '1st_linear': Eq(f(x), C1), + '1st_linear_Integral': Eq(f(x), C1 + Integral(0, x)), + '1st_exact': Eq(f(x), C1), + '1st_exact_Integral': Eq(Subs(Integral(0, x) + Integral(1, _y), _y, f(x)), C1), + 'lie_group': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_dep_div_indep': Eq(f(x), C1), + '1st_homogeneous_coeff_subs_dep_div_indep_Integral': Eq(log(x), C1 + Integral(-1/_u1, (_u1, f(x)/x))), + '1st_power_series': Eq(f(x), C1), + 'separable_Integral': Eq(Integral(1, (_y, f(x))), C1 + Integral(0, x)), + '1st_homogeneous_coeff_subs_indep_div_dep_Integral': Eq(f(x), C1), + 'best': Eq(f(x), C1), + 'best_hint': 'nth_algebraic', + 'default': 'nth_algebraic', + 'order': 1} + assert output == expected + + assert dsolve(eq, hint='best') == Eq(f(x), C1) + + +def test_dsolve_ics(): + # Maybe this should just use one of the solutions instead of raising... + with raises(NotImplementedError): + dsolve(f(x).diff(x) - sqrt(f(x)), ics={f(1):1}) + + +@slow +def test_dsolve_options(): + eq = x*f(x).diff(x) + f(x) + a = dsolve(eq, hint='all') + b = dsolve(eq, hint='all', simplify=False) + c = dsolve(eq, hint='all_Integral') + keys = ['1st_exact', '1st_exact_Integral', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', '1st_linear', + '1st_linear_Integral', 'Bernoulli', 'Bernoulli_Integral', + 'almost_linear', 'almost_linear_Integral', 'best', 'best_hint', + 'default', 'factorable', 'lie_group', + 'nth_linear_euler_eq_homogeneous', 'order', + 'separable', 'separable_Integral'] + Integral_keys = ['1st_exact_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', '1st_linear_Integral', + 'Bernoulli_Integral', 'almost_linear_Integral', 'best', 'best_hint', 'default', + 'factorable', 'nth_linear_euler_eq_homogeneous', + 'order', 'separable_Integral'] + assert sorted(a.keys()) == keys + assert a['order'] == ode_order(eq, f(x)) + assert a['best'] == Eq(f(x), C1/x) + assert dsolve(eq, hint='best') == Eq(f(x), C1/x) + assert a['default'] == 'factorable' + assert a['best_hint'] == 'factorable' + assert not a['1st_exact'].has(Integral) + assert not a['separable'].has(Integral) + assert not a['1st_homogeneous_coeff_best'].has(Integral) + assert not a['1st_homogeneous_coeff_subs_dep_div_indep'].has(Integral) + assert not a['1st_homogeneous_coeff_subs_indep_div_dep'].has(Integral) + assert not a['1st_linear'].has(Integral) + assert a['1st_linear_Integral'].has(Integral) + assert a['1st_exact_Integral'].has(Integral) + assert a['1st_homogeneous_coeff_subs_dep_div_indep_Integral'].has(Integral) + assert a['1st_homogeneous_coeff_subs_indep_div_dep_Integral'].has(Integral) + assert a['separable_Integral'].has(Integral) + assert sorted(b.keys()) == keys + assert b['order'] == ode_order(eq, f(x)) + assert b['best'] == Eq(f(x), C1/x) + assert dsolve(eq, hint='best', simplify=False) == Eq(f(x), C1/x) + assert b['default'] == 'factorable' + assert b['best_hint'] == 'factorable' + assert a['separable'] != b['separable'] + assert a['1st_homogeneous_coeff_subs_dep_div_indep'] != \ + b['1st_homogeneous_coeff_subs_dep_div_indep'] + assert a['1st_homogeneous_coeff_subs_indep_div_dep'] != \ + b['1st_homogeneous_coeff_subs_indep_div_dep'] + assert not b['1st_exact'].has(Integral) + assert not b['separable'].has(Integral) + assert not b['1st_homogeneous_coeff_best'].has(Integral) + assert not b['1st_homogeneous_coeff_subs_dep_div_indep'].has(Integral) + assert not b['1st_homogeneous_coeff_subs_indep_div_dep'].has(Integral) + assert not b['1st_linear'].has(Integral) + assert b['1st_linear_Integral'].has(Integral) + assert b['1st_exact_Integral'].has(Integral) + assert b['1st_homogeneous_coeff_subs_dep_div_indep_Integral'].has(Integral) + assert b['1st_homogeneous_coeff_subs_indep_div_dep_Integral'].has(Integral) + assert b['separable_Integral'].has(Integral) + assert sorted(c.keys()) == Integral_keys + raises(ValueError, lambda: dsolve(eq, hint='notarealhint')) + raises(ValueError, lambda: dsolve(eq, hint='Liouville')) + assert dsolve(f(x).diff(x) - 1/f(x)**2, hint='all')['best'] == \ + dsolve(f(x).diff(x) - 1/f(x)**2, hint='best') + assert dsolve(f(x) + f(x).diff(x) + sin(x).diff(x) + 1, f(x), + hint="1st_linear_Integral") == \ + Eq(f(x), (C1 + Integral((-sin(x).diff(x) - 1)* + exp(Integral(1, x)), x))*exp(-Integral(1, x))) + + +def test_classify_ode(): + assert classify_ode(f(x).diff(x, 2), f(x)) == \ + ( + 'nth_algebraic', + 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'Liouville', + '2nd_power_series_ordinary', + 'nth_algebraic_Integral', + 'Liouville_Integral', + ) + assert classify_ode(f(x), f(x)) == ('nth_algebraic', 'nth_algebraic_Integral') + assert classify_ode(Eq(f(x).diff(x), 0), f(x)) == ( + 'nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_homogeneous', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', + 'separable_Integral', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + assert classify_ode(f(x).diff(x)**2, f(x)) == ('factorable', + 'nth_algebraic', + 'separable', + '1st_exact', + '1st_linear', + 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', + 'lie_group', + 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', + 'separable_Integral', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + # issue 4749: f(x) should be cleared from highest derivative before classifying + a = classify_ode(Eq(f(x).diff(x) + f(x), x), f(x)) + b = classify_ode(f(x).diff(x)*f(x) + f(x)*f(x) - x*f(x), f(x)) + c = classify_ode(f(x).diff(x)/f(x) + f(x)/f(x) - x/f(x), f(x)) + assert a == ('1st_exact', + '1st_linear', + 'Bernoulli', + 'almost_linear', + '1st_power_series', "lie_group", + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'almost_linear_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + assert b == ('factorable', + '1st_linear', + 'Bernoulli', + '1st_power_series', + 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + assert c == ('factorable', + '1st_linear', + 'Bernoulli', + '1st_power_series', + 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_linear_Integral', + 'Bernoulli_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + + assert classify_ode( + 2*x*f(x)*f(x).diff(x) + (1 + x)*f(x)**2 - exp(x), f(x) + ) == ('factorable', '1st_exact', 'Bernoulli', 'almost_linear', 'lie_group', + '1st_exact_Integral', 'Bernoulli_Integral', 'almost_linear_Integral') + assert 'Riccati_special_minus2' in \ + classify_ode(2*f(x).diff(x) + f(x)**2 - f(x)/x + 3*x**(-2), f(x)) + raises(ValueError, lambda: classify_ode(x + f(x, y).diff(x).diff( + y), f(x, y))) + # issue 5176 + k = Symbol('k') + assert classify_ode(f(x).diff(x)/(k*f(x) + k*x*f(x)) + 2*f(x)/(k*f(x) + + k*x*f(x)) + x*f(x).diff(x)/(k*f(x) + k*x*f(x)) + z, f(x)) == \ + ('factorable', 'separable', '1st_exact', '1st_linear', 'Bernoulli', + '1st_power_series', 'lie_group', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral') + # preprocessing + ans = ('factorable', 'nth_algebraic', 'separable', '1st_exact', '1st_linear', 'Bernoulli', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters', + 'nth_algebraic_Integral', + 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', + 'Bernoulli_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters_Integral') + # w/o f(x) given + assert classify_ode(diff(f(x) + x, x) + diff(f(x), x)) == ans + # w/ f(x) and prep=True + assert classify_ode(diff(f(x) + x, x) + diff(f(x), x), f(x), + prep=True) == ans + + assert classify_ode(Eq(2*x**3*f(x).diff(x), 0), f(x)) == \ + ('factorable', 'nth_algebraic', 'separable', '1st_exact', + '1st_linear', 'Bernoulli', '1st_power_series', + 'lie_group', 'nth_linear_euler_eq_homogeneous', + 'nth_algebraic_Integral', 'separable_Integral', '1st_exact_Integral', + '1st_linear_Integral', 'Bernoulli_Integral') + + + assert classify_ode(Eq(2*f(x)**3*f(x).diff(x), 0), f(x)) == \ + ('factorable', 'nth_algebraic', 'separable', '1st_exact', '1st_linear', + 'Bernoulli', '1st_power_series', 'lie_group', 'nth_algebraic_Integral', + 'separable_Integral', '1st_exact_Integral', '1st_linear_Integral', + 'Bernoulli_Integral') + # test issue 13864 + assert classify_ode(Eq(diff(f(x), x) - f(x)**x, 0), f(x)) == \ + ('1st_power_series', 'lie_group') + assert isinstance(classify_ode(Eq(f(x), 5), f(x), dict=True), dict) + + #This is for new behavior of classify_ode when called internally with default, It should + # return the first hint which matches therefore, 'ordered_hints' key will not be there. + assert sorted(classify_ode(Eq(f(x).diff(x), 0), f(x), dict=True).keys()) == \ + ['default', 'nth_linear_constant_coeff_homogeneous', 'order'] + a = classify_ode(2*x*f(x)*f(x).diff(x) + (1 + x)*f(x)**2 - exp(x), f(x), dict=True, hint='Bernoulli') + assert sorted(a.keys()) == ['Bernoulli', 'Bernoulli_Integral', 'default', 'order', 'ordered_hints'] + + # test issue 22155 + a = classify_ode(f(x).diff(x) - exp(f(x) - x), f(x)) + assert a == ('separable', + '1st_exact', '1st_power_series', + 'lie_group', 'separable_Integral', + '1st_exact_Integral') + + +def test_classify_ode_ics(): + # Dummy + eq = f(x).diff(x, x) - f(x) + + # Not f(0) or f'(0) + ics = {x: 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + + ############################ + # f(0) type (AppliedUndef) # + ############################ + + + # Wrong function + ics = {g(0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(0, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(0): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(0): f(0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(0): 1} + classify_ode(eq, f(x), ics=ics) + + + ##################### + # f'(0) type (Subs) # + ##################### + + # Wrong function + ics = {g(x).diff(x).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(y).diff(y).subs(y, x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Wrong variable + ics = {f(y).diff(y).subs(y, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(x, y).diff(x).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Derivative wrt wrong vars + ics = {Derivative(f(x), x, y).subs(x, 0): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(x).diff(x).subs(x, 0): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): f(x).diff(x).subs(x, 0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): 1} + classify_ode(eq, f(x), ics=ics) + + ########################### + # f'(y) type (Derivative) # + ########################### + + # Wrong function + ics = {g(x).diff(x).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Contains x + ics = {f(y).diff(y).subs(y, x): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Too many args + ics = {f(x, y).diff(x).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Derivative wrt wrong vars + ics = {Derivative(f(x), x, z).subs(x, y): 1} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # point contains x + ics = {f(x).diff(x).subs(x, y): f(x)} + raises(ValueError, lambda: classify_ode(eq, f(x), ics=ics)) + + # Does not raise + ics = {f(x).diff(x).subs(x, 0): f(0)} + classify_ode(eq, f(x), ics=ics) + + # Does not raise + ics = {f(x).diff(x).subs(x, y): 1} + classify_ode(eq, f(x), ics=ics) + +def test_classify_sysode(): + # Here x is assumed to be x(t) and y as y(t) for simplicity. + # Similarly diff(x,t) and diff(y,y) is assumed to be x1 and y1 respectively. + k, l, m, n = symbols('k, l, m, n', Integer=True) + k1, k2, k3, l1, l2, l3, m1, m2, m3 = symbols('k1, k2, k3, l1, l2, l3, m1, m2, m3', Integer=True) + P, Q, R, p, q, r = symbols('P, Q, R, p, q, r', cls=Function) + P1, P2, P3, Q1, Q2, R1, R2 = symbols('P1, P2, P3, Q1, Q2, R1, R2', cls=Function) + x, y, z = symbols('x, y, z', cls=Function) + t = symbols('t') + x1 = diff(x(t),t) ; y1 = diff(y(t),t) ; + + eq6 = (Eq(x1, exp(k*x(t))*P(x(t),y(t))), Eq(y1,r(y(t))*P(x(t),y(t)))) + sol6 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): 0, (1, x(t), 1): 0, (0, x(t), 1): 1, (1, y(t), 0): 0, \ + (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): 0, (1, y(t), 1): 1}, 'type_of_equation': 'type2', 'func': \ + [x(t), y(t)], 'is_linear': False, 'eq': [-P(x(t), y(t))*exp(k*x(t)) + Derivative(x(t), t), -P(x(t), \ + y(t))*r(y(t)) + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq6) == sol6 + + eq7 = (Eq(x1, x(t)**2+y(t)/x(t)), Eq(y1, x(t)/y(t))) + sol7 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): 0, (1, x(t), 1): 0, (0, x(t), 1): 1, (1, y(t), 0): 0, \ + (1, x(t), 0): -1/y(t), (0, y(t), 1): 0, (0, y(t), 0): -1/x(t), (1, y(t), 1): 1}, 'type_of_equation': 'type3', \ + 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)**2 + Derivative(x(t), t) - y(t)/x(t), -x(t)/y(t) + \ + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq7) == sol7 + + eq8 = (Eq(x1, P1(x(t))*Q1(y(t))*R(x(t),y(t),t)), Eq(y1, P1(x(t))*Q1(y(t))*R(x(t),y(t),t))) + sol8 = {'func': [x(t), y(t)], 'is_linear': False, 'type_of_equation': 'type4', 'eq': \ + [-P1(x(t))*Q1(y(t))*R(x(t), y(t), t) + Derivative(x(t), t), -P1(x(t))*Q1(y(t))*R(x(t), y(t), t) + \ + Derivative(y(t), t)], 'func_coeff': {(0, y(t), 1): 0, (1, y(t), 1): 1, (1, x(t), 1): 0, (0, y(t), 0): 0, \ + (1, x(t), 0): 0, (0, x(t), 0): 0, (1, y(t), 0): 0, (0, x(t), 1): 1}, 'order': {y(t): 1, x(t): 1}, 'no_of_equation': 2} + assert classify_sysode(eq8) == sol8 + + eq11 = (Eq(x1,x(t)*y(t)**3), Eq(y1,y(t)**5)) + sol11 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): -y(t)**3, (1, x(t), 1): 0, (0, x(t), 1): 1, \ + (1, y(t), 0): 0, (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): 0, (1, y(t), 1): 1}, 'type_of_equation': \ + 'type1', 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)*y(t)**3 + Derivative(x(t), t), \ + -y(t)**5 + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq11) == sol11 + + eq13 = (Eq(x1,x(t)*y(t)*sin(t)**2), Eq(y1,y(t)**2*sin(t)**2)) + sol13 = {'no_of_equation': 2, 'func_coeff': {(0, x(t), 0): -y(t)*sin(t)**2, (1, x(t), 1): 0, (0, x(t), 1): 1, \ + (1, y(t), 0): 0, (1, x(t), 0): 0, (0, y(t), 1): 0, (0, y(t), 0): -x(t)*sin(t)**2, (1, y(t), 1): 1}, \ + 'type_of_equation': 'type4', 'func': [x(t), y(t)], 'is_linear': False, 'eq': [-x(t)*y(t)*sin(t)**2 + \ + Derivative(x(t), t), -y(t)**2*sin(t)**2 + Derivative(y(t), t)], 'order': {y(t): 1, x(t): 1}} + assert classify_sysode(eq13) == sol13 + + +def test_solve_ics(): + # Basic tests that things work from dsolve. + assert dsolve(f(x).diff(x) - 1/f(x), f(x), ics={f(1): 2}) == \ + Eq(f(x), sqrt(2 * x + 2)) + assert dsolve(f(x).diff(x) - f(x), f(x), ics={f(0): 1}) == Eq(f(x), exp(x)) + assert dsolve(f(x).diff(x) - f(x), f(x), ics={f(x).diff(x).subs(x, 0): 1}) == Eq(f(x), exp(x)) + assert dsolve(f(x).diff(x, x) + f(x), f(x), ics={f(0): 1, + f(x).diff(x).subs(x, 0): 1}) == Eq(f(x), sin(x) + cos(x)) + assert dsolve([f(x).diff(x) - f(x) + g(x), g(x).diff(x) - g(x) - f(x)], + [f(x), g(x)], ics={f(0): 1, g(0): 0}) == [Eq(f(x), exp(x)*cos(x)), Eq(g(x), exp(x)*sin(x))] + + # Test cases where dsolve returns two solutions. + eq = (x**2*f(x)**2 - x).diff(x) + assert dsolve(eq, f(x), ics={f(1): 0}) == [Eq(f(x), + -sqrt(x - 1)/x), Eq(f(x), sqrt(x - 1)/x)] + assert dsolve(eq, f(x), ics={f(x).diff(x).subs(x, 1): 0}) == [Eq(f(x), + -sqrt(x - S.Half)/x), Eq(f(x), sqrt(x - S.Half)/x)] + + eq = cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x) + assert dsolve(eq, f(x), + ics={f(0):1}, hint='1st_exact', simplify=False) == Eq(x*cos(f(x)) + f(x)**3/3, Rational(1, 3)) + assert dsolve(eq, f(x), + ics={f(0):1}, hint='1st_exact', simplify=True) == Eq(x*cos(f(x)) + f(x)**3/3, Rational(1, 3)) + + assert solve_ics([Eq(f(x), C1*exp(x))], [f(x)], [C1], {f(0): 1}) == {C1: 1} + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], + {f(0): 1, f(pi/2): 1}) == {C1: 1, C2: 1} + + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], + {f(0): 1, f(x).diff(x).subs(x, 0): 1}) == {C1: 1, C2: 1} + + assert solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], {f(0): 1}) == \ + {C2: 1} + + # Some more complicated tests Refer to PR #16098 + + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0, f(x).diff(x).subs(x, 1):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6 - x / 2)} + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0})) == \ + {Eq(f(x), 0), Eq(f(x), C2*x + x**3/6)} + + K, r, f0 = symbols('K r f0') + sol = Eq(f(x), K*f0*exp(r*x)/((-K + f0)*(f0*exp(r*x)/(-K + f0) - 1))) + assert (dsolve(Eq(f(x).diff(x), r * f(x) * (1 - f(x) / K)), f(x), ics={f(0): f0})) == sol + + + #Order dependent issues Refer to PR #16098 + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(x).diff(x).subs(x,0):0, f(0):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6)} + assert set(dsolve(f(x).diff(x)*(f(x).diff(x, 2)-x), ics={f(0):0, f(x).diff(x).subs(x,0):0})) == \ + {Eq(f(x), 0), Eq(f(x), x ** 3 / 6)} + + # XXX: Ought to be ValueError + raises(ValueError, lambda: solve_ics([Eq(f(x), C1*sin(x) + C2*cos(x))], [f(x)], [C1, C2], {f(0): 1, f(pi): 1})) + + # Degenerate case. f'(0) is identically 0. + raises(ValueError, lambda: solve_ics([Eq(f(x), sqrt(C1 - x**2))], [f(x)], [C1], {f(x).diff(x).subs(x, 0): 0})) + + EI, q, L = symbols('EI q L') + + # eq = Eq(EI*diff(f(x), x, 4), q) + sols = [Eq(f(x), C1 + C2*x + C3*x**2 + C4*x**3 + q*x**4/(24*EI))] + funcs = [f(x)] + constants = [C1, C2, C3, C4] + # Test both cases, Derivative (the default from f(x).diff(x).subs(x, L)), + # and Subs + ics1 = {f(0): 0, + f(x).diff(x).subs(x, 0): 0, + f(L).diff(L, 2): 0, + f(L).diff(L, 3): 0} + ics2 = {f(0): 0, + f(x).diff(x).subs(x, 0): 0, + Subs(f(x).diff(x, 2), x, L): 0, + Subs(f(x).diff(x, 3), x, L): 0} + + solved_constants1 = solve_ics(sols, funcs, constants, ics1) + solved_constants2 = solve_ics(sols, funcs, constants, ics2) + assert solved_constants1 == solved_constants2 == { + C1: 0, + C2: 0, + C3: L**2*q/(4*EI), + C4: -L*q/(6*EI)} + + # Allow the ics to refer to f + ics = {f(0): f(0)} + assert dsolve(f(x).diff(x) - f(x), f(x), ics=ics) == Eq(f(x), f(0)*exp(x)) + + ics = {f(x).diff(x).subs(x, 0): f(x).diff(x).subs(x, 0), f(0): f(0)} + assert dsolve(f(x).diff(x, x) + f(x), f(x), ics=ics) == \ + Eq(f(x), f(0)*cos(x) + f(x).diff(x).subs(x, 0)*sin(x)) + +def test_ode_order(): + f = Function('f') + g = Function('g') + x = Symbol('x') + assert ode_order(3*x*exp(f(x)), f(x)) == 0 + assert ode_order(x*diff(f(x), x) + 3*x*f(x) - sin(x)/x, f(x)) == 1 + assert ode_order(x**2*f(x).diff(x, x) + x*diff(f(x), x) - f(x), f(x)) == 2 + assert ode_order(diff(x*exp(f(x)), x, x), f(x)) == 2 + assert ode_order(diff(x*diff(x*exp(f(x)), x, x), x), f(x)) == 3 + assert ode_order(diff(f(x), x, x), g(x)) == 0 + assert ode_order(diff(f(x), x, x)*diff(g(x), x), f(x)) == 2 + assert ode_order(diff(f(x), x, x)*diff(g(x), x), g(x)) == 1 + assert ode_order(diff(x*diff(x*exp(f(x)), x, x), x), g(x)) == 0 + # issue 5835: ode_order has to also work for unevaluated derivatives + # (ie, without using doit()). + assert ode_order(Derivative(x*f(x), x), f(x)) == 1 + assert ode_order(x*sin(Derivative(x*f(x)**2, x, x)), f(x)) == 2 + assert ode_order(Derivative(x*Derivative(x*exp(f(x)), x, x), x), g(x)) == 0 + assert ode_order(Derivative(f(x), x, x), g(x)) == 0 + assert ode_order(Derivative(x*exp(f(x)), x, x), f(x)) == 2 + assert ode_order(Derivative(f(x), x, x)*Derivative(g(x), x), g(x)) == 1 + assert ode_order(Derivative(x*Derivative(f(x), x, x), x), f(x)) == 3 + assert ode_order( + x*sin(Derivative(x*Derivative(f(x), x)**2, x, x)), f(x)) == 3 + + +def test_homogeneous_order(): + assert homogeneous_order(exp(y/x) + tan(y/x), x, y) == 0 + assert homogeneous_order(x**2 + sin(x)*cos(y), x, y) is None + assert homogeneous_order(x - y - x*sin(y/x), x, y) == 1 + assert homogeneous_order((x*y + sqrt(x**4 + y**4) + x**2*(log(x) - log(y)))/ + (pi*x**Rational(2, 3)*sqrt(y)**3), x, y) == Rational(-1, 6) + assert homogeneous_order(y/x*cos(y/x) - x/y*sin(y/x) + cos(y/x), x, y) == 0 + assert homogeneous_order(f(x), x, f(x)) == 1 + assert homogeneous_order(f(x)**2, x, f(x)) == 2 + assert homogeneous_order(x*y*z, x, y) == 2 + assert homogeneous_order(x*y*z, x, y, z) == 3 + assert homogeneous_order(x**2*f(x)/sqrt(x**2 + f(x)**2), f(x)) is None + assert homogeneous_order(f(x, y)**2, x, f(x, y), y) == 2 + assert homogeneous_order(f(x, y)**2, x, f(x), y) is None + assert homogeneous_order(f(x, y)**2, x, f(x, y)) is None + assert homogeneous_order(f(y, x)**2, x, y, f(x, y)) is None + assert homogeneous_order(f(y), f(x), x) is None + assert homogeneous_order(-f(x)/x + 1/sin(f(x)/ x), f(x), x) == 0 + assert homogeneous_order(log(1/y) + log(x**2), x, y) is None + assert homogeneous_order(log(1/y) + log(x), x, y) == 0 + assert homogeneous_order(log(x/y), x, y) == 0 + assert homogeneous_order(2*log(1/y) + 2*log(x), x, y) == 0 + a = Symbol('a') + assert homogeneous_order(a*log(1/y) + a*log(x), x, y) == 0 + assert homogeneous_order(f(x).diff(x), x, y) is None + assert homogeneous_order(-f(x).diff(x) + x, x, y) is None + assert homogeneous_order(O(x), x, y) is None + assert homogeneous_order(x + O(x**2), x, y) is None + assert homogeneous_order(x**pi, x) == pi + assert homogeneous_order(x**x, x) is None + raises(ValueError, lambda: homogeneous_order(x*y)) + + +@XFAIL +def test_noncircularized_real_imaginary_parts(): + # If this passes, lines numbered 3878-3882 (at the time of this commit) + # of sympy/solvers/ode.py for nth_linear_constant_coeff_homogeneous + # should be removed. + y = sqrt(1+x) + i, r = im(y), re(y) + assert not (i.has(atan2) and r.has(atan2)) + + +def test_collect_respecting_exponentials(): + # If this test passes, lines 1306-1311 (at the time of this commit) + # of sympy/solvers/ode.py should be removed. + sol = 1 + exp(x/2) + assert sol == collect( sol, exp(x/3)) + + +def test_undetermined_coefficients_match(): + assert _undetermined_coefficients_match(g(x), x) == {'test': False} + assert _undetermined_coefficients_match(sin(2*x + sqrt(5)), x) == \ + {'test': True, 'trialset': + {cos(2*x + sqrt(5)), sin(2*x + sqrt(5))}} + assert _undetermined_coefficients_match(sin(x)*cos(x), x) == \ + {'test': False} + s = {cos(x), x*cos(x), x**2*cos(x), x**2*sin(x), x*sin(x), sin(x)} + assert _undetermined_coefficients_match(sin(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': s} + assert _undetermined_coefficients_match( + sin(x)*x**2 + sin(x)*x + sin(x), x) == {'test': True, 'trialset': s} + assert _undetermined_coefficients_match( + exp(2*x)*sin(x)*(x**2 + x + 1), x + ) == { + 'test': True, 'trialset': {exp(2*x)*sin(x), x**2*exp(2*x)*sin(x), + cos(x)*exp(2*x), x**2*cos(x)*exp(2*x), x*cos(x)*exp(2*x), + x*exp(2*x)*sin(x)}} + assert _undetermined_coefficients_match(1/sin(x), x) == {'test': False} + assert _undetermined_coefficients_match(log(x), x) == {'test': False} + assert _undetermined_coefficients_match(2**(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': {2**x, x*2**x, x**2*2**x}} + assert _undetermined_coefficients_match(x**y, x) == {'test': False} + assert _undetermined_coefficients_match(exp(x)*exp(2*x + 1), x) == \ + {'test': True, 'trialset': {exp(1 + 3*x)}} + assert _undetermined_coefficients_match(sin(x)*(x**2 + x + 1), x) == \ + {'test': True, 'trialset': {x*cos(x), x*sin(x), x**2*cos(x), + x**2*sin(x), cos(x), sin(x)}} + assert _undetermined_coefficients_match(sin(x)*(x + sin(x)), x) == \ + {'test': False} + assert _undetermined_coefficients_match(sin(x)*(x + sin(2*x)), x) == \ + {'test': False} + assert _undetermined_coefficients_match(sin(x)*tan(x), x) == \ + {'test': False} + assert _undetermined_coefficients_match( + x**2*sin(x)*exp(x) + x*sin(x) + x, x + ) == { + 'test': True, 'trialset': {x**2*cos(x)*exp(x), x, cos(x), S.One, + exp(x)*sin(x), sin(x), x*exp(x)*sin(x), x*cos(x), x*cos(x)*exp(x), + x*sin(x), cos(x)*exp(x), x**2*exp(x)*sin(x)}} + assert _undetermined_coefficients_match(4*x*sin(x - 2), x) == { + 'trialset': {x*cos(x - 2), x*sin(x - 2), cos(x - 2), sin(x - 2)}, + 'test': True, + } + assert _undetermined_coefficients_match(2**x*x, x) == \ + {'test': True, 'trialset': {2**x, x*2**x}} + assert _undetermined_coefficients_match(2**x*exp(2*x), x) == \ + {'test': True, 'trialset': {2**x*exp(2*x)}} + assert _undetermined_coefficients_match(exp(-x)/x, x) == \ + {'test': False} + # Below are from Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 231 + assert _undetermined_coefficients_match(S(4), x) == \ + {'test': True, 'trialset': {S.One}} + assert _undetermined_coefficients_match(12*exp(x), x) == \ + {'test': True, 'trialset': {exp(x)}} + assert _undetermined_coefficients_match(exp(I*x), x) == \ + {'test': True, 'trialset': {exp(I*x)}} + assert _undetermined_coefficients_match(sin(x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x)}} + assert _undetermined_coefficients_match(cos(x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x)}} + assert _undetermined_coefficients_match(8 + 6*exp(x) + 2*sin(x), x) == \ + {'test': True, 'trialset': {S.One, cos(x), sin(x), exp(x)}} + assert _undetermined_coefficients_match(x**2, x) == \ + {'test': True, 'trialset': {S.One, x, x**2}} + assert _undetermined_coefficients_match(9*x*exp(x) + exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(x), exp(x), exp(-x)}} + assert _undetermined_coefficients_match(2*exp(2*x)*sin(x), x) == \ + {'test': True, 'trialset': {exp(2*x)*sin(x), cos(x)*exp(2*x)}} + assert _undetermined_coefficients_match(x - sin(x), x) == \ + {'test': True, 'trialset': {S.One, x, cos(x), sin(x)}} + assert _undetermined_coefficients_match(x**2 + 2*x, x) == \ + {'test': True, 'trialset': {S.One, x, x**2}} + assert _undetermined_coefficients_match(4*x*sin(x), x) == \ + {'test': True, 'trialset': {x*cos(x), x*sin(x), cos(x), sin(x)}} + assert _undetermined_coefficients_match(x*sin(2*x), x) == \ + {'test': True, 'trialset': + {x*cos(2*x), x*sin(2*x), cos(2*x), sin(2*x)}} + assert _undetermined_coefficients_match(x**2*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), x**2*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(2*exp(-x) - x**2*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), x**2*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(exp(-2*x) + x**2, x) == \ + {'test': True, 'trialset': {S.One, x, x**2, exp(-2*x)}} + assert _undetermined_coefficients_match(x*exp(-x), x) == \ + {'test': True, 'trialset': {x*exp(-x), exp(-x)}} + assert _undetermined_coefficients_match(x + exp(2*x), x) == \ + {'test': True, 'trialset': {S.One, x, exp(2*x)}} + assert _undetermined_coefficients_match(sin(x) + exp(-x), x) == \ + {'test': True, 'trialset': {cos(x), sin(x), exp(-x)}} + assert _undetermined_coefficients_match(exp(x), x) == \ + {'test': True, 'trialset': {exp(x)}} + # converted from sin(x)**2 + assert _undetermined_coefficients_match(S.Half - cos(2*x)/2, x) == \ + {'test': True, 'trialset': {S.One, cos(2*x), sin(2*x)}} + # converted from exp(2*x)*sin(x)**2 + assert _undetermined_coefficients_match( + exp(2*x)*(S.Half + cos(2*x)/2), x + ) == { + 'test': True, 'trialset': {exp(2*x)*sin(2*x), cos(2*x)*exp(2*x), + exp(2*x)}} + assert _undetermined_coefficients_match(2*x + sin(x) + cos(x), x) == \ + {'test': True, 'trialset': {S.One, x, cos(x), sin(x)}} + # converted from sin(2*x)*sin(x) + assert _undetermined_coefficients_match(cos(x)/2 - cos(3*x)/2, x) == \ + {'test': True, 'trialset': {cos(x), cos(3*x), sin(x), sin(3*x)}} + assert _undetermined_coefficients_match(cos(x**2), x) == {'test': False} + assert _undetermined_coefficients_match(2**(x**2), x) == {'test': False} + + +def test_issue_4785_22462(): + from sympy.abc import A + eq = x + A*(x + diff(f(x), x) + f(x)) + diff(f(x), x) + f(x) + 2 + assert classify_ode(eq, f(x)) == ('factorable', '1st_exact', '1st_linear', + 'Bernoulli', 'almost_linear', '1st_power_series', 'lie_group', + 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + '1st_exact_Integral', '1st_linear_Integral', 'Bernoulli_Integral', + 'almost_linear_Integral', + 'nth_linear_constant_coeff_variation_of_parameters_Integral') + # issue 4864 + eq = (x**2 + f(x)**2)*f(x).diff(x) - 2*x*f(x) + assert classify_ode(eq, f(x)) == ('factorable', '1st_exact', + '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', + '1st_homogeneous_coeff_subs_dep_div_indep', + '1st_power_series', + 'lie_group', '1st_exact_Integral', + '1st_homogeneous_coeff_subs_indep_div_dep_Integral', + '1st_homogeneous_coeff_subs_dep_div_indep_Integral') + + +def test_issue_4825(): + raises(ValueError, lambda: dsolve(f(x, y).diff(x) - y*f(x, y), f(x))) + assert classify_ode(f(x, y).diff(x) - y*f(x, y), f(x), dict=True) == \ + {'order': 0, 'default': None, 'ordered_hints': ()} + # See also issue 3793, test Z13. + raises(ValueError, lambda: dsolve(f(x).diff(x), f(y))) + assert classify_ode(f(x).diff(x), f(y), dict=True) == \ + {'order': 0, 'default': None, 'ordered_hints': ()} + + +def test_constant_renumber_order_issue_5308(): + from sympy.utilities.iterables import variations + + assert constant_renumber(C1*x + C2*y) == \ + constant_renumber(C1*y + C2*x) == \ + C1*x + C2*y + e = C1*(C2 + x)*(C3 + y) + for a, b, c in variations([C1, C2, C3], 3): + assert constant_renumber(a*(b + x)*(c + y)) == e + + +def test_constant_renumber(): + e1, e2, x, y = symbols("e1:3 x y") + exprs = [e2*x, e1*x + e2*y] + + assert constant_renumber(exprs[0]) == e2*x + assert constant_renumber(exprs[0], variables=[x]) == C1*x + assert constant_renumber(exprs[0], variables=[x], newconstants=[C2]) == C2*x + assert constant_renumber(exprs, variables=[x, y]) == [C1*x, C1*y + C2*x] + assert constant_renumber(exprs, variables=[x, y], newconstants=symbols("C3:5")) == [C3*x, C3*y + C4*x] + + +def test_issue_5770(): + k = Symbol("k", real=True) + t = Symbol('t') + w = Function('w') + sol = dsolve(w(t).diff(t, 6) - k**6*w(t), w(t)) + assert len([s for s in sol.free_symbols if s.name.startswith('C')]) == 6 + assert constantsimp((C1*cos(x) + C2*cos(x))*exp(x), {C1, C2}) == \ + C1*cos(x)*exp(x) + assert constantsimp(C1*cos(x) + C2*cos(x) + C3*sin(x), {C1, C2, C3}) == \ + C1*cos(x) + C3*sin(x) + assert constantsimp(exp(C1 + x), {C1}) == C1*exp(x) + assert constantsimp(x + C1 + y, {C1, y}) == C1 + x + assert constantsimp(x + C1 + Integral(x, (x, 1, 2)), {C1}) == C1 + x + + +def test_issue_5112_5430(): + assert homogeneous_order(-log(x) + acosh(x), x) is None + assert homogeneous_order(y - log(x), x, y) is None + + +def test_issue_5095(): + f = Function('f') + raises(ValueError, lambda: dsolve(f(x).diff(x)**2, f(x), 'fdsjf')) + + +def test_homogeneous_function(): + f = Function('f') + eq1 = tan(x + f(x)) + eq2 = sin((3*x)/(4*f(x))) + eq3 = cos(x*f(x)*Rational(3, 4)) + eq4 = log((3*x + 4*f(x))/(5*f(x) + 7*x)) + eq5 = exp((2*x**2)/(3*f(x)**2)) + eq6 = log((3*x + 4*f(x))/(5*f(x) + 7*x) + exp((2*x**2)/(3*f(x)**2))) + eq7 = sin((3*x)/(5*f(x) + x**2)) + assert homogeneous_order(eq1, x, f(x)) == None + assert homogeneous_order(eq2, x, f(x)) == 0 + assert homogeneous_order(eq3, x, f(x)) == None + assert homogeneous_order(eq4, x, f(x)) == 0 + assert homogeneous_order(eq5, x, f(x)) == 0 + assert homogeneous_order(eq6, x, f(x)) == 0 + assert homogeneous_order(eq7, x, f(x)) == None + + +def test_linear_coeff_match(): + n, d = z*(2*x + 3*f(x) + 5), z*(7*x + 9*f(x) + 11) + rat = n/d + eq1 = sin(rat) + cos(rat.expand()) + obj1 = LinearCoefficients(eq1) + eq2 = rat + obj2 = LinearCoefficients(eq2) + eq3 = log(sin(rat)) + obj3 = LinearCoefficients(eq3) + ans = (4, Rational(-13, 3)) + assert obj1._linear_coeff_match(eq1, f(x)) == ans + assert obj2._linear_coeff_match(eq2, f(x)) == ans + assert obj3._linear_coeff_match(eq3, f(x)) == ans + + # no c + eq4 = (3*x)/f(x) + obj4 = LinearCoefficients(eq4) + # not x and f(x) + eq5 = (3*x + 2)/x + obj5 = LinearCoefficients(eq5) + # denom will be zero + eq6 = (3*x + 2*f(x) + 1)/(3*x + 2*f(x) + 5) + obj6 = LinearCoefficients(eq6) + # not rational coefficient + eq7 = (3*x + 2*f(x) + sqrt(2))/(3*x + 2*f(x) + 5) + obj7 = LinearCoefficients(eq7) + assert obj4._linear_coeff_match(eq4, f(x)) is None + assert obj5._linear_coeff_match(eq5, f(x)) is None + assert obj6._linear_coeff_match(eq6, f(x)) is None + assert obj7._linear_coeff_match(eq7, f(x)) is None + + +def test_constantsimp_take_problem(): + c = exp(C1) + 2 + assert len(Poly(constantsimp(exp(C1) + c + c*x, [C1])).gens) == 2 + + +def test_series(): + C1 = Symbol("C1") + eq = f(x).diff(x) - f(x) + sol = Eq(f(x), C1 + C1*x + C1*x**2/2 + C1*x**3/6 + C1*x**4/24 + + C1*x**5/120 + O(x**6)) + assert dsolve(eq, hint='1st_power_series') == sol + assert checkodesol(eq, sol, order=1)[0] + + eq = f(x).diff(x) - x*f(x) + sol = Eq(f(x), C1*x**4/8 + C1*x**2/2 + C1 + O(x**6)) + assert dsolve(eq, hint='1st_power_series') == sol + assert checkodesol(eq, sol, order=1)[0] + + eq = f(x).diff(x) - sin(x*f(x)) + sol = Eq(f(x), (x - 2)**2*(1+ sin(4))*cos(4) + (x - 2)*sin(4) + 2 + O(x**3)) + assert dsolve(eq, hint='1st_power_series', ics={f(2): 2}, n=3) == sol + # FIXME: The solution here should be O((x-2)**3) so is incorrect + #assert checkodesol(eq, sol, order=1)[0] + + +@slow +def test_2nd_power_series_ordinary(): + C1, C2 = symbols("C1 C2") + + eq = f(x).diff(x, 2) - x*f(x) + assert classify_ode(eq) == ('2nd_linear_airy', '2nd_power_series_ordinary') + sol = Eq(f(x), C2*(x**3/6 + 1) + C1*x*(x**3/12 + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary') == sol + assert checkodesol(eq, sol) == (True, 0) + + sol = Eq(f(x), C2*((x + 2)**4/6 + (x + 2)**3/6 - (x + 2)**2 + 1) + + C1*(x + (x + 2)**4/12 - (x + 2)**3/3 + S(2)) + + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary', x0=-2) == sol + # FIXME: Solution should be O((x+2)**6) + # assert checkodesol(eq, sol) == (True, 0) + + sol = Eq(f(x), C2*x + C1 + O(x**2)) + assert dsolve(eq, hint='2nd_power_series_ordinary', n=2) == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = (1 + x**2)*(f(x).diff(x, 2)) + 2*x*(f(x).diff(x)) -2*f(x) + assert classify_ode(eq) == ('factorable', '2nd_hypergeometric', '2nd_hypergeometric_Integral', + '2nd_power_series_ordinary') + + sol = Eq(f(x), C2*(-x**4/3 + x**2 + 1) + C1*x + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_ordinary') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + x*(f(x).diff(x)) + f(x) + assert classify_ode(eq) == ('factorable', '2nd_power_series_ordinary',) + sol = Eq(f(x), C2*(x**4/8 - x**2/2 + 1) + C1*x*(-x**2/3 + 1) + O(x**6)) + assert dsolve(eq) == sol + # FIXME: checkodesol fails for this solution... + # assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + f(x).diff(x) - x*f(x) + assert classify_ode(eq) == ('2nd_power_series_ordinary',) + sol = Eq(f(x), C2*(-x**4/24 + x**3/6 + 1) + + C1*x*(x**3/24 + x**2/6 - x/2 + 1) + O(x**6)) + assert dsolve(eq) == sol + # FIXME: checkodesol fails for this solution... + # assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + x*f(x) + assert classify_ode(eq) == ('2nd_linear_airy', '2nd_power_series_ordinary') + sol = Eq(f(x), C2*(x**6/180 - x**3/6 + 1) + C1*x*(-x**3/12 + 1) + O(x**7)) + assert dsolve(eq, hint='2nd_power_series_ordinary', n=7) == sol + assert checkodesol(eq, sol) == (True, 0) + + +def test_2nd_power_series_regular(): + C1, C2, a = symbols("C1 C2 a") + eq = x**2*(f(x).diff(x, 2)) - 3*x*(f(x).diff(x)) + (4*x + 4)*f(x) + sol = Eq(f(x), C1*x**2*(-16*x**3/9 + 4*x**2 - 4*x + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = 4*x**2*(f(x).diff(x, 2)) -8*x**2*(f(x).diff(x)) + (4*x**2 + + 1)*f(x) + sol = Eq(f(x), C1*sqrt(x)*(x**4/24 + x**3/6 + x**2/2 + x + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x**2*(f(x).diff(x, 2)) - x**2*(f(x).diff(x)) + ( + x**2 - 2)*f(x) + sol = Eq(f(x), C1*(-x**6/720 - 3*x**5/80 - x**4/8 + x**2/2 + x/2 + 1)/x + + C2*x**2*(-x**3/60 + x**2/20 + x/2 + 1) + O(x**6)) + assert dsolve(eq) == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - Rational(1, 4))*f(x) + sol = Eq(f(x), C1*(x**4/24 - x**2/2 + 1)/sqrt(x) + + C2*sqrt(x)*(x**4/120 - x**2/6 + 1) + O(x**6)) + assert dsolve(eq, hint='2nd_power_series_regular') == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = x*f(x).diff(x, 2) + f(x).diff(x) - a*x*f(x) + sol = Eq(f(x), C1*(a**2*x**4/64 + a*x**2/4 + 1) + O(x**6)) + assert dsolve(eq, f(x), hint="2nd_power_series_regular") == sol + assert checkodesol(eq, sol) == (True, 0) + + eq = f(x).diff(x, 2) + ((1 - x)/x)*f(x).diff(x) + (a/x)*f(x) + sol = Eq(f(x), C1*(-a*x**5*(a - 4)*(a - 3)*(a - 2)*(a - 1)/14400 + \ + a*x**4*(a - 3)*(a - 2)*(a - 1)/576 - a*x**3*(a - 2)*(a - 1)/36 + \ + a*x**2*(a - 1)/4 - a*x + 1) + O(x**6)) + assert dsolve(eq, f(x), hint="2nd_power_series_regular") == sol + assert checkodesol(eq, sol) == (True, 0) + + +def test_issue_15056(): + t = Symbol('t') + C3 = Symbol('C3') + assert get_numbered_constants(Symbol('C1') * Function('C2')(t)) == C3 + + +def test_issue_15913(): + eq = -C1/x - 2*x*f(x) - f(x) + Derivative(f(x), x) + sol = C2*exp(x**2 + x) + exp(x**2 + x)*Integral(C1*exp(-x**2 - x)/x, x) + assert checkodesol(eq, sol) == (True, 0) + sol = C1 + C2*exp(-x*y) + eq = Derivative(y*f(x), x) + f(x).diff(x, 2) + assert checkodesol(eq, sol, f(x)) == (True, 0) + + +def test_issue_16146(): + raises(ValueError, lambda: dsolve([f(x).diff(x), g(x).diff(x)], [f(x), g(x), h(x)])) + raises(ValueError, lambda: dsolve([f(x).diff(x), g(x).diff(x)], [f(x)])) + + +def test_dsolve_remove_redundant_solutions(): + + eq = (f(x)-2)*f(x).diff(x) + sol = Eq(f(x), C1) + assert dsolve(eq) == sol + + eq = (f(x)-sin(x))*(f(x).diff(x, 2)) + sol = {Eq(f(x), C1 + C2*x), Eq(f(x), sin(x))} + assert set(dsolve(eq)) == sol + + eq = (f(x)**2-2*f(x)+1)*f(x).diff(x, 3) + sol = Eq(f(x), C1 + C2*x + C3*x**2) + assert dsolve(eq) == sol + + +def test_issue_13060(): + A, B = symbols("A B", cls=Function) + t = Symbol("t") + eq = [Eq(Derivative(A(t), t), A(t)*B(t)), Eq(Derivative(B(t), t), A(t)*B(t))] + sol = dsolve(eq) + assert checkodesol(eq, sol) == (True, [0, 0]) + + +def test_issue_22523(): + N, s = symbols('N s') + rho = Function('rho') + # intentionally use 4.0 to confirm issue with nfloat + # works here + eqn = 4.0*N*sqrt(N - 1)*rho(s) + (4*s**2*(N - 1) + (N - 2*s*(N - 1))**2 + )*Derivative(rho(s), (s, 2)) + match = classify_ode(eqn, dict=True, hint='all') + assert match['2nd_power_series_ordinary']['terms'] == 5 + C1, C2 = symbols('C1,C2') + sol = dsolve(eqn, hint='2nd_power_series_ordinary') + # there is no r(2.0) in this result + assert filldedent(sol) == filldedent(str(''' + Eq(rho(s), C2*(1 - 4.0*s**4*sqrt(N - 1.0)/N + 0.666666666666667*s**4/N + - 2.66666666666667*s**3*sqrt(N - 1.0)/N - 2.0*s**2*sqrt(N - 1.0)/N + + 9.33333333333333*s**4*sqrt(N - 1.0)/N**2 - 0.666666666666667*s**4/N**2 + + 2.66666666666667*s**3*sqrt(N - 1.0)/N**2 - + 5.33333333333333*s**4*sqrt(N - 1.0)/N**3) + C1*s*(1.0 - + 1.33333333333333*s**3*sqrt(N - 1.0)/N - 0.666666666666667*s**2*sqrt(N + - 1.0)/N + 1.33333333333333*s**3*sqrt(N - 1.0)/N**2) + O(s**6))''')) + + +def test_issue_22604(): + x1, x2 = symbols('x1, x2', cls = Function) + t, k1, k2, m1, m2 = symbols('t k1 k2 m1 m2', real = True) + k1, k2, m1, m2 = 1, 1, 1, 1 + eq1 = Eq(m1*diff(x1(t), t, 2) + k1*x1(t) - k2*(x2(t) - x1(t)), 0) + eq2 = Eq(m2*diff(x2(t), t, 2) + k2*(x2(t) - x1(t)), 0) + eqs = [eq1, eq2] + [x1sol, x2sol] = dsolve(eqs, [x1(t), x2(t)], ics = {x1(0):0, x1(t).diff().subs(t,0):0, \ + x2(0):1, x2(t).diff().subs(t,0):0}) + assert x1sol == Eq(x1(t), sqrt(3 - sqrt(5))*(sqrt(10) + 5*sqrt(2))*cos(sqrt(2)*t*sqrt(3 - sqrt(5))/2)/20 + \ + (-5*sqrt(2) + sqrt(10))*sqrt(sqrt(5) + 3)*cos(sqrt(2)*t*sqrt(sqrt(5) + 3)/2)/20) + assert x2sol == Eq(x2(t), (sqrt(5) + 5)*cos(sqrt(2)*t*sqrt(3 - sqrt(5))/2)/10 + (5 - sqrt(5))*cos(sqrt(2)*t*sqrt(sqrt(5) + 3)/2)/10) + + +def test_issue_22462(): + for de in [ + Eq(f(x).diff(x), -20*f(x)**2 - 500*f(x)/7200), + Eq(f(x).diff(x), -2*f(x)**2 - 5*f(x)/7)]: + assert 'Bernoulli' in classify_ode(de, f(x)) + + +def test_issue_23425(): + x = symbols('x') + y = Function('y') + eq = Eq(-E**x*y(x).diff().diff() + y(x).diff(), 0) + assert classify_ode(eq) == \ + ('Liouville', 'nth_order_reducible', \ + '2nd_power_series_ordinary', 'Liouville_Integral') + + +@SKIP("too slow for @slow") +def test_issue_25820(): + x = Symbol('x') + y = Function('y') + eq = y(x)**3*y(x).diff(x, 2) + 49 + assert dsolve(eq, y(x)) is not None # doesn't raise diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_riccati.py b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_riccati.py new file mode 100644 index 0000000000000000000000000000000000000000..548a1ee5b5e82d88f1b0aa319af09b8b9d1d9bfe --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_riccati.py @@ -0,0 +1,877 @@ +from sympy.core.random import randint +from sympy.core.function import Function +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, oo) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import tanh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import sin +from sympy.polys.polytools import Poly +from sympy.simplify.ratsimp import ratsimp +from sympy.solvers.ode.subscheck import checkodesol +from sympy.testing.pytest import slow +from sympy.solvers.ode.riccati import (riccati_normal, riccati_inverse_normal, + riccati_reduced, match_riccati, inverse_transform_poly, limit_at_inf, + check_necessary_conds, val_at_inf, construct_c_case_1, + construct_c_case_2, construct_c_case_3, construct_d_case_4, + construct_d_case_5, construct_d_case_6, rational_laurent_series, + solve_riccati) + +f = Function('f') +x = symbols('x') + +# These are the functions used to generate the tests +# SHOULD NOT BE USED DIRECTLY IN TESTS + +def rand_rational(maxint): + return Rational(randint(-maxint, maxint), randint(1, maxint)) + + +def rand_poly(x, degree, maxint): + return Poly([rand_rational(maxint) for _ in range(degree+1)], x) + + +def rand_rational_function(x, degree, maxint): + degnum = randint(1, degree) + degden = randint(1, degree) + num = rand_poly(x, degnum, maxint) + den = rand_poly(x, degden, maxint) + while den == Poly(0, x): + den = rand_poly(x, degden, maxint) + return num / den + + +def find_riccati_ode(ratfunc, x, yf): + y = ratfunc + yp = y.diff(x) + q1 = rand_rational_function(x, 1, 3) + q2 = rand_rational_function(x, 1, 3) + while q2 == 0: + q2 = rand_rational_function(x, 1, 3) + q0 = ratsimp(yp - q1*y - q2*y**2) + eq = Eq(yf.diff(), q0 + q1*yf + q2*yf**2) + sol = Eq(yf, y) + assert checkodesol(eq, sol) == (True, 0) + return eq, q0, q1, q2 + + +# Testing functions start + +def test_riccati_transformation(): + """ + This function tests the transformation of the + solution of a Riccati ODE to the solution of + its corresponding normal Riccati ODE. + + Each test case 4 values - + + 1. w - The solution to be transformed + 2. b1 - The coefficient of f(x) in the ODE. + 3. b2 - The coefficient of f(x)**2 in the ODE. + 4. y - The solution to the normal Riccati ODE. + """ + tests = [ + ( + x/(x - 1), + (x**2 + 7)/3*x, + x, + -x**2/(x - 1) - x*(x**2/3 + S(7)/3)/2 - 1/(2*x) + ), + ( + (2*x + 3)/(2*x + 2), + (3 - 3*x)/(x + 1), + 5*x, + -5*x*(2*x + 3)/(2*x + 2) - (3 - 3*x)/(Mul(2, x + 1, evaluate=False)) - 1/(2*x) + ), + ( + -1/(2*x**2 - 1), + 0, + (2 - x)/(4*x - 2), + (2 - x)/((4*x - 2)*(2*x**2 - 1)) - (4*x - 2)*(Mul(-4, 2 - x, evaluate=False)/(4*x - \ + 2)**2 - 1/(4*x - 2))/(Mul(2, 2 - x, evaluate=False)) + ), + ( + x, + (8*x - 12)/(12*x + 9), + x**3/(6*x - 9), + -x**4/(6*x - 9) - (8*x - 12)/(Mul(2, 12*x + 9, evaluate=False)) - (6*x - 9)*(-6*x**3/(6*x \ + - 9)**2 + 3*x**2/(6*x - 9))/(2*x**3) + )] + for w, b1, b2, y in tests: + assert y == riccati_normal(w, x, b1, b2) + assert w == riccati_inverse_normal(y, x, b1, b2).cancel() + + # Test bp parameter in riccati_inverse_normal + tests = [ + ( + (-2*x - 1)/(2*x**2 + 2*x - 2), + -2/x, + (-x - 1)/(4*x), + 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1), + -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) - (-2*x - 1)*(-x - 1)/(4*x*(2*x**2 + 2*x \ + - 2)) + 1/x + ), + ( + 3/(2*x**2), + -2/x, + (-x - 1)/(4*x), + 8*x**2*(1/(4*x) + (-x - 1)/(4*x**2))/(-x - 1)**2 + 4/(-x - 1), + -2*x*(-1/(4*x) - (-x - 1)/(4*x**2))/(-x - 1) + 1/x - Mul(3, -x - 1, evaluate=False)/(8*x**3) + )] + for w, b1, b2, bp, y in tests: + assert y == riccati_normal(w, x, b1, b2) + assert w == riccati_inverse_normal(y, x, b1, b2, bp).cancel() + + +def test_riccati_reduced(): + """ + This function tests the transformation of a + Riccati ODE to its normal Riccati ODE. + + Each test case 2 values - + + 1. eq - A Riccati ODE. + 2. normal_eq - The normal Riccati ODE of eq. + """ + tests = [ + ( + f(x).diff(x) - x**2 - x*f(x) - x*f(x)**2, + + f(x).diff(x) + f(x)**2 + x**3 - x**2/4 - 3/(4*x**2) + ), + ( + 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)**2/x, + + -3*x**2*(1/x + (-x - 1)/x**2)**2/(4*(-x - 1)**2) + Mul(6, \ + -x - 1, evaluate=False)/(2*x + 9) + f(x)**2 + f(x).diff(x) \ + - (-1 + (x + 1)/x)/(x*(-x - 1)) + ), + ( + f(x)**2 + f(x).diff(x) - (x - 1)*f(x)/(-x - S(1)/2), + + -(2*x - 2)**2/(4*(2*x + 1)**2) + (2*x - 2)/(2*x + 1)**2 + \ + f(x)**2 + f(x).diff(x) - 1/(2*x + 1) + ), + ( + f(x).diff(x) - f(x)**2/x, + + f(x)**2 + f(x).diff(x) + 1/(4*x**2) + ), + ( + -3*(-x**2 - x + 1)/(x**2 + 6*x + 1) + f(x).diff(x) + f(x)**2/x, + + f(x)**2 + f(x).diff(x) + (3*x**2/(x**2 + 6*x + 1) + 3*x/(x**2 \ + + 6*x + 1) - 3/(x**2 + 6*x + 1))/x + 1/(4*x**2) + ), + ( + 6*x/(2*x + 9) + f(x).diff(x) - (x + 1)*f(x)/x, + + False + ), + ( + f(x)*f(x).diff(x) - 1/x + f(x)/3 + f(x)**2/(x**2 - 2), + + False + )] + for eq, normal_eq in tests: + assert normal_eq == riccati_reduced(eq, f, x) + + +def test_match_riccati(): + """ + This function tests if an ODE is Riccati or not. + + Each test case has 5 values - + + 1. eq - The Riccati ODE. + 2. match - Boolean indicating if eq is a Riccati ODE. + 3. b0 - + 4. b1 - Coefficient of f(x) in eq. + 5. b2 - Coefficient of f(x)**2 in eq. + """ + tests = [ + # Test Rational Riccati ODEs + ( + f(x).diff(x) - (405*x**3 - 882*x**2 - 78*x + 92)/(243*x**4 \ + - 945*x**3 + 846*x**2 + 180*x - 72) - 2 - f(x)**2/(3*x + 1) \ + - (S(1)/3 - x)*f(x)/(S(1)/3 - 3*x/2), + + True, + + 45*x**3/(27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 98*x**2/ \ + (27*x**4 - 105*x**3 + 94*x**2 + 20*x - 8) - 26*x/(81*x**4 - \ + 315*x**3 + 282*x**2 + 60*x - 24) + 2 + 92/(243*x**4 - 945*x**3 \ + + 846*x**2 + 180*x - 72), + + Mul(-1, 2 - 6*x, evaluate=False)/(9*x - 2), + + 1/(3*x + 1) + ), + ( + f(x).diff(x) + 4*x/27 - (x/3 - 1)*f(x)**2 - (2*x/3 + \ + 1)*f(x)/(3*x + 2) - S(10)/27 - (265*x**2 + 423*x + 162) \ + /(324*x**3 + 216*x**2), + + True, + + -4*x/27 + S(10)/27 + 3/(6*x**3 + 4*x**2) + 47/(36*x**2 \ + + 24*x) + 265/(324*x + 216), + + Mul(-1, -2*x - 3, evaluate=False)/(9*x + 6), + + x/3 - 1 + ), + ( + f(x).diff(x) - (304*x**5 - 745*x**4 + 631*x**3 - 876*x**2 \ + + 198*x - 108)/(36*x**6 - 216*x**5 + 477*x**4 - 567*x**3 + \ + 360*x**2 - 108*x) - S(17)/9 - (x - S(3)/2)*f(x)/(x/2 - \ + S(3)/2) - (x/3 - 3)*f(x)**2/(3*x), + + True, + + 304*x**4/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + 360*x - \ + 108) - 745*x**3/(36*x**5 - 216*x**4 + 477*x**3 - 567*x**2 + \ + 360*x - 108) + 631*x**2/(36*x**5 - 216*x**4 + 477*x**3 - 567* \ + x**2 + 360*x - 108) - 292*x/(12*x**5 - 72*x**4 + 159*x**3 - \ + 189*x**2 + 120*x - 36) + S(17)/9 - 12/(4*x**6 - 24*x**5 + \ + 53*x**4 - 63*x**3 + 40*x**2 - 12*x) + 22/(4*x**5 - 24*x**4 \ + + 53*x**3 - 63*x**2 + 40*x - 12), + + Mul(-1, 3 - 2*x, evaluate=False)/(x - 3), + + Mul(-1, 9 - x, evaluate=False)/(9*x) + ), + # Test Non-Rational Riccati ODEs + ( + f(x).diff(x) - x**(S(3)/2)/(x**(S(1)/2) - 2) + x**2*f(x) + \ + x*f(x)**2/(x**(S(3)/4)), + False, 0, 0, 0 + ), + ( + f(x).diff(x) - sin(x**2) + exp(x)*f(x) + log(x)*f(x)**2, + False, 0, 0, 0 + ), + ( + f(x).diff(x) - tanh(x + sqrt(x)) + f(x) + x**4*f(x)**2, + False, 0, 0, 0 + ), + # Test Non-Riccati ODEs + ( + (1 - x**2)*f(x).diff(x, 2) - 2*x*f(x).diff(x) + 20*f(x), + False, 0, 0, 0 + ), + ( + f(x).diff(x) - x**2 + x**3*f(x) + (x**2/(x + 1))*f(x)**3, + False, 0, 0, 0 + ), + ( + f(x).diff(x)*f(x)**2 + (x**2 - 1)/(x**3 + 1)*f(x) + 1/(2*x \ + + 3) + f(x)**2, + False, 0, 0, 0 + )] + for eq, res, b0, b1, b2 in tests: + match, funcs = match_riccati(eq, f, x) + assert match == res + if res: + assert [b0, b1, b2] == funcs + + +def test_val_at_inf(): + """ + This function tests the valuation of rational + function at oo. + + Each test case has 3 values - + + 1. num - Numerator of rational function. + 2. den - Denominator of rational function. + 3. val_inf - Valuation of rational function at oo + """ + tests = [ + # degree(denom) > degree(numer) + ( + Poly(10*x**3 + 8*x**2 - 13*x + 6, x), + Poly(-13*x**10 - x**9 + 5*x**8 + 7*x**7 + 10*x**6 + 6*x**5 - 7*x**4 + 11*x**3 - 8*x**2 + 5*x + 13, x), + 7 + ), + ( + Poly(1, x), + Poly(-9*x**4 + 3*x**3 + 15*x**2 - 6*x - 14, x), + 4 + ), + # degree(denom) == degree(numer) + ( + Poly(-6*x**3 - 8*x**2 + 8*x - 6, x), + Poly(-5*x**3 + 12*x**2 - 6*x - 9, x), + 0 + ), + # degree(denom) < degree(numer) + ( + Poly(12*x**8 - 12*x**7 - 11*x**6 + 8*x**5 + 3*x**4 - x**3 + x**2 - 11*x, x), + Poly(-14*x**2 + x, x), + -6 + ), + ( + Poly(5*x**6 + 9*x**5 - 11*x**4 - 9*x**3 + x**2 - 4*x + 4, x), + Poly(15*x**4 + 3*x**3 - 8*x**2 + 15*x + 12, x), + -2 + )] + for num, den, val in tests: + assert val_at_inf(num, den, x) == val + + +def test_necessary_conds(): + """ + This function tests the necessary conditions for + a Riccati ODE to have a rational particular solution. + """ + # Valuation at Infinity is an odd negative integer + assert check_necessary_conds(-3, [1, 2, 4]) == False + # Valuation at Infinity is a positive integer lesser than 2 + assert check_necessary_conds(1, [1, 2, 4]) == False + # Multiplicity of a pole is an odd integer greater than 1 + assert check_necessary_conds(2, [3, 1, 6]) == False + # All values are correct + assert check_necessary_conds(-10, [1, 2, 8, 12]) == True + + +def test_inverse_transform_poly(): + """ + This function tests the substitution x -> 1/x + in rational functions represented using Poly. + """ + fns = [ + (15*x**3 - 8*x**2 - 2*x - 6)/(18*x + 6), + + (180*x**5 + 40*x**4 + 80*x**3 + 30*x**2 - 60*x - 80)/(180*x**3 - 150*x**2 + 75*x + 12), + + (-15*x**5 - 36*x**4 + 75*x**3 - 60*x**2 - 80*x - 60)/(80*x**4 + 60*x**3 + 60*x**2 + 60*x - 80), + + (60*x**7 + 24*x**6 - 15*x**5 - 20*x**4 + 30*x**2 + 100*x - 60)/(240*x**2 - 20*x - 30), + + (30*x**6 - 12*x**5 + 15*x**4 - 15*x**2 + 10*x + 60)/(3*x**10 - 45*x**9 + 15*x**5 + 15*x**4 - 5*x**3 \ + + 15*x**2 + 45*x - 15) + ] + for f in fns: + num, den = [Poly(e, x) for e in f.as_numer_denom()] + num, den = inverse_transform_poly(num, den, x) + assert f.subs(x, 1/x).cancel() == num/den + + +def test_limit_at_inf(): + """ + This function tests the limit at oo of a + rational function. + + Each test case has 3 values - + + 1. num - Numerator of rational function. + 2. den - Denominator of rational function. + 3. limit_at_inf - Limit of rational function at oo + """ + tests = [ + # deg(denom) > deg(numer) + ( + Poly(-12*x**2 + 20*x + 32, x), + Poly(32*x**3 + 72*x**2 + 3*x - 32, x), + 0 + ), + # deg(denom) < deg(numer) + ( + Poly(1260*x**4 - 1260*x**3 - 700*x**2 - 1260*x + 1400, x), + Poly(6300*x**3 - 1575*x**2 + 756*x - 540, x), + oo + ), + # deg(denom) < deg(numer), one of the leading coefficients is negative + ( + Poly(-735*x**8 - 1400*x**7 + 1680*x**6 - 315*x**5 - 600*x**4 + 840*x**3 - 525*x**2 \ + + 630*x + 3780, x), + Poly(1008*x**7 - 2940*x**6 - 84*x**5 + 2940*x**4 - 420*x**3 + 1512*x**2 + 105*x + 168, x), + -oo + ), + # deg(denom) == deg(numer) + ( + Poly(105*x**7 - 960*x**6 + 60*x**5 + 60*x**4 - 80*x**3 + 45*x**2 + 120*x + 15, x), + Poly(735*x**7 + 525*x**6 + 720*x**5 + 720*x**4 - 8400*x**3 - 2520*x**2 + 2800*x + 280, x), + S(1)/7 + ), + ( + Poly(288*x**4 - 450*x**3 + 280*x**2 - 900*x - 90, x), + Poly(607*x**4 + 840*x**3 - 1050*x**2 + 420*x + 420, x), + S(288)/607 + )] + for num, den, lim in tests: + assert limit_at_inf(num, den, x) == lim + + +def test_construct_c_case_1(): + """ + This function tests the Case 1 in the step + to calculate coefficients of c-vectors. + + Each test case has 4 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. pole - Pole of a(x) for which c-vector is being + calculated. + 4. c - The c-vector for the pole. + """ + tests = [ + ( + Poly(-3*x**3 + 3*x**2 + 4*x - 5, x, extension=True), + Poly(4*x**8 + 16*x**7 + 9*x**5 + 12*x**4 + 6*x**3 + 12*x**2, x, extension=True), + S(0), + [[S(1)/2 + sqrt(6)*I/6], [S(1)/2 - sqrt(6)*I/6]] + ), + ( + Poly(1200*x**3 + 1440*x**2 + 816*x + 560, x, extension=True), + Poly(128*x**5 - 656*x**4 + 1264*x**3 - 1125*x**2 + 385*x + 49, x, extension=True), + S(7)/4, + [[S(1)/2 + sqrt(16367978)/634], [S(1)/2 - sqrt(16367978)/634]] + ), + ( + Poly(4*x + 2, x, extension=True), + Poly(18*x**4 + (2 - 18*sqrt(3))*x**3 + (14 - 11*sqrt(3))*x**2 + (4 - 6*sqrt(3))*x \ + + 8*sqrt(3) + 16, x, domain='QQ'), + (S(1) + sqrt(3))/2, + [[S(1)/2 + sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2], \ + [S(1)/2 - sqrt(Mul(4, 2*sqrt(3) + 4, evaluate=False)/(19*sqrt(3) + 44) + 1)/2]] + )] + for num, den, pole, c in tests: + assert construct_c_case_1(num, den, x, pole) == c + + +def test_construct_c_case_2(): + """ + This function tests the Case 2 in the step + to calculate coefficients of c-vectors. + + Each test case has 5 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. pole - Pole of a(x) for which c-vector is being + calculated. + 4. mul - The multiplicity of the pole. + 5. c - The c-vector for the pole. + """ + tests = [ + # Testing poles with multiplicity 2 + ( + Poly(1, x, extension=True), + Poly((x - 1)**2*(x - 2), x, extension=True), + 1, 2, + [[-I*(-1 - I)/2], [I*(-1 + I)/2]] + ), + ( + Poly(3*x**5 - 12*x**4 - 7*x**3 + 1, x, extension=True), + Poly((3*x - 1)**2*(x + 2)**2, x, extension=True), + S(1)/3, 2, + [[-S(89)/98], [-S(9)/98]] + ), + # Testing poles with multiplicity 4 + ( + Poly(x**3 - x**2 + 4*x, x, extension=True), + Poly((x - 2)**4*(x + 5)**2, x, extension=True), + 2, 4, + [[7*sqrt(3)*(S(60)/343 - 4*sqrt(3)/7)/12, 2*sqrt(3)/7], \ + [-7*sqrt(3)*(S(60)/343 + 4*sqrt(3)/7)/12, -2*sqrt(3)/7]] + ), + ( + Poly(3*x**5 + x**4 + 3, x, extension=True), + Poly((4*x + 1)**4*(x + 2), x, extension=True), + -S(1)/4, 4, + [[128*sqrt(439)*(-sqrt(439)/128 - S(55)/14336)/439, sqrt(439)/256], \ + [-128*sqrt(439)*(sqrt(439)/128 - S(55)/14336)/439, -sqrt(439)/256]] + ), + # Testing poles with multiplicity 6 + ( + Poly(x**3 + 2, x, extension=True), + Poly((3*x - 1)**6*(x**2 + 1), x, extension=True), + S(1)/3, 6, + [[27*sqrt(66)*(-sqrt(66)/54 - S(131)/267300)/22, -2*sqrt(66)/1485, sqrt(66)/162], \ + [-27*sqrt(66)*(sqrt(66)/54 - S(131)/267300)/22, 2*sqrt(66)/1485, -sqrt(66)/162]] + ), + ( + Poly(x**2 + 12, x, extension=True), + Poly((x - sqrt(2))**6, x, extension=True), + sqrt(2), 6, + [[sqrt(14)*(S(6)/7 - 3*sqrt(14))/28, sqrt(7)/7, sqrt(14)], \ + [-sqrt(14)*(S(6)/7 + 3*sqrt(14))/28, -sqrt(7)/7, -sqrt(14)]] + )] + for num, den, pole, mul, c in tests: + assert construct_c_case_2(num, den, x, pole, mul) == c + + +def test_construct_c_case_3(): + """ + This function tests the Case 3 in the step + to calculate coefficients of c-vectors. + """ + assert construct_c_case_3() == [[1]] + + +def test_construct_d_case_4(): + """ + This function tests the Case 4 in the step + to calculate coefficients of the d-vector. + + Each test case has 4 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. mul - Multiplicity of oo as a pole. + 4. d - The d-vector. + """ + tests = [ + # Tests with multiplicity at oo = 2 + ( + Poly(-x**5 - 2*x**4 + 4*x**3 + 2*x + 5, x, extension=True), + Poly(9*x**3 - 2*x**2 + 10*x - 2, x, extension=True), + 2, + [[10*I/27, I/3, -3*I*(S(158)/243 - I/3)/2], \ + [-10*I/27, -I/3, 3*I*(S(158)/243 + I/3)/2]] + ), + ( + Poly(-x**6 + 9*x**5 + 5*x**4 + 6*x**3 + 5*x**2 + 6*x + 7, x, extension=True), + Poly(x**4 + 3*x**3 + 12*x**2 - x + 7, x, extension=True), + 2, + [[-6*I, I, -I*(17 - I)/2], [6*I, -I, I*(17 + I)/2]] + ), + # Tests with multiplicity at oo = 4 + ( + Poly(-2*x**6 - x**5 - x**4 - 2*x**3 - x**2 - 3*x - 3, x, extension=True), + Poly(3*x**2 + 10*x + 7, x, extension=True), + 4, + [[269*sqrt(6)*I/288, -17*sqrt(6)*I/36, sqrt(6)*I/3, -sqrt(6)*I*(S(16969)/2592 \ + - 2*sqrt(6)*I/3)/4], [-269*sqrt(6)*I/288, 17*sqrt(6)*I/36, -sqrt(6)*I/3, \ + sqrt(6)*I*(S(16969)/2592 + 2*sqrt(6)*I/3)/4]] + ), + ( + Poly(-3*x**5 - 3*x**4 - 3*x**3 - x**2 - 1, x, extension=True), + Poly(12*x - 2, x, extension=True), + 4, + [[41*I/192, 7*I/24, I/2, -I*(-S(59)/6912 - I)], \ + [-41*I/192, -7*I/24, -I/2, I*(-S(59)/6912 + I)]] + ), + # Tests with multiplicity at oo = 4 + ( + Poly(-x**7 - x**5 - x**4 - x**2 - x, x, extension=True), + Poly(x + 2, x, extension=True), + 6, + [[-5*I/2, 2*I, -I, I, -I*(-9 - 3*I)/2], [5*I/2, -2*I, I, -I, I*(-9 + 3*I)/2]] + ), + ( + Poly(-x**7 - x**6 - 2*x**5 - 2*x**4 - x**3 - x**2 + 2*x - 2, x, extension=True), + Poly(2*x - 2, x, extension=True), + 6, + [[3*sqrt(2)*I/4, 3*sqrt(2)*I/4, sqrt(2)*I/2, sqrt(2)*I/2, -sqrt(2)*I*(-S(7)/8 - \ + 3*sqrt(2)*I/2)/2], [-3*sqrt(2)*I/4, -3*sqrt(2)*I/4, -sqrt(2)*I/2, -sqrt(2)*I/2, \ + sqrt(2)*I*(-S(7)/8 + 3*sqrt(2)*I/2)/2]] + )] + for num, den, mul, d in tests: + ser = rational_laurent_series(num, den, x, oo, mul, 1) + assert construct_d_case_4(ser, mul//2) == d + + +def test_construct_d_case_5(): + """ + This function tests the Case 5 in the step + to calculate coefficients of the d-vector. + + Each test case has 3 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. d - The d-vector. + """ + tests = [ + ( + Poly(2*x**3 + x**2 + x - 2, x, extension=True), + Poly(9*x**3 + 5*x**2 + 2*x - 1, x, extension=True), + [[sqrt(2)/3, -sqrt(2)/108], [-sqrt(2)/3, sqrt(2)/108]] + ), + ( + Poly(3*x**5 + x**4 - x**3 + x**2 - 2*x - 2, x, domain='ZZ'), + Poly(9*x**5 + 7*x**4 + 3*x**3 + 2*x**2 + 5*x + 7, x, domain='ZZ'), + [[sqrt(3)/3, -2*sqrt(3)/27], [-sqrt(3)/3, 2*sqrt(3)/27]] + ), + ( + Poly(x**2 - x + 1, x, domain='ZZ'), + Poly(3*x**2 + 7*x + 3, x, domain='ZZ'), + [[sqrt(3)/3, -5*sqrt(3)/9], [-sqrt(3)/3, 5*sqrt(3)/9]] + )] + for num, den, d in tests: + # Multiplicity of oo is 0 + ser = rational_laurent_series(num, den, x, oo, 0, 1) + assert construct_d_case_5(ser) == d + + +def test_construct_d_case_6(): + """ + This function tests the Case 6 in the step + to calculate coefficients of the d-vector. + + Each test case has 3 values - + + 1. num - Numerator of the rational function a(x). + 2. den - Denominator of the rational function a(x). + 3. d - The d-vector. + """ + tests = [ + ( + Poly(-2*x**2 - 5, x, domain='ZZ'), + Poly(4*x**4 + 2*x**2 + 10*x + 2, x, domain='ZZ'), + [[S(1)/2 + I/2], [S(1)/2 - I/2]] + ), + ( + Poly(-2*x**3 - 4*x**2 - 2*x - 5, x, domain='ZZ'), + Poly(x**6 - x**5 + 2*x**4 - 4*x**3 - 5*x**2 - 5*x + 9, x, domain='ZZ'), + [[1], [0]] + ), + ( + Poly(-5*x**3 + x**2 + 11*x + 12, x, domain='ZZ'), + Poly(6*x**8 - 26*x**7 - 27*x**6 - 10*x**5 - 44*x**4 - 46*x**3 - 34*x**2 \ + - 27*x - 42, x, domain='ZZ'), + [[1], [0]] + )] + for num, den, d in tests: + assert construct_d_case_6(num, den, x) == d + + +def test_rational_laurent_series(): + """ + This function tests the computation of coefficients + of Laurent series of a rational function. + + Each test case has 5 values - + + 1. num - Numerator of the rational function. + 2. den - Denominator of the rational function. + 3. x0 - Point about which Laurent series is to + be calculated. + 4. mul - Multiplicity of x0 if x0 is a pole of + the rational function (0 otherwise). + 5. n - Number of terms upto which the series + is to be calculated. + """ + tests = [ + # Laurent series about simple pole (Multiplicity = 1) + ( + Poly(x**2 - 3*x + 9, x, extension=True), + Poly(x**2 - x, x, extension=True), + S(1), 1, 6, + {1: 7, 0: -8, -1: 9, -2: -9, -3: 9, -4: -9} + ), + # Laurent series about multiple pole (Multiplicity > 1) + ( + Poly(64*x**3 - 1728*x + 1216, x, extension=True), + Poly(64*x**4 - 80*x**3 - 831*x**2 + 1809*x - 972, x, extension=True), + S(9)/8, 2, 3, + {0: S(32177152)/46521675, 2: S(1019)/984, -1: S(11947565056)/28610830125, \ + 1: S(209149)/75645} + ), + ( + Poly(1, x, extension=True), + Poly(x**5 + (-4*sqrt(2) - 1)*x**4 + (4*sqrt(2) + 12)*x**3 + (-12 - 8*sqrt(2))*x**2 \ + + (4 + 8*sqrt(2))*x - 4, x, extension=True), + sqrt(2), 4, 6, + {4: 1 + sqrt(2), 3: -3 - 2*sqrt(2), 2: Mul(-1, -3 - 2*sqrt(2), evaluate=False)/(-1 \ + + sqrt(2)), 1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**2, 0: Mul(-1, -3 - 2*sqrt(2), evaluate=False \ + )/(-1 + sqrt(2))**3, -1: (-3 - 2*sqrt(2))/(-1 + sqrt(2))**4} + ), + # Laurent series about oo + ( + Poly(x**5 - 4*x**3 + 6*x**2 + 10*x - 13, x, extension=True), + Poly(x**2 - 5, x, extension=True), + oo, 3, 6, + {3: 1, 2: 0, 1: 1, 0: 6, -1: 15, -2: 17} + ), + # Laurent series at x0 where x0 is not a pole of the function + # Using multiplicity as 0 (as x0 will not be a pole) + ( + Poly(3*x**3 + 6*x**2 - 2*x + 5, x, extension=True), + Poly(9*x**4 - x**3 - 3*x**2 + 4*x + 4, x, extension=True), + S(2)/5, 0, 1, + {0: S(3345)/3304, -1: S(399325)/2729104, -2: S(3926413375)/4508479808, \ + -3: S(-5000852751875)/1862002160704, -4: S(-6683640101653125)/6152055138966016} + ), + ( + Poly(-7*x**2 + 2*x - 4, x, extension=True), + Poly(7*x**5 + 9*x**4 + 8*x**3 + 3*x**2 + 6*x + 9, x, extension=True), + oo, 0, 6, + {0: 0, -2: 0, -5: -S(71)/49, -1: 0, -3: -1, -4: S(11)/7} + )] + for num, den, x0, mul, n, ser in tests: + assert ser == rational_laurent_series(num, den, x, x0, mul, n) + + +def check_dummy_sol(eq, solse, dummy_sym): + """ + Helper function to check if actual solution + matches expected solution if actual solution + contains dummy symbols. + """ + if isinstance(eq, Eq): + eq = eq.lhs - eq.rhs + _, funcs = match_riccati(eq, f, x) + + sols = solve_riccati(f(x), x, *funcs) + C1 = Dummy('C1') + sols = [sol.subs(C1, dummy_sym) for sol in sols] + + assert all(x[0] for x in checkodesol(eq, sols)) + assert all(s1.dummy_eq(s2, dummy_sym) for s1, s2 in zip(sols, solse)) + + +def test_solve_riccati(): + """ + This function tests the computation of rational + particular solutions for a Riccati ODE. + + Each test case has 2 values - + + 1. eq - Riccati ODE to be solved. + 2. sol - Expected solution to the equation. + + Some examples have been taken from the paper - "Statistical Investigation of + First-Order Algebraic ODEs and their Rational General Solutions" by + Georg Grasegger, N. Thieu Vo, Franz Winkler + + https://www3.risc.jku.at/publications/download/risc_5197/RISCReport15-19.pdf + """ + C0 = Dummy('C0') + # Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2, + # a, b, c are rational functions of x + + tests = [ + # a(x) is a constant + ( + Eq(f(x).diff(x) + f(x)**2 - 2, 0), + [Eq(f(x), sqrt(2)), Eq(f(x), -sqrt(2))] + ), + # a(x) is a constant + ( + f(x)**2 + f(x).diff(x) + 4*f(x)/x + 2/x**2, + [Eq(f(x), (-2*C0 - x)/(C0*x + x**2))] + ), + # a(x) is a constant + ( + 2*x**2*f(x).diff(x) - x*(4*f(x) + f(x).diff(x) - 4) + (f(x) - 1)*f(x), + [Eq(f(x), (C0 + 2*x**2)/(C0 + x))] + ), + # Pole with multiplicity 1 + ( + Eq(f(x).diff(x), -f(x)**2 - 2/(x**3 - x**2)), + [Eq(f(x), 1/(x**2 - x))] + ), + # One pole of multiplicity 2 + ( + x**2 - (2*x + 1/x)*f(x) + f(x)**2 + f(x).diff(x), + [Eq(f(x), (C0*x + x**3 + 2*x)/(C0 + x**2)), Eq(f(x), x)] + ), + ( + x**4*f(x).diff(x) + x**2 - x*(2*f(x)**2 + f(x).diff(x)) + f(x), + [Eq(f(x), (C0*x**2 + x)/(C0 + x**2)), Eq(f(x), x**2)] + ), + # Multiple poles of multiplicity 2 + ( + -f(x)**2 + f(x).diff(x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \ + - 1)**2), + [Eq(f(x), (9*C0*x - 6*C0 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 \ + - 30*x + 6)/(6*C0*x**2 - 9*C0*x + 3*C0 + 6*x**6 - 29*x**5 + \ + 57*x**4 - 58*x**3 + 30*x**2 - 6*x)), Eq(f(x), (3*x - 2)/(2*x**2 \ + - 3*x + 1))] + ), + # Regression: Poles with even multiplicity > 2 fixed + ( + f(x)**2 + f(x).diff(x) - (4*x**6 - 8*x**5 + 12*x**4 + 4*x**3 + \ + 7*x**2 - 20*x + 4)/(4*x**4), + [Eq(f(x), (2*x**5 - 2*x**4 - x**3 + 4*x**2 + 3*x - 2)/(2*x**4 \ + - 2*x**2))] + ), + # Regression: Poles with even multiplicity > 2 fixed + ( + Eq(f(x).diff(x), (-x**6 + 15*x**4 - 40*x**3 + 45*x**2 - 24*x + 4)/\ + (x**12 - 12*x**11 + 66*x**10 - 220*x**9 + 495*x**8 - 792*x**7 + 924*x**6 - \ + 792*x**5 + 495*x**4 - 220*x**3 + 66*x**2 - 12*x + 1) + f(x)**2 + f(x)), + [Eq(f(x), 1/(x**6 - 6*x**5 + 15*x**4 - 20*x**3 + 15*x**2 - 6*x + 1))] + ), + # More than 2 poles with multiplicity 2 + # Regression: Fixed mistake in necessary conditions + ( + Eq(f(x).diff(x), x*f(x) + 2*x + (3*x - 2)*f(x)**2/(4*x + 2) + \ + (8*x**2 - 7*x + 26)/(16*x**3 - 24*x**2 + 8) - S(3)/2), + [Eq(f(x), (1 - 4*x)/(2*x - 2))] + ), + # Regression: Fixed mistake in necessary conditions + ( + Eq(f(x).diff(x), (-12*x**2 - 48*x - 15)/(24*x**3 - 40*x**2 + 8*x + 8) \ + + 3*f(x)**2/(6*x + 2)), + [Eq(f(x), (2*x + 1)/(2*x - 2))] + ), + # Imaginary poles + ( + f(x).diff(x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \ + - 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2), + [Eq(f(x), (-C0 - x**3 + x**2 - 2*x)/(C0*x - C0 + x**4 - x**3 + x**2 \ + - x)), Eq(f(x), -1/(x - 1))], + ), + # Imaginary coefficients in equation + ( + f(x).diff(x) - 2*I*(f(x)**2 + 1)/x, + [Eq(f(x), (-I*C0 + I*x**4)/(C0 + x**4)), Eq(f(x), -I)] + ), + # Regression: linsolve returning empty solution + # Large value of m (> 10) + ( + Eq(f(x).diff(x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\ + (2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)), + [Eq(f(x), (9 - x)/x), Eq(f(x), (40*x**14 + 28*x**13 + 420*x**12 + 2940*x**11 + \ + 18480*x**10 + 103950*x**9 + 519750*x**8 + 2286900*x**7 + 8731800*x**6 + 28378350*\ + x**5 + 76403250*x**4 + 163721250*x**3 + 261954000*x**2 + 278326125*x + 147349125)/\ + ((24*x**14 + 140*x**13 + 840*x**12 + 4620*x**11 + 23100*x**10 + 103950*x**9 + \ + 415800*x**8 + 1455300*x**7 + 4365900*x**6 + 10914750*x**5 + 21829500*x**4 + 32744250\ + *x**3 + 32744250*x**2 + 16372125*x)))] + ), + # Regression: Fixed bug due to a typo in paper + ( + Eq(f(x).diff(x), 18*x**3 + 18*x**2 + (-x/2 - S(1)/2)*f(x)**2 + 6), + [Eq(f(x), 6*x)] + ), + # Regression: Fixed bug due to a typo in paper + ( + Eq(f(x).diff(x), -3*x**3/4 + 15*x/2 + (x/3 - S(4)/3)*f(x)**2 \ + + 9 + (1 - x)*f(x)/x + 3/x), + [Eq(f(x), -3*x/2 - 3)] + )] + for eq, sol in tests: + check_dummy_sol(eq, sol, C0) + + +@slow +def test_solve_riccati_slow(): + """ + This function tests the computation of rational + particular solutions for a Riccati ODE. + + Each test case has 2 values - + + 1. eq - Riccati ODE to be solved. + 2. sol - Expected solution to the equation. + """ + C0 = Dummy('C0') + tests = [ + # Very large values of m (989 and 991) + ( + Eq(f(x).diff(x), (1 - x)*f(x)/(x - 3) + (2 - 12*x)*f(x)**2/(2*x - 9) + \ + (54924*x**3 - 405264*x**2 + 1084347*x - 1087533)/(8*x**4 - 132*x**3 + 810*x**2 - \ + 2187*x + 2187) + 495), + [Eq(f(x), (18*x + 6)/(2*x - 9))] + )] + for eq, sol in tests: + check_dummy_sol(eq, sol, C0) diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_single.py b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_single.py new file mode 100644 index 0000000000000000000000000000000000000000..7bd34add98d072a7fabdd16a16e3afaee7bfe53f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_single.py @@ -0,0 +1,2902 @@ +# +# The main tests for the code in single.py are currently located in +# sympy/solvers/tests/test_ode.py +# +r""" +This File contains test functions for the individual hints used for solving ODEs. + +Examples of each solver will be returned by _get_examples_ode_sol_name_of_solver. + +Examples should have a key 'XFAIL' which stores the list of hints if they are +expected to fail for that hint. + +Functions that are for internal use: + +1) _ode_solver_test(ode_examples) - It takes a dictionary of examples returned by + _get_examples method and tests them with their respective hints. + +2) _test_particular_example(our_hint, example_name) - It tests the ODE example corresponding + to the hint provided. + +3) _test_all_hints(runxfail=False) - It is used to test all the examples with all the hints + currently implemented. It calls _test_all_examples_for_one_hint() which outputs whether the + given hint functions properly if it classifies the ODE example. + If runxfail flag is set to True then it will only test the examples which are expected to fail. + + Everytime the ODE of a particular solver is added, _test_all_hints() is to be executed to find + the possible failures of different solver hints. + +4) _test_all_examples_for_one_hint(our_hint, all_examples) - It takes hint as argument and checks + this hint against all the ODE examples and gives output as the number of ODEs matched, number + of ODEs which were solved correctly, list of ODEs which gives incorrect solution and list of + ODEs which raises exception. + +""" +from sympy.core.function import (Derivative, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (E, I, Rational, pi) +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.elementary.complexes import (im, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (asinh, cosh, sinh, tanh) +from sympy.functions.elementary.miscellaneous import (cbrt, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sec, sin, tan) +from sympy.functions.special.error_functions import (Ei, erfi) +from sympy.functions.special.hyper import hyper +from sympy.integrals.integrals import (Integral, integrate) +from sympy.polys.rootoftools import rootof + +from sympy.core import Function, Symbol +from sympy.functions import airyai, airybi, besselj, bessely, lowergamma +from sympy.integrals.risch import NonElementaryIntegral +from sympy.solvers.ode import classify_ode, dsolve +from sympy.solvers.ode.ode import allhints, _remove_redundant_solutions +from sympy.solvers.ode.single import (FirstLinear, ODEMatchError, + SingleODEProblem, SingleODESolver, NthOrderReducible) + +from sympy.solvers.ode.subscheck import checkodesol + +from sympy.testing.pytest import raises, slow +import traceback + + +x = Symbol('x') +u = Symbol('u') +_u = Dummy('u') +y = Symbol('y') +f = Function('f') +g = Function('g') +C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C1:11') +a, b, c = symbols('a b c') + + +hint_message = """\ +Hint did not match the example {example}. + +The ODE is: +{eq}. + +The expected hint was +{our_hint}\ +""" + +expected_sol_message = """\ +Different solution found from dsolve for example {example}. + +The ODE is: +{eq} + +The expected solution was +{sol} + +What dsolve returned is: +{dsolve_sol}\ +""" + +checkodesol_msg = """\ +solution found is not correct for example {example}. + +The ODE is: +{eq}\ +""" + +dsol_incorrect_msg = """\ +solution returned by dsolve is incorrect when using {hint}. + +The ODE is: +{eq} + +The expected solution was +{sol} + +what dsolve returned is: +{dsolve_sol} + +You can test this with: + +eq = {eq} +sol = dsolve(eq, hint='{hint}') +print(sol) +print(checkodesol(eq, sol)) + +""" + +exception_msg = """\ +dsolve raised exception : {e} + +when using {hint} for the example {example} + +You can test this with: + +from sympy.solvers.ode.tests.test_single import _test_an_example + +_test_an_example('{hint}', example_name = '{example}') + +The ODE is: +{eq} + +\ +""" + +check_hint_msg = """\ +Tested hint was : {hint} + +Total of {matched} examples matched with this hint. + +Out of which {solve} gave correct results. + +Examples which gave incorrect results are {unsolve}. + +Examples which raised exceptions are {exceptions} +\ +""" + + +def _add_example_keys(func): + def inner(): + solver=func() + examples=[] + for example in solver['examples']: + temp={ + 'eq': solver['examples'][example]['eq'], + 'sol': solver['examples'][example]['sol'], + 'XFAIL': solver['examples'][example].get('XFAIL', []), + 'func': solver['examples'][example].get('func',solver['func']), + 'example_name': example, + 'slow': solver['examples'][example].get('slow', False), + 'simplify_flag':solver['examples'][example].get('simplify_flag',True), + 'checkodesol_XFAIL': solver['examples'][example].get('checkodesol_XFAIL', False), + 'dsolve_too_slow':solver['examples'][example].get('dsolve_too_slow',False), + 'checkodesol_too_slow':solver['examples'][example].get('checkodesol_too_slow',False), + 'hint': solver['hint'] + } + examples.append(temp) + return examples + return inner() + + +def _ode_solver_test(ode_examples, run_slow_test=False): + for example in ode_examples: + if ((not run_slow_test) and example['slow']) or (run_slow_test and (not example['slow'])): + continue + + result = _test_particular_example(example['hint'], example, solver_flag=True) + if result['xpass_msg'] != "": + print(result['xpass_msg']) + + +def _test_all_hints(runxfail=False): + all_hints = list(allhints)+["default"] + all_examples = _get_all_examples() + + for our_hint in all_hints: + if our_hint.endswith('_Integral') or 'series' in our_hint: + continue + _test_all_examples_for_one_hint(our_hint, all_examples, runxfail) + + +def _test_dummy_sol(expected_sol,dsolve_sol): + if type(dsolve_sol)==list: + return any(expected_sol.dummy_eq(sub_dsol) for sub_dsol in dsolve_sol) + else: + return expected_sol.dummy_eq(dsolve_sol) + + +def _test_an_example(our_hint, example_name): + all_examples = _get_all_examples() + for example in all_examples: + if example['example_name'] == example_name: + _test_particular_example(our_hint, example) + + +def _test_particular_example(our_hint, ode_example, solver_flag=False): + eq = ode_example['eq'] + expected_sol = ode_example['sol'] + example = ode_example['example_name'] + xfail = our_hint in ode_example['XFAIL'] + func = ode_example['func'] + result = {'msg': '', 'xpass_msg': ''} + simplify_flag=ode_example['simplify_flag'] + checkodesol_XFAIL = ode_example['checkodesol_XFAIL'] + dsolve_too_slow = ode_example['dsolve_too_slow'] + checkodesol_too_slow = ode_example['checkodesol_too_slow'] + xpass = True + if solver_flag: + if our_hint not in classify_ode(eq, func): + message = hint_message.format(example=example, eq=eq, our_hint=our_hint) + raise AssertionError(message) + + if our_hint in classify_ode(eq, func): + result['match_list'] = example + try: + if not (dsolve_too_slow): + dsolve_sol = dsolve(eq, func, simplify=simplify_flag,hint=our_hint) + else: + if len(expected_sol)==1: + dsolve_sol = expected_sol[0] + else: + dsolve_sol = expected_sol + + except Exception as e: + dsolve_sol = [] + result['exception_list'] = example + if not solver_flag: + traceback.print_exc() + result['msg'] = exception_msg.format(e=str(e), hint=our_hint, example=example, eq=eq) + if solver_flag and not xfail: + print(result['msg']) + raise + xpass = False + + if solver_flag and dsolve_sol!=[]: + expect_sol_check = False + if type(dsolve_sol)==list: + for sub_sol in expected_sol: + if sub_sol.has(Dummy): + expect_sol_check = not _test_dummy_sol(sub_sol, dsolve_sol) + else: + expect_sol_check = sub_sol not in dsolve_sol + if expect_sol_check: + break + else: + expect_sol_check = dsolve_sol not in expected_sol + for sub_sol in expected_sol: + if sub_sol.has(Dummy): + expect_sol_check = not _test_dummy_sol(sub_sol, dsolve_sol) + + if expect_sol_check: + message = expected_sol_message.format(example=example, eq=eq, sol=expected_sol, dsolve_sol=dsolve_sol) + raise AssertionError(message) + + expected_checkodesol = [(True, 0) for i in range(len(expected_sol))] + if len(expected_sol) == 1: + expected_checkodesol = (True, 0) + + if not checkodesol_too_slow: + if not checkodesol_XFAIL: + if checkodesol(eq, dsolve_sol, func, solve_for_func=False) != expected_checkodesol: + result['unsolve_list'] = example + xpass = False + message = dsol_incorrect_msg.format(hint=our_hint, eq=eq, sol=expected_sol,dsolve_sol=dsolve_sol) + if solver_flag: + message = checkodesol_msg.format(example=example, eq=eq) + raise AssertionError(message) + else: + result['msg'] = 'AssertionError: ' + message + + if xpass and xfail: + result['xpass_msg'] = example + "is now passing for the hint" + our_hint + return result + + +def _test_all_examples_for_one_hint(our_hint, all_examples=[], runxfail=None): + if all_examples == []: + all_examples = _get_all_examples() + match_list, unsolve_list, exception_list = [], [], [] + for ode_example in all_examples: + xfail = our_hint in ode_example['XFAIL'] + if runxfail and not xfail: + continue + if xfail: + continue + result = _test_particular_example(our_hint, ode_example) + match_list += result.get('match_list',[]) + unsolve_list += result.get('unsolve_list',[]) + exception_list += result.get('exception_list',[]) + if runxfail is not None: + msg = result['msg'] + if msg!='': + print(result['msg']) + # print(result.get('xpass_msg','')) + if runxfail is None: + match_count = len(match_list) + solved = len(match_list)-len(unsolve_list)-len(exception_list) + msg = check_hint_msg.format(hint=our_hint, matched=match_count, solve=solved, unsolve=unsolve_list, exceptions=exception_list) + print(msg) + + +def test_SingleODESolver(): + # Test that not implemented methods give NotImplementedError + # Subclasses should override these methods. + problem = SingleODEProblem(f(x).diff(x), f(x), x) + solver = SingleODESolver(problem) + raises(NotImplementedError, lambda: solver.matches()) + raises(NotImplementedError, lambda: solver.get_general_solution()) + raises(NotImplementedError, lambda: solver._matches()) + raises(NotImplementedError, lambda: solver._get_general_solution()) + + # This ODE can not be solved by the FirstLinear solver. Here we test that + # it does not match and the asking for a general solution gives + # ODEMatchError + + problem = SingleODEProblem(f(x).diff(x) + f(x)*f(x), f(x), x) + + solver = FirstLinear(problem) + raises(ODEMatchError, lambda: solver.get_general_solution()) + + solver = FirstLinear(problem) + assert solver.matches() is False + + #These are just test for order of ODE + + problem = SingleODEProblem(f(x).diff(x) + f(x), f(x), x) + assert problem.order == 1 + + problem = SingleODEProblem(f(x).diff(x,4) + f(x).diff(x,2) - f(x).diff(x,3), f(x), x) + assert problem.order == 4 + + problem = SingleODEProblem(f(x).diff(x, 3) + f(x).diff(x, 2) - f(x)**2, f(x), x) + assert problem.is_autonomous == True + + problem = SingleODEProblem(f(x).diff(x, 3) + x*f(x).diff(x, 2) - f(x)**2, f(x), x) + assert problem.is_autonomous == False + + +def test_linear_coefficients(): + _ode_solver_test(_get_examples_ode_sol_linear_coefficients) + + +@slow +def test_1st_homogeneous_coeff_ode(): + #These were marked as test_1st_homogeneous_coeff_corner_case + eq1 = f(x).diff(x) - f(x)/x + c1 = classify_ode(eq1, f(x)) + eq2 = x*f(x).diff(x) - f(x) + c2 = classify_ode(eq2, f(x)) + sdi = "1st_homogeneous_coeff_subs_dep_div_indep" + sid = "1st_homogeneous_coeff_subs_indep_div_dep" + assert sid not in c1 and sdi not in c1 + assert sid not in c2 and sdi not in c2 + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep) + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_best) + + +@slow +def test_slow_examples_1st_homogeneous_coeff_ode(): + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep, run_slow_test=True) + _ode_solver_test(_get_examples_ode_sol_1st_homogeneous_coeff_best, run_slow_test=True) + + +@slow +def test_nth_linear_constant_coeff_homogeneous(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_constant_coeff_homogeneous) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_homogeneous(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_constant_coeff_homogeneous, run_slow_test=True) + + +def test_Airy_equation(): + _ode_solver_test(_get_examples_ode_sol_2nd_linear_airy) + + +@slow +def test_lie_group(): + _ode_solver_test(_get_examples_ode_sol_lie_group) + + +@slow +def test_separable_reduced(): + df = f(x).diff(x) + eq = (x / f(x))*df + tan(x**2*f(x) / (x**2*f(x) - 1)) + assert classify_ode(eq) == ('factorable', 'separable_reduced', 'lie_group', + 'separable_reduced_Integral') + _ode_solver_test(_get_examples_ode_sol_separable_reduced) + + +@slow +def test_slow_examples_separable_reduced(): + _ode_solver_test(_get_examples_ode_sol_separable_reduced, run_slow_test=True) + + +@slow +def test_2nd_2F1_hypergeometric(): + _ode_solver_test(_get_examples_ode_sol_2nd_2F1_hypergeometric) + + +def test_2nd_2F1_hypergeometric_integral(): + eq = x*(x-1)*f(x).diff(x, 2) + (-1+ S(7)/2*x)*f(x).diff(x) + f(x) + sol = Eq(f(x), (C1 + C2*Integral(exp(Integral((1 - x/2)/(x*(x - 1)), x))/(1 - + x/2)**2, x))*exp(Integral(1/(x - 1), x)/4)*exp(-Integral(7/(x - + 1), x)/4)*hyper((S(1)/2, -1), (1,), x)) + assert sol == dsolve(eq, hint='2nd_hypergeometric_Integral') + assert checkodesol(eq, sol) == (True, 0) + + +@slow +def test_2nd_nonlinear_autonomous_conserved(): + _ode_solver_test(_get_examples_ode_sol_2nd_nonlinear_autonomous_conserved) + + +def test_2nd_nonlinear_autonomous_conserved_integral(): + eq = f(x).diff(x, 2) + asin(f(x)) + actual = [Eq(Integral(1/sqrt(C1 - 2*Integral(asin(_u), _u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*Integral(asin(_u), _u)), (_u, f(x))), C2 - x)] + solved = dsolve(eq, hint='2nd_nonlinear_autonomous_conserved_Integral', simplify=False) + for a,s in zip(actual, solved): + assert a.dummy_eq(s) + # checkodesol unable to simplify solutions with f(x) in an integral equation + assert checkodesol(eq, [s.doit() for s in solved]) == [(True, 0), (True, 0)] + + +@slow +def test_2nd_linear_bessel_equation(): + _ode_solver_test(_get_examples_ode_sol_2nd_linear_bessel) + + +@slow +def test_nth_algebraic(): + eqn = f(x) + f(x)*f(x).diff(x) + solns = [Eq(f(x), exp(x)), + Eq(f(x), C1*exp(C2*x))] + solns_final = _remove_redundant_solutions(eqn, solns, 2, x) + assert solns_final == [Eq(f(x), C1*exp(C2*x))] + + _ode_solver_test(_get_examples_ode_sol_nth_algebraic) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_var_of_parameters(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_var_of_parameters, run_slow_test=True) + + +def test_nth_linear_constant_coeff_var_of_parameters(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_var_of_parameters) + + +@slow +def test_nth_linear_constant_coeff_variation_of_parameters__integral(): + # solve_variation_of_parameters shouldn't attempt to simplify the + # Wronskian if simplify=False. If wronskian() ever gets good enough + # to simplify the result itself, this test might fail. + our_hint = 'nth_linear_constant_coeff_variation_of_parameters_Integral' + eq = f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x) + sol_simp = dsolve(eq, f(x), hint=our_hint, simplify=True) + sol_nsimp = dsolve(eq, f(x), hint=our_hint, simplify=False) + assert sol_simp != sol_nsimp + assert checkodesol(eq, sol_simp, order=5, solve_for_func=False) == (True, 0) + assert checkodesol(eq, sol_simp, order=5, solve_for_func=False) == (True, 0) + + +@slow +def test_slow_examples_1st_exact(): + _ode_solver_test(_get_examples_ode_sol_1st_exact, run_slow_test=True) + + +@slow +def test_1st_exact(): + _ode_solver_test(_get_examples_ode_sol_1st_exact) + + +def test_1st_exact_integral(): + eq = cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x) + sol_1 = dsolve(eq, f(x), simplify=False, hint='1st_exact_Integral') + assert checkodesol(eq, sol_1, order=1, solve_for_func=False) + + +@slow +def test_slow_examples_nth_order_reducible(): + _ode_solver_test(_get_examples_ode_sol_nth_order_reducible, run_slow_test=True) + + +@slow +def test_slow_examples_nth_linear_constant_coeff_undetermined_coefficients(): + _ode_solver_test(_get_examples_ode_sol_nth_linear_undetermined_coefficients, run_slow_test=True) + + +@slow +def test_slow_examples_separable(): + _ode_solver_test(_get_examples_ode_sol_separable, run_slow_test=True) + + +@slow +def test_nth_linear_constant_coeff_undetermined_coefficients(): + #issue-https://github.com/sympy/sympy/issues/5787 + # This test case is to show the classification of imaginary constants under + # nth_linear_constant_coeff_undetermined_coefficients + eq = Eq(diff(f(x), x), I*f(x) + S.Half - I) + our_hint = 'nth_linear_constant_coeff_undetermined_coefficients' + assert our_hint in classify_ode(eq) + _ode_solver_test(_get_examples_ode_sol_nth_linear_undetermined_coefficients) + + +def test_nth_order_reducible(): + F = lambda eq: NthOrderReducible(SingleODEProblem(eq, f(x), x))._matches() + D = Derivative + assert F(D(y*f(x), x, y) + D(f(x), x)) == False + assert F(D(y*f(y), y, y) + D(f(y), y)) == False + assert F(f(x)*D(f(x), x) + D(f(x), x, 2))== False + assert F(D(x*f(y), y, 2) + D(u*y*f(x), x, 3)) == False # no simplification by design + assert F(D(f(y), y, 2) + D(f(y), y, 3) + D(f(x), x, 4)) == False + assert F(D(f(x), x, 2) + D(f(x), x, 3)) == True + _ode_solver_test(_get_examples_ode_sol_nth_order_reducible) + + +@slow +def test_separable(): + _ode_solver_test(_get_examples_ode_sol_separable) + + +@slow +def test_factorable(): + assert integrate(-asin(f(2*x)+pi), x) == -Integral(asin(pi + f(2*x)), x) + _ode_solver_test(_get_examples_ode_sol_factorable) + + +@slow +def test_slow_examples_factorable(): + _ode_solver_test(_get_examples_ode_sol_factorable, run_slow_test=True) + + +def test_Riccati_special_minus2(): + _ode_solver_test(_get_examples_ode_sol_riccati) + + +@slow +def test_1st_rational_riccati(): + _ode_solver_test(_get_examples_ode_sol_1st_rational_riccati) + + +def test_Bernoulli(): + _ode_solver_test(_get_examples_ode_sol_bernoulli) + + +def test_1st_linear(): + _ode_solver_test(_get_examples_ode_sol_1st_linear) + + +def test_almost_linear(): + _ode_solver_test(_get_examples_ode_sol_almost_linear) + + +@slow +def test_Liouville_ODE(): + hint = 'Liouville' + not_Liouville1 = classify_ode(diff(f(x), x)/x + f(x)*diff(f(x), x, x)/2 - + diff(f(x), x)**2/2, f(x)) + not_Liouville2 = classify_ode(diff(f(x), x)/x + diff(f(x), x, x)/2 - + x*diff(f(x), x)**2/2, f(x)) + assert hint not in not_Liouville1 + assert hint not in not_Liouville2 + assert hint + '_Integral' not in not_Liouville1 + assert hint + '_Integral' not in not_Liouville2 + + _ode_solver_test(_get_examples_ode_sol_liouville) + + +def test_nth_order_linear_euler_eq_homogeneous(): + x, t, a, b, c = symbols('x t a b c') + y = Function('y') + our_hint = "nth_linear_euler_eq_homogeneous" + + eq = diff(f(t), t, 4)*t**4 - 13*diff(f(t), t, 2)*t**2 + 36*f(t) + assert our_hint in classify_ode(eq) + + eq = a*y(t) + b*t*diff(y(t), t) + c*t**2*diff(y(t), t, 2) + assert our_hint in classify_ode(eq) + + _ode_solver_test(_get_examples_ode_sol_euler_homogeneous) + + +def test_nth_order_linear_euler_eq_nonhomogeneous_undetermined_coefficients(): + x, t = symbols('x t') + a, b, c, d = symbols('a b c d', integer=True) + our_hint = "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients" + + eq = x**4*diff(f(x), x, 4) - 13*x**2*diff(f(x), x, 2) + 36*f(x) + x + assert our_hint in classify_ode(eq, f(x)) + + eq = a*x**2*diff(f(x), x, 2) + b*x*diff(f(x), x) + c*f(x) + d*log(x) + assert our_hint in classify_ode(eq, f(x)) + + _ode_solver_test(_get_examples_ode_sol_euler_undetermined_coeff) + + +@slow +def test_nth_order_linear_euler_eq_nonhomogeneous_variation_of_parameters(): + x, t = symbols('x, t') + a, b, c, d = symbols('a, b, c, d', integer=True) + our_hint = "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters" + + eq = Eq(x**2*diff(f(x),x,2) - 8*x*diff(f(x),x) + 12*f(x), x**2) + assert our_hint in classify_ode(eq, f(x)) + + eq = Eq(a*x**3*diff(f(x),x,3) + b*x**2*diff(f(x),x,2) + c*x*diff(f(x),x) + d*f(x), x*log(x)) + assert our_hint in classify_ode(eq, f(x)) + + _ode_solver_test(_get_examples_ode_sol_euler_var_para) + + +@_add_example_keys +def _get_examples_ode_sol_euler_homogeneous(): + r1, r2, r3, r4, r5 = [rootof(x**5 - 14*x**4 + 71*x**3 - 154*x**2 + 120*x - 1, n) for n in range(5)] + return { + 'hint': "nth_linear_euler_eq_homogeneous", + 'func': f(x), + 'examples':{ + 'euler_hom_01': { + 'eq': Eq(-3*diff(f(x), x)*x + 2*x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), C1 + C2*x**Rational(5, 2))], + }, + + 'euler_hom_02': { + 'eq': Eq(3*f(x) - 5*diff(f(x), x)*x + 2*x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), C1*sqrt(x) + C2*x**3)] + }, + + 'euler_hom_03': { + 'eq': Eq(4*f(x) + 5*diff(f(x), x)*x + x**2*diff(f(x), x, x), 0), + 'sol': [Eq(f(x), (C1 + C2*log(x))/x**2)] + }, + + 'euler_hom_04': { + 'eq': Eq(6*f(x) - 6*diff(f(x), x)*x + 1*x**2*diff(f(x), x, x) + x**3*diff(f(x), x, x, x), 0), + 'sol': [Eq(f(x), C1/x**2 + C2*x + C3*x**3)] + }, + + 'euler_hom_05': { + 'eq': Eq(-125*f(x) + 61*diff(f(x), x)*x - 12*x**2*diff(f(x), x, x) + x**3*diff(f(x), x, x, x), 0), + 'sol': [Eq(f(x), x**5*(C1 + C2*log(x) + C3*log(x)**2))] + }, + + 'euler_hom_06': { + 'eq': x**2*diff(f(x), x, 2) + x*diff(f(x), x) - 9*f(x), + 'sol': [Eq(f(x), C1*x**-3 + C2*x**3)] + }, + + 'euler_hom_07': { + 'eq': sin(x)*x**2*f(x).diff(x, 2) + sin(x)*x*f(x).diff(x) + sin(x)*f(x), + 'sol': [Eq(f(x), C1*sin(log(x)) + C2*cos(log(x)))], + 'XFAIL': ['2nd_power_series_regular','nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients'] + }, + + 'euler_hom_08': { + 'eq': x**6 * f(x).diff(x, 6) - x*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*x + C2*x**r1 + C3*x**r2 + C4*x**r3 + C5*x**r4 + C6*x**r5)], + 'checkodesol_XFAIL':True + }, + + #This example is from issue: https://github.com/sympy/sympy/issues/15237 #This example is from issue: + # https://github.com/sympy/sympy/issues/15237 + 'euler_hom_09': { + 'eq': Derivative(x*f(x), x, x, x), + 'sol': [Eq(f(x), C1 + C2/x + C3*x)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_euler_undetermined_coeff(): + return { + 'hint': "nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients", + 'func': f(x), + 'examples':{ + 'euler_undet_01': { + 'eq': Eq(x**2*diff(f(x), x, x) + x*diff(f(x), x), 1), + 'sol': [Eq(f(x), C1 + C2*log(x) + log(x)**2/2)] + }, + + 'euler_undet_02': { + 'eq': Eq(x**2*diff(f(x), x, x) - 2*x*diff(f(x), x) + 2*f(x), x**3), + 'sol': [Eq(f(x), x*(C1 + C2*x + Rational(1, 2)*x**2))] + }, + + 'euler_undet_03': { + 'eq': Eq(x**2*diff(f(x), x, x) - x*diff(f(x), x) - 3*f(x), log(x)/x), + 'sol': [Eq(f(x), (C1 + C2*x**4 - log(x)**2/8 - log(x)/16)/x)] + }, + + 'euler_undet_04': { + 'eq': Eq(x**2*diff(f(x), x, x) + 3*x*diff(f(x), x) - 8*f(x), log(x)**3 - log(x)), + 'sol': [Eq(f(x), C1/x**4 + C2*x**2 - Rational(1,8)*log(x)**3 - Rational(3,32)*log(x)**2 - Rational(1,64)*log(x) - Rational(7, 256))] + }, + + 'euler_undet_05': { + 'eq': Eq(x**3*diff(f(x), x, x, x) - 3*x**2*diff(f(x), x, x) + 6*x*diff(f(x), x) - 6*f(x), log(x)), + 'sol': [Eq(f(x), C1*x + C2*x**2 + C3*x**3 - Rational(1, 6)*log(x) - Rational(11, 36))] + }, + + #Below examples were added for the issue: https://github.com/sympy/sympy/issues/5096 + 'euler_undet_06': { + 'eq': 2*x**2*f(x).diff(x, 2) + f(x) + sqrt(2*x)*sin(log(2*x)/2), + 'sol': [Eq(f(x), sqrt(x)*(C1*sin(log(x)/2) + C2*cos(log(x)/2) + sqrt(2)*log(x)*cos(log(2*x)/2)/2))] + }, + + 'euler_undet_07': { + 'eq': 2*x**2*f(x).diff(x, 2) + f(x) + sin(log(2*x)/2), + 'sol': [Eq(f(x), C1*sqrt(x)*sin(log(x)/2) + C2*sqrt(x)*cos(log(x)/2) - 2*sin(log(2*x)/2)/5 - 4*cos(log(2*x)/2)/5)] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_euler_var_para(): + return { + 'hint': "nth_linear_euler_eq_nonhomogeneous_variation_of_parameters", + 'func': f(x), + 'examples':{ + 'euler_var_01': { + 'eq': Eq(x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x), x**4), + 'sol': [Eq(f(x), x*(C1 + C2*x + x**3/6))] + }, + + 'euler_var_02': { + 'eq': Eq(3*x**2*diff(f(x), x, x) + 6*x*diff(f(x), x) - 6*f(x), x**3*exp(x)), + 'sol': [Eq(f(x), C1/x**2 + C2*x + x*exp(x)/3 - 4*exp(x)/3 + 8*exp(x)/(3*x) - 8*exp(x)/(3*x**2))] + }, + + 'euler_var_03': { + 'eq': Eq(x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x), x**4*exp(x)), + 'sol': [Eq(f(x), x*(C1 + C2*x + x*exp(x) - 2*exp(x)))] + }, + + 'euler_var_04': { + 'eq': x**2*Derivative(f(x), x, x) - 2*x*Derivative(f(x), x) + 2*f(x) - log(x), + 'sol': [Eq(f(x), C1*x + C2*x**2 + log(x)/2 + Rational(3, 4))] + }, + + 'euler_var_05': { + 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))] + }, + + 'euler_var_06': { + 'eq': x**2 * f(x).diff(x, 2) + x * f(x).diff(x) + 4 * f(x) - 1/x, + 'sol': [Eq(f(x), C1*sin(2*log(x)) + C2*cos(2*log(x)) + 1/(5*x))] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_bernoulli(): + # Type: Bernoulli, f'(x) + p(x)*f(x) == q(x)*f(x)**n + return { + 'hint': "Bernoulli", + 'func': f(x), + 'examples':{ + 'bernoulli_01': { + 'eq': Eq(x*f(x).diff(x) + f(x) - f(x)**2, 0), + 'sol': [Eq(f(x), 1/(C1*x + 1))], + 'XFAIL': ['separable_reduced'] + }, + + 'bernoulli_02': { + 'eq': f(x).diff(x) - y*f(x), + 'sol': [Eq(f(x), C1*exp(x*y))] + }, + + 'bernoulli_03': { + 'eq': f(x)*f(x).diff(x) - 1, + 'sol': [Eq(f(x), -sqrt(C1 + 2*x)), Eq(f(x), sqrt(C1 + 2*x))] + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_riccati(): + # Type: Riccati special alpha = -2, a*dy/dx + b*y**2 + c*y/x +d/x**2 + return { + 'hint': "Riccati_special_minus2", + 'func': f(x), + 'examples':{ + 'riccati_01': { + 'eq': 2*f(x).diff(x) + f(x)**2 - f(x)/x + 3*x**(-2), + 'sol': [Eq(f(x), (-sqrt(3)*tan(C1 + sqrt(3)*log(x)/4) + 3)/(2*x))], + }, + }, + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_rational_riccati(): + # Type: 1st Order Rational Riccati, dy/dx = a + b*y + c*y**2, + # a, b, c are rational functions of x + return { + 'hint': "1st_rational_riccati", + 'func': f(x), + 'examples':{ + # a(x) is a constant + "rational_riccati_01": { + "eq": Eq(f(x).diff(x) + f(x)**2 - 2, 0), + "sol": [Eq(f(x), sqrt(2)*(-C1 - exp(2*sqrt(2)*x))/(C1 - exp(2*sqrt(2)*x)))] + }, + # a(x) is a constant + "rational_riccati_02": { + "eq": f(x)**2 + Derivative(f(x), x) + 4*f(x)/x + 2/x**2, + "sol": [Eq(f(x), (-2*C1 - x)/(x*(C1 + x)))] + }, + # a(x) is a constant + "rational_riccati_03": { + "eq": 2*x**2*Derivative(f(x), x) - x*(4*f(x) + Derivative(f(x), x) - 4) + (f(x) - 1)*f(x), + "sol": [Eq(f(x), (C1 + 2*x**2)/(C1 + x))] + }, + # Constant coefficients + "rational_riccati_04": { + "eq": f(x).diff(x) - 6 - 5*f(x) - f(x)**2, + "sol": [Eq(f(x), (-2*C1 + 3*exp(x))/(C1 - exp(x)))] + }, + # One pole of multiplicity 2 + "rational_riccati_05": { + "eq": x**2 - (2*x + 1/x)*f(x) + f(x)**2 + Derivative(f(x), x), + "sol": [Eq(f(x), x*(C1 + x**2 + 1)/(C1 + x**2 - 1))] + }, + # One pole of multiplicity 2 + "rational_riccati_06": { + "eq": x**4*Derivative(f(x), x) + x**2 - x*(2*f(x)**2 + Derivative(f(x), x)) + f(x), + "sol": [Eq(f(x), x*(C1*x - x + 1)/(C1 + x**2 - 1))] + }, + # Multiple poles of multiplicity 2 + "rational_riccati_07": { + "eq": -f(x)**2 + Derivative(f(x), x) + (15*x**2 - 20*x + 7)/((x - 1)**2*(2*x \ + - 1)**2), + "sol": [Eq(f(x), (9*C1*x - 6*C1 - 15*x**5 + 60*x**4 - 94*x**3 + 72*x**2 - \ + 33*x + 8)/(6*C1*x**2 - 9*C1*x + 3*C1 + 6*x**6 - 29*x**5 + 57*x**4 - \ + 58*x**3 + 28*x**2 - 3*x - 1))] + }, + # Imaginary poles + "rational_riccati_08": { + "eq": Derivative(f(x), x) + (3*x**2 + 1)*f(x)**2/x + (6*x**2 - x + 3)*f(x)/(x*(x \ + - 1)) + (3*x**2 - 2*x + 2)/(x*(x - 1)**2), + "sol": [Eq(f(x), (-C1 - x**3 + x**2 - 2*x + 1)/(C1*x - C1 + x**4 - x**3 + x**2 - \ + 2*x + 1))], + }, + # Imaginary coefficients in equation + "rational_riccati_09": { + "eq": Derivative(f(x), x) - 2*I*(f(x)**2 + 1)/x, + "sol": [Eq(f(x), (-I*C1 + I*x**4 + I)/(C1 + x**4 - 1))] + }, + # Regression: linsolve returning empty solution + # Large value of m (> 10) + "rational_riccati_10": { + "eq": Eq(Derivative(f(x), x), x*f(x)/(S(3)/2 - 2*x) + (x/2 - S(1)/3)*f(x)**2/\ + (2*x/3 - S(1)/2) - S(5)/4 + (281*x**2 - 1260*x + 756)/(16*x**3 - 12*x**2)), + "sol": [Eq(f(x), (40*C1*x**14 + 28*C1*x**13 + 420*C1*x**12 + 2940*C1*x**11 + \ + 18480*C1*x**10 + 103950*C1*x**9 + 519750*C1*x**8 + 2286900*C1*x**7 + \ + 8731800*C1*x**6 + 28378350*C1*x**5 + 76403250*C1*x**4 + 163721250*C1*x**3 \ + + 261954000*C1*x**2 + 278326125*C1*x + 147349125*C1 + x*exp(2*x) - 9*exp(2*x) \ + )/(x*(24*C1*x**13 + 140*C1*x**12 + 840*C1*x**11 + 4620*C1*x**10 + 23100*C1*x**9 \ + + 103950*C1*x**8 + 415800*C1*x**7 + 1455300*C1*x**6 + 4365900*C1*x**5 + \ + 10914750*C1*x**4 + 21829500*C1*x**3 + 32744250*C1*x**2 + 32744250*C1*x + \ + 16372125*C1 - exp(2*x))))] + } + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_1st_linear(): + # Type: first order linear form f'(x)+p(x)f(x)=q(x) + return { + 'hint': "1st_linear", + 'func': f(x), + 'examples':{ + 'linear_01': { + 'eq': Eq(f(x).diff(x) + x*f(x), x**2), + 'sol': [Eq(f(x), (C1 + x*exp(x**2/2)- sqrt(2)*sqrt(pi)*erfi(sqrt(2)*x/2)/2)*exp(-x**2/2))], + }, + }, + } + + +@_add_example_keys +def _get_examples_ode_sol_factorable(): + """ some hints are marked as xfail for examples because they missed additional algebraic solution + which could be found by Factorable hint. Fact_01 raise exception for + nth_linear_constant_coeff_undetermined_coefficients""" + + y = Dummy('y') + a0,a1,a2,a3,a4 = symbols('a0, a1, a2, a3, a4') + return { + 'hint': "factorable", + 'func': f(x), + 'examples':{ + 'fact_01': { + 'eq': f(x) + f(x)*f(x).diff(x), + 'sol': [Eq(f(x), 0), Eq(f(x), C1 - x)], + 'XFAIL': ['separable', '1st_exact', '1st_linear', 'Bernoulli', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', '1st_homogeneous_coeff_subs_dep_div_indep', + 'lie_group', 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters', + 'nth_linear_constant_coeff_undetermined_coefficients'] + }, + + 'fact_02': { + 'eq': f(x)*(f(x).diff(x)+f(x)*x+2), + 'sol': [Eq(f(x), (C1 - sqrt(2)*sqrt(pi)*erfi(sqrt(2)*x/2))*exp(-x**2/2)), Eq(f(x), 0)], + 'XFAIL': ['Bernoulli', '1st_linear', 'lie_group'] + }, + + 'fact_03': { + 'eq': (f(x).diff(x)+f(x)*x**2)*(f(x).diff(x, 2) + x*f(x)), + 'sol': [Eq(f(x), C1*airyai(-x) + C2*airybi(-x)),Eq(f(x), C1*exp(-x**3/3))] + }, + + 'fact_04': { + 'eq': (f(x).diff(x)+f(x)*x**2)*(f(x).diff(x, 2) + f(x)), + 'sol': [Eq(f(x), C1*exp(-x**3/3)), Eq(f(x), C1*sin(x) + C2*cos(x))] + }, + + 'fact_05': { + 'eq': (f(x).diff(x)**2-1)*(f(x).diff(x)**2-4), + 'sol': [Eq(f(x), C1 - x), Eq(f(x), C1 + x), Eq(f(x), C1 + 2*x), Eq(f(x), C1 - 2*x)] + }, + + 'fact_06': { + 'eq': (f(x).diff(x, 2)-exp(f(x)))*f(x).diff(x), + 'sol': [ + Eq(f(x), log(-C1/(cos(sqrt(-C1)*(C2 + x)) + 1))), + Eq(f(x), log(-C1/(cos(sqrt(-C1)*(C2 - x)) + 1))), + Eq(f(x), C1) + ], + 'slow': True, + }, + + 'fact_07': { + 'eq': (f(x).diff(x)**2-1)*(f(x)*f(x).diff(x)-1), + 'sol': [Eq(f(x), C1 - x), Eq(f(x), -sqrt(C1 + 2*x)),Eq(f(x), sqrt(C1 + 2*x)), Eq(f(x), C1 + x)] + }, + + 'fact_08': { + 'eq': Derivative(f(x), x)**4 - 2*Derivative(f(x), x)**2 + 1, + 'sol': [Eq(f(x), C1 - x), Eq(f(x), C1 + x)] + }, + + 'fact_09': { + 'eq': f(x)**2*Derivative(f(x), x)**6 - 2*f(x)**2*Derivative(f(x), + x)**4 + f(x)**2*Derivative(f(x), x)**2 - 2*f(x)*Derivative(f(x), + x)**5 + 4*f(x)*Derivative(f(x), x)**3 - 2*f(x)*Derivative(f(x), + x) + Derivative(f(x), x)**4 - 2*Derivative(f(x), x)**2 + 1, + 'sol': [ + Eq(f(x), C1 - x), Eq(f(x), -sqrt(C1 + 2*x)), + Eq(f(x), sqrt(C1 + 2*x)), Eq(f(x), C1 + x) + ] + }, + + 'fact_10': { + 'eq': x**4*f(x)**2 + 2*x**4*f(x)*Derivative(f(x), (x, 2)) + x**4*Derivative(f(x), + (x, 2))**2 + 2*x**3*f(x)*Derivative(f(x), x) + 2*x**3*Derivative(f(x), + x)*Derivative(f(x), (x, 2)) - 7*x**2*f(x)**2 - 7*x**2*f(x)*Derivative(f(x), + (x, 2)) + x**2*Derivative(f(x), x)**2 - 7*x*f(x)*Derivative(f(x), x) + 12*f(x)**2, + 'sol': [ + Eq(f(x), C1*besselj(2, x) + C2*bessely(2, x)), + Eq(f(x), C1*besselj(sqrt(3), x) + C2*bessely(sqrt(3), x)) + ], + 'slow': True, + }, + + 'fact_11': { + 'eq': (f(x).diff(x, 2)-exp(f(x)))*(f(x).diff(x, 2)+exp(f(x))), + 'sol': [ + Eq(f(x), log(C1/(cos(C1*sqrt(-1/C1)*(C2 + x)) - 1))), + Eq(f(x), log(C1/(cos(C1*sqrt(-1/C1)*(C2 - x)) - 1))), + Eq(f(x), log(C1/(1 - cos(C1*sqrt(-1/C1)*(C2 + x))))), + Eq(f(x), log(C1/(1 - cos(C1*sqrt(-1/C1)*(C2 - x))))) + ], + 'dsolve_too_slow': True, + }, + + #Below examples were added for the issue: https://github.com/sympy/sympy/issues/15889 + 'fact_12': { + 'eq': exp(f(x).diff(x))-f(x)**2, + 'sol': [Eq(NonElementaryIntegral(1/log(y**2), (y, f(x))), C1 + x)], + 'XFAIL': ['lie_group'] #It shows not implemented error for lie_group. + }, + + 'fact_13': { + 'eq': f(x).diff(x)**2 - f(x)**3, + 'sol': [Eq(f(x), 4/(C1**2 - 2*C1*x + x**2))], + 'XFAIL': ['lie_group'] #It shows not implemented error for lie_group. + }, + + 'fact_14': { + 'eq': f(x).diff(x)**2 - f(x), + 'sol': [Eq(f(x), C1**2/4 - C1*x/2 + x**2/4)] + }, + + 'fact_15': { + 'eq': f(x).diff(x)**2 - f(x)**2, + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1*exp(-x))] + }, + + 'fact_16': { + 'eq': f(x).diff(x)**2 - f(x)**3, + 'sol': [Eq(f(x), 4/(C1**2 - 2*C1*x + x**2))], + }, + + # kamke ode 1.1 + 'fact_17': { + 'eq': f(x).diff(x)-(a4*x**4 + a3*x**3 + a2*x**2 + a1*x + a0)**(-1/2), + 'sol': [Eq(f(x), C1 + Integral(1/sqrt(a0 + a1*x + a2*x**2 + a3*x**3 + a4*x**4), x))], + 'slow': True + }, + + # This is from issue: https://github.com/sympy/sympy/issues/9446 + 'fact_18':{ + 'eq': Eq(f(2 * x), sin(Derivative(f(x)))), + 'sol': [Eq(f(x), C1 + Integral(pi - asin(f(2*x)), x)), Eq(f(x), C1 + Integral(asin(f(2*x)), x))], + 'checkodesol_XFAIL':True + }, + + # This is from issue: https://github.com/sympy/sympy/issues/7093 + 'fact_19': { + 'eq': Derivative(f(x), x)**2 - x**3, + 'sol': [Eq(f(x), C1 - 2*x**Rational(5,2)/5), Eq(f(x), C1 + 2*x**Rational(5,2)/5)], + }, + + 'fact_20': { + 'eq': x*f(x).diff(x, 2) - x*f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(x))], + }, + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_almost_linear(): + from sympy.functions.special.error_functions import Ei + A = Symbol('A', positive=True) + f = Function('f') + d = f(x).diff(x) + + return { + 'hint': "almost_linear", + 'func': f(x), + 'examples':{ + 'almost_lin_01': { + 'eq': x**2*f(x)**2*d + f(x)**3 + 1, + 'sol': [Eq(f(x), (C1*exp(3/x) - 1)**Rational(1, 3)), + Eq(f(x), (-1 - sqrt(3)*I)*(C1*exp(3/x) - 1)**Rational(1, 3)/2), + Eq(f(x), (-1 + sqrt(3)*I)*(C1*exp(3/x) - 1)**Rational(1, 3)/2)], + + }, + + 'almost_lin_02': { + 'eq': x*f(x)*d + 2*x*f(x)**2 + 1, + 'sol': [Eq(f(x), -sqrt((C1 - 2*Ei(4*x))*exp(-4*x))), Eq(f(x), sqrt((C1 - 2*Ei(4*x))*exp(-4*x)))] + }, + + 'almost_lin_03': { + 'eq': x*d + x*f(x) + 1, + 'sol': [Eq(f(x), (C1 - Ei(x))*exp(-x))] + }, + + 'almost_lin_04': { + 'eq': x*exp(f(x))*d + exp(f(x)) + 3*x, + 'sol': [Eq(f(x), log(C1/x - x*Rational(3, 2)))], + }, + + 'almost_lin_05': { + 'eq': x + A*(x + diff(f(x), x) + f(x)) + diff(f(x), x) + f(x) + 2, + 'sol': [Eq(f(x), (C1 + Piecewise( + (x, Eq(A + 1, 0)), ((-A*x + A - x - 1)*exp(x)/(A + 1), True)))*exp(-x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_liouville(): + n = Symbol('n') + _y = Dummy('y') + return { + 'hint': "Liouville", + 'func': f(x), + 'examples':{ + 'liouville_01': { + 'eq': diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2, + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))], + + }, + + 'liouville_02': { + 'eq': diff(x*exp(-f(x)), x, x), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))] + }, + + 'liouville_03': { + 'eq': ((diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2)*exp(-f(x))/exp(f(x))).expand(), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))] + }, + + 'liouville_04': { + 'eq': diff(f(x), x, x) + 1/f(x)*(diff(f(x), x))**2 + 1/x*diff(f(x), x), + 'sol': [Eq(f(x), -sqrt(C1 + C2*log(x))), Eq(f(x), sqrt(C1 + C2*log(x)))], + }, + + 'liouville_05': { + 'eq': x*diff(f(x), x, x) + x/f(x)*diff(f(x), x)**2 + x*diff(f(x), x), + 'sol': [Eq(f(x), -sqrt(C1 + C2*exp(-x))), Eq(f(x), sqrt(C1 + C2*exp(-x)))], + }, + + 'liouville_06': { + 'eq': Eq((x*exp(f(x))).diff(x, x), 0), + 'sol': [Eq(f(x), log(C1 + C2/x))], + }, + + 'liouville_07': { + 'eq': (diff(f(x), x)/x + diff(f(x), x, x)/2 - diff(f(x), x)**2/2)*exp(-f(x))/exp(f(x)), + 'sol': [Eq(f(x), log(x/(C1 + C2*x)))], + }, + + 'liouville_08': { + 'eq': x**2*diff(f(x),x) + (n*f(x) + f(x)**2)*diff(f(x),x)**2 + diff(f(x), (x, 2)), + 'sol': [Eq(C1 + C2*lowergamma(Rational(1,3), x**3/3) + NonElementaryIntegral(exp(_y**3/3)*exp(_y**2*n/2), (_y, f(x))), 0)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_algebraic(): + M, m, r, t = symbols('M m r t') + phi = Function('phi') + k = Symbol('k') + # This one needs a substitution f' = g. + # 'algeb_12': { + # 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + # 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + # }, + return { + 'hint': "nth_algebraic", + 'func': f(x), + 'examples':{ + 'algeb_01': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1) * (f(x).diff(x) - x), + 'sol': [Eq(f(x), C1 + x**2/2), Eq(f(x), C1 + C2*x)] + }, + + 'algeb_02': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x) * (f(x) - 1), + 'sol': [Eq(f(x), C1 + C2*x)] + }, + + 'algeb_03': { + 'eq': f(x) * f(x).diff(x) * f(x).diff(x, x), + 'sol': [Eq(f(x), C1 + C2*x)] + }, + + 'algeb_04': { + 'eq': Eq(-M * phi(t).diff(t), + Rational(3, 2) * m * r**2 * phi(t).diff(t) * phi(t).diff(t,t)), + 'sol': [Eq(phi(t), C1), Eq(phi(t), C1 + C2*t - M*t**2/(3*m*r**2))], + 'func': phi(t) + }, + + 'algeb_05': { + 'eq': (1 - sin(f(x))) * f(x).diff(x), + 'sol': [Eq(f(x), C1)], + 'XFAIL': ['separable'] #It raised exception. + }, + + 'algeb_06': { + 'eq': (diff(f(x)) - x)*(diff(f(x)) + x), + 'sol': [Eq(f(x), C1 - x**2/2), Eq(f(x), C1 + x**2/2)] + }, + + 'algeb_07': { + 'eq': Eq(Derivative(f(x), x), Derivative(g(x), x)), + 'sol': [Eq(f(x), C1 + g(x))], + }, + + 'algeb_08': { + 'eq': f(x).diff(x) - C1, #this example is from issue 15999 + 'sol': [Eq(f(x), C1*x + C2)], + }, + + 'algeb_09': { + 'eq': f(x)*f(x).diff(x), + 'sol': [Eq(f(x), C1)], + }, + + 'algeb_10': { + 'eq': (diff(f(x)) - x)*(diff(f(x)) + x), + 'sol': [Eq(f(x), C1 - x**2/2), Eq(f(x), C1 + x**2/2)], + }, + + 'algeb_11': { + 'eq': f(x) + f(x)*f(x).diff(x), + 'sol': [Eq(f(x), 0), Eq(f(x), C1 - x)], + 'XFAIL': ['separable', '1st_exact', '1st_linear', 'Bernoulli', '1st_homogeneous_coeff_best', + '1st_homogeneous_coeff_subs_indep_div_dep', '1st_homogeneous_coeff_subs_dep_div_indep', + 'lie_group', 'nth_linear_constant_coeff_undetermined_coefficients', + 'nth_linear_euler_eq_nonhomogeneous_undetermined_coefficients', + 'nth_linear_constant_coeff_variation_of_parameters', + 'nth_linear_euler_eq_nonhomogeneous_variation_of_parameters'] + #nth_linear_constant_coeff_undetermined_coefficients raises exception rest all of them misses a solution. + }, + + 'algeb_12': { + 'eq': Derivative(x*f(x), x, x, x), + 'sol': [Eq(f(x), (C1 + C2*x + C3*x**2) / x)], + 'XFAIL': ['nth_algebraic'] # It passes only when prep=False is set in dsolve. + }, + + 'algeb_13': { + 'eq': Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x)), + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + 'XFAIL': ['nth_algebraic'] # It passes only when prep=False is set in dsolve. + }, + + # These are simple tests from the old ode module example 14-18 + 'algeb_14': { + 'eq': Eq(f(x).diff(x), 0), + 'sol': [Eq(f(x), C1)], + }, + + 'algeb_15': { + 'eq': Eq(3*f(x).diff(x) - 5, 0), + 'sol': [Eq(f(x), C1 + x*Rational(5, 3))], + }, + + 'algeb_16': { + 'eq': Eq(3*f(x).diff(x), 5), + 'sol': [Eq(f(x), C1 + x*Rational(5, 3))], + }, + + # Type: 2nd order, constant coefficients (two complex roots) + 'algeb_17': { + 'eq': Eq(3*f(x).diff(x) - 1, 0), + 'sol': [Eq(f(x), C1 + x/3)], + }, + + 'algeb_18': { + 'eq': Eq(x*f(x).diff(x) - 1, 0), + 'sol': [Eq(f(x), C1 + log(x))], + }, + + # https://github.com/sympy/sympy/issues/6989 + 'algeb_19': { + 'eq': f(x).diff(x) - x*exp(-k*x), + 'sol': [Eq(f(x), C1 + Piecewise(((-k*x - 1)*exp(-k*x)/k**2, Ne(k**2, 0)),(x**2/2, True)))], + }, + + 'algeb_20': { + 'eq': -f(x).diff(x) + x*exp(-k*x), + 'sol': [Eq(f(x), C1 + Piecewise(((-k*x - 1)*exp(-k*x)/k**2, Ne(k**2, 0)),(x**2/2, True)))], + }, + + # https://github.com/sympy/sympy/issues/10867 + 'algeb_21': { + 'eq': Eq(g(x).diff(x).diff(x), (x-2)**2 + (x-3)**3), + 'sol': [Eq(g(x), C1 + C2*x + x**5/20 - 2*x**4/3 + 23*x**3/6 - 23*x**2/2)], + 'func': g(x), + }, + + # https://github.com/sympy/sympy/issues/13691 + 'algeb_22': { + 'eq': f(x).diff(x) - C1*g(x).diff(x), + 'sol': [Eq(f(x), C2 + C1*g(x))], + 'func': f(x), + }, + + # https://github.com/sympy/sympy/issues/4838 + 'algeb_23': { + 'eq': f(x).diff(x) - 3*C1 - 3*x**2, + 'sol': [Eq(f(x), C2 + 3*C1*x + x**3)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_order_reducible(): + return { + 'hint': "nth_order_reducible", + 'func': f(x), + 'examples':{ + 'reducible_01': { + 'eq': Eq(x*Derivative(f(x), x)**2 + Derivative(f(x), x, 2), 0), + 'sol': [Eq(f(x),C1 - sqrt(-1/C2)*log(-C2*sqrt(-1/C2) + x) + + sqrt(-1/C2)*log(C2*sqrt(-1/C2) + x))], + 'slow': True, + }, + + 'reducible_02': { + 'eq': -exp(x) + (x*Derivative(f(x), (x, 2)) + Derivative(f(x), x))/x, + 'sol': [Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x))], + 'slow': True, + }, + + 'reducible_03': { + 'eq': Eq(sqrt(2) * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(2**Rational(3, 4)*x/2) + C3*cos(2**Rational(3, 4)*x/2))], + 'slow': True, + }, + + 'reducible_04': { + 'eq': f(x).diff(x, 2) + 2*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x))], + }, + + 'reducible_05': { + 'eq': f(x).diff(x, 3) + f(x).diff(x, 2) - 6*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x))], + 'slow': True, + }, + + 'reducible_06': { + 'eq': f(x).diff(x, 4) - f(x).diff(x, 3) - 4*f(x).diff(x, 2) + \ + 4*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x) + C3*exp(x) + C4*exp(2*x))], + 'slow': True, + }, + + 'reducible_07': { + 'eq': f(x).diff(x, 4) + 3*f(x).diff(x, 3), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*exp(-3*x))], + 'slow': True, + }, + + 'reducible_08': { + 'eq': f(x).diff(x, 4) - 2*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*exp(-sqrt(2)*x) + C4*exp(sqrt(2)*x))], + 'slow': True, + }, + + 'reducible_09': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*sin(2*x) + C4*cos(2*x))], + 'slow': True, + }, + + 'reducible_10': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*x*sin(x) + C2*cos(x) - C3*x*cos(x) + C3*sin(x) + C4*sin(x) + C5*cos(x))], + 'slow': True, + }, + + 'reducible_11': { + 'eq': f(x).diff(x, 2) - f(x).diff(x)**3, + 'sol': [Eq(f(x), C1 - sqrt(2)*sqrt(-1/(C2 + x))*(C2 + x)), + Eq(f(x), C1 + sqrt(2)*sqrt(-1/(C2 + x))*(C2 + x))], + 'slow': True, + }, + + # Needs to be a way to know how to combine derivatives in the expression + 'reducible_12': { + 'eq': Derivative(x*f(x), x, x, x) + Derivative(f(x), x, x, x), + 'sol': [Eq(f(x), C1 + C3/Mul(2, (x**2 + 2*x + 1), evaluate=False) + + x*(C2 + C3/Mul(2, (x**2 + 2*x + 1), evaluate=False)))], # 2-arg Mul! + 'slow': True, + }, + } + } + + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_undetermined_coefficients(): + # examples 3-27 below are from Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 231 + g = exp(-x) + f2 = f(x).diff(x, 2) + c = 3*f(x).diff(x, 3) + 5*f2 + f(x).diff(x) - f(x) - x + t = symbols("t") + u = symbols("u",cls=Function) + R, L, C, E_0, alpha = symbols("R L C E_0 alpha",positive=True) + omega = Symbol('omega') + return { + 'hint': "nth_linear_constant_coeff_undetermined_coefficients", + 'func': f(x), + 'examples':{ + 'undet_01': { + 'eq': c - x*g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x**2/24 - 3*x/32))*exp(-x) - 1)], + 'slow': True, + }, + + 'undet_02': { + 'eq': c - g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x/8))*exp(-x) - 1)], + 'slow': True, + }, + + 'undet_03': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 4, + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2)], + 'slow': True, + }, + + 'undet_04': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 12*exp(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2*exp(x))], + 'slow': True, + }, + + 'undet_05': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - exp(I*x), + 'sol': [Eq(f(x), (S(3)/10 + I/10)*(C1*exp(-2*x) + C2*exp(-x) - I*exp(I*x)))], + 'slow': True, + }, + + 'undet_06': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - sin(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + sin(x)/10 - 3*cos(x)/10)], + 'slow': True, + }, + + 'undet_07': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - cos(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 3*sin(x)/10 + cos(x)/10)], + 'slow': True, + }, + + 'undet_08': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - (8 + 6*exp(x) + 2*sin(x)), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + exp(x) + sin(x)/5 - 3*cos(x)/5 + 4)], + 'slow': True, + }, + + 'undet_09': { + 'eq': f2 + f(x).diff(x) + f(x) - x**2, + 'sol': [Eq(f(x), -2*x + x**2 + (C1*sin(x*sqrt(3)/2) + C2*cos(x*sqrt(3)/2))*exp(-x/2))], + 'slow': True, + }, + + 'undet_10': { + 'eq': f2 - 2*f(x).diff(x) - 8*f(x) - 9*x*exp(x) - 10*exp(-x), + 'sol': [Eq(f(x), -x*exp(x) - 2*exp(-x) + C1*exp(-2*x) + C2*exp(4*x))], + 'slow': True, + }, + + 'undet_11': { + 'eq': f2 - 3*f(x).diff(x) - 2*exp(2*x)*sin(x), + 'sol': [Eq(f(x), C1 + C2*exp(3*x) - 3*exp(2*x)*sin(x)/5 - exp(2*x)*cos(x)/5)], + 'slow': True, + }, + + 'undet_12': { + 'eq': f(x).diff(x, 4) - 2*f2 + f(x) - x + sin(x), + 'sol': [Eq(f(x), x - sin(x)/4 + (C1 + C2*x)*exp(-x) + (C3 + C4*x)*exp(x))], + 'slow': True, + }, + + 'undet_13': { + 'eq': f2 + f(x).diff(x) - x**2 - 2*x, + 'sol': [Eq(f(x), C1 + x**3/3 + C2*exp(-x))], + 'slow': True, + }, + + 'undet_14': { + 'eq': f2 + f(x).diff(x) - x - sin(2*x), + 'sol': [Eq(f(x), C1 - x - sin(2*x)/5 - cos(2*x)/10 + x**2/2 + C2*exp(-x))], + 'slow': True, + }, + + 'undet_15': { + 'eq': f2 + f(x) - 4*x*sin(x), + 'sol': [Eq(f(x), (C1 - x**2)*cos(x) + (C2 + x)*sin(x))], + 'slow': True, + }, + + 'undet_16': { + 'eq': f2 + 4*f(x) - x*sin(2*x), + 'sol': [Eq(f(x), (C1 - x**2/8)*cos(2*x) + (C2 + x/16)*sin(2*x))], + 'slow': True, + }, + + 'undet_17': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x**3/12))*exp(-x))], + 'slow': True, + }, + + 'undet_18': { + 'eq': f(x).diff(x, 3) + 3*f2 + 3*f(x).diff(x) + f(x) - 2*exp(-x) + \ + x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 - x**3/60 + x/3)))*exp(-x))], + 'slow': True, + }, + + 'undet_19': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - exp(-2*x) - x**2, + 'sol': [Eq(f(x), C2*exp(-x) + x**2/2 - x*Rational(3,2) + (C1 - x)*exp(-2*x) + Rational(7,4))], + 'slow': True, + }, + + 'undet_20': { + 'eq': f2 - 3*f(x).diff(x) + 2*f(x) - x*exp(-x), + 'sol': [Eq(f(x), C1*exp(x) + C2*exp(2*x) + (6*x + 5)*exp(-x)/36)], + 'slow': True, + }, + + 'undet_21': { + 'eq': f2 + f(x).diff(x) - 6*f(x) - x - exp(2*x), + 'sol': [Eq(f(x), Rational(-1, 36) - x/6 + C2*exp(-3*x) + (C1 + x/5)*exp(2*x))], + 'slow': True, + }, + + 'undet_22': { + 'eq': f2 + f(x) - sin(x) - exp(-x), + 'sol': [Eq(f(x), C2*sin(x) + (C1 - x/2)*cos(x) + exp(-x)/2)], + 'slow': True, + }, + + 'undet_23': { + 'eq': f(x).diff(x, 3) - 3*f2 + 3*f(x).diff(x) - f(x) - exp(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 + x/6)))*exp(x))], + 'slow': True, + }, + + 'undet_24': { + 'eq': f2 + f(x) - S.Half - cos(2*x)/2, + 'sol': [Eq(f(x), S.Half - cos(2*x)/6 + C1*sin(x) + C2*cos(x))], + 'slow': True, + }, + + 'undet_25': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - exp(2*x)*(S.Half - cos(2*x)/2), + 'sol': [Eq(f(x), C1 + C2*exp(-x) + C3*exp(x) + (-21*sin(2*x) + 27*cos(2*x) + 130)*exp(2*x)/1560)], + 'slow': True, + }, + + #Note: 'undet_26' is referred in 'undet_37' + 'undet_26': { + 'eq': (f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - + sin(x) - cos(x)), + 'sol': [Eq(f(x), C1 + x**2 + (C2 + x*(C3 - x/8))*sin(x) + (C4 + x*(C5 + x/8))*cos(x))], + 'slow': True, + }, + + 'undet_27': { + 'eq': f2 + f(x) - cos(x)/2 + cos(3*x)/2, + 'sol': [Eq(f(x), cos(3*x)/16 + C2*cos(x) + (C1 + x/4)*sin(x))], + 'slow': True, + }, + + 'undet_28': { + 'eq': f(x).diff(x) - 1, + 'sol': [Eq(f(x), C1 + x)], + 'slow': True, + }, + + # https://github.com/sympy/sympy/issues/19358 + 'undet_29': { + 'eq': f2 + f(x).diff(x) + exp(x-C1), + 'sol': [Eq(f(x), C2 + C3*exp(-x) - exp(-C1 + x)/2)], + 'slow': True, + }, + + # https://github.com/sympy/sympy/issues/18408 + 'undet_30': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - sinh(x), + 'sol': [Eq(f(x), C1 + C2*exp(-x) + C3*exp(x) + x*sinh(x)/2)], + }, + + 'undet_31': { + 'eq': f(x).diff(x, 2) - 49*f(x) - sinh(3*x), + 'sol': [Eq(f(x), C1*exp(-7*x) + C2*exp(7*x) - sinh(3*x)/40)], + }, + + 'undet_32': { + 'eq': f(x).diff(x, 3) - f(x).diff(x) - sinh(x) - exp(x), + 'sol': [Eq(f(x), C1 + C3*exp(-x) + x*sinh(x)/2 + (C2 + x/2)*exp(x))], + }, + + # https://github.com/sympy/sympy/issues/5096 + 'undet_33': { + 'eq': f(x).diff(x, x) + f(x) - x*sin(x - 2), + 'sol': [Eq(f(x), C1*sin(x) + C2*cos(x) - x**2*cos(x - 2)/4 + x*sin(x - 2)/4)], + }, + + 'undet_34': { + 'eq': f(x).diff(x, 2) + f(x) - x**4*sin(x-1), + 'sol': [ Eq(f(x), C1*sin(x) + C2*cos(x) - x**5*cos(x - 1)/10 + x**4*sin(x - 1)/4 + x**3*cos(x - 1)/2 - 3*x**2*sin(x - 1)/4 - 3*x*cos(x - 1)/4)], + }, + + 'undet_35': { + 'eq': f(x).diff(x, 2) - f(x) - exp(x - 1), + 'sol': [Eq(f(x), C2*exp(-x) + (C1 + x*exp(-1)/2)*exp(x))], + }, + + 'undet_36': { + 'eq': f(x).diff(x, 2)+f(x)-(sin(x-2)+1), + 'sol': [Eq(f(x), C1*sin(x) + C2*cos(x) - x*cos(x - 2)/2 + 1)], + }, + + # Equivalent to example_name 'undet_26'. + # This previously failed because the algorithm for undetermined coefficients + # didn't know to multiply exp(I*x) by sufficient x because it is linearly + # dependent on sin(x) and cos(x). + 'undet_37': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x), + 'sol': [Eq(f(x), C1 + x**2*(I*exp(I*x)/8 + 1) + (C2 + C3*x)*sin(x) + (C4 + C5*x)*cos(x))], + }, + + # https://github.com/sympy/sympy/issues/12623 + 'undet_38': { + 'eq': Eq( u(t).diff(t,t) + R /L*u(t).diff(t) + 1/(L*C)*u(t), alpha), + 'sol': [Eq(u(t), C*L*alpha + C2*exp(-t*(R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + + C1*exp(t*(-R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)))], + 'func': u(t) + }, + + 'undet_39': { + 'eq': Eq( L*C*u(t).diff(t,t) + R*C*u(t).diff(t) + u(t), E_0*exp(I*omega*t) ), + 'sol': [Eq(u(t), C2*exp(-t*(R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + + C1*exp(t*(-R + sqrt(C*R**2 - 4*L)/sqrt(C))/(2*L)) + - E_0*exp(I*omega*t)/(C*L*omega**2 - I*C*R*omega - 1))], + 'func': u(t), + }, + + # https://github.com/sympy/sympy/issues/6879 + 'undet_40': { + 'eq': Eq(Derivative(f(x), x, 2) - 2*Derivative(f(x), x) + f(x), sin(x)), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(x) + cos(x)/2)], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_separable(): + # test_separable1-5 are from Ordinary Differential Equations, Tenenbaum and + # Pollard, pg. 55 + t,a = symbols('a,t') + m = 96 + g = 9.8 + k = .2 + f1 = g * m + v = Function('v') + return { + 'hint': "separable", + 'func': f(x), + 'examples':{ + 'separable_01': { + 'eq': f(x).diff(x) - f(x), + 'sol': [Eq(f(x), C1*exp(x))], + }, + + 'separable_02': { + 'eq': x*f(x).diff(x) - f(x), + 'sol': [Eq(f(x), C1*x)], + }, + + 'separable_03': { + 'eq': f(x).diff(x) + sin(x), + 'sol': [Eq(f(x), C1 + cos(x))], + }, + + 'separable_04': { + 'eq': f(x)**2 + 1 - (x**2 + 1)*f(x).diff(x), + 'sol': [Eq(f(x), tan(C1 + atan(x)))], + }, + + 'separable_05': { + 'eq': f(x).diff(x)/tan(x) - f(x) - 2, + 'sol': [Eq(f(x), C1/cos(x) - 2)], + }, + + 'separable_06': { + 'eq': f(x).diff(x) * (1 - sin(f(x))) - 1, + 'sol': [Eq(-x + f(x) + cos(f(x)), C1)], + }, + + 'separable_07': { + 'eq': f(x)*x**2*f(x).diff(x) - f(x)**3 - 2*x**2*f(x).diff(x), + 'sol': [Eq(f(x), (-x - sqrt(x*(4*C1*x + x - 4)))/(C1*x - 1)/2), + Eq(f(x), (-x + sqrt(x*(4*C1*x + x - 4)))/(C1*x - 1)/2)], + 'slow': True, + }, + + 'separable_08': { + 'eq': f(x)**2 - 1 - (2*f(x) + x*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -sqrt(C1*x**2 + 4*C1*x + 4*C1 + 1)), + Eq(f(x), sqrt(C1*x**2 + 4*C1*x + 4*C1 + 1))], + 'slow': True, + }, + + 'separable_09': { + 'eq': x*log(x)*f(x).diff(x) + sqrt(1 + f(x)**2), + 'sol': [Eq(f(x), sinh(C1 - log(log(x))))], #One more solution is f(x)=I + 'slow': True, + 'checkodesol_XFAIL': True, + }, + + 'separable_10': { + 'eq': exp(x + 1)*tan(f(x)) + cos(f(x))*f(x).diff(x), + 'sol': [Eq(E*exp(x) + log(cos(f(x)) - 1)/2 - log(cos(f(x)) + 1)/2 + cos(f(x)), C1)], + 'slow': True, + }, + + 'separable_11': { + 'eq': (x*cos(f(x)) + x**2*sin(f(x))*f(x).diff(x) - a**2*sin(f(x))*f(x).diff(x)), + 'sol': [ + Eq(f(x), -acos(C1*sqrt(-a**2 + x**2)) + 2*pi), + Eq(f(x), acos(C1*sqrt(-a**2 + x**2))) + ], + 'slow': True, + }, + + 'separable_12': { + 'eq': f(x).diff(x) - f(x)*tan(x), + 'sol': [Eq(f(x), C1/cos(x))], + }, + + 'separable_13': { + 'eq': (x - 1)*cos(f(x))*f(x).diff(x) - 2*x*sin(f(x)), + 'sol': [ + Eq(f(x), pi - asin(C1*(x**2 - 2*x + 1)*exp(2*x))), + Eq(f(x), asin(C1*(x**2 - 2*x + 1)*exp(2*x))) + ], + }, + + 'separable_14': { + 'eq': f(x).diff(x) - f(x)*log(f(x))/tan(x), + 'sol': [Eq(f(x), exp(C1*sin(x)))], + }, + + 'separable_15': { + 'eq': x*f(x).diff(x) + (1 + f(x)**2)*atan(f(x)), + 'sol': [Eq(f(x), tan(C1/x))], #Two more solutions are f(x)=0 and f(x)=I + 'slow': True, + 'checkodesol_XFAIL': True, + }, + + 'separable_16': { + 'eq': f(x).diff(x) + x*(f(x) + 1), + 'sol': [Eq(f(x), -1 + C1*exp(-x**2/2))], + }, + + 'separable_17': { + 'eq': exp(f(x)**2)*(x**2 + 2*x + 1) + (x*f(x) + f(x))*f(x).diff(x), + 'sol': [ + Eq(f(x), -sqrt(log(1/(C1 + x**2 + 2*x)))), + Eq(f(x), sqrt(log(1/(C1 + x**2 + 2*x)))) + ], + }, + + 'separable_18': { + 'eq': f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*exp(-x))], + }, + + 'separable_19': { + 'eq': sin(x)*cos(2*f(x)) + cos(x)*sin(2*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), pi - acos(C1/cos(x)**2)/2), Eq(f(x), acos(C1/cos(x)**2)/2)], + }, + + 'separable_20': { + 'eq': (1 - x)*f(x).diff(x) - x*(f(x) + 1), + 'sol': [Eq(f(x), (C1*exp(-x) - x + 1)/(x - 1))], + }, + + 'separable_21': { + 'eq': f(x)*diff(f(x), x) + x - 3*x*f(x)**2, + 'sol': [Eq(f(x), -sqrt(3)*sqrt(C1*exp(3*x**2) + 1)/3), + Eq(f(x), sqrt(3)*sqrt(C1*exp(3*x**2) + 1)/3)], + }, + + 'separable_22': { + 'eq': f(x).diff(x) - exp(x + f(x)), + 'sol': [Eq(f(x), log(-1/(C1 + exp(x))))], + 'XFAIL': ['lie_group'] #It shows 'NoneType' object is not subscriptable for lie_group. + }, + + # https://github.com/sympy/sympy/issues/7081 + 'separable_23': { + 'eq': x*(f(x).diff(x)) + 1 - f(x)**2, + 'sol': [Eq(f(x), (-C1 - x**2)/(-C1 + x**2))], + }, + + # https://github.com/sympy/sympy/issues/10379 + 'separable_24': { + 'eq': f(t).diff(t)-(1-51.05*y*f(t)), + 'sol': [Eq(f(t), (0.019588638589618023*exp(y*(C1 - 51.049999999999997*t)) + 0.019588638589618023)/y)], + 'func': f(t), + }, + + # https://github.com/sympy/sympy/issues/15999 + 'separable_25': { + 'eq': f(x).diff(x) - C1*f(x), + 'sol': [Eq(f(x), C2*exp(C1*x))], + }, + + 'separable_26': { + 'eq': f1 - k * (v(t) ** 2) - m * Derivative(v(t)), + 'sol': [Eq(v(t), -68.585712797928991/tanh(C1 - 0.14288690166235204*t))], + 'func': v(t), + 'checkodesol_XFAIL': True, + }, + + #https://github.com/sympy/sympy/issues/22155 + 'separable_27': { + 'eq': f(x).diff(x) - exp(f(x) - x), + 'sol': [Eq(f(x), log(-exp(x)/(C1*exp(x) - 1)))], + } + } + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_exact(): + # Type: Exact differential equation, p(x,f) + q(x,f)*f' == 0, + # where dp/df == dq/dx + ''' + Example 7 is an exact equation that fails under the exact engine. It is caught + by first order homogeneous albeit with a much contorted solution. The + exact engine fails because of a poorly simplified integral of q(0,y)dy, + where q is the function multiplying f'. The solutions should be + Eq(sqrt(x**2+f(x)**2)**3+y**3, C1). The equation below is + equivalent, but it is so complex that checkodesol fails, and takes a long + time to do so. + ''' + return { + 'hint': "1st_exact", + 'func': f(x), + 'examples':{ + '1st_exact_01': { + 'eq': sin(x)*cos(f(x)) + cos(x)*sin(f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -acos(C1/cos(x)) + 2*pi), Eq(f(x), acos(C1/cos(x)))], + 'slow': True, + }, + + '1st_exact_02': { + 'eq': (2*x*f(x) + 1)/f(x) + (f(x) - x)/f(x)**2*f(x).diff(x), + 'sol': [Eq(f(x), exp(C1 - x**2 + LambertW(-x*exp(-C1 + x**2))))], + 'XFAIL': ['lie_group'], #It shows dsolve raises an exception: List index out of range for lie_group + 'slow': True, + 'checkodesol_XFAIL':True + }, + + '1st_exact_03': { + 'eq': 2*x + f(x)*cos(x) + (2*f(x) + sin(x) - sin(f(x)))*f(x).diff(x), + 'sol': [Eq(f(x)*sin(x) + cos(f(x)) + x**2 + f(x)**2, C1)], + 'XFAIL': ['lie_group'], #It goes into infinite loop for lie_group. + 'slow': True, + }, + + '1st_exact_04': { + 'eq': cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + 'sol': [Eq(x*cos(f(x)) + f(x)**3/3, C1)], + 'slow': True, + }, + + '1st_exact_05': { + 'eq': 2*x*f(x) + (x**2 + f(x)**2)*f(x).diff(x), + 'sol': [Eq(x**2*f(x) + f(x)**3/3, C1)], + 'slow': True, + 'simplify_flag':False + }, + + # This was from issue: https://github.com/sympy/sympy/issues/11290 + '1st_exact_06': { + 'eq': cos(f(x)) - (x*sin(f(x)) - f(x)**2)*f(x).diff(x), + 'sol': [Eq(x*cos(f(x)) + f(x)**3/3, C1)], + 'simplify_flag':False + }, + + '1st_exact_07': { + 'eq': x*sqrt(x**2 + f(x)**2) - (x**2*f(x)/(f(x) - sqrt(x**2 + f(x)**2)))*f(x).diff(x), + 'sol': [Eq(log(x), + C1 - 9*sqrt(1 + f(x)**2/x**2)*asinh(f(x)/x)/(-27*f(x)/x + + 27*sqrt(1 + f(x)**2/x**2)) - 9*sqrt(1 + f(x)**2/x**2)* + log(1 - sqrt(1 + f(x)**2/x**2)*f(x)/x + 2*f(x)**2/x**2)/ + (-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2)) + + 9*asinh(f(x)/x)*f(x)/(x*(-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2))) + + 9*f(x)*log(1 - sqrt(1 + f(x)**2/x**2)*f(x)/x + 2*f(x)**2/x**2)/ + (x*(-27*f(x)/x + 27*sqrt(1 + f(x)**2/x**2))))], + 'slow': True, + 'dsolve_too_slow':True + }, + + # Type: a(x)f'(x)+b(x)*f(x)+c(x)=0 + '1st_exact_08': { + 'eq': Eq(x**2*f(x).diff(x) + 3*x*f(x) - sin(x)/x, 0), + 'sol': [Eq(f(x), (C1 - cos(x))/x**3)], + }, + + # these examples are from test_exact_enhancement + '1st_exact_09': { + 'eq': f(x)/x**2 + ((f(x)*x - 1)/x)*f(x).diff(x), + 'sol': [Eq(f(x), (i*sqrt(C1*x**2 + 1) + 1)/x) for i in (-1, 1)], + }, + + '1st_exact_10': { + 'eq': (x*f(x) - 1) + f(x).diff(x)*(x**2 - x*f(x)), + 'sol': [Eq(f(x), x - sqrt(C1 + x**2 - 2*log(x))), Eq(f(x), x + sqrt(C1 + x**2 - 2*log(x)))], + }, + + '1st_exact_11': { + 'eq': (x + 2)*sin(f(x)) + f(x).diff(x)*x*cos(f(x)), + 'sol': [Eq(f(x), -asin(C1*exp(-x)/x**2) + pi), Eq(f(x), asin(C1*exp(-x)/x**2))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_var_of_parameters(): + g = exp(-x) + f2 = f(x).diff(x, 2) + c = 3*f(x).diff(x, 3) + 5*f2 + f(x).diff(x) - f(x) - x + return { + 'hint': "nth_linear_constant_coeff_variation_of_parameters", + 'func': f(x), + 'examples':{ + 'var_of_parameters_01': { + 'eq': c - x*g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x**2/24 - 3*x/32))*exp(-x) - 1)], + 'slow': True, + }, + + 'var_of_parameters_02': { + 'eq': c - g, + 'sol': [Eq(f(x), C3*exp(x/3) - x + (C1 + x*(C2 - x/8))*exp(-x) - 1)], + 'slow': True, + }, + + 'var_of_parameters_03': { + 'eq': f(x).diff(x) - 1, + 'sol': [Eq(f(x), C1 + x)], + 'slow': True, + }, + + 'var_of_parameters_04': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 4, + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2)], + 'slow': True, + }, + + 'var_of_parameters_05': { + 'eq': f2 + 3*f(x).diff(x) + 2*f(x) - 12*exp(x), + 'sol': [Eq(f(x), C1*exp(-2*x) + C2*exp(-x) + 2*exp(x))], + 'slow': True, + }, + + 'var_of_parameters_06': { + 'eq': f2 - 2*f(x).diff(x) - 8*f(x) - 9*x*exp(x) - 10*exp(-x), + 'sol': [Eq(f(x), -x*exp(x) - 2*exp(-x) + C1*exp(-2*x) + C2*exp(4*x))], + 'slow': True, + }, + + 'var_of_parameters_07': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - x**2*exp(-x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x**3/12))*exp(-x))], + 'slow': True, + }, + + 'var_of_parameters_08': { + 'eq': f2 - 3*f(x).diff(x) + 2*f(x) - x*exp(-x), + 'sol': [Eq(f(x), C1*exp(x) + C2*exp(2*x) + (6*x + 5)*exp(-x)/36)], + 'slow': True, + }, + + 'var_of_parameters_09': { + 'eq': f(x).diff(x, 3) - 3*f2 + 3*f(x).diff(x) - f(x) - exp(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + x*(C3 + x/6)))*exp(x))], + 'slow': True, + }, + + 'var_of_parameters_10': { + 'eq': f2 + 2*f(x).diff(x) + f(x) - exp(-x)/x, + 'sol': [Eq(f(x), (C1 + x*(C2 + log(x)))*exp(-x))], + 'slow': True, + }, + + 'var_of_parameters_11': { + 'eq': f2 + f(x) - 1/sin(x)*1/cos(x), + 'sol': [Eq(f(x), (C1 + log(sin(x) - 1)/2 - log(sin(x) + 1)/2 + )*cos(x) + (C2 + log(cos(x) - 1)/2 - log(cos(x) + 1)/2)*sin(x))], + 'slow': True, + }, + + 'var_of_parameters_12': { + 'eq': f(x).diff(x, 4) - 1/x, + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + x**3*(C4 + log(x)/6))], + 'slow': True, + }, + + # These were from issue: https://github.com/sympy/sympy/issues/15996 + 'var_of_parameters_13': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - 2*x - exp(I*x), + 'sol': [Eq(f(x), C1 + x**2 + (C2 + x*(C3 - x/8 + 3*exp(I*x)/2 + 3*exp(-I*x)/2) + 5*exp(2*I*x)/16 + 2*I*exp(I*x) - 2*I*exp(-I*x))*sin(x) + (C4 + x*(C5 + I*x/8 + 3*I*exp(I*x)/2 - 3*I*exp(-I*x)/2) + + 5*I*exp(2*I*x)/16 - 2*exp(I*x) - 2*exp(-I*x))*cos(x) - I*exp(I*x))], + }, + + 'var_of_parameters_14': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x) - exp(I*x), + 'sol': [Eq(f(x), C1 + (C2 + x*(C3 - x/8) + 5*exp(2*I*x)/16)*sin(x) + (C4 + x*(C5 + I*x/8) + 5*I*exp(2*I*x)/16)*cos(x) - I*exp(I*x))], + }, + + # https://github.com/sympy/sympy/issues/14395 + 'var_of_parameters_15': { + 'eq': Derivative(f(x), x, x) + 9*f(x) - sec(x), + 'sol': [Eq(f(x), (C1 - x/3 + sin(2*x)/3)*sin(3*x) + (C2 + log(cos(x)) + - 2*log(cos(x)**2)/3 + 2*cos(x)**2/3)*cos(3*x))], + 'slow': True, + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_linear_bessel(): + return { + 'hint': "2nd_linear_bessel", + 'func': f(x), + 'examples':{ + '2nd_lin_bessel_01': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - 4)*f(x), + 'sol': [Eq(f(x), C1*besselj(2, x) + C2*bessely(2, x))], + }, + + '2nd_lin_bessel_02': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 +25)*f(x), + 'sol': [Eq(f(x), C1*besselj(5*I, x) + C2*bessely(5*I, x))], + }, + + '2nd_lin_bessel_03': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2)*f(x), + 'sol': [Eq(f(x), C1*besselj(0, x) + C2*bessely(0, x))], + }, + + '2nd_lin_bessel_04': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (81*x**2 -S(1)/9)*f(x), + 'sol': [Eq(f(x), C1*besselj(S(1)/3, 9*x) + C2*bessely(S(1)/3, 9*x))], + }, + + '2nd_lin_bessel_05': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**4 - 4)*f(x), + 'sol': [Eq(f(x), C1*besselj(1, x**2/2) + C2*bessely(1, x**2/2))], + }, + + '2nd_lin_bessel_06': { + 'eq': x**2*(f(x).diff(x, 2)) + 2*x*(f(x).diff(x)) + (x**4 - 4)*f(x), + 'sol': [Eq(f(x), (C1*besselj(sqrt(17)/4, x**2/2) + C2*bessely(sqrt(17)/4, x**2/2))/sqrt(x))], + }, + + '2nd_lin_bessel_07': { + 'eq': x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (x**2 - S(1)/4)*f(x), + 'sol': [Eq(f(x), C1*besselj(S(1)/2, x) + C2*bessely(S(1)/2, x))], + }, + + '2nd_lin_bessel_08': { + 'eq': x**2*(f(x).diff(x, 2)) - 3*x*(f(x).diff(x)) + (4*x + 4)*f(x), + 'sol': [Eq(f(x), x**2*(C1*besselj(0, 4*sqrt(x)) + C2*bessely(0, 4*sqrt(x))))], + }, + + '2nd_lin_bessel_09': { + 'eq': x*(f(x).diff(x, 2)) - f(x).diff(x) + 4*x**3*f(x), + 'sol': [Eq(f(x), x*(C1*besselj(S(1)/2, x**2) + C2*bessely(S(1)/2, x**2)))], + }, + + '2nd_lin_bessel_10': { + 'eq': (x-2)**2*(f(x).diff(x, 2)) - (x-2)*f(x).diff(x) + 4*(x-2)**2*f(x), + 'sol': [Eq(f(x), (x - 2)*(C1*besselj(1, 2*x - 4) + C2*bessely(1, 2*x - 4)))], + }, + + # https://github.com/sympy/sympy/issues/4414 + '2nd_lin_bessel_11': { + 'eq': f(x).diff(x, x) + 2/x*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*besselj(S(1)/2, x) + C2*bessely(S(1)/2, x))/sqrt(x))], + }, + '2nd_lin_bessel_12': { + 'eq': x**2*f(x).diff(x, 2) + x*f(x).diff(x) + (a**2*x**2/c**2 - b**2)*f(x), + 'sol': [Eq(f(x), C1*besselj(sqrt(b**2), x*sqrt(a**2/c**2)) + C2*bessely(sqrt(b**2), x*sqrt(a**2/c**2)))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_2F1_hypergeometric(): + return { + 'hint': "2nd_hypergeometric", + 'func': f(x), + 'examples':{ + '2nd_2F1_hyper_01': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(3)/2 -2*x)*f(x).diff(x) + 2*f(x), + 'sol': [Eq(f(x), C1*x**(S(5)/2)*hyper((S(3)/2, S(1)/2), (S(7)/2,), x) + C2*hyper((-1, -2), (-S(3)/2,), x))], + }, + + '2nd_2F1_hyper_02': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(7)/2*x)*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*(1 - x)**(S(5)/2)*hyper((S(1)/2, 2), (S(7)/2,), 1 - x) + + C2*hyper((-S(1)/2, -2), (-S(3)/2,), 1 - x))/(x - 1)**(S(5)/2))], + }, + + '2nd_2F1_hyper_03': { + 'eq': x*(x-1)*f(x).diff(x, 2) + (S(3)+ S(7)/2*x)*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*(1 - x)**(S(11)/2)*hyper((S(1)/2, 2), (S(13)/2,), 1 - x) + + C2*hyper((-S(7)/2, -5), (-S(9)/2,), 1 - x))/(x - 1)**(S(11)/2))], + }, + + '2nd_2F1_hyper_04': { + 'eq': -x**(S(5)/7)*(-416*x**(S(9)/7)/9 - 2385*x**(S(5)/7)/49 + S(298)*x/3)*f(x)/(196*(-x**(S(6)/7) + + x)**2*(x**(S(6)/7) + x)**2) + Derivative(f(x), (x, 2)), + 'sol': [Eq(f(x), x**(S(45)/98)*(C1*x**(S(4)/49)*hyper((S(1)/3, -S(1)/2), (S(9)/7,), x**(S(2)/7)) + + C2*hyper((S(1)/21, -S(11)/14), (S(5)/7,), x**(S(2)/7)))/(x**(S(2)/7) - 1)**(S(19)/84))], + 'checkodesol_XFAIL':True, + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_2nd_nonlinear_autonomous_conserved(): + return { + 'hint': "2nd_nonlinear_autonomous_conserved", + 'func': f(x), + 'examples': { + '2nd_nonlinear_autonomous_conserved_01': { + 'eq': f(x).diff(x, 2) + exp(f(x)) + log(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*_u*log(_u) + 2*_u - 2*exp(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*_u*log(_u) + 2*_u - 2*exp(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_02': { + 'eq': f(x).diff(x, 2) + cbrt(f(x)) + 1/f(x), + 'sol': [ + Eq(sqrt(2)*Integral(1/sqrt(2*C1 - 3*_u**Rational(4, 3) - 4*log(_u)), (_u, f(x))), C2 + x), + Eq(sqrt(2)*Integral(1/sqrt(2*C1 - 3*_u**Rational(4, 3) - 4*log(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_03': { + 'eq': f(x).diff(x, 2) + sin(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 + 2*cos(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 + 2*cos(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_04': { + 'eq': f(x).diff(x, 2) + cosh(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*sinh(_u)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*sinh(_u)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + }, + '2nd_nonlinear_autonomous_conserved_05': { + 'eq': f(x).diff(x, 2) + asin(f(x)), + 'sol': [ + Eq(Integral(1/sqrt(C1 - 2*_u*asin(_u) - 2*sqrt(1 - _u**2)), (_u, f(x))), C2 + x), + Eq(Integral(1/sqrt(C1 - 2*_u*asin(_u) - 2*sqrt(1 - _u**2)), (_u, f(x))), C2 - x) + ], + 'simplify_flag': False, + 'XFAIL': ['2nd_nonlinear_autonomous_conserved_Integral'] + } + } + } + + +@_add_example_keys +def _get_examples_ode_sol_separable_reduced(): + df = f(x).diff(x) + return { + 'hint': "separable_reduced", + 'func': f(x), + 'examples':{ + 'separable_reduced_01': { + 'eq': x* df + f(x)* (1 / (x**2*f(x) - 1)), + 'sol': [Eq(log(x**2*f(x))/3 + log(x**2*f(x) - Rational(3, 2))/6, C1 + log(x))], + 'simplify_flag': False, + 'XFAIL': ['lie_group'], #It hangs. + }, + + #Note: 'separable_reduced_02' is referred in 'separable_reduced_11' + 'separable_reduced_02': { + 'eq': f(x).diff(x) + (f(x) / (x**4*f(x) - x)), + 'sol': [Eq(log(x**3*f(x))/4 + log(x**3*f(x) - Rational(4,3))/12, C1 + log(x))], + 'simplify_flag': False, + 'checkodesol_XFAIL':True, #It hangs for this. + }, + + 'separable_reduced_03': { + 'eq': x*df + f(x)*(x**2*f(x)), + 'sol': [Eq(log(x**2*f(x))/2 - log(x**2*f(x) - 2)/2, C1 + log(x))], + 'simplify_flag': False, + }, + + 'separable_reduced_04': { + 'eq': Eq(f(x).diff(x) + f(x)/x * (1 + (x**(S(2)/3)*f(x))**2), 0), + 'sol': [Eq(-3*log(x**(S(2)/3)*f(x)) + 3*log(3*x**(S(4)/3)*f(x)**2 + 1)/2, C1 + log(x))], + 'simplify_flag': False, + }, + + 'separable_reduced_05': { + 'eq': Eq(f(x).diff(x) + f(x)/x * (1 + (x*f(x))**2), 0), + 'sol': [Eq(f(x), -sqrt(2)*sqrt(1/(C1 + log(x)))/(2*x)),\ + Eq(f(x), sqrt(2)*sqrt(1/(C1 + log(x)))/(2*x))], + }, + + 'separable_reduced_06': { + 'eq': Eq(f(x).diff(x) + (x**4*f(x)**2 + x**2*f(x))*f(x)/(x*(x**6*f(x)**3 + x**4*f(x)**2)), 0), + 'sol': [Eq(f(x), C1 + 1/(2*x**2))], + }, + + 'separable_reduced_07': { + 'eq': Eq(f(x).diff(x) + (f(x)**2)*f(x)/(x), 0), + 'sol': [ + Eq(f(x), -sqrt(2)*sqrt(1/(C1 + log(x)))/2), + Eq(f(x), sqrt(2)*sqrt(1/(C1 + log(x)))/2) + ], + }, + + 'separable_reduced_08': { + 'eq': Eq(f(x).diff(x) + (f(x)+3)*f(x)/(x*(f(x)+2)), 0), + 'sol': [Eq(-log(f(x) + 3)/3 - 2*log(f(x))/3, C1 + log(x))], + 'simplify_flag': False, + 'XFAIL': ['lie_group'], #It hangs. + }, + + 'separable_reduced_09': { + 'eq': Eq(f(x).diff(x) + (f(x)+3)*f(x)/x, 0), + 'sol': [Eq(f(x), 3/(C1*x**3 - 1))], + }, + + 'separable_reduced_10': { + 'eq': Eq(f(x).diff(x) + (f(x)**2+f(x))*f(x)/(x), 0), + 'sol': [Eq(- log(x) - log(f(x) + 1) + log(f(x)) + 1/f(x), C1)], + 'XFAIL': ['lie_group'],#No algorithms are implemented to solve equation -C1 + x*(_y + 1)*exp(-1/_y)/_y + + }, + + # Equivalent to example_name 'separable_reduced_02'. Only difference is testing with simplify=True + 'separable_reduced_11': { + 'eq': f(x).diff(x) + (f(x) / (x**4*f(x) - x)), + 'sol': [Eq(f(x), -sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 +- sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 +- 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), -sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 ++ sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 +- 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 +- sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) ++ 4/x**6 + 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3)), +Eq(f(x), sqrt(2)*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) +- 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)/6 ++ sqrt(2)*sqrt(-3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) ++ x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 4/x**6 + 4*sqrt(2)/(x**9*sqrt(3*3**Rational(1,3)*(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) +- exp(12*C1)/x**6)**Rational(1,3) - 3*3**Rational(2,3)*exp(12*C1)/(sqrt((3*exp(12*C1) + x**(-12))*exp(24*C1)) - exp(12*C1)/x**6)**Rational(1,3) + 2/x**6)))/6 + 1/(3*x**3))], + 'checkodesol_XFAIL':True, #It hangs for this. + 'slow': True, + }, + + #These were from issue: https://github.com/sympy/sympy/issues/6247 + 'separable_reduced_12': { + 'eq': x**2*f(x)**2 + x*Derivative(f(x), x), + 'sol': [Eq(f(x), 2*C1/(C1*x**2 - 1))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_lie_group(): + a, b, c = symbols("a b c") + return { + 'hint': "lie_group", + 'func': f(x), + 'examples':{ + #Example 1-4 and 19-20 were from issue: https://github.com/sympy/sympy/issues/17322 + 'lie_group_01': { + 'eq': x*f(x).diff(x)*(f(x)+4) + (f(x)**2) -2*f(x)-2*x, + 'sol': [], + 'dsolve_too_slow': True, + 'checkodesol_too_slow': True, + }, + + 'lie_group_02': { + 'eq': x*f(x).diff(x)*(f(x)+4) + (f(x)**2) -2*f(x)-2*x, + 'sol': [], + 'dsolve_too_slow': True, + }, + + 'lie_group_03': { + 'eq': Eq(x**7*Derivative(f(x), x) + 5*x**3*f(x)**2 - (2*x**2 + 2)*f(x)**3, 0), + 'sol': [], + 'dsolve_too_slow': True, + }, + + 'lie_group_04': { + 'eq': f(x).diff(x) - (f(x) - x*log(x))**2/x**2 + log(x), + 'sol': [], + 'XFAIL': ['lie_group'], + }, + + 'lie_group_05': { + 'eq': f(x).diff(x)**2, + 'sol': [Eq(f(x), C1)], + 'XFAIL': ['factorable'], #It raises Not Implemented error + }, + + 'lie_group_06': { + 'eq': Eq(f(x).diff(x), x**2*f(x)), + 'sol': [Eq(f(x), C1*exp(x**3)**Rational(1, 3))], + }, + + 'lie_group_07': { + 'eq': f(x).diff(x) + a*f(x) - c*exp(b*x), + 'sol': [Eq(f(x), Piecewise(((-C1*(a + b) + c*exp(x*(a + b)))*exp(-a*x)/(a + b),\ + Ne(a, -b)), ((-C1 + c*x)*exp(-a*x), True)))], + }, + + 'lie_group_08': { + 'eq': f(x).diff(x) + 2*x*f(x) - x*exp(-x**2), + 'sol': [Eq(f(x), (C1 + x**2/2)*exp(-x**2))], + }, + + 'lie_group_09': { + 'eq': (1 + 2*x)*(f(x).diff(x)) + 2 - 4*exp(-f(x)), + 'sol': [Eq(f(x), log(C1/(2*x + 1) + 2))], + }, + + 'lie_group_10': { + 'eq': x**2*(f(x).diff(x)) - f(x) + x**2*exp(x - (1/x)), + 'sol': [Eq(f(x), (C1 - exp(x))*exp(-1/x))], + 'XFAIL': ['factorable'], #It raises Recursion Error (maixmum depth exceeded) + }, + + 'lie_group_11': { + 'eq': x**2*f(x)**2 + x*Derivative(f(x), x), + 'sol': [Eq(f(x), 2/(C1 + x**2))], + }, + + 'lie_group_12': { + 'eq': diff(f(x),x) + 2*x*f(x) - x*exp(-x**2), + 'sol': [Eq(f(x), exp(-x**2)*(C1 + x**2/2))], + }, + + 'lie_group_13': { + 'eq': diff(f(x),x) + f(x)*cos(x) - exp(2*x), + 'sol': [Eq(f(x), exp(-sin(x))*(C1 + Integral(exp(2*x)*exp(sin(x)), x)))], + }, + + 'lie_group_14': { + 'eq': diff(f(x),x) + f(x)*cos(x) - sin(2*x)/2, + 'sol': [Eq(f(x), C1*exp(-sin(x)) + sin(x) - 1)], + }, + + 'lie_group_15': { + 'eq': x*diff(f(x),x) + f(x) - x*sin(x), + 'sol': [Eq(f(x), (C1 - x*cos(x) + sin(x))/x)], + }, + + 'lie_group_16': { + 'eq': x*diff(f(x),x) - f(x) - x/log(x), + 'sol': [Eq(f(x), x*(C1 + log(log(x))))], + }, + + 'lie_group_17': { + 'eq': (f(x).diff(x)-f(x)) * (f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1*exp(-x))], + }, + + 'lie_group_18': { + 'eq': f(x).diff(x) * (f(x).diff(x) - f(x)), + 'sol': [Eq(f(x), C1*exp(x)), Eq(f(x), C1)], + }, + + 'lie_group_19': { + 'eq': (f(x).diff(x)-f(x)) * (f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1*exp(-x)), Eq(f(x), C1*exp(x))], + }, + + 'lie_group_20': { + 'eq': f(x).diff(x)*(f(x).diff(x)+f(x)), + 'sol': [Eq(f(x), C1), Eq(f(x), C1*exp(-x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_2nd_linear_airy(): + return { + 'hint': "2nd_linear_airy", + 'func': f(x), + 'examples':{ + '2nd_lin_airy_01': { + 'eq': f(x).diff(x, 2) - x*f(x), + 'sol': [Eq(f(x), C1*airyai(x) + C2*airybi(x))], + }, + + '2nd_lin_airy_02': { + 'eq': f(x).diff(x, 2) + 2*x*f(x), + 'sol': [Eq(f(x), C1*airyai(-2**(S(1)/3)*x) + C2*airybi(-2**(S(1)/3)*x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_nth_linear_constant_coeff_homogeneous(): + # From Exercise 20, in Ordinary Differential Equations, + # Tenenbaum and Pollard, pg. 220 + a = Symbol('a', positive=True) + k = Symbol('k', real=True) + r1, r2, r3, r4, r5 = [rootof(x**5 + 11*x - 2, n) for n in range(5)] + r6, r7, r8, r9, r10 = [rootof(x**5 - 3*x + 1, n) for n in range(5)] + r11, r12, r13, r14, r15 = [rootof(x**5 - 100*x**3 + 1000*x + 1, n) for n in range(5)] + r16, r17, r18, r19, r20 = [rootof(x**5 - x**4 + 10, n) for n in range(5)] + r21, r22, r23, r24, r25 = [rootof(x**5 - x + 1, n) for n in range(5)] + E = exp(1) + return { + 'hint': "nth_linear_constant_coeff_homogeneous", + 'func': f(x), + 'examples':{ + 'lin_const_coeff_hom_01': { + 'eq': f(x).diff(x, 2) + 2*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x))], + }, + + 'lin_const_coeff_hom_02': { + 'eq': f(x).diff(x, 2) - 3*f(x).diff(x) + 2*f(x), + 'sol': [Eq(f(x), (C1 + C2*exp(x))*exp(x))], + }, + + 'lin_const_coeff_hom_03': { + 'eq': f(x).diff(x, 2) - f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(x))], + }, + + 'lin_const_coeff_hom_04': { + 'eq': f(x).diff(x, 3) + f(x).diff(x, 2) - 6*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-3*x) + C3*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_05': { + 'eq': 6*f(x).diff(x, 2) - 11*f(x).diff(x) + 4*f(x), + 'sol': [Eq(f(x), C1*exp(x/2) + C2*exp(x*Rational(4, 3)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_06': { + 'eq': Eq(f(x).diff(x, 2) + 2*f(x).diff(x) - f(x), 0), + 'sol': [Eq(f(x), C1*exp(x*(-1 + sqrt(2))) + C2*exp(-x*(sqrt(2) + 1)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_07': { + 'eq': diff(f(x), x, 3) + diff(f(x), x, 2) - 10*diff(f(x), x) - 6*f(x), + 'sol': [Eq(f(x), C1*exp(3*x) + C3*exp(-x*(2 + sqrt(2))) + C2*exp(x*(-2 + sqrt(2))))], + 'slow': True, + }, + + 'lin_const_coeff_hom_08': { + 'eq': f(x).diff(x, 4) - f(x).diff(x, 3) - 4*f(x).diff(x, 2) + \ + 4*f(x).diff(x), + 'sol': [Eq(f(x), C1 + C2*exp(-2*x) + C3*exp(x) + C4*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_09': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 3) + f(x).diff(x, 2) - \ + 4*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), C3*exp(-x) + C4*exp(x) + (C1*exp(-sqrt(2)*x) + C2*exp(sqrt(2)*x))*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_10': { + 'eq': f(x).diff(x, 4) - a**2*f(x), + 'sol': [Eq(f(x), C1*exp(-sqrt(a)*x) + C2*exp(sqrt(a)*x) + C3*sin(sqrt(a)*x) + C4*cos(sqrt(a)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_11': { + 'eq': f(x).diff(x, 2) - 2*k*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), C1*exp(x*(k - sqrt(k**2 + 2))) + C2*exp(x*(k + sqrt(k**2 + 2))))], + 'slow': True, + }, + + 'lin_const_coeff_hom_12': { + 'eq': f(x).diff(x, 2) + 4*k*f(x).diff(x) - 12*k**2*f(x), + 'sol': [Eq(f(x), C1*exp(-6*k*x) + C2*exp(2*k*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_13': { + 'eq': f(x).diff(x, 4), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*x**3)], + 'slow': True, + }, + + 'lin_const_coeff_hom_14': { + 'eq': f(x).diff(x, 2) + 4*f(x).diff(x) + 4*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_15': { + 'eq': 3*f(x).diff(x, 3) + 5*f(x).diff(x, 2) + f(x).diff(x) - f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-x) + C3*exp(x/3))], + 'slow': True, + }, + + 'lin_const_coeff_hom_16': { + 'eq': f(x).diff(x, 3) - 6*f(x).diff(x, 2) + 12*f(x).diff(x) - 8*f(x), + 'sol': [Eq(f(x), (C1 + x*(C2 + C3*x))*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_17': { + 'eq': f(x).diff(x, 2) - 2*a*f(x).diff(x) + a**2*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(a*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_18': { + 'eq': f(x).diff(x, 4) + 3*f(x).diff(x, 3), + 'sol': [Eq(f(x), C1 + C2*x + C3*x**2 + C4*exp(-3*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_19': { + 'eq': f(x).diff(x, 4) - 2*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*exp(-sqrt(2)*x) + C4*exp(sqrt(2)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_20': { + 'eq': f(x).diff(x, 4) + 2*f(x).diff(x, 3) - 11*f(x).diff(x, 2) - \ + 12*f(x).diff(x) + 36*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-3*x) + (C3 + C4*x)*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_21': { + 'eq': 36*f(x).diff(x, 4) - 37*f(x).diff(x, 2) + 4*f(x).diff(x) + 5*f(x), + 'sol': [Eq(f(x), C1*exp(-x) + C2*exp(-x/3) + C3*exp(x/2) + C4*exp(x*Rational(5, 6)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_22': { + 'eq': f(x).diff(x, 4) - 8*f(x).diff(x, 2) + 16*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(-2*x) + (C3 + C4*x)*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_23': { + 'eq': f(x).diff(x, 2) - 2*f(x).diff(x) + 5*f(x), + 'sol': [Eq(f(x), (C1*sin(2*x) + C2*cos(2*x))*exp(x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_24': { + 'eq': f(x).diff(x, 2) - f(x).diff(x) + f(x), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(3)/2) + C2*cos(x*sqrt(3)/2))*exp(x/2))], + 'slow': True, + }, + + 'lin_const_coeff_hom_25': { + 'eq': f(x).diff(x, 4) + 5*f(x).diff(x, 2) + 6*f(x), + 'sol': [Eq(f(x), + C1*sin(sqrt(2)*x) + C2*sin(sqrt(3)*x) + C3*cos(sqrt(2)*x) + C4*cos(sqrt(3)*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_26': { + 'eq': f(x).diff(x, 2) - 4*f(x).diff(x) + 20*f(x), + 'sol': [Eq(f(x), (C1*sin(4*x) + C2*cos(4*x))*exp(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_27': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2) + 4*f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*sin(x*sqrt(2)) + (C3 + C4*x)*cos(x*sqrt(2)))], + 'slow': True, + }, + + 'lin_const_coeff_hom_28': { + 'eq': f(x).diff(x, 3) + 8*f(x), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(3)) + C2*cos(x*sqrt(3)))*exp(x) + C3*exp(-2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_29': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2), + 'sol': [Eq(f(x), C1 + C2*x + C3*sin(2*x) + C4*cos(2*x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_30': { + 'eq': f(x).diff(x, 5) + 2*f(x).diff(x, 3) + f(x).diff(x), + 'sol': [Eq(f(x), C1 + (C2 + C3*x)*sin(x) + (C4 + C5*x)*cos(x))], + 'slow': True, + }, + + 'lin_const_coeff_hom_31': { + 'eq': f(x).diff(x, 4) + f(x).diff(x, 2) + f(x), + 'sol': [Eq(f(x), (C1*sin(sqrt(3)*x/2) + C2*cos(sqrt(3)*x/2))*exp(-x/2) + + (C3*sin(sqrt(3)*x/2) + C4*cos(sqrt(3)*x/2))*exp(x/2))], + 'slow': True, + }, + + 'lin_const_coeff_hom_32': { + 'eq': f(x).diff(x, 4) + 4*f(x).diff(x, 2) + f(x), + 'sol': [Eq(f(x), C1*sin(x*sqrt(-sqrt(3) + 2)) + C2*sin(x*sqrt(sqrt(3) + 2)) + + C3*cos(x*sqrt(-sqrt(3) + 2)) + C4*cos(x*sqrt(sqrt(3) + 2)))], + 'slow': True, + }, + + # One real root, two complex conjugate pairs + 'lin_const_coeff_hom_33': { + 'eq': f(x).diff(x, 5) + 11*f(x).diff(x) - 2*f(x), + 'sol': [Eq(f(x), + C5*exp(r1*x) + exp(re(r2)*x) * (C1*sin(im(r2)*x) + C2*cos(im(r2)*x)) + + exp(re(r4)*x) * (C3*sin(im(r4)*x) + C4*cos(im(r4)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Three real roots, one complex conjugate pair + 'lin_const_coeff_hom_34': { + 'eq': f(x).diff(x,5) - 3*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), + C3*exp(r6*x) + C4*exp(r7*x) + C5*exp(r8*x) + + exp(re(r9)*x) * (C1*sin(im(r9)*x) + C2*cos(im(r9)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Five distinct real roots + 'lin_const_coeff_hom_35': { + 'eq': f(x).diff(x,5) - 100*f(x).diff(x,3) + 1000*f(x).diff(x) + f(x), + 'sol': [Eq(f(x), C1*exp(r11*x) + C2*exp(r12*x) + C3*exp(r13*x) + C4*exp(r14*x) + C5*exp(r15*x))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Rational root and unsolvable quintic + 'lin_const_coeff_hom_36': { + 'eq': f(x).diff(x, 6) - 6*f(x).diff(x, 5) + 5*f(x).diff(x, 4) + 10*f(x).diff(x) - 50 * f(x), + 'sol': [Eq(f(x), + C5*exp(5*x) + + C6*exp(x*r16) + + exp(re(r17)*x) * (C1*sin(im(r17)*x) + C2*cos(im(r17)*x)) + + exp(re(r19)*x) * (C3*sin(im(r19)*x) + C4*cos(im(r19)*x)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + # Five double roots (this is (x**5 - x + 1)**2) + 'lin_const_coeff_hom_37': { + 'eq': f(x).diff(x, 10) - 2*f(x).diff(x, 6) + 2*f(x).diff(x, 5) + + f(x).diff(x, 2) - 2*f(x).diff(x, 1) + f(x), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(x*r21) + (-((C3 + C4*x)*sin(x*im(r22))) + + (C5 + C6*x)*cos(x*im(r22)))*exp(x*re(r22)) + (-((C7 + C8*x)*sin(x*im(r24))) + + (C10*x + C9)*cos(x*im(r24)))*exp(x*re(r24)))], + 'checkodesol_XFAIL':True, #It Hangs + }, + + 'lin_const_coeff_hom_38': { + 'eq': Eq(sqrt(2) * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(2**Rational(3, 4)*x/2) + C3*cos(2**Rational(3, 4)*x/2))], + }, + + 'lin_const_coeff_hom_39': { + 'eq': Eq(E * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(x/sqrt(E)) + C3*cos(x/sqrt(E)))], + }, + + 'lin_const_coeff_hom_40': { + 'eq': Eq(pi * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*sin(x/sqrt(pi)) + C3*cos(x/sqrt(pi)))], + }, + + 'lin_const_coeff_hom_41': { + 'eq': Eq(I * f(x).diff(x,x,x) + f(x).diff(x), 0), + 'sol': [Eq(f(x), C1 + C2*exp(-sqrt(I)*x) + C3*exp(sqrt(I)*x))], + }, + + 'lin_const_coeff_hom_42': { + 'eq': f(x).diff(x, x) + y*f(x), + 'sol': [Eq(f(x), C1*exp(-x*sqrt(-y)) + C2*exp(x*sqrt(-y)))], + }, + + 'lin_const_coeff_hom_43': { + 'eq': Eq(9*f(x).diff(x, x) + f(x), 0), + 'sol': [Eq(f(x), C1*sin(x/3) + C2*cos(x/3))], + }, + + 'lin_const_coeff_hom_44': { + 'eq': Eq(9*f(x).diff(x, x), f(x)), + 'sol': [Eq(f(x), C1*exp(-x/3) + C2*exp(x/3))], + }, + + 'lin_const_coeff_hom_45': { + 'eq': Eq(f(x).diff(x, x) - 3*diff(f(x), x) + 2*f(x), 0), + 'sol': [Eq(f(x), (C1 + C2*exp(x))*exp(x))], + }, + + 'lin_const_coeff_hom_46': { + 'eq': Eq(f(x).diff(x, x) - 4*diff(f(x), x) + 4*f(x), 0), + 'sol': [Eq(f(x), (C1 + C2*x)*exp(2*x))], + }, + + # Type: 2nd order, constant coefficients (two real equal roots) + 'lin_const_coeff_hom_47': { + 'eq': Eq(f(x).diff(x, x) + 2*diff(f(x), x) + 3*f(x), 0), + 'sol': [Eq(f(x), (C1*sin(x*sqrt(2)) + C2*cos(x*sqrt(2)))*exp(-x))], + }, + + #These were from issue: https://github.com/sympy/sympy/issues/6247 + 'lin_const_coeff_hom_48': { + 'eq': f(x).diff(x, x) + 4*f(x), + 'sol': [Eq(f(x), C1*sin(2*x) + C2*cos(2*x))], + }, + } + } + + +@_add_example_keys +def _get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep(): + return { + 'hint': "1st_homogeneous_coeff_subs_dep_div_indep", + 'func': f(x), + 'examples':{ + 'dep_div_indep_01': { + 'eq': f(x)/x*cos(f(x)/x) - (x/f(x)*sin(f(x)/x) + cos(f(x)/x))*f(x).diff(x), + 'sol': [Eq(log(x), C1 - log(f(x)*sin(f(x)/x)/x))], + 'slow': True + }, + + #indep_div_dep actually has a simpler solution for example 2 but it runs too slow. + 'dep_div_indep_02': { + 'eq': x*f(x).diff(x) - f(x) - x*sin(f(x)/x), + 'sol': [Eq(log(x), log(C1) + log(cos(f(x)/x) - 1)/2 - log(cos(f(x)/x) + 1)/2)], + 'simplify_flag':False, + }, + + 'dep_div_indep_03': { + 'eq': x*exp(f(x)/x) - f(x)*sin(f(x)/x) + x*sin(f(x)/x)*f(x).diff(x), + 'sol': [Eq(log(x), C1 + exp(-f(x)/x)*sin(f(x)/x)/2 + exp(-f(x)/x)*cos(f(x)/x)/2)], + 'slow': True + }, + + 'dep_div_indep_04': { + 'eq': f(x).diff(x) - f(x)/x + 1/sin(f(x)/x), + 'sol': [Eq(f(x), x*(-acos(C1 + log(x)) + 2*pi)), Eq(f(x), x*acos(C1 + log(x)))], + 'slow': True + }, + + # previous code was testing with these other solution: + # example5_solb = Eq(f(x), log(log(C1/x)**(-x))) + 'dep_div_indep_05': { + 'eq': x*exp(f(x)/x) + f(x) - x*f(x).diff(x), + 'sol': [Eq(f(x), log((1/(C1 - log(x)))**x))], + 'checkodesol_XFAIL':True, #(because of **x?) + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_linear_coefficients(): + return { + 'hint': "linear_coefficients", + 'func': f(x), + 'examples':{ + 'linear_coeff_01': { + 'eq': f(x).diff(x) + (3 + 2*f(x))/(x + 3), + 'sol': [Eq(f(x), C1/(x**2 + 6*x + 9) - Rational(3, 2))], + }, + } + } + +@_add_example_keys +def _get_examples_ode_sol_1st_homogeneous_coeff_best(): + return { + 'hint': "1st_homogeneous_coeff_best", + 'func': f(x), + 'examples':{ + # previous code was testing this with other solution: + # example1_solb = Eq(-f(x)/(1 + log(x/f(x))), C1) + '1st_homogeneous_coeff_best_01': { + 'eq': f(x) + (x*log(f(x)/x) - 2*x)*diff(f(x), x), + 'sol': [Eq(f(x), -exp(C1)*LambertW(-x*exp(-C1 + 1)))], + 'checkodesol_XFAIL':True, #(because of LambertW?) + }, + + '1st_homogeneous_coeff_best_02': { + 'eq': 2*f(x)*exp(x/f(x)) + f(x)*f(x).diff(x) - 2*x*exp(x/f(x))*f(x).diff(x), + 'sol': [Eq(log(f(x)), C1 - 2*exp(x/f(x)))], + }, + + # previous code was testing this with other solution: + # example3_solb = Eq(log(C1*x*sqrt(1/x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0) + '1st_homogeneous_coeff_best_03': { + 'eq': 2*x**2*f(x) + f(x)**3 + (x*f(x)**2 - 2*x**3)*f(x).diff(x), + 'sol': [Eq(f(x), exp(2*C1 + LambertW(-2*x**4*exp(-4*C1))/2)/x)], + 'checkodesol_XFAIL':True, #(because of LambertW?) + }, + + '1st_homogeneous_coeff_best_04': { + 'eq': (x + sqrt(f(x)**2 - x*f(x)))*f(x).diff(x) - f(x), + 'sol': [Eq(log(f(x)), C1 - 2*sqrt(-x/f(x) + 1))], + 'slow': True, + }, + + '1st_homogeneous_coeff_best_05': { + 'eq': x + f(x) - (x - f(x))*f(x).diff(x), + 'sol': [Eq(log(x), C1 - log(sqrt(1 + f(x)**2/x**2)) + atan(f(x)/x))], + }, + + '1st_homogeneous_coeff_best_06': { + 'eq': x*f(x).diff(x) - f(x) - x*sin(f(x)/x), + 'sol': [Eq(f(x), 2*x*atan(C1*x))], + }, + + '1st_homogeneous_coeff_best_07': { + 'eq': x**2 + f(x)**2 - 2*x*f(x)*f(x).diff(x), + 'sol': [Eq(f(x), -sqrt(x*(C1 + x))), Eq(f(x), sqrt(x*(C1 + x)))], + }, + + '1st_homogeneous_coeff_best_08': { + 'eq': f(x)**2 + (x*sqrt(f(x)**2 - x**2) - x*f(x))*f(x).diff(x), + 'sol': [Eq(f(x), -C1*sqrt(-x/(x - 2*C1))), Eq(f(x), C1*sqrt(-x/(x - 2*C1)))], + 'checkodesol_XFAIL': True # solutions are valid in a range + }, + } + } + + +def _get_all_examples(): + all_examples = _get_examples_ode_sol_euler_homogeneous + \ + _get_examples_ode_sol_euler_undetermined_coeff + \ + _get_examples_ode_sol_euler_var_para + \ + _get_examples_ode_sol_factorable + \ + _get_examples_ode_sol_bernoulli + \ + _get_examples_ode_sol_nth_algebraic + \ + _get_examples_ode_sol_riccati + \ + _get_examples_ode_sol_1st_linear + \ + _get_examples_ode_sol_1st_exact + \ + _get_examples_ode_sol_almost_linear + \ + _get_examples_ode_sol_nth_order_reducible + \ + _get_examples_ode_sol_nth_linear_undetermined_coefficients + \ + _get_examples_ode_sol_liouville + \ + _get_examples_ode_sol_separable + \ + _get_examples_ode_sol_1st_rational_riccati + \ + _get_examples_ode_sol_nth_linear_var_of_parameters + \ + _get_examples_ode_sol_2nd_linear_bessel + \ + _get_examples_ode_sol_2nd_2F1_hypergeometric + \ + _get_examples_ode_sol_2nd_nonlinear_autonomous_conserved + \ + _get_examples_ode_sol_separable_reduced + \ + _get_examples_ode_sol_lie_group + \ + _get_examples_ode_sol_2nd_linear_airy + \ + _get_examples_ode_sol_nth_linear_constant_coeff_homogeneous +\ + _get_examples_ode_sol_1st_homogeneous_coeff_best +\ + _get_examples_ode_sol_1st_homogeneous_coeff_subs_dep_div_indep +\ + _get_examples_ode_sol_linear_coefficients + + return all_examples diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_subscheck.py b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_subscheck.py new file mode 100644 index 0000000000000000000000000000000000000000..799c2854e878208721b600767de350cda08cd7e5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_subscheck.py @@ -0,0 +1,203 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.functions.special.error_functions import (Ei, erf, erfi) +from sympy.integrals.integrals import Integral + +from sympy.solvers.ode.subscheck import checkodesol, checksysodesol + +from sympy.functions import besselj, bessely + +from sympy.testing.pytest import raises, slow + + +C0, C1, C2, C3, C4 = symbols('C0:5') +u, x, y, z = symbols('u,x:z', real=True) +f = Function('f') +g = Function('g') +h = Function('h') + + +@slow +def test_checkodesol(): + # For the most part, checkodesol is well tested in the tests below. + # These tests only handle cases not checked below. + raises(ValueError, lambda: checkodesol(f(x, y).diff(x), Eq(f(x, y), x))) + raises(ValueError, lambda: checkodesol(f(x).diff(x), Eq(f(x, y), + x), f(x, y))) + assert checkodesol(f(x).diff(x), Eq(f(x, y), x)) == \ + (False, -f(x).diff(x) + f(x, y).diff(x) - 1) + assert checkodesol(f(x).diff(x), Eq(f(x), x)) is not True + assert checkodesol(f(x).diff(x), Eq(f(x), x)) == (False, 1) + sol1 = Eq(f(x)**5 + 11*f(x) - 2*f(x) + x, 0) + assert checkodesol(diff(sol1.lhs, x), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 2), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 2)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3)*exp(f(x)), sol1) == (True, 0) + assert checkodesol(diff(sol1.lhs, x, 3), Eq(f(x), x*log(x))) == \ + (False, 60*x**4*((log(x) + 1)**2 + log(x))*( + log(x) + 1)*log(x)**2 - 5*x**4*log(x)**4 - 9) + assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0)) == \ + (True, 0) + assert checkodesol(diff(exp(f(x)) + x, x)*x, Eq(exp(f(x)) + x, 0), + solve_for_func=False) == (True, 0) + assert checkodesol(f(x).diff(x, 2), [Eq(f(x), C1 + C2*x), + Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)]) == \ + [(True, 0), (True, 0), (False, C2)] + assert checkodesol(f(x).diff(x, 2), {Eq(f(x), C1 + C2*x), + Eq(f(x), C2 + C1*x), Eq(f(x), C1*x + C2*x**2)}) == \ + {(True, 0), (True, 0), (False, C2)} + assert checkodesol(f(x).diff(x) - 1/f(x)/2, Eq(f(x)**2, x)) == \ + [(True, 0), (True, 0)] + assert checkodesol(f(x).diff(x) - f(x), Eq(C1*exp(x), f(x))) == (True, 0) + # Based on test_1st_homogeneous_coeff_ode2_eq3sol. Make sure that + # checkodesol tries back substituting f(x) when it can. + eq3 = x*exp(f(x)/x) + f(x) - x*f(x).diff(x) + sol3 = Eq(f(x), log(log(C1/x)**(-x))) + assert not checkodesol(eq3, sol3)[1].has(f(x)) + # This case was failing intermittently depending on hash-seed: + eqn = Eq(Derivative(x*Derivative(f(x), x), x)/x, exp(x)) + sol = Eq(f(x), C1 + C2*log(x) + exp(x) - Ei(x)) + assert checkodesol(eqn, sol, order=2, solve_for_func=False)[0] + eq = x**2*(f(x).diff(x, 2)) + x*(f(x).diff(x)) + (2*x**2 +25)*f(x) + sol = Eq(f(x), C1*besselj(5*I, sqrt(2)*x) + C2*bessely(5*I, sqrt(2)*x)) + assert checkodesol(eq, sol) == (True, 0) + + eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))] + sol = [Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))] + assert checkodesol(eqs, sol) == (True, [0, 0]) + + +def test_checksysodesol(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + eq = (Eq(diff(x(t),t), 9*y(t)), Eq(diff(y(t),t), 12*x(t))) + sol = [Eq(x(t), 9*C1*exp(-6*sqrt(3)*t) + 9*C2*exp(6*sqrt(3)*t)), \ + Eq(y(t), -6*sqrt(3)*C1*exp(-6*sqrt(3)*t) + 6*sqrt(3)*C2*exp(6*sqrt(3)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 2*x(t) + 4*y(t)), Eq(diff(y(t),t), 12*x(t) + 41*y(t))) + sol = [Eq(x(t), 4*C1*exp(t*(-sqrt(1713)/2 + Rational(43, 2))) + 4*C2*exp(t*(sqrt(1713)/2 + \ + Rational(43, 2)))), Eq(y(t), C1*(-sqrt(1713)/2 + Rational(39, 2))*exp(t*(-sqrt(1713)/2 + \ + Rational(43, 2))) + C2*(Rational(39, 2) + sqrt(1713)/2)*exp(t*(sqrt(1713)/2 + Rational(43, 2))))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t)), Eq(diff(y(t),t), -2*x(t) + 2*y(t))) + sol = [Eq(x(t), (C1*sin(sqrt(7)*t/2) + C2*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2))), \ + Eq(y(t), ((C1/2 - sqrt(7)*C2/2)*sin(sqrt(7)*t/2) + (sqrt(7)*C1/2 + \ + C2/2)*cos(sqrt(7)*t/2))*exp(t*Rational(3, 2)))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23)) + sol = [Eq(x(t), C1*exp(t*(-sqrt(6) + 3)) + C2*exp(t*(sqrt(6) + 3)) - \ + Rational(22, 3)), Eq(y(t), C1*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) + C2*(2 + \ + sqrt(6))*exp(t*(sqrt(6) + 3)) - Rational(5, 3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23)) + sol = [Eq(x(t), (C1*sin(sqrt(2)*t) + C2*cos(sqrt(2)*t))*exp(t) - Rational(58, 3)), \ + Eq(y(t), (sqrt(2)*C1*cos(sqrt(2)*t) - sqrt(2)*C2*sin(sqrt(2)*t))*exp(t) - Rational(185, 3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t))) + sol = [Eq(x(t), (C1*exp(Integral(2, t).doit()) + C2*exp(-(Integral(2, t)).doit()))*\ + exp((Integral(5*t, t)).doit())), Eq(y(t), (C1*exp((Integral(2, t)).doit()) - \ + C2*exp(-(Integral(2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + sol = [Eq(x(t), (C1*cos((Integral(t**2, t)).doit()) + C2*sin((Integral(t**2, t)).doit()))*\ + exp((Integral(5*t, t)).doit())), Eq(y(t), (-C1*sin((Integral(t**2, t)).doit()) + \ + C2*cos((Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t))) + sol = [Eq(x(t), (C1*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \ + C2*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit())), \ + Eq(y(t), (C1*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()) + \ + C2*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(Integral(t**2, t)).doit()))*exp((Integral(5*t, t)).doit()))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 5*x(t) + 43*y(t)), Eq(diff(y(t),t,t), x(t) + 9*y(t))) + root0 = -sqrt(-sqrt(47) + 7) + root1 = sqrt(-sqrt(47) + 7) + root2 = -sqrt(sqrt(47) + 7) + root3 = sqrt(sqrt(47) + 7) + sol = [Eq(x(t), 43*C1*exp(t*root0) + 43*C2*exp(t*root1) + 43*C3*exp(t*root2) + 43*C4*exp(t*root3)), \ + Eq(y(t), C1*(root0**2 - 5)*exp(t*root0) + C2*(root1**2 - 5)*exp(t*root1) + \ + C3*(root2**2 - 5)*exp(t*root2) + C4*(root3**2 - 5)*exp(t*root3))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 8*x(t)+3*y(t)+31), Eq(diff(y(t),t,t), 9*x(t)+7*y(t)+12)) + root0 = -sqrt(-sqrt(109)/2 + Rational(15, 2)) + root1 = sqrt(-sqrt(109)/2 + Rational(15, 2)) + root2 = -sqrt(sqrt(109)/2 + Rational(15, 2)) + root3 = sqrt(sqrt(109)/2 + Rational(15, 2)) + sol = [Eq(x(t), 3*C1*exp(t*root0) + 3*C2*exp(t*root1) + 3*C3*exp(t*root2) + 3*C4*exp(t*root3) - Rational(181, 29)), \ + Eq(y(t), C1*(root0**2 - 8)*exp(t*root0) + C2*(root1**2 - 8)*exp(t*root1) + \ + C3*(root2**2 - 8)*exp(t*root2) + C4*(root3**2 - 8)*exp(t*root3) + Rational(183, 29))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t) - 9*diff(y(t),t) + 7*x(t),0), Eq(diff(y(t),t,t) + 9*diff(x(t),t) + 7*y(t),0)) + sol = [Eq(x(t), C1*cos(t*(Rational(9, 2) + sqrt(109)/2)) + C2*sin(t*(Rational(9, 2) + sqrt(109)/2)) + \ + C3*cos(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*sin(t*(-sqrt(109)/2 + Rational(9, 2)))), Eq(y(t), -C1*sin(t*(Rational(9, 2) + sqrt(109)/2)) \ + + C2*cos(t*(Rational(9, 2) + sqrt(109)/2)) - C3*sin(t*(-sqrt(109)/2 + Rational(9, 2))) + C4*cos(t*(-sqrt(109)/2 + Rational(9, 2))))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t,t), 9*t*diff(y(t),t)-9*y(t)), Eq(diff(y(t),t,t),7*t*diff(x(t),t)-7*x(t))) + I1 = sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erfi(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(3*sqrt(7)*t**2/2)/t + I2 = -sqrt(6)*7**Rational(1, 4)*sqrt(pi)*erf(sqrt(6)*7**Rational(1, 4)*t/2)/2 - exp(-3*sqrt(7)*t**2/2)/t + sol = [Eq(x(t), C3*t + t*(9*C1*I1 + 9*C2*I2)), Eq(y(t), C4*t + t*(3*sqrt(7)*C1*I1 - 3*sqrt(7)*C2*I2))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), 21*x(t)), Eq(diff(y(t),t), 17*x(t)+3*y(t)), Eq(diff(z(t),t), 5*x(t)+7*y(t)+9*z(t))) + sol = [Eq(x(t), C1*exp(21*t)), Eq(y(t), 17*C1*exp(21*t)/18 + C2*exp(3*t)), \ + Eq(z(t), 209*C1*exp(21*t)/216 - 7*C2*exp(3*t)/6 + C3*exp(9*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),3*y(t)-11*z(t)),Eq(diff(y(t),t),7*z(t)-3*x(t)),Eq(diff(z(t),t),11*x(t)-7*y(t))) + sol = [Eq(x(t), 7*C0 + sqrt(179)*C1*cos(sqrt(179)*t) + (77*C1/3 + 130*C2/3)*sin(sqrt(179)*t)), \ + Eq(y(t), 11*C0 + sqrt(179)*C2*cos(sqrt(179)*t) + (-58*C1/3 - 77*C2/3)*sin(sqrt(179)*t)), \ + Eq(z(t), 3*C0 + sqrt(179)*(-7*C1/3 - 11*C2/3)*cos(sqrt(179)*t) + (11*C1 - 7*C2)*sin(sqrt(179)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(3*diff(x(t),t),4*5*(y(t)-z(t))),Eq(4*diff(y(t),t),3*5*(z(t)-x(t))),Eq(5*diff(z(t),t),3*4*(x(t)-y(t)))) + sol = [Eq(x(t), C0 + 5*sqrt(2)*C1*cos(5*sqrt(2)*t) + (12*C1/5 + 164*C2/15)*sin(5*sqrt(2)*t)), \ + Eq(y(t), C0 + 5*sqrt(2)*C2*cos(5*sqrt(2)*t) + (-51*C1/10 - 12*C2/5)*sin(5*sqrt(2)*t)), \ + Eq(z(t), C0 + 5*sqrt(2)*(-9*C1/25 - 16*C2/25)*cos(5*sqrt(2)*t) + (12*C1/5 - 12*C2/5)*sin(5*sqrt(2)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),4*x(t) - z(t)),Eq(diff(y(t),t),2*x(t)+2*y(t)-z(t)),Eq(diff(z(t),t),3*x(t)+y(t))) + sol = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t) + C3*exp(2*t)), \ + Eq(y(t), C1*exp(2*t) + C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t)/2 + C3*t*exp(2*t)), \ + Eq(z(t), 2*C1*exp(2*t) + 2*C2*t*exp(2*t) + C2*exp(2*t) + C3*t**2*exp(2*t) + C3*t*exp(2*t) + C3*exp(2*t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),4*x(t) - y(t) - 2*z(t)),Eq(diff(y(t),t),2*x(t) + y(t)- 2*z(t)),Eq(diff(z(t),t),5*x(t)-3*z(t))) + sol = [Eq(x(t), C1*exp(2*t) + C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), \ + Eq(y(t), C2*(-sin(t) + 3*cos(t)) + C3*(3*sin(t) + cos(t))), Eq(z(t), C1*exp(2*t) + 5*C2*cos(t) + 5*C3*sin(t))] + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + eq = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol = [Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5)) + sol = [Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), \ + Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + eq = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2)) + sol = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)} + assert checksysodesol(eq, sol) == (True, [0, 0]) diff --git a/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_systems.py b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_systems.py new file mode 100644 index 0000000000000000000000000000000000000000..9d206129dfcf38c7b8c2e0ab42bd875003253f35 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/ode/tests/test_systems.py @@ -0,0 +1,2544 @@ +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mul import Mul +from sympy.core.numbers import (I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.hyperbolic import sinh +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.matrices.dense import Matrix +from sympy.core.containers import Tuple +from sympy.functions import exp, cos, sin, log, Ci, Si, erf, erfi +from sympy.matrices import dotprodsimp, NonSquareMatrixError +from sympy.solvers.ode import dsolve +from sympy.solvers.ode.ode import constant_renumber +from sympy.solvers.ode.subscheck import checksysodesol +from sympy.solvers.ode.systems import (_classify_linear_system, linear_ode_to_matrix, + ODEOrderError, ODENonlinearError, _simpsol, + _is_commutative_anti_derivative, linodesolve, + canonical_odes, dsolve_system, _component_division, + _eqs2dict, _dict2graph) +from sympy.functions import airyai, airybi +from sympy.integrals.integrals import Integral +from sympy.simplify.ratsimp import ratsimp +from sympy.testing.pytest import raises, slow, tooslow, XFAIL + + +C0, C1, C2, C3, C4, C5, C6, C7, C8, C9, C10 = symbols('C0:11') +x = symbols('x') +f = Function('f') +g = Function('g') +h = Function('h') + + +def test_linear_ode_to_matrix(): + f, g, h = symbols("f, g, h", cls=Function) + t = Symbol("t") + funcs = [f(t), g(t), h(t)] + f1 = f(t).diff(t) + g1 = g(t).diff(t) + h1 = h(t).diff(t) + f2 = f(t).diff(t, 2) + g2 = g(t).diff(t, 2) + h2 = h(t).diff(t, 2) + + eqs_1 = [Eq(f1, g(t)), Eq(g1, f(t))] + sol_1 = ([Matrix([[1, 0], [0, 1]]), Matrix([[ 0, 1], [1, 0]])], Matrix([[0],[0]])) + assert linear_ode_to_matrix(eqs_1, funcs[:-1], t, 1) == sol_1 + + eqs_2 = [Eq(f1, f(t) + 2*g(t)), Eq(g1, h(t)), Eq(h1, g(t) + h(t) + f(t))] + sol_2 = ([Matrix([[1, 0, 0], [0, 1, 0], [0, 0, 1]]), Matrix([[1, 2, 0], [ 0, 0, 1], [1, 1, 1]])], + Matrix([[0], [0], [0]])) + assert linear_ode_to_matrix(eqs_2, funcs, t, 1) == sol_2 + + eqs_3 = [Eq(2*f1 + 3*h1, f(t) + g(t)), Eq(4*h1 + 5*g1, f(t) + h(t)), Eq(5*f1 + 4*g1, g(t) + h(t))] + sol_3 = ([Matrix([[2, 0, 3], [0, 5, 4], [5, 4, 0]]), Matrix([[1, 1, 0], [1, 0, 1], [0, 1, 1]])], + Matrix([[0], [0], [0]])) + assert linear_ode_to_matrix(eqs_3, funcs, t, 1) == sol_3 + + eqs_4 = [Eq(f2 + h(t), f1 + g(t)), Eq(2*h2 + g2 + g1 + g(t), 0), Eq(3*h1, 4)] + sol_4 = ([Matrix([[1, 0, 0], [0, 1, 2], [0, 0, 0]]), Matrix([[1, 0, 0], [0, -1, 0], [0, 0, -3]]), + Matrix([[0, 1, -1], [0, -1, 0], [0, 0, 0]])], Matrix([[0], [0], [4]])) + assert linear_ode_to_matrix(eqs_4, funcs, t, 2) == sol_4 + + eqs_5 = [Eq(f2, g(t)), Eq(f1 + g1, f(t))] + raises(ODEOrderError, lambda: linear_ode_to_matrix(eqs_5, funcs[:-1], t, 1)) + + eqs_6 = [Eq(f1, f(t)**2), Eq(g1, f(t) + g(t))] + raises(ODENonlinearError, lambda: linear_ode_to_matrix(eqs_6, funcs[:-1], t, 1)) + + +def test__classify_linear_system(): + x, y, z, w = symbols('x, y, z, w', cls=Function) + t, k, l = symbols('t k l') + x1 = diff(x(t), t) + y1 = diff(y(t), t) + z1 = diff(z(t), t) + w1 = diff(w(t), t) + x2 = diff(x(t), t, t) + y2 = diff(y(t), t, t) + funcs = [x(t), y(t)] + funcs_2 = funcs + [z(t), w(t)] + + eqs_1 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * t * x(t) + 3 * y(t) + t)) + assert _classify_linear_system(eqs_1, funcs, t) is None + + eqs_2 = (5 * (x1**2) + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * t * x(t) + 3 * y(t) + t)) + sol2 = {'is_implicit': True, + 'canon_eqs': [[Eq(Derivative(x(t), t), -sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)], + [Eq(Derivative(x(t), t), sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)]]} + assert _classify_linear_system(eqs_2, funcs, t) == sol2 + + eqs_2_1 = [Eq(Derivative(x(t), t), -sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)] + assert _classify_linear_system(eqs_2_1, funcs, t) is None + + eqs_2_2 = [Eq(Derivative(x(t), t), sqrt(-12*x(t)/5 + 6*y(t)/5)), + Eq(Derivative(y(t), t), 11*t*x(t)/2 - t/2 - 3*y(t)/2)] + assert _classify_linear_system(eqs_2_2, funcs, t) is None + + eqs_3 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * x(t) + 3 * y(t)), (5 * w1 + z(t)), (z1 + w(t))) + answer_3 = {'no_of_equation': 4, + 'eq': (12*x(t) - 6*y(t) + 5*Derivative(x(t), t), + -11*x(t) + 3*y(t) + 2*Derivative(y(t), t), + z(t) + 5*Derivative(w(t), t), + w(t) + Derivative(z(t), t)), + 'func': [x(t), y(t), z(t), w(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, + 'is_linear': True, + 'is_constant': True, + 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [Rational(12, 5), Rational(-6, 5), 0, 0], + [Rational(-11, 2), Rational(3, 2), 0, 0], + [0, 0, 0, 1], + [0, 0, Rational(1, 5), 0]]), + 'type_of_equation': 'type1', + 'is_general': True} + assert _classify_linear_system(eqs_3, funcs_2, t) == answer_3 + + eqs_4 = (5 * x1 + 12 * x(t) - 6 * (y(t)), (2 * y1 - 11 * x(t) + 3 * y(t)), (z1 - w(t)), (w1 - z(t))) + answer_4 = {'no_of_equation': 4, + 'eq': (12 * x(t) - 6 * y(t) + 5 * Derivative(x(t), t), + -11 * x(t) + 3 * y(t) + 2 * Derivative(y(t), t), + -w(t) + Derivative(z(t), t), + -z(t) + Derivative(w(t), t)), + 'func': [x(t), y(t), z(t), w(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, + 'is_linear': True, + 'is_constant': True, + 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [Rational(12, 5), Rational(-6, 5), 0, 0], + [Rational(-11, 2), Rational(3, 2), 0, 0], + [0, 0, 0, -1], + [0, 0, -1, 0]]), + 'type_of_equation': 'type1', + 'is_general': True} + assert _classify_linear_system(eqs_4, funcs_2, t) == answer_4 + + eqs_5 = (5*x1 + 12*x(t) - 6*(y(t)) + x2, (2*y1 - 11*x(t) + 3*y(t)), (z1 - w(t)), (w1 - z(t))) + answer_5 = {'no_of_equation': 4, 'eq': (12*x(t) - 6*y(t) + 5*Derivative(x(t), t) + Derivative(x(t), (t, 2)), + -11*x(t) + 3*y(t) + 2*Derivative(y(t), t), -w(t) + Derivative(z(t), t), -z(t) + Derivative(w(t), + t)), 'func': [x(t), y(t), z(t), w(t)], 'order': {x(t): 2, y(t): 1, z(t): 1, w(t): 1}, 'is_linear': + True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': 'type0', 'is_higher_order': True} + assert _classify_linear_system(eqs_5, funcs_2, t) == answer_5 + + eqs_6 = (Eq(x1, 3*y(t) - 11*z(t)), Eq(y1, 7*z(t) - 3*x(t)), Eq(z1, 11*x(t) - 7*y(t))) + answer_6 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ 0, -3, 11], + [ 3, 0, -7], + [-11, 7, 0]]), + 'type_of_equation': 'type1', 'is_general': True} + + assert _classify_linear_system(eqs_6, funcs_2[:-1], t) == answer_6 + + eqs_7 = (Eq(x1, y(t)), Eq(y1, x(t))) + answer_7 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, + 'is_homogeneous': True, 'func_coeff': -Matrix([ + [ 0, -1], + [-1, 0]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_7, funcs, t) == answer_7 + + eqs_8 = (Eq(x1, 21*x(t)), Eq(y1, 17*x(t) + 3*y(t)), Eq(z1, 5*x(t) + 7*y(t) + 9*z(t))) + answer_8 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 21*x(t)), Eq(Derivative(y(t), t), 17*x(t) + 3*y(t)), + Eq(Derivative(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-21, 0, 0], + [-17, -3, 0], + [ -5, -7, -9]]), + 'type_of_equation': 'type1', 'is_general': True} + + assert _classify_linear_system(eqs_8, funcs_2[:-1], t) == answer_8 + + eqs_9 = (Eq(x1, 4*x(t) + 5*y(t) + 2*z(t)), Eq(y1, x(t) + 13*y(t) + 9*z(t)), Eq(z1, 32*x(t) + 41*y(t) + 11*z(t))) + answer_9 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 4*x(t) + 5*y(t) + 2*z(t)), + Eq(Derivative(y(t), t), x(t) + 13*y(t) + 9*z(t)), Eq(Derivative(z(t), t), 32*x(t) + 41*y(t) + 11*z(t))), + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, + 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ -4, -5, -2], + [ -1, -13, -9], + [-32, -41, -11]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_9, funcs_2[:-1], t) == answer_9 + + eqs_10 = (Eq(3*x1, 4*5*(y(t) - z(t))), Eq(4*y1, 3*5*(z(t) - x(t))), Eq(5*z1, 3*4*(x(t) - y(t)))) + answer_10 = {'no_of_equation': 3, 'eq': (Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), + Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))), + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, + 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [ 0, Rational(-20, 3), Rational(20, 3)], + [Rational(15, 4), 0, Rational(-15, 4)], + [Rational(-12, 5), Rational(12, 5), 0]]), + 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eqs_10, funcs_2[:-1], t) == answer_10 + + eq11 = (Eq(x1, 3*y(t) - 11*z(t)), Eq(y1, 7*z(t) - 3*x(t)), Eq(z1, 11*x(t) - 7*y(t))) + sol11 = {'no_of_equation': 3, 'eq': (Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))), 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, + 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, 'func_coeff': -Matrix([ + [ 0, -3, 11], [ 3, 0, -7], [-11, 7, 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq11, funcs_2[:-1], t) == sol11 + + eq12 = (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))) + sol12 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), y(t)), Eq(Derivative(y(t), t), x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, + 'is_homogeneous': True, 'func_coeff': -Matrix([ + [0, -1], + [-1, 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq12, [x(t), y(t)], t) == sol12 + + eq13 = (Eq(Derivative(x(t), t), 21*x(t)), Eq(Derivative(y(t), t), 17*x(t) + 3*y(t)), + Eq(Derivative(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))) + sol13 = {'no_of_equation': 3, 'eq': ( + Eq(Derivative(x(t), t), 21 * x(t)), Eq(Derivative(y(t), t), 17 * x(t) + 3 * y(t)), + Eq(Derivative(z(t), t), 5 * x(t) + 7 * y(t) + 9 * z(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-21, 0, 0], + [-17, -3, 0], + [-5, -7, -9]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq13, [x(t), y(t), z(t)], t) == sol13 + + eq14 = ( + Eq(Derivative(x(t), t), 4*x(t) + 5*y(t) + 2*z(t)), Eq(Derivative(y(t), t), x(t) + 13*y(t) + 9*z(t)), + Eq(Derivative(z(t), t), 32*x(t) + 41*y(t) + 11*z(t))) + sol14 = {'no_of_equation': 3, 'eq': ( + Eq(Derivative(x(t), t), 4 * x(t) + 5 * y(t) + 2 * z(t)), Eq(Derivative(y(t), t), x(t) + 13 * y(t) + 9 * z(t)), + Eq(Derivative(z(t), t), 32 * x(t) + 41 * y(t) + 11 * z(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [-4, -5, -2], + [-1, -13, -9], + [-32, -41, -11]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq14, [x(t), y(t), z(t)], t) == sol14 + + eq15 = (Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), + Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))) + sol15 = {'no_of_equation': 3, 'eq': ( + Eq(3 * Derivative(x(t), t), 20 * y(t) - 20 * z(t)), Eq(4 * Derivative(y(t), t), -15 * x(t) + 15 * z(t)), + Eq(5 * Derivative(z(t), t), 12 * x(t) - 12 * y(t))), 'func': [x(t), y(t), z(t)], + 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': True, + 'func_coeff': -Matrix([ + [0, Rational(-20, 3), Rational(20, 3)], + [Rational(15, 4), 0, Rational(-15, 4)], + [Rational(-12, 5), Rational(12, 5), 0]]), 'type_of_equation': 'type1', 'is_general': True} + assert _classify_linear_system(eq15, [x(t), y(t), z(t)], t) == sol15 + + # Constant coefficient homogeneous ODEs + eq1 = (Eq(diff(x(t), t), x(t) + y(t) + 9), Eq(diff(y(t), t), 2*x(t) + 5*y(t) + 23)) + sol1 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), x(t) + y(t) + 9), + Eq(Derivative(y(t), t), 2*x(t) + 5*y(t) + 23)), 'func': [x(t), y(t)], + 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': True, 'is_homogeneous': False, 'is_general': True, + 'func_coeff': -Matrix([[-1, -1], [-2, -5]]), 'rhs': Matrix([[ 9], [23]]), 'type_of_equation': 'type2'} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + # Non constant coefficient homogeneous ODEs + eq1 = (Eq(diff(x(t), t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t), t), 2*x(t) + 5*t*y(t))) + sol1 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), t), 5*t*x(t) + 2*y(t)), Eq(Derivative(y(t), t), 5*t*y(t) + 2*x(t))), + 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': True, 'func_coeff': -Matrix([ [-5*t, -2], [ -2, -5*t]]), 'commutative_antiderivative': Matrix([ + [5*t**2/2, 2*t], [ 2*t, 5*t**2/2]]), 'type_of_equation': 'type3', 'is_general': True} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + # Non constant coefficient non-homogeneous ODEs + eq1 = [Eq(x1, x(t) + t*y(t) + t), Eq(y1, t*x(t) + y(t))] + sol1 = {'no_of_equation': 2, 'eq': [Eq(Derivative(x(t), t), t*y(t) + t + x(t)), Eq(Derivative(y(t), t), + t*x(t) + y(t))], 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, + 'is_constant': False, 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-1, -t], + [-t, -1]]), 'commutative_antiderivative': Matrix([ [ t, t**2/2], [t**2/2, t]]), 'rhs': + Matrix([ [t], [0]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + eq2 = [Eq(x1, t*x(t) + t*y(t) + t), Eq(y1, t*x(t) + t*y(t) + cos(t))] + sol2 = {'no_of_equation': 2, 'eq': [Eq(Derivative(x(t), t), t*x(t) + t*y(t) + t), Eq(Derivative(y(t), t), + t*x(t) + t*y(t) + cos(t))], 'func': [x(t), y(t)], 'order': {x(t): 1, y(t): 1}, 'is_linear': True, + 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ [ t], [cos(t)]]), 'func_coeff': + Matrix([ [t, t], [t, t]]), 'is_constant': False, 'type_of_equation': 'type4', + 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2], [t**2/2, t**2/2]])} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = [Eq(x1, t*(x(t) + y(t) + z(t) + 1)), Eq(y1, t*(x(t) + y(t) + z(t))), Eq(z1, t*(x(t) + y(t) + z(t)))] + sol3 = {'no_of_equation': 3, 'eq': [Eq(Derivative(x(t), t), t*(x(t) + y(t) + z(t) + 1)), + Eq(Derivative(y(t), t), t*(x(t) + y(t) + z(t))), Eq(Derivative(z(t), t), t*(x(t) + y(t) + z(t)))], + 'func': [x(t), y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': + False, 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-t, -t, -t], [-t, -t, + -t], [-t, -t, -t]]), 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2, t**2/2], [t**2/2, + t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2]]), 'rhs': Matrix([ [t], [0], [0]]), 'type_of_equation': + 'type4'} + assert _classify_linear_system(eq3, funcs_2[:-1], t) == sol3 + + eq4 = [Eq(x1, x(t) + y(t) + t*z(t) + 1), Eq(y1, x(t) + t*y(t) + z(t) + 10), Eq(z1, t*x(t) + y(t) + z(t) + t)] + sol4 = {'no_of_equation': 3, 'eq': [Eq(Derivative(x(t), t), t*z(t) + x(t) + y(t) + 1), Eq(Derivative(y(t), + t), t*y(t) + x(t) + z(t) + 10), Eq(Derivative(z(t), t), t*x(t) + t + y(t) + z(t))], 'func': [x(t), + y(t), z(t)], 'order': {x(t): 1, y(t): 1, z(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-1, -1, -t], [-1, -t, -1], [-t, + -1, -1]]), 'commutative_antiderivative': Matrix([ [ t, t, t**2/2], [ t, t**2/2, + t], [t**2/2, t, t]]), 'rhs': Matrix([ [ 1], [10], [ t]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq4, funcs_2[:-1], t) == sol4 + + sum_terms = t*(x(t) + y(t) + z(t) + w(t)) + eq5 = [Eq(x1, sum_terms), Eq(y1, sum_terms), Eq(z1, sum_terms + 1), Eq(w1, sum_terms)] + sol5 = {'no_of_equation': 4, 'eq': [Eq(Derivative(x(t), t), t*(w(t) + x(t) + y(t) + z(t))), + Eq(Derivative(y(t), t), t*(w(t) + x(t) + y(t) + z(t))), Eq(Derivative(z(t), t), t*(w(t) + x(t) + + y(t) + z(t)) + 1), Eq(Derivative(w(t), t), t*(w(t) + x(t) + y(t) + z(t)))], 'func': [x(t), y(t), + z(t), w(t)], 'order': {x(t): 1, y(t): 1, z(t): 1, w(t): 1}, 'is_linear': True, 'is_constant': False, + 'is_homogeneous': False, 'is_general': True, 'func_coeff': -Matrix([ [-t, -t, -t, -t], [-t, -t, -t, + -t], [-t, -t, -t, -t], [-t, -t, -t, -t]]), 'commutative_antiderivative': Matrix([ [t**2/2, t**2/2, + t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2, t**2/2], [t**2/2, t**2/2, t**2/2, t**2/2], [t**2/2, + t**2/2, t**2/2, t**2/2]]), 'rhs': Matrix([ [0], [0], [1], [0]]), 'type_of_equation': 'type4'} + assert _classify_linear_system(eq5, funcs_2, t) == sol5 + + # Second Order + t_ = symbols("t_") + + eq1 = (Eq(9*x(t) + 7*y(t) + 4*Derivative(x(t), t) + Derivative(x(t), (t, 2)) + 3*Derivative(y(t), t), 11*exp(I*t)), + Eq(3*x(t) + 12*y(t) + 5*Derivative(x(t), t) + 8*Derivative(y(t), t) + Derivative(y(t), (t, 2)), 2*exp(I*t))) + sol1 = {'no_of_equation': 2, 'eq': (Eq(9*x(t) + 7*y(t) + 4*Derivative(x(t), t) + Derivative(x(t), (t, 2)) + + 3*Derivative(y(t), t), 11*exp(I*t)), Eq(3*x(t) + 12*y(t) + 5*Derivative(x(t), t) + + 8*Derivative(y(t), t) + Derivative(y(t), (t, 2)), 2*exp(I*t))), 'func': [x(t), y(t)], 'order': + {x(t): 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ + [11*exp(I*t)], [ 2*exp(I*t)]]), 'type_of_equation': 'type0', 'is_second_order': True, + 'is_higher_order': True} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + eq2 = (Eq((4*t**2 + 7*t + 1)**2*Derivative(x(t), (t, 2)), 5*x(t) + 35*y(t)), + Eq((4*t**2 + 7*t + 1)**2*Derivative(y(t), (t, 2)), x(t) + 9*y(t))) + sol2 = {'no_of_equation': 2, 'eq': (Eq((4*t**2 + 7*t + 1)**2*Derivative(x(t), (t, 2)), 5*x(t) + 35*y(t)), + Eq((4*t**2 + 7*t + 1)**2*Derivative(y(t), (t, 2)), x(t) + 9*y(t))), 'func': [x(t), y(t)], 'order': + {x(t): 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': True, 'is_general': True, + 'type_of_equation': 'type2', 'A0': Matrix([ [Rational(53, 4), 35], [ 1, Rational(69, 4)]]), 'g(t)': sqrt(4*t**2 + 7*t + + 1), 'tau': sqrt(33)*log(t - sqrt(33)/8 + Rational(7, 8))/33 - sqrt(33)*log(t + sqrt(33)/8 + Rational(7, 8))/33, + 'is_transformed': True, 't_': t_, 'is_second_order': True, 'is_higher_order': True} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = ((t*Derivative(x(t), t) - x(t))*log(t) + (t*Derivative(y(t), t) - y(t))*exp(t) + Derivative(x(t), (t, 2)), + t**2*(t*Derivative(x(t), t) - x(t)) + t*(t*Derivative(y(t), t) - y(t)) + Derivative(y(t), (t, 2))) + sol3 = {'no_of_equation': 2, 'eq': ((t*Derivative(x(t), t) - x(t))*log(t) + (t*Derivative(y(t), t) - + y(t))*exp(t) + Derivative(x(t), (t, 2)), t**2*(t*Derivative(x(t), t) - x(t)) + t*(t*Derivative(y(t), + t) - y(t)) + Derivative(y(t), (t, 2))), 'func': [x(t), y(t)], 'order': {x(t): 2, y(t): 2}, + 'is_linear': True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': 'type1', 'A1': + Matrix([ [-t*log(t), -t*exp(t)], [ -t**3, -t**2]]), 'is_second_order': True, + 'is_higher_order': True} + assert _classify_linear_system(eq3, funcs, t) == sol3 + + eq4 = (Eq(x2, k*x(t) - l*y1), Eq(y2, l*x1 + k*y(t))) + sol4 = {'no_of_equation': 2, 'eq': (Eq(Derivative(x(t), (t, 2)), k*x(t) - l*Derivative(y(t), t)), + Eq(Derivative(y(t), (t, 2)), k*y(t) + l*Derivative(x(t), t))), 'func': [x(t), y(t)], 'order': {x(t): + 2, y(t): 2}, 'is_linear': True, 'is_homogeneous': True, 'is_general': True, 'type_of_equation': + 'type0', 'is_second_order': True, 'is_higher_order': True} + assert _classify_linear_system(eq4, funcs, t) == sol4 + + + # Multiple matches + + f, g = symbols("f g", cls=Function) + y, t_ = symbols("y t_") + funcs = [f(t), g(t)] + + eq1 = [Eq(Derivative(f(t), t)**2 - 2*Derivative(f(t), t) + 1, 4), + Eq(-y*f(t) + Derivative(g(t), t), 0)] + sol1 = {'is_implicit': True, + 'canon_eqs': [[Eq(Derivative(f(t), t), -1), Eq(Derivative(g(t), t), y*f(t))], + [Eq(Derivative(f(t), t), 3), Eq(Derivative(g(t), t), y*f(t))]]} + assert _classify_linear_system(eq1, funcs, t) == sol1 + + raises(ValueError, lambda: _classify_linear_system(eq1, funcs[:1], t)) + + eq2 = [Eq(Derivative(f(t), t), (2*f(t) + g(t) + 1)/t), Eq(Derivative(g(t), t), (f(t) + 2*g(t))/t)] + sol2 = {'no_of_equation': 2, 'eq': [Eq(Derivative(f(t), t), (2*f(t) + g(t) + 1)/t), Eq(Derivative(g(t), t), + (f(t) + 2*g(t))/t)], 'func': [f(t), g(t)], 'order': {f(t): 1, g(t): 1}, 'is_linear': True, + 'is_homogeneous': False, 'is_general': True, 'rhs': Matrix([ [1], [0]]), 'func_coeff': Matrix([ [2, + 1], [1, 2]]), 'is_constant': False, 'type_of_equation': 'type6', 't_': t_, 'tau': log(t), + 'commutative_antiderivative': Matrix([ [2*log(t), log(t)], [ log(t), 2*log(t)]])} + assert _classify_linear_system(eq2, funcs, t) == sol2 + + eq3 = [Eq(Derivative(f(t), t), (2*f(t) + g(t))/t), Eq(Derivative(g(t), t), (f(t) + 2*g(t))/t)] + sol3 = {'no_of_equation': 2, 'eq': [Eq(Derivative(f(t), t), (2*f(t) + g(t))/t), Eq(Derivative(g(t), t), + (f(t) + 2*g(t))/t)], 'func': [f(t), g(t)], 'order': {f(t): 1, g(t): 1}, 'is_linear': True, + 'is_homogeneous': True, 'is_general': True, 'func_coeff': Matrix([ [2, 1], [1, 2]]), 'is_constant': + False, 'type_of_equation': 'type5', 't_': t_, 'rhs': Matrix([ [0], [0]]), 'tau': log(t), + 'commutative_antiderivative': Matrix([ [2*log(t), log(t)], [ log(t), 2*log(t)]])} + assert _classify_linear_system(eq3, funcs, t) == sol3 + + +def test_matrix_exp(): + from sympy.matrices.dense import Matrix, eye, zeros + from sympy.solvers.ode.systems import matrix_exp + t = Symbol('t') + + for n in range(1, 6+1): + assert matrix_exp(zeros(n), t) == eye(n) + + for n in range(1, 6+1): + A = eye(n) + expAt = exp(t) * eye(n) + assert matrix_exp(A, t) == expAt + + for n in range(1, 6+1): + A = Matrix(n, n, lambda i,j: i+1 if i==j else 0) + expAt = Matrix(n, n, lambda i,j: exp((i+1)*t) if i==j else 0) + assert matrix_exp(A, t) == expAt + + A = Matrix([[0, 1], [-1, 0]]) + expAt = Matrix([[cos(t), sin(t)], [-sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[2, -5], [2, -4]]) + expAt = Matrix([ + [3*exp(-t)*sin(t) + exp(-t)*cos(t), -5*exp(-t)*sin(t)], + [2*exp(-t)*sin(t), -3*exp(-t)*sin(t) + exp(-t)*cos(t)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[21, 17, 6], [-5, -1, -6], [4, 4, 16]]) + # TO update this. + # expAt = Matrix([ + # [(8*t*exp(12*t) + 5*exp(12*t) - 1)*exp(4*t)/4, + # (8*t*exp(12*t) + 5*exp(12*t) - 5)*exp(4*t)/4, + # (exp(12*t) - 1)*exp(4*t)/2], + # [(-8*t*exp(12*t) - exp(12*t) + 1)*exp(4*t)/4, + # (-8*t*exp(12*t) - exp(12*t) + 5)*exp(4*t)/4, + # (-exp(12*t) + 1)*exp(4*t)/2], + # [4*t*exp(16*t), 4*t*exp(16*t), exp(16*t)]]) + expAt = Matrix([ + [2*t*exp(16*t) + 5*exp(16*t)/4 - exp(4*t)/4, 2*t*exp(16*t) + 5*exp(16*t)/4 - 5*exp(4*t)/4, exp(16*t)/2 - exp(4*t)/2], + [ -2*t*exp(16*t) - exp(16*t)/4 + exp(4*t)/4, -2*t*exp(16*t) - exp(16*t)/4 + 5*exp(4*t)/4, -exp(16*t)/2 + exp(4*t)/2], + [ 4*t*exp(16*t), 4*t*exp(16*t), exp(16*t)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([[1, 1, 0, 0], + [0, 1, 1, 0], + [0, 0, 1, -S(1)/8], + [0, 0, S(1)/2, S(1)/2]]) + expAt = Matrix([ + [exp(t), t*exp(t), 4*t*exp(3*t/4) + 8*t*exp(t) + 48*exp(3*t/4) - 48*exp(t), + -2*t*exp(3*t/4) - 2*t*exp(t) - 16*exp(3*t/4) + 16*exp(t)], + [0, exp(t), -t*exp(3*t/4) - 8*exp(3*t/4) + 8*exp(t), t*exp(3*t/4)/2 + 2*exp(3*t/4) - 2*exp(t)], + [0, 0, t*exp(3*t/4)/4 + exp(3*t/4), -t*exp(3*t/4)/8], + [0, 0, t*exp(3*t/4)/2, -t*exp(3*t/4)/4 + exp(3*t/4)] + ]) + assert matrix_exp(A, t) == expAt + + A = Matrix([ + [ 0, 1, 0, 0], + [-1, 0, 0, 0], + [ 0, 0, 0, 1], + [ 0, 0, -1, 0]]) + + expAt = Matrix([ + [ cos(t), sin(t), 0, 0], + [-sin(t), cos(t), 0, 0], + [ 0, 0, cos(t), sin(t)], + [ 0, 0, -sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + A = Matrix([ + [ 0, 1, 1, 0], + [-1, 0, 0, 1], + [ 0, 0, 0, 1], + [ 0, 0, -1, 0]]) + + expAt = Matrix([ + [ cos(t), sin(t), t*cos(t), t*sin(t)], + [-sin(t), cos(t), -t*sin(t), t*cos(t)], + [ 0, 0, cos(t), sin(t)], + [ 0, 0, -sin(t), cos(t)]]) + assert matrix_exp(A, t) == expAt + + # This case is unacceptably slow right now but should be solvable... + #a, b, c, d, e, f = symbols('a b c d e f') + #A = Matrix([ + #[-a, b, c, d], + #[ a, -b, e, 0], + #[ 0, 0, -c - e - f, 0], + #[ 0, 0, f, -d]]) + + A = Matrix([[0, I], [I, 0]]) + expAt = Matrix([ + [exp(I*t)/2 + exp(-I*t)/2, exp(I*t)/2 - exp(-I*t)/2], + [exp(I*t)/2 - exp(-I*t)/2, exp(I*t)/2 + exp(-I*t)/2]]) + assert matrix_exp(A, t) == expAt + + # Testing Errors + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 7, 7]]) + M1 = Matrix([[t, 1], [1, 1]]) + + raises(ValueError, lambda: matrix_exp(M[:, :2], t)) + raises(ValueError, lambda: matrix_exp(M[:2, :], t)) + raises(ValueError, lambda: matrix_exp(M1, t)) + raises(ValueError, lambda: matrix_exp(M1[:1, :1], t)) + + +def test_canonical_odes(): + f, g, h = symbols('f g h', cls=Function) + x = symbols('x') + funcs = [f(x), g(x), h(x)] + + eqs1 = [Eq(f(x).diff(x, x), f(x) + 2*g(x)), Eq(g(x) + 1, g(x).diff(x) + f(x))] + sol1 = [[Eq(Derivative(f(x), (x, 2)), f(x) + 2*g(x)), Eq(Derivative(g(x), x), -f(x) + g(x) + 1)]] + assert canonical_odes(eqs1, funcs[:2], x) == sol1 + + eqs2 = [Eq(f(x).diff(x), h(x).diff(x) + f(x)), Eq(g(x).diff(x)**2, f(x) + h(x)), Eq(h(x).diff(x), f(x))] + sol2 = [[Eq(Derivative(f(x), x), 2*f(x)), Eq(Derivative(g(x), x), -sqrt(f(x) + h(x))), Eq(Derivative(h(x), x), f(x))], + [Eq(Derivative(f(x), x), 2*f(x)), Eq(Derivative(g(x), x), sqrt(f(x) + h(x))), Eq(Derivative(h(x), x), f(x))]] + assert canonical_odes(eqs2, funcs, x) == sol2 + + +def test_sysode_linear_neq_order1_type1(): + + f, g, x, y, h = symbols('f g x y h', cls=Function) + a, b, c, t = symbols('a b c t') + + eqs1 = [Eq(Derivative(x(t), t), x(t)), + Eq(Derivative(y(t), t), y(t))] + sol1 = [Eq(x(t), C1*exp(t)), + Eq(y(t), C2*exp(t))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(x(t), t), 2*x(t)), + Eq(Derivative(y(t), t), 3*y(t))] + sol2 = [Eq(x(t), C1*exp(2*t)), + Eq(y(t), C2*exp(3*t))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(x(t), t), a*x(t)), + Eq(Derivative(y(t), t), a*y(t))] + sol3 = [Eq(x(t), C1*exp(a*t)), + Eq(y(t), C2*exp(a*t))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + eqs4 = [Eq(Derivative(x(t), t), a*x(t)), + Eq(Derivative(y(t), t), b*y(t))] + sol4 = [Eq(x(t), C1*exp(a*t)), + Eq(y(t), C2*exp(b*t))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(Derivative(x(t), t), -y(t)), + Eq(Derivative(y(t), t), x(t))] + sol5 = [Eq(x(t), -C1*sin(t) - C2*cos(t)), + Eq(y(t), C1*cos(t) - C2*sin(t))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(x(t), t), -2*y(t)), + Eq(Derivative(y(t), t), 2*x(t))] + sol6 = [Eq(x(t), -C1*sin(2*t) - C2*cos(2*t)), + Eq(y(t), C1*cos(2*t) - C2*sin(2*t))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + eqs7 = [Eq(Derivative(x(t), t), I*y(t)), + Eq(Derivative(y(t), t), I*x(t))] + sol7 = [Eq(x(t), -C1*exp(-I*t) + C2*exp(I*t)), + Eq(y(t), C1*exp(-I*t) + C2*exp(I*t))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + eqs8 = [Eq(Derivative(x(t), t), -a*y(t)), + Eq(Derivative(y(t), t), a*x(t))] + sol8 = [Eq(x(t), -I*C1*exp(-I*a*t) + I*C2*exp(I*a*t)), + Eq(y(t), C1*exp(-I*a*t) + C2*exp(I*a*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + eqs9 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), x(t) - y(t))] + sol9 = [Eq(x(t), C1*(1 - sqrt(2))*exp(-sqrt(2)*t) + C2*(1 + sqrt(2))*exp(sqrt(2)*t)), + Eq(y(t), C1*exp(-sqrt(2)*t) + C2*exp(sqrt(2)*t))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0]) + + eqs10 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol10 = [Eq(x(t), -C1 + C2*exp(2*t)), + Eq(y(t), C1 + C2*exp(2*t))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + eqs11 = [Eq(Derivative(x(t), t), 2*x(t) + y(t)), + Eq(Derivative(y(t), t), -x(t) + 2*y(t))] + sol11 = [Eq(x(t), C1*exp(2*t)*sin(t) + C2*exp(2*t)*cos(t)), + Eq(y(t), C1*exp(2*t)*cos(t) - C2*exp(2*t)*sin(t))] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0]) + + eqs12 = [Eq(Derivative(x(t), t), x(t) + 2*y(t)), + Eq(Derivative(y(t), t), 2*x(t) + y(t))] + sol12 = [Eq(x(t), -C1*exp(-t) + C2*exp(3*t)), + Eq(y(t), C1*exp(-t) + C2*exp(3*t))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + eqs13 = [Eq(Derivative(x(t), t), 4*x(t) + y(t)), + Eq(Derivative(y(t), t), -x(t) + 2*y(t))] + sol13 = [Eq(x(t), C2*t*exp(3*t) + (C1 + C2)*exp(3*t)), + Eq(y(t), -C1*exp(3*t) - C2*t*exp(3*t))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + eqs14 = [Eq(Derivative(x(t), t), a*y(t)), + Eq(Derivative(y(t), t), a*x(t))] + sol14 = [Eq(x(t), -C1*exp(-a*t) + C2*exp(a*t)), + Eq(y(t), C1*exp(-a*t) + C2*exp(a*t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0]) + + eqs15 = [Eq(Derivative(x(t), t), a*y(t)), + Eq(Derivative(y(t), t), b*x(t))] + sol15 = [Eq(x(t), -C1*a*exp(-t*sqrt(a*b))/sqrt(a*b) + C2*a*exp(t*sqrt(a*b))/sqrt(a*b)), + Eq(y(t), C1*exp(-t*sqrt(a*b)) + C2*exp(t*sqrt(a*b)))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0]) + + eqs16 = [Eq(Derivative(x(t), t), a*x(t) + b*y(t)), + Eq(Derivative(y(t), t), c*x(t))] + sol16 = [Eq(x(t), -2*C1*b*exp(t*(a + sqrt(a**2 + 4*b*c))/2)/(a - sqrt(a**2 + 4*b*c)) - 2*C2*b*exp(t*(a - + sqrt(a**2 + 4*b*c))/2)/(a + sqrt(a**2 + 4*b*c))), + Eq(y(t), C1*exp(t*(a + sqrt(a**2 + 4*b*c))/2) + C2*exp(t*(a - sqrt(a**2 + 4*b*c))/2))] + assert dsolve(eqs16) == sol16 + assert checksysodesol(eqs16, sol16) == (True, [0, 0]) + + # Regression test case for issue #18562 + # https://github.com/sympy/sympy/issues/18562 + eqs17 = [Eq(Derivative(x(t), t), a*y(t) + x(t)), + Eq(Derivative(y(t), t), a*x(t) - y(t))] + sol17 = [Eq(x(t), C1*a*exp(t*sqrt(a**2 + 1))/(sqrt(a**2 + 1) - 1) - C2*a*exp(-t*sqrt(a**2 + 1))/(sqrt(a**2 + + 1) + 1)), + Eq(y(t), C1*exp(t*sqrt(a**2 + 1)) + C2*exp(-t*sqrt(a**2 + 1)))] + assert dsolve(eqs17) == sol17 + assert checksysodesol(eqs17, sol17) == (True, [0, 0]) + + eqs18 = [Eq(Derivative(x(t), t), 0), + Eq(Derivative(y(t), t), 0)] + sol18 = [Eq(x(t), C1), + Eq(y(t), C2)] + assert dsolve(eqs18) == sol18 + assert checksysodesol(eqs18, sol18) == (True, [0, 0]) + + eqs19 = [Eq(Derivative(x(t), t), 2*x(t) - y(t)), + Eq(Derivative(y(t), t), x(t))] + sol19 = [Eq(x(t), C2*t*exp(t) + (C1 + C2)*exp(t)), + Eq(y(t), C1*exp(t) + C2*t*exp(t))] + assert dsolve(eqs19) == sol19 + assert checksysodesol(eqs19, sol19) == (True, [0, 0]) + + eqs20 = [Eq(Derivative(x(t), t), x(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol20 = [Eq(x(t), C1*exp(t)), + Eq(y(t), C1*t*exp(t) + C2*exp(t))] + assert dsolve(eqs20) == sol20 + assert checksysodesol(eqs20, sol20) == (True, [0, 0]) + + eqs21 = [Eq(Derivative(x(t), t), 3*x(t)), + Eq(Derivative(y(t), t), x(t) + y(t))] + sol21 = [Eq(x(t), 2*C1*exp(3*t)), + Eq(y(t), C1*exp(3*t) + C2*exp(t))] + assert dsolve(eqs21) == sol21 + assert checksysodesol(eqs21, sol21) == (True, [0, 0]) + + eqs22 = [Eq(Derivative(x(t), t), 3*x(t)), + Eq(Derivative(y(t), t), y(t))] + sol22 = [Eq(x(t), C1*exp(3*t)), + Eq(y(t), C2*exp(t))] + assert dsolve(eqs22) == sol22 + assert checksysodesol(eqs22, sol22) == (True, [0, 0]) + + +@slow +def test_sysode_linear_neq_order1_type1_slow(): + + t = Symbol('t') + Z0 = Function('Z0') + Z1 = Function('Z1') + Z2 = Function('Z2') + Z3 = Function('Z3') + + k01, k10, k20, k21, k23, k30 = symbols('k01 k10 k20 k21 k23 k30') + + eqs1 = [Eq(Derivative(Z0(t), t), -k01*Z0(t) + k10*Z1(t) + k20*Z2(t) + k30*Z3(t)), + Eq(Derivative(Z1(t), t), k01*Z0(t) - k10*Z1(t) + k21*Z2(t)), + Eq(Derivative(Z2(t), t), (-k20 - k21 - k23)*Z2(t)), + Eq(Derivative(Z3(t), t), k23*Z2(t) - k30*Z3(t))] + sol1 = [Eq(Z0(t), C1*k10/k01 - C2*(k10 - k30)*exp(-k30*t)/(k01 + k10 - k30) - C3*(k10*(k20 + k21 - k30) - + k20**2 - k20*(k21 + k23 - k30) + k23*k30)*exp(-t*(k20 + k21 + k23))/(k23*(-k01 - k10 + k20 + k21 + + k23)) - C4*exp(-t*(k01 + k10))), + Eq(Z1(t), C1 - C2*k01*exp(-k30*t)/(k01 + k10 - k30) + C3*(-k01*(k20 + k21 - k30) + k20*k21 + k21**2 + + k21*(k23 - k30))*exp(-t*(k20 + k21 + k23))/(k23*(-k01 - k10 + k20 + k21 + k23)) + C4*exp(-t*(k01 + + k10))), + Eq(Z2(t), -C3*(k20 + k21 + k23 - k30)*exp(-t*(k20 + k21 + k23))/k23), + Eq(Z3(t), C2*exp(-k30*t) + C3*exp(-t*(k20 + k21 + k23)))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0, 0, 0]) + + x, y, z, u, v, w = symbols('x y z u v w', cls=Function) + k2, k3 = symbols('k2 k3') + a_b, a_c = symbols('a_b a_c', real=True) + + eqs2 = [Eq(Derivative(z(t), t), k2*y(t)), + Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t))] + sol2 = [Eq(z(t), C1 - C2*k2*exp(-t*(k2 + k3))/(k2 + k3)), + Eq(x(t), -C2*k3*exp(-t*(k2 + k3))/(k2 + k3) + C3), + Eq(y(t), C2*exp(-t*(k2 + k3)))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0, 0]) + + eqs3 = [4*u(t) - v(t) - 2*w(t) + Derivative(u(t), t), + 2*u(t) + v(t) - 2*w(t) + Derivative(v(t), t), + 5*u(t) + v(t) - 3*w(t) + Derivative(w(t), t)] + sol3 = [Eq(u(t), C3*exp(-2*t) + (C1/2 + sqrt(3)*C2/6)*cos(sqrt(3)*t) + sin(sqrt(3)*t)*(sqrt(3)*C1/6 + + C2*Rational(-1, 2))), + Eq(v(t), (C1/2 + sqrt(3)*C2/6)*cos(sqrt(3)*t) + sin(sqrt(3)*t)*(sqrt(3)*C1/6 + C2*Rational(-1, 2))), + Eq(w(t), C1*cos(sqrt(3)*t) - C2*sin(sqrt(3)*t) + C3*exp(-2*t))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0]) + + eqs4 = [Eq(Derivative(x(t), t), w(t)*Rational(-2, 9) + 2*x(t) + y(t) + z(t)*Rational(-8, 9)), + Eq(Derivative(y(t), t), w(t)*Rational(4, 9) + 2*y(t) + z(t)*Rational(16, 9)), + Eq(Derivative(z(t), t), w(t)*Rational(-2, 9) + z(t)*Rational(37, 9)), + Eq(Derivative(w(t), t), w(t)*Rational(44, 9) + z(t)*Rational(-4, 9))] + sol4 = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t)), + Eq(y(t), C2*exp(2*t) + 2*C3*exp(4*t)), + Eq(z(t), 2*C3*exp(4*t) + C4*exp(5*t)*Rational(-1, 4)), + Eq(w(t), C3*exp(4*t) + C4*exp(5*t))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq5 = [Eq(x(t).diff(t), x(t)), Eq(y(t).diff(t), y(t)), Eq(z(t).diff(t), z(t)), Eq(w(t).diff(t), w(t))] + sol5 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t)), Eq(w(t), C4*exp(t))] + assert dsolve(eq5) == sol5 + assert checksysodesol(eq5, sol5) == (True, [0, 0, 0, 0]) + + eqs6 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), y(t) + z(t)), + Eq(Derivative(z(t), t), w(t)*Rational(-1, 8) + z(t)), + Eq(Derivative(w(t), t), w(t)/2 + z(t)/2)] + sol6 = [Eq(x(t), C1*exp(t) + C2*t*exp(t) + 4*C4*t*exp(t*Rational(3, 4)) + (4*C3 + 48*C4)*exp(t*Rational(3, + 4))), + Eq(y(t), C2*exp(t) - C4*t*exp(t*Rational(3, 4)) - (C3 + 8*C4)*exp(t*Rational(3, 4))), + Eq(z(t), C4*t*exp(t*Rational(3, 4))/4 + (C3/4 + C4)*exp(t*Rational(3, 4))), + Eq(w(t), C3*exp(t*Rational(3, 4))/2 + C4*t*exp(t*Rational(3, 4))/2)] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq7 = [Eq(Derivative(x(t), t), x(t)), Eq(Derivative(y(t), t), y(t)), Eq(Derivative(z(t), t), z(t)), + Eq(Derivative(w(t), t), w(t)), Eq(Derivative(u(t), t), u(t))] + sol7 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t)), Eq(w(t), C4*exp(t)), + Eq(u(t), C5*exp(t))] + assert dsolve(eq7) == sol7 + assert checksysodesol(eq7, sol7) == (True, [0, 0, 0, 0, 0]) + + eqs8 = [Eq(Derivative(x(t), t), 2*x(t) + y(t)), + Eq(Derivative(y(t), t), 2*y(t)), + Eq(Derivative(z(t), t), 4*z(t)), + Eq(Derivative(w(t), t), u(t) + 5*w(t)), + Eq(Derivative(u(t), t), 5*u(t))] + sol8 = [Eq(x(t), C1*exp(2*t) + C2*t*exp(2*t)), + Eq(y(t), C2*exp(2*t)), + Eq(z(t), C3*exp(4*t)), + Eq(w(t), C4*exp(5*t) + C5*t*exp(5*t)), + Eq(u(t), C5*exp(5*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0, 0, 0, 0]) + + # Regression test case for issue #15574 + # https://github.com/sympy/sympy/issues/15574 + eq9 = [Eq(Derivative(x(t), t), x(t)), Eq(Derivative(y(t), t), y(t)), Eq(Derivative(z(t), t), z(t))] + sol9 = [Eq(x(t), C1*exp(t)), Eq(y(t), C2*exp(t)), Eq(z(t), C3*exp(t))] + assert dsolve(eq9) == sol9 + assert checksysodesol(eq9, sol9) == (True, [0, 0, 0]) + + # Regression test case for issue #15407 + # https://github.com/sympy/sympy/issues/15407 + eqs10 = [Eq(Derivative(x(t), t), (-a_b - a_c)*x(t)), + Eq(Derivative(y(t), t), a_b*y(t)), + Eq(Derivative(z(t), t), a_c*x(t))] + sol10 = [Eq(x(t), -C1*(a_b + a_c)*exp(-t*(a_b + a_c))/a_c), + Eq(y(t), C2*exp(a_b*t)), + Eq(z(t), C1*exp(-t*(a_b + a_c)) + C3)] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0, 0]) + + # Regression test case for issue #14312 + # https://github.com/sympy/sympy/issues/14312 + eqs11 = [Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t)), + Eq(Derivative(z(t), t), k2*y(t))] + sol11 = [Eq(x(t), C1 + C2*k3*exp(-t*(k2 + k3))/k2), + Eq(y(t), -C2*(k2 + k3)*exp(-t*(k2 + k3))/k2), + Eq(z(t), C2*exp(-t*(k2 + k3)) + C3)] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0, 0]) + + # Regression test case for issue #14312 + # https://github.com/sympy/sympy/issues/14312 + eqs12 = [Eq(Derivative(z(t), t), k2*y(t)), + Eq(Derivative(x(t), t), k3*y(t)), + Eq(Derivative(y(t), t), (-k2 - k3)*y(t))] + sol12 = [Eq(z(t), C1 - C2*k2*exp(-t*(k2 + k3))/(k2 + k3)), + Eq(x(t), -C2*k3*exp(-t*(k2 + k3))/(k2 + k3) + C3), + Eq(y(t), C2*exp(-t*(k2 + k3)))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0, 0]) + + f, g, h = symbols('f, g, h', cls=Function) + a, b, c = symbols('a, b, c') + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + eqs13 = [Eq(Derivative(f(t), t), 2*f(t) + g(t)), + Eq(Derivative(g(t), t), a*f(t))] + sol13 = [Eq(f(t), C1*exp(t*(sqrt(a + 1) + 1))/(sqrt(a + 1) - 1) - C2*exp(-t*(sqrt(a + 1) - 1))/(sqrt(a + 1) + + 1)), + Eq(g(t), C1*exp(t*(sqrt(a + 1) + 1)) + C2*exp(-t*(sqrt(a + 1) - 1)))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + eqs14 = [Eq(Derivative(f(t), t), 2*g(t) - 3*h(t)), + Eq(Derivative(g(t), t), -2*f(t) + 4*h(t)), + Eq(Derivative(h(t), t), 3*f(t) - 4*g(t))] + sol14 = [Eq(f(t), 2*C1 - sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(3, 25) + C3*Rational(-8, 25)) - + cos(sqrt(29)*t)*(C2*Rational(8, 25) + sqrt(29)*C3*Rational(3, 25))), + Eq(g(t), C1*Rational(3, 2) + sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(4, 25) + C3*Rational(6, 25)) - + cos(sqrt(29)*t)*(C2*Rational(6, 25) + sqrt(29)*C3*Rational(-4, 25))), + Eq(h(t), C1 + C2*cos(sqrt(29)*t) - C3*sin(sqrt(29)*t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0, 0]) + + eqs15 = [Eq(2*Derivative(f(t), t), 12*g(t) - 12*h(t)), + Eq(3*Derivative(g(t), t), -8*f(t) + 8*h(t)), + Eq(4*Derivative(h(t), t), 6*f(t) - 6*g(t))] + sol15 = [Eq(f(t), C1 - sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(6, 13) + C3*Rational(-16, 13)) - + cos(sqrt(29)*t)*(C2*Rational(16, 13) + sqrt(29)*C3*Rational(6, 13))), + Eq(g(t), C1 + sin(sqrt(29)*t)*(sqrt(29)*C2*Rational(8, 39) + C3*Rational(16, 13)) - + cos(sqrt(29)*t)*(C2*Rational(16, 13) + sqrt(29)*C3*Rational(-8, 39))), + Eq(h(t), C1 + C2*cos(sqrt(29)*t) - C3*sin(sqrt(29)*t))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0, 0]) + + eq16 = (Eq(diff(x(t), t), 21*x(t)), Eq(diff(y(t), t), 17*x(t) + 3*y(t)), + Eq(diff(z(t), t), 5*x(t) + 7*y(t) + 9*z(t))) + sol16 = [Eq(x(t), 216*C1*exp(21*t)/209), + Eq(y(t), 204*C1*exp(21*t)/209 - 6*C2*exp(3*t)/7), + Eq(z(t), C1*exp(21*t) + C2*exp(3*t) + C3*exp(9*t))] + assert dsolve(eq16) == sol16 + assert checksysodesol(eq16, sol16) == (True, [0, 0, 0]) + + eqs17 = [Eq(Derivative(x(t), t), 3*y(t) - 11*z(t)), + Eq(Derivative(y(t), t), -3*x(t) + 7*z(t)), + Eq(Derivative(z(t), t), 11*x(t) - 7*y(t))] + sol17 = [Eq(x(t), C1*Rational(7, 3) - sin(sqrt(179)*t)*(sqrt(179)*C2*Rational(11, 170) + C3*Rational(-21, + 170)) - cos(sqrt(179)*t)*(C2*Rational(21, 170) + sqrt(179)*C3*Rational(11, 170))), + Eq(y(t), C1*Rational(11, 3) + sin(sqrt(179)*t)*(sqrt(179)*C2*Rational(7, 170) + C3*Rational(33, + 170)) - cos(sqrt(179)*t)*(C2*Rational(33, 170) + sqrt(179)*C3*Rational(-7, 170))), + Eq(z(t), C1 + C2*cos(sqrt(179)*t) - C3*sin(sqrt(179)*t))] + assert dsolve(eqs17) == sol17 + assert checksysodesol(eqs17, sol17) == (True, [0, 0, 0]) + + eqs18 = [Eq(3*Derivative(x(t), t), 20*y(t) - 20*z(t)), + Eq(4*Derivative(y(t), t), -15*x(t) + 15*z(t)), + Eq(5*Derivative(z(t), t), 12*x(t) - 12*y(t))] + sol18 = [Eq(x(t), C1 - sin(5*sqrt(2)*t)*(sqrt(2)*C2*Rational(4, 3) - C3) - cos(5*sqrt(2)*t)*(C2 + + sqrt(2)*C3*Rational(4, 3))), + Eq(y(t), C1 + sin(5*sqrt(2)*t)*(sqrt(2)*C2*Rational(3, 4) + C3) - cos(5*sqrt(2)*t)*(C2 + + sqrt(2)*C3*Rational(-3, 4))), + Eq(z(t), C1 + C2*cos(5*sqrt(2)*t) - C3*sin(5*sqrt(2)*t))] + assert dsolve(eqs18) == sol18 + assert checksysodesol(eqs18, sol18) == (True, [0, 0, 0]) + + eqs19 = [Eq(Derivative(x(t), t), 4*x(t) - z(t)), + Eq(Derivative(y(t), t), 2*x(t) + 2*y(t) - z(t)), + Eq(Derivative(z(t), t), 3*x(t) + y(t))] + sol19 = [Eq(x(t), C2*t**2*exp(2*t)/2 + t*(2*C2 + C3)*exp(2*t) + (C1 + C2 + 2*C3)*exp(2*t)), + Eq(y(t), C2*t**2*exp(2*t)/2 + t*(2*C2 + C3)*exp(2*t) + (C1 + 2*C3)*exp(2*t)), + Eq(z(t), C2*t**2*exp(2*t) + t*(3*C2 + 2*C3)*exp(2*t) + (2*C1 + 3*C3)*exp(2*t))] + assert dsolve(eqs19) == sol19 + assert checksysodesol(eqs19, sol19) == (True, [0, 0, 0]) + + eqs20 = [Eq(Derivative(x(t), t), 4*x(t) - y(t) - 2*z(t)), + Eq(Derivative(y(t), t), 2*x(t) + y(t) - 2*z(t)), + Eq(Derivative(z(t), t), 5*x(t) - 3*z(t))] + sol20 = [Eq(x(t), C1*exp(2*t) - sin(t)*(C2*Rational(3, 5) + C3/5) - cos(t)*(C2/5 + C3*Rational(-3, 5))), + Eq(y(t), -sin(t)*(C2*Rational(3, 5) + C3/5) - cos(t)*(C2/5 + C3*Rational(-3, 5))), + Eq(z(t), C1*exp(2*t) - C2*sin(t) + C3*cos(t))] + assert dsolve(eqs20) == sol20 + assert checksysodesol(eqs20, sol20) == (True, [0, 0, 0]) + + eq21 = (Eq(diff(x(t), t), 9*y(t)), Eq(diff(y(t), t), 12*x(t))) + sol21 = [Eq(x(t), -sqrt(3)*C1*exp(-6*sqrt(3)*t)/2 + sqrt(3)*C2*exp(6*sqrt(3)*t)/2), + Eq(y(t), C1*exp(-6*sqrt(3)*t) + C2*exp(6*sqrt(3)*t))] + + assert dsolve(eq21) == sol21 + assert checksysodesol(eq21, sol21) == (True, [0, 0]) + + eqs22 = [Eq(Derivative(x(t), t), 2*x(t) + 4*y(t)), + Eq(Derivative(y(t), t), 12*x(t) + 41*y(t))] + sol22 = [Eq(x(t), C1*(39 - sqrt(1713))*exp(t*(sqrt(1713) + 43)/2)*Rational(-1, 24) + C2*(39 + + sqrt(1713))*exp(t*(43 - sqrt(1713))/2)*Rational(-1, 24)), + Eq(y(t), C1*exp(t*(sqrt(1713) + 43)/2) + C2*exp(t*(43 - sqrt(1713))/2))] + assert dsolve(eqs22) == sol22 + assert checksysodesol(eqs22, sol22) == (True, [0, 0]) + + eqs23 = [Eq(Derivative(x(t), t), x(t) + y(t)), + Eq(Derivative(y(t), t), -2*x(t) + 2*y(t))] + sol23 = [Eq(x(t), (C1/4 + sqrt(7)*C2/4)*cos(sqrt(7)*t/2)*exp(t*Rational(3, 2)) + + sin(sqrt(7)*t/2)*(sqrt(7)*C1/4 + C2*Rational(-1, 4))*exp(t*Rational(3, 2))), + Eq(y(t), C1*cos(sqrt(7)*t/2)*exp(t*Rational(3, 2)) - C2*sin(sqrt(7)*t/2)*exp(t*Rational(3, 2)))] + assert dsolve(eqs23) == sol23 + assert checksysodesol(eqs23, sol23) == (True, [0, 0]) + + # Regression test case for issue #15474 + # https://github.com/sympy/sympy/issues/15474 + a = Symbol("a", real=True) + eq24 = [x(t).diff(t) - a*y(t), y(t).diff(t) + a*x(t)] + sol24 = [Eq(x(t), C1*sin(a*t) + C2*cos(a*t)), Eq(y(t), C1*cos(a*t) - C2*sin(a*t))] + assert dsolve(eq24) == sol24 + assert checksysodesol(eq24, sol24) == (True, [0, 0]) + + # Regression test case for issue #19150 + # https://github.com/sympy/sympy/issues/19150 + eqs25 = [Eq(Derivative(f(t), t), 0), + Eq(Derivative(g(t), t), (f(t) - 2*g(t) + x(t))/(b*c)), + Eq(Derivative(x(t), t), (g(t) - 2*x(t) + y(t))/(b*c)), + Eq(Derivative(y(t), t), (h(t) + x(t) - 2*y(t))/(b*c)), + Eq(Derivative(h(t), t), 0)] + sol25 = [Eq(f(t), -3*C1 + 4*C2), + Eq(g(t), -2*C1 + 3*C2 - C3*exp(-2*t/(b*c)) + C4*exp(-t*(sqrt(2) + 2)/(b*c)) + C5*exp(-t*(2 - + sqrt(2))/(b*c))), + Eq(x(t), -C1 + 2*C2 - sqrt(2)*C4*exp(-t*(sqrt(2) + 2)/(b*c)) + sqrt(2)*C5*exp(-t*(2 - + sqrt(2))/(b*c))), + Eq(y(t), C2 + C3*exp(-2*t/(b*c)) + C4*exp(-t*(sqrt(2) + 2)/(b*c)) + C5*exp(-t*(2 - sqrt(2))/(b*c))), + Eq(h(t), C1)] + assert dsolve(eqs25) == sol25 + assert checksysodesol(eqs25, sol25) == (True, [0, 0, 0, 0, 0]) + + eq26 = [Eq(Derivative(f(t), t), 2*f(t)), Eq(Derivative(g(t), t), 3*f(t) + 7*g(t))] + sol26 = [Eq(f(t), -5*C1*exp(2*t)/3), Eq(g(t), C1*exp(2*t) + C2*exp(7*t))] + assert dsolve(eq26) == sol26 + assert checksysodesol(eq26, sol26) == (True, [0, 0]) + + eq27 = [Eq(Derivative(f(t), t), -9*I*f(t) - 4*g(t)), Eq(Derivative(g(t), t), -4*I*g(t))] + sol27 = [Eq(f(t), 4*I*C1*exp(-4*I*t)/5 + C2*exp(-9*I*t)), Eq(g(t), C1*exp(-4*I*t))] + assert dsolve(eq27) == sol27 + assert checksysodesol(eq27, sol27) == (True, [0, 0]) + + eq28 = [Eq(Derivative(f(t), t), -9*I*f(t)), Eq(Derivative(g(t), t), -4*I*g(t))] + sol28 = [Eq(f(t), C1*exp(-9*I*t)), Eq(g(t), C2*exp(-4*I*t))] + assert dsolve(eq28) == sol28 + assert checksysodesol(eq28, sol28) == (True, [0, 0]) + + eq29 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), 0)] + sol29 = [Eq(f(t), C1), Eq(g(t), C2)] + assert dsolve(eq29) == sol29 + assert checksysodesol(eq29, sol29) == (True, [0, 0]) + + eq30 = [Eq(Derivative(f(t), t), f(t)), Eq(Derivative(g(t), t), 0)] + sol30 = [Eq(f(t), C1*exp(t)), Eq(g(t), C2)] + assert dsolve(eq30) == sol30 + assert checksysodesol(eq30, sol30) == (True, [0, 0]) + + eq31 = [Eq(Derivative(f(t), t), g(t)), Eq(Derivative(g(t), t), 0)] + sol31 = [Eq(f(t), C1 + C2*t), Eq(g(t), C2)] + assert dsolve(eq31) == sol31 + assert checksysodesol(eq31, sol31) == (True, [0, 0]) + + eq32 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), f(t))] + sol32 = [Eq(f(t), C1), Eq(g(t), C1*t + C2)] + assert dsolve(eq32) == sol32 + assert checksysodesol(eq32, sol32) == (True, [0, 0]) + + eq33 = [Eq(Derivative(f(t), t), 0), Eq(Derivative(g(t), t), g(t))] + sol33 = [Eq(f(t), C1), Eq(g(t), C2*exp(t))] + assert dsolve(eq33) == sol33 + assert checksysodesol(eq33, sol33) == (True, [0, 0]) + + eq34 = [Eq(Derivative(f(t), t), f(t)), Eq(Derivative(g(t), t), I*g(t))] + sol34 = [Eq(f(t), C1*exp(t)), Eq(g(t), C2*exp(I*t))] + assert dsolve(eq34) == sol34 + assert checksysodesol(eq34, sol34) == (True, [0, 0]) + + eq35 = [Eq(Derivative(f(t), t), I*f(t)), Eq(Derivative(g(t), t), -I*g(t))] + sol35 = [Eq(f(t), C1*exp(I*t)), Eq(g(t), C2*exp(-I*t))] + assert dsolve(eq35) == sol35 + assert checksysodesol(eq35, sol35) == (True, [0, 0]) + + eq36 = [Eq(Derivative(f(t), t), I*g(t)), Eq(Derivative(g(t), t), 0)] + sol36 = [Eq(f(t), I*C1 + I*C2*t), Eq(g(t), C2)] + assert dsolve(eq36) == sol36 + assert checksysodesol(eq36, sol36) == (True, [0, 0]) + + eq37 = [Eq(Derivative(f(t), t), I*g(t)), Eq(Derivative(g(t), t), I*f(t))] + sol37 = [Eq(f(t), -C1*exp(-I*t) + C2*exp(I*t)), Eq(g(t), C1*exp(-I*t) + C2*exp(I*t))] + assert dsolve(eq37) == sol37 + assert checksysodesol(eq37, sol37) == (True, [0, 0]) + + # Multiple systems + eq1 = [Eq(Derivative(f(t), t)**2, g(t)**2), Eq(-f(t) + Derivative(g(t), t), 0)] + sol1 = [[Eq(f(t), -C1*sin(t) - C2*cos(t)), + Eq(g(t), C1*cos(t) - C2*sin(t))], + [Eq(f(t), -C1*exp(-t) + C2*exp(t)), + Eq(g(t), C1*exp(-t) + C2*exp(t))]] + assert dsolve(eq1) == sol1 + for sol in sol1: + assert checksysodesol(eq1, sol) == (True, [0, 0]) + + +def test_sysode_linear_neq_order1_type2(): + + f, g, h, k = symbols('f g h k', cls=Function) + x, t, a, b, c, d, y = symbols('x t a b c d y') + k1, k2 = symbols('k1 k2') + + + eqs1 = [Eq(Derivative(f(x), x), f(x) + g(x) + 5), + Eq(Derivative(g(x), x), -f(x) - g(x) + 7)] + sol1 = [Eq(f(x), C1 + C2 + 6*x**2 + x*(C2 + 5)), + Eq(g(x), -C1 - 6*x**2 - x*(C2 - 7))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(f(x), x), f(x) + g(x) + 5), + Eq(Derivative(g(x), x), f(x) + g(x) + 7)] + sol2 = [Eq(f(x), -C1 + C2*exp(2*x) - x - 3), + Eq(g(x), C1 + C2*exp(2*x) + x - 3)] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), x), f(x) + 5), + Eq(Derivative(g(x), x), f(x) + 7)] + sol3 = [Eq(f(x), C1*exp(x) - 5), + Eq(g(x), C1*exp(x) + C2 + 2*x - 5)] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(Derivative(f(x), x), f(x) + exp(x)), + Eq(Derivative(g(x), x), x*exp(x) + f(x) + g(x))] + sol4 = [Eq(f(x), C1*exp(x) + x*exp(x)), + Eq(g(x), C1*x*exp(x) + C2*exp(x) + x**2*exp(x))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), 5*x + f(x) + g(x)), + Eq(Derivative(g(x), x), f(x) - g(x))] + sol5 = [Eq(f(x), C1*(1 + sqrt(2))*exp(sqrt(2)*x) + C2*(1 - sqrt(2))*exp(-sqrt(2)*x) + x*Rational(-5, 2) + + Rational(-5, 2)), + Eq(g(x), C1*exp(sqrt(2)*x) + C2*exp(-sqrt(2)*x) + x*Rational(-5, 2))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), -9*f(x) - 4*g(x)), + Eq(Derivative(g(x), x), -4*g(x)), + Eq(Derivative(h(x), x), h(x) + exp(x))] + sol6 = [Eq(f(x), C2*exp(-4*x)*Rational(-4, 5) + C1*exp(-9*x)), + Eq(g(x), C2*exp(-4*x)), + Eq(h(x), C3*exp(x) + x*exp(x))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0]) + + # Regression test case for issue #8859 + # https://github.com/sympy/sympy/issues/8859 + eqs7 = [Eq(Derivative(f(t), t), 3*t + f(t)), + Eq(Derivative(g(t), t), g(t))] + sol7 = [Eq(f(t), C1*exp(t) - 3*t - 3), + Eq(g(t), C2*exp(t))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + # Regression test case for issue #8567 + # https://github.com/sympy/sympy/issues/8567 + eqs8 = [Eq(Derivative(f(t), t), f(t) + 2*g(t)), + Eq(Derivative(g(t), t), -2*f(t) + g(t) + 2*exp(t))] + sol8 = [Eq(f(t), C1*exp(t)*sin(2*t) + C2*exp(t)*cos(2*t) + + exp(t)*sin(2*t)**2 + exp(t)*cos(2*t)**2), + Eq(g(t), C1*exp(t)*cos(2*t) - C2*exp(t)*sin(2*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + # Regression test case for issue #19150 + # https://github.com/sympy/sympy/issues/19150 + eqs9 = [Eq(Derivative(f(t), t), (c - 2*f(t) + g(t))/(a*b)), + Eq(Derivative(g(t), t), (f(t) - 2*g(t) + h(t))/(a*b)), + Eq(Derivative(h(t), t), (d + g(t) - 2*h(t))/(a*b))] + sol9 = [Eq(f(t), -C1*exp(-2*t/(a*b)) + C2*exp(-t*(sqrt(2) + 2)/(a*b)) + C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 4), 3*c + d, evaluate=False)), + Eq(g(t), -sqrt(2)*C2*exp(-t*(sqrt(2) + 2)/(a*b)) + sqrt(2)*C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 2), c + d, evaluate=False)), + Eq(h(t), C1*exp(-2*t/(a*b)) + C2*exp(-t*(sqrt(2) + 2)/(a*b)) + C3*exp(-t*(2 - sqrt(2))/(a*b)) + + Mul(Rational(1, 4), c + 3*d, evaluate=False))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0, 0]) + + # Regression test case for issue #16635 + # https://github.com/sympy/sympy/issues/16635 + eqs10 = [Eq(Derivative(f(t), t), 15*t + f(t) - g(t) - 10), + Eq(Derivative(g(t), t), -15*t + f(t) - g(t) - 5)] + sol10 = [Eq(f(t), C1 + C2 + 5*t**3 + 5*t**2 + t*(C2 - 10)), + Eq(g(t), C1 + 5*t**3 - 10*t**2 + t*(C2 - 5))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + # Multiple solutions + eqs11 = [Eq(Derivative(f(t), t)**2 - 2*Derivative(f(t), t) + 1, 4), + Eq(-y*f(t) + Derivative(g(t), t), 0)] + sol11 = [[Eq(f(t), C1 - t), Eq(g(t), C1*t*y + C2*y + t**2*y*Rational(-1, 2))], + [Eq(f(t), C1 + 3*t), Eq(g(t), C1*t*y + C2*y + t**2*y*Rational(3, 2))]] + assert dsolve(eqs11) == sol11 + for s11 in sol11: + assert checksysodesol(eqs11, s11) == (True, [0, 0]) + + # test case for issue #19831 + # https://github.com/sympy/sympy/issues/19831 + n = symbols('n', positive=True) + x0 = symbols('x_0') + t0 = symbols('t_0') + x_0 = symbols('x_0') + t_0 = symbols('t_0') + t = symbols('t') + x = Function('x') + y = Function('y') + T = symbols('T') + + eqs12 = [Eq(Derivative(y(t), t), x(t)), + Eq(Derivative(x(t), t), n*(y(t) + 1))] + sol12 = [Eq(y(t), C1*exp(sqrt(n)*t)*n**Rational(-1, 2) - C2*exp(-sqrt(n)*t)*n**Rational(-1, 2) - 1), + Eq(x(t), C1*exp(sqrt(n)*t) + C2*exp(-sqrt(n)*t))] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + sol12b = [ + Eq(y(t), (T*exp(-sqrt(n)*t_0)/2 + exp(-sqrt(n)*t_0)/2 + + x_0*exp(-sqrt(n)*t_0)/(2*sqrt(n)))*exp(sqrt(n)*t) + + (T*exp(sqrt(n)*t_0)/2 + exp(sqrt(n)*t_0)/2 - + x_0*exp(sqrt(n)*t_0)/(2*sqrt(n)))*exp(-sqrt(n)*t) - 1), + Eq(x(t), (T*sqrt(n)*exp(-sqrt(n)*t_0)/2 + sqrt(n)*exp(-sqrt(n)*t_0)/2 + + x_0*exp(-sqrt(n)*t_0)/2)*exp(sqrt(n)*t) + - (T*sqrt(n)*exp(sqrt(n)*t_0)/2 + sqrt(n)*exp(sqrt(n)*t_0)/2 - + x_0*exp(sqrt(n)*t_0)/2)*exp(-sqrt(n)*t)) + ] + assert dsolve(eqs12, ics={y(t0): T, x(t0): x0}) == sol12b + assert checksysodesol(eqs12, sol12b) == (True, [0, 0]) + + #Test cases added for the issue 19763 + #https://github.com/sympy/sympy/issues/19763 + + eq13 = [Eq(Derivative(f(t), t), f(t) + g(t) + 9), + Eq(Derivative(g(t), t), 2*f(t) + 5*g(t) + 23)] + sol13 = [Eq(f(t), -C1*(2 + sqrt(6))*exp(t*(3 - sqrt(6)))/2 - C2*(2 - sqrt(6))*exp(t*(sqrt(6) + 3))/2 - + Rational(22,3)), + Eq(g(t), C1*exp(t*(3 - sqrt(6))) + C2*exp(t*(sqrt(6) + 3)) - Rational(5,3))] + assert dsolve(eq13) == sol13 + assert checksysodesol(eq13, sol13) == (True, [0, 0]) + + eq14 = [Eq(Derivative(f(t), t), f(t) + g(t) + 81), + Eq(Derivative(g(t), t), -2*f(t) + g(t) + 23)] + sol14 = [Eq(f(t), sqrt(2)*C1*exp(t)*sin(sqrt(2)*t)/2 + + sqrt(2)*C2*exp(t)*cos(sqrt(2)*t)/2 + - 58*sin(sqrt(2)*t)**2/3 - 58*cos(sqrt(2)*t)**2/3), + Eq(g(t), C1*exp(t)*cos(sqrt(2)*t) - C2*exp(t)*sin(sqrt(2)*t) + - 185*sin(sqrt(2)*t)**2/3 - 185*cos(sqrt(2)*t)**2/3)] + assert dsolve(eq14) == sol14 + assert checksysodesol(eq14, sol14) == (True, [0,0]) + + eq15 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), 3*f(t) + 4*g(t) + k2)] + sol15 = [Eq(f(t), -C1*(3 - sqrt(33))*exp(t*(5 + sqrt(33))/2)/6 - + C2*(3 + sqrt(33))*exp(t*(5 - sqrt(33))/2)/6 + 2*k1 - k2), + Eq(g(t), C1*exp(t*(5 + sqrt(33))/2) + C2*exp(t*(5 - sqrt(33))/2) - + Mul(Rational(1,2), 3*k1 - k2, evaluate = False))] + assert dsolve(eq15) == sol15 + assert checksysodesol(eq15, sol15) == (True, [0,0]) + + eq16 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), k2)] + sol16 = [Eq(f(t), C1 + k1*t), + Eq(g(t), C2 + k2*t)] + assert dsolve(eq16) == sol16 + assert checksysodesol(eq16, sol16) == (True, [0,0]) + + eq17 = [Eq(Derivative(f(t), t), 0), + Eq(Derivative(g(t), t), c*f(t) + k2)] + sol17 = [Eq(f(t), C1), + Eq(g(t), C2*c + t*(C1*c + k2))] + assert dsolve(eq17) == sol17 + assert checksysodesol(eq17 , sol17) == (True , [0,0]) + + eq18 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), f(t) + k2)] + sol18 = [Eq(f(t), C1 + k1*t), + Eq(g(t), C2 + k1*t**2/2 + t*(C1 + k2))] + assert dsolve(eq18) == sol18 + assert checksysodesol(eq18 , sol18) == (True , [0,0]) + + eq19 = [Eq(Derivative(f(t), t), k1), + Eq(Derivative(g(t), t), f(t) + 2*g(t) + k2)] + sol19 = [Eq(f(t), -2*C1 + k1*t), + Eq(g(t), C1 + C2*exp(2*t) - k1*t/2 - Mul(Rational(1,4), k1 + 2*k2 , evaluate = False))] + assert dsolve(eq19) == sol19 + assert checksysodesol(eq19 , sol19) == (True , [0,0]) + + eq20 = [Eq(diff(f(t), t), f(t) + k1), + Eq(diff(g(t), t), k2)] + sol20 = [Eq(f(t), C1*exp(t) - k1), + Eq(g(t), C2 + k2*t)] + assert dsolve(eq20) == sol20 + assert checksysodesol(eq20 , sol20) == (True , [0,0]) + + eq21 = [Eq(diff(f(t), t), g(t) + k1), + Eq(diff(g(t), t), 0)] + sol21 = [Eq(f(t), C1 + t*(C2 + k1)), + Eq(g(t), C2)] + assert dsolve(eq21) == sol21 + assert checksysodesol(eq21 , sol21) == (True , [0,0]) + + eq22 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), k2)] + sol22 = [Eq(f(t), -2*C1 + C2*exp(t) - k1 - 2*k2*t - 2*k2), + Eq(g(t), C1 + k2*t)] + assert dsolve(eq22) == sol22 + assert checksysodesol(eq22 , sol22) == (True , [0,0]) + + eq23 = [Eq(Derivative(f(t), t), g(t) + k1), + Eq(Derivative(g(t), t), 2*g(t) + k2)] + sol23 = [Eq(f(t), C1 + C2*exp(2*t)/2 - k2/4 + t*(2*k1 - k2)/2), + Eq(g(t), C2*exp(2*t) - k2/2)] + assert dsolve(eq23) == sol23 + assert checksysodesol(eq23 , sol23) == (True , [0,0]) + + eq24 = [Eq(Derivative(f(t), t), f(t) + k1), + Eq(Derivative(g(t), t), 2*f(t) + k2)] + sol24 = [Eq(f(t), C1*exp(t)/2 - k1), + Eq(g(t), C1*exp(t) + C2 - 2*k1 - t*(2*k1 - k2))] + assert dsolve(eq24) == sol24 + assert checksysodesol(eq24 , sol24) == (True , [0,0]) + + eq25 = [Eq(Derivative(f(t), t), f(t) + 2*g(t) + k1), + Eq(Derivative(g(t), t), 3*f(t) + 6*g(t) + k2)] + sol25 = [Eq(f(t), -2*C1 + C2*exp(7*t)/3 + 2*t*(3*k1 - k2)/7 - + Mul(Rational(1,49), k1 + 2*k2 , evaluate = False)), + Eq(g(t), C1 + C2*exp(7*t) - t*(3*k1 - k2)/7 - + Mul(Rational(3,49), k1 + 2*k2 , evaluate = False))] + assert dsolve(eq25) == sol25 + assert checksysodesol(eq25 , sol25) == (True , [0,0]) + + eq26 = [Eq(Derivative(f(t), t), 2*f(t) - g(t) + k1), + Eq(Derivative(g(t), t), 4*f(t) - 2*g(t) + 2*k1)] + sol26 = [Eq(f(t), C1 + 2*C2 + t*(2*C1 + k1)), + Eq(g(t), 4*C2 + t*(4*C1 + 2*k1))] + assert dsolve(eq26) == sol26 + assert checksysodesol(eq26 , sol26) == (True , [0,0]) + + # Test Case added for issue #22715 + # https://github.com/sympy/sympy/issues/22715 + + eq27 = [Eq(diff(x(t),t),-1*y(t)+10), Eq(diff(y(t),t),5*x(t)-2*y(t)+3)] + sol27 = [Eq(x(t), (C1/5 - 2*C2/5)*exp(-t)*cos(2*t) + - (2*C1/5 + C2/5)*exp(-t)*sin(2*t) + + 17*sin(2*t)**2/5 + 17*cos(2*t)**2/5), + Eq(y(t), C1*exp(-t)*cos(2*t) - C2*exp(-t)*sin(2*t) + + 10*sin(2*t)**2 + 10*cos(2*t)**2)] + assert dsolve(eq27) == sol27 + assert checksysodesol(eq27 , sol27) == (True , [0,0]) + + +def test_sysode_linear_neq_order1_type3(): + + f, g, h, k, x0 , y0 = symbols('f g h k x0 y0', cls=Function) + x, t, a = symbols('x t a') + r = symbols('r', real=True) + + eqs1 = [Eq(Derivative(f(r), r), r*g(r) + f(r)), + Eq(Derivative(g(r), r), -r*f(r) + g(r))] + sol1 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(Derivative(f(x), x), x**2*g(x) + x*f(x)), + Eq(Derivative(g(x), x), 2*x**2*f(x) + (3*x**2 + x)*g(x))] + sol2 = [Eq(f(x), (sqrt(17)*C1/17 + C2*(17 - 3*sqrt(17))/34)*exp(x**3*(3 + sqrt(17))/6 + x**2/2) - + exp(x**3*(3 - sqrt(17))/6 + x**2/2)*(sqrt(17)*C1/17 + C2*(3*sqrt(17) + 17)*Rational(-1, 34))), + Eq(g(x), exp(x**3*(3 - sqrt(17))/6 + x**2/2)*(C1*(17 - 3*sqrt(17))/34 + sqrt(17)*C2*Rational(-2, + 17)) + exp(x**3*(3 + sqrt(17))/6 + x**2/2)*(C1*(3*sqrt(17) + 17)/34 + sqrt(17)*C2*Rational(2, 17)))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(f(x).diff(x), x*f(x) + g(x)), + Eq(g(x).diff(x), -f(x) + x*g(x))] + sol3 = [Eq(f(x), (C1/2 + I*C2/2)*exp(x**2/2 - I*x) + exp(x**2/2 + I*x)*(C1/2 + I*C2*Rational(-1, 2))), + Eq(g(x), (I*C1/2 + C2/2)*exp(x**2/2 + I*x) - exp(x**2/2 - I*x)*(I*C1/2 + C2*Rational(-1, 2)))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(f(x).diff(x), x*(f(x) + g(x) + h(x))), Eq(g(x).diff(x), x*(f(x) + g(x) + h(x))), + Eq(h(x).diff(x), x*(f(x) + g(x) + h(x)))] + sol4 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0]) + + eqs5 = [Eq(f(x).diff(x), x**2*(f(x) + g(x) + h(x))), Eq(g(x).diff(x), x**2*(f(x) + g(x) + h(x))), + Eq(h(x).diff(x), x**2*(f(x) + g(x) + h(x)))] + sol5 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(x**3))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x) + k(x))), + Eq(Derivative(k(x), x), x*(f(x) + g(x) + h(x) + k(x)))] + sol6 = [Eq(f(x), -C1/4 - C2/4 - C3/4 + 3*C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(g(x), 3*C1/4 - C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(h(x), -C1/4 + 3*C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2)), + Eq(k(x), -C1/4 - C2/4 + 3*C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + C4/4)*exp(2*x**2))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + y = symbols("y", real=True) + + eqs7 = [Eq(Derivative(f(y), y), y*f(y) + g(y)), + Eq(Derivative(g(y), y), y*g(y) - f(y))] + sol7 = [Eq(f(y), C1*exp(y**2/2)*sin(y) + C2*exp(y**2/2)*cos(y)), + Eq(g(y), C1*exp(y**2/2)*cos(y) - C2*exp(y**2/2)*sin(y))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + #Test cases added for the issue 19763 + #https://github.com/sympy/sympy/issues/19763 + + eqs8 = [Eq(Derivative(f(t), t), 5*t*f(t) + 2*h(t)), + Eq(Derivative(h(t), t), 2*f(t) + 5*t*h(t))] + sol8 = [Eq(f(t), Mul(-1, (C1/2 - C2/2), evaluate = False)*exp(5*t**2/2 - 2*t) + (C1/2 + C2/2)*exp(5*t**2/2 + 2*t)), + Eq(h(t), (C1/2 - C2/2)*exp(5*t**2/2 - 2*t) + (C1/2 + C2/2)*exp(5*t**2/2 + 2*t))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + eqs9 = [Eq(diff(f(t), t), 5*t*f(t) + t**2*g(t)), + Eq(diff(g(t), t), -t**2*f(t) + 5*t*g(t))] + sol9 = [Eq(f(t), (C1/2 - I*C2/2)*exp(I*t**3/3 + 5*t**2/2) + (C1/2 + I*C2/2)*exp(-I*t**3/3 + 5*t**2/2)), + Eq(g(t), Mul(-1, (I*C1/2 - C2/2) , evaluate = False)*exp(-I*t**3/3 + 5*t**2/2) + (I*C1/2 + C2/2)*exp(I*t**3/3 + 5*t**2/2))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9 , sol9) == (True , [0,0]) + + eqs10 = [Eq(diff(f(t), t), t**2*g(t) + 5*t*f(t)), + Eq(diff(g(t), t), -t**2*f(t) + (9*t**2 + 5*t)*g(t))] + sol10 = [Eq(f(t), (C1*(77 - 9*sqrt(77))/154 + sqrt(77)*C2/77)*exp(t**3*(sqrt(77) + 9)/6 + 5*t**2/2) + (C1*(77 + 9*sqrt(77))/154 - sqrt(77)*C2/77)*exp(t**3*(9 - sqrt(77))/6 + 5*t**2/2)), + Eq(g(t), (sqrt(77)*C1/77 + C2*(77 - 9*sqrt(77))/154)*exp(t**3*(9 - sqrt(77))/6 + 5*t**2/2) - (sqrt(77)*C1/77 - C2*(77 + 9*sqrt(77))/154)*exp(t**3*(sqrt(77) + 9)/6 + 5*t**2/2))] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10 , sol10) == (True , [0,0]) + + eqs11 = [Eq(diff(f(t), t), 5*t*f(t) + t**2*g(t)), + Eq(diff(g(t), t), (1-t**2)*f(t) + (5*t + 9*t**2)*g(t))] + sol11 = [Eq(f(t), C1*x0(t) + C2*x0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t)), + Eq(g(t), C1*y0(t) + C2*(y0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t) + exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)))] + assert dsolve(eqs11) == sol11 + +@slow +def test_sysode_linear_neq_order1_type4(): + + f, g, h, k = symbols('f g h k', cls=Function) + x, t, a = symbols('x t a') + r = symbols('r', real=True) + + eqs1 = [Eq(diff(f(r), r), f(r) + r*g(r) + r**2), Eq(diff(g(r), r), -r*f(r) + g(r) + r)] + sol1 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2) + exp(r)*sin(r**2/2)*Integral(r**2*exp(-r)*sin(r**2/2) + + r*exp(-r)*cos(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r**2*exp(-r)*cos(r**2/2) - r*exp(-r)*sin(r**2/2), r)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2) - exp(r)*sin(r**2/2)*Integral(r**2*exp(-r)*cos(r**2/2) - + r*exp(-r)*sin(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r**2*exp(-r)*sin(r**2/2) + r*exp(-r)*cos(r**2/2), r))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(diff(f(r), r), f(r) + r*g(r) + r), Eq(diff(g(r), r), -r*f(r) + g(r) + log(r))] + sol2 = [Eq(f(r), C1*exp(r)*sin(r**2/2) + C2*exp(r)*cos(r**2/2) + exp(r)*sin(r**2/2)*Integral(r*exp(-r)*sin(r**2/2) + + exp(-r)*log(r)*cos(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r*exp(-r)*cos(r**2/2) - exp(-r)*log(r)*sin( + r**2/2), r)), + Eq(g(r), C1*exp(r)*cos(r**2/2) - C2*exp(r)*sin(r**2/2) - exp(r)*sin(r**2/2)*Integral(r*exp(-r)*cos(r**2/2) - + exp(-r)*log(r)*sin(r**2/2), r) + exp(r)*cos(r**2/2)*Integral(r*exp(-r)*sin(r**2/2) + exp(-r)*log(r)*cos( + r**2/2), r))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs2, simplify=False, doit=False) == [sol2] + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x)) + x), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x)) + x), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x)) + 1)] + sol3 = [Eq(f(x), C1*Rational(-1, 3) + C2*Rational(-1, 3) + C3*Rational(2, 3) + x**2/6 + x*Rational(-1, 3) + + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9)), + Eq(g(x), C1*Rational(2, 3) + C2*Rational(-1, 3) + C3*Rational(-1, 3) + x**2/6 + x*Rational(-1, 3) + + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9)), + Eq(h(x), C1*Rational(-1, 3) + C2*Rational(2, 3) + C3*Rational(-1, 3) + x**2*Rational(-1, 3) + + x*Rational(2, 3) + (C1/3 + C2/3 + C3/3)*exp(x**2*Rational(3, 2)) + + sqrt(6)*sqrt(pi)*erf(sqrt(6)*x/2)*exp(x**2*Rational(3, 2))/18 + Rational(-2, 9))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0]) + + eqs4 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x)) + sin(x)), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x)) + sin(x)), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x)) + sin(x))] + sol4 = [Eq(f(x), C1*Rational(-1, 3) + C2*Rational(-1, 3) + C3*Rational(2, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2))), + Eq(g(x), C1*Rational(2, 3) + C2*Rational(-1, 3) + C3*Rational(-1, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2))), + Eq(h(x), C1*Rational(-1, 3) + C2*Rational(2, 3) + C3*Rational(-1, 3) + (C1/3 + C2/3 + + C3/3)*exp(x**2*Rational(3, 2)) + Integral(sin(x)*exp(x**2*Rational(-3, 2)), x)*exp(x**2*Rational(3, + 2)))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(g(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(h(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(k(x), x), x*(f(x) + g(x) + h(x) + k(x) + 1))] + sol5 = [Eq(f(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(3, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(g(x), C1*Rational(3, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(h(x), C1*Rational(-1, 4) + C2*Rational(3, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4)), + Eq(k(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(3, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(2*x**2) + Rational(-1, 4))] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0, 0]) + + eqs6 = [Eq(Derivative(f(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(g(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(h(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1)), + Eq(Derivative(k(x), x), x**2*(f(x) + g(x) + h(x) + k(x) + 1))] + sol6 = [Eq(f(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(3, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(g(x), C1*Rational(3, 4) + C2*Rational(-1, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(h(x), C1*Rational(-1, 4) + C2*Rational(3, 4) + C3*Rational(-1, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4)), + Eq(k(x), C1*Rational(-1, 4) + C2*Rational(-1, 4) + C3*Rational(3, 4) + C4*Rational(-1, 4) + (C1/4 + + C2/4 + C3/4 + C4/4)*exp(x**3*Rational(4, 3)) + Rational(-1, 4))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0, 0, 0]) + + eqs7 = [Eq(Derivative(f(x), x), (f(x) + g(x) + h(x))*log(x) + sin(x)), Eq(Derivative(g(x), x), (f(x) + g(x) + + h(x))*log(x) + sin(x)), Eq(Derivative(h(x), x), (f(x) + g(x) + h(x))*log(x) + sin(x))] + sol7 = [Eq(f(x), -C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x)), + Eq(g(x), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x)), + Eq(h(x), -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + + C3/3)*exp(x*(3*log(x) - 3)) + exp(x*(3*log(x) - + 3))*Integral(exp(3*x)*exp(-3*x*log(x))*sin(x), x))] + with dotprodsimp(True): + assert dsolve(eqs7, simplify=False, doit=False) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0, 0]) + + eqs8 = [Eq(Derivative(f(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + sin(x)), Eq(Derivative(g(x), x), (f(x) + + g(x) + h(x) + k(x))*log(x) + sin(x)), Eq(Derivative(h(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + + sin(x)), Eq(Derivative(k(x), x), (f(x) + g(x) + h(x) + k(x))*log(x) + sin(x))] + sol8 = [Eq(f(x), -C1/4 - C2/4 - C3/4 + 3*C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(g(x), 3*C1/4 - C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(h(x), -C1/4 + 3*C2/4 - C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x)), + Eq(k(x), -C1/4 - C2/4 + 3*C3/4 - C4/4 + (C1/4 + C2/4 + C3/4 + + C4/4)*exp(x*(4*log(x) - 4)) + exp(x*(4*log(x) - + 4))*Integral(exp(4*x)*exp(-4*x*log(x))*sin(x), x))] + with dotprodsimp(True): + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0, 0, 0]) + + +def test_sysode_linear_neq_order1_type5_type6(): + f, g = symbols("f g", cls=Function) + x, x_ = symbols("x x_") + + # Type 5 + eqs1 = [Eq(Derivative(f(x), x), (2*f(x) + g(x))/x), Eq(Derivative(g(x), x), (f(x) + 2*g(x))/x)] + sol1 = [Eq(f(x), -C1*x + C2*x**3), Eq(g(x), C1*x + C2*x**3)] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + # Type 6 + eqs2 = [Eq(Derivative(f(x), x), (2*f(x) + g(x) + 1)/x), + Eq(Derivative(g(x), x), (x + f(x) + 2*g(x))/x)] + sol2 = [Eq(f(x), C2*x**3 - x*(C1 + Rational(1, 4)) + x*log(x)*Rational(-1, 2) + Rational(-2, 3)), + Eq(g(x), C2*x**3 + x*log(x)/2 + x*(C1 + Rational(-1, 4)) + Rational(1, 3))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + +def test_higher_order_to_first_order(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + eqs1 = [Eq(Derivative(f(x), (x, 2)), 2*f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), -f(x))] + sol1 = [Eq(f(x), -C2*x*exp(-x) + C3*x*exp(x) - (C1 - C2)*exp(-x) + (C3 + C4)*exp(x)), + Eq(g(x), C2*x*exp(-x) - C3*x*exp(x) + (C1 + C2)*exp(-x) + (C3 - C4)*exp(x))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + eqs2 = [Eq(f(x).diff(x, 2), 0), Eq(g(x).diff(x, 2), f(x))] + sol2 = [Eq(f(x), C1 + C2*x), Eq(g(x), C1*x**2/2 + C2*x**3/6 + C3 + C4*x)] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = [Eq(Derivative(f(x), (x, 2)), 2*f(x)), + Eq(Derivative(g(x), (x, 2)), -f(x) + 2*g(x))] + sol3 = [Eq(f(x), 4*C1*exp(-sqrt(2)*x) + 4*C2*exp(sqrt(2)*x)), + Eq(g(x), sqrt(2)*C1*x*exp(-sqrt(2)*x) - sqrt(2)*C2*x*exp(sqrt(2)*x) + (C1 + + sqrt(2)*C4)*exp(-sqrt(2)*x) + (C2 - sqrt(2)*C3)*exp(sqrt(2)*x))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + eqs4 = [Eq(Derivative(f(x), (x, 2)), 2*f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), 2*g(x))] + sol4 = [Eq(f(x), C1*x*exp(sqrt(2)*x)/4 + C3*x*exp(-sqrt(2)*x)/4 + (C2/4 + sqrt(2)*C3/8)*exp(-sqrt(2)*x) - + exp(sqrt(2)*x)*(sqrt(2)*C1/8 + C4*Rational(-1, 4))), + Eq(g(x), sqrt(2)*C1*exp(sqrt(2)*x)/2 + sqrt(2)*C3*exp(-sqrt(2)*x)*Rational(-1, 2))] + assert dsolve(eqs4) == sol4 + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + eqs5 = [Eq(f(x).diff(x, 2), f(x)), Eq(g(x).diff(x, 2), f(x))] + sol5 = [Eq(f(x), -C1*exp(-x) + C2*exp(x)), Eq(g(x), -C1*exp(-x) + C2*exp(x) + C3 + C4*x)] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + eqs6 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x)), + Eq(Derivative(g(x), (x, 2)), -f(x) - g(x))] + sol6 = [Eq(f(x), C1 + C2*x**2/2 + C2 + C4*x**3/6 + x*(C3 + C4)), + Eq(g(x), -C1 + C2*x**2*Rational(-1, 2) - C3*x + C4*x**3*Rational(-1, 6))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + eqs7 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x) + 1), + Eq(Derivative(g(x), (x, 2)), f(x) + g(x) + 1)] + sol7 = [Eq(f(x), -C1 - C2*x + sqrt(2)*C3*exp(sqrt(2)*x)/2 + sqrt(2)*C4*exp(-sqrt(2)*x)*Rational(-1, 2) + + Rational(-1, 2)), + Eq(g(x), C1 + C2*x + sqrt(2)*C3*exp(sqrt(2)*x)/2 + sqrt(2)*C4*exp(-sqrt(2)*x)*Rational(-1, 2) + + Rational(-1, 2))] + assert dsolve(eqs7) == sol7 + assert checksysodesol(eqs7, sol7) == (True, [0, 0]) + + eqs8 = [Eq(Derivative(f(x), (x, 2)), f(x) + g(x) + 1), + Eq(Derivative(g(x), (x, 2)), -f(x) - g(x) + 1)] + sol8 = [Eq(f(x), C1 + C2 + C4*x**3/6 + x**4/12 + x**2*(C2/2 + Rational(1, 2)) + x*(C3 + C4)), + Eq(g(x), -C1 - C3*x + C4*x**3*Rational(-1, 6) + x**4*Rational(-1, 12) - x**2*(C2/2 + Rational(-1, + 2)))] + assert dsolve(eqs8) == sol8 + assert checksysodesol(eqs8, sol8) == (True, [0, 0]) + + x, y = symbols('x, y', cls=Function) + t, l = symbols('t, l') + + eqs10 = [Eq(Derivative(x(t), (t, 2)), 5*x(t) + 43*y(t)), + Eq(Derivative(y(t), (t, 2)), x(t) + 9*y(t))] + sol10 = [Eq(x(t), C1*(61 - 9*sqrt(47))*sqrt(sqrt(47) + 7)*exp(-t*sqrt(sqrt(47) + 7))/2 + C2*sqrt(7 - + sqrt(47))*(61 + 9*sqrt(47))*exp(-t*sqrt(7 - sqrt(47)))/2 + C3*(61 - 9*sqrt(47))*sqrt(sqrt(47) + + 7)*exp(t*sqrt(sqrt(47) + 7))*Rational(-1, 2) + C4*sqrt(7 - sqrt(47))*(61 + 9*sqrt(47))*exp(t*sqrt(7 + - sqrt(47)))*Rational(-1, 2)), + Eq(y(t), C1*(7 - sqrt(47))*sqrt(sqrt(47) + 7)*exp(-t*sqrt(sqrt(47) + 7))*Rational(-1, 2) + C2*sqrt(7 + - sqrt(47))*(sqrt(47) + 7)*exp(-t*sqrt(7 - sqrt(47)))*Rational(-1, 2) + C3*(7 - + sqrt(47))*sqrt(sqrt(47) + 7)*exp(t*sqrt(sqrt(47) + 7))/2 + C4*sqrt(7 - sqrt(47))*(sqrt(47) + + 7)*exp(t*sqrt(7 - sqrt(47)))/2)] + assert dsolve(eqs10) == sol10 + assert checksysodesol(eqs10, sol10) == (True, [0, 0]) + + eqs11 = [Eq(7*x(t) + Derivative(x(t), (t, 2)) - 9*Derivative(y(t), t), 0), + Eq(7*y(t) + 9*Derivative(x(t), t) + Derivative(y(t), (t, 2)), 0)] + sol11 = [Eq(y(t), C1*(9 - sqrt(109))*sin(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)/14 + C2*(9 - + sqrt(109))*cos(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C3*(9 + + sqrt(109))*sin(sqrt(2)*t*sqrt(95 - 9*sqrt(109))/2)/14 + C4*(9 + sqrt(109))*cos(sqrt(2)*t*sqrt(95 - + 9*sqrt(109))/2)*Rational(-1, 14)), + Eq(x(t), C1*(9 - sqrt(109))*cos(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C2*(9 - + sqrt(109))*sin(sqrt(2)*t*sqrt(9*sqrt(109) + 95)/2)*Rational(-1, 14) + C3*(9 + + sqrt(109))*cos(sqrt(2)*t*sqrt(95 - 9*sqrt(109))/2)/14 + C4*(9 + sqrt(109))*sin(sqrt(2)*t*sqrt(95 - + 9*sqrt(109))/2)/14)] + assert dsolve(eqs11) == sol11 + assert checksysodesol(eqs11, sol11) == (True, [0, 0]) + + # Euler Systems + # Note: To add examples of euler systems solver with non-homogeneous term. + eqs13 = [Eq(Derivative(f(t), (t, 2)), Derivative(f(t), t)/t + f(t)/t**2 + g(t)/t**2), + Eq(Derivative(g(t), (t, 2)), g(t)/t**2)] + sol13 = [Eq(f(t), C1*(sqrt(5) + 3)*Rational(-1, 2)*t**(Rational(1, 2) + + sqrt(5)*Rational(-1, 2)) + C2*t**(Rational(1, 2) + + sqrt(5)/2)*(3 - sqrt(5))*Rational(-1, 2) - C3*t**(1 - + sqrt(2))*(1 + sqrt(2)) - C4*t**(1 + sqrt(2))*(1 - sqrt(2))), + Eq(g(t), C1*(1 + sqrt(5))*Rational(-1, 2)*t**(Rational(1, 2) + + sqrt(5)*Rational(-1, 2)) + C2*t**(Rational(1, 2) + + sqrt(5)/2)*(1 - sqrt(5))*Rational(-1, 2))] + assert dsolve(eqs13) == sol13 + assert checksysodesol(eqs13, sol13) == (True, [0, 0]) + + # Solving systems using dsolve separately + eqs14 = [Eq(Derivative(f(t), (t, 2)), t*f(t)), + Eq(Derivative(g(t), (t, 2)), t*g(t))] + sol14 = [Eq(f(t), C1*airyai(t) + C2*airybi(t)), + Eq(g(t), C3*airyai(t) + C4*airybi(t))] + assert dsolve(eqs14) == sol14 + assert checksysodesol(eqs14, sol14) == (True, [0, 0]) + + + eqs15 = [Eq(Derivative(x(t), (t, 2)), t*(4*Derivative(x(t), t) + 8*Derivative(y(t), t))), + Eq(Derivative(y(t), (t, 2)), t*(12*Derivative(x(t), t) - 6*Derivative(y(t), t)))] + sol15 = [Eq(x(t), C1 - erf(sqrt(6)*t)*(sqrt(6)*sqrt(pi)*C2/33 + sqrt(6)*sqrt(pi)*C3*Rational(-1, 44)) + + erfi(sqrt(5)*t)*(sqrt(5)*sqrt(pi)*C2*Rational(2, 55) + sqrt(5)*sqrt(pi)*C3*Rational(4, 55))), + Eq(y(t), C4 + erf(sqrt(6)*t)*(sqrt(6)*sqrt(pi)*C2*Rational(2, 33) + sqrt(6)*sqrt(pi)*C3*Rational(-1, + 22)) + erfi(sqrt(5)*t)*(sqrt(5)*sqrt(pi)*C2*Rational(3, 110) + sqrt(5)*sqrt(pi)*C3*Rational(3, 55)))] + assert dsolve(eqs15) == sol15 + assert checksysodesol(eqs15, sol15) == (True, [0, 0]) + + +@slow +def test_higher_order_to_first_order_9(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + eqs9 = [f(x) + g(x) - 2*exp(I*x) + 2*Derivative(f(x), x) + Derivative(f(x), (x, 2)), + f(x) + g(x) - 2*exp(I*x) + 2*Derivative(g(x), x) + Derivative(g(x), (x, 2))] + sol9 = [Eq(f(x), -C1 + C4*exp(-2*x)/2 - (C2/2 - C3/2)*exp(-x)*cos(x) + + (C2/2 + C3/2)*exp(-x)*sin(x) + 2*((1 - 2*I)*exp(I*x)*sin(x)**2/5) + + 2*((1 - 2*I)*exp(I*x)*cos(x)**2/5)), + Eq(g(x), C1 - C4*exp(-2*x)/2 - (C2/2 - C3/2)*exp(-x)*cos(x) + + (C2/2 + C3/2)*exp(-x)*sin(x) + 2*((1 - 2*I)*exp(I*x)*sin(x)**2/5) + + 2*((1 - 2*I)*exp(I*x)*cos(x)**2/5))] + assert dsolve(eqs9) == sol9 + assert checksysodesol(eqs9, sol9) == (True, [0, 0]) + + +def test_higher_order_to_first_order_12(): + f, g = symbols('f g', cls=Function) + x = symbols('x') + + x, y = symbols('x, y', cls=Function) + t, l = symbols('t, l') + + eqs12 = [Eq(4*x(t) + Derivative(x(t), (t, 2)) + 8*Derivative(y(t), t), 0), + Eq(4*y(t) - 8*Derivative(x(t), t) + Derivative(y(t), (t, 2)), 0)] + sol12 = [Eq(y(t), C1*(2 - sqrt(5))*sin(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C2*(2 - + sqrt(5))*cos(2*t*sqrt(4*sqrt(5) + 9))/2 + C3*(2 + sqrt(5))*sin(2*t*sqrt(9 - 4*sqrt(5)))*Rational(-1, + 2) + C4*(2 + sqrt(5))*cos(2*t*sqrt(9 - 4*sqrt(5)))/2), + Eq(x(t), C1*(2 - sqrt(5))*cos(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C2*(2 - + sqrt(5))*sin(2*t*sqrt(4*sqrt(5) + 9))*Rational(-1, 2) + C3*(2 + sqrt(5))*cos(2*t*sqrt(9 - + 4*sqrt(5)))/2 + C4*(2 + sqrt(5))*sin(2*t*sqrt(9 - 4*sqrt(5)))/2)] + assert dsolve(eqs12) == sol12 + assert checksysodesol(eqs12, sol12) == (True, [0, 0]) + + +def test_second_order_to_first_order_2(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + eqs2 = [Eq(f(x).diff(x, 2), 2*(x*g(x).diff(x) - g(x))), + Eq(g(x).diff(x, 2),-2*(x*f(x).diff(x) - f(x)))] + sol2 = [Eq(f(x), C1*x + x*Integral(C2*exp(-x_)*exp(I*exp(2*x_))/2 + C2*exp(-x_)*exp(-I*exp(2*x_))/2 - + I*C3*exp(-x_)*exp(I*exp(2*x_))/2 + I*C3*exp(-x_)*exp(-I*exp(2*x_))/2, (x_, log(x)))), + Eq(g(x), C4*x + x*Integral(I*C2*exp(-x_)*exp(I*exp(2*x_))/2 - I*C2*exp(-x_)*exp(-I*exp(2*x_))/2 + + C3*exp(-x_)*exp(I*exp(2*x_))/2 + C3*exp(-x_)*exp(-I*exp(2*x_))/2, (x_, log(x))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs2, simplify=False, doit=False) == [sol2] + assert checksysodesol(eqs2, sol2) == (True, [0, 0]) + + eqs3 = (Eq(diff(f(t),t,t), 9*t*diff(g(t),t)-9*g(t)), Eq(diff(g(t),t,t),7*t*diff(f(t),t)-7*f(t))) + sol3 = [Eq(f(t), C1*t + t*Integral(C2*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/2 + C2*exp(-t_)* + exp(-3*sqrt(7)*exp(2*t_)/2)/2 + 3*sqrt(7)*C3*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/14 - + 3*sqrt(7)*C3*exp(-t_)*exp(-3*sqrt(7)*exp(2*t_)/2)/14, (t_, log(t)))), + Eq(g(t), C4*t + t*Integral(sqrt(7)*C2*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/6 - sqrt(7)*C2*exp(-t_)* + exp(-3*sqrt(7)*exp(2*t_)/2)/6 + C3*exp(-t_)*exp(3*sqrt(7)*exp(2*t_)/2)/2 + C3*exp(-t_)*exp(-3*sqrt(7)* + exp(2*t_)/2)/2, (t_, log(t))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs3, simplify=False, doit=False) == [sol3] + assert checksysodesol(eqs3, sol3) == (True, [0, 0]) + + # Regression Test case for sympy#19238 + # https://github.com/sympy/sympy/issues/19238 + # Note: When the doit method is removed, these particular types of systems + # can be divided first so that we have lesser number of big matrices. + eqs5 = [Eq(Derivative(g(t), (t, 2)), a*m), + Eq(Derivative(f(t), (t, 2)), 0)] + sol5 = [Eq(g(t), C1 + C2*t + a*m*t**2/2), + Eq(f(t), C3 + C4*t)] + assert dsolve(eqs5) == sol5 + assert checksysodesol(eqs5, sol5) == (True, [0, 0]) + + # Type 2 + eqs6 = [Eq(Derivative(f(t), (t, 2)), f(t)/t**4), + Eq(Derivative(g(t), (t, 2)), d*g(t)/t**4)] + sol6 = [Eq(f(t), C1*sqrt(t**2)*exp(-1/t) - C2*sqrt(t**2)*exp(1/t)), + Eq(g(t), C3*sqrt(t**2)*exp(-sqrt(d)/t)*d**Rational(-1, 2) - + C4*sqrt(t**2)*exp(sqrt(d)/t)*d**Rational(-1, 2))] + assert dsolve(eqs6) == sol6 + assert checksysodesol(eqs6, sol6) == (True, [0, 0]) + + +@slow +def test_second_order_to_first_order_slow1(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + # Type 1 + + eqs1 = [Eq(f(x).diff(x, 2), 2/x *(x*g(x).diff(x) - g(x))), + Eq(g(x).diff(x, 2),-2/x *(x*f(x).diff(x) - f(x)))] + sol1 = [Eq(f(x), C1*x + 2*C2*x*Ci(2*x) - C2*sin(2*x) - 2*C3*x*Si(2*x) - C3*cos(2*x)), + Eq(g(x), -2*C2*x*Si(2*x) - C2*cos(2*x) - 2*C3*x*Ci(2*x) + C3*sin(2*x) + C4*x)] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + +def test_second_order_to_first_order_slow4(): + f, g = symbols("f g", cls=Function) + x, t, x_, t_, d, a, m = symbols("x t x_ t_ d a m") + + eqs4 = [Eq(Derivative(f(t), (t, 2)), t*sin(t)*Derivative(g(t), t) - g(t)*sin(t)), + Eq(Derivative(g(t), (t, 2)), t*sin(t)*Derivative(f(t), t) - f(t)*sin(t))] + sol4 = [Eq(f(t), C1*t + t*Integral(C2*exp(-t_)*exp(exp(t_)*cos(exp(t_)))*exp(-sin(exp(t_)))/2 + + C2*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2 - C3*exp(-t_)*exp(exp(t_)*cos(exp(t_)))* + exp(-sin(exp(t_)))/2 + + C3*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2, (t_, log(t)))), + Eq(g(t), C4*t + t*Integral(-C2*exp(-t_)*exp(exp(t_)*cos(exp(t_)))*exp(-sin(exp(t_)))/2 + + C2*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2 + C3*exp(-t_)*exp(exp(t_)*cos(exp(t_)))* + exp(-sin(exp(t_)))/2 + C3*exp(-t_)*exp(-exp(t_)*cos(exp(t_)))*exp(sin(exp(t_)))/2, (t_, log(t))))] + # XXX: dsolve hangs for this in integration + assert dsolve_system(eqs4, simplify=False, doit=False) == [sol4] + assert checksysodesol(eqs4, sol4) == (True, [0, 0]) + + +def test_component_division(): + f, g, h, k = symbols('f g h k', cls=Function) + x = symbols("x") + funcs = [f(x), g(x), h(x), k(x)] + + eqs1 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), h(x)**4 + k(x))] + sol1 = [Eq(f(x), 2*C1*exp(2*x)), + Eq(g(x), C1*exp(2*x) + C2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C3**4*exp(4*x)/3 + C4*exp(x))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0, 0, 0]) + + components1 = {((Eq(Derivative(f(x), x), 2*f(x)),), (Eq(Derivative(g(x), x), f(x)),)), + ((Eq(Derivative(h(x), x), h(x)),), (Eq(Derivative(k(x), x), h(x)**4 + k(x)),))} + eqsdict1 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {h(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), h(x)**4 + k(x))}) + graph1 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), h(x))}] + assert {tuple(tuple(scc) for scc in wcc) for wcc in _component_division(eqs1, funcs, x)} == components1 + assert _eqs2dict(eqs1, funcs) == eqsdict1 + assert [set(element) for element in _dict2graph(eqsdict1[0])] == graph1 + + eqs2 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol2 = [Eq(f(x), C1*exp(2*x)), + Eq(g(x), C1*exp(2*x)/2 + C2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C1**4*exp(8*x)/7 + C4*exp(x))] + assert dsolve(eqs2) == sol2 + assert checksysodesol(eqs2, sol2) == (True, [0, 0, 0, 0]) + + components2 = {frozenset([(Eq(Derivative(f(x), x), 2*f(x)),), + (Eq(Derivative(g(x), x), f(x)),), + (Eq(Derivative(k(x), x), f(x)**4 + k(x)),)]), + frozenset([(Eq(Derivative(h(x), x), h(x)),)])} + eqsdict2 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph2 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), f(x))}] + assert {frozenset(tuple(scc) for scc in wcc) for wcc in _component_division(eqs2, funcs, x)} == components2 + assert _eqs2dict(eqs2, funcs) == eqsdict2 + assert [set(element) for element in _dict2graph(eqsdict2[0])] == graph2 + + eqs3 = [Eq(Derivative(f(x), x), 2*f(x)), + Eq(Derivative(g(x), x), x + f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol3 = [Eq(f(x), C1*exp(2*x)), + Eq(g(x), C1*exp(2*x)/2 + C2 + x**2/2), + Eq(h(x), C3*exp(x)), + Eq(k(x), C1**4*exp(8*x)/7 + C4*exp(x))] + assert dsolve(eqs3) == sol3 + assert checksysodesol(eqs3, sol3) == (True, [0, 0, 0, 0]) + + components3 = {frozenset([(Eq(Derivative(f(x), x), 2*f(x)),), + (Eq(Derivative(g(x), x), x + f(x)),), + (Eq(Derivative(k(x), x), f(x)**4 + k(x)),)]), + frozenset([(Eq(Derivative(h(x), x), h(x)),),])} + eqsdict3 = ({f(x): set(), g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), 2*f(x)), + g(x): Eq(Derivative(g(x), x), x + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph3 = [{f(x), g(x), h(x), k(x)}, {(g(x), f(x)), (k(x), f(x))}] + assert {frozenset(tuple(scc) for scc in wcc) for wcc in _component_division(eqs3, funcs, x)} == components3 + assert _eqs2dict(eqs3, funcs) == eqsdict3 + assert [set(l) for l in _dict2graph(eqsdict3[0])] == graph3 + + # Note: To be uncommented when the default option to call dsolve first for + # single ODE system can be rearranged. This can be done after the doit + # option in dsolve is made False by default. + + eqs4 = [Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), f(x) + x*g(x) + x), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol4 = [Eq(f(x), (C1/2 - sqrt(2)*C2/2 - sqrt(2)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 +\ + sqrt(2)*x)/2, x)/2 + Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 +\ + sqrt(2)*x)/2, x)/2)*exp(x**2/2 - sqrt(2)*x) + (C1/2 + sqrt(2)*C2/2 + sqrt(2)*Integral(x*exp(-x**2/2 + - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 + - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(g(x), (-sqrt(2)*C1/4 + C2/2 + Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 -\ + sqrt(2)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, + x)/4)*exp(x**2/2 - sqrt(2)*x) + (sqrt(2)*C1/4 + C2/2 + Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + sqrt(2)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - + sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/4)*exp(x**2/2 + sqrt(2)*x)), + Eq(h(x), C3*exp(x)), + Eq(k(x), C4*exp(x) + exp(x)*Integral((C1*exp(x**2/2 - sqrt(2)*x)/2 + C1*exp(x**2/2 + sqrt(2)*x)/2 - + sqrt(2)*C2*exp(x**2/2 - sqrt(2)*x)/2 + sqrt(2)*C2*exp(x**2/2 + sqrt(2)*x)/2 - sqrt(2)*exp(x**2/2 - + sqrt(2)*x)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2 + exp(x**2/2 - + sqrt(2)*x)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, + x)/2 + sqrt(2)*exp(x**2/2 + sqrt(2)*x)*Integral(x*exp(-x**2/2 - sqrt(2)*x)/2 + x*exp(-x**2/2 + + sqrt(2)*x)/2, x)/2 + exp(x**2/2 + sqrt(2)*x)*Integral(sqrt(2)*x*exp(-x**2/2 - sqrt(2)*x)/2 - + sqrt(2)*x*exp(-x**2/2 + sqrt(2)*x)/2, x)/2)**4*exp(-x), x))] + components4 = {(frozenset([Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + x + f(x))]), + frozenset([Eq(Derivative(k(x), x), f(x)**4 + k(x)),])), + (frozenset([Eq(Derivative(h(x), x), h(x)),]),)} + eqsdict4 = ({f(x): {g(x)}, g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + g(x): Eq(Derivative(g(x), x), x*g(x) + x + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph4 = [{f(x), g(x), h(x), k(x)}, {(f(x), g(x)), (g(x), f(x)), (k(x), f(x))}] + assert {tuple(frozenset(scc) for scc in wcc) for wcc in _component_division(eqs4, funcs, x)} == components4 + assert _eqs2dict(eqs4, funcs) == eqsdict4 + assert [set(element) for element in _dict2graph(eqsdict4[0])] == graph4 + # XXX: dsolve hangs in integration here: + assert dsolve_system(eqs4, simplify=False, doit=False) == [sol4] + assert checksysodesol(eqs4, sol4) == (True, [0, 0, 0, 0]) + + eqs5 = [Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + f(x)), + Eq(Derivative(h(x), x), h(x)), + Eq(Derivative(k(x), x), f(x)**4 + k(x))] + sol5 = [Eq(f(x), (C1/2 - sqrt(2)*C2/2)*exp(x**2/2 - sqrt(2)*x) + (C1/2 + sqrt(2)*C2/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(g(x), (-sqrt(2)*C1/4 + C2/2)*exp(x**2/2 - sqrt(2)*x) + (sqrt(2)*C1/4 + C2/2)*exp(x**2/2 + sqrt(2)*x)), + Eq(h(x), C3*exp(x)), + Eq(k(x), C4*exp(x) + exp(x)*Integral((C1*exp(x**2/2 - sqrt(2)*x)/2 + C1*exp(x**2/2 + sqrt(2)*x)/2 - + sqrt(2)*C2*exp(x**2/2 - sqrt(2)*x)/2 + sqrt(2)*C2*exp(x**2/2 + sqrt(2)*x)/2)**4*exp(-x), x))] + components5 = {(frozenset([Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + Eq(Derivative(g(x), x), x*g(x) + f(x))]), + frozenset([Eq(Derivative(k(x), x), f(x)**4 + k(x)),])), + (frozenset([Eq(Derivative(h(x), x), h(x)),]),)} + eqsdict5 = ({f(x): {g(x)}, g(x): {f(x)}, h(x): set(), k(x): {f(x)}}, + {f(x): Eq(Derivative(f(x), x), x*f(x) + 2*g(x)), + g(x): Eq(Derivative(g(x), x), x*g(x) + f(x)), + h(x): Eq(Derivative(h(x), x), h(x)), + k(x): Eq(Derivative(k(x), x), f(x)**4 + k(x))}) + graph5 = [{f(x), g(x), h(x), k(x)}, {(f(x), g(x)), (g(x), f(x)), (k(x), f(x))}] + assert {tuple(frozenset(scc) for scc in wcc) for wcc in _component_division(eqs5, funcs, x)} == components5 + assert _eqs2dict(eqs5, funcs) == eqsdict5 + assert [set(element) for element in _dict2graph(eqsdict5[0])] == graph5 + # XXX: dsolve hangs in integration here: + assert dsolve_system(eqs5, simplify=False, doit=False) == [sol5] + assert checksysodesol(eqs5, sol5) == (True, [0, 0, 0, 0]) + + +def test_linodesolve(): + t, x, a = symbols("t x a") + f, g, h = symbols("f g h", cls=Function) + + # Testing the Errors + raises(ValueError, lambda: linodesolve(1, t)) + raises(ValueError, lambda: linodesolve(a, t)) + + A1 = Matrix([[1, 2], [2, 4], [4, 6]]) + raises(NonSquareMatrixError, lambda: linodesolve(A1, t)) + + A2 = Matrix([[1, 2, 1], [3, 1, 2]]) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t)) + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t) + g(t).diff(t), g(t)), Eq(g(t).diff(t), f(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [C1*(-Rational(1, 2) + sqrt(5)/2)*exp(t*(-Rational(1, 2) + sqrt(5)/2)) + C2*(-sqrt(5)/2 - Rational(1, 2))* + exp(t*(-sqrt(5)/2 - Rational(1, 2))), + C1*exp(t*(-Rational(1, 2) + sqrt(5)/2)) + C2*exp(t*(-sqrt(5)/2 - Rational(1, 2)))] + assert constant_renumber(linodesolve(A, t), variables=Tuple(*eq).free_symbols) == sol + + # Testing the Errors + raises(ValueError, lambda: linodesolve(1, t, b=Matrix([t+1]))) + raises(ValueError, lambda: linodesolve(a, t, b=Matrix([log(t) + sin(t)]))) + + raises(ValueError, lambda: linodesolve(Matrix([7]), t, b=t**2)) + raises(ValueError, lambda: linodesolve(Matrix([a+10]), t, b=log(t)*cos(t))) + + raises(ValueError, lambda: linodesolve(7, t, b=t**2)) + raises(ValueError, lambda: linodesolve(a, t, b=log(t) + sin(t))) + + A1 = Matrix([[1, 2], [2, 4], [4, 6]]) + b1 = Matrix([t, 1, t**2]) + raises(NonSquareMatrixError, lambda: linodesolve(A1, t, b=b1)) + + A2 = Matrix([[1, 2, 1], [3, 1, 2]]) + b2 = Matrix([t, t**2]) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t, b=b2)) + + raises(ValueError, lambda: linodesolve(A1[:2, :], t, b=b1)) + raises(ValueError, lambda: linodesolve(A1[:2, :], t, b=b1[:1])) + + # DOIT check + A1 = Matrix([[1, -1], [1, -1]]) + b1 = Matrix([15*t - 10, -15*t - 5]) + sol1 = [C1 + C2*t + C2 - 10*t**3 + 10*t**2 + t*(15*t**2 - 5*t) - 10*t, + C1 + C2*t - 10*t**3 - 5*t**2 + t*(15*t**2 - 5*t) - 5*t] + assert constant_renumber(linodesolve(A1, t, b=b1, type="type2", doit=True), + variables=[t]) == sol1 + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t) + g(t).diff(t), g(t) + t), Eq(g(t).diff(t), f(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [-C1*exp(-t/2 + sqrt(5)*t/2)/2 + sqrt(5)*C1*exp(-t/2 + sqrt(5)*t/2)/2 - sqrt(5)*C2*exp(-sqrt(5)*t/2 - + t/2)/2 - C2*exp(-sqrt(5)*t/2 - t/2)/2 - exp(-t/2 + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)), t)/2 + sqrt(5)*exp(-t/2 + + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)), t)/2 - sqrt(5)*exp(-sqrt(5)*t/2 - t/2)*Integral(-sqrt(5)*t*exp(t/2 + + sqrt(5)*t/2)/5, t)/2 - exp(-sqrt(5)*t/2 - t/2)*Integral(-sqrt(5)*t*exp(t/2 + sqrt(5)*t/2)/5, t)/2, + C1*exp(-t/2 + sqrt(5)*t/2) + C2*exp(-sqrt(5)*t/2 - t/2) + exp(-t/2 + + sqrt(5)*t/2)*Integral(t*exp(-sqrt(5)*t/2 + t/2)/(-5 + sqrt(5)) - sqrt(5)*t*exp(-sqrt(5)*t/2 + + t/2)/(-5 + sqrt(5)), t) + exp(-sqrt(5)*t/2 - + t/2)*Integral(-sqrt(5)*t*exp(t/2 + sqrt(5)*t/2)/5, t)] + assert constant_renumber(linodesolve(A, t, b=b), variables=[t]) == sol + + # non-homogeneous term assumed to be 0 + sol1 = [-C1*exp(-t/2 + sqrt(5)*t/2)/2 + sqrt(5)*C1*exp(-t/2 + sqrt(5)*t/2)/2 - sqrt(5)*C2*exp(-sqrt(5)*t/2 + - t/2)/2 - C2*exp(-sqrt(5)*t/2 - t/2)/2, + C1*exp(-t/2 + sqrt(5)*t/2) + C2*exp(-sqrt(5)*t/2 - t/2)] + assert constant_renumber(linodesolve(A, t, type="type2"), variables=[t]) == sol1 + + # Testing the Errors + raises(ValueError, lambda: linodesolve(t+10, t)) + raises(ValueError, lambda: linodesolve(a*t, t)) + + A1 = Matrix([[1, t], [-t, 1]]) + B1, _ = _is_commutative_anti_derivative(A1, t) + raises(NonSquareMatrixError, lambda: linodesolve(A1[:, :1], t, B=B1)) + raises(ValueError, lambda: linodesolve(A1, t, B=1)) + + A2 = Matrix([[t, t, t], [t, t, t], [t, t, t]]) + B2, _ = _is_commutative_anti_derivative(A2, t) + raises(NonSquareMatrixError, lambda: linodesolve(A2, t, B=B2[:2, :])) + raises(ValueError, lambda: linodesolve(A2, t, B=2)) + raises(ValueError, lambda: linodesolve(A2, t, B=B2, type="type31")) + + raises(ValueError, lambda: linodesolve(A1, t, B=B2)) + raises(ValueError, lambda: linodesolve(A2, t, B=B1)) + + # Testing auto functionality + func = [f(t), g(t)] + eq = [Eq(f(t).diff(t), f(t) + t*g(t)), Eq(g(t).diff(t), -t*f(t) + g(t))] + ceq = canonical_odes(eq, func, t) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, t, 1) + A = A0 + sol = [(C1/2 - I*C2/2)*exp(I*t**2/2 + t) + (C1/2 + I*C2/2)*exp(-I*t**2/2 + t), + (-I*C1/2 + C2/2)*exp(-I*t**2/2 + t) + (I*C1/2 + C2/2)*exp(I*t**2/2 + t)] + assert constant_renumber(linodesolve(A, t), variables=Tuple(*eq).free_symbols) == sol + assert constant_renumber(linodesolve(A, t, type="type3"), variables=Tuple(*eq).free_symbols) == sol + + A1 = Matrix([[t, 1], [t, -1]]) + raises(NotImplementedError, lambda: linodesolve(A1, t)) + + # Testing the Errors + raises(ValueError, lambda: linodesolve(t+10, t, b=Matrix([t+1]))) + raises(ValueError, lambda: linodesolve(a*t, t, b=Matrix([log(t) + sin(t)]))) + + raises(ValueError, lambda: linodesolve(Matrix([7*t]), t, b=t**2)) + raises(ValueError, lambda: linodesolve(Matrix([a + 10*log(t)]), t, b=log(t)*cos(t))) + + raises(ValueError, lambda: linodesolve(7*t, t, b=t**2)) + raises(ValueError, lambda: linodesolve(a*t**2, t, b=log(t) + sin(t))) + + A1 = Matrix([[1, t], [-t, 1]]) + b1 = Matrix([t, t ** 2]) + B1, _ = _is_commutative_anti_derivative(A1, t) + raises(NonSquareMatrixError, lambda: linodesolve(A1[:, :1], t, b=b1)) + + A2 = Matrix([[t, t, t], [t, t, t], [t, t, t]]) + b2 = Matrix([t, 1, t**2]) + B2, _ = _is_commutative_anti_derivative(A2, t) + raises(NonSquareMatrixError, lambda: linodesolve(A2[:2, :], t, b=b2)) + + raises(ValueError, lambda: linodesolve(A1, t, b=b2)) + raises(ValueError, lambda: linodesolve(A2, t, b=b1)) + + raises(ValueError, lambda: linodesolve(A1, t, b=b1, B=B2)) + raises(ValueError, lambda: linodesolve(A2, t, b=b2, B=B1)) + + # Testing auto functionality + func = [f(x), g(x), h(x)] + eq = [Eq(f(x).diff(x), x*(f(x) + g(x) + h(x)) + x), + Eq(g(x).diff(x), x*(f(x) + g(x) + h(x)) + x), + Eq(h(x).diff(x), x*(f(x) + g(x) + h(x)) + 1)] + ceq = canonical_odes(eq, func, x) + (A1, A0), b = linear_ode_to_matrix(ceq[0], func, x, 1) + A = A0 + _x1 = exp(-3*x**2/2) + _x2 = exp(3*x**2/2) + _x3 = Integral(2*_x1*x/3 + _x1/3 + x/3 - Rational(1, 3), x) + _x4 = 2*_x2*_x3/3 + _x5 = Integral(2*_x1*x/3 + _x1/3 - 2*x/3 + Rational(2, 3), x) + sol = [ + C1*_x2/3 - C1/3 + C2*_x2/3 - C2/3 + C3*_x2/3 + 2*C3/3 + _x2*_x5/3 + _x3/3 + _x4 - _x5/3, + C1*_x2/3 + 2*C1/3 + C2*_x2/3 - C2/3 + C3*_x2/3 - C3/3 + _x2*_x5/3 + _x3/3 + _x4 - _x5/3, + C1*_x2/3 - C1/3 + C2*_x2/3 + 2*C2/3 + C3*_x2/3 - C3/3 + _x2*_x5/3 - 2*_x3/3 + _x4 + 2*_x5/3, + ] + assert constant_renumber(linodesolve(A, x, b=b), variables=Tuple(*eq).free_symbols) == sol + assert constant_renumber(linodesolve(A, x, b=b, type="type4"), + variables=Tuple(*eq).free_symbols) == sol + + A1 = Matrix([[t, 1], [t, -1]]) + raises(NotImplementedError, lambda: linodesolve(A1, t, b=b1)) + + # non-homogeneous term not passed + sol1 = [-C1/3 - C2/3 + 2*C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2), 2*C1/3 - C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2), + -C1/3 + 2*C2/3 - C3/3 + (C1/3 + C2/3 + C3/3)*exp(3*x**2/2)] + assert constant_renumber(linodesolve(A, x, type="type4", doit=True), variables=Tuple(*eq).free_symbols) == sol1 + + +@slow +def test_linear_3eq_order1_type4_slow(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + + f = t ** 3 + log(t) + g = t ** 2 + sin(t) + eq1 = (Eq(diff(x(t), t), (4 * f + g) * x(t) - f * y(t) - 2 * f * z(t)), + Eq(diff(y(t), t), 2 * f * x(t) + (f + g) * y(t) - 2 * f * z(t)), Eq(diff(z(t), t), 5 * f * x(t) + f * y( + t) + (-3 * f + g) * z(t))) + with dotprodsimp(True): + dsolve(eq1) + + +@slow +def test_linear_neq_order1_type2_slow1(): + i, r1, c1, r2, c2, t = symbols('i, r1, c1, r2, c2, t') + x1 = Function('x1') + x2 = Function('x2') + + eq1 = r1*c1*Derivative(x1(t), t) + x1(t) - x2(t) - r1*i + eq2 = r2*c1*Derivative(x1(t), t) + r2*c2*Derivative(x2(t), t) + x2(t) - r2*i + eq = [eq1, eq2] + + # XXX: Solution is too complicated + [sol] = dsolve_system(eq, simplify=False, doit=False) + assert checksysodesol(eq, sol) == (True, [0, 0]) + + +# Regression test case for issue #9204 +# https://github.com/sympy/sympy/issues/9204 +@tooslow +def test_linear_new_order1_type2_de_lorentz_slow_check(): + m = Symbol("m", real=True) + q = Symbol("q", real=True) + t = Symbol("t", real=True) + + e1, e2, e3 = symbols("e1:4", real=True) + b1, b2, b3 = symbols("b1:4", real=True) + v1, v2, v3 = symbols("v1:4", cls=Function, real=True) + + eqs = [ + -e1*q + m*Derivative(v1(t), t) - q*(-b2*v3(t) + b3*v2(t)), + -e2*q + m*Derivative(v2(t), t) - q*(b1*v3(t) - b3*v1(t)), + -e3*q + m*Derivative(v3(t), t) - q*(-b1*v2(t) + b2*v1(t)) + ] + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0, 0]) + + +# Regression test case for issue #14001 +# https://github.com/sympy/sympy/issues/14001 +@slow +def test_linear_neq_order1_type2_slow_check(): + RC, t, C, Vs, L, R1, V0, I0 = symbols("RC t C Vs L R1 V0 I0") + V = Function("V") + I = Function("I") + system = [Eq(V(t).diff(t), -1/RC*V(t) + I(t)/C), Eq(I(t).diff(t), -R1/L*I(t) - 1/L*V(t) + Vs/L)] + [sol] = dsolve_system(system, simplify=False, doit=False) + + assert checksysodesol(system, sol) == (True, [0, 0]) + + +def _linear_3eq_order1_type4_long(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + + f = t ** 3 + log(t) + g = t ** 2 + sin(t) + + eq1 = (Eq(diff(x(t), t), (4*f + g)*x(t) - f*y(t) - 2*f*z(t)), + Eq(diff(y(t), t), 2*f*x(t) + (f + g)*y(t) - 2*f*z(t)), Eq(diff(z(t), t), 5*f*x(t) + f*y( + t) + (-3*f + g)*z(t))) + + dsolve_sol = dsolve(eq1) + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + + x_1 = sqrt(-t**6 - 8*t**3*log(t) + 8*t**3 - 16*log(t)**2 + 32*log(t) - 16) + x_2 = sqrt(3) + x_3 = 8324372644*C1*x_1*x_2 + 4162186322*C2*x_1*x_2 - 8324372644*C3*x_1*x_2 + x_4 = 1 / (1903457163*t**3 + 3825881643*x_1*x_2 + 7613828652*log(t) - 7613828652) + x_5 = exp(t**3/3 + t*x_1*x_2/4 - cos(t)) + x_6 = exp(t**3/3 - t*x_1*x_2/4 - cos(t)) + x_7 = exp(t**4/2 + t**3/3 + 2*t*log(t) - 2*t - cos(t)) + x_8 = 91238*C1*x_1*x_2 + 91238*C2*x_1*x_2 - 91238*C3*x_1*x_2 + x_9 = 1 / (66049*t**3 - 50629*x_1*x_2 + 264196*log(t) - 264196) + x_10 = 50629 * C1 / 25189 + 37909*C2/25189 - 50629*C3/25189 - x_3*x_4 + x_11 = -50629*C1/25189 - 12720*C2/25189 + 50629*C3/25189 + x_3*x_4 + sol = [Eq(x(t), x_10*x_5 + x_11*x_6 + x_7*(C1 - C2)), Eq(y(t), x_10*x_5 + x_11*x_6), Eq(z(t), x_5*( + -424*C1/257 - 167*C2/257 + 424*C3/257 - x_8*x_9) + x_6*(167*C1/257 + 424*C2/257 - + 167*C3/257 + x_8*x_9) + x_7*(C1 - C2))] + + assert dsolve_sol1 == sol + assert checksysodesol(eq1, dsolve_sol1) == (True, [0, 0, 0]) + + +@slow +def test_neq_order1_type4_slow_check1(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + + eqs = [Eq(diff(f(x), x), x*f(x) + x**2*g(x) + x), + Eq(diff(g(x), x), 2*x**2*f(x) + (x + 3*x**2)*g(x) + 1)] + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0]) + + +@slow +def test_neq_order1_type4_slow_check2(): + f, g, h = symbols("f, g, h", cls=Function) + x = Symbol("x") + + eqs = [ + Eq(Derivative(f(x), x), x*h(x) + f(x) + g(x) + 1), + Eq(Derivative(g(x), x), x*g(x) + f(x) + h(x) + 10), + Eq(Derivative(h(x), x), x*f(x) + x + g(x) + h(x)) + ] + with dotprodsimp(True): + sol = dsolve(eqs) + assert checksysodesol(eqs, sol) == (True, [0, 0, 0]) + + +def _neq_order1_type4_slow3(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + + eqs = [ + Eq(Derivative(f(x), x), x*f(x) + g(x) + sin(x)), + Eq(Derivative(g(x), x), x**2 + x*g(x) - f(x)) + ] + sol = [ + Eq(f(x), (C1/2 - I*C2/2 - I*Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 + Integral(-I*x**2*exp(-x**2/2 + - I*x)/2 + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - + I*x)*sin(x)/2 + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 + + I*x) + (C1/2 + I*C2/2 + I*Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 + Integral(-I*x**2*exp(-x**2/2 + - I*x)/2 + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - + I*x)*sin(x)/2 + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 - + I*x)), + Eq(g(x), (-I*C1/2 + C2/2 + Integral(x**2*exp(-x**2/2 - I*x)/2 + + x**2*exp(-x**2/2 + I*x)/2 + I*exp(-x**2/2 - I*x)*sin(x)/2 - + I*exp(-x**2/2 + I*x)*sin(x)/2, x)/2 - + I*Integral(-I*x**2*exp(-x**2/2 - I*x)/2 + I*x**2*exp(-x**2/2 + + I*x)/2 + exp(-x**2/2 - I*x)*sin(x)/2 + exp(-x**2/2 + + I*x)*sin(x)/2, x)/2)*exp(x**2/2 - I*x) + (I*C1/2 + C2/2 + + Integral(x**2*exp(-x**2/2 - I*x)/2 + x**2*exp(-x**2/2 + I*x)/2 + + I*exp(-x**2/2 - I*x)*sin(x)/2 - I*exp(-x**2/2 + I*x)*sin(x)/2, + x)/2 + I*Integral(-I*x**2*exp(-x**2/2 - I*x)/2 + + I*x**2*exp(-x**2/2 + I*x)/2 + exp(-x**2/2 - I*x)*sin(x)/2 + + exp(-x**2/2 + I*x)*sin(x)/2, x)/2)*exp(x**2/2 + I*x)) + ] + + return eqs, sol + + +def test_neq_order1_type4_slow3(): + eqs, sol = _neq_order1_type4_slow3() + assert dsolve_system(eqs, simplify=False, doit=False) == [sol] + # XXX: dsolve gives an error in integration: + # assert dsolve(eqs) == sol + # https://github.com/sympy/sympy/issues/20155 + + +@slow +def test_neq_order1_type4_slow_check3(): + eqs, sol = _neq_order1_type4_slow3() + assert checksysodesol(eqs, sol) == (True, [0, 0]) + + +@tooslow +@XFAIL +def test_linear_3eq_order1_type4_long_dsolve_slow_xfail(): + eq, sol = _linear_3eq_order1_type4_long() + + dsolve_sol = dsolve(eq) + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + + assert dsolve_sol1 == sol + + +@tooslow +def test_linear_3eq_order1_type4_long_dsolve_dotprodsimp(): + eq, sol = _linear_3eq_order1_type4_long() + + # XXX: Only works with dotprodsimp see + # test_linear_3eq_order1_type4_long_dsolve_slow_xfail which is too slow + with dotprodsimp(True): + dsolve_sol = dsolve(eq) + + dsolve_sol1 = [_simpsol(sol) for sol in dsolve_sol] + assert dsolve_sol1 == sol + + +@tooslow +def test_linear_3eq_order1_type4_long_check(): + eq, sol = _linear_3eq_order1_type4_long() + assert checksysodesol(eq, sol) == (True, [0, 0, 0]) + + +def test_dsolve_system(): + f, g = symbols("f g", cls=Function) + x = symbols("x") + eqs = [Eq(f(x).diff(x), f(x) + g(x)), Eq(g(x).diff(x), f(x) + g(x))] + funcs = [f(x), g(x)] + + sol = [[Eq(f(x), -C1 + C2*exp(2*x)), Eq(g(x), C1 + C2*exp(2*x))]] + assert dsolve_system(eqs, funcs=funcs, t=x, doit=True) == sol + + raises(ValueError, lambda: dsolve_system(1)) + raises(ValueError, lambda: dsolve_system(eqs, 1)) + raises(ValueError, lambda: dsolve_system(eqs, funcs, 1)) + raises(ValueError, lambda: dsolve_system(eqs, funcs[:1], x)) + + eq = (Eq(f(x).diff(x), 12 * f(x) - 6 * g(x)), Eq(g(x).diff(x) ** 2, 11 * f(x) + 3 * g(x))) + raises(NotImplementedError, lambda: dsolve_system(eq) == ([], [])) + + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)]) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], t=x) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], t=x, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, t=x, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, ics={f(0): 1, g(0): 1}) == ([], [])) + raises(NotImplementedError, lambda: dsolve_system(eq, funcs=[f(x), g(x)], ics={f(0): 1, g(0): 1}) == ([], [])) + +def test_dsolve(): + + f, g = symbols('f g', cls=Function) + x, y = symbols('x y') + + eqs = [f(x).diff(x) - x, f(x).diff(x) + x] + with raises(ValueError): + dsolve(eqs) + + eqs = [f(x, y).diff(x)] + with raises(ValueError): + dsolve(eqs) + + eqs = [f(x, y).diff(x)+g(x).diff(x), g(x).diff(x)] + with raises(ValueError): + dsolve(eqs) + + +@slow +def test_higher_order1_slow1(): + x, y = symbols("x y", cls=Function) + t = symbols("t") + + eq = [ + Eq(diff(x(t),t,t), (log(t)+t**2)*diff(x(t),t)+(log(t)+t**2)*3*diff(y(t),t)), + Eq(diff(y(t),t,t), (log(t)+t**2)*2*diff(x(t),t)+(log(t)+t**2)*9*diff(y(t),t)) + ] + sol, = dsolve_system(eq, simplify=False, doit=False) + # The solution is too long to write out explicitly and checkodesol is too + # slow so we test for particular values of t: + for e in eq: + res = (e.lhs - e.rhs).subs({sol[0].lhs:sol[0].rhs, sol[1].lhs:sol[1].rhs}) + res = res.subs({d: d.doit(deep=False) for d in res.atoms(Derivative)}) + assert ratsimp(res.subs(t, 1)) == 0 + + +def test_second_order_type2_slow1(): + x, y, z = symbols('x, y, z', cls=Function) + t, l = symbols('t, l') + + eqs1 = [Eq(Derivative(x(t), (t, 2)), t*(2*x(t) + y(t))), + Eq(Derivative(y(t), (t, 2)), t*(-x(t) + 2*y(t)))] + sol1 = [Eq(x(t), I*C1*airyai(t*(2 - I)**(S(1)/3)) + I*C2*airybi(t*(2 - I)**(S(1)/3)) - I*C3*airyai(t*(2 + + I)**(S(1)/3)) - I*C4*airybi(t*(2 + I)**(S(1)/3))), + Eq(y(t), C1*airyai(t*(2 - I)**(S(1)/3)) + C2*airybi(t*(2 - I)**(S(1)/3)) + C3*airyai(t*(2 + I)**(S(1)/3)) + + C4*airybi(t*(2 + I)**(S(1)/3)))] + assert dsolve(eqs1) == sol1 + assert checksysodesol(eqs1, sol1) == (True, [0, 0]) + + +@tooslow +@XFAIL +def test_nonlinear_3eq_order1_type1(): + a, b, c = symbols('a b c') + + eqs = [ + a * f(x).diff(x) - (b - c) * g(x) * h(x), + b * g(x).diff(x) - (c - a) * h(x) * f(x), + c * h(x).diff(x) - (a - b) * f(x) * g(x), + ] + + assert dsolve(eqs) # NotImplementedError + + +@XFAIL +def test_nonlinear_3eq_order1_type4(): + eqs = [ + Eq(f(x).diff(x), (2*h(x)*g(x) - 3*g(x)*h(x))), + Eq(g(x).diff(x), (4*f(x)*h(x) - 2*h(x)*f(x))), + Eq(h(x).diff(x), (3*g(x)*f(x) - 4*f(x)*g(x))), + ] + dsolve(eqs) # KeyError when matching + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +@tooslow +@XFAIL +def test_nonlinear_3eq_order1_type3(): + eqs = [ + Eq(f(x).diff(x), (2*f(x)**2 - 3 )), + Eq(g(x).diff(x), (4 - 2*h(x) )), + Eq(h(x).diff(x), (3*h(x) - 4*f(x)**2)), + ] + dsolve(eqs) # Not sure if this finishes... + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +@XFAIL +def test_nonlinear_3eq_order1_type5(): + eqs = [ + Eq(f(x).diff(x), f(x)*(2*f(x) - 3*g(x))), + Eq(g(x).diff(x), g(x)*(4*g(x) - 2*h(x))), + Eq(h(x).diff(x), h(x)*(3*h(x) - 4*f(x))), + ] + dsolve(eqs) # KeyError + # sol = ? + # assert dsolve_sol == sol + # assert checksysodesol(eqs, dsolve_sol) == (True, [0, 0, 0]) + + +def test_linear_2eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + k, l, m, n = symbols('k, l, m, n', Integer=True) + t = Symbol('t') + x0, y0 = symbols('x0, y0', cls=Function) + + eq1 = (Eq(diff(x(t),t), x(t) + y(t) + 9), Eq(diff(y(t),t), 2*x(t) + 5*y(t) + 23)) + sol1 = [Eq(x(t), C1*exp(t*(sqrt(6) + 3)) + C2*exp(t*(-sqrt(6) + 3)) - Rational(22, 3)), \ + Eq(y(t), C1*(2 + sqrt(6))*exp(t*(sqrt(6) + 3)) + C2*(-sqrt(6) + 2)*exp(t*(-sqrt(6) + 3)) - Rational(5, 3))] + assert checksysodesol(eq1, sol1) == (True, [0, 0]) + + eq2 = (Eq(diff(x(t),t), x(t) + y(t) + 81), Eq(diff(y(t),t), -2*x(t) + y(t) + 23)) + sol2 = [Eq(x(t), (C1*cos(sqrt(2)*t) + C2*sin(sqrt(2)*t))*exp(t) - Rational(58, 3)), \ + Eq(y(t), (-sqrt(2)*C1*sin(sqrt(2)*t) + sqrt(2)*C2*cos(sqrt(2)*t))*exp(t) - Rational(185, 3))] + assert checksysodesol(eq2, sol2) == (True, [0, 0]) + + eq3 = (Eq(diff(x(t),t), 5*t*x(t) + 2*y(t)), Eq(diff(y(t),t), 2*x(t) + 5*t*y(t))) + sol3 = [Eq(x(t), (C1*exp(2*t) + C2*exp(-2*t))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (C1*exp(2*t) - C2*exp(-2*t))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq3, sol3) == (True, [0, 0]) + + eq4 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + 5*t*y(t))) + sol4 = [Eq(x(t), (C1*cos((t**3)/3) + C2*sin((t**3)/3))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (-C1*sin((t**3)/3) + C2*cos((t**3)/3))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq4, sol4) == (True, [0, 0]) + + eq5 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), -t**2*x(t) + (5*t+9*t**2)*y(t))) + sol5 = [Eq(x(t), (C1*exp((sqrt(77)/2 + Rational(9, 2))*(t**3)/3) + \ + C2*exp((-sqrt(77)/2 + Rational(9, 2))*(t**3)/3))*exp(Rational(5, 2)*t**2)), \ + Eq(y(t), (C1*(sqrt(77)/2 + Rational(9, 2))*exp((sqrt(77)/2 + Rational(9, 2))*(t**3)/3) + \ + C2*(-sqrt(77)/2 + Rational(9, 2))*exp((-sqrt(77)/2 + Rational(9, 2))*(t**3)/3))*exp(Rational(5, 2)*t**2))] + assert checksysodesol(eq5, sol5) == (True, [0, 0]) + + eq6 = (Eq(diff(x(t),t), 5*t*x(t) + t**2*y(t)), Eq(diff(y(t),t), (1-t**2)*x(t) + (5*t+9*t**2)*y(t))) + sol6 = [Eq(x(t), C1*x0(t) + C2*x0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t)), \ + Eq(y(t), C1*y0(t) + C2*(y0(t)*Integral(t**2*exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)**2, t) + \ + exp(Integral(5*t, t))*exp(Integral(9*t**2 + 5*t, t))/x0(t)))] + s = dsolve(eq6) + assert s == sol6 # too complicated to test with subs and simplify + # assert checksysodesol(eq10, sol10) == (True, [0, 0]) # this one fails + + +def test_nonlinear_2eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + t = Symbol('t') + eq1 = (Eq(diff(x(t),t),x(t)*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol1 = [ + Eq(x(t), C1*exp((-1/(4*C2 + 4*t))**(Rational(-1, 4)))), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(-1/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(-I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), C1*exp(I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq1) == sol1 + assert checksysodesol(eq1, sol1) == (True, [0, 0]) + + eq2 = (Eq(diff(x(t),t), exp(3*x(t))*y(t)**3),Eq(diff(y(t),t), y(t)**5)) + sol2 = [ + Eq(x(t), -log(C1 - 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 + 3/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 + 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), -log(C1 - 3*I/(-1/(4*C2 + 4*t))**Rational(1, 4))/3), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq2) == sol2 + assert checksysodesol(eq2, sol2) == (True, [0, 0]) + + eq3 = (Eq(diff(x(t),t), y(t)*x(t)), Eq(diff(y(t),t), x(t)**3)) + tt = Rational(2, 3) + sol3 = [ + Eq(x(t), 6**tt/(6*(-sinh(sqrt(C1)*(C2 + t)/2)/sqrt(C1))**tt)), + Eq(y(t), sqrt(C1 + C1/sinh(sqrt(C1)*(C2 + t)/2)**2)/3)] + assert dsolve(eq3) == sol3 + # FIXME: assert checksysodesol(eq3, sol3) == (True, [0, 0]) + + eq4 = (Eq(diff(x(t),t),x(t)*y(t)*sin(t)**2), Eq(diff(y(t),t),y(t)**2*sin(t)**2)) + sol4 = {Eq(x(t), -2*exp(C1)/(C2*exp(C1) + t - sin(2*t)/2)), Eq(y(t), -2/(C1 + t - sin(2*t)/2))} + assert dsolve(eq4) == sol4 + # FIXME: assert checksysodesol(eq4, sol4) == (True, [0, 0]) + + eq5 = (Eq(x(t),t*diff(x(t),t)+diff(x(t),t)*diff(y(t),t)), Eq(y(t),t*diff(y(t),t)+diff(y(t),t)**2)) + sol5 = {Eq(x(t), C1*C2 + C1*t), Eq(y(t), C2**2 + C2*t)} + assert dsolve(eq5) == sol5 + assert checksysodesol(eq5, sol5) == (True, [0, 0]) + + eq6 = (Eq(diff(x(t),t),x(t)**2*y(t)**3), Eq(diff(y(t),t),y(t)**5)) + sol6 = [ + Eq(x(t), 1/(C1 - 1/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 + (-1/(4*C2 + 4*t))**(Rational(-1, 4)))), + Eq(y(t), (-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 + I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), -I*(-1/(4*C2 + 4*t))**Rational(1, 4)), + Eq(x(t), 1/(C1 - I/(-1/(4*C2 + 4*t))**Rational(1, 4))), + Eq(y(t), I*(-1/(4*C2 + 4*t))**Rational(1, 4))] + assert dsolve(eq6) == sol6 + assert checksysodesol(eq6, sol6) == (True, [0, 0]) + + +@slow +def test_nonlinear_3eq_order1(): + x, y, z = symbols('x, y, z', cls=Function) + t, u = symbols('t u') + eq1 = (4*diff(x(t),t) + 2*y(t)*z(t), 3*diff(y(t),t) - z(t)*x(t), 5*diff(z(t),t) - x(t)*y(t)) + sol1 = [Eq(4*Integral(1/(sqrt(-4*u**2 - 3*C1 + C2)*sqrt(-4*u**2 + 5*C1 - C2)), (u, x(t))), + C3 - sqrt(15)*t/15), Eq(3*Integral(1/(sqrt(-6*u**2 - C1 + 5*C2)*sqrt(3*u**2 + C1 - 4*C2)), + (u, y(t))), C3 + sqrt(5)*t/10), Eq(5*Integral(1/(sqrt(-10*u**2 - 3*C1 + C2)* + sqrt(5*u**2 + 4*C1 - C2)), (u, z(t))), C3 + sqrt(3)*t/6)] + assert [i.dummy_eq(j) for i, j in zip(dsolve(eq1), sol1)] + # FIXME: assert checksysodesol(eq1, sol1) == (True, [0, 0, 0]) + + eq2 = (4*diff(x(t),t) + 2*y(t)*z(t)*sin(t), 3*diff(y(t),t) - z(t)*x(t)*sin(t), 5*diff(z(t),t) - x(t)*y(t)*sin(t)) + sol2 = [Eq(3*Integral(1/(sqrt(-6*u**2 - C1 + 5*C2)*sqrt(3*u**2 + C1 - 4*C2)), (u, x(t))), C3 + + sqrt(5)*cos(t)/10), Eq(4*Integral(1/(sqrt(-4*u**2 - 3*C1 + C2)*sqrt(-4*u**2 + 5*C1 - C2)), + (u, y(t))), C3 - sqrt(15)*cos(t)/15), Eq(5*Integral(1/(sqrt(-10*u**2 - 3*C1 + C2)* + sqrt(5*u**2 + 4*C1 - C2)), (u, z(t))), C3 + sqrt(3)*cos(t)/6)] + assert [i.dummy_eq(j) for i, j in zip(dsolve(eq2), sol2)] + # FIXME: assert checksysodesol(eq2, sol2) == (True, [0, 0, 0]) + + +def test_C1_function_9239(): + t = Symbol('t') + C1 = Function('C1') + C2 = Function('C2') + C3 = Symbol('C3') + C4 = Symbol('C4') + eq = (Eq(diff(C1(t), t), 9*C2(t)), Eq(diff(C2(t), t), 12*C1(t))) + sol = [Eq(C1(t), 9*C3*exp(6*sqrt(3)*t) + 9*C4*exp(-6*sqrt(3)*t)), + Eq(C2(t), 6*sqrt(3)*C3*exp(6*sqrt(3)*t) - 6*sqrt(3)*C4*exp(-6*sqrt(3)*t))] + assert checksysodesol(eq, sol) == (True, [0, 0]) + + +def test_dsolve_linsystem_symbol(): + eps = Symbol('epsilon', positive=True) + eq1 = (Eq(diff(f(x), x), -eps*g(x)), Eq(diff(g(x), x), eps*f(x))) + sol1 = [Eq(f(x), -C1*eps*cos(eps*x) - C2*eps*sin(eps*x)), + Eq(g(x), -C1*eps*sin(eps*x) + C2*eps*cos(eps*x))] + assert checksysodesol(eq1, sol1) == (True, [0, 0]) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__init__.py b/lib/python3.10/site-packages/sympy/solvers/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1d9a2becb162bb3d4a458412c45bebf83fd914f5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_constantsimp.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_constantsimp.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ad23cecd2cfc596a6c71b66ac554b48cffc0303e Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_constantsimp.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_decompogen.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_decompogen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e92075a3b6974a380adcc192d3b25c7afd70169 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_decompogen.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_inequalities.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_inequalities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e4a342de63cd7b24b18a2471ebeb3e4dd8141c12 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_inequalities.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_numeric.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_numeric.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a1088fea0848d28de62dfea4bbdc3bf93b095d55 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_numeric.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_pde.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_pde.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b3399b97eb9bf7632121b13a65971b654df1c4d2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_pde.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_polysys.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_polysys.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aabe676768e596316b73043dad420db1eae870ef Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_polysys.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_recurr.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_recurr.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5334cca5df3ab54dccb35eee6b75d3b0a79f2eab Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_recurr.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_simplex.cpython-310.pyc b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_simplex.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a42c7e748aa39e8bea78f24b0fb248d552322dab Binary files /dev/null and b/lib/python3.10/site-packages/sympy/solvers/tests/__pycache__/test_simplex.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_constantsimp.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_constantsimp.py new file mode 100644 index 0000000000000000000000000000000000000000..efb966a4c8c2f93558d05e7c330f06530e69180c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_constantsimp.py @@ -0,0 +1,179 @@ +""" +If the arbitrary constant class from issue 4435 is ever implemented, this +should serve as a set of test cases. +""" + +from sympy.core.function import Function +from sympy.core.numbers import I +from sympy.core.power import Pow +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.hyperbolic import (cosh, sinh) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (acos, cos, sin) +from sympy.integrals.integrals import Integral +from sympy.solvers.ode.ode import constantsimp, constant_renumber +from sympy.testing.pytest import XFAIL + + +x = Symbol('x') +y = Symbol('y') +z = Symbol('z') +u2 = Symbol('u2') +_a = Symbol('_a') +C1 = Symbol('C1') +C2 = Symbol('C2') +C3 = Symbol('C3') +f = Function('f') + + +def test_constant_mul(): + # We want C1 (Constant) below to absorb the y's, but not the x's + assert constant_renumber(constantsimp(y*C1, [C1])) == C1*y + assert constant_renumber(constantsimp(C1*y, [C1])) == C1*y + assert constant_renumber(constantsimp(x*C1, [C1])) == x*C1 + assert constant_renumber(constantsimp(C1*x, [C1])) == x*C1 + assert constant_renumber(constantsimp(2*C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1*2, [C1])) == C1 + assert constant_renumber(constantsimp(y*C1*x, [C1, y])) == C1*x + assert constant_renumber(constantsimp(x*y*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(y*x*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*x*y, [C1, y])) == C1*x + assert constant_renumber(constantsimp(x*C1*y, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*y*(y + 1), [C1])) == C1*y*(y+1) + assert constant_renumber(constantsimp(y*C1*(y + 1), [C1])) == C1*y*(y+1) + assert constant_renumber(constantsimp(x*(y*C1), [C1])) == x*y*C1 + assert constant_renumber(constantsimp(x*(C1*y), [C1])) == x*y*C1 + assert constant_renumber(constantsimp(C1*(x*y), [C1, y])) == C1*x + assert constant_renumber(constantsimp((x*y)*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp((y*x)*C1, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(y*(y + 1)*C1, [C1, y])) == C1 + assert constant_renumber(constantsimp((C1*x)*y, [C1, y])) == C1*x + assert constant_renumber(constantsimp(y*(x*C1), [C1, y])) == x*C1 + assert constant_renumber(constantsimp((x*C1)*y, [C1, y])) == x*C1 + assert constant_renumber(constantsimp(C1*x*y*x*y*2, [C1, y])) == C1*x**2 + assert constant_renumber(constantsimp(C1*x*y*z, [C1, y, z])) == C1*x + assert constant_renumber(constantsimp(C1*x*y**2*sin(z), [C1, y, z])) == C1*x + assert constant_renumber(constantsimp(C1*C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1*C1*C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1*x*2**x, [C1])) == C1*x*2**x + +def test_constant_add(): + assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + 2, [C1])) == C1 + assert constant_renumber(constantsimp(2 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + y, [C1, y])) == C1 + assert constant_renumber(constantsimp(C1 + x, [C1])) == C1 + x + assert constant_renumber(constantsimp(C1 + C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1 + C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2 + C1, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1 + C2 + C1, [C1, C2])) == C1 + + +def test_constant_power_as_base(): + assert constant_renumber(constantsimp(C1**C1, [C1])) == C1 + assert constant_renumber(constantsimp(Pow(C1, C1), [C1])) == C1 + assert constant_renumber(constantsimp(C1**C1, [C1])) == C1 + assert constant_renumber(constantsimp(C1**C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2**C1, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C2**C2, [C1, C2])) == C1 + assert constant_renumber(constantsimp(C1**y, [C1, y])) == C1 + assert constant_renumber(constantsimp(C1**x, [C1])) == C1**x + assert constant_renumber(constantsimp(C1**2, [C1])) == C1 + assert constant_renumber( + constantsimp(C1**(x*y), [C1])) == C1**(x*y) + + +def test_constant_power_as_exp(): + assert constant_renumber(constantsimp(x**C1, [C1])) == x**C1 + assert constant_renumber(constantsimp(y**C1, [C1, y])) == C1 + assert constant_renumber(constantsimp(x**y**C1, [C1, y])) == x**C1 + assert constant_renumber( + constantsimp((x**y)**C1, [C1])) == (x**y)**C1 + assert constant_renumber( + constantsimp(x**(y**C1), [C1, y])) == x**C1 + assert constant_renumber(constantsimp(x**C1**y, [C1, y])) == x**C1 + assert constant_renumber( + constantsimp(x**(C1**y), [C1, y])) == x**C1 + assert constant_renumber( + constantsimp((x**C1)**y, [C1])) == (x**C1)**y + assert constant_renumber(constantsimp(2**C1, [C1])) == C1 + assert constant_renumber(constantsimp(S(2)**C1, [C1])) == C1 + assert constant_renumber(constantsimp(exp(C1), [C1])) == C1 + assert constant_renumber( + constantsimp(exp(C1 + x), [C1])) == C1*exp(x) + assert constant_renumber(constantsimp(Pow(2, C1), [C1])) == C1 + + +def test_constant_function(): + assert constant_renumber(constantsimp(sin(C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1, C1), [C1])) == C1 + assert constant_renumber(constantsimp(f(C1, C2), [C1, C2])) == C1 + assert constant_renumber(constantsimp(f(C2, C1), [C1, C2])) == C1 + assert constant_renumber(constantsimp(f(C2, C2), [C1, C2])) == C1 + assert constant_renumber( + constantsimp(f(C1, x), [C1])) == f(C1, x) + assert constant_renumber(constantsimp(f(C1, y), [C1, y])) == C1 + assert constant_renumber(constantsimp(f(y, C1), [C1, y])) == C1 + assert constant_renumber(constantsimp(f(C1, y, C2), [C1, C2, y])) == C1 + + +def test_constant_function_multiple(): + # The rules to not renumber in this case would be too complicated, and + # dsolve is not likely to ever encounter anything remotely like this. + assert constant_renumber( + constantsimp(f(C1, C1, x), [C1])) == f(C1, C1, x) + + +def test_constant_multiple(): + assert constant_renumber(constantsimp(C1*2 + 2, [C1])) == C1 + assert constant_renumber(constantsimp(x*2/C1, [C1])) == C1*x + assert constant_renumber(constantsimp(C1**2*2 + 2, [C1])) == C1 + assert constant_renumber( + constantsimp(sin(2*C1) + x + sqrt(2), [C1])) == C1 + x + assert constant_renumber(constantsimp(2*C1 + C2, [C1, C2])) == C1 + +def test_constant_repeated(): + assert C1 + C1*x == constant_renumber( C1 + C1*x) + +def test_ode_solutions(): + # only a few examples here, the rest will be tested in the actual dsolve tests + assert constant_renumber(constantsimp(C1*exp(2*x) + exp(x)*(C2 + C3), [C1, C2, C3])) == \ + constant_renumber(C1*exp(x) + C2*exp(2*x)) + assert constant_renumber( + constantsimp(Eq(f(x), I*C1*sinh(x/3) + C2*cosh(x/3)), [C1, C2]) + ) == constant_renumber(Eq(f(x), C1*sinh(x/3) + C2*cosh(x/3))) + assert constant_renumber(constantsimp(Eq(f(x), acos((-C1)/cos(x))), [C1])) == \ + Eq(f(x), acos(C1/cos(x))) + assert constant_renumber( + constantsimp(Eq(log(f(x)/C1) + 2*exp(x/f(x)), 0), [C1]) + ) == Eq(log(C1*f(x)) + 2*exp(x/f(x)), 0) + assert constant_renumber(constantsimp(Eq(log(x*sqrt(2)*sqrt(1/x)*sqrt(f(x)) + /C1) + x**2/(2*f(x)**2), 0), [C1])) == \ + Eq(log(C1*sqrt(x)*sqrt(f(x))) + x**2/(2*f(x)**2), 0) + assert constant_renumber(constantsimp(Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(x/C1) - + cos(f(x)/x)*exp(-f(x)/x)/2, 0), [C1])) == \ + Eq(-exp(-f(x)/x)*sin(f(x)/x)/2 + log(C1*x) - cos(f(x)/x)* + exp(-f(x)/x)/2, 0) + assert constant_renumber(constantsimp(Eq(-Integral(-1/(sqrt(1 - u2**2)*u2), + (u2, _a, x/f(x))) + log(f(x)/C1), 0), [C1])) == \ + Eq(-Integral(-1/(u2*sqrt(1 - u2**2)), (u2, _a, x/f(x))) + + log(C1*f(x)), 0) + assert [constantsimp(i, [C1]) for i in [Eq(f(x), sqrt(-C1*x + x**2)), Eq(f(x), -sqrt(-C1*x + x**2))]] == \ + [Eq(f(x), sqrt(x*(C1 + x))), Eq(f(x), -sqrt(x*(C1 + x)))] + + +@XFAIL +def test_nonlocal_simplification(): + assert constantsimp(C1 + C2+x*C2, [C1, C2]) == C1 + C2*x + + +def test_constant_Eq(): + # C1 on the rhs is well-tested, but the lhs is only tested here + assert constantsimp(Eq(C1, 3 + f(x)*x), [C1]) == Eq(x*f(x), C1) + assert constantsimp(Eq(C1, 3 * f(x)*x), [C1]) == Eq(f(x)*x, C1) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_decompogen.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_decompogen.py new file mode 100644 index 0000000000000000000000000000000000000000..1ba03f4b42558231b626b6ed169f8b0a81a72bf9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_decompogen.py @@ -0,0 +1,59 @@ +from sympy.solvers.decompogen import decompogen, compogen +from sympy.core.symbol import symbols +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp +from sympy.functions.elementary.miscellaneous import sqrt, Max +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import XFAIL, raises + +x, y = symbols('x y') + + +def test_decompogen(): + assert decompogen(sin(cos(x)), x) == [sin(x), cos(x)] + assert decompogen(sin(x)**2 + sin(x) + 1, x) == [x**2 + x + 1, sin(x)] + assert decompogen(sqrt(6*x**2 - 5), x) == [sqrt(x), 6*x**2 - 5] + assert decompogen(sin(sqrt(cos(x**2 + 1))), x) == [sin(x), sqrt(x), cos(x), x**2 + 1] + assert decompogen(Abs(cos(x)**2 + 3*cos(x) - 4), x) == [Abs(x), x**2 + 3*x - 4, cos(x)] + assert decompogen(sin(x)**2 + sin(x) - sqrt(3)/2, x) == [x**2 + x - sqrt(3)/2, sin(x)] + assert decompogen(Abs(cos(y)**2 + 3*cos(x) - 4), x) == [Abs(x), 3*x + cos(y)**2 - 4, cos(x)] + assert decompogen(x, y) == [x] + assert decompogen(1, x) == [1] + assert decompogen(Max(3, x), x) == [Max(3, x)] + raises(TypeError, lambda: decompogen(x < 5, x)) + u = 2*x + 3 + assert decompogen(Max(sqrt(u),(u)**2), x) == [Max(sqrt(x), x**2), u] + assert decompogen(Max(u, u**2, y), x) == [Max(x, x**2, y), u] + assert decompogen(Max(sin(x), u), x) == [Max(2*x + 3, sin(x))] + + +def test_decompogen_poly(): + assert decompogen(x**4 + 2*x**2 + 1, x) == [x**2 + 2*x + 1, x**2] + assert decompogen(x**4 + 2*x**3 - x - 1, x) == [x**2 - x - 1, x**2 + x] + + +@XFAIL +def test_decompogen_fails(): + A = lambda x: x**2 + 2*x + 3 + B = lambda x: 4*x**2 + 5*x + 6 + assert decompogen(A(x*exp(x)), x) == [x**2 + 2*x + 3, x*exp(x)] + assert decompogen(A(B(x)), x) == [x**2 + 2*x + 3, 4*x**2 + 5*x + 6] + assert decompogen(A(1/x + 1/x**2), x) == [x**2 + 2*x + 3, 1/x + 1/x**2] + assert decompogen(A(1/x + 2/(x + 1)), x) == [x**2 + 2*x + 3, 1/x + 2/(x + 1)] + + +def test_compogen(): + assert compogen([sin(x), cos(x)], x) == sin(cos(x)) + assert compogen([x**2 + x + 1, sin(x)], x) == sin(x)**2 + sin(x) + 1 + assert compogen([sqrt(x), 6*x**2 - 5], x) == sqrt(6*x**2 - 5) + assert compogen([sin(x), sqrt(x), cos(x), x**2 + 1], x) == sin(sqrt( + cos(x**2 + 1))) + assert compogen([Abs(x), x**2 + 3*x - 4, cos(x)], x) == Abs(cos(x)**2 + + 3*cos(x) - 4) + assert compogen([x**2 + x - sqrt(3)/2, sin(x)], x) == (sin(x)**2 + sin(x) - + sqrt(3)/2) + assert compogen([Abs(x), 3*x + cos(y)**2 - 4, cos(x)], x) == \ + Abs(3*cos(x) + cos(y)**2 - 4) + assert compogen([x**2 + 2*x + 1, x**2], x) == x**4 + 2*x**2 + 1 + # the result is in unsimplified form + assert compogen([x**2 - x - 1, x**2 + x], x) == -x**2 - x + (x**2 + x)**2 - 1 diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_inequalities.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_inequalities.py new file mode 100644 index 0000000000000000000000000000000000000000..6ce6f4520b52d8714102c95457c90d44543c685c --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_inequalities.py @@ -0,0 +1,500 @@ +"""Tests for tools for solving inequalities and systems of inequalities. """ + +from sympy.concrete.summations import Sum +from sympy.core.function import Function +from sympy.core.numbers import I, Rational, oo, pi +from sympy.core.relational import Eq, Ge, Gt, Le, Lt, Ne +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.miscellaneous import root, sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import cos, sin, tan +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import And, Or +from sympy.polys.polytools import Poly, PurePoly +from sympy.sets.sets import FiniteSet, Interval, Union +from sympy.solvers.inequalities import (reduce_inequalities, + solve_poly_inequality as psolve, + reduce_rational_inequalities, + solve_univariate_inequality as isolve, + reduce_abs_inequality, + _solve_inequality) +from sympy.polys.rootoftools import rootof +from sympy.solvers.solvers import solve +from sympy.solvers.solveset import solveset +from sympy.core.mod import Mod +from sympy.abc import x, y + +from sympy.testing.pytest import raises, XFAIL + + +inf = oo.evalf() + + +def test_solve_poly_inequality(): + assert psolve(Poly(0, x), '==') == [S.Reals] + assert psolve(Poly(1, x), '==') == [S.EmptySet] + assert psolve(PurePoly(x + 1, x), ">") == [Interval(-1, oo, True, False)] + + +def test_reduce_poly_inequalities_real_interval(): + assert reduce_rational_inequalities( + [[Eq(x**2, 0)]], x, relational=False) == FiniteSet(0) + assert reduce_rational_inequalities( + [[Le(x**2, 0)]], x, relational=False) == FiniteSet(0) + assert reduce_rational_inequalities( + [[Lt(x**2, 0)]], x, relational=False) == S.EmptySet + assert reduce_rational_inequalities( + [[Ge(x**2, 0)]], x, relational=False) == \ + S.Reals if x.is_real else Interval(-oo, oo) + assert reduce_rational_inequalities( + [[Gt(x**2, 0)]], x, relational=False) == \ + FiniteSet(0).complement(S.Reals) + assert reduce_rational_inequalities( + [[Ne(x**2, 0)]], x, relational=False) == \ + FiniteSet(0).complement(S.Reals) + + assert reduce_rational_inequalities( + [[Eq(x**2, 1)]], x, relational=False) == FiniteSet(-1, 1) + assert reduce_rational_inequalities( + [[Le(x**2, 1)]], x, relational=False) == Interval(-1, 1) + assert reduce_rational_inequalities( + [[Lt(x**2, 1)]], x, relational=False) == Interval(-1, 1, True, True) + assert reduce_rational_inequalities( + [[Ge(x**2, 1)]], x, relational=False) == \ + Union(Interval(-oo, -1), Interval(1, oo)) + assert reduce_rational_inequalities( + [[Gt(x**2, 1)]], x, relational=False) == \ + Interval(-1, 1).complement(S.Reals) + assert reduce_rational_inequalities( + [[Ne(x**2, 1)]], x, relational=False) == \ + FiniteSet(-1, 1).complement(S.Reals) + assert reduce_rational_inequalities([[Eq( + x**2, 1.0)]], x, relational=False) == FiniteSet(-1.0, 1.0).evalf() + assert reduce_rational_inequalities( + [[Le(x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0) + assert reduce_rational_inequalities([[Lt( + x**2, 1.0)]], x, relational=False) == Interval(-1.0, 1.0, True, True) + assert reduce_rational_inequalities( + [[Ge(x**2, 1.0)]], x, relational=False) == \ + Union(Interval(-inf, -1.0), Interval(1.0, inf)) + assert reduce_rational_inequalities( + [[Gt(x**2, 1.0)]], x, relational=False) == \ + Union(Interval(-inf, -1.0, right_open=True), + Interval(1.0, inf, left_open=True)) + assert reduce_rational_inequalities([[Ne( + x**2, 1.0)]], x, relational=False) == \ + FiniteSet(-1.0, 1.0).complement(S.Reals) + + s = sqrt(2) + + assert reduce_rational_inequalities([[Lt( + x**2 - 1, 0), Gt(x**2 - 1, 0)]], x, relational=False) == S.EmptySet + assert reduce_rational_inequalities([[Le(x**2 - 1, 0), Ge( + x**2 - 1, 0)]], x, relational=False) == FiniteSet(-1, 1) + assert reduce_rational_inequalities( + [[Le(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, False, False), Interval(1, s, False, False)) + assert reduce_rational_inequalities( + [[Le(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, False, True), Interval(1, s, True, False)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Ge(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, False), Interval(1, s, False, True)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, True), Interval(1, s, True, True)) + assert reduce_rational_inequalities( + [[Lt(x**2 - 2, 0), Ne(x**2 - 1, 0)]], x, relational=False + ) == Union(Interval(-s, -1, True, True), Interval(-1, 1, True, True), + Interval(1, s, True, True)) + + assert reduce_rational_inequalities([[Lt(x**2, -1.)]], x) is S.false + + +def test_reduce_poly_inequalities_complex_relational(): + assert reduce_rational_inequalities( + [[Eq(x**2, 0)]], x, relational=True) == Eq(x, 0) + assert reduce_rational_inequalities( + [[Le(x**2, 0)]], x, relational=True) == Eq(x, 0) + assert reduce_rational_inequalities( + [[Lt(x**2, 0)]], x, relational=True) == False + assert reduce_rational_inequalities( + [[Ge(x**2, 0)]], x, relational=True) == And(Lt(-oo, x), Lt(x, oo)) + assert reduce_rational_inequalities( + [[Gt(x**2, 0)]], x, relational=True) == \ + And(Gt(x, -oo), Lt(x, oo), Ne(x, 0)) + assert reduce_rational_inequalities( + [[Ne(x**2, 0)]], x, relational=True) == \ + And(Gt(x, -oo), Lt(x, oo), Ne(x, 0)) + + for one in (S.One, S(1.0)): + inf = one*oo + assert reduce_rational_inequalities( + [[Eq(x**2, one)]], x, relational=True) == \ + Or(Eq(x, -one), Eq(x, one)) + assert reduce_rational_inequalities( + [[Le(x**2, one)]], x, relational=True) == \ + And(And(Le(-one, x), Le(x, one))) + assert reduce_rational_inequalities( + [[Lt(x**2, one)]], x, relational=True) == \ + And(And(Lt(-one, x), Lt(x, one))) + assert reduce_rational_inequalities( + [[Ge(x**2, one)]], x, relational=True) == \ + And(Or(And(Le(one, x), Lt(x, inf)), And(Le(x, -one), Lt(-inf, x)))) + assert reduce_rational_inequalities( + [[Gt(x**2, one)]], x, relational=True) == \ + And(Or(And(Lt(-inf, x), Lt(x, -one)), And(Lt(one, x), Lt(x, inf)))) + assert reduce_rational_inequalities( + [[Ne(x**2, one)]], x, relational=True) == \ + Or(And(Lt(-inf, x), Lt(x, -one)), + And(Lt(-one, x), Lt(x, one)), + And(Lt(one, x), Lt(x, inf))) + + +def test_reduce_rational_inequalities_real_relational(): + assert reduce_rational_inequalities([], x) == False + assert reduce_rational_inequalities( + [[(x**2 + 3*x + 2)/(x**2 - 16) >= 0]], x, relational=False) == \ + Union(Interval.open(-oo, -4), Interval(-2, -1), Interval.open(4, oo)) + + assert reduce_rational_inequalities( + [[((-2*x - 10)*(3 - x))/((x**2 + 5)*(x - 2)**2) < 0]], x, + relational=False) == \ + Union(Interval.open(-5, 2), Interval.open(2, 3)) + + assert reduce_rational_inequalities([[(x + 1)/(x - 5) <= 0]], x, + relational=False) == \ + Interval.Ropen(-1, 5) + + assert reduce_rational_inequalities([[(x**2 + 4*x + 3)/(x - 1) > 0]], x, + relational=False) == \ + Union(Interval.open(-3, -1), Interval.open(1, oo)) + + assert reduce_rational_inequalities([[(x**2 - 16)/(x - 1)**2 < 0]], x, + relational=False) == \ + Union(Interval.open(-4, 1), Interval.open(1, 4)) + + assert reduce_rational_inequalities([[(3*x + 1)/(x + 4) >= 1]], x, + relational=False) == \ + Union(Interval.open(-oo, -4), Interval.Ropen(Rational(3, 2), oo)) + + assert reduce_rational_inequalities([[(x - 8)/x <= 3 - x]], x, + relational=False) == \ + Union(Interval.Lopen(-oo, -2), Interval.Lopen(0, 4)) + + # issue sympy/sympy#10237 + assert reduce_rational_inequalities( + [[x < oo, x >= 0, -oo < x]], x, relational=False) == Interval(0, oo) + + +def test_reduce_abs_inequalities(): + e = abs(x - 5) < 3 + ans = And(Lt(2, x), Lt(x, 8)) + assert reduce_inequalities(e) == ans + assert reduce_inequalities(e, x) == ans + assert reduce_inequalities(abs(x - 5)) == Eq(x, 5) + assert reduce_inequalities( + abs(2*x + 3) >= 8) == Or(And(Le(Rational(5, 2), x), Lt(x, oo)), + And(Le(x, Rational(-11, 2)), Lt(-oo, x))) + assert reduce_inequalities(abs(x - 4) + abs( + 3*x - 5) < 7) == And(Lt(S.Half, x), Lt(x, 4)) + assert reduce_inequalities(abs(x - 4) + abs(3*abs(x) - 5) < 7) == \ + Or(And(S(-2) < x, x < -1), And(S.Half < x, x < 4)) + + nr = Symbol('nr', extended_real=False) + raises(TypeError, lambda: reduce_inequalities(abs(nr - 5) < 3)) + assert reduce_inequalities(x < 3, symbols=[x, nr]) == And(-oo < x, x < 3) + + +def test_reduce_inequalities_general(): + assert reduce_inequalities(Ge(sqrt(2)*x, 1)) == And(sqrt(2)/2 <= x, x < oo) + assert reduce_inequalities(x + 1 > 0) == And(S.NegativeOne < x, x < oo) + + +def test_reduce_inequalities_boolean(): + assert reduce_inequalities( + [Eq(x**2, 0), True]) == Eq(x, 0) + assert reduce_inequalities([Eq(x**2, 0), False]) == False + assert reduce_inequalities(x**2 >= 0) is S.true # issue 10196 + + +def test_reduce_inequalities_multivariate(): + assert reduce_inequalities([Ge(x**2, 1), Ge(y**2, 1)]) == And( + Or(And(Le(S.One, x), Lt(x, oo)), And(Le(x, -1), Lt(-oo, x))), + Or(And(Le(S.One, y), Lt(y, oo)), And(Le(y, -1), Lt(-oo, y)))) + + +def test_reduce_inequalities_errors(): + raises(NotImplementedError, lambda: reduce_inequalities(Ge(sin(x) + x, 1))) + raises(NotImplementedError, lambda: reduce_inequalities(Ge(x**2*y + y, 1))) + + +def test__solve_inequalities(): + assert reduce_inequalities(x + y < 1, symbols=[x]) == (x < 1 - y) + assert reduce_inequalities(x + y >= 1, symbols=[x]) == (x < oo) & (x >= -y + 1) + assert reduce_inequalities(Eq(0, x - y), symbols=[x]) == Eq(x, y) + assert reduce_inequalities(Ne(0, x - y), symbols=[x]) == Ne(x, y) + + +def test_issue_6343(): + eq = -3*x**2/2 - x*Rational(45, 4) + Rational(33, 2) > 0 + assert reduce_inequalities(eq) == \ + And(x < Rational(-15, 4) + sqrt(401)/4, -sqrt(401)/4 - Rational(15, 4) < x) + + +def test_issue_8235(): + assert reduce_inequalities(x**2 - 1 < 0) == \ + And(S.NegativeOne < x, x < 1) + assert reduce_inequalities(x**2 - 1 <= 0) == \ + And(S.NegativeOne <= x, x <= 1) + assert reduce_inequalities(x**2 - 1 > 0) == \ + Or(And(-oo < x, x < -1), And(x < oo, S.One < x)) + assert reduce_inequalities(x**2 - 1 >= 0) == \ + Or(And(-oo < x, x <= -1), And(S.One <= x, x < oo)) + + eq = x**8 + x - 9 # we want CRootOf solns here + sol = solve(eq >= 0) + tru = Or(And(rootof(eq, 1) <= x, x < oo), And(-oo < x, x <= rootof(eq, 0))) + assert sol == tru + + # recast vanilla as real + assert solve(sqrt((-x + 1)**2) < 1) == And(S.Zero < x, x < 2) + + +def test_issue_5526(): + assert reduce_inequalities(0 <= + x + Integral(y**2, (y, 1, 3)) - 1, [x]) == \ + (x >= -Integral(y**2, (y, 1, 3)) + 1) + f = Function('f') + e = Sum(f(x), (x, 1, 3)) + assert reduce_inequalities(0 <= x + e + y**2, [x]) == \ + (x >= -y**2 - Sum(f(x), (x, 1, 3))) + + +def test_solve_univariate_inequality(): + assert isolve(x**2 >= 4, x, relational=False) == Union(Interval(-oo, -2), + Interval(2, oo)) + assert isolve(x**2 >= 4, x) == Or(And(Le(2, x), Lt(x, oo)), And(Le(x, -2), + Lt(-oo, x))) + assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x, relational=False) == \ + Union(Interval(1, 2), Interval(3, oo)) + assert isolve((x - 1)*(x - 2)*(x - 3) >= 0, x) == \ + Or(And(Le(1, x), Le(x, 2)), And(Le(3, x), Lt(x, oo))) + assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain = FiniteSet(0, 3)) == \ + Or(Eq(x, 0), Eq(x, 3)) + # issue 2785: + assert isolve(x**3 - 2*x - 1 > 0, x, relational=False) == \ + Union(Interval(-1, -sqrt(5)/2 + S.Half, True, True), + Interval(S.Half + sqrt(5)/2, oo, True, True)) + # issue 2794: + assert isolve(x**3 - x**2 + x - 1 > 0, x, relational=False) == \ + Interval(1, oo, True) + #issue 13105 + assert isolve((x + I)*(x + 2*I) < 0, x) == Eq(x, 0) + assert isolve(((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I) < 0, x) == Or(Eq(x, 1), Eq(x, 2)) + assert isolve((((x - 1)*(x - 2) + I)*((x - 1)*(x - 2) + 2*I))/(x - 2) > 0, x) == Eq(x, 1) + raises (ValueError, lambda: isolve((x**2 - 3*x*I + 2)/x < 0, x)) + + # numerical testing in valid() is needed + assert isolve(x**7 - x - 2 > 0, x) == \ + And(rootof(x**7 - x - 2, 0) < x, x < oo) + + # handle numerator and denominator; although these would be handled as + # rational inequalities, these test confirm that the right thing is done + # when the domain is EX (e.g. when 2 is replaced with sqrt(2)) + assert isolve(1/(x - 2) > 0, x) == And(S(2) < x, x < oo) + den = ((x - 1)*(x - 2)).expand() + assert isolve((x - 1)/den <= 0, x) == \ + (x > -oo) & (x < 2) & Ne(x, 1) + + n = Dummy('n') + raises(NotImplementedError, lambda: isolve(Abs(x) <= n, x, relational=False)) + c1 = Dummy("c1", positive=True) + raises(NotImplementedError, lambda: isolve(n/c1 < 0, c1)) + n = Dummy('n', negative=True) + assert isolve(n/c1 > -2, c1) == (-n/2 < c1) + assert isolve(n/c1 < 0, c1) == True + assert isolve(n/c1 > 0, c1) == False + + zero = cos(1)**2 + sin(1)**2 - 1 + raises(NotImplementedError, lambda: isolve(x**2 < zero, x)) + raises(NotImplementedError, lambda: isolve( + x**2 < zero*I, x)) + raises(NotImplementedError, lambda: isolve(1/(x - y) < 2, x)) + raises(NotImplementedError, lambda: isolve(1/(x - y) < 0, x)) + raises(TypeError, lambda: isolve(x - I < 0, x)) + + zero = x**2 + x - x*(x + 1) + assert isolve(zero < 0, x, relational=False) is S.EmptySet + assert isolve(zero <= 0, x, relational=False) is S.Reals + + # make sure iter_solutions gets a default value + raises(NotImplementedError, lambda: isolve( + Eq(cos(x)**2 + sin(x)**2, 1), x)) + + +def test_trig_inequalities(): + # all the inequalities are solved in a periodic interval. + assert isolve(sin(x) < S.Half, x, relational=False) == \ + Union(Interval(0, pi/6, False, True), Interval.open(pi*Rational(5, 6), 2*pi)) + assert isolve(sin(x) > S.Half, x, relational=False) == \ + Interval(pi/6, pi*Rational(5, 6), True, True) + assert isolve(cos(x) < S.Zero, x, relational=False) == \ + Interval(pi/2, pi*Rational(3, 2), True, True) + assert isolve(cos(x) >= S.Zero, x, relational=False) == \ + Union(Interval(0, pi/2), Interval.Ropen(pi*Rational(3, 2), 2*pi)) + + assert isolve(tan(x) < S.One, x, relational=False) == \ + Union(Interval.Ropen(0, pi/4), Interval.open(pi/2, pi)) + + assert isolve(sin(x) <= S.Zero, x, relational=False) == \ + Union(FiniteSet(S.Zero), Interval.Ropen(pi, 2*pi)) + + assert isolve(sin(x) <= S.One, x, relational=False) == S.Reals + assert isolve(cos(x) < S(-2), x, relational=False) == S.EmptySet + assert isolve(sin(x) >= S.NegativeOne, x, relational=False) == S.Reals + assert isolve(cos(x) > S.One, x, relational=False) == S.EmptySet + + +def test_issue_9954(): + assert isolve(x**2 >= 0, x, relational=False) == S.Reals + assert isolve(x**2 >= 0, x, relational=True) == S.Reals.as_relational(x) + assert isolve(x**2 < 0, x, relational=False) == S.EmptySet + assert isolve(x**2 < 0, x, relational=True) == S.EmptySet.as_relational(x) + + +@XFAIL +def test_slow_general_univariate(): + r = rootof(x**5 - x**2 + 1, 0) + assert solve(sqrt(x) + 1/root(x, 3) > 1) == \ + Or(And(0 < x, x < r**6), And(r**6 < x, x < oo)) + + +def test_issue_8545(): + eq = 1 - x - abs(1 - x) + ans = And(Lt(1, x), Lt(x, oo)) + assert reduce_abs_inequality(eq, '<', x) == ans + eq = 1 - x - sqrt((1 - x)**2) + assert reduce_inequalities(eq < 0) == ans + + +def test_issue_8974(): + assert isolve(-oo < x, x) == And(-oo < x, x < oo) + assert isolve(oo > x, x) == And(-oo < x, x < oo) + + +def test_issue_10198(): + assert reduce_inequalities( + -1 + 1/abs(1/x - 1) < 0) == (x > -oo) & (x < S(1)/2) & Ne(x, 0) + + assert reduce_inequalities(abs(1/sqrt(x)) - 1, x) == Eq(x, 1) + assert reduce_abs_inequality(-3 + 1/abs(1 - 1/x), '<', x) == \ + Or(And(-oo < x, x < 0), + And(S.Zero < x, x < Rational(3, 4)), And(Rational(3, 2) < x, x < oo)) + raises(ValueError,lambda: reduce_abs_inequality(-3 + 1/abs( + 1 - 1/sqrt(x)), '<', x)) + + +def test_issue_10047(): + # issue 10047: this must remain an inequality, not True, since if x + # is not real the inequality is invalid + # assert solve(sin(x) < 2) == (x <= oo) + + # with PR 16956, (x <= oo) autoevaluates when x is extended_real + # which is assumed in the current implementation of inequality solvers + assert solve(sin(x) < 2) == True + assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals + + +def test_issue_10268(): + assert solve(log(x) < 1000) == And(S.Zero < x, x < exp(1000)) + + +@XFAIL +def test_isolve_Sets(): + n = Dummy('n') + assert isolve(Abs(x) <= n, x, relational=False) == \ + Piecewise((S.EmptySet, n < 0), (Interval(-n, n), True)) + + +def test_integer_domain_relational_isolve(): + + dom = FiniteSet(0, 3) + x = Symbol('x',zero=False) + assert isolve((x - 1)*(x - 2)*(x - 4) < 0, x, domain=dom) == Eq(x, 3) + + x = Symbol('x') + assert isolve(x + 2 < 0, x, domain=S.Integers) == \ + (x <= -3) & (x > -oo) & Eq(Mod(x, 1), 0) + assert isolve(2 * x + 3 > 0, x, domain=S.Integers) == \ + (x >= -1) & (x < oo) & Eq(Mod(x, 1), 0) + assert isolve((x ** 2 + 3 * x - 2) < 0, x, domain=S.Integers) == \ + (x >= -3) & (x <= 0) & Eq(Mod(x, 1), 0) + assert isolve((x ** 2 + 3 * x - 2) > 0, x, domain=S.Integers) == \ + ((x >= 1) & (x < oo) & Eq(Mod(x, 1), 0)) | ( + (x <= -4) & (x > -oo) & Eq(Mod(x, 1), 0)) + + +def test_issue_10671_12466(): + assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi) + i = Interval(1, 10) + assert solveset((1/x).diff(x) < 0, x, i) == i + assert solveset((log(x - 6)/x) <= 0, x, S.Reals) == \ + Interval.Lopen(6, 7) + + +def test__solve_inequality(): + for op in (Gt, Lt, Le, Ge, Eq, Ne): + assert _solve_inequality(op(x, 1), x).lhs == x + assert _solve_inequality(op(S.One, x), x).lhs == x + # don't get tricked by symbol on right: solve it + assert _solve_inequality(Eq(2*x - 1, x), x) == Eq(x, 1) + ie = Eq(S.One, y) + assert _solve_inequality(ie, x) == ie + for fx in (x**2, exp(x), sin(x) + cos(x), x*(1 + x)): + for c in (0, 1): + e = 2*fx - c > 0 + assert _solve_inequality(e, x, linear=True) == ( + fx > c/S(2)) + assert _solve_inequality(2*x**2 + 2*x - 1 < 0, x, linear=True) == ( + x*(x + 1) < S.Half) + assert _solve_inequality(Eq(x*y, 1), x) == Eq(x*y, 1) + nz = Symbol('nz', nonzero=True) + assert _solve_inequality(Eq(x*nz, 1), x) == Eq(x, 1/nz) + assert _solve_inequality(x*nz < 1, x) == (x*nz < 1) + a = Symbol('a', positive=True) + assert _solve_inequality(a/x > 1, x) == (S.Zero < x) & (x < a) + assert _solve_inequality(a/x > 1, x, linear=True) == (1/x > 1/a) + # make sure to include conditions under which solution is valid + e = Eq(1 - x, x*(1/x - 1)) + assert _solve_inequality(e, x) == Ne(x, 0) + assert _solve_inequality(x < x*(1/x - 1), x) == (x < S.Half) & Ne(x, 0) + + +def test__pt(): + from sympy.solvers.inequalities import _pt + assert _pt(-oo, oo) == 0 + assert _pt(S.One, S(3)) == 2 + assert _pt(S.One, oo) == _pt(oo, S.One) == 2 + assert _pt(S.One, -oo) == _pt(-oo, S.One) == S.Half + assert _pt(S.NegativeOne, oo) == _pt(oo, S.NegativeOne) == Rational(-1, 2) + assert _pt(S.NegativeOne, -oo) == _pt(-oo, S.NegativeOne) == -2 + assert _pt(x, oo) == _pt(oo, x) == x + 1 + assert _pt(x, -oo) == _pt(-oo, x) == x - 1 + raises(ValueError, lambda: _pt(Dummy('i', infinite=True), S.One)) + + +def test_issue_25697(): + assert _solve_inequality(log(x, 3) <= 2, x) == (x <= 9) & (S.Zero < x) + + +def test_issue_25738(): + assert reduce_inequalities(3 < abs(x) + ) == reduce_inequalities(pi < abs(x)).subs(pi, 3) + + +def test_issue_25983(): + assert(reduce_inequalities(pi/Abs(x) <= 1) == ((pi <= x) & (x < oo)) | ((-oo < x) & (x <= -pi))) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_numeric.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_numeric.py new file mode 100644 index 0000000000000000000000000000000000000000..f40bab6965233b82984148960a62ed57a7ddb178 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_numeric.py @@ -0,0 +1,139 @@ +from sympy.core.function import nfloat +from sympy.core.numbers import (Float, I, Rational, pi) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import sin +from sympy.integrals.integrals import Integral +from sympy.matrices.dense import Matrix +from mpmath import mnorm, mpf +from sympy.solvers import nsolve +from sympy.utilities.lambdify import lambdify +from sympy.testing.pytest import raises, XFAIL +from sympy.utilities.decorator import conserve_mpmath_dps + +@XFAIL +def test_nsolve_fail(): + x = symbols('x') + # Sometimes it is better to use the numerator (issue 4829) + # but sometimes it is not (issue 11768) so leave this to + # the discretion of the user + ans = nsolve(x**2/(1 - x)/(1 - 2*x)**2 - 100, x, 0) + assert ans > 0.46 and ans < 0.47 + + +def test_nsolve_denominator(): + x = symbols('x') + # Test that nsolve uses the full expression (numerator and denominator). + ans = nsolve((x**2 + 3*x + 2)/(x + 2), -2.1) + # The root -2 was divided out, so make sure we don't find it. + assert ans == -1.0 + +def test_nsolve(): + # onedimensional + x = Symbol('x') + assert nsolve(sin(x), 2) - pi.evalf() < 1e-15 + assert nsolve(Eq(2*x, 2), x, -10) == nsolve(2*x - 2, -10) + # Testing checks on number of inputs + raises(TypeError, lambda: nsolve(Eq(2*x, 2))) + raises(TypeError, lambda: nsolve(Eq(2*x, 2), x, 1, 2)) + # multidimensional + x1 = Symbol('x1') + x2 = Symbol('x2') + f1 = 3 * x1**2 - 2 * x2**2 - 1 + f2 = x1**2 - 2 * x1 + x2**2 + 2 * x2 - 8 + f = Matrix((f1, f2)).T + F = lambdify((x1, x2), f.T, modules='mpmath') + for x0 in [(-1, 1), (1, -2), (4, 4), (-4, -4)]: + x = nsolve(f, (x1, x2), x0, tol=1.e-8) + assert mnorm(F(*x), 1) <= 1.e-10 + # The Chinese mathematician Zhu Shijie was the very first to solve this + # nonlinear system 700 years ago (z was added to make it 3-dimensional) + x = Symbol('x') + y = Symbol('y') + z = Symbol('z') + f1 = -x + 2*y + f2 = (x**2 + x*(y**2 - 2) - 4*y) / (x + 4) + f3 = sqrt(x**2 + y**2)*z + f = Matrix((f1, f2, f3)).T + F = lambdify((x, y, z), f.T, modules='mpmath') + + def getroot(x0): + root = nsolve(f, (x, y, z), x0) + assert mnorm(F(*root), 1) <= 1.e-8 + return root + assert list(map(round, getroot((1, 1, 1)))) == [2, 1, 0] + assert nsolve([Eq( + f1, 0), Eq(f2, 0), Eq(f3, 0)], [x, y, z], (1, 1, 1)) # just see that it works + a = Symbol('a') + assert abs(nsolve(1/(0.001 + a)**3 - 6/(0.9 - a)**3, a, 0.3) - + mpf('0.31883011387318591')) < 1e-15 + + +def test_issue_6408(): + x = Symbol('x') + assert nsolve(Piecewise((x, x < 1), (x**2, True)), x, 2) == 0.0 + + +def test_issue_6408_integral(): + x, y = symbols('x y') + assert nsolve(Integral(x*y, (x, 0, 5)), y, 2) == 0.0 + + +@conserve_mpmath_dps +def test_increased_dps(): + # Issue 8564 + import mpmath + mpmath.mp.dps = 128 + x = Symbol('x') + e1 = x**2 - pi + q = nsolve(e1, x, 3.0) + + assert abs(sqrt(pi).evalf(128) - q) < 1e-128 + +def test_nsolve_precision(): + x, y = symbols('x y') + sol = nsolve(x**2 - pi, x, 3, prec=128) + assert abs(sqrt(pi).evalf(128) - sol) < 1e-128 + assert isinstance(sol, Float) + + sols = nsolve((y**2 - x, x**2 - pi), (x, y), (3, 3), prec=128) + assert isinstance(sols, Matrix) + assert sols.shape == (2, 1) + assert abs(sqrt(pi).evalf(128) - sols[0]) < 1e-128 + assert abs(sqrt(sqrt(pi)).evalf(128) - sols[1]) < 1e-128 + assert all(isinstance(i, Float) for i in sols) + +def test_nsolve_complex(): + x, y = symbols('x y') + + assert nsolve(x**2 + 2, 1j) == sqrt(2.)*I + assert nsolve(x**2 + 2, I) == sqrt(2.)*I + + assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I]) + assert nsolve([x**2 + 2, y**2 + 2], [x, y], [I, I]) == Matrix([sqrt(2.)*I, sqrt(2.)*I]) + +def test_nsolve_dict_kwarg(): + x, y = symbols('x y') + # one variable + assert nsolve(x**2 - 2, 1, dict = True) == \ + [{x: sqrt(2.)}] + # one variable with complex solution + assert nsolve(x**2 + 2, I, dict = True) == \ + [{x: sqrt(2.)*I}] + # two variables + assert nsolve([x**2 + y**2 - 5, x**2 - y**2 + 1], [x, y], [1, 1], dict = True) == \ + [{x: sqrt(2.), y: sqrt(3.)}] + +def test_nsolve_rational(): + x = symbols('x') + assert nsolve(x - Rational(1, 3), 0, prec=100) == Rational(1, 3).evalf(100) + + +def test_issue_14950(): + x = Matrix(symbols('t s')) + x0 = Matrix([17, 23]) + eqn = x + x0 + assert nsolve(eqn, x, x0) == nfloat(-x0) + assert nsolve(eqn.T, x.T, x0.T) == nfloat(-x0) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_pde.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_pde.py new file mode 100644 index 0000000000000000000000000000000000000000..948d90c7be21a9e0e03753e723ef04f1fb08a5d6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_pde.py @@ -0,0 +1,239 @@ +from sympy.core.function import (Derivative as D, Function) +from sympy.core.relational import Eq +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.elementary.exponential import (exp, log) +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.core import S +from sympy.solvers.pde import (pde_separate, pde_separate_add, pde_separate_mul, + pdsolve, classify_pde, checkpdesol) +from sympy.testing.pytest import raises + + +a, b, c, x, y = symbols('a b c x y') + +def test_pde_separate_add(): + x, y, z, t = symbols("x,y,z,t") + F, T, X, Y, Z, u = map(Function, 'FTXYZu') + + eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t))) + res = pde_separate_add(eq, u(x, t), [X(x), T(t)]) + assert res == [D(X(x), x)*exp(-X(x)), D(T(t), t)*exp(T(t))] + + +def test_pde_separate(): + x, y, z, t = symbols("x,y,z,t") + F, T, X, Y, Z, u = map(Function, 'FTXYZu') + + eq = Eq(D(u(x, t), x), D(u(x, t), t)*exp(u(x, t))) + raises(ValueError, lambda: pde_separate(eq, u(x, t), [X(x), T(t)], 'div')) + + +def test_pde_separate_mul(): + x, y, z, t = symbols("x,y,z,t") + c = Symbol("C", real=True) + Phi = Function('Phi') + F, R, T, X, Y, Z, u = map(Function, 'FRTXYZu') + r, theta, z = symbols('r,theta,z') + + # Something simple :) + eq = Eq(D(F(x, y, z), x) + D(F(x, y, z), y) + D(F(x, y, z), z), 0) + + # Duplicate arguments in functions + raises( + ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), u(z, z)])) + # Wrong number of arguments + raises(ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(x), Y(y)])) + # Wrong variables: [x, y] -> [x, z] + raises( + ValueError, lambda: pde_separate_mul(eq, F(x, y, z), [X(t), Y(x, y)])) + + assert pde_separate_mul(eq, F(x, y, z), [Y(y), u(x, z)]) == \ + [D(Y(y), y)/Y(y), -D(u(x, z), x)/u(x, z) - D(u(x, z), z)/u(x, z)] + assert pde_separate_mul(eq, F(x, y, z), [X(x), Y(y), Z(z)]) == \ + [D(X(x), x)/X(x), -D(Z(z), z)/Z(z) - D(Y(y), y)/Y(y)] + + # wave equation + wave = Eq(D(u(x, t), t, t), c**2*D(u(x, t), x, x)) + res = pde_separate_mul(wave, u(x, t), [X(x), T(t)]) + assert res == [D(X(x), x, x)/X(x), D(T(t), t, t)/(c**2*T(t))] + + # Laplace equation in cylindrical coords + eq = Eq(1/r * D(Phi(r, theta, z), r) + D(Phi(r, theta, z), r, 2) + + 1/r**2 * D(Phi(r, theta, z), theta, 2) + D(Phi(r, theta, z), z, 2), 0) + # Separate z + res = pde_separate_mul(eq, Phi(r, theta, z), [Z(z), u(theta, r)]) + assert res == [D(Z(z), z, z)/Z(z), + -D(u(theta, r), r, r)/u(theta, r) - + D(u(theta, r), r)/(r*u(theta, r)) - + D(u(theta, r), theta, theta)/(r**2*u(theta, r))] + # Lets use the result to create a new equation... + eq = Eq(res[1], c) + # ...and separate theta... + res = pde_separate_mul(eq, u(theta, r), [T(theta), R(r)]) + assert res == [D(T(theta), theta, theta)/T(theta), + -r*D(R(r), r)/R(r) - r**2*D(R(r), r, r)/R(r) - c*r**2] + # ...or r... + res = pde_separate_mul(eq, u(theta, r), [R(r), T(theta)]) + assert res == [r*D(R(r), r)/R(r) + r**2*D(R(r), r, r)/R(r) + c*r**2, + -D(T(theta), theta, theta)/T(theta)] + + +def test_issue_11726(): + x, t = symbols("x t") + f = symbols("f", cls=Function) + X, T = symbols("X T", cls=Function) + + u = f(x, t) + eq = u.diff(x, 2) - u.diff(t, 2) + res = pde_separate(eq, u, [T(x), X(t)]) + assert res == [D(T(x), x, x)/T(x),D(X(t), t, t)/X(t)] + + +def test_pde_classify(): + # When more number of hints are added, add tests for classifying here. + f = Function('f') + eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y) + eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y) + eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y) + eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y) + eq5 = x**2*f(x,y) + x*f(x,y).diff(x) + x*y*f(x,y).diff(y) + eq6 = y*x**2*f(x,y) + y*f(x,y).diff(x) + f(x,y).diff(y) + for eq in [eq1, eq2, eq3]: + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + for eq in [eq4, eq5, eq6]: + assert classify_pde(eq) == ('1st_linear_variable_coeff',) + + +def test_checkpdesol(): + f, F = map(Function, ['f', 'F']) + eq1 = a*f(x,y) + b*f(x,y).diff(x) + c*f(x,y).diff(y) + eq2 = 3*f(x,y) + 2*f(x,y).diff(x) + f(x,y).diff(y) + eq3 = a*f(x,y) + b*f(x,y).diff(x) + 2*f(x,y).diff(y) + for eq in [eq1, eq2, eq3]: + assert checkpdesol(eq, pdsolve(eq))[0] + eq4 = x*f(x,y) + f(x,y).diff(x) + 3*f(x,y).diff(y) + eq5 = 2*f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y) + eq6 = f(x,y) + 1*f(x,y).diff(x) + 3*f(x,y).diff(y) + assert checkpdesol(eq4, [pdsolve(eq5), pdsolve(eq6)]) == [ + (False, (x - 2)*F(3*x - y)*exp(-x/S(5) - 3*y/S(5))), + (False, (x - 1)*F(3*x - y)*exp(-x/S(10) - 3*y/S(10)))] + for eq in [eq4, eq5, eq6]: + assert checkpdesol(eq, pdsolve(eq))[0] + sol = pdsolve(eq4) + sol4 = Eq(sol.lhs - sol.rhs, 0) + raises(NotImplementedError, lambda: + checkpdesol(eq4, sol4, solve_for_func=False)) + + +def test_solvefun(): + f, F, G, H = map(Function, ['f', 'F', 'G', 'H']) + eq1 = f(x,y) + f(x,y).diff(x) + f(x,y).diff(y) + assert pdsolve(eq1) == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2)) + assert pdsolve(eq1, solvefun=G) == Eq(f(x, y), G(x - y)*exp(-x/2 - y/2)) + assert pdsolve(eq1, solvefun=H) == Eq(f(x, y), H(x - y)*exp(-x/2 - y/2)) + + +def test_pde_1st_linear_constant_coeff_homogeneous(): + f, F = map(Function, ['f', 'F']) + u = f(x, y) + eq = 2*u + u.diff(x) + u.diff(y) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(x - y)*exp(-x - y)) + assert checkpdesol(eq, sol)[0] + + eq = 4 + (3*u.diff(x)/u) + (2*u.diff(y)/u) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(2*x - 3*y)*exp(-S(12)*x/13 - S(8)*y/13)) + assert checkpdesol(eq, sol)[0] + + eq = u + (6*u.diff(x)) + (7*u.diff(y)) + assert classify_pde(eq) == ('1st_linear_constant_coeff_homogeneous',) + sol = pdsolve(eq) + assert sol == Eq(u, F(7*x - 6*y)*exp(-6*x/S(85) - 7*y/S(85))) + assert checkpdesol(eq, sol)[0] + + eq = a*u + b*u.diff(x) + c*u.diff(y) + sol = pdsolve(eq) + assert checkpdesol(eq, sol)[0] + + +def test_pde_1st_linear_constant_coeff(): + f, F = map(Function, ['f', 'F']) + u = f(x,y) + eq = -2*u.diff(x) + 4*u.diff(y) + 5*u - exp(x + 3*y) + sol = pdsolve(eq) + assert sol == Eq(f(x,y), + (F(4*x + 2*y)*exp(x/2) + exp(x + 4*y)/15)*exp(-y)) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = (u.diff(x)/u) + (u.diff(y)/u) + 1 - (exp(x + y)/u) + sol = pdsolve(eq) + assert sol == Eq(f(x, y), F(x - y)*exp(-x/2 - y/2) + exp(x + y)/3) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = 2*u + -u.diff(x) + 3*u.diff(y) + sin(x) + sol = pdsolve(eq) + assert sol == Eq(f(x, y), + F(3*x + y)*exp(x/5 - 3*y/5) - 2*sin(x)/5 - cos(x)/5) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + + eq = u + u.diff(x) + u.diff(y) + x*y + sol = pdsolve(eq) + assert sol.expand() == Eq(f(x, y), + x + y + (x - y)**2/4 - (x + y)**2/4 + F(x - y)*exp(-x/2 - y/2) - 2).expand() + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + assert checkpdesol(eq, sol)[0] + eq = u + u.diff(x) + u.diff(y) + log(x) + assert classify_pde(eq) == ('1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral') + + +def test_pdsolve_all(): + f, F = map(Function, ['f', 'F']) + u = f(x,y) + eq = u + u.diff(x) + u.diff(y) + x**2*y + sol = pdsolve(eq, hint = 'all') + keys = ['1st_linear_constant_coeff', + '1st_linear_constant_coeff_Integral', 'default', 'order'] + assert sorted(sol.keys()) == keys + assert sol['order'] == 1 + assert sol['default'] == '1st_linear_constant_coeff' + assert sol['1st_linear_constant_coeff'].expand() == Eq(f(x, y), + -x**2*y + x**2 + 2*x*y - 4*x - 2*y + F(x - y)*exp(-x/2 - y/2) + 6).expand() + + +def test_pdsolve_variable_coeff(): + f, F = map(Function, ['f', 'F']) + u = f(x, y) + eq = x*(u.diff(x)) - y*(u.diff(y)) + y**2*u - y**2 + sol = pdsolve(eq, hint="1st_linear_variable_coeff") + assert sol == Eq(u, F(x*y)*exp(y**2/2) + 1) + assert checkpdesol(eq, sol)[0] + + eq = x**2*u + x*u.diff(x) + x*y*u.diff(y) + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(y*exp(-x))*exp(-x**2/2)) + assert checkpdesol(eq, sol)[0] + + eq = y*x**2*u + y*u.diff(x) + u.diff(y) + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(-2*x + y**2)*exp(-x**3/3)) + assert checkpdesol(eq, sol)[0] + + eq = exp(x)**2*(u.diff(x)) + y + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, y*exp(-2*x)/2 + F(y)) + assert checkpdesol(eq, sol)[0] + + eq = exp(2*x)*(u.diff(y)) + y*u - u + sol = pdsolve(eq, hint='1st_linear_variable_coeff') + assert sol == Eq(u, F(x)*exp(-y*(y - 2)*exp(-2*x)/2)) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_polysys.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_polysys.py new file mode 100644 index 0000000000000000000000000000000000000000..9f0a70c89cd94e9f03cb7cc5a009bc8209a21178 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_polysys.py @@ -0,0 +1,178 @@ +"""Tests for solvers of systems of polynomial equations. """ +from sympy.core.numbers import (I, Integer, Rational) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.polyerrors import UnsolvableFactorError +from sympy.polys.polyoptions import Options +from sympy.polys.polytools import Poly +from sympy.solvers.solvers import solve +from sympy.utilities.iterables import flatten +from sympy.abc import x, y, z +from sympy.polys import PolynomialError +from sympy.solvers.polysys import (solve_poly_system, + solve_triangulated, + solve_biquadratic, SolveFailed, + solve_generic) +from sympy.polys.polytools import parallel_poly_from_expr +from sympy.testing.pytest import raises + + +def test_solve_poly_system(): + assert solve_poly_system([x - 1], x) == [(S.One,)] + + assert solve_poly_system([y - x, y - x - 1], x, y) is None + + assert solve_poly_system([y - x**2, y + x**2], x, y) == [(S.Zero, S.Zero)] + + assert solve_poly_system([2*x - 3, y*Rational(3, 2) - 2*x, z - 5*y], x, y, z) == \ + [(Rational(3, 2), Integer(2), Integer(10))] + + assert solve_poly_system([x*y - 2*y, 2*y**2 - x**2], x, y) == \ + [(0, 0), (2, -sqrt(2)), (2, sqrt(2))] + + assert solve_poly_system([y - x**2, y + x**2 + 1], x, y) == \ + [(-I*sqrt(S.Half), Rational(-1, 2)), (I*sqrt(S.Half), Rational(-1, 2))] + + f_1 = x**2 + y + z - 1 + f_2 = x + y**2 + z - 1 + f_3 = x + y + z**2 - 1 + + a, b = sqrt(2) - 1, -sqrt(2) - 1 + + assert solve_poly_system([f_1, f_2, f_3], x, y, z) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + solution = [(1, -1), (1, 1)] + + assert solve_poly_system([Poly(x**2 - y**2), Poly(x - 1)]) == solution + assert solve_poly_system([x**2 - y**2, x - 1], x, y) == solution + assert solve_poly_system([x**2 - y**2, x - 1]) == solution + + assert solve_poly_system( + [x + x*y - 3, y + x*y - 4], x, y) == [(-3, -2), (1, 2)] + + raises(NotImplementedError, lambda: solve_poly_system([x**3 - y**3], x, y)) + raises(NotImplementedError, lambda: solve_poly_system( + [z, -2*x*y**2 + x + y**2*z, y**2*(-z - 4) + 2])) + raises(PolynomialError, lambda: solve_poly_system([1/x], x)) + + raises(NotImplementedError, lambda: solve_poly_system( + [x-1,], (x, y))) + raises(NotImplementedError, lambda: solve_poly_system( + [y-1,], (x, y))) + + # solve_poly_system should ideally construct solutions using + # CRootOf for the following four tests + assert solve_poly_system([x**5 - x + 1], [x], strict=False) == [] + raises(UnsolvableFactorError, lambda: solve_poly_system( + [x**5 - x + 1], [x], strict=True)) + + assert solve_poly_system([(x - 1)*(x**5 - x + 1), y**2 - 1], [x, y], + strict=False) == [(1, -1), (1, 1)] + raises(UnsolvableFactorError, + lambda: solve_poly_system([(x - 1)*(x**5 - x + 1), y**2-1], + [x, y], strict=True)) + + +def test_solve_generic(): + NewOption = Options((x, y), {'domain': 'ZZ'}) + assert solve_generic([x**2 - 2*y**2, y**2 - y + 1], NewOption) == \ + [(-sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2), + (sqrt(-1 - sqrt(3)*I), Rational(1, 2) - sqrt(3)*I/2), + (-sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2), + (sqrt(-1 + sqrt(3)*I), Rational(1, 2) + sqrt(3)*I/2)] + + # solve_generic should ideally construct solutions using + # CRootOf for the following two tests + assert solve_generic( + [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=False) == \ + [(Rational(1, 2), 1)] + raises(UnsolvableFactorError, lambda: solve_generic( + [2*x - y, (y - 1)*(y**5 - y + 1)], NewOption, strict=True)) + + +def test_solve_biquadratic(): + x0, y0, x1, y1, r = symbols('x0 y0 x1 y1 r') + + f_1 = (x - 1)**2 + (y - 1)**2 - r**2 + f_2 = (x - 2)**2 + (y - 2)**2 - r**2 + s = sqrt(2*r**2 - 1) + a = (3 - s)/2 + b = (3 + s)/2 + assert solve_poly_system([f_1, f_2], x, y) == [(a, b), (b, a)] + + f_1 = (x - 1)**2 + (y - 2)**2 - r**2 + f_2 = (x - 1)**2 + (y - 1)**2 - r**2 + + assert solve_poly_system([f_1, f_2], x, y) == \ + [(1 - sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2)), + (1 + sqrt((2*r - 1)*(2*r + 1))/2, Rational(3, 2))] + + query = lambda expr: expr.is_Pow and expr.exp is S.Half + + f_1 = (x - 1 )**2 + (y - 2)**2 - r**2 + f_2 = (x - x1)**2 + (y - 1)**2 - r**2 + + result = solve_poly_system([f_1, f_2], x, y) + + assert len(result) == 2 and all(len(r) == 2 for r in result) + assert all(r.count(query) == 1 for r in flatten(result)) + + f_1 = (x - x0)**2 + (y - y0)**2 - r**2 + f_2 = (x - x1)**2 + (y - y1)**2 - r**2 + + result = solve_poly_system([f_1, f_2], x, y) + + assert len(result) == 2 and all(len(r) == 2 for r in result) + assert all(len(r.find(query)) == 1 for r in flatten(result)) + + s1 = (x*y - y, x**2 - x) + assert solve(s1) == [{x: 1}, {x: 0, y: 0}] + s2 = (x*y - x, y**2 - y) + assert solve(s2) == [{y: 1}, {x: 0, y: 0}] + gens = (x, y) + for seq in (s1, s2): + (f, g), opt = parallel_poly_from_expr(seq, *gens) + raises(SolveFailed, lambda: solve_biquadratic(f, g, opt)) + seq = (x**2 + y**2 - 2, y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == [ + (-1, -1), (-1, 1), (1, -1), (1, 1)] + ans = [(0, -1), (0, 1)] + seq = (x**2 + y**2 - 1, y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == ans + seq = (x**2 + y**2 - 1, x**2 - x + y**2 - 1) + (f, g), opt = parallel_poly_from_expr(seq, *gens) + assert solve_biquadratic(f, g, opt) == ans + + +def test_solve_triangulated(): + f_1 = x**2 + y + z - 1 + f_2 = x + y**2 + z - 1 + f_3 = x + y + z**2 - 1 + + a, b = sqrt(2) - 1, -sqrt(2) - 1 + + assert solve_triangulated([f_1, f_2, f_3], x, y, z) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0)] + + dom = QQ.algebraic_field(sqrt(2)) + + assert solve_triangulated([f_1, f_2, f_3], x, y, z, domain=dom) == \ + [(0, 0, 1), (0, 1, 0), (1, 0, 0), (a, a, a), (b, b, b)] + + +def test_solve_issue_3686(): + roots = solve_poly_system([((x - 5)**2/250000 + (y - Rational(5, 10))**2/250000) - 1, x], x, y) + assert roots == [(0, S.Half - 15*sqrt(1111)), (0, S.Half + 15*sqrt(1111))] + + roots = solve_poly_system([((x - 5)**2/250000 + (y - 5.0/10)**2/250000) - 1, x], x, y) + # TODO: does this really have to be so complicated?! + assert len(roots) == 2 + assert roots[0][0] == 0 + assert roots[0][1].epsilon_eq(-499.474999374969, 1e12) + assert roots[1][0] == 0 + assert roots[1][1].epsilon_eq(500.474999374969, 1e12) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_recurr.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_recurr.py new file mode 100644 index 0000000000000000000000000000000000000000..5a6306b51a5cf33ccd9fae131430a24690d540a7 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_recurr.py @@ -0,0 +1,295 @@ +from sympy.core.function import (Function, Lambda, expand) +from sympy.core.numbers import (I, Rational) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Symbol, symbols) +from sympy.functions.combinatorial.factorials import (rf, binomial, factorial) +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.polys.polytools import factor +from sympy.solvers.recurr import rsolve, rsolve_hyper, rsolve_poly, rsolve_ratio +from sympy.testing.pytest import raises, slow, XFAIL +from sympy.abc import a, b + +y = Function('y') +n, k = symbols('n,k', integer=True) +C0, C1, C2 = symbols('C0,C1,C2') + + +def test_rsolve_poly(): + assert rsolve_poly([-1, -1, 1], 0, n) == 0 + assert rsolve_poly([-1, -1, 1], 1, n) == -1 + + assert rsolve_poly([-1, n + 1], n, n) == 1 + assert rsolve_poly([-1, 1], n, n) == C0 + (n**2 - n)/2 + assert rsolve_poly([-n - 1, n], 1, n) == C0*n - 1 + assert rsolve_poly([-4*n - 2, 1], 4*n + 1, n) == -1 + + assert rsolve_poly([-1, 1], n**5 + n**3, n) == \ + C0 - n**3 / 2 - n**5 / 2 + n**2 / 6 + n**6 / 6 + 2*n**4 / 3 + + +def test_rsolve_ratio(): + solution = rsolve_ratio([-2*n**3 + n**2 + 2*n - 1, 2*n**3 + n**2 - 6*n, + -2*n**3 - 11*n**2 - 18*n - 9, 2*n**3 + 13*n**2 + 22*n + 8], 0, n) + assert solution == C0*(2*n - 3)/(n**2 - 1)/2 + + +def test_rsolve_hyper(): + assert rsolve_hyper([-1, -1, 1], 0, n) in [ + C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n, + C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n, + ] + + assert rsolve_hyper([n**2 - 2, -2*n - 1, 1], 0, n) in [ + C0*rf(sqrt(2), n) + C1*rf(-sqrt(2), n), + C1*rf(sqrt(2), n) + C0*rf(-sqrt(2), n), + ] + + assert rsolve_hyper([n**2 - k, -2*n - 1, 1], 0, n) in [ + C0*rf(sqrt(k), n) + C1*rf(-sqrt(k), n), + C1*rf(sqrt(k), n) + C0*rf(-sqrt(k), n), + ] + + assert rsolve_hyper( + [2*n*(n + 1), -n**2 - 3*n + 2, n - 1], 0, n) == C1*factorial(n) + C0*2**n + + assert rsolve_hyper( + [n + 2, -(2*n + 3)*(17*n**2 + 51*n + 39), n + 1], 0, n) == 0 + + assert rsolve_hyper([-n - 1, -1, 1], 0, n) == 0 + + assert rsolve_hyper([-1, 1], n, n).expand() == C0 + n**2/2 - n/2 + + assert rsolve_hyper([-1, 1], 1 + n, n).expand() == C0 + n**2/2 + n/2 + + assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n + + assert rsolve_hyper([-a, 1],0,n).expand() == C0*a**n + + assert rsolve_hyper([-a, 0, 1], 0, n).expand() == (-1)**n*C1*a**(n/2) + C0*a**(n/2) + + assert rsolve_hyper([1, 1, 1], 0, n).expand() == \ + C0*(Rational(-1, 2) - sqrt(3)*I/2)**n + C1*(Rational(-1, 2) + sqrt(3)*I/2)**n + + assert rsolve_hyper([1, -2*n/a - 2/a, 1], 0, n) == 0 + + +@XFAIL +def test_rsolve_ratio_missed(): + # this arises during computation + # assert rsolve_hyper([-1, 1], 3*(n + n**2), n).expand() == C0 + n**3 - n + assert rsolve_ratio([-n, n + 2], n, n) is not None + + +def recurrence_term(c, f): + """Compute RHS of recurrence in f(n) with coefficients in c.""" + return sum(c[i]*f.subs(n, n + i) for i in range(len(c))) + + +def test_rsolve_bulk(): + """Some bulk-generated tests.""" + funcs = [ n, n + 1, n**2, n**3, n**4, n + n**2, 27*n + 52*n**2 - 3* + n**3 + 12*n**4 - 52*n**5 ] + coeffs = [ [-2, 1], [-2, -1, 1], [-1, 1, 1, -1, 1], [-n, 1], [n**2 - + n + 12, 1] ] + for p in funcs: + # compute difference + for c in coeffs: + q = recurrence_term(c, p) + if p.is_polynomial(n): + assert rsolve_poly(c, q, n) == p + # See issue 3956: + if p.is_hypergeometric(n) and len(c) <= 3: + assert rsolve_hyper(c, q, n).subs(zip(symbols('C:3'), [0, 0, 0])).expand() == p + + +def test_rsolve_0_sol_homogeneous(): + # fixed by cherry-pick from + # https://github.com/diofant/diofant/commit/e1d2e52125199eb3df59f12e8944f8a5f24b00a5 + assert rsolve_hyper([n**2 - n + 12, 1], n*(n**2 - n + 12) + n + 1, n) == n + + +def test_rsolve(): + f = y(n + 2) - y(n + 1) - y(n) + h = sqrt(5)*(S.Half + S.Half*sqrt(5))**n \ + - sqrt(5)*(S.Half - S.Half*sqrt(5))**n + + assert rsolve(f, y(n)) in [ + C0*(S.Half - S.Half*sqrt(5))**n + C1*(S.Half + S.Half*sqrt(5))**n, + C1*(S.Half - S.Half*sqrt(5))**n + C0*(S.Half + S.Half*sqrt(5))**n, + ] + + assert rsolve(f, y(n), [0, 5]) == h + assert rsolve(f, y(n), {0: 0, 1: 5}) == h + assert rsolve(f, y(n), {y(0): 0, y(1): 5}) == h + assert rsolve(y(n) - y(n - 1) - y(n - 2), y(n), [0, 5]) == h + assert rsolve(Eq(y(n), y(n - 1) + y(n - 2)), y(n), [0, 5]) == h + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = (n - 1)*y(n + 2) - (n**2 + 3*n - 2)*y(n + 1) + 2*n*(n + 1)*y(n) + g = C1*factorial(n) + C0*2**n + h = -3*factorial(n) + 3*2**n + + assert rsolve(f, y(n)) == g + assert rsolve(f, y(n), []) == g + assert rsolve(f, y(n), {}) == g + + assert rsolve(f, y(n), [0, 3]) == h + assert rsolve(f, y(n), {0: 0, 1: 3}) == h + assert rsolve(f, y(n), {y(0): 0, y(1): 3}) == h + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - y(n - 1) - 2 + + assert rsolve(f, y(n), {y(0): 0}) == 2*n + assert rsolve(f, y(n), {y(0): 1}) == 2*n + 1 + assert rsolve(f, y(n), {y(0): 0, y(1): 1}) is None + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = 3*y(n - 1) - y(n) - 1 + + assert rsolve(f, y(n), {y(0): 0}) == -3**n/2 + S.Half + assert rsolve(f, y(n), {y(0): 1}) == 3**n/2 + S.Half + assert rsolve(f, y(n), {y(0): 2}) == 3*3**n/2 + S.Half + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - 1/n*y(n - 1) + assert rsolve(f, y(n)) == C0/factorial(n) + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = y(n) - 1/n*y(n - 1) - 1 + assert rsolve(f, y(n)) is None + + f = 2*y(n - 1) + (1 - n)*y(n)/n + + assert rsolve(f, y(n), {y(1): 1}) == 2**(n - 1)*n + assert rsolve(f, y(n), {y(1): 2}) == 2**(n - 1)*n*2 + assert rsolve(f, y(n), {y(1): 3}) == 2**(n - 1)*n*3 + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + f = (n - 1)*(n - 2)*y(n + 2) - (n + 1)*(n + 2)*y(n) + + assert rsolve(f, y(n), {y(3): 6, y(4): 24}) == n*(n - 1)*(n - 2) + assert rsolve( + f, y(n), {y(3): 6, y(4): -24}) == -n*(n - 1)*(n - 2)*(-1)**(n) + + assert f.subs(y, Lambda(k, rsolve(f, y(n)).subs(n, k))).simplify() == 0 + + assert rsolve(Eq(y(n + 1), a*y(n)), y(n), {y(1): a}).simplify() == a**n + + assert rsolve(y(n) - a*y(n-2),y(n), \ + {y(1): sqrt(a)*(a + b), y(2): a*(a - b)}).simplify() == \ + a**(n/2 + 1) - b*(-sqrt(a))**n + + f = (-16*n**2 + 32*n - 12)*y(n - 1) + (4*n**2 - 12*n + 9)*y(n) + + yn = rsolve(f, y(n), {y(1): binomial(2*n + 1, 3)}) + sol = 2**(2*n)*n*(2*n - 1)**2*(2*n + 1)/12 + assert factor(expand(yn, func=True)) == sol + + sol = rsolve(y(n) + a*(y(n + 1) + y(n - 1))/2, y(n)) + assert str(sol) == 'C0*((-sqrt(1 - a**2) - 1)/a)**n + C1*((sqrt(1 - a**2) - 1)/a)**n' + + assert rsolve((k + 1)*y(k), y(k)) is None + assert (rsolve((k + 1)*y(k) + (k + 3)*y(k + 1) + (k + 5)*y(k + 2), y(k)) + is None) + + assert rsolve(y(n) + y(n + 1) + 2**n + 3**n, y(n)) == (-1)**n*C0 - 2**n/3 - 3**n/4 + + +def test_rsolve_raises(): + x = Function('x') + raises(ValueError, lambda: rsolve(y(n) - y(k + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - y(n + 1), x(n))) + raises(ValueError, lambda: rsolve(y(n) - x(n + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - sqrt(n)*y(n + 1), y(n))) + raises(ValueError, lambda: rsolve(y(n) - y(n + 1), y(n), {x(0): 0})) + raises(ValueError, lambda: rsolve(y(n) + y(n + 1) + 2**n + cos(n), y(n))) + + +def test_issue_6844(): + f = y(n + 2) - y(n + 1) + y(n)/4 + assert rsolve(f, y(n)) == 2**(-n + 1)*C1*n + 2**(-n)*C0 + assert rsolve(f, y(n), {y(0): 0, y(1): 1}) == 2**(1 - n)*n + + +def test_issue_18751(): + r = Symbol('r', positive=True) + theta = Symbol('theta', real=True) + f = y(n) - 2 * r * cos(theta) * y(n - 1) + r**2 * y(n - 2) + assert rsolve(f, y(n)) == \ + C0*(r*(cos(theta) - I*Abs(sin(theta))))**n + C1*(r*(cos(theta) + I*Abs(sin(theta))))**n + + +def test_constant_naming(): + #issue 8697 + assert rsolve(y(n+3) - y(n+2) - y(n+1) + y(n), y(n)) == (-1)**n*C1 + C0 + C2*n + assert rsolve(y(n+3)+3*y(n+2)+3*y(n+1)+y(n), y(n)).expand() == (-1)**n*C0 - (-1)**n*C1*n - (-1)**n*C2*n**2 + assert rsolve(y(n) - 2*y(n - 3) + 5*y(n - 2) - 4*y(n - 1),y(n),[1,3,8]) == 3*2**n - n - 2 + + #issue 19630 + assert rsolve(y(n+3) - 3*y(n+1) + 2*y(n), y(n), {y(1):0, y(2):8, y(3):-2}) == (-2)**n + 2*n + + +@slow +def test_issue_15751(): + f = y(n) + 21*y(n + 1) - 273*y(n + 2) - 1092*y(n + 3) + 1820*y(n + 4) + 1092*y(n + 5) - 273*y(n + 6) - 21*y(n + 7) + y(n + 8) + assert rsolve(f, y(n)) is not None + + +def test_issue_17990(): + f = -10*y(n) + 4*y(n + 1) + 6*y(n + 2) + 46*y(n + 3) + sol = rsolve(f, y(n)) + expected = C0*((86*18**(S(1)/3)/69 + (-12 + (-1 + sqrt(3)*I)*(290412 + + 3036*sqrt(9165))**(S(1)/3))*(1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))** + (S(1)/3)/276)/((1 - sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3)) + )**n + C1*((86*18**(S(1)/3)/69 + (-12 + (-1 - sqrt(3)*I)*(290412 + 3036 + *sqrt(9165))**(S(1)/3))*(1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))** + (S(1)/3)/276)/((1 + sqrt(3)*I)*(24201 + 253*sqrt(9165))**(S(1)/3)) + )**n + C2*(-43*18**(S(1)/3)/(69*(24201 + 253*sqrt(9165))**(S(1)/3)) - + S(1)/23 + (290412 + 3036*sqrt(9165))**(S(1)/3)/138)**n + assert sol == expected + e = sol.subs({C0: 1, C1: 1, C2: 1, n: 1}).evalf() + assert abs(e + 0.130434782608696) < 1e-13 + + +def test_issue_8697(): + a = Function('a') + eq = a(n + 3) - a(n + 2) - a(n + 1) + a(n) + assert rsolve(eq, a(n)) == (-1)**n*C1 + C0 + C2*n + eq2 = a(n + 3) + 3*a(n + 2) + 3*a(n + 1) + a(n) + assert (rsolve(eq2, a(n)) == + (-1)**n*C0 + (-1)**(n + 1)*C1*n + (-1)**(n + 1)*C2*n**2) + + assert rsolve(a(n) - 2*a(n - 3) + 5*a(n - 2) - 4*a(n - 1), + a(n), {a(0): 1, a(1): 3, a(2): 8}) == 3*2**n - n - 2 + + # From issue thread (but fixed by https://github.com/diofant/diofant/commit/da9789c6cd7d0c2ceeea19fbf59645987125b289): + assert rsolve(a(n) - 2*a(n - 1) - n, a(n), {a(0): 1}) == 3*2**n - n - 2 + + +def test_diofantissue_294(): + f = y(n) - y(n - 1) - 2*y(n - 2) - 2*n + assert rsolve(f, y(n)) == (-1)**n*C0 + 2**n*C1 - n - Rational(5, 2) + # issue sympy/sympy#11261 + assert rsolve(f, y(n), {y(0): -1, y(1): 1}) == (-(-1)**n/2 + 2*2**n - + n - Rational(5, 2)) + # issue sympy/sympy#7055 + assert rsolve(-2*y(n) + y(n + 1) + n - 1, y(n)) == 2**n*C0 + n + + +def test_issue_15553(): + f = Function("f") + assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n)) == 2**n*C0 - n - 2 + assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n)) == 2**n*C0 - n**2 - 2*n - 4 + assert rsolve(Eq(f(n + 1), 2*f(n) + n**2 + 1), f(n), {f(1): 0}) == 7*2**n/2 - n**2 - 2*n - 4 + assert rsolve(Eq(f(n), 2*f(n - 1) + 3*n**2), f(n)) == 2**n*C0 - 3*n**2 - 12*n - 18 + assert rsolve(Eq(f(n), 2*f(n - 1) + n**2), f(n)) == 2**n*C0 - n**2 - 4*n - 6 + assert rsolve(Eq(f(n), 2*f(n - 1) + n), f(n), {f(0): 1}) == 3*2**n - n - 2 diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_simplex.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_simplex.py new file mode 100644 index 0000000000000000000000000000000000000000..611205f5df009a6d0de6e687501695b63bb932c9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_simplex.py @@ -0,0 +1,254 @@ +from sympy.core.numbers import Rational +from sympy.core.relational import Eq, Ne +from sympy.core.symbol import symbols +from sympy.core.sympify import sympify +from sympy.core.singleton import S +from sympy.core.random import random, choice +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.ntheory.generate import randprime +from sympy.matrices.dense import Matrix +from sympy.solvers.solveset import linear_eq_to_matrix +from sympy.solvers.simplex import (_lp as lp, _primal_dual, + UnboundedLPError, InfeasibleLPError, lpmin, lpmax, + _m, _abcd, _simplex, linprog) + +from sympy.external.importtools import import_module + +from sympy.testing.pytest import raises + +from sympy.abc import x, y, z + + +np = import_module("numpy") +scipy = import_module("scipy") + + +def test_lp(): + r1 = y + 2*z <= 3 + r2 = -x - 3*z <= -2 + r3 = 2*x + y + 7*z <= 5 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + objective = -x - y - 5 * z + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + r1 = x - y + 2*z <= 3 + r2 = -x + 2*y - 3*z <= -2 + r3 = 2*x + y - 7*z <= -5 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + objective = -x - y - 5*z + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + r1 = x - y + 2*z <= -4 + r2 = -x + 2*y - 3*z <= 8 + r3 = 2*x + y - 7*z <= 10 + constraints = [r1, r2, r3, x >= 0, y >= 0, z >= 0] + const = 2 + objective = -x-y-5*z+const # has constant term + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + # Section 4 Problem 1 from + # http://web.tecnico.ulisboa.pt/mcasquilho/acad/or/ftp/FergusonUCLA_LP.pdf + # answer on page 55 + v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4') + r1 = x1 - x2 - 2*x3 - x4 <= 4 + r2 = 2*x1 + x3 -4*x4 <= 2 + r3 = -2*x1 + x2 + x4 <= 1 + objective, constraints = x1 - 2*x2 - 3*x3 - x4, [r1, r2, r3] + [ + i >= 0 for i in v] + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert ans == (4, {x1: 7, x2: 0, x3: 0, x4: 3}) + + # input contains Floats + r1 = x - y + 2.0*z <= -4 + r2 = -x + 2*y - 3.0*z <= 8 + r3 = 2*x + y - 7*z <= 10 + constraints = [r1, r2, r3] + [i >= 0 for i in (x, y, z)] + objective = -x-y-5*z + optimum, argmax = lp(max, objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + # input contains non-float or non-Rational + r1 = x - y + sqrt(2) * z <= -4 + r2 = -x + 2*y - 3*z <= 8 + r3 = 2*x + y - 7*z <= 10 + raises(TypeError, lambda: lp(max, -x-y-5*z, [r1, r2, r3])) + + r1 = x >= 0 + raises(UnboundedLPError, lambda: lp(max, x, [r1])) + r2 = x <= -1 + raises(InfeasibleLPError, lambda: lp(max, x, [r1, r2])) + + # strict inequalities are not allowed + r1 = x > 0 + raises(TypeError, lambda: lp(max, x, [r1])) + + # not equals not allowed + r1 = Ne(x, 0) + raises(TypeError, lambda: lp(max, x, [r1])) + + def make_random_problem(nvar=2, num_constraints=2, sparsity=.1): + def rand(): + if random() < sparsity: + return sympify(0) + int1, int2 = [randprime(0, 200) for _ in range(2)] + return Rational(int1, int2)*choice([-1, 1]) + variables = symbols('x1:%s' % (nvar + 1)) + constraints = [(sum(rand()*x for x in variables) <= rand()) + for _ in range(num_constraints)] + objective = sum(rand() * x for x in variables) + return objective, constraints, variables + + # equality + r1 = Eq(x, y) + r2 = Eq(y, z) + r3 = z <= 3 + constraints = [r1, r2, r3] + objective = x + ans = optimum, argmax = lp(max, objective, constraints) + assert ans == lpmax(objective, constraints) + assert objective.subs(argmax) == optimum + for constr in constraints: + assert constr.subs(argmax) == True + + +def test_simplex(): + L = [ + [[1, 1], [-1, 1], [0, 1], [-1, 0]], + [5, 1, 2, -1], + [[1, 1]], + [-1]] + A, B, C, D = _abcd(_m(*L), list=False) + assert _simplex(A, B, -C, -D) == (-6, [3, 2], [1, 0, 0, 0]) + assert _simplex(A, B, -C, -D, dual=True) == (-6, + [1, 0, 0, 0], [5, 0]) + + assert _simplex([[]],[],[[1]],[0]) == (0, [0], []) + + # handling of Eq (or Eq-like x<=y, x>=y conditions) + assert lpmax(x - y, [x <= y + 2, x >= y + 2, x >= 0, y >= 0] + ) == (2, {x: 2, y: 0}) + assert lpmax(x - y, [x <= y + 2, Eq(x, y + 2), x >= 0, y >= 0] + ) == (2, {x: 2, y: 0}) + assert lpmax(x - y, [x <= y + 2, Eq(x, 2)]) == (2, {x: 2, y: 0}) + assert lpmax(y, [Eq(y, 2)]) == (2, {y: 2}) + + # the conditions are equivalent to Eq(x, y + 2) + assert lpmin(y, [x <= y + 2, x >= y + 2, y >= 0] + ) == (0, {x: 2, y: 0}) + # equivalent to Eq(y, -2) + assert lpmax(y, [0 <= y + 2, 0 >= y + 2]) == (-2, {y: -2}) + assert lpmax(y, [0 <= y + 2, 0 >= y + 2, y <= 0] + ) == (-2, {y: -2}) + + # extra symbols symbols + assert lpmin(x, [y >= 1, x >= y]) == (1, {x: 1, y: 1}) + assert lpmin(x, [y >= 1, x >= y + z, x >= 0, z >= 0] + ) == (1, {x: 1, y: 1, z: 0}) + + # detect oscillation + # o1 + v = x1, x2, x3, x4 = symbols('x1 x2 x3 x4') + raises(InfeasibleLPError, lambda: lpmin( + 9*x2 - 8*x3 + 3*x4 + 6, + [5*x2 - 2*x3 <= 0, + -x1 - 8*x2 + 9*x3 <= -3, + 10*x1 - x2+ 9*x4 <= -4] + [i >= 0 for i in v])) + # o2 - equations fed to lpmin are changed into a matrix + # system that doesn't oscillate and has the same solution + # as below + M = linear_eq_to_matrix + f = 5*x2 + x3 + 4*x4 - x1 + L = 5*x2 + 2*x3 + 5*x4 - (x1 + 5) + cond = [L <= 0] + [Eq(3*x2 + x4, 2), Eq(-x1 + x3 + 2*x4, 1)] + c, d = M(f, v) + a, b = M(L, v) + aeq, beq = M(cond[1:], v) + ans = (S(9)/2, [0, S(1)/2, 0, S(1)/2]) + assert linprog(c, a, b, aeq, beq, bounds=(0, 1)) == ans + lpans = lpmin(f, cond + [x1 >= 0, x1 <= 1, + x2 >= 0, x2 <= 1, x3 >= 0, x3 <= 1, x4 >= 0, x4 <= 1]) + assert (lpans[0], list(lpans[1].values())) == ans + + +def test_lpmin_lpmax(): + v = x1, x2, y1, y2 = symbols('x1 x2 y1 y2') + L = [[1, -1]], [1], [[1, 1]], [2] + a, b, c, d = [Matrix(i) for i in L] + m = Matrix([[a, b], [c, d]]) + f, constr = _primal_dual(m)[0] + ans = lpmin(f, constr + [i >= 0 for i in v[:2]]) + assert ans == (-1, {x1: 1, x2: 0}),ans + + L = [[1, -1], [1, 1]], [1, 1], [[1, 1]], [2] + a, b, c, d = [Matrix(i) for i in L] + m = Matrix([[a, b], [c, d]]) + f, constr = _primal_dual(m)[1] + ans = lpmax(f, constr + [i >= 0 for i in v[-2:]]) + assert ans == (-1, {y1: 1, y2: 0}) + + +def test_linprog(): + for do in range(2): + if not do: + M = lambda a, b: linear_eq_to_matrix(a, b) + else: + # check matrices as list + M = lambda a, b: tuple([ + i.tolist() for i in linear_eq_to_matrix(a, b)]) + + v = x, y, z = symbols('x1:4') + f = x + y - 2*z + c = M(f, v)[0] + ineq = [7*x + 4*y - 7*z <= 3, + 3*x - y + 10*z <= 6, + x >= 0, y >= 0, z >= 0] + ab = M([i.lts - i.gts for i in ineq], v) + ans = (-S(6)/5, [0, 0, S(3)/5]) + assert lpmin(f, ineq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab) == ans + + f += 1 + c = M(f, v)[0] + eq = [Eq(y - 9*x, 1)] + abeq = M([i.lhs - i.rhs for i in eq], v) + ans = (1 - S(2)/5, [0, 1, S(7)/10]) + assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1]) + + eq = [z - y <= S.Half] + abeq = M([i.lhs - i.rhs for i in eq], v) + ans = (1 - S(10)/9, [0, S(1)/9, S(11)/18]) + assert lpmin(f, ineq + eq) == (ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, *abeq) == (ans[0] - 1, ans[1]) + + bounds = [(0, None), (0, None), (None, S.Half)] + ans = (0, [0, 0, S.Half]) + assert lpmin(f, ineq + [z <= S.Half]) == ( + ans[0], dict(zip(v, ans[1]))) + assert linprog(c, *ab, bounds=bounds) == (ans[0] - 1, ans[1]) + assert linprog(c, *ab, bounds={v.index(z): bounds[-1]} + ) == (ans[0] - 1, ans[1]) + eq = [z - y <= S.Half] + + assert linprog([[1]], [], [], bounds=(2, 3)) == (2, [2]) + assert linprog([1], [], [], bounds=(2, 3)) == (2, [2]) + assert linprog([1], bounds=(2, 3)) == (2, [2]) + assert linprog([1, -1], [[1, 1]], [2], bounds={1:(None, None)} + ) == (-2, [0, 2]) + assert linprog([1, -1], [[1, 1]], [5], bounds={1:(3, None)} + ) == (-5, [0, 5]) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_solvers.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_solvers.py new file mode 100644 index 0000000000000000000000000000000000000000..c3ef819bbecd171adebdad76d9e52dead4f9fe31 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_solvers.py @@ -0,0 +1,2703 @@ +from sympy.assumptions.ask import (Q, ask) +from sympy.core.add import Add +from sympy.core.containers import Tuple +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.mod import Mod +from sympy.core.mul import Mul +from sympy.core import (GoldenRatio, TribonacciConstant) +from sympy.core.numbers import (E, Float, I, Rational, oo, pi) +from sympy.core.relational import (Eq, Gt, Lt, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, Symbol, Wild, symbols) +from sympy.core.sympify import sympify +from sympy.functions.combinatorial.factorials import binomial +from sympy.functions.elementary.complexes import (Abs, arg, conjugate, im, re) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (atanh, cosh, sinh, tanh) +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import (cbrt, root, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, asin, atan, atan2, cos, sec, sin, tan) +from sympy.functions.special.error_functions import (erf, erfc, erfcinv, erfinv) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, Or) +from sympy.matrices.dense import Matrix +from sympy.matrices import SparseMatrix +from sympy.polys.polytools import Poly +from sympy.printing.str import sstr +from sympy.simplify.radsimp import denom +from sympy.solvers.solvers import (nsolve, solve, solve_linear) + +from sympy.core.function import nfloat +from sympy.solvers import solve_linear_system, solve_linear_system_LU, \ + solve_undetermined_coeffs +from sympy.solvers.bivariate import _filtered_gens, _solve_lambert, _lambert +from sympy.solvers.solvers import _invert, unrad, checksol, posify, _ispow, \ + det_quick, det_perm, det_minor, _simple_dens, denoms + +from sympy.physics.units import cm +from sympy.polys.rootoftools import CRootOf + +from sympy.testing.pytest import slow, XFAIL, SKIP, raises +from sympy.core.random import verify_numerically as tn + +from sympy.abc import a, b, c, d, e, k, h, p, x, y, z, t, q, m, R + + +def NS(e, n=15, **options): + return sstr(sympify(e).evalf(n, **options), full_prec=True) + + +def test_swap_back(): + f, g = map(Function, 'fg') + fx, gx = f(x), g(x) + assert solve([fx + y - 2, fx - gx - 5], fx, y, gx) == \ + {fx: gx + 5, y: -gx - 3} + assert solve(fx + gx*x - 2, [fx, gx], dict=True) == [{fx: 2, gx: 0}] + assert solve(fx + gx**2*x - y, [fx, gx], dict=True) == [{fx: y, gx: 0}] + assert solve([f(1) - 2, x + 2], dict=True) == [{x: -2, f(1): 2}] + + +def guess_solve_strategy(eq, symbol): + try: + solve(eq, symbol) + return True + except (TypeError, NotImplementedError): + return False + + +def test_guess_poly(): + # polynomial equations + assert guess_solve_strategy( S(4), x ) # == GS_POLY + assert guess_solve_strategy( x, x ) # == GS_POLY + assert guess_solve_strategy( x + a, x ) # == GS_POLY + assert guess_solve_strategy( 2*x, x ) # == GS_POLY + assert guess_solve_strategy( x + sqrt(2), x) # == GS_POLY + assert guess_solve_strategy( x + 2**Rational(1, 4), x) # == GS_POLY + assert guess_solve_strategy( x**2 + 1, x ) # == GS_POLY + assert guess_solve_strategy( x**2 - 1, x ) # == GS_POLY + assert guess_solve_strategy( x*y + y, x ) # == GS_POLY + assert guess_solve_strategy( x*exp(y) + y, x) # == GS_POLY + assert guess_solve_strategy( + (x - y**3)/(y**2*sqrt(1 - y**2)), x) # == GS_POLY + + +def test_guess_poly_cv(): + # polynomial equations via a change of variable + assert guess_solve_strategy( sqrt(x) + 1, x ) # == GS_POLY_CV_1 + assert guess_solve_strategy( + x**Rational(1, 3) + sqrt(x) + 1, x ) # == GS_POLY_CV_1 + assert guess_solve_strategy( 4*x*(1 - sqrt(x)), x ) # == GS_POLY_CV_1 + + # polynomial equation multiplying both sides by x**n + assert guess_solve_strategy( x + 1/x + y, x ) # == GS_POLY_CV_2 + + +def test_guess_rational_cv(): + # rational functions + assert guess_solve_strategy( (x + 1)/(x**2 + 2), x) # == GS_RATIONAL + assert guess_solve_strategy( + (x - y**3)/(y**2*sqrt(1 - y**2)), y) # == GS_RATIONAL_CV_1 + + # rational functions via the change of variable y -> x**n + assert guess_solve_strategy( (sqrt(x) + 1)/(x**Rational(1, 3) + sqrt(x) + 1), x ) \ + #== GS_RATIONAL_CV_1 + + +def test_guess_transcendental(): + #transcendental functions + assert guess_solve_strategy( exp(x) + 1, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy( 2*cos(x) - y, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy( + exp(x) + exp(-x) - y, x ) # == GS_TRANSCENDENTAL + assert guess_solve_strategy(3**x - 10, x) # == GS_TRANSCENDENTAL + assert guess_solve_strategy(-3**x + 10, x) # == GS_TRANSCENDENTAL + + assert guess_solve_strategy(a*x**b - y, x) # == GS_TRANSCENDENTAL + + +def test_solve_args(): + # equation container, issue 5113 + ans = {x: -3, y: 1} + eqs = (x + 5*y - 2, -3*x + 6*y - 15) + assert all(solve(container(eqs), x, y) == ans for container in + (tuple, list, set, frozenset)) + assert solve(Tuple(*eqs), x, y) == ans + # implicit symbol to solve for + assert set(solve(x**2 - 4)) == {S(2), -S(2)} + assert solve([x + y - 3, x - y - 5]) == {x: 4, y: -1} + assert solve(x - exp(x), x, implicit=True) == [exp(x)] + # no symbol to solve for + assert solve(42) == solve(42, x) == [] + assert solve([1, 2]) == [] + assert solve([sqrt(2)],[x]) == [] + # duplicate symbols raises + raises(ValueError, lambda: solve((x - 3, y + 2), x, y, x)) + raises(ValueError, lambda: solve(x, x, x)) + # no error in exclude + assert solve(x, x, exclude=[y, y]) == [0] + # duplicate symbols raises + raises(ValueError, lambda: solve((x - 3, y + 2), x, y, x)) + raises(ValueError, lambda: solve(x, x, x)) + # no error in exclude + assert solve(x, x, exclude=[y, y]) == [0] + # unordered symbols + # only 1 + assert solve(y - 3, {y}) == [3] + # more than 1 + assert solve(y - 3, {x, y}) == [{y: 3}] + # multiple symbols: take the first linear solution+ + # - return as tuple with values for all requested symbols + assert solve(x + y - 3, [x, y]) == [(3 - y, y)] + # - unless dict is True + assert solve(x + y - 3, [x, y], dict=True) == [{x: 3 - y}] + # - or no symbols are given + assert solve(x + y - 3) == [{x: 3 - y}] + # multiple symbols might represent an undetermined coefficients system + assert solve(a + b*x - 2, [a, b]) == {a: 2, b: 0} + assert solve((a + b)*x + b - c, [a, b]) == {a: -c, b: c} + eq = a*x**2 + b*x + c - ((x - h)**2 + 4*p*k)/4/p + # - check that flags are obeyed + sol = solve(eq, [h, p, k], exclude=[a, b, c]) + assert sol == {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + assert solve(eq, [h, p, k], dict=True) == [sol] + assert solve(eq, [h, p, k], set=True) == \ + ([h, p, k], {(-b/(2*a), 1/(4*a), (4*a*c - b**2)/(4*a))}) + # issue 23889 - polysys not simplified + assert solve(eq, [h, p, k], exclude=[a, b, c], simplify=False) == \ + {h: -b/(2*a), k: (4*a*c - b**2)/(4*a), p: 1/(4*a)} + # but this only happens when system has a single solution + args = (a + b)*x - b**2 + 2, a, b + assert solve(*args) == [((b**2 - b*x - 2)/x, b)] + # and if the system has a solution; the following doesn't so + # an algebraic solution is returned + assert solve(a*x + b**2/(x + 4) - 3*x - 4/x, a, b, dict=True) == \ + [{a: (-b**2*x + 3*x**3 + 12*x**2 + 4*x + 16)/(x**2*(x + 4))}] + # failed single equation + assert solve(1/(1/x - y + exp(y))) == [] + raises( + NotImplementedError, lambda: solve(exp(x) + sin(x) + exp(y) + sin(y))) + # failed system + # -- when no symbols given, 1 fails + assert solve([y, exp(x) + x]) == [{x: -LambertW(1), y: 0}] + # both fail + assert solve( + (exp(x) - x, exp(y) - y)) == [{x: -LambertW(-1), y: -LambertW(-1)}] + # -- when symbols given + assert solve([y, exp(x) + x], x, y) == [(-LambertW(1), 0)] + # symbol is a number + assert solve(x**2 - pi, pi) == [x**2] + # no equations + assert solve([], [x]) == [] + # nonlinear system + assert solve((x**2 - 4, y - 2), x, y) == [(-2, 2), (2, 2)] + assert solve((x**2 - 4, y - 2), y, x) == [(2, -2), (2, 2)] + assert solve((x**2 - 4 + z, y - 2 - z), a, z, y, x, set=True + ) == ([a, z, y, x], { + (a, z, z + 2, -sqrt(4 - z)), + (a, z, z + 2, sqrt(4 - z))}) + # overdetermined system + # - nonlinear + assert solve([(x + y)**2 - 4, x + y - 2]) == [{x: -y + 2}] + # - linear + assert solve((x + y - 2, 2*x + 2*y - 4)) == {x: -y + 2} + # When one or more args are Boolean + assert solve(Eq(x**2, 0.0)) == [0.0] # issue 19048 + assert solve([True, Eq(x, 0)], [x], dict=True) == [{x: 0}] + assert solve([Eq(x, x), Eq(x, 0), Eq(x, x+1)], [x], dict=True) == [] + assert not solve([Eq(x, x+1), x < 2], x) + assert solve([Eq(x, 0), x+1<2]) == Eq(x, 0) + assert solve([Eq(x, x), Eq(x, x+1)], x) == [] + assert solve(True, x) == [] + assert solve([x - 1, False], [x], set=True) == ([], set()) + assert solve([-y*(x + y - 1)/2, (y - 1)/x/y + 1/y], + set=True, check=False) == ([x, y], {(1 - y, y), (x, 0)}) + # ordering should be canonical, fastest to order by keys instead + # of by size + assert list(solve((y - 1, x - sqrt(3)*z)).keys()) == [x, y] + # as set always returns as symbols, set even if no solution + assert solve([x - 1, x], (y, x), set=True) == ([y, x], set()) + assert solve([x - 1, x], {y, x}, set=True) == ([x, y], set()) + + +def test_solve_polynomial1(): + assert solve(3*x - 2, x) == [Rational(2, 3)] + assert solve(Eq(3*x, 2), x) == [Rational(2, 3)] + + assert set(solve(x**2 - 1, x)) == {-S.One, S.One} + assert set(solve(Eq(x**2, 1), x)) == {-S.One, S.One} + + assert solve(x - y**3, x) == [y**3] + rx = root(x, 3) + assert solve(x - y**3, y) == [ + rx, -rx/2 - sqrt(3)*I*rx/2, -rx/2 + sqrt(3)*I*rx/2] + a11, a12, a21, a22, b1, b2 = symbols('a11,a12,a21,a22,b1,b2') + + assert solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) == \ + { + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21), + y: (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + } + + solution = {x: S.Zero, y: S.Zero} + + assert solve((x - y, x + y), x, y ) == solution + assert solve((x - y, x + y), (x, y)) == solution + assert solve((x - y, x + y), [x, y]) == solution + + assert set(solve(x**3 - 15*x - 4, x)) == { + -2 + 3**S.Half, + S(4), + -2 - 3**S.Half + } + + assert set(solve((x**2 - 1)**2 - a, x)) == \ + {sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), + sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))} + + +def test_solve_polynomial2(): + assert solve(4, x) == [] + + +def test_solve_polynomial_cv_1a(): + """ + Test for solving on equations that can be converted to a polynomial equation + using the change of variable y -> x**Rational(p, q) + """ + assert solve( sqrt(x) - 1, x) == [1] + assert solve( sqrt(x) - 2, x) == [4] + assert solve( x**Rational(1, 4) - 2, x) == [16] + assert solve( x**Rational(1, 3) - 3, x) == [27] + assert solve(sqrt(x) + x**Rational(1, 3) + x**Rational(1, 4), x) == [0] + + +def test_solve_polynomial_cv_1b(): + assert set(solve(4*x*(1 - a*sqrt(x)), x)) == {S.Zero, 1/a**2} + assert set(solve(x*(root(x, 3) - 3), x)) == {S.Zero, S(27)} + + +def test_solve_polynomial_cv_2(): + """ + Test for solving on equations that can be converted to a polynomial equation + multiplying both sides of the equation by x**m + """ + assert solve(x + 1/x - 1, x) in \ + [[ S.Half + I*sqrt(3)/2, S.Half - I*sqrt(3)/2], + [ S.Half - I*sqrt(3)/2, S.Half + I*sqrt(3)/2]] + + +def test_quintics_1(): + f = x**5 - 110*x**3 - 55*x**2 + 2310*x + 979 + s = solve(f, check=False) + for r in s: + res = f.subs(x, r.n()).n() + assert tn(res, 0) + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = solve(f) + for r in s: + assert r.func == CRootOf + + # if one uses solve to get the roots of a polynomial that has a CRootOf + # solution, make sure that the use of nfloat during the solve process + # doesn't fail. Note: if you want numerical solutions to a polynomial + # it is *much* faster to use nroots to get them than to solve the + # equation only to get RootOf solutions which are then numerically + # evaluated. So for eq = x**5 + 3*x + 7 do Poly(eq).nroots() rather + # than [i.n() for i in solve(eq)] to get the numerical roots of eq. + assert nfloat(solve(x**5 + 3*x**3 + 7)[0], exponent=False) == \ + CRootOf(x**5 + 3*x**3 + 7, 0).n() + + +def test_quintics_2(): + f = x**5 + 15*x + 12 + s = solve(f, check=False) + for r in s: + res = f.subs(x, r.n()).n() + assert tn(res, 0) + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = solve(f) + for r in s: + assert r.func == CRootOf + + assert solve(x**5 - 6*x**3 - 6*x**2 + x - 6) == [ + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 0), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 1), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 2), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 3), + CRootOf(x**5 - 6*x**3 - 6*x**2 + x - 6, 4)] + + +def test_quintics_3(): + y = x**5 + x**3 - 2**Rational(1, 3) + assert solve(y) == solve(-y) == [] + + +def test_highorder_poly(): + # just testing that the uniq generator is unpacked + sol = solve(x**6 - 2*x + 2) + assert all(isinstance(i, CRootOf) for i in sol) and len(sol) == 6 + + +def test_solve_rational(): + """Test solve for rational functions""" + assert solve( ( x - y**3 )/( (y**2)*sqrt(1 - y**2) ), x) == [y**3] + + +def test_solve_conjugate(): + """Test solve for simple conjugate functions""" + assert solve(conjugate(x) -3 + I) == [3 + I] + + +def test_solve_nonlinear(): + assert solve(x**2 - y**2, x, y, dict=True) == [{x: -y}, {x: y}] + assert solve(x**2 - y**2/exp(x), y, x, dict=True) == [{y: -x*sqrt(exp(x))}, + {y: x*sqrt(exp(x))}] + + +def test_issue_8666(): + x = symbols('x') + assert solve(Eq(x**2 - 1/(x**2 - 4), 4 - 1/(x**2 - 4)), x) == [] + assert solve(Eq(x + 1/x, 1/x), x) == [] + + +def test_issue_7228(): + assert solve(4**(2*(x**2) + 2*x) - 8, x) == [Rational(-3, 2), S.Half] + + +def test_issue_7190(): + assert solve(log(x-3) + log(x+3), x) == [sqrt(10)] + + +def test_issue_21004(): + x = symbols('x') + f = x/sqrt(x**2+1) + f_diff = f.diff(x) + assert solve(f_diff, x) == [] + + +def test_issue_24650(): + x = symbols('x') + r = solve(Eq(Piecewise((x, Eq(x, 0) | (x > 1))), 0)) + assert r == [0] + r = checksol(Eq(Piecewise((x, Eq(x, 0) | (x > 1))), 0), x, sol=0) + assert r is True + + +def test_linear_system(): + x, y, z, t, n = symbols('x, y, z, t, n') + + assert solve([x - 1, x - y, x - 2*y, y - 1], [x, y]) == [] + + assert solve([x - 1, x - y, x - 2*y, x - 1], [x, y]) == [] + assert solve([x - 1, x - 1, x - y, x - 2*y], [x, y]) == [] + + assert solve([x + 5*y - 2, -3*x + 6*y - 15], x, y) == {x: -3, y: 1} + + M = Matrix([[0, 0, n*(n + 1), (n + 1)**2, 0], + [n + 1, n + 1, -2*n - 1, -(n + 1), 0], + [-1, 0, 1, 0, 0]]) + + assert solve_linear_system(M, x, y, z, t) == \ + {x: t*(-n-1)/n, y: 0, z: t*(-n-1)/n} + + assert solve([x + y + z + t, -z - t], x, y, z, t) == {x: -y, z: -t} + + +@XFAIL +def test_linear_system_xfail(): + # https://github.com/sympy/sympy/issues/6420 + M = Matrix([[0, 15.0, 10.0, 700.0], + [1, 1, 1, 100.0], + [0, 10.0, 5.0, 200.0], + [-5.0, 0, 0, 0 ]]) + + assert solve_linear_system(M, x, y, z) == {x: 0, y: -60.0, z: 160.0} + + +def test_linear_system_function(): + a = Function('a') + assert solve([a(0, 0) + a(0, 1) + a(1, 0) + a(1, 1), -a(1, 0) - a(1, 1)], + a(0, 0), a(0, 1), a(1, 0), a(1, 1)) == {a(1, 0): -a(1, 1), a(0, 0): -a(0, 1)} + + +def test_linear_system_symbols_doesnt_hang_1(): + + def _mk_eqs(wy): + # Equations for fitting a wy*2 - 1 degree polynomial between two points, + # at end points derivatives are known up to order: wy - 1 + order = 2*wy - 1 + x, x0, x1 = symbols('x, x0, x1', real=True) + y0s = symbols('y0_:{}'.format(wy), real=True) + y1s = symbols('y1_:{}'.format(wy), real=True) + c = symbols('c_:{}'.format(order+1), real=True) + + expr = sum(coeff*x**o for o, coeff in enumerate(c)) + eqs = [] + for i in range(wy): + eqs.append(expr.diff(x, i).subs({x: x0}) - y0s[i]) + eqs.append(expr.diff(x, i).subs({x: x1}) - y1s[i]) + return eqs, c + + # + # The purpose of this test is just to see that these calls don't hang. The + # expressions returned are complicated so are not included here. Testing + # their correctness takes longer than solving the system. + # + + for n in range(1, 7+1): + eqs, c = _mk_eqs(n) + solve(eqs, c) + + +def test_linear_system_symbols_doesnt_hang_2(): + + M = Matrix([ + [66, 24, 39, 50, 88, 40, 37, 96, 16, 65, 31, 11, 37, 72, 16, 19, 55, 37, 28, 76], + [10, 93, 34, 98, 59, 44, 67, 74, 74, 94, 71, 61, 60, 23, 6, 2, 57, 8, 29, 78], + [19, 91, 57, 13, 64, 65, 24, 53, 77, 34, 85, 58, 87, 39, 39, 7, 36, 67, 91, 3], + [74, 70, 15, 53, 68, 43, 86, 83, 81, 72, 25, 46, 67, 17, 59, 25, 78, 39, 63, 6], + [69, 40, 67, 21, 67, 40, 17, 13, 93, 44, 46, 89, 62, 31, 30, 38, 18, 20, 12, 81], + [50, 22, 74, 76, 34, 45, 19, 76, 28, 28, 11, 99, 97, 82, 8, 46, 99, 57, 68, 35], + [58, 18, 45, 88, 10, 64, 9, 34, 90, 82, 17, 41, 43, 81, 45, 83, 22, 88, 24, 39], + [42, 21, 70, 68, 6, 33, 64, 81, 83, 15, 86, 75, 86, 17, 77, 34, 62, 72, 20, 24], + [ 7, 8, 2, 72, 71, 52, 96, 5, 32, 51, 31, 36, 79, 88, 25, 77, 29, 26, 33, 13], + [19, 31, 30, 85, 81, 39, 63, 28, 19, 12, 16, 49, 37, 66, 38, 13, 3, 71, 61, 51], + [29, 82, 80, 49, 26, 85, 1, 37, 2, 74, 54, 82, 26, 47, 54, 9, 35, 0, 99, 40], + [15, 49, 82, 91, 93, 57, 45, 25, 45, 97, 15, 98, 48, 52, 66, 24, 62, 54, 97, 37], + [62, 23, 73, 53, 52, 86, 28, 38, 0, 74, 92, 38, 97, 70, 71, 29, 26, 90, 67, 45], + [ 2, 32, 23, 24, 71, 37, 25, 71, 5, 41, 97, 65, 93, 13, 65, 45, 25, 88, 69, 50], + [40, 56, 1, 29, 79, 98, 79, 62, 37, 28, 45, 47, 3, 1, 32, 74, 98, 35, 84, 32], + [33, 15, 87, 79, 65, 9, 14, 63, 24, 19, 46, 28, 74, 20, 29, 96, 84, 91, 93, 1], + [97, 18, 12, 52, 1, 2, 50, 14, 52, 76, 19, 82, 41, 73, 51, 79, 13, 3, 82, 96], + [40, 28, 52, 10, 10, 71, 56, 78, 82, 5, 29, 48, 1, 26, 16, 18, 50, 76, 86, 52], + [38, 89, 83, 43, 29, 52, 90, 77, 57, 0, 67, 20, 81, 88, 48, 96, 88, 58, 14, 3]]) + + syms = x0,x1,x2,x3,x4,x5,x6,x7,x8,x9,x10,x11,x12,x13,x14,x15,x16,x17,x18 = symbols('x:19') + + sol = { + x0: -S(1967374186044955317099186851240896179)/3166636564687820453598895768302256588, + x1: -S(84268280268757263347292368432053826)/791659141171955113399723942075564147, + x2: -S(229962957341664730974463872411844965)/1583318282343910226799447884151128294, + x3: S(990156781744251750886760432229180537)/6333273129375640907197791536604513176, + x4: -S(2169830351210066092046760299593096265)/18999819388126922721593374609813539528, + x5: S(4680868883477577389628494526618745355)/9499909694063461360796687304906769764, + x6: -S(1590820774344371990683178396480879213)/3166636564687820453598895768302256588, + x7: -S(54104723404825537735226491634383072)/339282489073695048599881689460956063, + x8: S(3182076494196560075964847771774733847)/6333273129375640907197791536604513176, + x9: -S(10870817431029210431989147852497539675)/18999819388126922721593374609813539528, + x10: -S(13118019242576506476316318268573312603)/18999819388126922721593374609813539528, + x11: -S(5173852969886775824855781403820641259)/4749954847031730680398343652453384882, + x12: S(4261112042731942783763341580651820563)/4749954847031730680398343652453384882, + x13: -S(821833082694661608993818117038209051)/6333273129375640907197791536604513176, + x14: S(906881575107250690508618713632090559)/904753304196520129599684505229216168, + x15: -S(732162528717458388995329317371283987)/6333273129375640907197791536604513176, + x16: S(4524215476705983545537087360959896817)/9499909694063461360796687304906769764, + x17: -S(3898571347562055611881270844646055217)/6333273129375640907197791536604513176, + x18: S(7513502486176995632751685137907442269)/18999819388126922721593374609813539528 + } + + eqs = list(M * Matrix(syms + (1,))) + assert solve(eqs, syms) == sol + + y = Symbol('y') + eqs = list(y * M * Matrix(syms + (1,))) + assert solve(eqs, syms) == sol + + +def test_linear_systemLU(): + n = Symbol('n') + + M = Matrix([[1, 2, 0, 1], [1, 3, 2*n, 1], [4, -1, n**2, 1]]) + + assert solve_linear_system_LU(M, [x, y, z]) == {z: -3/(n**2 + 18*n), + x: 1 - 12*n/(n**2 + 18*n), + y: 6*n/(n**2 + 18*n)} + +# Note: multiple solutions exist for some of these equations, so the tests +# should be expected to break if the implementation of the solver changes +# in such a way that a different branch is chosen + +@slow +def test_solve_transcendental(): + from sympy.abc import a, b + + assert solve(exp(x) - 3, x) == [log(3)] + assert set(solve((a*x + b)*(exp(x) - 3), x)) == {-b/a, log(3)} + assert solve(cos(x) - y, x) == [-acos(y) + 2*pi, acos(y)] + assert solve(2*cos(x) - y, x) == [-acos(y/2) + 2*pi, acos(y/2)] + assert solve(Eq(cos(x), sin(x)), x) == [pi/4] + + assert set(solve(exp(x) + exp(-x) - y, x)) in [{ + log(y/2 - sqrt(y**2 - 4)/2), + log(y/2 + sqrt(y**2 - 4)/2), + }, { + log(y - sqrt(y**2 - 4)) - log(2), + log(y + sqrt(y**2 - 4)) - log(2)}, + { + log(y/2 - sqrt((y - 2)*(y + 2))/2), + log(y/2 + sqrt((y - 2)*(y + 2))/2)}] + assert solve(exp(x) - 3, x) == [log(3)] + assert solve(Eq(exp(x), 3), x) == [log(3)] + assert solve(log(x) - 3, x) == [exp(3)] + assert solve(sqrt(3*x) - 4, x) == [Rational(16, 3)] + assert solve(3**(x + 2), x) == [] + assert solve(3**(2 - x), x) == [] + assert solve(x + 2**x, x) == [-LambertW(log(2))/log(2)] + assert solve(2*x + 5 + log(3*x - 2), x) == \ + [Rational(2, 3) + LambertW(2*exp(Rational(-19, 3))/3)/2] + assert solve(3*x + log(4*x), x) == [LambertW(Rational(3, 4))/3] + assert set(solve((2*x + 8)*(8 + exp(x)), x)) == {S(-4), log(8) + pi*I} + eq = 2*exp(3*x + 4) - 3 + ans = solve(eq, x) # this generated a failure in flatten + assert len(ans) == 3 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans) + assert solve(2*log(3*x + 4) - 3, x) == [(exp(Rational(3, 2)) - 4)/3] + assert solve(exp(x) + 1, x) == [pi*I] + + eq = 2*(3*x + 4)**5 - 6*7**(3*x + 9) + result = solve(eq, x) + x0 = -log(2401) + x1 = 3**Rational(1, 5) + x2 = log(7**(7*x1/20)) + x3 = sqrt(2) + x4 = sqrt(5) + x5 = x3*sqrt(x4 - 5) + x6 = x4 + 1 + x7 = 1/(3*log(7)) + x8 = -x4 + x9 = x3*sqrt(x8 - 5) + x10 = x8 + 1 + ans = [x7*(x0 - 5*LambertW(x2*(-x5 + x6))), + x7*(x0 - 5*LambertW(x2*(x5 + x6))), + x7*(x0 - 5*LambertW(x2*(x10 - x9))), + x7*(x0 - 5*LambertW(x2*(x10 + x9))), + x7*(x0 - 5*LambertW(-log(7**(7*x1/5))))] + assert result == ans, result + # it works if expanded, too + assert solve(eq.expand(), x) == result + + assert solve(z*cos(x) - y, x) == [-acos(y/z) + 2*pi, acos(y/z)] + assert solve(z*cos(2*x) - y, x) == [-acos(y/z)/2 + pi, acos(y/z)/2] + assert solve(z*cos(sin(x)) - y, x) == [ + pi - asin(acos(y/z)), asin(acos(y/z) - 2*pi) + pi, + -asin(acos(y/z) - 2*pi), asin(acos(y/z))] + + assert solve(z*cos(x), x) == [pi/2, pi*Rational(3, 2)] + + # issue 4508 + assert solve(y - b*x/(a + x), x) in [[-a*y/(y - b)], [a*y/(b - y)]] + assert solve(y - b*exp(a/x), x) == [a/log(y/b)] + # issue 4507 + assert solve(y - b/(1 + a*x), x) in [[(b - y)/(a*y)], [-((y - b)/(a*y))]] + # issue 4506 + assert solve(y - a*x**b, x) == [(y/a)**(1/b)] + # issue 4505 + assert solve(z**x - y, x) == [log(y)/log(z)] + # issue 4504 + assert solve(2**x - 10, x) == [1 + log(5)/log(2)] + # issue 6744 + assert solve(x*y) == [{x: 0}, {y: 0}] + assert solve([x*y]) == [{x: 0}, {y: 0}] + assert solve(x**y - 1) == [{x: 1}, {y: 0}] + assert solve([x**y - 1]) == [{x: 1}, {y: 0}] + assert solve(x*y*(x**2 - y**2)) == [{x: 0}, {x: -y}, {x: y}, {y: 0}] + assert solve([x*y*(x**2 - y**2)]) == [{x: 0}, {x: -y}, {x: y}, {y: 0}] + # issue 4739 + assert solve(exp(log(5)*x) - 2**x, x) == [0] + # issue 14791 + assert solve(exp(log(5)*x) - exp(log(2)*x), x) == [0] + f = Function('f') + assert solve(y*f(log(5)*x) - y*f(log(2)*x), x) == [0] + assert solve(f(x) - f(0), x) == [0] + assert solve(f(x) - f(2 - x), x) == [1] + raises(NotImplementedError, lambda: solve(f(x, y) - f(1, 2), x)) + raises(NotImplementedError, lambda: solve(f(x, y) - f(2 - x, 2), x)) + raises(ValueError, lambda: solve(f(x, y) - f(1 - x), x)) + raises(ValueError, lambda: solve(f(x, y) - f(1), x)) + + # misc + # make sure that the right variables is picked up in tsolve + # shouldn't generate a GeneratorsNeeded error in _tsolve when the NaN is generated + # for eq_down. Actual answers, as determined numerically are approx. +/- 0.83 + raises(NotImplementedError, lambda: + solve(sinh(x)*sinh(sinh(x)) + cosh(x)*cosh(sinh(x)) - 3)) + + # watch out for recursive loop in tsolve + raises(NotImplementedError, lambda: solve((x + 2)**y*x - 3, x)) + + # issue 7245 + assert solve(sin(sqrt(x))) == [0, pi**2] + + # issue 7602 + a, b = symbols('a, b', real=True, negative=False) + assert str(solve(Eq(a, 0.5 - cos(pi*b)/2), b)) == \ + '[2.0 - 0.318309886183791*acos(1.0 - 2.0*a), 0.318309886183791*acos(1.0 - 2.0*a)]' + + # issue 15325 + assert solve(y**(1/x) - z, x) == [log(y)/log(z)] + + # issue 25685 (basic trig identies should give simple solutions) + for yi in [cos(2*x),sin(2*x),cos(x - pi/3)]: + sol = solve([cos(x) - S(3)/5, yi - y]) + assert (sol[0][y] + sol[1][y]).is_Rational, (yi,sol) + # don't allow massive expansion + assert solve(cos(1000*x) - S.Half) == [pi/3000, pi/600] + assert solve(cos(x - 1000*y) - 1, x) == [1000*y, 1000*y + 2*pi] + assert solve(cos(x + y + z) - 1, x) == [-y - z, -y - z + 2*pi] + + # issue 26008 + assert solve(sin(x + pi/6)) == [-pi/6, 5*pi/6] + + +def test_solve_for_functions_derivatives(): + t = Symbol('t') + x = Function('x')(t) + y = Function('y')(t) + a11, a12, a21, a22, b1, b2 = symbols('a11,a12,a21,a22,b1,b2') + + soln = solve([a11*x + a12*y - b1, a21*x + a22*y - b2], x, y) + assert soln == { + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21), + y: (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + } + + assert solve(x - 1, x) == [1] + assert solve(3*x - 2, x) == [Rational(2, 3)] + + soln = solve([a11*x.diff(t) + a12*y.diff(t) - b1, a21*x.diff(t) + + a22*y.diff(t) - b2], x.diff(t), y.diff(t)) + assert soln == { y.diff(t): (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x.diff(t): (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + + assert solve(x.diff(t) - 1, x.diff(t)) == [1] + assert solve(3*x.diff(t) - 2, x.diff(t)) == [Rational(2, 3)] + + eqns = {3*x - 1, 2*y - 4} + assert solve(eqns, {x, y}) == { x: Rational(1, 3), y: 2 } + x = Symbol('x') + f = Function('f') + F = x**2 + f(x)**2 - 4*x - 1 + assert solve(F.diff(x), diff(f(x), x)) == [(-x + 2)/f(x)] + + # Mixed cased with a Symbol and a Function + x = Symbol('x') + y = Function('y')(t) + + soln = solve([a11*x + a12*y.diff(t) - b1, a21*x + + a22*y.diff(t) - b2], x, y.diff(t)) + assert soln == { y.diff(t): (a11*b2 - a21*b1)/(a11*a22 - a12*a21), + x: (a22*b1 - a12*b2)/(a11*a22 - a12*a21) } + + # issue 13263 + x = Symbol('x') + f = Function('f') + soln = solve([f(x).diff(x) + f(x).diff(x, 2) - 1, f(x).diff(x) - f(x).diff(x, 2)], + f(x).diff(x), f(x).diff(x, 2)) + assert soln == { f(x).diff(x, 2): S(1)/2, f(x).diff(x): S(1)/2 } + + soln = solve([f(x).diff(x, 2) + f(x).diff(x, 3) - 1, 1 - f(x).diff(x, 2) - + f(x).diff(x, 3), 1 - f(x).diff(x,3)], f(x).diff(x, 2), f(x).diff(x, 3)) + assert soln == { f(x).diff(x, 2): 0, f(x).diff(x, 3): 1 } + + +def test_issue_3725(): + f = Function('f') + F = x**2 + f(x)**2 - 4*x - 1 + e = F.diff(x) + assert solve(e, f(x).diff(x)) in [[(2 - x)/f(x)], [-((x - 2)/f(x))]] + + +def test_issue_3870(): + a, b, c, d = symbols('a b c d') + A = Matrix(2, 2, [a, b, c, d]) + B = Matrix(2, 2, [0, 2, -3, 0]) + C = Matrix(2, 2, [1, 2, 3, 4]) + + assert solve(A*B - C, [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + assert solve([A*B - C], [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + assert solve(Eq(A*B, C), [a, b, c, d]) == {a: 1, b: Rational(-1, 3), c: 2, d: -1} + + assert solve([A*B - B*A], [a, b, c, d]) == {a: d, b: Rational(-2, 3)*c} + assert solve([A*C - C*A], [a, b, c, d]) == {a: d - c, b: Rational(2, 3)*c} + assert solve([A*B - B*A, A*C - C*A], [a, b, c, d]) == {a: d, b: 0, c: 0} + + assert solve([Eq(A*B, B*A)], [a, b, c, d]) == {a: d, b: Rational(-2, 3)*c} + assert solve([Eq(A*C, C*A)], [a, b, c, d]) == {a: d - c, b: Rational(2, 3)*c} + assert solve([Eq(A*B, B*A), Eq(A*C, C*A)], [a, b, c, d]) == {a: d, b: 0, c: 0} + + +def test_solve_linear(): + w = Wild('w') + assert solve_linear(x, x) == (0, 1) + assert solve_linear(x, exclude=[x]) == (0, 1) + assert solve_linear(x, symbols=[w]) == (0, 1) + assert solve_linear(x, y - 2*x) in [(x, y/3), (y, 3*x)] + assert solve_linear(x, y - 2*x, exclude=[x]) == (y, 3*x) + assert solve_linear(3*x - y, 0) in [(x, y/3), (y, 3*x)] + assert solve_linear(3*x - y, 0, [x]) == (x, y/3) + assert solve_linear(3*x - y, 0, [y]) == (y, 3*x) + assert solve_linear(x**2/y, 1) == (y, x**2) + assert solve_linear(w, x) in [(w, x), (x, w)] + assert solve_linear(cos(x)**2 + sin(x)**2 + 2 + y) == \ + (y, -2 - cos(x)**2 - sin(x)**2) + assert solve_linear(cos(x)**2 + sin(x)**2 + 2 + y, symbols=[x]) == (0, 1) + assert solve_linear(Eq(x, 3)) == (x, 3) + assert solve_linear(1/(1/x - 2)) == (0, 0) + assert solve_linear((x + 1)*exp(-x), symbols=[x]) == (x, -1) + assert solve_linear((x + 1)*exp(x), symbols=[x]) == ((x + 1)*exp(x), 1) + assert solve_linear(x*exp(-x**2), symbols=[x]) == (x, 0) + assert solve_linear(0**x - 1) == (0**x - 1, 1) + assert solve_linear(1 + 1/(x - 1)) == (x, 0) + eq = y*cos(x)**2 + y*sin(x)**2 - y # = y*(1 - 1) = 0 + assert solve_linear(eq) == (0, 1) + eq = cos(x)**2 + sin(x)**2 # = 1 + assert solve_linear(eq) == (0, 1) + raises(ValueError, lambda: solve_linear(Eq(x, 3), 3)) + + +def test_solve_undetermined_coeffs(): + assert solve_undetermined_coeffs( + a*x**2 + b*x**2 + b*x + 2*c*x + c + 1, [a, b, c], x + ) == {a: -2, b: 2, c: -1} + # Test that rational functions work + assert solve_undetermined_coeffs(a/x + b/(x + 1) + - (2*x + 1)/(x**2 + x), [a, b], x) == {a: 1, b: 1} + # Test cancellation in rational functions + assert solve_undetermined_coeffs( + ((c + 1)*a*x**2 + (c + 1)*b*x**2 + + (c + 1)*b*x + (c + 1)*2*c*x + (c + 1)**2)/(c + 1), + [a, b, c], x) == \ + {a: -2, b: 2, c: -1} + # multivariate + X, Y, Z = y, x**y, y*x**y + eq = a*X + b*Y + c*Z - X - 2*Y - 3*Z + coeffs = a, b, c + syms = x, y + assert solve_undetermined_coeffs(eq, coeffs) == { + a: 1, b: 2, c: 3} + assert solve_undetermined_coeffs(eq, coeffs, syms) == { + a: 1, b: 2, c: 3} + assert solve_undetermined_coeffs(eq, coeffs, *syms) == { + a: 1, b: 2, c: 3} + # check output format + assert solve_undetermined_coeffs(a*x + a - 2, [a]) == [] + assert solve_undetermined_coeffs(a**2*x - 4*x, [a]) == [ + {a: -2}, {a: 2}] + assert solve_undetermined_coeffs(0, [a]) == [] + assert solve_undetermined_coeffs(0, [a], dict=True) == [] + assert solve_undetermined_coeffs(0, [a], set=True) == ([], {}) + assert solve_undetermined_coeffs(1, [a]) == [] + abeq = a*x - 2*x + b - 3 + s = {b, a} + assert solve_undetermined_coeffs(abeq, s, x) == {a: 2, b: 3} + assert solve_undetermined_coeffs(abeq, s, x, set=True) == ([a, b], {(2, 3)}) + assert solve_undetermined_coeffs(sin(a*x) - sin(2*x), (a,)) is None + assert solve_undetermined_coeffs(a*x + b*x - 2*x, (a, b)) == {a: 2 - b} + + +def test_solve_inequalities(): + x = Symbol('x') + sol = And(S.Zero < x, x < oo) + assert solve(x + 1 > 1) == sol + assert solve([x + 1 > 1]) == sol + assert solve([x + 1 > 1], x) == sol + assert solve([x + 1 > 1], [x]) == sol + + system = [Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)] + assert solve(system) == \ + And(Or(And(Lt(-sqrt(2), x), Lt(x, -1)), + And(Lt(1, x), Lt(x, sqrt(2)))), Eq(0, 0)) + + x = Symbol('x', real=True) + system = [Lt(x**2 - 2, 0), Gt(x**2 - 1, 0)] + assert solve(system) == \ + Or(And(Lt(-sqrt(2), x), Lt(x, -1)), And(Lt(1, x), Lt(x, sqrt(2)))) + + # issues 6627, 3448 + assert solve((x - 3)/(x - 2) < 0, x) == And(Lt(2, x), Lt(x, 3)) + assert solve(x/(x + 1) > 1, x) == And(Lt(-oo, x), Lt(x, -1)) + + assert solve(sin(x) > S.Half) == And(pi/6 < x, x < pi*Rational(5, 6)) + + assert solve(Eq(False, x < 1)) == (S.One <= x) & (x < oo) + assert solve(Eq(True, x < 1)) == (-oo < x) & (x < 1) + assert solve(Eq(x < 1, False)) == (S.One <= x) & (x < oo) + assert solve(Eq(x < 1, True)) == (-oo < x) & (x < 1) + + assert solve(Eq(False, x)) == False + assert solve(Eq(0, x)) == [0] + assert solve(Eq(True, x)) == True + assert solve(Eq(1, x)) == [1] + assert solve(Eq(False, ~x)) == True + assert solve(Eq(True, ~x)) == False + assert solve(Ne(True, x)) == False + assert solve(Ne(1, x)) == (x > -oo) & (x < oo) & Ne(x, 1) + + +def test_issue_4793(): + assert solve(1/x) == [] + assert solve(x*(1 - 5/x)) == [5] + assert solve(x + sqrt(x) - 2) == [1] + assert solve(-(1 + x)/(2 + x)**2 + 1/(2 + x)) == [] + assert solve(-x**2 - 2*x + (x + 1)**2 - 1) == [] + assert solve((x/(x + 1) + 3)**(-2)) == [] + assert solve(x/sqrt(x**2 + 1), x) == [0] + assert solve(exp(x) - y, x) == [log(y)] + assert solve(exp(x)) == [] + assert solve(x**2 + x + sin(y)**2 + cos(y)**2 - 1, x) in [[0, -1], [-1, 0]] + eq = 4*3**(5*x + 2) - 7 + ans = solve(eq, x) + assert len(ans) == 5 and all(eq.subs(x, a).n(chop=True) == 0 for a in ans) + assert solve(log(x**2) - y**2/exp(x), x, y, set=True) == ( + [x, y], + {(x, sqrt(exp(x) * log(x ** 2))), (x, -sqrt(exp(x) * log(x ** 2)))}) + assert solve(x**2*z**2 - z**2*y**2) == [{x: -y}, {x: y}, {z: 0}] + assert solve((x - 1)/(1 + 1/(x - 1))) == [] + assert solve(x**(y*z) - x, x) == [1] + raises(NotImplementedError, lambda: solve(log(x) - exp(x), x)) + raises(NotImplementedError, lambda: solve(2**x - exp(x) - 3)) + + +def test_PR1964(): + # issue 5171 + assert solve(sqrt(x)) == solve(sqrt(x**3)) == [0] + assert solve(sqrt(x - 1)) == [1] + # issue 4462 + a = Symbol('a') + assert solve(-3*a/sqrt(x), x) == [] + # issue 4486 + assert solve(2*x/(x + 2) - 1, x) == [2] + # issue 4496 + assert set(solve((x**2/(7 - x)).diff(x))) == {S.Zero, S(14)} + # issue 4695 + f = Function('f') + assert solve((3 - 5*x/f(x))*f(x), f(x)) == [x*Rational(5, 3)] + # issue 4497 + assert solve(1/root(5 + x, 5) - 9, x) == [Rational(-295244, 59049)] + + assert solve(sqrt(x) + sqrt(sqrt(x)) - 4) == [(Rational(-1, 2) + sqrt(17)/2)**4] + assert set(solve(Poly(sqrt(exp(x)) + sqrt(exp(-x)) - 4))) in \ + [ + {log((-sqrt(3) + 2)**2), log((sqrt(3) + 2)**2)}, + {2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)}, + {log(-4*sqrt(3) + 7), log(4*sqrt(3) + 7)}, + ] + assert set(solve(Poly(exp(x) + exp(-x) - 4))) == \ + {log(-sqrt(3) + 2), log(sqrt(3) + 2)} + assert set(solve(x**y + x**(2*y) - 1, x)) == \ + {(Rational(-1, 2) + sqrt(5)/2)**(1/y), (Rational(-1, 2) - sqrt(5)/2)**(1/y)} + + assert solve(exp(x/y)*exp(-z/y) - 2, y) == [(x - z)/log(2)] + assert solve( + x**z*y**z - 2, z) in [[log(2)/(log(x) + log(y))], [log(2)/(log(x*y))]] + # if you do inversion too soon then multiple roots (as for the following) + # will be missed, e.g. if exp(3*x) = exp(3) -> 3*x = 3 + E = S.Exp1 + assert solve(exp(3*x) - exp(3), x) in [ + [1, log(E*(Rational(-1, 2) - sqrt(3)*I/2)), log(E*(Rational(-1, 2) + sqrt(3)*I/2))], + [1, log(-E/2 - sqrt(3)*E*I/2), log(-E/2 + sqrt(3)*E*I/2)], + ] + + # coverage test + p = Symbol('p', positive=True) + assert solve((1/p + 1)**(p + 1)) == [] + + +def test_issue_5197(): + x = Symbol('x', real=True) + assert solve(x**2 + 1, x) == [] + n = Symbol('n', integer=True, positive=True) + assert solve((n - 1)*(n + 2)*(2*n - 1), n) == [1] + x = Symbol('x', positive=True) + y = Symbol('y') + assert solve([x + 5*y - 2, -3*x + 6*y - 15], x, y) == [] + # not {x: -3, y: 1} b/c x is positive + # The solution following should not contain (-sqrt(2), sqrt(2)) + assert solve([(x + y), 2 - y**2], x, y) == [(sqrt(2), -sqrt(2))] + y = Symbol('y', positive=True) + # The solution following should not contain {y: -x*exp(x/2)} + assert solve(x**2 - y**2/exp(x), y, x, dict=True) == [{y: x*exp(x/2)}] + x, y, z = symbols('x y z', positive=True) + assert solve(z**2*x**2 - z**2*y**2/exp(x), y, x, z, dict=True) == [{y: x*exp(x/2)}] + + +def test_checking(): + assert set( + solve(x*(x - y/x), x, check=False)) == {sqrt(y), S.Zero, -sqrt(y)} + assert set(solve(x*(x - y/x), x, check=True)) == {sqrt(y), -sqrt(y)} + # {x: 0, y: 4} sets denominator to 0 in the following so system should return None + assert solve((1/(1/x + 2), 1/(y - 3) - 1)) == [] + # 0 sets denominator of 1/x to zero so None is returned + assert solve(1/(1/x + 2)) == [] + + +def test_issue_4671_4463_4467(): + assert solve(sqrt(x**2 - 1) - 2) in ([sqrt(5), -sqrt(5)], + [-sqrt(5), sqrt(5)]) + assert solve((2**exp(y**2/x) + 2)/(x**2 + 15), y) == [ + -sqrt(x*log(1 + I*pi/log(2))), sqrt(x*log(1 + I*pi/log(2)))] + + C1, C2 = symbols('C1 C2') + f = Function('f') + assert solve(C1 + C2/x**2 - exp(-f(x)), f(x)) == [log(x**2/(C1*x**2 + C2))] + a = Symbol('a') + E = S.Exp1 + assert solve(1 - log(a + 4*x**2), x) in ( + [-sqrt(-a + E)/2, sqrt(-a + E)/2], + [sqrt(-a + E)/2, -sqrt(-a + E)/2] + ) + assert solve(log(a**(-3) - x**2)/a, x) in ( + [-sqrt(-1 + a**(-3)), sqrt(-1 + a**(-3))], + [sqrt(-1 + a**(-3)), -sqrt(-1 + a**(-3))],) + assert solve(1 - log(a + 4*x**2), x) in ( + [-sqrt(-a + E)/2, sqrt(-a + E)/2], + [sqrt(-a + E)/2, -sqrt(-a + E)/2],) + assert solve((a**2 + 1)*(sin(a*x) + cos(a*x)), x) == [-pi/(4*a)] + assert solve(3 - (sinh(a*x) + cosh(a*x)), x) == [log(3)/a] + assert set(solve(3 - (sinh(a*x) + cosh(a*x)**2), x)) == \ + {log(-2 + sqrt(5))/a, log(-sqrt(2) + 1)/a, + log(-sqrt(5) - 2)/a, log(1 + sqrt(2))/a} + assert solve(atan(x) - 1) == [tan(1)] + + +def test_issue_5132(): + r, t = symbols('r,t') + assert set(solve([r - x**2 - y**2, tan(t) - y/x], [x, y])) == \ + {( + -sqrt(r*cos(t)**2), -1*sqrt(r*cos(t)**2)*tan(t)), + (sqrt(r*cos(t)**2), sqrt(r*cos(t)**2)*tan(t))} + assert solve([exp(x) - sin(y), 1/y - 3], [x, y]) == \ + [(log(sin(Rational(1, 3))), Rational(1, 3))] + assert solve([exp(x) - sin(y), 1/exp(y) - 3], [x, y]) == \ + [(log(-sin(log(3))), -log(3))] + assert set(solve([exp(x) - sin(y), y**2 - 4], [x, y])) == \ + {(log(-sin(2)), -S(2)), (log(sin(2)), S(2))} + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + assert solve(eqs, set=True) == \ + ([y, z], { + (-log(3), sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), -sqrt(-exp(2*x) - sin(log(3))))}) + assert solve(eqs, x, z, set=True) == ( + [x, z], + {(x, sqrt(-exp(2*x) + sin(y))), (x, -sqrt(-exp(2*x) + sin(y)))}) + assert set(solve(eqs, x, y)) == \ + { + (log(-sqrt(-z**2 - sin(log(3)))), -log(3)), + (log(-z**2 - sin(log(3)))/2, -log(3))} + assert set(solve(eqs, y, z)) == \ + { + (-log(3), -sqrt(-exp(2*x) - sin(log(3)))), + (-log(3), sqrt(-exp(2*x) - sin(log(3))))} + eqs = [exp(x)**2 - sin(y) + z, 1/exp(y) - 3] + assert solve(eqs, set=True) == ([y, z], { + (-log(3), -exp(2*x) - sin(log(3)))}) + assert solve(eqs, x, z, set=True) == ( + [x, z], {(x, -exp(2*x) + sin(y))}) + assert set(solve(eqs, x, y)) == { + (log(-sqrt(-z - sin(log(3)))), -log(3)), + (log(-z - sin(log(3)))/2, -log(3))} + assert solve(eqs, z, y) == \ + [(-exp(2*x) - sin(log(3)), -log(3))] + assert solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), set=True) == ( + [x, y], {(S.One, S(3)), (S(3), S.One)}) + assert set(solve((sqrt(x**2 + y**2) - sqrt(10), x + y - 4), x, y)) == \ + {(S.One, S(3)), (S(3), S.One)} + + +def test_issue_5335(): + lam, a0, conc = symbols('lam a0 conc') + a = 0.005 + b = 0.743436700916726 + eqs = [lam + 2*y - a0*(1 - x/2)*x - a*x/2*x, + a0*(1 - x/2)*x - 1*y - b*y, + x + y - conc] + sym = [x, y, a0] + # there are 4 solutions obtained manually but only two are valid + assert len(solve(eqs, sym, manual=True, minimal=True)) == 2 + assert len(solve(eqs, sym)) == 2 # cf below with rational=False + + +@SKIP("Hangs") +def _test_issue_5335_float(): + # gives ZeroDivisionError: polynomial division + lam, a0, conc = symbols('lam a0 conc') + a = 0.005 + b = 0.743436700916726 + eqs = [lam + 2*y - a0*(1 - x/2)*x - a*x/2*x, + a0*(1 - x/2)*x - 1*y - b*y, + x + y - conc] + sym = [x, y, a0] + assert len(solve(eqs, sym, rational=False)) == 2 + + +def test_issue_5767(): + assert set(solve([x**2 + y + 4], [x])) == \ + {(-sqrt(-y - 4),), (sqrt(-y - 4),)} + + +def _make_example_24609(): + D, R, H, B_g, V, D_c = symbols("D, R, H, B_g, V, D_c", real=True, positive=True) + Sigma_f, Sigma_a, nu = symbols("Sigma_f, Sigma_a, nu", real=True, positive=True) + x = symbols("x", real=True, positive=True) + eq = ( + 2**(S(2)/3)*pi**(S(2)/3)*D_c*(S(231361)/10000 + pi**2/x**2) + /(6*V**(S(2)/3)*x**(S(1)/3)) + - 2**(S(2)/3)*pi**(S(8)/3)*D_c/(2*V**(S(2)/3)*x**(S(7)/3)) + ) + expected = 100*sqrt(2)*pi/481 + return eq, expected, x + + +def test_issue_24609(): + # https://github.com/sympy/sympy/issues/24609 + eq, expected, x = _make_example_24609() + assert solve(eq, x, simplify=True) == [expected] + [solapprox] = solve(eq.n(), x) + assert abs(solapprox - expected.n()) < 1e-14 + + +@XFAIL +def test_issue_24609_xfail(): + # + # This returns 5 solutions when it should be 1 (with x positive). + # Simplification reveals all solutions to be equivalent. It is expected + # that solve without simplify=True returns duplicate solutions in some + # cases but the core of this equation is a simple quadratic that can easily + # be solved without introducing any redundant solutions: + # + # >>> print(factor_terms(eq.as_numer_denom()[0])) + # 2**(2/3)*pi**(2/3)*D_c*V**(2/3)*x**(7/3)*(231361*x**2 - 20000*pi**2) + # + eq, expected, x = _make_example_24609() + assert len(solve(eq, x)) == [expected] + # + # We do not want to pass this test just by using simplify so if the above + # passes then uncomment the additional test below: + # + # assert len(solve(eq, x, simplify=False)) == 1 + + +def test_polysys(): + assert set(solve([x**2 + 2/y - 2, x + y - 3], [x, y])) == \ + {(S.One, S(2)), (1 + sqrt(5), 2 - sqrt(5)), + (1 - sqrt(5), 2 + sqrt(5))} + assert solve([x**2 + y - 2, x**2 + y]) == [] + # the ordering should be whatever the user requested + assert solve([x**2 + y - 3, x - y - 4], (x, y)) != solve([x**2 + + y - 3, x - y - 4], (y, x)) + + +@slow +def test_unrad1(): + raises(NotImplementedError, lambda: + unrad(sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x)) + 3)) + raises(NotImplementedError, lambda: + unrad(sqrt(x) + (x + 1)**Rational(1, 3) + 2*sqrt(y))) + + s = symbols('s', cls=Dummy) + + # checkers to deal with possibility of answer coming + # back with a sign change (cf issue 5203) + def check(rv, ans): + assert bool(rv[1]) == bool(ans[1]) + if ans[1]: + return s_check(rv, ans) + e = rv[0].expand() + a = ans[0].expand() + return e in [a, -a] and rv[1] == ans[1] + + def s_check(rv, ans): + # get the dummy + rv = list(rv) + d = rv[0].atoms(Dummy) + reps = list(zip(d, [s]*len(d))) + # replace s with this dummy + rv = (rv[0].subs(reps).expand(), [rv[1][0].subs(reps), rv[1][1].subs(reps)]) + ans = (ans[0].subs(reps).expand(), [ans[1][0].subs(reps), ans[1][1].subs(reps)]) + return str(rv[0]) in [str(ans[0]), str(-ans[0])] and \ + str(rv[1]) == str(ans[1]) + + assert unrad(1) is None + assert check(unrad(sqrt(x)), + (x, [])) + assert check(unrad(sqrt(x) + 1), + (x - 1, [])) + assert check(unrad(sqrt(x) + root(x, 3) + 2), + (s**3 + s**2 + 2, [s, s**6 - x])) + assert check(unrad(sqrt(x)*root(x, 3) + 2), + (x**5 - 64, [])) + assert check(unrad(sqrt(x) + (x + 1)**Rational(1, 3)), + (x**3 - (x + 1)**2, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + sqrt(2*x)), + (-2*sqrt(2)*x - 2*x + 1, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + 2), + (16*x - 9, [])) + assert check(unrad(sqrt(x) + sqrt(x + 1) + sqrt(1 - x)), + (5*x**2 - 4*x, [])) + assert check(unrad(a*sqrt(x) + b*sqrt(x) + c*sqrt(y) + d*sqrt(y)), + ((a*sqrt(x) + b*sqrt(x))**2 - (c*sqrt(y) + d*sqrt(y))**2, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x)), + (2*x - 1, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x) - 3), + (x**2 - x + 16, [])) + assert check(unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x)), + (5*x**2 - 2*x + 1, [])) + assert unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x) - 3) in [ + (25*x**4 + 376*x**3 + 1256*x**2 - 2272*x + 784, []), + (25*x**8 - 476*x**6 + 2534*x**4 - 1468*x**2 + 169, [])] + assert unrad(sqrt(x) + sqrt(1 - x) + sqrt(2 + x) - sqrt(1 - 2*x)) == \ + (41*x**4 + 40*x**3 + 232*x**2 - 160*x + 16, []) # orig root at 0.487 + assert check(unrad(sqrt(x) + sqrt(x + 1)), (S.One, [])) + + eq = sqrt(x) + sqrt(x + 1) + sqrt(1 - sqrt(x)) + assert check(unrad(eq), + (16*x**2 - 9*x, [])) + assert set(solve(eq, check=False)) == {S.Zero, Rational(9, 16)} + assert solve(eq) == [] + # but this one really does have those solutions + assert set(solve(sqrt(x) - sqrt(x + 1) + sqrt(1 - sqrt(x)))) == \ + {S.Zero, Rational(9, 16)} + + assert check(unrad(sqrt(x) + root(x + 1, 3) + 2*sqrt(y), y), + (S('2*sqrt(x)*(x + 1)**(1/3) + x - 4*y + (x + 1)**(2/3)'), [])) + assert check(unrad(sqrt(x/(1 - x)) + (x + 1)**Rational(1, 3)), + (x**5 - x**4 - x**3 + 2*x**2 + x - 1, [])) + assert check(unrad(sqrt(x/(1 - x)) + 2*sqrt(y), y), + (4*x*y + x - 4*y, [])) + assert check(unrad(sqrt(x)*sqrt(1 - x) + 2, x), + (x**2 - x + 4, [])) + + # http://tutorial.math.lamar.edu/ + # Classes/Alg/SolveRadicalEqns.aspx#Solve_Rad_Ex2_a + assert solve(Eq(x, sqrt(x + 6))) == [3] + assert solve(Eq(x + sqrt(x - 4), 4)) == [4] + assert solve(Eq(1, x + sqrt(2*x - 3))) == [] + assert set(solve(Eq(sqrt(5*x + 6) - 2, x))) == {-S.One, S(2)} + assert set(solve(Eq(sqrt(2*x - 1) - sqrt(x - 4), 2))) == {S(5), S(13)} + assert solve(Eq(sqrt(x + 7) + 2, sqrt(3 - x))) == [-6] + # http://www.purplemath.com/modules/solverad.htm + assert solve((2*x - 5)**Rational(1, 3) - 3) == [16] + assert set(solve(x + 1 - root(x**4 + 4*x**3 - x, 4))) == \ + {Rational(-1, 2), Rational(-1, 3)} + assert set(solve(sqrt(2*x**2 - 7) - (3 - x))) == {-S(8), S(2)} + assert solve(sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4)) == [0] + assert solve(sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1)) == [5] + assert solve(sqrt(x)*sqrt(x - 7) - 12) == [16] + assert solve(sqrt(x - 3) + sqrt(x) - 3) == [4] + assert solve(sqrt(9*x**2 + 4) - (3*x + 2)) == [0] + assert solve(sqrt(x) - 2 - 5) == [49] + assert solve(sqrt(x - 3) - sqrt(x) - 3) == [] + assert solve(sqrt(x - 1) - x + 7) == [10] + assert solve(sqrt(x - 2) - 5) == [27] + assert solve(sqrt(17*x - sqrt(x**2 - 5)) - 7) == [3] + assert solve(sqrt(x) - sqrt(x - 1) + sqrt(sqrt(x))) == [] + + # don't posify the expression in unrad and do use _mexpand + z = sqrt(2*x + 1)/sqrt(x) - sqrt(2 + 1/x) + p = posify(z)[0] + assert solve(p) == [] + assert solve(z) == [] + assert solve(z + 6*I) == [Rational(-1, 11)] + assert solve(p + 6*I) == [] + # issue 8622 + assert unrad(root(x + 1, 5) - root(x, 3)) == ( + -(x**5 - x**3 - 3*x**2 - 3*x - 1), []) + # issue #8679 + assert check(unrad(x + root(x, 3) + root(x, 3)**2 + sqrt(y), x), + (s**3 + s**2 + s + sqrt(y), [s, s**3 - x])) + + # for coverage + assert check(unrad(sqrt(x) + root(x, 3) + y), + (s**3 + s**2 + y, [s, s**6 - x])) + assert solve(sqrt(x) + root(x, 3) - 2) == [1] + raises(NotImplementedError, lambda: + solve(sqrt(x) + root(x, 3) + root(x + 1, 5) - 2)) + # fails through a different code path + raises(NotImplementedError, lambda: solve(-sqrt(2) + cosh(x)/x)) + # unrad some + assert solve(sqrt(x + root(x, 3))+root(x - y, 5), y) == [ + x + (x**Rational(1, 3) + x)**Rational(5, 2)] + assert check(unrad(sqrt(x) - root(x + 1, 3)*sqrt(x + 2) + 2), + (s**10 + 8*s**8 + 24*s**6 - 12*s**5 - 22*s**4 - 160*s**3 - 212*s**2 - + 192*s - 56, [s, s**2 - x])) + e = root(x + 1, 3) + root(x, 3) + assert unrad(e) == (2*x + 1, []) + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + assert check(unrad(eq), + (15625*x**4 + 173000*x**3 + 355600*x**2 - 817920*x + 331776, [])) + assert check(unrad(root(x, 4) + root(x, 4)**3 - 1), + (s**3 + s - 1, [s, s**4 - x])) + assert check(unrad(root(x, 2) + root(x, 2)**3 - 1), + (x**3 + 2*x**2 + x - 1, [])) + assert unrad(x**0.5) is None + assert check(unrad(t + root(x + y, 5) + root(x + y, 5)**3), + (s**3 + s + t, [s, s**5 - x - y])) + assert check(unrad(x + root(x + y, 5) + root(x + y, 5)**3, y), + (s**3 + s + x, [s, s**5 - x - y])) + assert check(unrad(x + root(x + y, 5) + root(x + y, 5)**3, x), + (s**5 + s**3 + s - y, [s, s**5 - x - y])) + assert check(unrad(root(x - 1, 3) + root(x + 1, 5) + root(2, 5)), + (s**5 + 5*2**Rational(1, 5)*s**4 + s**3 + 10*2**Rational(2, 5)*s**3 + + 10*2**Rational(3, 5)*s**2 + 5*2**Rational(4, 5)*s + 4, [s, s**3 - x + 1])) + raises(NotImplementedError, lambda: + unrad((root(x, 2) + root(x, 3) + root(x, 4)).subs(x, x**5 - x + 1))) + + # the simplify flag should be reset to False for unrad results; + # if it's not then this next test will take a long time + assert solve(root(x, 3) + root(x, 5) - 2) == [1] + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + assert check(unrad(eq), + ((5*x - 4)*(3125*x**3 + 37100*x**2 + 100800*x - 82944), [])) + ans = S(''' + [4/5, -1484/375 + 172564/(140625*(114*sqrt(12657)/78125 + + 12459439/52734375)**(1/3)) + + 4*(114*sqrt(12657)/78125 + 12459439/52734375)**(1/3)]''') + assert solve(eq) == ans + # duplicate radical handling + assert check(unrad(sqrt(x + root(x + 1, 3)) - root(x + 1, 3) - 2), + (s**3 - s**2 - 3*s - 5, [s, s**3 - x - 1])) + # cov post-processing + e = root(x**2 + 1, 3) - root(x**2 - 1, 5) - 2 + assert check(unrad(e), + (s**5 - 10*s**4 + 39*s**3 - 80*s**2 + 80*s - 30, + [s, s**3 - x**2 - 1])) + + e = sqrt(x + root(x + 1, 2)) - root(x + 1, 3) - 2 + assert check(unrad(e), + (s**6 - 2*s**5 - 7*s**4 - 3*s**3 + 26*s**2 + 40*s + 25, + [s, s**3 - x - 1])) + assert check(unrad(e, _reverse=True), + (s**6 - 14*s**5 + 73*s**4 - 187*s**3 + 276*s**2 - 228*s + 89, + [s, s**2 - x - sqrt(x + 1)])) + # this one needs r0, r1 reversal to work + assert check(unrad(sqrt(x + sqrt(root(x, 3) - 1)) - root(x, 6) - 2), + (s**12 - 2*s**8 - 8*s**7 - 8*s**6 + s**4 + 8*s**3 + 23*s**2 + + 32*s + 17, [s, s**6 - x])) + + # why does this pass + assert unrad(root(cosh(x), 3)/x*root(x + 1, 5) - 1) == ( + -(x**15 - x**3*cosh(x)**5 - 3*x**2*cosh(x)**5 - 3*x*cosh(x)**5 + - cosh(x)**5), []) + # and this fail? + #assert unrad(sqrt(cosh(x)/x) + root(x + 1, 3)*sqrt(x) - 1) == ( + # -s**6 + 6*s**5 - 15*s**4 + 20*s**3 - 15*s**2 + 6*s + x**5 + + # 2*x**4 + x**3 - 1, [s, s**2 - cosh(x)/x]) + + # watch for symbols in exponents + assert unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1')) is None + assert check(unrad(S('(x+y)**(2*y/3) + (x+y)**(1/3) + 1'), x), + (s**(2*y) + s + 1, [s, s**3 - x - y])) + # should _Q be so lenient? + assert unrad(x**(S.Half/y) + y, x) == (x**(1/y) - y**2, []) + + # This tests two things: that if full unrad is attempted and fails + # the solution should still be found; also it tests that the use of + # composite + assert len(solve(sqrt(y)*x + x**3 - 1, x)) == 3 + assert len(solve(-512*y**3 + 1344*(x + 2)**Rational(1, 3)*y**2 - + 1176*(x + 2)**Rational(2, 3)*y - 169*x + 686, y, _unrad=False)) == 3 + + # watch out for when the cov doesn't involve the symbol of interest + eq = S('-x + (7*y/8 - (27*x/2 + 27*sqrt(x**2)/2)**(1/3)/3)**3 - 1') + assert solve(eq, y) == [ + 2**(S(2)/3)*(27*x + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + + S(512)/343)**(S(1)/3)*(-S(1)/2 - sqrt(3)*I/2), 2**(S(2)/3)*(27*x + + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + + S(512)/343)**(S(1)/3)*(-S(1)/2 + sqrt(3)*I/2), 2**(S(2)/3)*(27*x + + 27*sqrt(x**2))**(S(1)/3)*S(4)/21 + (512*x/343 + S(512)/343)**(S(1)/3)] + + eq = root(x + 1, 3) - (root(x, 3) + root(x, 5)) + assert check(unrad(eq), + (3*s**13 + 3*s**11 + s**9 - 1, [s, s**15 - x])) + assert check(unrad(eq - 2), + (3*s**13 + 3*s**11 + 6*s**10 + s**9 + 12*s**8 + 6*s**6 + 12*s**5 + + 12*s**3 + 7, [s, s**15 - x])) + assert check(unrad(root(x, 3) - root(x + 1, 4)/2 + root(x + 2, 3)), + (s*(4096*s**9 + 960*s**8 + 48*s**7 - s**6 - 1728), + [s, s**4 - x - 1])) # orig expr has two real roots: -1, -.389 + assert check(unrad(root(x, 3) + root(x + 1, 4) - root(x + 2, 3)/2), + (343*s**13 + 2904*s**12 + 1344*s**11 + 512*s**10 - 1323*s**9 - + 3024*s**8 - 1728*s**7 + 1701*s**5 + 216*s**4 - 729*s, [s, s**4 - x - + 1])) # orig expr has one real root: -0.048 + assert check(unrad(root(x, 3)/2 - root(x + 1, 4) + root(x + 2, 3)), + (729*s**13 - 216*s**12 + 1728*s**11 - 512*s**10 + 1701*s**9 - + 3024*s**8 + 1344*s**7 + 1323*s**5 - 2904*s**4 + 343*s, [s, s**4 - x - + 1])) # orig expr has 2 real roots: -0.91, -0.15 + assert check(unrad(root(x, 3)/2 - root(x + 1, 4) + root(x + 2, 3) - 2), + (729*s**13 + 1242*s**12 + 18496*s**10 + 129701*s**9 + 388602*s**8 + + 453312*s**7 - 612864*s**6 - 3337173*s**5 - 6332418*s**4 - 7134912*s**3 + - 5064768*s**2 - 2111913*s - 398034, [s, s**4 - x - 1])) + # orig expr has 1 real root: 19.53 + + ans = solve(sqrt(x) + sqrt(x + 1) - + sqrt(1 - x) - sqrt(2 + x)) + assert len(ans) == 1 and NS(ans[0])[:4] == '0.73' + # the fence optimization problem + # https://github.com/sympy/sympy/issues/4793#issuecomment-36994519 + F = Symbol('F') + eq = F - (2*x + 2*y + sqrt(x**2 + y**2)) + ans = F*Rational(2, 7) - sqrt(2)*F/14 + X = solve(eq, x, check=False) + for xi in reversed(X): # reverse since currently, ans is the 2nd one + Y = solve((x*y).subs(x, xi).diff(y), y, simplify=False, check=False) + if any((a - ans).expand().is_zero for a in Y): + break + else: + assert None # no answer was found + assert solve(sqrt(x + 1) + root(x, 3) - 2) == S(''' + [(-11/(9*(47/54 + sqrt(93)/6)**(1/3)) + 1/3 + (47/54 + + sqrt(93)/6)**(1/3))**3]''') + assert solve(sqrt(sqrt(x + 1)) + x**Rational(1, 3) - 2) == S(''' + [(-sqrt(-2*(-1/16 + sqrt(6913)/16)**(1/3) + 6/(-1/16 + + sqrt(6913)/16)**(1/3) + 17/2 + 121/(4*sqrt(-6/(-1/16 + + sqrt(6913)/16)**(1/3) + 2*(-1/16 + sqrt(6913)/16)**(1/3) + 17/4)))/2 + + sqrt(-6/(-1/16 + sqrt(6913)/16)**(1/3) + 2*(-1/16 + + sqrt(6913)/16)**(1/3) + 17/4)/2 + 9/4)**3]''') + assert solve(sqrt(x) + root(sqrt(x) + 1, 3) - 2) == S(''' + [(-(81/2 + 3*sqrt(741)/2)**(1/3)/3 + (81/2 + 3*sqrt(741)/2)**(-1/3) + + 2)**2]''') + eq = S(''' + -x + (1/2 - sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + + x*(3*x**2 - 34) + 90)**2/4 - 39304/27) - 45)**(1/3) + 34/(3*(1/2 - + sqrt(3)*I/2)*(3*x**3/2 - x*(3*x**2 - 34)/2 + sqrt((-3*x**3 + x*(3*x**2 + - 34) + 90)**2/4 - 39304/27) - 45)**(1/3))''') + assert check(unrad(eq), + (s*-(-s**6 + sqrt(3)*s**6*I - 153*2**Rational(2, 3)*3**Rational(1, 3)*s**4 + + 51*12**Rational(1, 3)*s**4 - 102*2**Rational(2, 3)*3**Rational(5, 6)*s**4*I - 1620*s**3 + + 1620*sqrt(3)*s**3*I + 13872*18**Rational(1, 3)*s**2 - 471648 + + 471648*sqrt(3)*I), [s, s**3 - 306*x - sqrt(3)*sqrt(31212*x**2 - + 165240*x + 61484) + 810])) + + assert solve(eq) == [] # not other code errors + eq = root(x, 3) - root(y, 3) + root(x, 5) + assert check(unrad(eq), + (s**15 + 3*s**13 + 3*s**11 + s**9 - y, [s, s**15 - x])) + eq = root(x, 3) + root(y, 3) + root(x*y, 4) + assert check(unrad(eq), + (s*y*(-s**12 - 3*s**11*y - 3*s**10*y**2 - s**9*y**3 - + 3*s**8*y**2 + 21*s**7*y**3 - 3*s**6*y**4 - 3*s**4*y**4 - + 3*s**3*y**5 - y**6), [s, s**4 - x*y])) + raises(NotImplementedError, + lambda: unrad(root(x, 3) + root(y, 3) + root(x*y, 5))) + + # Test unrad with an Equality + eq = Eq(-x**(S(1)/5) + x**(S(1)/3), -3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5)) + assert check(unrad(eq), + (-s**5 + s**3 - 3**(S(1)/3) - (-1)**(S(3)/5)*3**(S(1)/5), [s, s**15 - x])) + + # make sure buried radicals are exposed + s = sqrt(x) - 1 + assert unrad(s**2 - s**3) == (x**3 - 6*x**2 + 9*x - 4, []) + # make sure numerators which are already polynomial are rejected + assert unrad((x/(x + 1) + 3)**(-2), x) is None + + # https://github.com/sympy/sympy/issues/23707 + eq = sqrt(x - y)*exp(t*sqrt(x - y)) - exp(t*sqrt(x - y)) + assert solve(eq, y) == [x - 1] + assert unrad(eq) is None + + +@slow +def test_unrad_slow(): + # this has roots with multiplicity > 1; there should be no + # repeats in roots obtained, however + eq = (sqrt(1 + sqrt(1 - 4*x**2)) - x*(1 + sqrt(1 + 2*sqrt(1 - 4*x**2)))) + assert solve(eq) == [S.Half] + + +@XFAIL +def test_unrad_fail(): + # this only works if we check real_root(eq.subs(x, Rational(1, 3))) + # but checksol doesn't work like that + assert solve(root(x**3 - 3*x**2, 3) + 1 - x) == [Rational(1, 3)] + assert solve(root(x + 1, 3) + root(x**2 - 2, 5) + 1) == [ + -1, -1 + CRootOf(x**5 + x**4 + 5*x**3 + 8*x**2 + 10*x + 5, 0)**3] + + +def test_checksol(): + x, y, r, t = symbols('x, y, r, t') + eq = r - x**2 - y**2 + dict_var_soln = {y: - sqrt(r) / sqrt(tan(t)**2 + 1), + x: -sqrt(r)*tan(t)/sqrt(tan(t)**2 + 1)} + assert checksol(eq, dict_var_soln) == True + assert checksol(Eq(x, False), {x: False}) is True + assert checksol(Ne(x, False), {x: False}) is False + assert checksol(Eq(x < 1, True), {x: 0}) is True + assert checksol(Eq(x < 1, True), {x: 1}) is False + assert checksol(Eq(x < 1, False), {x: 1}) is True + assert checksol(Eq(x < 1, False), {x: 0}) is False + assert checksol(Eq(x + 1, x**2 + 1), {x: 1}) is True + assert checksol([x - 1, x**2 - 1], x, 1) is True + assert checksol([x - 1, x**2 - 2], x, 1) is False + assert checksol(Poly(x**2 - 1), x, 1) is True + assert checksol(0, {}) is True + assert checksol([1e-10, x - 2], x, 2) is False + assert checksol([0.5, 0, x], x, 0) is False + assert checksol(y, x, 2) is False + assert checksol(x+1e-10, x, 0, numerical=True) is True + assert checksol(x+1e-10, x, 0, numerical=False) is False + assert checksol(exp(92*x), {x: log(sqrt(2)/2)}) is False + assert checksol(exp(92*x), {x: log(sqrt(2)/2) + I*pi}) is False + assert checksol(1/x**5, x, 1000) is False + raises(ValueError, lambda: checksol(x, 1)) + raises(ValueError, lambda: checksol([], x, 1)) + + +def test__invert(): + assert _invert(x - 2) == (2, x) + assert _invert(2) == (2, 0) + assert _invert(exp(1/x) - 3, x) == (1/log(3), x) + assert _invert(exp(1/x + a/x) - 3, x) == ((a + 1)/log(3), x) + assert _invert(a, x) == (a, 0) + + +def test_issue_4463(): + assert solve(-a*x + 2*x*log(x), x) == [exp(a/2)] + assert solve(x**x) == [] + assert solve(x**x - 2) == [exp(LambertW(log(2)))] + assert solve(((x - 3)*(x - 2))**((x - 3)*(x - 4))) == [2] + +@slow +def test_issue_5114_solvers(): + a, b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r = symbols('a:r') + + # there is no 'a' in the equation set but this is how the + # problem was originally posed + syms = a, b, c, f, h, k, n + eqs = [b + r/d - c/d, + c*(1/d + 1/e + 1/g) - f/g - r/d, + f*(1/g + 1/i + 1/j) - c/g - h/i, + h*(1/i + 1/l + 1/m) - f/i - k/m, + k*(1/m + 1/o + 1/p) - h/m - n/p, + n*(1/p + 1/q) - k/p] + assert len(solve(eqs, syms, manual=True, check=False, simplify=False)) == 1 + + +def test_issue_5849(): + # + # XXX: This system does not have a solution for most values of the + # parameters. Generally solve returns the empty set for systems that are + # generically inconsistent. + # + I1, I2, I3, I4, I5, I6 = symbols('I1:7') + dI1, dI4, dQ2, dQ4, Q2, Q4 = symbols('dI1,dI4,dQ2,dQ4,Q2,Q4') + + e = ( + I1 - I2 - I3, + I3 - I4 - I5, + I4 + I5 - I6, + -I1 + I2 + I6, + -2*I1 - 2*I3 - 2*I5 - 3*I6 - dI1/2 + 12, + -I4 + dQ4, + -I2 + dQ2, + 2*I3 + 2*I5 + 3*I6 - Q2, + I4 - 2*I5 + 2*Q4 + dI4 + ) + + ans = [{ + I1: I2 + I3, + dI1: -4*I2 - 8*I3 - 4*I5 - 6*I6 + 24, + I4: I3 - I5, + dQ4: I3 - I5, + Q4: -I3/2 + 3*I5/2 - dI4/2, + dQ2: I2, + Q2: 2*I3 + 2*I5 + 3*I6}] + + v = I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4 + assert solve(e, *v, manual=True, check=False, dict=True) == ans + assert solve(e, *v, manual=True, check=False) == [ + tuple([a.get(i, i) for i in v]) for a in ans] + assert solve(e, *v, manual=True) == [] + assert solve(e, *v) == [] + + # the matrix solver (tested below) doesn't like this because it produces + # a zero row in the matrix. Is this related to issue 4551? + assert [ei.subs( + ans[0]) for ei in e] == [0, 0, I3 - I6, -I3 + I6, 0, 0, 0, 0, 0] + + +def test_issue_5849_matrix(): + '''Same as test_issue_5849 but solved with the matrix solver. + + A solution only exists if I3 == I6 which is not generically true, + but `solve` does not return conditions under which the solution is + valid, only a solution that is canonical and consistent with the input. + ''' + # a simple example with the same issue + # assert solve([x+y+z, x+y], [x, y]) == {x: y} + # the longer example + I1, I2, I3, I4, I5, I6 = symbols('I1:7') + dI1, dI4, dQ2, dQ4, Q2, Q4 = symbols('dI1,dI4,dQ2,dQ4,Q2,Q4') + + e = ( + I1 - I2 - I3, + I3 - I4 - I5, + I4 + I5 - I6, + -I1 + I2 + I6, + -2*I1 - 2*I3 - 2*I5 - 3*I6 - dI1/2 + 12, + -I4 + dQ4, + -I2 + dQ2, + 2*I3 + 2*I5 + 3*I6 - Q2, + I4 - 2*I5 + 2*Q4 + dI4 + ) + assert solve(e, I1, I4, Q2, Q4, dI1, dI4, dQ2, dQ4) == [] + + +def test_issue_21882(): + + a, b, c, d, f, g, k = unknowns = symbols('a, b, c, d, f, g, k') + + equations = [ + -k*a + b + 5*f/6 + 2*c/9 + 5*d/6 + 4*a/3, + -k*f + 4*f/3 + d/2, + -k*d + f/6 + d, + 13*b/18 + 13*c/18 + 13*a/18, + -k*c + b/2 + 20*c/9 + a, + -k*b + b + c/18 + a/6, + 5*b/3 + c/3 + a, + 2*b/3 + 2*c + 4*a/3, + -g, + ] + + answer = [ + {a: 0, f: 0, b: 0, d: 0, c: 0, g: 0}, + {a: 0, f: -d, b: 0, k: S(5)/6, c: 0, g: 0}, + {a: -2*c, f: 0, b: c, d: 0, k: S(13)/18, g: 0}] + # but not {a: 0, f: 0, b: 0, k: S(3)/2, c: 0, d: 0, g: 0} + # since this is already covered by the first solution + got = solve(equations, unknowns, dict=True) + assert got == answer, (got,answer) + + +def test_issue_5901(): + f, g, h = map(Function, 'fgh') + a = Symbol('a') + D = Derivative(f(x), x) + G = Derivative(g(a), a) + assert solve(f(x) + f(x).diff(x), f(x)) == \ + [-D] + assert solve(f(x) - 3, f(x)) == \ + [3] + assert solve(f(x) - 3*f(x).diff(x), f(x)) == \ + [3*D] + assert solve([f(x) - 3*f(x).diff(x)], f(x)) == \ + {f(x): 3*D} + assert solve([f(x) - 3*f(x).diff(x), f(x)**2 - y + 4], f(x), y) == \ + [(3*D, 9*D**2 + 4)] + assert solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), + h(a), g(a), set=True) == \ + ([h(a), g(a)], { + (-sqrt(f(a)**2*g(a)**2 - G)/f(a), g(a)), + (sqrt(f(a)**2*g(a)**2 - G)/f(a), g(a))}), solve(-f(a)**2*g(a)**2 + f(a)**2*h(a)**2 + g(a).diff(a), + h(a), g(a), set=True) + args = [[f(x).diff(x, 2)*(f(x) + g(x)), 2 - g(x)**2], f(x), g(x)] + assert solve(*args, set=True)[1] == \ + {(-sqrt(2), sqrt(2)), (sqrt(2), -sqrt(2))} + eqs = [f(x)**2 + g(x) - 2*f(x).diff(x), g(x)**2 - 4] + assert solve(eqs, f(x), g(x), set=True) == \ + ([f(x), g(x)], { + (-sqrt(2*D - 2), S(2)), + (sqrt(2*D - 2), S(2)), + (-sqrt(2*D + 2), -S(2)), + (sqrt(2*D + 2), -S(2))}) + + # the underlying problem was in solve_linear that was not masking off + # anything but a Mul or Add; it now raises an error if it gets anything + # but a symbol and solve handles the substitutions necessary so solve_linear + # won't make this error + raises( + ValueError, lambda: solve_linear(f(x) + f(x).diff(x), symbols=[f(x)])) + assert solve_linear(f(x) + f(x).diff(x), symbols=[x]) == \ + (f(x) + Derivative(f(x), x), 1) + assert solve_linear(f(x) + Integral(x, (x, y)), symbols=[x]) == \ + (f(x) + Integral(x, (x, y)), 1) + assert solve_linear(f(x) + Integral(x, (x, y)) + x, symbols=[x]) == \ + (x + f(x) + Integral(x, (x, y)), 1) + assert solve_linear(f(y) + Integral(x, (x, y)) + x, symbols=[x]) == \ + (x, -f(y) - Integral(x, (x, y))) + assert solve_linear(x - f(x)/a + (f(x) - 1)/a, symbols=[x]) == \ + (x, 1/a) + assert solve_linear(x + Derivative(2*x, x)) == \ + (x, -2) + assert solve_linear(x + Integral(x, y), symbols=[x]) == \ + (x, 0) + assert solve_linear(x + Integral(x, y) - 2, symbols=[x]) == \ + (x, 2/(y + 1)) + + assert set(solve(x + exp(x)**2, exp(x))) == \ + {-sqrt(-x), sqrt(-x)} + assert solve(x + exp(x), x, implicit=True) == \ + [-exp(x)] + assert solve(cos(x) - sin(x), x, implicit=True) == [] + assert solve(x - sin(x), x, implicit=True) == \ + [sin(x)] + assert solve(x**2 + x - 3, x, implicit=True) == \ + [-x**2 + 3] + assert solve(x**2 + x - 3, x**2, implicit=True) == \ + [-x + 3] + + +def test_issue_5912(): + assert set(solve(x**2 - x - 0.1, rational=True)) == \ + {S.Half + sqrt(35)/10, -sqrt(35)/10 + S.Half} + ans = solve(x**2 - x - 0.1, rational=False) + assert len(ans) == 2 and all(a.is_Number for a in ans) + ans = solve(x**2 - x - 0.1) + assert len(ans) == 2 and all(a.is_Number for a in ans) + + +def test_float_handling(): + def test(e1, e2): + return len(e1.atoms(Float)) == len(e2.atoms(Float)) + assert solve(x - 0.5, rational=True)[0].is_Rational + assert solve(x - 0.5, rational=False)[0].is_Float + assert solve(x - S.Half, rational=False)[0].is_Rational + assert solve(x - 0.5, rational=None)[0].is_Float + assert solve(x - S.Half, rational=None)[0].is_Rational + assert test(nfloat(1 + 2*x), 1.0 + 2.0*x) + for contain in [list, tuple, set]: + ans = nfloat(contain([1 + 2*x])) + assert type(ans) is contain and test(list(ans)[0], 1.0 + 2.0*x) + k, v = list(nfloat({2*x: [1 + 2*x]}).items())[0] + assert test(k, 2*x) and test(v[0], 1.0 + 2.0*x) + assert test(nfloat(cos(2*x)), cos(2.0*x)) + assert test(nfloat(3*x**2), 3.0*x**2) + assert test(nfloat(3*x**2, exponent=True), 3.0*x**2.0) + assert test(nfloat(exp(2*x)), exp(2.0*x)) + assert test(nfloat(x/3), x/3.0) + assert test(nfloat(x**4 + 2*x + cos(Rational(1, 3)) + 1), + x**4 + 2.0*x + 1.94495694631474) + # don't call nfloat if there is no solution + tot = 100 + c + z + t + assert solve(((.7 + c)/tot - .6, (.2 + z)/tot - .3, t/tot - .1)) == [] + + +def test_check_assumptions(): + x = symbols('x', positive=True) + assert solve(x**2 - 1) == [1] + + +def test_issue_6056(): + assert solve(tanh(x + 3)*tanh(x - 3) - 1) == [] + assert solve(tanh(x - 1)*tanh(x + 1) + 1) == \ + [I*pi*Rational(-3, 4), -I*pi/4, I*pi/4, I*pi*Rational(3, 4)] + assert solve((tanh(x + 3)*tanh(x - 3) + 1)**2) == \ + [I*pi*Rational(-3, 4), -I*pi/4, I*pi/4, I*pi*Rational(3, 4)] + + +def test_issue_5673(): + eq = -x + exp(exp(LambertW(log(x)))*LambertW(log(x))) + assert checksol(eq, x, 2) is True + assert checksol(eq, x, 2, numerical=False) is None + + +def test_exclude(): + R, C, Ri, Vout, V1, Vminus, Vplus, s = \ + symbols('R, C, Ri, Vout, V1, Vminus, Vplus, s') + Rf = symbols('Rf', positive=True) # to eliminate Rf = 0 soln + eqs = [C*V1*s + Vplus*(-2*C*s - 1/R), + Vminus*(-1/Ri - 1/Rf) + Vout/Rf, + C*Vplus*s + V1*(-C*s - 1/R) + Vout/R, + -Vminus + Vplus] + assert solve(eqs, exclude=s*C*R) == [ + { + Rf: Ri*(C*R*s + 1)**2/(C*R*s), + Vminus: Vplus, + V1: 2*Vplus + Vplus/(C*R*s), + Vout: C*R*Vplus*s + 3*Vplus + Vplus/(C*R*s)}, + { + Vplus: 0, + Vminus: 0, + V1: 0, + Vout: 0}, + ] + + # TODO: Investigate why currently solution [0] is preferred over [1]. + assert solve(eqs, exclude=[Vplus, s, C]) in [[{ + Vminus: Vplus, + V1: Vout/2 + Vplus/2 + sqrt((Vout - 5*Vplus)*(Vout - Vplus))/2, + R: (Vout - 3*Vplus - sqrt(Vout**2 - 6*Vout*Vplus + 5*Vplus**2))/(2*C*Vplus*s), + Rf: Ri*(Vout - Vplus)/Vplus, + }, { + Vminus: Vplus, + V1: Vout/2 + Vplus/2 - sqrt((Vout - 5*Vplus)*(Vout - Vplus))/2, + R: (Vout - 3*Vplus + sqrt(Vout**2 - 6*Vout*Vplus + 5*Vplus**2))/(2*C*Vplus*s), + Rf: Ri*(Vout - Vplus)/Vplus, + }], [{ + Vminus: Vplus, + Vout: (V1**2 - V1*Vplus - Vplus**2)/(V1 - 2*Vplus), + Rf: Ri*(V1 - Vplus)**2/(Vplus*(V1 - 2*Vplus)), + R: Vplus/(C*s*(V1 - 2*Vplus)), + }]] + + +def test_high_order_roots(): + s = x**5 + 4*x**3 + 3*x**2 + Rational(7, 4) + assert set(solve(s)) == set(Poly(s*4, domain='ZZ').all_roots()) + + +def test_minsolve_linear_system(): + pqt = {"quick": True, "particular": True} + pqf = {"quick": False, "particular": True} + assert solve([x + y - 5, 2*x - y - 1], **pqt) == {x: 2, y: 3} + assert solve([x + y - 5, 2*x - y - 1], **pqf) == {x: 2, y: 3} + def count(dic): + return len([x for x in dic.values() if x == 0]) + assert count(solve([x + y + z, y + z + a + t], **pqt)) == 3 + assert count(solve([x + y + z, y + z + a + t], **pqf)) == 3 + assert count(solve([x + y + z, y + z + a], **pqt)) == 1 + assert count(solve([x + y + z, y + z + a], **pqf)) == 2 + # issue 22718 + A = Matrix([ + [ 1, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 1, 1, 0], + [ 1, 1, 0, 1, 1, 0, 1, 0, 1, 0, -1, -1, 0, 0], + [-1, -1, 0, 0, -1, 0, 0, 0, 0, 0, 1, 1, 0, 1], + [ 1, 0, 1, 1, 0, 1, 1, 0, 0, 1, -1, 0, -1, 0], + [-1, 0, -1, 0, 0, -1, 0, 0, 0, 0, 1, 0, 1, 1], + [-1, 0, 0, -1, 0, 0, -1, 0, 0, 0, -1, 0, 0, -1], + [ 0, 1, 1, 1, 0, 0, 0, 1, 1, 1, 0, -1, -1, 0], + [ 0, -1, -1, 0, 0, 0, 0, -1, 0, 0, 0, 1, 1, 1], + [ 0, -1, 0, -1, 0, 0, 0, 0, -1, 0, 0, -1, 0, -1], + [ 0, 0, -1, -1, 0, 0, 0, 0, 0, -1, 0, 0, -1, -1], + [ 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0], + [ 0, 0, 0, 0, -1, -1, 0, -1, 0, 0, 0, 0, 0, 0]]) + v = Matrix(symbols("v:14", integer=True)) + B = Matrix([[2], [-2], [0], [0], [0], [0], [0], [0], [0], + [0], [0], [0]]) + eqs = A@v-B + assert solve(eqs) == [] + assert solve(eqs, particular=True) == [] # assumption violated + assert all(v for v in solve([x + y + z, y + z + a]).values()) + for _q in (True, False): + assert not all(v for v in solve( + [x + y + z, y + z + a], quick=_q, + particular=True).values()) + # raise error if quick used w/o particular=True + raises(ValueError, lambda: solve([x + 1], quick=_q)) + raises(ValueError, lambda: solve([x + 1], quick=_q, particular=False)) + # and give a good error message if someone tries to use + # particular with a single equation + raises(ValueError, lambda: solve(x + 1, particular=True)) + + +def test_real_roots(): + # cf. issue 6650 + x = Symbol('x', real=True) + assert len(solve(x**5 + x**3 + 1)) == 1 + + +def test_issue_6528(): + eqs = [ + 327600995*x**2 - 37869137*x + 1809975124*y**2 - 9998905626, + 895613949*x**2 - 273830224*x*y + 530506983*y**2 - 10000000000] + # two expressions encountered are > 1400 ops long so if this hangs + # it is likely because simplification is being done + assert len(solve(eqs, y, x, check=False)) == 4 + + +def test_overdetermined(): + x = symbols('x', real=True) + eqs = [Abs(4*x - 7) - 5, Abs(3 - 8*x) - 1] + assert solve(eqs, x) == [(S.Half,)] + assert solve(eqs, x, manual=True) == [(S.Half,)] + assert solve(eqs, x, manual=True, check=False) == [(S.Half,), (S(3),)] + + +def test_issue_6605(): + x = symbols('x') + assert solve(4**(x/2) - 2**(x/3)) == [0, 3*I*pi/log(2)] + # while the first one passed, this one failed + x = symbols('x', real=True) + assert solve(5**(x/2) - 2**(x/3)) == [0] + b = sqrt(6)*sqrt(log(2))/sqrt(log(5)) + assert solve(5**(x/2) - 2**(3/x)) == [-b, b] + + +def test__ispow(): + assert _ispow(x**2) + assert not _ispow(x) + assert not _ispow(True) + + +def test_issue_6644(): + eq = -sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) + sqrt((-m**2/2 - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2) + sol = solve(eq, q, simplify=False, check=False) + assert len(sol) == 5 + + +def test_issue_6752(): + assert solve([a**2 + a, a - b], [a, b]) == [(-1, -1), (0, 0)] + assert solve([a**2 + a*c, a - b], [a, b]) == [(0, 0), (-c, -c)] + + +def test_issue_6792(): + assert solve(x*(x - 1)**2*(x + 1)*(x**6 - x + 1)) == [ + -1, 0, 1, CRootOf(x**6 - x + 1, 0), CRootOf(x**6 - x + 1, 1), + CRootOf(x**6 - x + 1, 2), CRootOf(x**6 - x + 1, 3), + CRootOf(x**6 - x + 1, 4), CRootOf(x**6 - x + 1, 5)] + + +def test_issues_6819_6820_6821_6248_8692_25777_25779(): + # issue 6821 + x, y = symbols('x y', real=True) + assert solve(abs(x + 3) - 2*abs(x - 3)) == [1, 9] + assert solve([abs(x) - 2, arg(x) - pi], x) == [(-2,)] + assert set(solve(abs(x - 7) - 8)) == {-S.One, S(15)} + + # issue 8692 + assert solve(Eq(Abs(x + 1) + Abs(x**2 - 7), 9), x) == [ + Rational(-1, 2) + sqrt(61)/2, -sqrt(69)/2 + S.Half] + + # issue 7145 + assert solve(2*abs(x) - abs(x - 1)) == [-1, Rational(1, 3)] + + # 25777 + assert solve(abs(x**3 + x + 2)/(x + 1)) == [] + + # 25779 + assert solve(abs(x)) == [0] + assert solve(Eq(abs(x**2 - 2*x), 4), x) == [ + 1 - sqrt(5), 1 + sqrt(5)] + nn = symbols('nn', nonnegative=True) + assert solve(abs(sqrt(nn))) == [0] + nz = symbols('nz', nonzero=True) + assert solve(Eq(Abs(4 + 1 / (4*nz)), 0)) == [-Rational(1, 16)] + + x = symbols('x') + assert solve([re(x) - 1, im(x) - 2], x) == [ + {x: 1 + 2*I, re(x): 1, im(x): 2}] + + # check for 'dict' handling of solution + eq = sqrt(re(x)**2 + im(x)**2) - 3 + assert solve(eq) == solve(eq, x) + + i = symbols('i', imaginary=True) + assert solve(abs(i) - 3) == [-3*I, 3*I] + raises(NotImplementedError, lambda: solve(abs(x) - 3)) + + w = symbols('w', integer=True) + assert solve(2*x**w - 4*y**w, w) == solve((x/y)**w - 2, w) + + x, y = symbols('x y', real=True) + assert solve(x + y*I + 3) == {y: 0, x: -3} + # issue 2642 + assert solve(x*(1 + I)) == [0] + + x, y = symbols('x y', imaginary=True) + assert solve(x + y*I + 3 + 2*I) == {x: -2*I, y: 3*I} + + x = symbols('x', real=True) + assert solve(x + y + 3 + 2*I) == {x: -3, y: -2*I} + + # issue 6248 + f = Function('f') + assert solve(f(x + 1) - f(2*x - 1)) == [2] + assert solve(log(x + 1) - log(2*x - 1)) == [2] + + x = symbols('x') + assert solve(2**x + 4**x) == [I*pi/log(2)] + +def test_issue_17638(): + + assert solve(((2-exp(2*x))*exp(x))/(exp(2*x)+2)**2 > 0, x) == (-oo < x) & (x < log(2)/2) + assert solve(((2-exp(2*x)+2)*exp(x+2))/(exp(x)+2)**2 > 0, x) == (-oo < x) & (x < log(4)/2) + assert solve((exp(x)+2+x**2)*exp(2*x+2)/(exp(x)+2)**2 > 0, x) == (-oo < x) & (x < oo) + + + +def test_issue_14607(): + # issue 14607 + s, tau_c, tau_1, tau_2, phi, K = symbols( + 's, tau_c, tau_1, tau_2, phi, K') + + target = (s**2*tau_1*tau_2 + s*tau_1 + s*tau_2 + 1)/(K*s*(-phi + tau_c)) + + K_C, tau_I, tau_D = symbols('K_C, tau_I, tau_D', + positive=True, nonzero=True) + PID = K_C*(1 + 1/(tau_I*s) + tau_D*s) + + eq = (target - PID).together() + eq *= denom(eq).simplify() + eq = Poly(eq, s) + c = eq.coeffs() + + vars = [K_C, tau_I, tau_D] + s = solve(c, vars, dict=True) + + assert len(s) == 1 + + knownsolution = {K_C: -(tau_1 + tau_2)/(K*(phi - tau_c)), + tau_I: tau_1 + tau_2, + tau_D: tau_1*tau_2/(tau_1 + tau_2)} + + for var in vars: + assert s[0][var].simplify() == knownsolution[var].simplify() + + +def test_lambert_multivariate(): + from sympy.abc import x, y + assert _filtered_gens(Poly(x + 1/x + exp(x) + y), x) == {x, exp(x)} + assert _lambert(x, x) == [] + assert solve((x**2 - 2*x + 1).subs(x, log(x) + 3*x)) == [LambertW(3*S.Exp1)/3] + assert solve((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1)) == \ + [LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3] + assert solve((x**2 - 2*x - 2).subs(x, log(x) + 3*x)) == \ + [LambertW(3*exp(1 - sqrt(3)))/3, LambertW(3*exp(1 + sqrt(3)))/3] + eq = (x*exp(x) - 3).subs(x, x*exp(x)) + assert solve(eq) == [LambertW(3*exp(-LambertW(3)))] + # coverage test + raises(NotImplementedError, lambda: solve(x - sin(x)*log(y - x), x)) + ans = [3, -3*LambertW(-log(3)/3)/log(3)] # 3 and 2.478... + assert solve(x**3 - 3**x, x) == ans + assert set(solve(3*log(x) - x*log(3))) == set(ans) + assert solve(LambertW(2*x) - y, x) == [y*exp(y)/2] + + +@XFAIL +def test_other_lambert(): + assert solve(3*sin(x) - x*sin(3), x) == [3] + assert set(solve(x**a - a**x), x) == { + a, -a*LambertW(-log(a)/a)/log(a)} + + +@slow +def test_lambert_bivariate(): + # tests passing current implementation + assert solve((x**2 + x)*exp(x**2 + x) - 1) == [ + Rational(-1, 2) + sqrt(1 + 4*LambertW(1))/2, + Rational(-1, 2) - sqrt(1 + 4*LambertW(1))/2] + assert solve((x**2 + x)*exp((x**2 + x)*2) - 1) == [ + Rational(-1, 2) + sqrt(1 + 2*LambertW(2))/2, + Rational(-1, 2) - sqrt(1 + 2*LambertW(2))/2] + assert solve(a/x + exp(x/2), x) == [2*LambertW(-a/2)] + assert solve((a/x + exp(x/2)).diff(x), x) == \ + [4*LambertW(-sqrt(2)*sqrt(a)/4), 4*LambertW(sqrt(2)*sqrt(a)/4)] + assert solve((1/x + exp(x/2)).diff(x), x) == \ + [4*LambertW(-sqrt(2)/4), + 4*LambertW(sqrt(2)/4), # nsimplifies as 2*2**(141/299)*3**(206/299)*5**(205/299)*7**(37/299)/21 + 4*LambertW(-sqrt(2)/4, -1)] + assert solve(x*log(x) + 3*x + 1, x) == \ + [exp(-3 + LambertW(-exp(3)))] + assert solve(-x**2 + 2**x, x) == [2, 4, -2*LambertW(log(2)/2)/log(2)] + assert solve(x**2 - 2**x, x) == [2, 4, -2*LambertW(log(2)/2)/log(2)] + ans = solve(3*x + 5 + 2**(-5*x + 3), x) + assert len(ans) == 1 and ans[0].expand() == \ + Rational(-5, 3) + LambertW(-10240*root(2, 3)*log(2)/3)/(5*log(2)) + assert solve(5*x - 1 + 3*exp(2 - 7*x), x) == \ + [Rational(1, 5) + LambertW(-21*exp(Rational(3, 5))/5)/7] + assert solve((log(x) + x).subs(x, x**2 + 1)) == [ + -I*sqrt(-LambertW(1) + 1), sqrt(-1 + LambertW(1))] + # check collection + ax = a**(3*x + 5) + ans = solve(3*log(ax) + b*log(ax) + ax, x) + x0 = 1/log(a) + x1 = sqrt(3)*I + x2 = b + 3 + x3 = x2*LambertW(1/x2)/a**5 + x4 = x3**Rational(1, 3)/2 + assert ans == [ + x0*log(x4*(-x1 - 1)), + x0*log(x4*(x1 - 1)), + x0*log(x3)/3] + x1 = LambertW(Rational(1, 3)) + x2 = a**(-5) + x3 = -3**Rational(1, 3) + x4 = 3**Rational(5, 6)*I + x5 = x1**Rational(1, 3)*x2**Rational(1, 3)/2 + ans = solve(3*log(ax) + ax, x) + assert ans == [ + x0*log(3*x1*x2)/3, + x0*log(x5*(x3 - x4)), + x0*log(x5*(x3 + x4))] + # coverage + p = symbols('p', positive=True) + eq = 4*2**(2*p + 3) - 2*p - 3 + assert _solve_lambert(eq, p, _filtered_gens(Poly(eq), p)) == [ + Rational(-3, 2) - LambertW(-4*log(2))/(2*log(2))] + assert set(solve(3**cos(x) - cos(x)**3)) == { + acos(3), acos(-3*LambertW(-log(3)/3)/log(3))} + # should give only one solution after using `uniq` + assert solve(2*log(x) - 2*log(z) + log(z + log(x) + log(z)), x) == [ + exp(-z + LambertW(2*z**4*exp(2*z))/2)/z] + # cases when p != S.One + # issue 4271 + ans = solve((a/x + exp(x/2)).diff(x, 2), x) + x0 = (-a)**Rational(1, 3) + x1 = sqrt(3)*I + x2 = x0/6 + assert ans == [ + 6*LambertW(x0/3), + 6*LambertW(x2*(-x1 - 1)), + 6*LambertW(x2*(x1 - 1))] + assert solve((1/x + exp(x/2)).diff(x, 2), x) == \ + [6*LambertW(Rational(-1, 3)), 6*LambertW(Rational(1, 6) - sqrt(3)*I/6), \ + 6*LambertW(Rational(1, 6) + sqrt(3)*I/6), 6*LambertW(Rational(-1, 3), -1)] + assert solve(x**2 - y**2/exp(x), x, y, dict=True) == \ + [{x: 2*LambertW(-y/2)}, {x: 2*LambertW(y/2)}] + # this is slow but not exceedingly slow + assert solve((x**3)**(x/2) + pi/2, x) == [ + exp(LambertW(-2*log(2)/3 + 2*log(pi)/3 + I*pi*Rational(2, 3)))] + + # issue 23253 + assert solve((1/log(sqrt(x) + 2)**2 - 1/x)) == [ + (LambertW(-exp(-2), -1) + 2)**2] + assert solve((1/log(1/sqrt(x) + 2)**2 - x)) == [ + (LambertW(-exp(-2), -1) + 2)**-2] + assert solve((1/log(x**2 + 2)**2 - x**-4)) == [ + -I*sqrt(2 - LambertW(exp(2))), + -I*sqrt(LambertW(-exp(-2)) + 2), + sqrt(-2 - LambertW(-exp(-2))), + sqrt(-2 + LambertW(exp(2))), + -sqrt(-2 - LambertW(-exp(-2), -1)), + sqrt(-2 - LambertW(-exp(-2), -1))] + + +def test_rewrite_trig(): + assert solve(sin(x) + tan(x)) == [0, -pi, pi, 2*pi] + assert solve(sin(x) + sec(x)) == [ + -2*atan(Rational(-1, 2) + sqrt(2)*sqrt(1 - sqrt(3)*I)/2 + sqrt(3)*I/2), + 2*atan(S.Half - sqrt(2)*sqrt(1 + sqrt(3)*I)/2 + sqrt(3)*I/2), 2*atan(S.Half + + sqrt(2)*sqrt(1 + sqrt(3)*I)/2 + sqrt(3)*I/2), 2*atan(S.Half - + sqrt(3)*I/2 + sqrt(2)*sqrt(1 - sqrt(3)*I)/2)] + assert solve(sinh(x) + tanh(x)) == [0, I*pi] + + # issue 6157 + assert solve(2*sin(x) - cos(x), x) == [atan(S.Half)] + + +@XFAIL +def test_rewrite_trigh(): + # if this import passes then the test below should also pass + from sympy.functions.elementary.hyperbolic import sech + assert solve(sinh(x) + sech(x)) == [ + 2*atanh(Rational(-1, 2) + sqrt(5)/2 - sqrt(-2*sqrt(5) + 2)/2), + 2*atanh(Rational(-1, 2) + sqrt(5)/2 + sqrt(-2*sqrt(5) + 2)/2), + 2*atanh(-sqrt(5)/2 - S.Half + sqrt(2 + 2*sqrt(5))/2), + 2*atanh(-sqrt(2 + 2*sqrt(5))/2 - sqrt(5)/2 - S.Half)] + + +def test_uselogcombine(): + eq = z - log(x) + log(y/(x*(-1 + y**2/x**2))) + assert solve(eq, x, force=True) == [-sqrt(y*(y - exp(z))), sqrt(y*(y - exp(z)))] + assert solve(log(x + 3) + log(1 + 3/x) - 3) in [ + [-3 + sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 + exp(3)/2, + -sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 - 3 + exp(3)/2], + [-3 + sqrt(-36 + (-exp(3) + 6)**2)/2 + exp(3)/2, + -3 - sqrt(-36 + (-exp(3) + 6)**2)/2 + exp(3)/2], + ] + assert solve(log(exp(2*x) + 1) + log(-tanh(x) + 1) - log(2)) == [] + + +def test_atan2(): + assert solve(atan2(x, 2) - pi/3, x) == [2*sqrt(3)] + + +def test_errorinverses(): + assert solve(erf(x) - y, x) == [erfinv(y)] + assert solve(erfinv(x) - y, x) == [erf(y)] + assert solve(erfc(x) - y, x) == [erfcinv(y)] + assert solve(erfcinv(x) - y, x) == [erfc(y)] + + +def test_issue_2725(): + R = Symbol('R') + eq = sqrt(2)*R*sqrt(1/(R + 1)) + (R + 1)*(sqrt(2)*sqrt(1/(R + 1)) - 1) + sol = solve(eq, R, set=True)[1] + assert sol == {(Rational(5, 3) + (Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3) + 40/(9*((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3))),), (Rational(5, 3) + 40/(9*(Rational(251, 27) + + sqrt(111)*I/9)**Rational(1, 3)) + (Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3),)} + + +def test_issue_5114_6611(): + # See that it doesn't hang; this solves in about 2 seconds. + # Also check that the solution is relatively small. + # Note: the system in issue 6611 solves in about 5 seconds and has + # an op-count of 138336 (with simplify=False). + b, c, d, e, f, g, h, i, j, k, l, m, n, o, p, q, r = symbols('b:r') + eqs = Matrix([ + [b - c/d + r/d], [c*(1/g + 1/e + 1/d) - f/g - r/d], + [-c/g + f*(1/j + 1/i + 1/g) - h/i], [-f/i + h*(1/m + 1/l + 1/i) - k/m], + [-h/m + k*(1/p + 1/o + 1/m) - n/p], [-k/p + n*(1/q + 1/p)]]) + v = Matrix([f, h, k, n, b, c]) + ans = solve(list(eqs), list(v), simplify=False) + # If time is taken to simplify then then 2617 below becomes + # 1168 and the time is about 50 seconds instead of 2. + assert sum(s.count_ops() for s in ans.values()) <= 3270 + + +def test_det_quick(): + m = Matrix(3, 3, symbols('a:9')) + assert m.det() == det_quick(m) # calls det_perm + m[0, 0] = 1 + assert m.det() == det_quick(m) # calls det_minor + m = Matrix(3, 3, list(range(9))) + assert m.det() == det_quick(m) # defaults to .det() + # make sure they work with Sparse + s = SparseMatrix(2, 2, (1, 2, 1, 4)) + assert det_perm(s) == det_minor(s) == s.det() + + +def test_real_imag_splitting(): + a, b = symbols('a b', real=True) + assert solve(sqrt(a**2 + b**2) - 3, a) == \ + [-sqrt(-b**2 + 9), sqrt(-b**2 + 9)] + a, b = symbols('a b', imaginary=True) + assert solve(sqrt(a**2 + b**2) - 3, a) == [] + + +def test_issue_7110(): + y = -2*x**3 + 4*x**2 - 2*x + 5 + assert any(ask(Q.real(i)) for i in solve(y)) + + +def test_units(): + assert solve(1/x - 1/(2*cm)) == [2*cm] + + +def test_issue_7547(): + A, B, V = symbols('A,B,V') + eq1 = Eq(630.26*(V - 39.0)*V*(V + 39) - A + B, 0) + eq2 = Eq(B, 1.36*10**8*(V - 39)) + eq3 = Eq(A, 5.75*10**5*V*(V + 39.0)) + sol = Matrix(nsolve(Tuple(eq1, eq2, eq3), [A, B, V], (0, 0, 0))) + assert str(sol) == str(Matrix( + [['4442890172.68209'], + ['4289299466.1432'], + ['70.5389666628177']])) + + +def test_issue_7895(): + r = symbols('r', real=True) + assert solve(sqrt(r) - 2) == [4] + + +def test_issue_2777(): + # the equations represent two circles + x, y = symbols('x y', real=True) + e1, e2 = sqrt(x**2 + y**2) - 10, sqrt(y**2 + (-x + 10)**2) - 3 + a, b = Rational(191, 20), 3*sqrt(391)/20 + ans = [(a, -b), (a, b)] + assert solve((e1, e2), (x, y)) == ans + assert solve((e1, e2/(x - a)), (x, y)) == [] + # make the 2nd circle's radius be -3 + e2 += 6 + assert solve((e1, e2), (x, y)) == [] + assert solve((e1, e2), (x, y), check=False) == ans + + +def test_issue_7322(): + number = 5.62527e-35 + assert solve(x - number, x)[0] == number + + +def test_nsolve(): + raises(ValueError, lambda: nsolve(x, (-1, 1), method='bisect')) + raises(TypeError, lambda: nsolve((x - y + 3,x + y,z - y),(x,y,z),(-50,50))) + raises(TypeError, lambda: nsolve((x + y, x - y), (0, 1))) + raises(TypeError, lambda: nsolve(x < 0.5, x, 1)) + + +@slow +def test_high_order_multivariate(): + assert len(solve(a*x**3 - x + 1, x)) == 3 + assert len(solve(a*x**4 - x + 1, x)) == 4 + assert solve(a*x**5 - x + 1, x) == [] # incomplete solution allowed + raises(NotImplementedError, lambda: + solve(a*x**5 - x + 1, x, incomplete=False)) + + # result checking must always consider the denominator and CRootOf + # must be checked, too + d = x**5 - x + 1 + assert solve(d*(1 + 1/d)) == [CRootOf(d + 1, i) for i in range(5)] + d = x - 1 + assert solve(d*(2 + 1/d)) == [S.Half] + + +def test_base_0_exp_0(): + assert solve(0**x - 1) == [0] + assert solve(0**(x - 2) - 1) == [2] + assert solve(S('x*(1/x**0 - x)', evaluate=False)) == \ + [0, 1] + + +def test__simple_dens(): + assert _simple_dens(1/x**0, [x]) == set() + assert _simple_dens(1/x**y, [x]) == {x**y} + assert _simple_dens(1/root(x, 3), [x]) == {x} + + +def test_issue_8755(): + # This tests two things: that if full unrad is attempted and fails + # the solution should still be found; also it tests the use of + # keyword `composite`. + assert len(solve(sqrt(y)*x + x**3 - 1, x)) == 3 + assert len(solve(-512*y**3 + 1344*(x + 2)**Rational(1, 3)*y**2 - + 1176*(x + 2)**Rational(2, 3)*y - 169*x + 686, y, _unrad=False)) == 3 + + +@slow +def test_issue_8828(): + x1 = 0 + y1 = -620 + r1 = 920 + x2 = 126 + y2 = 276 + x3 = 51 + y3 = 205 + r3 = 104 + v = x, y, z + + f1 = (x - x1)**2 + (y - y1)**2 - (r1 - z)**2 + f2 = (x - x2)**2 + (y - y2)**2 - z**2 + f3 = (x - x3)**2 + (y - y3)**2 - (r3 - z)**2 + F = f1,f2,f3 + + g1 = sqrt((x - x1)**2 + (y - y1)**2) + z - r1 + g2 = f2 + g3 = sqrt((x - x3)**2 + (y - y3)**2) + z - r3 + G = g1,g2,g3 + + A = solve(F, v) + B = solve(G, v) + C = solve(G, v, manual=True) + + p, q, r = [{tuple(i.evalf(2) for i in j) for j in R} for R in [A, B, C]] + assert p == q == r + + +def test_issue_2840_8155(): + # with parameter-free solutions (i.e. no `n`), we want to avoid + # excessive periodic solutions + assert solve(sin(3*x) + sin(6*x)) == [0, -2*pi/9, 2*pi/9] + assert solve(sin(300*x) + sin(600*x)) == [0, -pi/450, pi/450] + assert solve(2*sin(x) - 2*sin(2*x)) == [0, -pi/3, pi/3] + + +def test_issue_9567(): + assert solve(1 + 1/(x - 1)) == [0] + + +def test_issue_11538(): + assert solve(x + E) == [-E] + assert solve(x**2 + E) == [-I*sqrt(E), I*sqrt(E)] + assert solve(x**3 + 2*E) == [ + -cbrt(2 * E), + cbrt(2)*cbrt(E)/2 - cbrt(2)*sqrt(3)*I*cbrt(E)/2, + cbrt(2)*cbrt(E)/2 + cbrt(2)*sqrt(3)*I*cbrt(E)/2] + assert solve([x + 4, y + E], x, y) == {x: -4, y: -E} + assert solve([x**2 + 4, y + E], x, y) == [ + (-2*I, -E), (2*I, -E)] + + e1 = x - y**3 + 4 + e2 = x + y + 4 + 4 * E + assert len(solve([e1, e2], x, y)) == 3 + + +@slow +def test_issue_12114(): + a, b, c, d, e, f, g = symbols('a,b,c,d,e,f,g') + terms = [1 + a*b + d*e, 1 + a*c + d*f, 1 + b*c + e*f, + g - a**2 - d**2, g - b**2 - e**2, g - c**2 - f**2] + sol = solve(terms, [a, b, c, d, e, f, g], dict=True) + s = sqrt(-f**2 - 1) + s2 = sqrt(2 - f**2) + s3 = sqrt(6 - 3*f**2) + s4 = sqrt(3)*f + s5 = sqrt(3)*s2 + assert sol == [ + {a: -s, b: -s, c: -s, d: f, e: f, g: -1}, + {a: s, b: s, c: s, d: f, e: f, g: -1}, + {a: -s4/2 - s2/2, b: s4/2 - s2/2, c: s2, + d: -f/2 + s3/2, e: -f/2 - s5/2, g: 2}, + {a: -s4/2 + s2/2, b: s4/2 + s2/2, c: -s2, + d: -f/2 - s3/2, e: -f/2 + s5/2, g: 2}, + {a: s4/2 - s2/2, b: -s4/2 - s2/2, c: s2, + d: -f/2 - s3/2, e: -f/2 + s5/2, g: 2}, + {a: s4/2 + s2/2, b: -s4/2 + s2/2, c: -s2, + d: -f/2 + s3/2, e: -f/2 - s5/2, g: 2}] + + +def test_inf(): + assert solve(1 - oo*x) == [] + assert solve(oo*x, x) == [] + assert solve(oo*x - oo, x) == [] + + +def test_issue_12448(): + f = Function('f') + fun = [f(i) for i in range(15)] + sym = symbols('x:15') + reps = dict(zip(fun, sym)) + + (x, y, z), c = sym[:3], sym[3:] + ssym = solve([c[4*i]*x + c[4*i + 1]*y + c[4*i + 2]*z + c[4*i + 3] + for i in range(3)], (x, y, z)) + + (x, y, z), c = fun[:3], fun[3:] + sfun = solve([c[4*i]*x + c[4*i + 1]*y + c[4*i + 2]*z + c[4*i + 3] + for i in range(3)], (x, y, z)) + + assert sfun[fun[0]].xreplace(reps).count_ops() == \ + ssym[sym[0]].count_ops() + + +def test_denoms(): + assert denoms(x/2 + 1/y) == {2, y} + assert denoms(x/2 + 1/y, y) == {y} + assert denoms(x/2 + 1/y, [y]) == {y} + assert denoms(1/x + 1/y + 1/z, [x, y]) == {x, y} + assert denoms(1/x + 1/y + 1/z, x, y) == {x, y} + assert denoms(1/x + 1/y + 1/z, {x, y}) == {x, y} + + +def test_issue_12476(): + x0, x1, x2, x3, x4, x5 = symbols('x0 x1 x2 x3 x4 x5') + eqns = [x0**2 - x0, x0*x1 - x1, x0*x2 - x2, x0*x3 - x3, x0*x4 - x4, x0*x5 - x5, + x0*x1 - x1, -x0/3 + x1**2 - 2*x2/3, x1*x2 - x1/3 - x2/3 - x3/3, + x1*x3 - x2/3 - x3/3 - x4/3, x1*x4 - 2*x3/3 - x5/3, x1*x5 - x4, x0*x2 - x2, + x1*x2 - x1/3 - x2/3 - x3/3, -x0/6 - x1/6 + x2**2 - x2/6 - x3/3 - x4/6, + -x1/6 + x2*x3 - x2/3 - x3/6 - x4/6 - x5/6, x2*x4 - x2/3 - x3/3 - x4/3, + x2*x5 - x3, x0*x3 - x3, x1*x3 - x2/3 - x3/3 - x4/3, + -x1/6 + x2*x3 - x2/3 - x3/6 - x4/6 - x5/6, + -x0/6 - x1/6 - x2/6 + x3**2 - x3/3 - x4/6, -x1/3 - x2/3 + x3*x4 - x3/3, + -x2 + x3*x5, x0*x4 - x4, x1*x4 - 2*x3/3 - x5/3, x2*x4 - x2/3 - x3/3 - x4/3, + -x1/3 - x2/3 + x3*x4 - x3/3, -x0/3 - 2*x2/3 + x4**2, -x1 + x4*x5, x0*x5 - x5, + x1*x5 - x4, x2*x5 - x3, -x2 + x3*x5, -x1 + x4*x5, -x0 + x5**2, x0 - 1] + sols = [{x0: 1, x3: Rational(1, 6), x2: Rational(1, 6), x4: Rational(-2, 3), x1: Rational(-2, 3), x5: 1}, + {x0: 1, x3: S.Half, x2: Rational(-1, 2), x4: 0, x1: 0, x5: -1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(-1, 3), x4: Rational(1, 3), x1: Rational(1, 3), x5: 1}, + {x0: 1, x3: 1, x2: 1, x4: 1, x1: 1, x5: 1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(1, 3), x4: sqrt(5)/3, x1: -sqrt(5)/3, x5: -1}, + {x0: 1, x3: Rational(-1, 3), x2: Rational(1, 3), x4: -sqrt(5)/3, x1: sqrt(5)/3, x5: -1}] + + assert solve(eqns) == sols + + +def test_issue_13849(): + t = symbols('t') + assert solve((t*(sqrt(5) + sqrt(2)) - sqrt(2), t), t) == [] + + +def test_issue_14860(): + from sympy.physics.units import newton, kilo + assert solve(8*kilo*newton + x + y, x) == [-8000*newton - y] + + +def test_issue_14721(): + k, h, a, b = symbols(':4') + assert solve([ + -1 + (-k + 1)**2/b**2 + (-h - 1)**2/a**2, + -1 + (-k + 1)**2/b**2 + (-h + 1)**2/a**2, + h, k + 2], h, k, a, b) == [ + (0, -2, -b*sqrt(1/(b**2 - 9)), b), + (0, -2, b*sqrt(1/(b**2 - 9)), b)] + assert solve([ + h, h/a + 1/b**2 - 2, -h/2 + 1/b**2 - 2], a, h, b) == [ + (a, 0, -sqrt(2)/2), (a, 0, sqrt(2)/2)] + assert solve((a + b**2 - 1, a + b**2 - 2)) == [] + + +def test_issue_14779(): + x = symbols('x', real=True) + assert solve(sqrt(x**4 - 130*x**2 + 1089) + sqrt(x**4 - 130*x**2 + + 3969) - 96*Abs(x)/x,x) == [sqrt(130)] + + +def test_issue_15307(): + assert solve((y - 2, Mul(x + 3,x - 2, evaluate=False))) == \ + [{x: -3, y: 2}, {x: 2, y: 2}] + assert solve((y - 2, Mul(3, x - 2, evaluate=False))) == \ + {x: 2, y: 2} + assert solve((y - 2, Add(x + 4, x - 2, evaluate=False))) == \ + {x: -1, y: 2} + eq1 = Eq(12513*x + 2*y - 219093, -5726*x - y) + eq2 = Eq(-2*x + 8, 2*x - 40) + assert solve([eq1, eq2]) == {x:12, y:75} + + +def test_issue_15415(): + assert solve(x - 3, x) == [3] + assert solve([x - 3], x) == {x:3} + assert solve(Eq(y + 3*x**2/2, y + 3*x), y) == [] + assert solve([Eq(y + 3*x**2/2, y + 3*x)], y) == [] + assert solve([Eq(y + 3*x**2/2, y + 3*x), Eq(x, 1)], y) == [] + + +@slow +def test_issue_15731(): + # f(x)**g(x)=c + assert solve(Eq((x**2 - 7*x + 11)**(x**2 - 13*x + 42), 1)) == [2, 3, 4, 5, 6, 7] + assert solve((x)**(x + 4) - 4) == [-2] + assert solve((-x)**(-x + 4) - 4) == [2] + assert solve((x**2 - 6)**(x**2 - 2) - 4) == [-2, 2] + assert solve((x**2 - 2*x - 1)**(x**2 - 3) - 1/(1 - 2*sqrt(2))) == [sqrt(2)] + assert solve(x**(x + S.Half) - 4*sqrt(2)) == [S(2)] + assert solve((x**2 + 1)**x - 25) == [2] + assert solve(x**(2/x) - 2) == [2, 4] + assert solve((x/2)**(2/x) - sqrt(2)) == [4, 8] + assert solve(x**(x + S.Half) - Rational(9, 4)) == [Rational(3, 2)] + # a**g(x)=c + assert solve((-sqrt(sqrt(2)))**x - 2) == [4, log(2)/(log(2**Rational(1, 4)) + I*pi)] + assert solve((sqrt(2))**x - sqrt(sqrt(2))) == [S.Half] + assert solve((-sqrt(2))**x + 2*(sqrt(2))) == [3, + (3*log(2)**2 + 4*pi**2 - 4*I*pi*log(2))/(log(2)**2 + 4*pi**2)] + assert solve((sqrt(2))**x - 2*(sqrt(2))) == [3] + assert solve(I**x + 1) == [2] + assert solve((1 + I)**x - 2*I) == [2] + assert solve((sqrt(2) + sqrt(3))**x - (2*sqrt(6) + 5)**Rational(1, 3)) == [Rational(2, 3)] + # bases of both sides are equal + b = Symbol('b') + assert solve(b**x - b**2, x) == [2] + assert solve(b**x - 1/b, x) == [-1] + assert solve(b**x - b, x) == [1] + b = Symbol('b', positive=True) + assert solve(b**x - b**2, x) == [2] + assert solve(b**x - 1/b, x) == [-1] + + +def test_issue_10933(): + assert solve(x**4 + y*(x + 0.1), x) # doesn't fail + assert solve(I*x**4 + x**3 + x**2 + 1.) # doesn't fail + + +def test_Abs_handling(): + x = symbols('x', real=True) + assert solve(abs(x/y), x) == [0] + + +def test_issue_7982(): + x = Symbol('x') + # Test that no exception happens + assert solve([2*x**2 + 5*x + 20 <= 0, x >= 1.5], x) is S.false + # From #8040 + assert solve([x**3 - 8.08*x**2 - 56.48*x/5 - 106 >= 0, x - 1 <= 0], [x]) is S.false + + +def test_issue_14645(): + x, y = symbols('x y') + assert solve([x*y - x - y, x*y - x - y], [x, y]) == [(y/(y - 1), y)] + + +def test_issue_12024(): + x, y = symbols('x y') + assert solve(Piecewise((0.0, x < 0.1), (x, x >= 0.1)) - y) == \ + [{y: Piecewise((0.0, x < 0.1), (x, True))}] + + +def test_issue_17452(): + assert solve((7**x)**x + pi, x) == [-sqrt(log(pi) + I*pi)/sqrt(log(7)), + sqrt(log(pi) + I*pi)/sqrt(log(7))] + assert solve(x**(x/11) + pi/11, x) == [exp(LambertW(-11*log(11) + 11*log(pi) + 11*I*pi))] + + +def test_issue_17799(): + assert solve(-erf(x**(S(1)/3))**pi + I, x) == [] + + +def test_issue_17650(): + x = Symbol('x', real=True) + assert solve(abs(abs(x**2 - 1) - x) - x) == [1, -1 + sqrt(2), 1 + sqrt(2)] + + +def test_issue_17882(): + eq = -8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3)) + assert unrad(eq) is None + + +def test_issue_17949(): + assert solve(exp(+x+x**2), x) == [] + assert solve(exp(-x+x**2), x) == [] + assert solve(exp(+x-x**2), x) == [] + assert solve(exp(-x-x**2), x) == [] + + +def test_issue_10993(): + assert solve(Eq(binomial(x, 2), 3)) == [-2, 3] + assert solve(Eq(pow(x, 2) + binomial(x, 3), x)) == [-4, 0, 1] + assert solve(Eq(binomial(x, 2), 0)) == [0, 1] + assert solve(a+binomial(x, 3), a) == [-binomial(x, 3)] + assert solve(x-binomial(a, 3) + binomial(y, 2) + sin(a), x) == [-sin(a) + binomial(a, 3) - binomial(y, 2)] + assert solve((x+1)-binomial(x+1, 3), x) == [-2, -1, 3] + + +def test_issue_11553(): + eq1 = x + y + 1 + eq2 = x + GoldenRatio + assert solve([eq1, eq2], x, y) == {x: -GoldenRatio, y: -1 + GoldenRatio} + eq3 = x + 2 + TribonacciConstant + assert solve([eq1, eq3], x, y) == {x: -2 - TribonacciConstant, y: 1 + TribonacciConstant} + + +def test_issue_19113_19102(): + t = S(1)/3 + solve(cos(x)**5-sin(x)**5) + assert solve(4*cos(x)**3 - 2*sin(x)**3) == [ + atan(2**(t)), -atan(2**(t)*(1 - sqrt(3)*I)/2), + -atan(2**(t)*(1 + sqrt(3)*I)/2)] + h = S.Half + assert solve(cos(x)**2 + sin(x)) == [ + 2*atan(-h + sqrt(5)/2 + sqrt(2)*sqrt(1 - sqrt(5))/2), + -2*atan(h + sqrt(5)/2 + sqrt(2)*sqrt(1 + sqrt(5))/2), + -2*atan(-sqrt(5)/2 + h + sqrt(2)*sqrt(1 - sqrt(5))/2), + -2*atan(-sqrt(2)*sqrt(1 + sqrt(5))/2 + h + sqrt(5)/2)] + assert solve(3*cos(x) - sin(x)) == [atan(3)] + + +def test_issue_19509(): + a = S(3)/4 + b = S(5)/8 + c = sqrt(5)/8 + d = sqrt(5)/4 + assert solve(1/(x -1)**5 - 1) == [2, + -d + a - sqrt(-b + c), + -d + a + sqrt(-b + c), + d + a - sqrt(-b - c), + d + a + sqrt(-b - c)] + +def test_issue_20747(): + THT, HT, DBH, dib, c0, c1, c2, c3, c4 = symbols('THT HT DBH dib c0 c1 c2 c3 c4') + f = DBH*c3 + THT*c4 + c2 + rhs = 1 - ((HT - 1)/(THT - 1))**c1*(1 - exp(c0/f)) + eq = dib - DBH*(c0 - f*log(rhs)) + term = ((1 - exp((DBH*c0 - dib)/(DBH*(DBH*c3 + THT*c4 + c2)))) + / (1 - exp(c0/(DBH*c3 + THT*c4 + c2)))) + sol = [THT*term**(1/c1) - term**(1/c1) + 1] + assert solve(eq, HT) == sol + + +def test_issue_20902(): + f = (t / ((1 + t) ** 2)) + assert solve(f.subs({t: 3 * x + 2}).diff(x) > 0, x) == (S(-1) < x) & (x < S(-1)/3) + assert solve(f.subs({t: 3 * x + 3}).diff(x) > 0, x) == (S(-4)/3 < x) & (x < S(-2)/3) + assert solve(f.subs({t: 3 * x + 4}).diff(x) > 0, x) == (S(-5)/3 < x) & (x < S(-1)) + assert solve(f.subs({t: 3 * x + 2}).diff(x) > 0, x) == (S(-1) < x) & (x < S(-1)/3) + + +def test_issue_21034(): + a = symbols('a', real=True) + system = [x - cosh(cos(4)), y - sinh(cos(a)), z - tanh(x)] + # constants inside hyperbolic functions should not be rewritten in terms of exp + assert solve(system, x, y, z) == [(cosh(cos(4)), sinh(cos(a)), tanh(cosh(cos(4))))] + # but if the variable of interest is present in a hyperbolic function, + # then it should be rewritten in terms of exp and solved further + newsystem = [(exp(x) - exp(-x)) - tanh(x)*(exp(x) + exp(-x)) + x - 5] + assert solve(newsystem, x) == {x: 5} + + +def test_issue_4886(): + z = a*sqrt(R**2*a**2 + R**2*b**2 - c**2)/(a**2 + b**2) + t = b*c/(a**2 + b**2) + sol = [((b*(t - z) - c)/(-a), t - z), ((b*(t + z) - c)/(-a), t + z)] + assert solve([x**2 + y**2 - R**2, a*x + b*y - c], x, y) == sol + + +def test_issue_6819(): + a, b, c, d = symbols('a b c d', positive=True) + assert solve(a*b**x - c*d**x, x) == [log(c/a)/log(b/d)] + + +def test_issue_17454(): + x = Symbol('x') + assert solve((1 - x - I)**4, x) == [1 - I] + + +def test_issue_21852(): + solution = [21 - 21*sqrt(2)/2] + assert solve(2*x + sqrt(2*x**2) - 21) == solution + + +def test_issue_21942(): + eq = -d + (a*c**(1 - e) + b**(1 - e)*(1 - a))**(1/(1 - e)) + sol = solve(eq, c, simplify=False, check=False) + assert sol == [((a*b**(1 - e) - b**(1 - e) + + d**(1 - e))/a)**(1/(1 - e))] + + +def test_solver_flags(): + root = solve(x**5 + x**2 - x - 1, cubics=False) + rad = solve(x**5 + x**2 - x - 1, cubics=True) + assert root != rad + + +def test_issue_22768(): + eq = 2*x**3 - 16*(y - 1)**6*z**3 + assert solve(eq.expand(), x, simplify=False + ) == [2*z*(y - 1)**2, z*(-1 + sqrt(3)*I)*(y - 1)**2, + -z*(1 + sqrt(3)*I)*(y - 1)**2] + + +def test_issue_22717(): + assert solve((-y**2 + log(y**2/x) + 2, -2*x*y + 2*x/y)) == [ + {y: -1, x: E}, {y: 1, x: E}] + + +def test_issue_25176(): + eq = (x - 5)**-8 - 3 + sol = solve(eq) + assert not any(eq.subs(x, i) for i in sol) + + +def test_issue_10169(): + eq = S(-8*a - x**5*(a + b + c + e) - x**4*(4*a - 2**Rational(3,4)*c + 4*c + + d + 2**Rational(3,4)*e + 4*e + k) - x**3*(-4*2**Rational(3,4)*c + sqrt(2)*c - + 2**Rational(3,4)*d + 4*d + sqrt(2)*e + 4*2**Rational(3,4)*e + 2**Rational(3,4)*k + 4*k) - + x**2*(4*sqrt(2)*c - 4*2**Rational(3,4)*d + sqrt(2)*d + 4*sqrt(2)*e + + sqrt(2)*k + 4*2**Rational(3,4)*k) - x*(2*a + 2*b + 4*sqrt(2)*d + + 4*sqrt(2)*k) + 5) + assert solve_undetermined_coeffs(eq, [a, b, c, d, e, k], x) == { + a: Rational(5,8), + b: Rational(-5,1032), + c: Rational(-40,129) - 5*2**Rational(3,4)/129 + 5*2**Rational(1,4)/1032, + d: -20*2**Rational(3,4)/129 - 10*sqrt(2)/129 - 5*2**Rational(1,4)/258, + e: Rational(-40,129) - 5*2**Rational(1,4)/1032 + 5*2**Rational(3,4)/129, + k: -10*sqrt(2)/129 + 5*2**Rational(1,4)/258 + 20*2**Rational(3,4)/129 + } + + +def test_solve_undetermined_coeffs_issue_23927(): + A, B, r, phi = symbols('A, B, r, phi') + e = Eq(A*sin(t) + B*cos(t), r*sin(t - phi)) + eq = (e.lhs - e.rhs).expand(trig=True) + soln = solve_undetermined_coeffs(eq, (r, phi), t) + assert soln == [{ + phi: 2*atan((A - sqrt(A**2 + B**2))/B), + r: (-A**2 + A*sqrt(A**2 + B**2) - B**2)/(A - sqrt(A**2 + B**2)) + }, { + phi: 2*atan((A + sqrt(A**2 + B**2))/B), + r: (A**2 + A*sqrt(A**2 + B**2) + B**2)/(A + sqrt(A**2 + B**2))/-1 + }] + +def test_issue_24368(): + # Ideally these would produce a solution, but for now just check that they + # don't fail with a RuntimeError + raises(NotImplementedError, lambda: solve(Mod(x**2, 49), x)) + s2 = Symbol('s2', integer=True, positive=True) + f = floor(s2/2 - S(1)/2) + raises(NotImplementedError, lambda: solve((Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1))*f + Mod(f**2/(f + 1) + 2*f/(f + 1) + 1/(f + 1), 1), s2)) + + +def test_solve_Piecewise(): + assert [S(10)/3] == solve(3*Piecewise( + (S.NaN, x <= 0), + (20*x - 3*(x - 6)**2/2 - 176, (x >= 0) & (x >= 2) & (x>= 4) & (x >= 6) & (x < 10)), + (100 - 26*x, (x >= 0) & (x >= 2) & (x >= 4) & (x < 10)), + (16*x - 3*(x - 6)**2/2 - 176, (x >= 2) & (x >= 4) & (x >= 6) & (x < 10)), + (100 - 30*x, (x >= 2) & (x >= 4) & (x < 10)), + (30*x - 3*(x - 6)**2/2 - 196, (x>= 0) & (x >= 4) & (x >= 6) & (x < 10)), + (80 - 16*x, (x >= 0) & (x >= 4) & (x < 10)), + (26*x - 3*(x - 6)**2/2 - 196, (x >= 4) & (x >= 6) & (x < 10)), + (80 - 20*x, (x >= 4) & (x < 10)), + (40*x - 3*(x - 6)**2/2 - 256, (x >= 0) & (x >= 2) & (x >= 6) & (x < 10)), + (20 - 6*x, (x >= 0) & (x >= 2) & (x < 10)), + (36*x - 3*(x - 6)**2/2 - 256, (x >= 2) & (x >= 6) & (x < 10)), + (20 - 10*x, (x >= 2) & (x < 10)), + (50*x - 3*(x - 6)**2/2 - 276, (x >= 0) & (x >= 6) & (x < 10)), + (4*x, (x >= 0) & (x < 10)), + (46*x - 3*(x - 6)**2/2 - 276, (x >= 6) & (x < 10)), + (0, x < 10), # this will simplify away + (S.NaN,True))) diff --git a/lib/python3.10/site-packages/sympy/solvers/tests/test_solveset.py b/lib/python3.10/site-packages/sympy/solvers/tests/test_solveset.py new file mode 100644 index 0000000000000000000000000000000000000000..a1ba7a11e68ed518c4d83c050947b78756ade181 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/solvers/tests/test_solveset.py @@ -0,0 +1,3548 @@ +from math import isclose + +from sympy.calculus.util import stationary_points +from sympy.core.containers import Tuple +from sympy.core.function import (Function, Lambda, nfloat, diff) +from sympy.core.mod import Mod +from sympy.core.numbers import (E, I, Rational, oo, pi, Integer, all_close) +from sympy.core.relational import (Eq, Gt, Ne, Ge) +from sympy.core.singleton import S +from sympy.core.sorting import ordered +from sympy.core.symbol import (Dummy, Symbol, symbols) +from sympy.core.sympify import sympify +from sympy.functions.elementary.complexes import (Abs, arg, im, re, sign, conjugate) +from sympy.functions.elementary.exponential import (LambertW, exp, log) +from sympy.functions.elementary.hyperbolic import (HyperbolicFunction, + sinh, cosh, tanh, coth, sech, csch, asinh, acosh, atanh, acoth, asech, acsch) +from sympy.functions.elementary.miscellaneous import sqrt, Min, Max +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import ( + TrigonometricFunction, acos, acot, acsc, asec, asin, atan, atan2, + cos, cot, csc, sec, sin, tan) +from sympy.functions.special.error_functions import (erf, erfc, + erfcinv, erfinv) +from sympy.logic.boolalg import And +from sympy.matrices.dense import MutableDenseMatrix as Matrix +from sympy.matrices.immutable import ImmutableDenseMatrix +from sympy.polys.polytools import Poly +from sympy.polys.rootoftools import CRootOf +from sympy.sets.contains import Contains +from sympy.sets.conditionset import ConditionSet +from sympy.sets.fancysets import ImageSet, Range +from sympy.sets.sets import (Complement, FiniteSet, + Intersection, Interval, Union, imageset, ProductSet) +from sympy.simplify import simplify +from sympy.tensor.indexed import Indexed +from sympy.utilities.iterables import numbered_symbols + +from sympy.testing.pytest import (XFAIL, raises, skip, slow, SKIP, _both_exp_pow) +from sympy.core.random import verify_numerically as tn +from sympy.physics.units import cm + +from sympy.solvers import solve +from sympy.solvers.solveset import ( + solveset_real, domain_check, solveset_complex, linear_eq_to_matrix, + linsolve, _is_function_class_equation, invert_real, invert_complex, + _invert_trig_hyp_real, solveset, solve_decomposition, substitution, + nonlinsolve, solvify, + _is_finite_with_finite_vars, _transolve, _is_exponential, + _solve_exponential, _is_logarithmic, _is_lambert, + _solve_logarithm, _term_factors, _is_modular, NonlinearError) + +from sympy.abc import (a, b, c, d, e, f, g, h, i, j, k, l, m, n, q, r, + t, w, x, y, z) + + +def dumeq(i, j): + if type(i) in (list, tuple): + return all(dumeq(i, j) for i, j in zip(i, j)) + return i == j or i.dummy_eq(j) + + +def assert_close_ss(sol1, sol2): + """Test solutions with floats from solveset are close""" + sol1 = sympify(sol1) + sol2 = sympify(sol2) + assert isinstance(sol1, FiniteSet) + assert isinstance(sol2, FiniteSet) + assert len(sol1) == len(sol2) + assert all(isclose(v1, v2) for v1, v2 in zip(sol1, sol2)) + + +def assert_close_nl(sol1, sol2): + """Test solutions with floats from nonlinsolve are close""" + sol1 = sympify(sol1) + sol2 = sympify(sol2) + assert isinstance(sol1, FiniteSet) + assert isinstance(sol2, FiniteSet) + assert len(sol1) == len(sol2) + for s1, s2 in zip(sol1, sol2): + assert len(s1) == len(s2) + assert all(isclose(v1, v2) for v1, v2 in zip(s1, s2)) + + +@_both_exp_pow +def test_invert_real(): + x = Symbol('x', real=True) + + def ireal(x, s=S.Reals): + return Intersection(s, x) + + assert invert_real(exp(x), z, x) == (x, ireal(FiniteSet(log(z)))) + + y = Symbol('y', positive=True) + n = Symbol('n', real=True) + assert invert_real(x + 3, y, x) == (x, FiniteSet(y - 3)) + assert invert_real(x*3, y, x) == (x, FiniteSet(y / 3)) + + assert invert_real(exp(x), y, x) == (x, FiniteSet(log(y))) + assert invert_real(exp(3*x), y, x) == (x, FiniteSet(log(y) / 3)) + assert invert_real(exp(x + 3), y, x) == (x, FiniteSet(log(y) - 3)) + + assert invert_real(exp(x) + 3, y, x) == (x, ireal(FiniteSet(log(y - 3)))) + assert invert_real(exp(x)*3, y, x) == (x, FiniteSet(log(y / 3))) + + assert invert_real(log(x), y, x) == (x, FiniteSet(exp(y))) + assert invert_real(log(3*x), y, x) == (x, FiniteSet(exp(y) / 3)) + assert invert_real(log(x + 3), y, x) == (x, FiniteSet(exp(y) - 3)) + + assert invert_real(Abs(x), y, x) == (x, FiniteSet(y, -y)) + + assert invert_real(2**x, y, x) == (x, FiniteSet(log(y)/log(2))) + assert invert_real(2**exp(x), y, x) == (x, ireal(FiniteSet(log(log(y)/log(2))))) + + assert invert_real(x**2, y, x) == (x, FiniteSet(sqrt(y), -sqrt(y))) + assert invert_real(x**S.Half, y, x) == (x, FiniteSet(y**2)) + + raises(ValueError, lambda: invert_real(x, x, x)) + + # issue 21236 + assert invert_real(x**pi, y, x) == (x, FiniteSet(y**(1/pi))) + assert invert_real(x**pi, -E, x) == (x, S.EmptySet) + assert invert_real(x**Rational(3/2), 1000, x) == (x, FiniteSet(100)) + assert invert_real(x**1.0, 1, x) == (x**1.0, FiniteSet(1)) + + raises(ValueError, lambda: invert_real(S.One, y, x)) + + assert invert_real(x**31 + x, y, x) == (x**31 + x, FiniteSet(y)) + + lhs = x**31 + x + base_values = FiniteSet(y - 1, -y - 1) + assert invert_real(Abs(x**31 + x + 1), y, x) == (lhs, base_values) + + assert dumeq(invert_real(sin(x), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + asin(y)), S.Integers), + ImageSet(Lambda(n, pi*2*n + pi - asin(y)), S.Integers))))) + + assert dumeq(invert_real(sin(exp(x)), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, log(2*n*pi + asin(y))), S.Integers), + ImageSet(Lambda(n, log(pi*2*n + pi - asin(y))), S.Integers))))) + + assert dumeq(invert_real(csc(x), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, 2*n*pi + acsc(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - acsc(y) + pi), S.Integers))))) + + assert dumeq(invert_real(csc(exp(x)), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, log(2*n*pi + acsc(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi - acsc(y) + pi)), S.Integers))))) + + assert dumeq(invert_real(cos(x), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + acos(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - acos(y)), S.Integers))))) + + assert dumeq(invert_real(cos(exp(x)), y, x), (x, + ConditionSet(x, (S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, log(2*n*pi + acos(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi - acos(y))), S.Integers))))) + + assert dumeq(invert_real(sec(x), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, 2*n*pi + asec(y)), S.Integers), \ + ImageSet(Lambda(n, 2*n*pi - asec(y)), S.Integers))))) + + assert dumeq(invert_real(sec(exp(x)), y, x), (x, + ConditionSet(x, ((S(1) <= y) & (y < oo)) | ((-oo < y) & (y <= S(-1))), + Union(ImageSet(Lambda(n, log(2*n*pi - asec(y))), S.Integers), + ImageSet(Lambda(n, log(2*n*pi + asec(y))), S.Integers))))) + + assert dumeq(invert_real(tan(x), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + atan(y)), S.Integers)))) + + assert dumeq(invert_real(tan(exp(x)), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, log(n*pi + atan(y))), S.Integers)))) + + assert dumeq(invert_real(cot(x), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + acot(y)), S.Integers)))) + + assert dumeq(invert_real(cot(exp(x)), y, x), (x, + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, log(n*pi + acot(y))), S.Integers)))) + + assert dumeq(invert_real(tan(tan(x)), y, x), + (x, ConditionSet(x, Eq(tan(tan(x)), y), S.Reals))) + # slight regression compared to previous result: + # (tan(x), imageset(Lambda(n, n*pi + atan(y)), S.Integers))) + + x = Symbol('x', positive=True) + assert invert_real(x**pi, y, x) == (x, FiniteSet(y**(1/pi))) + + r = Symbol('r', real=True) + p = Symbol('p', positive=True) + assert invert_real(sinh(x), r, x) == (x, FiniteSet(asinh(r))) + assert invert_real(sinh(log(x)), p, x) == (x, FiniteSet(exp(asinh(p)))) + + assert invert_real(cosh(x), r, x) == (x, Intersection( + FiniteSet(-acosh(r), acosh(r)), S.Reals)) + assert invert_real(cosh(x), p + 1, x) == (x, + FiniteSet(-acosh(p + 1), acosh(p + 1))) + + assert invert_real(tanh(x), r, x) == (x, Intersection(FiniteSet(atanh(r)), S.Reals)) + assert invert_real(coth(x), p+1, x) == (x, FiniteSet(acoth(p+1))) + assert invert_real(sech(x), r, x) == (x, Intersection( + FiniteSet(-asech(r), asech(r)), S.Reals)) + assert invert_real(csch(x), p, x) == (x, FiniteSet(acsch(p))) + + assert dumeq(invert_real(tanh(sin(x)), r, x), (x, + ConditionSet(x, (S(-1) <= atanh(r)) & (atanh(r) <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi + asin(atanh(r))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - asin(atanh(r)) + pi), S.Integers))))) + + +def test_invert_trig_hyp_real(): + # check some codepaths that are not as easily reached otherwise + n = Dummy('n') + assert _invert_trig_hyp_real(cosh(x), Range(-5, 10, 1), x)[1].dummy_eq(Union( + ImageSet(Lambda(n, -acosh(n)), Range(1, 10, 1)), + ImageSet(Lambda(n, acosh(n)), Range(1, 10, 1)))) + assert _invert_trig_hyp_real(coth(x), Interval(-3, 2), x) == (x, Union( + Interval(-oo, -acoth(3)), Interval(acoth(2), oo))) + assert _invert_trig_hyp_real(tanh(x), Interval(-S.Half, 1), x) == (x, + Interval(-atanh(S.Half), oo)) + assert _invert_trig_hyp_real(sech(x), imageset(n, S.Half + n/3, S.Naturals0), x) == \ + (x, FiniteSet(-asech(S(1)/2), asech(S(1)/2), -asech(S(5)/6), asech(S(5)/6))) + assert _invert_trig_hyp_real(csch(x), S.Reals, x) == (x, + Union(Interval.open(-oo, 0), Interval.open(0, oo))) + + +def test_invert_complex(): + assert invert_complex(x + 3, y, x) == (x, FiniteSet(y - 3)) + assert invert_complex(x*3, y, x) == (x, FiniteSet(y / 3)) + assert invert_complex((x - 1)**3, 0, x) == (x, FiniteSet(1)) + + assert dumeq(invert_complex(exp(x), y, x), + (x, imageset(Lambda(n, I*(2*pi*n + arg(y)) + log(Abs(y))), S.Integers))) + + assert invert_complex(log(x), y, x) == (x, FiniteSet(exp(y))) + + raises(ValueError, lambda: invert_real(1, y, x)) + raises(ValueError, lambda: invert_complex(x, x, x)) + raises(ValueError, lambda: invert_complex(x, x, 1)) + + assert dumeq(invert_complex(sin(x), I, x), (x, Union( + ImageSet(Lambda(n, 2*n*pi + I*log(1 + sqrt(2))), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi - I*log(1 + sqrt(2))), S.Integers)))) + assert dumeq(invert_complex(cos(x), 1+I, x), (x, Union( + ImageSet(Lambda(n, 2*n*pi - acos(1 + I)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(1 + I)), S.Integers)))) + assert dumeq(invert_complex(tan(2*x), 1, x), (x, + ImageSet(Lambda(n, n*pi/2 + pi/8), S.Integers))) + assert dumeq(invert_complex(cot(x), 2*I, x), (x, + ImageSet(Lambda(n, n*pi - I*acoth(2)), S.Integers))) + + assert dumeq(invert_complex(sinh(x), 0, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + I*pi), S.Integers)))) + assert dumeq(invert_complex(cosh(x), 0, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi + I*pi/2), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + 3*I*pi/2), S.Integers)))) + assert invert_complex(tanh(x), 1, x) == (x, S.EmptySet) + assert dumeq(invert_complex(tanh(x), a, x), (x, + ConditionSet(x, Ne(a, -1) & Ne(a, 1), + ImageSet(Lambda(n, n*I*pi + atanh(a)), S.Integers)))) + assert invert_complex(coth(x), 1, x) == (x, S.EmptySet) + assert dumeq(invert_complex(coth(x), a, x), (x, + ConditionSet(x, Ne(a, -1) & Ne(a, 1), + ImageSet(Lambda(n, n*I*pi + acoth(a)), S.Integers)))) + assert dumeq(invert_complex(sech(x), 2, x), (x, Union( + ImageSet(Lambda(n, 2*n*I*pi + I*pi/3), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + 5*I*pi/3), S.Integers)))) + + +def test_domain_check(): + assert domain_check(1/(1 + (1/(x+1))**2), x, -1) is False + assert domain_check(x**2, x, 0) is True + assert domain_check(x, x, oo) is False + assert domain_check(0, x, oo) is False + + +def test_issue_11536(): + assert solveset(0**x - 100, x, S.Reals) == S.EmptySet + assert solveset(0**x - 1, x, S.Reals) == FiniteSet(0) + + +def test_issue_17479(): + f = (x**2 + y**2)**2 + (x**2 + z**2)**2 - 2*(2*x**2 + y**2 + z**2) + fx = f.diff(x) + fy = f.diff(y) + fz = f.diff(z) + sol = nonlinsolve([fx, fy, fz], [x, y, z]) + assert len(sol) >= 4 and len(sol) <= 20 + # nonlinsolve has been giving a varying number of solutions + # (originally 18, then 20, now 19) due to various internal changes. + # Unfortunately not all the solutions are actually valid and some are + # redundant. Since the original issue was that an exception was raised, + # this first test only checks that nonlinsolve returns a "plausible" + # solution set. The next test checks the result for correctness. + + +@XFAIL +def test_issue_18449(): + x, y, z = symbols("x, y, z") + f = (x**2 + y**2)**2 + (x**2 + z**2)**2 - 2*(2*x**2 + y**2 + z**2) + fx = diff(f, x) + fy = diff(f, y) + fz = diff(f, z) + sol = nonlinsolve([fx, fy, fz], [x, y, z]) + for (xs, ys, zs) in sol: + d = {x: xs, y: ys, z: zs} + assert tuple(_.subs(d).simplify() for _ in (fx, fy, fz)) == (0, 0, 0) + # After simplification and removal of duplicate elements, there should + # only be 4 parametric solutions left: + # simplifiedsolutions = FiniteSet((sqrt(1 - z**2), z, z), + # (-sqrt(1 - z**2), z, z), + # (sqrt(1 - z**2), -z, z), + # (-sqrt(1 - z**2), -z, z)) + # TODO: Is the above solution set definitely complete? + + +def test_issue_21047(): + f = (2 - x)**2 + (sqrt(x - 1) - 1)**6 + assert solveset(f, x, S.Reals) == FiniteSet(2) + + f = (sqrt(x)-1)**2 + (sqrt(x)+1)**2 -2*x**2 + sqrt(2) + assert solveset(f, x, S.Reals) == FiniteSet( + S.Half - sqrt(2*sqrt(2) + 5)/2, S.Half + sqrt(2*sqrt(2) + 5)/2) + + +def test_is_function_class_equation(): + assert _is_function_class_equation(TrigonometricFunction, + tan(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + sin(x) - a, x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x + a) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + sin(x)*tan(x*a) + sin(x), x) is True + assert _is_function_class_equation(TrigonometricFunction, + a*tan(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x)**2 + sin(x) - 1, x) is True + assert _is_function_class_equation(TrigonometricFunction, + tan(x) + x, x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x**2), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x**2) + sin(x), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(x)**sin(x), x) is False + assert _is_function_class_equation(TrigonometricFunction, + tan(sin(x)) + sin(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + sinh(x) - a, x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x + a) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + sinh(x)*tanh(x*a) + sinh(x), x) is True + assert _is_function_class_equation(HyperbolicFunction, + a*tanh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x)**2 + sinh(x) - 1, x) is True + assert _is_function_class_equation(HyperbolicFunction, + tanh(x) + x, x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x**2), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x**2) + sinh(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(x)**sinh(x), x) is False + assert _is_function_class_equation(HyperbolicFunction, + tanh(sinh(x)) + sinh(x), x) is False + + +def test_garbage_input(): + raises(ValueError, lambda: solveset_real([y], y)) + x = Symbol('x', real=True) + assert solveset_real(x, 1) == S.EmptySet + assert solveset_real(x - 1, 1) == FiniteSet(x) + assert solveset_real(x, pi) == S.EmptySet + assert solveset_real(x, x**2) == S.EmptySet + + raises(ValueError, lambda: solveset_complex([x], x)) + assert solveset_complex(x, pi) == S.EmptySet + + raises(ValueError, lambda: solveset((x, y), x)) + raises(ValueError, lambda: solveset(x + 1, S.Reals)) + raises(ValueError, lambda: solveset(x + 1, x, 2)) + + +def test_solve_mul(): + assert solveset_real((a*x + b)*(exp(x) - 3), x) == \ + Union({log(3)}, Intersection({-b/a}, S.Reals)) + anz = Symbol('anz', nonzero=True) + bb = Symbol('bb', real=True) + assert solveset_real((anz*x + bb)*(exp(x) - 3), x) == \ + FiniteSet(-bb/anz, log(3)) + assert solveset_real((2*x + 8)*(8 + exp(x)), x) == FiniteSet(S(-4)) + assert solveset_real(x/log(x), x) is S.EmptySet + + +def test_solve_invert(): + assert solveset_real(exp(x) - 3, x) == FiniteSet(log(3)) + assert solveset_real(log(x) - 3, x) == FiniteSet(exp(3)) + + assert solveset_real(3**(x + 2), x) == FiniteSet() + assert solveset_real(3**(2 - x), x) == FiniteSet() + + assert solveset_real(y - b*exp(a/x), x) == Intersection( + S.Reals, FiniteSet(a/log(y/b))) + + # issue 4504 + assert solveset_real(2**x - 10, x) == FiniteSet(1 + log(5)/log(2)) + + +def test_issue_25768(): + assert dumeq(solveset_real(sin(x) - S.Half, x), Union( + ImageSet(Lambda(n, pi*2*n + pi/6), S.Integers), + ImageSet(Lambda(n, pi*2*n + pi*5/6), S.Integers))) + n1 = solveset_real(sin(x) - 0.5, x).n(5) + n2 = solveset_real(sin(x) - S.Half, x).n(5) + # help pass despite fp differences + eq = [i.replace( + lambda x:x.is_Float, + lambda x:Rational(x).limit_denominator(1000)) for i in (n1, n2)] + assert dumeq(*eq),(n1,n2) + + +def test_errorinverses(): + assert solveset_real(erf(x) - S.Half, x) == \ + FiniteSet(erfinv(S.Half)) + assert solveset_real(erfinv(x) - 2, x) == \ + FiniteSet(erf(2)) + assert solveset_real(erfc(x) - S.One, x) == \ + FiniteSet(erfcinv(S.One)) + assert solveset_real(erfcinv(x) - 2, x) == FiniteSet(erfc(2)) + + +def test_solve_polynomial(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset_real(3*x - 2, x) == FiniteSet(Rational(2, 3)) + + assert solveset_real(x**2 - 1, x) == FiniteSet(-S.One, S.One) + assert solveset_real(x - y**3, x) == FiniteSet(y ** 3) + + assert solveset_real(x**3 - 15*x - 4, x) == FiniteSet( + -2 + 3 ** S.Half, + S(4), + -2 - 3 ** S.Half) + + assert solveset_real(sqrt(x) - 1, x) == FiniteSet(1) + assert solveset_real(sqrt(x) - 2, x) == FiniteSet(4) + assert solveset_real(x**Rational(1, 4) - 2, x) == FiniteSet(16) + assert solveset_real(x**Rational(1, 3) - 3, x) == FiniteSet(27) + assert len(solveset_real(x**5 + x**3 + 1, x)) == 1 + assert len(solveset_real(-2*x**3 + 4*x**2 - 2*x + 6, x)) > 0 + assert solveset_real(x**6 + x**4 + I, x) is S.EmptySet + + +def test_return_root_of(): + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = list(solveset_complex(f, x)) + for root in s: + assert root.func == CRootOf + + # if one uses solve to get the roots of a polynomial that has a CRootOf + # solution, make sure that the use of nfloat during the solve process + # doesn't fail. Note: if you want numerical solutions to a polynomial + # it is *much* faster to use nroots to get them than to solve the + # equation only to get CRootOf solutions which are then numerically + # evaluated. So for eq = x**5 + 3*x + 7 do Poly(eq).nroots() rather + # than [i.n() for i in solve(eq)] to get the numerical roots of eq. + assert nfloat(list(solveset_complex(x**5 + 3*x**3 + 7, x))[0], + exponent=False) == CRootOf(x**5 + 3*x**3 + 7, 0).n() + + sol = list(solveset_complex(x**6 - 2*x + 2, x)) + assert all(isinstance(i, CRootOf) for i in sol) and len(sol) == 6 + + f = x**5 - 15*x**3 - 5*x**2 + 10*x + 20 + s = list(solveset_complex(f, x)) + for root in s: + assert root.func == CRootOf + + s = x**5 + 4*x**3 + 3*x**2 + Rational(7, 4) + assert solveset_complex(s, x) == \ + FiniteSet(*Poly(s*4, domain='ZZ').all_roots()) + + # Refer issue #7876 + eq = x*(x - 1)**2*(x + 1)*(x**6 - x + 1) + assert solveset_complex(eq, x) == \ + FiniteSet(-1, 0, 1, CRootOf(x**6 - x + 1, 0), + CRootOf(x**6 - x + 1, 1), + CRootOf(x**6 - x + 1, 2), + CRootOf(x**6 - x + 1, 3), + CRootOf(x**6 - x + 1, 4), + CRootOf(x**6 - x + 1, 5)) + + +def test_solveset_sqrt_1(): + assert solveset_real(sqrt(5*x + 6) - 2 - x, x) == \ + FiniteSet(-S.One, S(2)) + assert solveset_real(sqrt(x - 1) - x + 7, x) == FiniteSet(10) + assert solveset_real(sqrt(x - 2) - 5, x) == FiniteSet(27) + assert solveset_real(sqrt(x) - 2 - 5, x) == FiniteSet(49) + assert solveset_real(sqrt(x**3), x) == FiniteSet(0) + assert solveset_real(sqrt(x - 1), x) == FiniteSet(1) + assert solveset_real(sqrt((x-3)/x), x) == FiniteSet(3) + assert solveset_real(sqrt((x-3)/x)-Rational(1, 2), x) == \ + FiniteSet(4) + +def test_solveset_sqrt_2(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + # http://tutorial.math.lamar.edu/Classes/Alg/SolveRadicalEqns.aspx#Solve_Rad_Ex2_a + assert solveset_real(sqrt(2*x - 1) - sqrt(x - 4) - 2, x) == \ + FiniteSet(S(5), S(13)) + assert solveset_real(sqrt(x + 7) + 2 - sqrt(3 - x), x) == \ + FiniteSet(-6) + + # http://www.purplemath.com/modules/solverad.htm + assert solveset_real(sqrt(17*x - sqrt(x**2 - 5)) - 7, x) == \ + FiniteSet(3) + + eq = x + 1 - (x**4 + 4*x**3 - x)**Rational(1, 4) + assert solveset_real(eq, x) == FiniteSet(Rational(-1, 2), Rational(-1, 3)) + + eq = sqrt(2*x + 9) - sqrt(x + 1) - sqrt(x + 4) + assert solveset_real(eq, x) == FiniteSet(0) + + eq = sqrt(x + 4) + sqrt(2*x - 1) - 3*sqrt(x - 1) + assert solveset_real(eq, x) == FiniteSet(5) + + eq = sqrt(x)*sqrt(x - 7) - 12 + assert solveset_real(eq, x) == FiniteSet(16) + + eq = sqrt(x - 3) + sqrt(x) - 3 + assert solveset_real(eq, x) == FiniteSet(4) + + eq = sqrt(2*x**2 - 7) - (3 - x) + assert solveset_real(eq, x) == FiniteSet(-S(8), S(2)) + + # others + eq = sqrt(9*x**2 + 4) - (3*x + 2) + assert solveset_real(eq, x) == FiniteSet(0) + + assert solveset_real(sqrt(x - 3) - sqrt(x) - 3, x) == FiniteSet() + + eq = (2*x - 5)**Rational(1, 3) - 3 + assert solveset_real(eq, x) == FiniteSet(16) + + assert solveset_real(sqrt(x) + sqrt(sqrt(x)) - 4, x) == \ + FiniteSet((Rational(-1, 2) + sqrt(17)/2)**4) + + eq = sqrt(x) - sqrt(x - 1) + sqrt(sqrt(x)) + assert solveset_real(eq, x) == FiniteSet() + + eq = (x - 4)**2 + (sqrt(x) - 2)**4 + assert solveset_real(eq, x) == FiniteSet(-4, 4) + + eq = (sqrt(x) + sqrt(x + 1) + sqrt(1 - x) - 6*sqrt(5)/5) + ans = solveset_real(eq, x) + ra = S('''-1484/375 - 4*(-S(1)/2 + sqrt(3)*I/2)*(-12459439/52734375 + + 114*sqrt(12657)/78125)**(S(1)/3) - 172564/(140625*(-S(1)/2 + + sqrt(3)*I/2)*(-12459439/52734375 + 114*sqrt(12657)/78125)**(S(1)/3))''') + rb = Rational(4, 5) + assert all(abs(eq.subs(x, i).n()) < 1e-10 for i in (ra, rb)) and \ + len(ans) == 2 and \ + {i.n(chop=True) for i in ans} == \ + {i.n(chop=True) for i in (ra, rb)} + + assert solveset_real(sqrt(x) + x**Rational(1, 3) + + x**Rational(1, 4), x) == FiniteSet(0) + + assert solveset_real(x/sqrt(x**2 + 1), x) == FiniteSet(0) + + eq = (x - y**3)/((y**2)*sqrt(1 - y**2)) + assert solveset_real(eq, x) == FiniteSet(y**3) + + # issue 4497 + assert solveset_real(1/(5 + x)**Rational(1, 5) - 9, x) == \ + FiniteSet(Rational(-295244, 59049)) + + +@XFAIL +def test_solve_sqrt_fail(): + # this only works if we check real_root(eq.subs(x, Rational(1, 3))) + # but checksol doesn't work like that + eq = (x**3 - 3*x**2)**Rational(1, 3) + 1 - x + assert solveset_real(eq, x) == FiniteSet(Rational(1, 3)) + + +@slow +def test_solve_sqrt_3(): + R = Symbol('R') + eq = sqrt(2)*R*sqrt(1/(R + 1)) + (R + 1)*(sqrt(2)*sqrt(1/(R + 1)) - 1) + sol = solveset_complex(eq, R) + fset = [Rational(5, 3) + 4*sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3, + -sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3 + + 40*re(1/((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 + + sqrt(30)*sin(atan(3*sqrt(111)/251)/3)/3 + Rational(5, 3) + + I*(-sqrt(30)*cos(atan(3*sqrt(111)/251)/3)/3 - + sqrt(10)*sin(atan(3*sqrt(111)/251)/3)/3 + + 40*im(1/((Rational(-1, 2) - sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9)] + cset = [40*re(1/((Rational(-1, 2) + sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 - + sqrt(10)*cos(atan(3*sqrt(111)/251)/3)/3 - sqrt(30)*sin(atan(3*sqrt(111)/251)/3)/3 + + Rational(5, 3) + + I*(40*im(1/((Rational(-1, 2) + sqrt(3)*I/2)*(Rational(251, 27) + sqrt(111)*I/9)**Rational(1, 3)))/9 - + sqrt(10)*sin(atan(3*sqrt(111)/251)/3)/3 + + sqrt(30)*cos(atan(3*sqrt(111)/251)/3)/3)] + + fs = FiniteSet(*fset) + cs = ConditionSet(R, Eq(eq, 0), FiniteSet(*cset)) + assert sol == (fs - {-1}) | (cs - {-1}) + + # the number of real roots will depend on the value of m: for m=1 there are 4 + # and for m=-1 there are none. + eq = -sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) + sqrt((-m**2/2 - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - sqrt( + 4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2) + unsolved_object = ConditionSet(q, Eq(sqrt((m - q)**2 + (-m/(2*q) + S.Half)**2) - + sqrt((-m**2/2 - sqrt(4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2 + (m**2/2 - m - + sqrt(4*m**4 - 4*m**2 + 8*m + 1)/4 - Rational(1, 4))**2), 0), S.Reals) + assert solveset_real(eq, q) == unsolved_object + + +def test_solve_polynomial_symbolic_param(): + assert solveset_complex((x**2 - 1)**2 - a, x) == \ + FiniteSet(sqrt(1 + sqrt(a)), -sqrt(1 + sqrt(a)), + sqrt(1 - sqrt(a)), -sqrt(1 - sqrt(a))) + + # issue 4507 + assert solveset_complex(y - b/(1 + a*x), x) == \ + FiniteSet((b/y - 1)/a) - FiniteSet(-1/a) + + # issue 4508 + assert solveset_complex(y - b*x/(a + x), x) == \ + FiniteSet(-a*y/(y - b)) - FiniteSet(-a) + + +def test_solve_rational(): + assert solveset_real(1/x + 1, x) == FiniteSet(-S.One) + assert solveset_real(1/exp(x) - 1, x) == FiniteSet(0) + assert solveset_real(x*(1 - 5/x), x) == FiniteSet(5) + assert solveset_real(2*x/(x + 2) - 1, x) == FiniteSet(2) + assert solveset_real((x**2/(7 - x)).diff(x), x) == \ + FiniteSet(S.Zero, S(14)) + + +def test_solveset_real_gen_is_pow(): + assert solveset_real(sqrt(1) + 1, x) is S.EmptySet + + +def test_no_sol(): + assert solveset(1 - oo*x) is S.EmptySet + assert solveset(oo*x, x) is S.EmptySet + assert solveset(oo*x - oo, x) is S.EmptySet + assert solveset_real(4, x) is S.EmptySet + assert solveset_real(exp(x), x) is S.EmptySet + assert solveset_real(x**2 + 1, x) is S.EmptySet + assert solveset_real(-3*a/sqrt(x), x) is S.EmptySet + assert solveset_real(1/x, x) is S.EmptySet + assert solveset_real(-(1 + x)/(2 + x)**2 + 1/(2 + x), x + ) is S.EmptySet + + +def test_sol_zero_real(): + assert solveset_real(0, x) == S.Reals + assert solveset(0, x, Interval(1, 2)) == Interval(1, 2) + assert solveset_real(-x**2 - 2*x + (x + 1)**2 - 1, x) == S.Reals + + +def test_no_sol_rational_extragenous(): + assert solveset_real((x/(x + 1) + 3)**(-2), x) is S.EmptySet + assert solveset_real((x - 1)/(1 + 1/(x - 1)), x) is S.EmptySet + + +def test_solve_polynomial_cv_1a(): + """ + Test for solving on equations that can be converted to + a polynomial equation using the change of variable y -> x**Rational(p, q) + """ + assert solveset_real(sqrt(x) - 1, x) == FiniteSet(1) + assert solveset_real(sqrt(x) - 2, x) == FiniteSet(4) + assert solveset_real(x**Rational(1, 4) - 2, x) == FiniteSet(16) + assert solveset_real(x**Rational(1, 3) - 3, x) == FiniteSet(27) + assert solveset_real(x*(x**(S.One / 3) - 3), x) == \ + FiniteSet(S.Zero, S(27)) + + +def test_solveset_real_rational(): + """Test solveset_real for rational functions""" + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset_real((x - y**3) / ((y**2)*sqrt(1 - y**2)), x) \ + == FiniteSet(y**3) + # issue 4486 + assert solveset_real(2*x/(x + 2) - 1, x) == FiniteSet(2) + + +def test_solveset_real_log(): + assert solveset_real(log((x-1)*(x+1)), x) == \ + FiniteSet(sqrt(2), -sqrt(2)) + + +def test_poly_gens(): + assert solveset_real(4**(2*(x**2) + 2*x) - 8, x) == \ + FiniteSet(Rational(-3, 2), S.Half) + + +def test_solve_abs(): + n = Dummy('n') + raises(ValueError, lambda: solveset(Abs(x) - 1, x)) + assert solveset(Abs(x) - n, x, S.Reals).dummy_eq( + ConditionSet(x, Contains(n, Interval(0, oo)), {-n, n})) + assert solveset_real(Abs(x) - 2, x) == FiniteSet(-2, 2) + assert solveset_real(Abs(x) + 2, x) is S.EmptySet + assert solveset_real(Abs(x + 3) - 2*Abs(x - 3), x) == \ + FiniteSet(1, 9) + assert solveset_real(2*Abs(x) - Abs(x - 1), x) == \ + FiniteSet(-1, Rational(1, 3)) + + sol = ConditionSet( + x, + And( + Contains(b, Interval(0, oo)), + Contains(a + b, Interval(0, oo)), + Contains(a - b, Interval(0, oo))), + FiniteSet(-a - b - 3, -a + b - 3, a - b - 3, a + b - 3)) + eq = Abs(Abs(x + 3) - a) - b + assert invert_real(eq, 0, x)[1] == sol + reps = {a: 3, b: 1} + eqab = eq.subs(reps) + for si in sol.subs(reps): + assert not eqab.subs(x, si) + assert dumeq(solveset(Eq(sin(Abs(x)), 1), x, domain=S.Reals), Union( + Intersection(Interval(0, oo), Union( + Intersection(ImageSet(Lambda(n, 2*n*pi + 3*pi/2), S.Integers), + Interval(-oo, 0)), + Intersection(ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers), + Interval(0, oo)))))) + + +def test_issue_9824(): + assert dumeq(solveset(sin(x)**2 - 2*sin(x) + 1, x), ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers)) + assert dumeq(solveset(cos(x)**2 - 2*cos(x) + 1, x), ImageSet(Lambda(n, 2*n*pi), S.Integers)) + + +def test_issue_9565(): + assert solveset_real(Abs((x - 1)/(x - 5)) <= Rational(1, 3), x) == Interval(-1, 2) + + +def test_issue_10069(): + eq = abs(1/(x - 1)) - 1 > 0 + assert solveset_real(eq, x) == Union( + Interval.open(0, 1), Interval.open(1, 2)) + + +def test_real_imag_splitting(): + a, b = symbols('a b', real=True) + assert solveset_real(sqrt(a**2 - b**2) - 3, a) == \ + FiniteSet(-sqrt(b**2 + 9), sqrt(b**2 + 9)) + assert solveset_real(sqrt(a**2 + b**2) - 3, a) != \ + S.EmptySet + + +def test_units(): + assert solveset_real(1/x - 1/(2*cm), x) == FiniteSet(2*cm) + + +def test_solve_only_exp_1(): + y = Symbol('y', positive=True) + assert solveset_real(exp(x) - y, x) == FiniteSet(log(y)) + assert solveset_real(exp(x) + exp(-x) - 4, x) == \ + FiniteSet(log(-sqrt(3) + 2), log(sqrt(3) + 2)) + assert solveset_real(exp(x) + exp(-x) - y, x) != S.EmptySet + + +def test_atan2(): + # The .inverse() method on atan2 works only if x.is_real is True and the + # second argument is a real constant + assert solveset_real(atan2(x, 2) - pi/3, x) == FiniteSet(2*sqrt(3)) + + +def test_piecewise_solveset(): + eq = Piecewise((x - 2, Gt(x, 2)), (2 - x, True)) - 3 + assert set(solveset_real(eq, x)) == set(FiniteSet(-1, 5)) + + absxm3 = Piecewise( + (x - 3, 0 <= x - 3), + (3 - x, 0 > x - 3)) + y = Symbol('y', positive=True) + assert solveset_real(absxm3 - y, x) == FiniteSet(-y + 3, y + 3) + + f = Piecewise(((x - 2)**2, x >= 0), (0, True)) + assert solveset(f, x, domain=S.Reals) == Union(FiniteSet(2), Interval(-oo, 0, True, True)) + + assert solveset( + Piecewise((x + 1, x > 0), (I, True)) - I, x, S.Reals + ) == Interval(-oo, 0) + + assert solveset(Piecewise((x - 1, Ne(x, I)), (x, True)), x) == FiniteSet(1) + + # issue 19718 + g = Piecewise((1, x > 10), (0, True)) + assert solveset(g > 0, x, S.Reals) == Interval.open(10, oo) + + from sympy.logic.boolalg import BooleanTrue + f = BooleanTrue() + assert solveset(f, x, domain=Interval(-3, 10)) == Interval(-3, 10) + + # issue 20552 + f = Piecewise((0, Eq(x, 0)), (x**2/Abs(x), True)) + g = Piecewise((0, Eq(x, pi)), ((x - pi)/sin(x), True)) + assert solveset(f, x, domain=S.Reals) == FiniteSet(0) + assert solveset(g) == FiniteSet(pi) + + +def test_solveset_complex_polynomial(): + assert solveset_complex(a*x**2 + b*x + c, x) == \ + FiniteSet(-b/(2*a) - sqrt(-4*a*c + b**2)/(2*a), + -b/(2*a) + sqrt(-4*a*c + b**2)/(2*a)) + + assert solveset_complex(x - y**3, y) == FiniteSet( + (-x**Rational(1, 3))/2 + I*sqrt(3)*x**Rational(1, 3)/2, + x**Rational(1, 3), + (-x**Rational(1, 3))/2 - I*sqrt(3)*x**Rational(1, 3)/2) + + assert solveset_complex(x + 1/x - 1, x) == \ + FiniteSet(S.Half + I*sqrt(3)/2, S.Half - I*sqrt(3)/2) + + +def test_sol_zero_complex(): + assert solveset_complex(0, x) is S.Complexes + + +def test_solveset_complex_rational(): + assert solveset_complex((x - 1)*(x - I)/(x - 3), x) == \ + FiniteSet(1, I) + + assert solveset_complex((x - y**3)/((y**2)*sqrt(1 - y**2)), x) == \ + FiniteSet(y**3) + assert solveset_complex(-x**2 - I, x) == \ + FiniteSet(-sqrt(2)/2 + sqrt(2)*I/2, sqrt(2)/2 - sqrt(2)*I/2) + + +def test_solve_quintics(): + skip("This test is too slow") + f = x**5 - 110*x**3 - 55*x**2 + 2310*x + 979 + s = solveset_complex(f, x) + for root in s: + res = f.subs(x, root.n()).n() + assert tn(res, 0) + + f = x**5 + 15*x + 12 + s = solveset_complex(f, x) + for root in s: + res = f.subs(x, root.n()).n() + assert tn(res, 0) + + +def test_solveset_complex_exp(): + assert dumeq(solveset_complex(exp(x) - 1, x), + imageset(Lambda(n, I*2*n*pi), S.Integers)) + assert dumeq(solveset_complex(exp(x) - I, x), + imageset(Lambda(n, I*(2*n*pi + pi/2)), S.Integers)) + assert solveset_complex(1/exp(x), x) == S.EmptySet + assert dumeq(solveset_complex(sinh(x).rewrite(exp), x), + imageset(Lambda(n, n*pi*I), S.Integers)) + + +def test_solveset_real_exp(): + assert solveset(Eq((-2)**x, 4), x, S.Reals) == FiniteSet(2) + assert solveset(Eq(-2**x, 4), x, S.Reals) == S.EmptySet + assert solveset(Eq((-3)**x, 27), x, S.Reals) == S.EmptySet + assert solveset(Eq((-5)**(x+1), 625), x, S.Reals) == FiniteSet(3) + assert solveset(Eq(2**(x-3), -16), x, S.Reals) == S.EmptySet + assert solveset(Eq((-3)**(x - 3), -3**39), x, S.Reals) == FiniteSet(42) + assert solveset(Eq(2**x, y), x, S.Reals) == Intersection(S.Reals, FiniteSet(log(y)/log(2))) + + assert invert_real((-2)**(2*x) - 16, 0, x) == (x, FiniteSet(2)) + + +def test_solve_complex_log(): + assert solveset_complex(log(x), x) == FiniteSet(1) + assert solveset_complex(1 - log(a + 4*x**2), x) == \ + FiniteSet(-sqrt(-a + E)/2, sqrt(-a + E)/2) + + +def test_solve_complex_sqrt(): + assert solveset_complex(sqrt(5*x + 6) - 2 - x, x) == \ + FiniteSet(-S.One, S(2)) + assert solveset_complex(sqrt(5*x + 6) - (2 + 2*I) - x, x) == \ + FiniteSet(-S(2), 3 - 4*I) + assert solveset_complex(4*x*(1 - a * sqrt(x)), x) == \ + FiniteSet(S.Zero, 1 / a ** 2) + + +def test_solveset_complex_tan(): + s = solveset_complex(tan(x).rewrite(exp), x) + assert dumeq(s, imageset(Lambda(n, pi*n), S.Integers) - \ + imageset(Lambda(n, pi*n + pi/2), S.Integers)) + + +@_both_exp_pow +def test_solve_trig(): + assert dumeq(solveset_real(sin(x), x), + Union(imageset(Lambda(n, 2*pi*n), S.Integers), + imageset(Lambda(n, 2*pi*n + pi), S.Integers))) + + assert dumeq(solveset_real(sin(x) - 1, x), + imageset(Lambda(n, 2*pi*n + pi/2), S.Integers)) + + assert dumeq(solveset_real(cos(x), x), + Union(imageset(Lambda(n, 2*pi*n + pi/2), S.Integers), + imageset(Lambda(n, 2*pi*n + pi*Rational(3, 2)), S.Integers))) + + assert dumeq(solveset_real(sin(x) + cos(x), x), + Union(imageset(Lambda(n, 2*n*pi + pi*Rational(3, 4)), S.Integers), + imageset(Lambda(n, 2*n*pi + pi*Rational(7, 4)), S.Integers))) + + assert solveset_real(sin(x)**2 + cos(x)**2, x) == S.EmptySet + + assert dumeq(solveset_complex(cos(x) - S.Half, x), + Union(imageset(Lambda(n, 2*n*pi + pi*Rational(5, 3)), S.Integers), + imageset(Lambda(n, 2*n*pi + pi/3), S.Integers))) + + assert dumeq(solveset(sin(y + a) - sin(y), a, domain=S.Reals), + ConditionSet(a, (S(-1) <= sin(y)) & (sin(y) <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi - y + asin(sin(y))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - y - asin(sin(y)) + pi), S.Integers)))) + + assert dumeq(solveset_real(sin(2*x)*cos(x) + cos(2*x)*sin(x)-1, x), + ImageSet(Lambda(n, n*pi*Rational(2, 3) + pi/6), S.Integers)) + + assert dumeq(solveset_real(2*tan(x)*sin(x) + 1, x), Union( + ImageSet(Lambda(n, 2*n*pi + atan(sqrt(2)*sqrt(-1 + sqrt(17))/ + (1 - sqrt(17))) + pi), S.Integers), + ImageSet(Lambda(n, 2*n*pi - atan(sqrt(2)*sqrt(-1 + sqrt(17))/ + (1 - sqrt(17))) + pi), S.Integers))) + + assert dumeq(solveset_real(cos(2*x)*cos(4*x) - 1, x), + ImageSet(Lambda(n, n*pi), S.Integers)) + + assert dumeq(solveset(sin(x/10) + Rational(3, 4)), Union( + ImageSet(Lambda(n, 20*n*pi - 10*asin(S(3)/4) + 20*pi), S.Integers), + ImageSet(Lambda(n, 20*n*pi + 10*asin(S(3)/4) + 10*pi), S.Integers))) + + assert dumeq(solveset(cos(x/15) + cos(x/5)), Union( + ImageSet(Lambda(n, 30*n*pi + 15*pi/2), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 45*pi/2), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 75*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 45*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 105*pi/4), S.Integers), + ImageSet(Lambda(n, 30*n*pi + 15*pi/4), S.Integers))) + + assert dumeq(solveset(sec(sqrt(2)*x/3) + 5), Union( + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi - asec(-5))/2), S.Integers), + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*pi + asec(-5))/2), S.Integers))) + + assert dumeq(simplify(solveset(tan(pi*x) - cot(pi/2*x))), Union( + ImageSet(Lambda(n, 4*n + 1), S.Integers), + ImageSet(Lambda(n, 4*n + 3), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(7, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(5, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(11, 3)), S.Integers), + ImageSet(Lambda(n, 4*n + Rational(1, 3)), S.Integers))) + + assert dumeq(solveset(cos(9*x)), Union( + ImageSet(Lambda(n, 2*n*pi/9 + pi/18), S.Integers), + ImageSet(Lambda(n, 2*n*pi/9 + pi/6), S.Integers))) + + assert dumeq(solveset(sin(8*x) + cot(12*x), x, S.Reals), Union( + ImageSet(Lambda(n, n*pi/2 + pi/8), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 3*pi/8), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 5*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 3*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + 7*pi/16), S.Integers), + ImageSet(Lambda(n, n*pi/2 + pi/16), S.Integers))) + + # This is the only remaining solveset test that actually ends up being solved + # by _solve_trig2(). All others are handled by the improved _solve_trig1. + assert dumeq(solveset_real(2*cos(x)*cos(2*x) - 1, x), + Union(ImageSet(Lambda(n, 2*n*pi + 2*atan(sqrt(-2*2**Rational(1, 3)*(67 + + 9*sqrt(57))**Rational(2, 3) + 8*2**Rational(2, 3) + 11*(67 + + 9*sqrt(57))**Rational(1, 3))/(3*(67 + 9*sqrt(57))**Rational(1, 6)))), S.Integers), + ImageSet(Lambda(n, 2*n*pi - 2*atan(sqrt(-2*2**Rational(1, 3)*(67 + + 9*sqrt(57))**Rational(2, 3) + 8*2**Rational(2, 3) + 11*(67 + + 9*sqrt(57))**Rational(1, 3))/(3*(67 + 9*sqrt(57))**Rational(1, 6))) + + 2*pi), S.Integers))) + + # issue #16870 + assert dumeq(simplify(solveset(sin(x/180*pi) - S.Half, x, S.Reals)), Union( + ImageSet(Lambda(n, 360*n + 150), S.Integers), + ImageSet(Lambda(n, 360*n + 30), S.Integers))) + + +def test_solve_trig_hyp_by_inversion(): + n = Dummy('n') + assert solveset_real(sin(2*x + 3) - S(1)/2, x).dummy_eq(Union( + ImageSet(Lambda(n, n*pi - S(3)/2 + 13*pi/12), S.Integers), + ImageSet(Lambda(n, n*pi - S(3)/2 + 17*pi/12), S.Integers))) + assert solveset_complex(sin(2*x + 3) - S(1)/2, x).dummy_eq(Union( + ImageSet(Lambda(n, n*pi - S(3)/2 + 13*pi/12), S.Integers), + ImageSet(Lambda(n, n*pi - S(3)/2 + 17*pi/12), S.Integers))) + assert solveset_real(tan(x) - tan(pi/10), x).dummy_eq( + ImageSet(Lambda(n, n*pi + pi/10), S.Integers)) + assert solveset_complex(tan(x) - tan(pi/10), x).dummy_eq( + ImageSet(Lambda(n, n*pi + pi/10), S.Integers)) + + assert solveset_real(3*cosh(2*x) - 5, x) == FiniteSet( + -acosh(S(5)/3)/2, acosh(S(5)/3)/2) + assert solveset_complex(3*cosh(2*x) - 5, x).dummy_eq(Union( + ImageSet(Lambda(n, n*I*pi - acosh(S(5)/3)/2), S.Integers), + ImageSet(Lambda(n, n*I*pi + acosh(S(5)/3)/2), S.Integers))) + assert solveset_real(sinh(x - 3) - 2, x) == FiniteSet( + asinh(2) + 3) + assert solveset_complex(sinh(x - 3) - 2, x).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi + asinh(2) + 3), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi - asinh(2) + 3 + I*pi), S.Integers))) + + assert solveset_real(cos(sinh(x))-cos(pi/12), x).dummy_eq(Union( + ImageSet(Lambda(n, asinh(2*n*pi + pi/12)), S.Integers), + ImageSet(Lambda(n, asinh(2*n*pi + 23*pi/12)), S.Integers))) + assert solveset(cos(sinh(x))-cos(pi/12), x, Interval(2,3)) == \ + FiniteSet(asinh(23*pi/12), asinh(25*pi/12)) + assert solveset_real(cosh(x**2-1)-2, x) == FiniteSet( + -sqrt(1 + acosh(2)), sqrt(1 + acosh(2))) + + assert solveset_real(sin(x) - 2, x) == S.EmptySet # issue #17334 + assert solveset_real(cos(x) + 2, x) == S.EmptySet + assert solveset_real(sec(x), x) == S.EmptySet + assert solveset_real(csc(x), x) == S.EmptySet + assert solveset_real(cosh(x) + 1, x) == S.EmptySet + assert solveset_real(coth(x), x) == S.EmptySet + assert solveset_real(sech(x) - 2, x) == S.EmptySet + assert solveset_real(sech(x), x) == S.EmptySet + assert solveset_real(tanh(x) + 1, x) == S.EmptySet + assert solveset_complex(tanh(x), 1) == S.EmptySet + assert solveset_complex(coth(x), -1) == S.EmptySet + assert solveset_complex(sech(x), 0) == S.EmptySet + assert solveset_complex(csch(x), 0) == S.EmptySet + + assert solveset_real(abs(csch(x)) - 3, x) == FiniteSet(-acsch(3), acsch(3)) + + assert solveset_real(tanh(x**2 - 1) - exp(-9), x) == FiniteSet( + -sqrt(atanh(exp(-9)) + 1), sqrt(atanh(exp(-9)) + 1)) + + assert solveset_real(coth(log(x)) + 2, x) == FiniteSet(exp(-acoth(2))) + assert solveset_real(coth(exp(x)) + 2, x) == S.EmptySet + + assert solveset_complex(sinh(x) - I/2, x).dummy_eq(Union( + ImageSet(Lambda(n, 2*I*pi*n + 5*I*pi/6), S.Integers), + ImageSet(Lambda(n, 2*I*pi*n + I*pi/6), S.Integers))) + assert solveset_complex(sinh(x/10) + Rational(3, 4), x).dummy_eq(Union( + ImageSet(Lambda(n, 20*n*I*pi - 10*asinh(S(3)/4)), S.Integers), + ImageSet(Lambda(n, 20*n*I*pi + 10*asinh(S(3)/4) + 10*I*pi), S.Integers))) + assert solveset_complex(sech(sqrt(2)*x/3) + 5, x).dummy_eq(Union( + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*I*pi - asech(-5))/2), S.Integers), + ImageSet(Lambda(n, 3*sqrt(2)*(2*n*I*pi + asech(-5))/2), S.Integers))) + assert solveset_complex(cosh(9*x), x).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi/9 + I*pi/18), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi/9 + I*pi/6), S.Integers))) + + eq = (x**5 -4*x + 1).subs(x, coth(z)) + assert solveset(eq, z, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 0))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 1))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 2))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 3))), S.Integers), + ImageSet(Lambda(n, n*I*pi + acoth(CRootOf(x**5 -4*x + 1, 4))), S.Integers))) + assert solveset(eq, z, S.Reals) == FiniteSet( + acoth(CRootOf(x**5 - 4*x + 1, 0)), acoth(CRootOf(x**5 - 4*x + 1, 2))) + + eq = ((x-sqrt(3)/2)*(x+2)).expand().subs(x, cos(x)) + assert solveset(eq, x, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi - acos(-2)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(-2)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + 11*pi/6), S.Integers))) + assert solveset(eq, x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + 11*pi/6), S.Integers))) + + assert solveset((1+sec(sqrt(3)*x+4)**2)/(1-sec(sqrt(3)*x+4))).dummy_eq(Union( + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 - asec(I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 + asec(I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 - asec(-I))/3), S.Integers), + ImageSet(Lambda(n, sqrt(3)*(2*n*pi - 4 + asec(-I))/3), S.Integers))) + + assert all_close(solveset(tan(3.14*x)**(S(3)/2)-5.678, x, Interval(0, 3)), + FiniteSet(0.403301114561067, 0.403301114561067 + 0.318471337579618*pi, + 0.403301114561067 + 0.636942675159236*pi)) + + +def test_old_trig_issues(): + # issues #9606 / #9531: + assert solveset(sinh(x), x, S.Reals) == FiniteSet(0) + assert solveset(sinh(x), x, S.Complexes).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*I*pi), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi + I*pi), S.Integers))) + + # issues #11218 / #18427 + assert solveset(sin(pi*x), x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers), + ImageSet(Lambda(n, 2*n), S.Integers))) + assert solveset(sin(pi*x), x).dummy_eq(Union( + ImageSet(Lambda(n, (2*n*pi + pi)/pi), S.Integers), + ImageSet(Lambda(n, 2*n), S.Integers))) + + # issue #17543 + assert solveset(I*cot(8*x - 8*E), x).dummy_eq( + ImageSet(Lambda(n, pi*n/8 - 13*pi/16 + E), S.Integers)) + + # issue #20798 + assert all_close(solveset(cos(2*x) - 0.5, x, Interval(0, 2*pi)), FiniteSet( + 0.523598775598299, -0.523598775598299 + pi, + -0.523598775598299 + 2*pi, 0.523598775598299 + pi)) + sol = Union(ImageSet(Lambda(n, n*pi - 0.523598775598299), S.Integers), + ImageSet(Lambda(n, n*pi + 0.523598775598299), S.Integers)) + ret = solveset(cos(2*x) - 0.5, x, S.Reals) + # replace Dummy n by the regular Symbol n to allow all_close comparison. + ret = ret.subs(ret.atoms(Dummy).pop(), n) + assert all_close(ret, sol) + ret = solveset(cos(2*x) - 0.5, x, S.Complexes) + ret = ret.subs(ret.atoms(Dummy).pop(), n) + assert all_close(ret, sol) + + # issue #21296 / #17667 + assert solveset(tan(x)-sqrt(2), x, Interval(0, pi/2)) == FiniteSet(atan(sqrt(2))) + assert solveset(tan(x)-pi, x, Interval(0, pi/2)) == FiniteSet(atan(pi)) + + # issue #17667 + # not yet working properly: + # solveset(cos(x)-y, x, Interval(0, pi)) + assert solveset(cos(x)-y, x, S.Reals).dummy_eq( + ConditionSet(x,(S(-1) <= y) & (y <= S(1)), Union( + ImageSet(Lambda(n, 2*n*pi - acos(y)), S.Integers), + ImageSet(Lambda(n, 2*n*pi + acos(y)), S.Integers)))) + + # issue #17579 + # Valid result, but the intersection could potentially be simplified. + assert solveset(sin(log(x)), x, Interval(0,1, True, False)).dummy_eq( + Union(Intersection(ImageSet(Lambda(n, exp(2*n*pi)), S.Integers), Interval.Lopen(0, 1)), + Intersection(ImageSet(Lambda(n, exp(2*n*pi + pi)), S.Integers), Interval.Lopen(0, 1)))) + + # issue #17334 + assert solveset(sin(x) - sin(1), x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + 1), S.Integers), + ImageSet(Lambda(n, 2*n*pi - 1 + pi), S.Integers))) + assert solveset(sin(x) - sqrt(5)/3, x, S.Reals).dummy_eq(Union( + ImageSet(Lambda(n, 2*n*pi + asin(sqrt(5)/3)), S.Integers), + ImageSet(Lambda(n, 2*n*pi - asin(sqrt(5)/3) + pi), S.Integers))) + assert solveset(sinh(x)-cosh(2), x, S.Reals) == FiniteSet(asinh(cosh(2))) + + # issue 9825 + assert solveset(Eq(tan(x), y), x, domain=S.Reals).dummy_eq( + ConditionSet(x, (-oo < y) & (y < oo), + ImageSet(Lambda(n, n*pi + atan(y)), S.Integers))) + r = Symbol('r', real=True) + assert solveset(Eq(tan(x), r), x, domain=S.Reals).dummy_eq( + ImageSet(Lambda(n, n*pi + atan(r)), S.Integers)) + + +def test_solve_hyperbolic(): + # actual solver: _solve_trig1 + n = Dummy('n') + assert solveset(sinh(x) + cosh(x), x) == S.EmptySet + assert solveset(sinh(x) + cos(x), x) == ConditionSet(x, + Eq(cos(x) + sinh(x), 0), S.Complexes) + assert solveset_real(sinh(x) + sech(x), x) == FiniteSet( + log(sqrt(sqrt(5) - 2))) + assert solveset_real(cosh(2*x) + 2*sinh(x) - 5, x) == FiniteSet( + log(-2 + sqrt(5)), log(1 + sqrt(2))) + assert solveset_real((coth(x) + sinh(2*x))/cosh(x) - 3, x) == FiniteSet( + log(S.Half + sqrt(5)/2), log(1 + sqrt(2))) + assert solveset_real(cosh(x)*sinh(x) - 2, x) == FiniteSet( + log(4 + sqrt(17))/2) + assert solveset_real(sinh(x) + tanh(x) - 1, x) == FiniteSet( + log(sqrt(2)/2 + sqrt(-S(1)/2 + sqrt(2)))) + + assert dumeq(solveset_complex(sinh(x) + sech(x), x), Union( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(-2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi) + log(sqrt(-2 + sqrt(5)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - pi/2) + log(sqrt(2 + sqrt(5)))), S.Integers))) + + assert dumeq(solveset(cosh(x/15) + cosh(x/5)), Union( + ImageSet(Lambda(n, 15*I*(2*n*pi + pi/2)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - pi/2)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - 3*pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi + 3*pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi - pi/4)), S.Integers), + ImageSet(Lambda(n, 15*I*(2*n*pi + pi/4)), S.Integers))) + + assert dumeq(solveset(tanh(pi*x) - coth(pi/2*x)), Union( + ImageSet(Lambda(n, 2*I*(2*n*pi + pi/2)/pi), S.Integers), + ImageSet(Lambda(n, 2*I*(2*n*pi - pi/2)/pi), S.Integers))) + + # issues #18490 / #19489 + assert solveset(cosh(x) + cosh(3*x) - cosh(5*x), x, S.Reals + ).dummy_eq(ConditionSet(x, + Eq(cosh(x) + cosh(3*x) - cosh(5*x), 0), S.Reals)) + assert solveset(sinh(8*x) + coth(12*x)).dummy_eq( + ConditionSet(x, Eq(sinh(8*x) + coth(12*x), 0), S.Complexes)) + + +def test_solve_trig_hyp_symbolic(): + # actual solver: invert_trig_hyp + assert dumeq(solveset(sin(a*x), x), ConditionSet(x, Ne(a, 0), Union( + ImageSet(Lambda(n, (2*n*pi + pi)/a), S.Integers), + ImageSet(Lambda(n, 2*n*pi/a), S.Integers)))) + + assert dumeq(solveset(cosh(x/a), x), ConditionSet(x, Ne(a, 0), Union( + ImageSet(Lambda(n, a*(2*n*I*pi + I*pi/2)), S.Integers), + ImageSet(Lambda(n, a*(2*n*I*pi + 3*I*pi/2)), S.Integers)))) + + assert dumeq(solveset(sin(2*sqrt(3)/3*a**2/(b*pi)*x) + + cos(4*sqrt(3)/3*a**2/(b*pi)*x), x), + ConditionSet(x, Ne(b, 0) & Ne(a**2, 0), Union( + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi + pi/2)/(2*a**2)), S.Integers), + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - 5*pi/6)/(2*a**2)), S.Integers), + ImageSet(Lambda(n, sqrt(3)*pi*b*(2*n*pi - pi/6)/(2*a**2)), S.Integers)))) + + assert dumeq(solveset(cosh((a**2 + 1)*x) - 3, x), ConditionSet( + x, Ne(a**2 + 1, 0), Union( + ImageSet(Lambda(n, (2*n*I*pi - acosh(3))/(a**2 + 1)), S.Integers), + ImageSet(Lambda(n, (2*n*I*pi + acosh(3))/(a**2 + 1)), S.Integers)))) + + ar = Symbol('ar', real=True) + assert solveset(cosh((ar**2 + 1)*x) - 2, x, S.Reals) == FiniteSet( + -acosh(2)/(ar**2 + 1), acosh(2)/(ar**2 + 1)) + + # actual solver: _solve_trig1 + assert dumeq(simplify(solveset(cot((1 + I)*x) - cot((3 + 3*I)*x), x)), Union( + ImageSet(Lambda(n, pi*(1 - I)*(4*n + 1)/4), S.Integers), + ImageSet(Lambda(n, pi*(1 - I)*(4*n - 1)/4), S.Integers))) + + +def test_issue_9616(): + assert dumeq(solveset(sinh(x) + tanh(x) - 1, x), Union( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2)))) + + log(sqrt(1 + sqrt(2)))), S.Integers))) + f1 = (sinh(x)).rewrite(exp) + f2 = (tanh(x)).rewrite(exp) + assert dumeq(solveset(f1 + f2 - 1, x), Union( + Complement(ImageSet( + Lambda(n, I*(2*n*pi + pi) + log(-sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement(ImageSet(Lambda(n, I*(2*n*pi - pi + atan(sqrt(2)*sqrt(S.Half + sqrt(2)))) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement(ImageSet(Lambda(n, I*(2*n*pi - atan(sqrt(2)*sqrt(S.Half + sqrt(2))) + pi) + + log(sqrt(1 + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)), + Complement( + ImageSet(Lambda(n, 2*n*I*pi + log(sqrt(2)/2 + sqrt(-S.Half + sqrt(2)))), S.Integers), + ImageSet(Lambda(n, I*(2*n*pi + pi)/2), S.Integers)))) + + +def test_solve_invalid_sol(): + assert 0 not in solveset_real(sin(x)/x, x) + assert 0 not in solveset_complex((exp(x) - 1)/x, x) + + +@XFAIL +def test_solve_trig_simplified(): + n = Dummy('n') + assert dumeq(solveset_real(sin(x), x), + imageset(Lambda(n, n*pi), S.Integers)) + + assert dumeq(solveset_real(cos(x), x), + imageset(Lambda(n, n*pi + pi/2), S.Integers)) + + assert dumeq(solveset_real(cos(x) + sin(x), x), + imageset(Lambda(n, n*pi - pi/4), S.Integers)) + + +@XFAIL +def test_solve_lambert(): + assert solveset_real(x*exp(x) - 1, x) == FiniteSet(LambertW(1)) + assert solveset_real(exp(x) + x, x) == FiniteSet(-LambertW(1)) + assert solveset_real(x + 2**x, x) == \ + FiniteSet(-LambertW(log(2))/log(2)) + + # issue 4739 + ans = solveset_real(3*x + 5 + 2**(-5*x + 3), x) + assert ans == FiniteSet(Rational(-5, 3) + + LambertW(-10240*2**Rational(1, 3)*log(2)/3)/(5*log(2))) + + eq = 2*(3*x + 4)**5 - 6*7**(3*x + 9) + result = solveset_real(eq, x) + ans = FiniteSet((log(2401) + + 5*LambertW(-log(7**(7*3**Rational(1, 5)/5))))/(3*log(7))/-1) + assert result == ans + assert solveset_real(eq.expand(), x) == result + + assert solveset_real(5*x - 1 + 3*exp(2 - 7*x), x) == \ + FiniteSet(Rational(1, 5) + LambertW(-21*exp(Rational(3, 5))/5)/7) + + assert solveset_real(2*x + 5 + log(3*x - 2), x) == \ + FiniteSet(Rational(2, 3) + LambertW(2*exp(Rational(-19, 3))/3)/2) + + assert solveset_real(3*x + log(4*x), x) == \ + FiniteSet(LambertW(Rational(3, 4))/3) + + assert solveset_real(x**x - 2) == FiniteSet(exp(LambertW(log(2)))) + + a = Symbol('a') + assert solveset_real(-a*x + 2*x*log(x), x) == FiniteSet(exp(a/2)) + a = Symbol('a', real=True) + assert solveset_real(a/x + exp(x/2), x) == \ + FiniteSet(2*LambertW(-a/2)) + assert solveset_real((a/x + exp(x/2)).diff(x), x) == \ + FiniteSet(4*LambertW(sqrt(2)*sqrt(a)/4)) + + # coverage test + assert solveset_real(tanh(x + 3)*tanh(x - 3) - 1, x) is S.EmptySet + + assert solveset_real((x**2 - 2*x + 1).subs(x, log(x) + 3*x), x) == \ + FiniteSet(LambertW(3*S.Exp1)/3) + assert solveset_real((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) == \ + FiniteSet(LambertW(3*exp(-sqrt(2)))/3, LambertW(3*exp(sqrt(2)))/3) + assert solveset_real((x**2 - 2*x - 2).subs(x, log(x) + 3*x), x) == \ + FiniteSet(LambertW(3*exp(1 + sqrt(3)))/3, LambertW(3*exp(-sqrt(3) + 1))/3) + assert solveset_real(x*log(x) + 3*x + 1, x) == \ + FiniteSet(exp(-3 + LambertW(-exp(3)))) + eq = (x*exp(x) - 3).subs(x, x*exp(x)) + assert solveset_real(eq, x) == \ + FiniteSet(LambertW(3*exp(-LambertW(3)))) + + assert solveset_real(3*log(a**(3*x + 5)) + a**(3*x + 5), x) == \ + FiniteSet(-((log(a**5) + LambertW(Rational(1, 3)))/(3*log(a)))) + p = symbols('p', positive=True) + assert solveset_real(3*log(p**(3*x + 5)) + p**(3*x + 5), x) == \ + FiniteSet( + log((-3**Rational(1, 3) - 3**Rational(5, 6)*I)*LambertW(Rational(1, 3))**Rational(1, 3)/(2*p**Rational(5, 3)))/log(p), + log((-3**Rational(1, 3) + 3**Rational(5, 6)*I)*LambertW(Rational(1, 3))**Rational(1, 3)/(2*p**Rational(5, 3)))/log(p), + log((3*LambertW(Rational(1, 3))/p**5)**(1/(3*log(p)))),) # checked numerically + # check collection + b = Symbol('b') + eq = 3*log(a**(3*x + 5)) + b*log(a**(3*x + 5)) + a**(3*x + 5) + assert solveset_real(eq, x) == FiniteSet( + -((log(a**5) + LambertW(1/(b + 3)))/(3*log(a)))) + + # issue 4271 + assert solveset_real((a/x + exp(x/2)).diff(x, 2), x) == FiniteSet( + 6*LambertW((-1)**Rational(1, 3)*a**Rational(1, 3)/3)) + + assert solveset_real(x**3 - 3**x, x) == \ + FiniteSet(-3/log(3)*LambertW(-log(3)/3)) + assert solveset_real(3**cos(x) - cos(x)**3) == FiniteSet( + acos(-3*LambertW(-log(3)/3)/log(3))) + + assert solveset_real(x**2 - 2**x, x) == \ + solveset_real(-x**2 + 2**x, x) + + assert solveset_real(3*log(x) - x*log(3)) == FiniteSet( + -3*LambertW(-log(3)/3)/log(3), + -3*LambertW(-log(3)/3, -1)/log(3)) + + assert solveset_real(LambertW(2*x) - y) == FiniteSet( + y*exp(y)/2) + + +@XFAIL +def test_other_lambert(): + a = Rational(6, 5) + assert solveset_real(x**a - a**x, x) == FiniteSet( + a, -a*LambertW(-log(a)/a)/log(a)) + + +@_both_exp_pow +def test_solveset(): + f = Function('f') + raises(ValueError, lambda: solveset(x + y)) + assert solveset(x, 1) == S.EmptySet + assert solveset(f(1)**2 + y + 1, f(1) + ) == FiniteSet(-sqrt(-y - 1), sqrt(-y - 1)) + assert solveset(f(1)**2 - 1, f(1), S.Reals) == FiniteSet(-1, 1) + assert solveset(f(1)**2 + 1, f(1)) == FiniteSet(-I, I) + assert solveset(x - 1, 1) == FiniteSet(x) + assert solveset(sin(x) - cos(x), sin(x)) == FiniteSet(cos(x)) + + assert solveset(0, domain=S.Reals) == S.Reals + assert solveset(1) == S.EmptySet + assert solveset(True, domain=S.Reals) == S.Reals # issue 10197 + assert solveset(False, domain=S.Reals) == S.EmptySet + + assert solveset(exp(x) - 1, domain=S.Reals) == FiniteSet(0) + assert solveset(exp(x) - 1, x, S.Reals) == FiniteSet(0) + assert solveset(Eq(exp(x), 1), x, S.Reals) == FiniteSet(0) + assert solveset(exp(x) - 1, exp(x), S.Reals) == FiniteSet(1) + A = Indexed('A', x) + assert solveset(A - 1, A, S.Reals) == FiniteSet(1) + + assert solveset(x - 1 >= 0, x, S.Reals) == Interval(1, oo) + assert solveset(exp(x) - 1 >= 0, x, S.Reals) == Interval(0, oo) + + assert dumeq(solveset(exp(x) - 1, x), imageset(Lambda(n, 2*I*pi*n), S.Integers)) + assert dumeq(solveset(Eq(exp(x), 1), x), imageset(Lambda(n, 2*I*pi*n), + S.Integers)) + # issue 13825 + assert solveset(x**2 + f(0) + 1, x) == {-sqrt(-f(0) - 1), sqrt(-f(0) - 1)} + + # issue 19977 + assert solveset(atan(log(x)) > 0, x, domain=Interval.open(0, oo)) == Interval.open(1, oo) + + +@_both_exp_pow +def test_multi_exp(): + k1, k2, k3 = symbols('k1, k2, k3') + assert dumeq(solveset(exp(exp(x)) - 5, x),\ + imageset(Lambda(((k1, n),), I*(2*k1*pi + arg(2*n*I*pi + log(5))) + log(Abs(2*n*I*pi + log(5)))),\ + ProductSet(S.Integers, S.Integers))) + assert dumeq(solveset((d*exp(exp(a*x + b)) + c), x),\ + imageset(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k1, n),), \ + I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))), \ + ProductSet(S.Integers, S.Integers)))) + + assert dumeq(solveset((d*exp(exp(exp(a*x + b))) + c), x),\ + imageset(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k2, k1, n),), \ + I*(2*k2*pi + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + \ + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + \ + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))))), \ + ProductSet(S.Integers, S.Integers, S.Integers)))) + + assert dumeq(solveset((d*exp(exp(exp(exp(a*x + b)))) + c), x),\ + ImageSet(Lambda(x, (-b + x)/a), ImageSet(Lambda(((k3, k2, k1, n),), \ + I*(2*k3*pi + arg(I*(2*k2*pi + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + \ + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + \ + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))))) + log(Abs(I*(2*k2*pi + \ + arg(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))))) + \ + log(Abs(I*(2*k1*pi + arg(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d)))) + log(Abs(I*(2*n*pi + arg(-c/d)) + log(Abs(c/d))))))))), \ + ProductSet(S.Integers, S.Integers, S.Integers, S.Integers)))) + + +def test__solveset_multi(): + from sympy.solvers.solveset import _solveset_multi + from sympy.sets import Reals + + # Basic univariate case: + assert _solveset_multi([x**2-1], [x], [S.Reals]) == FiniteSet((1,), (-1,)) + + # Linear systems of two equations + assert _solveset_multi([x+y, x+1], [x, y], [Reals, Reals]) == FiniteSet((-1, 1)) + assert _solveset_multi([x+y, x+1], [y, x], [Reals, Reals]) == FiniteSet((1, -1)) + assert _solveset_multi([x+y, x-y-1], [x, y], [Reals, Reals]) == FiniteSet((S(1)/2, -S(1)/2)) + assert _solveset_multi([x-1, y-2], [x, y], [Reals, Reals]) == FiniteSet((1, 2)) + # assert dumeq(_solveset_multi([x+y], [x, y], [Reals, Reals]), ImageSet(Lambda(x, (x, -x)), Reals)) + assert dumeq(_solveset_multi([x+y], [x, y], [Reals, Reals]), Union( + ImageSet(Lambda(((x,),), (x, -x)), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (-y, y)), ProductSet(Reals)))) + assert _solveset_multi([x+y, x+y+1], [x, y], [Reals, Reals]) == S.EmptySet + assert _solveset_multi([x+y, x-y, x-1], [x, y], [Reals, Reals]) == S.EmptySet + assert _solveset_multi([x+y, x-y, x-1], [y, x], [Reals, Reals]) == S.EmptySet + + # Systems of three equations: + assert _solveset_multi([x+y+z-1, x+y-z-2, x-y-z-3], [x, y, z], [Reals, + Reals, Reals]) == FiniteSet((2, -S.Half, -S.Half)) + + # Nonlinear systems: + from sympy.abc import theta + assert _solveset_multi([x**2+y**2-2, x+y], [x, y], [Reals, Reals]) == FiniteSet((-1, 1), (1, -1)) + assert _solveset_multi([x**2-1, y], [x, y], [Reals, Reals]) == FiniteSet((1, 0), (-1, 0)) + #assert _solveset_multi([x**2-y**2], [x, y], [Reals, Reals]) == Union( + # ImageSet(Lambda(x, (x, -x)), Reals), ImageSet(Lambda(x, (x, x)), Reals)) + assert dumeq(_solveset_multi([x**2-y**2], [x, y], [Reals, Reals]), Union( + ImageSet(Lambda(((x,),), (x, -Abs(x))), ProductSet(Reals)), + ImageSet(Lambda(((x,),), (x, Abs(x))), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (-Abs(y), y)), ProductSet(Reals)), + ImageSet(Lambda(((y,),), (Abs(y), y)), ProductSet(Reals)))) + assert _solveset_multi([r*cos(theta)-1, r*sin(theta)], [theta, r], + [Interval(0, pi), Interval(-1, 1)]) == FiniteSet((0, 1), (pi, -1)) + assert _solveset_multi([r*cos(theta)-1, r*sin(theta)], [r, theta], + [Interval(0, 1), Interval(0, pi)]) == FiniteSet((1, 0)) + assert _solveset_multi([r*cos(theta)-r, r*sin(theta)], [r, theta], + [Interval(0, 1), Interval(0, pi)]) == Union( + ImageSet(Lambda(((r,),), (r, 0)), + ImageSet(Lambda(r, (r,)), Interval(0, 1))), + ImageSet(Lambda(((theta,),), (0, theta)), + ImageSet(Lambda(theta, (theta,)), Interval(0, pi)))) + + +def test_conditionset(): + assert solveset(Eq(sin(x)**2 + cos(x)**2, 1), x, domain=S.Reals + ) is S.Reals + + assert solveset(Eq(x**2 + x*sin(x), 1), x, domain=S.Reals + ).dummy_eq(ConditionSet(x, Eq(x**2 + x*sin(x) - 1, 0), S.Reals)) + + assert dumeq(solveset(Eq(-I*(exp(I*x) - exp(-I*x))/2, 1), x + ), imageset(Lambda(n, 2*n*pi + pi/2), S.Integers)) + + assert solveset(x + sin(x) > 1, x, domain=S.Reals + ).dummy_eq(ConditionSet(x, x + sin(x) > 1, S.Reals)) + + assert solveset(Eq(sin(Abs(x)), x), x, domain=S.Reals + ).dummy_eq(ConditionSet(x, Eq(-x + sin(Abs(x)), 0), S.Reals)) + + assert solveset(y**x-z, x, S.Reals + ).dummy_eq(ConditionSet(x, Eq(y**x - z, 0), S.Reals)) + + +@XFAIL +def test_conditionset_equality(): + ''' Checking equality of different representations of ConditionSet''' + assert solveset(Eq(tan(x), y), x) == ConditionSet(x, Eq(tan(x), y), S.Complexes) + + +def test_solveset_domain(): + assert solveset(x**2 - x - 6, x, Interval(0, oo)) == FiniteSet(3) + assert solveset(x**2 - 1, x, Interval(0, oo)) == FiniteSet(1) + assert solveset(x**4 - 16, x, Interval(0, 10)) == FiniteSet(2) + + +def test_improve_coverage(): + solution = solveset(exp(x) + sin(x), x, S.Reals) + unsolved_object = ConditionSet(x, Eq(exp(x) + sin(x), 0), S.Reals) + assert solution.dummy_eq(unsolved_object) + + +def test_issue_9522(): + expr1 = Eq(1/(x**2 - 4) + x, 1/(x**2 - 4) + 2) + expr2 = Eq(1/x + x, 1/x) + + assert solveset(expr1, x, S.Reals) is S.EmptySet + assert solveset(expr2, x, S.Reals) is S.EmptySet + + +def test_solvify(): + assert solvify(x**2 + 10, x, S.Reals) == [] + assert solvify(x**3 + 1, x, S.Complexes) == [-1, S.Half - sqrt(3)*I/2, + S.Half + sqrt(3)*I/2] + assert solvify(log(x), x, S.Reals) == [1] + assert solvify(cos(x), x, S.Reals) == [pi/2, pi*Rational(3, 2)] + assert solvify(sin(x) + 1, x, S.Reals) == [pi*Rational(3, 2)] + raises(NotImplementedError, lambda: solvify(sin(exp(x)), x, S.Complexes)) + + +def test_solvify_piecewise(): + p1 = Piecewise((0, x < -1), (x**2, x <= 1), (log(x), True)) + p2 = Piecewise((0, x < -10), (x**2 + 5*x - 6, x >= -9)) + p3 = Piecewise((0, Eq(x, 0)), (x**2/Abs(x), True)) + p4 = Piecewise((0, Eq(x, pi)), ((x - pi)/sin(x), True)) + + # issue 21079 + assert solvify(p1, x, S.Reals) == [0] + assert solvify(p2, x, S.Reals) == [-6, 1] + assert solvify(p3, x, S.Reals) == [0] + assert solvify(p4, x, S.Reals) == [pi] + + +def test_abs_invert_solvify(): + + x = Symbol('x',positive=True) + assert solvify(sin(Abs(x)), x, S.Reals) == [0, pi] + x = Symbol('x') + assert solvify(sin(Abs(x)), x, S.Reals) is None + + +def test_linear_eq_to_matrix(): + assert linear_eq_to_matrix(0, x) == (Matrix([[0]]), Matrix([[0]])) + assert linear_eq_to_matrix(1, x) == (Matrix([[0]]), Matrix([[-1]])) + + # integer coefficients + eqns1 = [2*x + y - 2*z - 3, x - y - z, x + y + 3*z - 12] + eqns2 = [Eq(3*x + 2*y - z, 1), Eq(2*x - 2*y + 4*z, -2), -2*x + y - 2*z] + + A, B = linear_eq_to_matrix(eqns1, x, y, z) + assert A == Matrix([[2, 1, -2], [1, -1, -1], [1, 1, 3]]) + assert B == Matrix([[3], [0], [12]]) + + A, B = linear_eq_to_matrix(eqns2, x, y, z) + assert A == Matrix([[3, 2, -1], [2, -2, 4], [-2, 1, -2]]) + assert B == Matrix([[1], [-2], [0]]) + + # Pure symbolic coefficients + eqns3 = [a*b*x + b*y + c*z - d, e*x + d*x + f*y + g*z - h, i*x + j*y + k*z - l] + A, B = linear_eq_to_matrix(eqns3, x, y, z) + assert A == Matrix([[a*b, b, c], [d + e, f, g], [i, j, k]]) + assert B == Matrix([[d], [h], [l]]) + + # raise Errors if + # 1) no symbols are given + raises(ValueError, lambda: linear_eq_to_matrix(eqns3)) + # 2) there are duplicates + raises(ValueError, lambda: linear_eq_to_matrix(eqns3, [x, x, y])) + # 3) a nonlinear term is detected in the original expression + raises(NonlinearError, lambda: linear_eq_to_matrix(Eq(1/x + x, 1/x), [x])) + raises(NonlinearError, lambda: linear_eq_to_matrix([x**2], [x])) + raises(NonlinearError, lambda: linear_eq_to_matrix([x*y], [x, y])) + # 4) Eq being used to represent equations autoevaluates + # (use unevaluated Eq instead) + raises(ValueError, lambda: linear_eq_to_matrix(Eq(x, x), x)) + raises(ValueError, lambda: linear_eq_to_matrix(Eq(x, x + 1), x)) + + + # if non-symbols are passed, the user is responsible for interpreting + assert linear_eq_to_matrix([x], [1/x]) == (Matrix([[0]]), Matrix([[-x]])) + + # issue 15195 + assert linear_eq_to_matrix(x + y*(z*(3*x + 2) + 3), x) == ( + Matrix([[3*y*z + 1]]), Matrix([[-y*(2*z + 3)]])) + assert linear_eq_to_matrix(Matrix( + [[a*x + b*y - 7], [5*x + 6*y - c]]), x, y) == ( + Matrix([[a, b], [5, 6]]), Matrix([[7], [c]])) + + # issue 15312 + assert linear_eq_to_matrix(Eq(x + 2, 1), x) == ( + Matrix([[1]]), Matrix([[-1]])) + + # issue 25423 + raises(TypeError, lambda: linear_eq_to_matrix([], {x, y})) + raises(TypeError, lambda: linear_eq_to_matrix([x + y], {x, y})) + raises(ValueError, lambda: linear_eq_to_matrix({x + y}, (x, y))) + + +def test_issue_16577(): + assert linear_eq_to_matrix(Eq(a*(2*x + 3*y) + 4*y, 5), x, y) == ( + Matrix([[2*a, 3*a + 4]]), Matrix([[5]])) + + +def test_issue_10085(): + assert invert_real(exp(x),0,x) == (x, S.EmptySet) + + +def test_linsolve(): + x1, x2, x3, x4 = symbols('x1, x2, x3, x4') + + # Test for different input forms + + M = Matrix([[1, 2, 1, 1, 7], [1, 2, 2, -1, 12], [2, 4, 0, 6, 4]]) + system1 = A, B = M[:, :-1], M[:, -1] + Eqns = [x1 + 2*x2 + x3 + x4 - 7, x1 + 2*x2 + 2*x3 - x4 - 12, + 2*x1 + 4*x2 + 6*x4 - 4] + + sol = FiniteSet((-2*x2 - 3*x4 + 2, x2, 2*x4 + 5, x4)) + assert linsolve(Eqns, (x1, x2, x3, x4)) == sol + assert linsolve(Eqns, *(x1, x2, x3, x4)) == sol + assert linsolve(system1, (x1, x2, x3, x4)) == sol + assert linsolve(system1, *(x1, x2, x3, x4)) == sol + # issue 9667 - symbols can be Dummy symbols + x1, x2, x3, x4 = symbols('x:4', cls=Dummy) + assert linsolve(system1, x1, x2, x3, x4) == FiniteSet( + (-2*x2 - 3*x4 + 2, x2, 2*x4 + 5, x4)) + + # raise ValueError for garbage value + raises(ValueError, lambda: linsolve(Eqns)) + raises(ValueError, lambda: linsolve(x1)) + raises(ValueError, lambda: linsolve(x1, x2)) + raises(ValueError, lambda: linsolve((A,), x1, x2)) + raises(ValueError, lambda: linsolve(A, B, x1, x2)) + raises(ValueError, lambda: linsolve([x1], x1, x1)) + raises(ValueError, lambda: linsolve([x1], (i for i in (x1, x1)))) + + #raise ValueError if equations are non-linear in given variables + raises(NonlinearError, lambda: linsolve([x + y - 1, x ** 2 + y - 3], [x, y])) + raises(NonlinearError, lambda: linsolve([cos(x) + y, x + y], [x, y])) + assert linsolve([x + z - 1, x ** 2 + y - 3], [z, y]) == {(-x + 1, -x**2 + 3)} + + # Fully symbolic test + A = Matrix([[a, b], [c, d]]) + B = Matrix([[e], [g]]) + system2 = (A, B) + sol = FiniteSet(((-b*g + d*e)/(a*d - b*c), (a*g - c*e)/(a*d - b*c))) + assert linsolve(system2, [x, y]) == sol + + # No solution + A = Matrix([[1, 2, 3], [2, 4, 6], [3, 6, 9]]) + B = Matrix([0, 0, 1]) + assert linsolve((A, B), (x, y, z)) is S.EmptySet + + # Issue #10056 + A, B, J1, J2 = symbols('A B J1 J2') + Augmatrix = Matrix([ + [2*I*J1, 2*I*J2, -2/J1], + [-2*I*J2, -2*I*J1, 2/J2], + [0, 2, 2*I/(J1*J2)], + [2, 0, 0], + ]) + + assert linsolve(Augmatrix, A, B) == FiniteSet((0, I/(J1*J2))) + + # Issue #10121 - Assignment of free variables + Augmatrix = Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0]]) + assert linsolve(Augmatrix, a, b, c, d, e) == FiniteSet((a, 0, c, 0, e)) + #raises(IndexError, lambda: linsolve(Augmatrix, a, b, c)) + + x0, x1, x2, _x0 = symbols('tau0 tau1 tau2 _tau0') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + x0, x1, x2, _x0 = symbols('tau00 tau01 tau02 tau0') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + x0, x1, x2, _x0 = symbols('tau00 tau01 tau02 tau1') + assert linsolve(Matrix([[0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, _x0]]) + ) == FiniteSet((x0, 0, x1, _x0, x2)) + # symbols can be given as generators + x0, x2, x4 = symbols('x0, x2, x4') + assert linsolve(Augmatrix, numbered_symbols('x') + ) == FiniteSet((x0, 0, x2, 0, x4)) + Augmatrix[-1, -1] = x0 + # use Dummy to avoid clash; the names may clash but the symbols + # will not + Augmatrix[-1, -1] = symbols('_x0') + assert len(linsolve( + Augmatrix, numbered_symbols('x', cls=Dummy)).free_symbols) == 4 + + # Issue #12604 + f = Function('f') + assert linsolve([f(x) - 5], f(x)) == FiniteSet((5,)) + + # Issue #14860 + from sympy.physics.units import meter, newton, kilo + kN = kilo*newton + Eqns = [8*kN + x + y, 28*kN*meter + 3*x*meter] + assert linsolve(Eqns, x, y) == { + (kilo*newton*Rational(-28, 3), kN*Rational(4, 3))} + + # linsolve does not allow expansion (real or implemented) + # to remove singularities, but it will cancel linear terms + assert linsolve([Eq(x, x + y)], [x, y]) == {(x, 0)} + assert linsolve([Eq(x + x*y, 1 + y)], [x]) == {(1,)} + assert linsolve([Eq(1 + y, x + x*y)], [x]) == {(1,)} + raises(NonlinearError, lambda: + linsolve([Eq(x**2, x**2 + y)], [x, y])) + + # corner cases + # + # XXX: The case below should give the same as for [0] + # assert linsolve([], [x]) == {(x,)} + assert linsolve([], [x]) is S.EmptySet + assert linsolve([0], [x]) == {(x,)} + assert linsolve([x], [x, y]) == {(0, y)} + assert linsolve([x, 0], [x, y]) == {(0, y)} + + +def test_linsolve_large_sparse(): + # + # This is mainly a performance test + # + + def _mk_eqs_sol(n): + xs = symbols('x:{}'.format(n)) + ys = symbols('y:{}'.format(n)) + syms = xs + ys + eqs = [] + sol = (-S.Half,) * n + (S.Half,) * n + for xi, yi in zip(xs, ys): + eqs.extend([xi + yi, xi - yi + 1]) + return eqs, syms, FiniteSet(sol) + + n = 500 + eqs, syms, sol = _mk_eqs_sol(n) + assert linsolve(eqs, syms) == sol + + +def test_linsolve_immutable(): + A = ImmutableDenseMatrix([[1, 1, 2], [0, 1, 2], [0, 0, 1]]) + B = ImmutableDenseMatrix([2, 1, -1]) + assert linsolve([A, B], (x, y, z)) == FiniteSet((1, 3, -1)) + + A = ImmutableDenseMatrix([[1, 1, 7], [1, -1, 3]]) + assert linsolve(A) == FiniteSet((5, 2)) + + +def test_solve_decomposition(): + n = Dummy('n') + + f1 = exp(3*x) - 6*exp(2*x) + 11*exp(x) - 6 + f2 = sin(x)**2 - 2*sin(x) + 1 + f3 = sin(x)**2 - sin(x) + f4 = sin(x + 1) + f5 = exp(x + 2) - 1 + f6 = 1/log(x) + f7 = 1/x + + s1 = ImageSet(Lambda(n, 2*n*pi), S.Integers) + s2 = ImageSet(Lambda(n, 2*n*pi + pi), S.Integers) + s3 = ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers) + s4 = ImageSet(Lambda(n, 2*n*pi - 1), S.Integers) + s5 = ImageSet(Lambda(n, 2*n*pi - 1 + pi), S.Integers) + + assert solve_decomposition(f1, x, S.Reals) == FiniteSet(0, log(2), log(3)) + assert dumeq(solve_decomposition(f2, x, S.Reals), s3) + assert dumeq(solve_decomposition(f3, x, S.Reals), Union(s1, s2, s3)) + assert dumeq(solve_decomposition(f4, x, S.Reals), Union(s4, s5)) + assert solve_decomposition(f5, x, S.Reals) == FiniteSet(-2) + assert solve_decomposition(f6, x, S.Reals) == S.EmptySet + assert solve_decomposition(f7, x, S.Reals) == S.EmptySet + assert solve_decomposition(x, x, Interval(1, 2)) == S.EmptySet + + +# nonlinsolve testcases +def test_nonlinsolve_basic(): + assert nonlinsolve([],[]) == S.EmptySet + assert nonlinsolve([],[x, y]) == S.EmptySet + + system = [x, y - x - 5] + assert nonlinsolve([x],[x, y]) == FiniteSet((0, y)) + assert nonlinsolve(system, [y]) == S.EmptySet + soln = (ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers),) + assert dumeq(nonlinsolve([sin(x) - 1], [x]), FiniteSet(tuple(soln))) + soln = ((ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), 1), + (ImageSet(Lambda(n, 2*n*pi), S.Integers), 1)) + assert dumeq(nonlinsolve([sin(x), y - 1], [x, y]), FiniteSet(*soln)) + assert nonlinsolve([x**2 - 1], [x]) == FiniteSet((-1,), (1,)) + + soln = FiniteSet((y, y)) + assert nonlinsolve([x - y, 0], x, y) == soln + assert nonlinsolve([0, x - y], x, y) == soln + assert nonlinsolve([x - y, x - y], x, y) == soln + assert nonlinsolve([x, 0], x, y) == FiniteSet((0, y)) + f = Function('f') + assert nonlinsolve([f(x), 0], f(x), y) == FiniteSet((0, y)) + assert nonlinsolve([f(x), 0], f(x), f(y)) == FiniteSet((0, f(y))) + A = Indexed('A', x) + assert nonlinsolve([A, 0], A, y) == FiniteSet((0, y)) + assert nonlinsolve([x**2 -1], [sin(x)]) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([x**2 -1], sin(x)) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([x**2 -1], 1) == FiniteSet((x**2,)) + assert nonlinsolve([x**2 -1], x + y) == FiniteSet((S.EmptySet,)) + assert nonlinsolve([Eq(1, x + y), Eq(1, -x + y - 1), Eq(1, -x + y - 1)], x, y) == FiniteSet( + (-S.Half, 3*S.Half)) + + +def test_nonlinsolve_abs(): + soln = FiniteSet((y, y), (-y, y)) + assert nonlinsolve([Abs(x) - y], x, y) == soln + + +def test_raise_exception_nonlinsolve(): + raises(IndexError, lambda: nonlinsolve([x**2 -1], [])) + raises(ValueError, lambda: nonlinsolve([x**2 -1])) + + +def test_trig_system(): + # TODO: add more simple testcases when solveset returns + # simplified soln for Trig eq + assert nonlinsolve([sin(x) - 1, cos(x) -1 ], x) == S.EmptySet + soln1 = (ImageSet(Lambda(n, 2*n*pi + pi/2), S.Integers),) + soln = FiniteSet(soln1) + assert dumeq(nonlinsolve([sin(x) - 1, cos(x)], x), soln) + + +@XFAIL +def test_trig_system_fail(): + # fails because solveset trig solver is not much smart. + sys = [x + y - pi/2, sin(x) + sin(y) - 1] + # solveset returns conditionset for sin(x) + sin(y) - 1 + soln_1 = (ImageSet(Lambda(n, n*pi + pi/2), S.Integers), + ImageSet(Lambda(n, n*pi), S.Integers)) + soln_1 = FiniteSet(soln_1) + soln_2 = (ImageSet(Lambda(n, n*pi), S.Integers), + ImageSet(Lambda(n, n*pi+ pi/2), S.Integers)) + soln_2 = FiniteSet(soln_2) + soln = soln_1 + soln_2 + assert dumeq(nonlinsolve(sys, [x, y]), soln) + + # Add more cases from here + # http://www.vitutor.com/geometry/trigonometry/equations_systems.html#uno + sys = [sin(x) + sin(y) - (sqrt(3)+1)/2, sin(x) - sin(y) - (sqrt(3) - 1)/2] + soln_x = Union(ImageSet(Lambda(n, 2*n*pi + pi/3), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi*Rational(2, 3)), S.Integers)) + soln_y = Union(ImageSet(Lambda(n, 2*n*pi + pi/6), S.Integers), + ImageSet(Lambda(n, 2*n*pi + pi*Rational(5, 6)), S.Integers)) + assert dumeq(nonlinsolve(sys, [x, y]), FiniteSet((soln_x, soln_y))) + + +def test_nonlinsolve_positive_dimensional(): + x, y, a, b, c, d = symbols('x, y, a, b, c, d', extended_real=True) + assert nonlinsolve([x*y, x*y - x], [x, y]) == FiniteSet((0, y)) + + system = [a**2 + a*c, a - b] + assert nonlinsolve(system, [a, b]) == FiniteSet((0, 0), (-c, -c)) + # here (a= 0, b = 0) is independent soln so both is printed. + # if symbols = [a, b, c] then only {a : -c ,b : -c} + + eq1 = a + b + c + d + eq2 = a*b + b*c + c*d + d*a + eq3 = a*b*c + b*c*d + c*d*a + d*a*b + eq4 = a*b*c*d - 1 + system = [eq1, eq2, eq3, eq4] + sol1 = (-1/d, -d, 1/d, FiniteSet(d) - FiniteSet(0)) + sol2 = (1/d, -d, -1/d, FiniteSet(d) - FiniteSet(0)) + soln = FiniteSet(sol1, sol2) + assert nonlinsolve(system, [a, b, c, d]) == soln + + assert nonlinsolve([x**4 - 3*x**2 + y*x, x*z**2, y*z - 1], [x, y, z]) == \ + {(0, 1/z, z)} + + +def test_nonlinsolve_polysys(): + x, y, z = symbols('x, y, z', real=True) + assert nonlinsolve([x**2 + y - 2, x**2 + y], [x, y]) == S.EmptySet + + s = (-y + 2, y) + assert nonlinsolve([(x + y)**2 - 4, x + y - 2], [x, y]) == FiniteSet(s) + + system = [x**2 - y**2] + soln_real = FiniteSet((-y, y), (y, y)) + soln_complex = FiniteSet((-Abs(y), y), (Abs(y), y)) + soln =soln_real + soln_complex + assert nonlinsolve(system, [x, y]) == soln + + system = [x**2 - y**2] + soln_real= FiniteSet((y, -y), (y, y)) + soln_complex = FiniteSet((y, -Abs(y)), (y, Abs(y))) + soln = soln_real + soln_complex + assert nonlinsolve(system, [y, x]) == soln + + system = [x**2 + y - 3, x - y - 4] + assert nonlinsolve(system, (x, y)) != nonlinsolve(system, (y, x)) + + assert nonlinsolve([-x**2 - y**2 + z, -2*x, -2*y, S.One], [x, y, z]) == S.EmptySet + assert nonlinsolve([x + y + z, S.One, S.One, S.One], [x, y, z]) == S.EmptySet + + system = [-x**2*z**2 + x*y*z + y**4, -2*x*z**2 + y*z, x*z + 4*y**3, -2*x**2*z + x*y] + assert nonlinsolve(system, [x, y, z]) == FiniteSet((0, 0, z), (x, 0, 0)) + + +def test_nonlinsolve_using_substitution(): + x, y, z, n = symbols('x, y, z, n', real = True) + system = [(x + y)*n - y**2 + 2] + s_x = (n*y - y**2 + 2)/n + soln = (-s_x, y) + assert nonlinsolve(system, [x, y]) == FiniteSet(soln) + + system = [z**2*x**2 - z**2*y**2/exp(x)] + soln_real_1 = (y, x, 0) + soln_real_2 = (-exp(x/2)*Abs(x), x, z) + soln_real_3 = (exp(x/2)*Abs(x), x, z) + soln_complex_1 = (-x*exp(x/2), x, z) + soln_complex_2 = (x*exp(x/2), x, z) + syms = [y, x, z] + soln = FiniteSet(soln_real_1, soln_complex_1, soln_complex_2,\ + soln_real_2, soln_real_3) + assert nonlinsolve(system,syms) == soln + + +def test_nonlinsolve_complex(): + n = Dummy('n') + assert dumeq(nonlinsolve([exp(x) - sin(y), 1/y - 3], [x, y]), { + (ImageSet(Lambda(n, 2*n*I*pi + log(sin(Rational(1, 3)))), S.Integers), Rational(1, 3))}) + + system = [exp(x) - sin(y), 1/exp(y) - 3] + assert dumeq(nonlinsolve(system, [x, y]), { + (ImageSet(Lambda(n, I*(2*n*pi + pi) + + log(sin(log(3)))), S.Integers), -log(3)), + (ImageSet(Lambda(n, I*(2*n*pi + arg(sin(2*n*I*pi - log(3)))) + + log(Abs(sin(2*n*I*pi - log(3))))), S.Integers), + ImageSet(Lambda(n, 2*n*I*pi - log(3)), S.Integers))}) + + system = [exp(x) - sin(y), y**2 - 4] + assert dumeq(nonlinsolve(system, [x, y]), { + (ImageSet(Lambda(n, I*(2*n*pi + pi) + log(sin(2))), S.Integers), -2), + (ImageSet(Lambda(n, 2*n*I*pi + log(sin(2))), S.Integers), 2)}) + + system = [exp(x) - 2, y ** 2 - 2] + assert dumeq(nonlinsolve(system, [x, y]), { + (log(2), -sqrt(2)), (log(2), sqrt(2)), + (ImageSet(Lambda(n, 2*n*I*pi + log(2)), S.Integers), -sqrt(2)), + (ImageSet(Lambda(n, 2 * n * I * pi + log(2)), S.Integers), sqrt(2))}) + + +def test_nonlinsolve_radical(): + assert nonlinsolve([sqrt(y) - x - z, y - 1], [x, y, z]) == {(1 - z, 1, z)} + + +def test_nonlinsolve_inexact(): + sol = [(-1.625, -1.375), (1.625, 1.375)] + res = nonlinsolve([(x + y)**2 - 9, x**2 - y**2 - 0.75], [x, y]) + assert all(abs(res.args[i][j]-sol[i][j]) < 1e-9 + for i in range(2) for j in range(2)) + + assert nonlinsolve([(x + y)**2 - 9, (x + y)**2 - 0.75], [x, y]) == S.EmptySet + + assert nonlinsolve([y**2 + (x - 0.5)**2 - 0.0625, 2*x - 1.0, 2*y], [x, y]) == \ + S.EmptySet + + res = nonlinsolve([x**2 + y - 0.5, (x + y)**2, log(z)], [x, y, z]) + sol = [(-0.366025403784439, 0.366025403784439, 1), + (-0.366025403784439, 0.366025403784439, 1), + (1.36602540378444, -1.36602540378444, 1)] + assert all(abs(res.args[i][j]-sol[i][j]) < 1e-9 + for i in range(3) for j in range(3)) + + res = nonlinsolve([y - x**2, x**5 - x + 1.0], [x, y]) + sol = [(-1.16730397826142, 1.36259857766493), + (-0.181232444469876 - 1.08395410131771*I, + -1.14211129483496 + 0.392895302949911*I), + (-0.181232444469876 + 1.08395410131771*I, + -1.14211129483496 - 0.392895302949911*I), + (0.764884433600585 - 0.352471546031726*I, + 0.460812006002492 - 0.539199997693599*I), + (0.764884433600585 + 0.352471546031726*I, + 0.460812006002492 + 0.539199997693599*I)] + assert all(abs(res.args[i][j] - sol[i][j]) < 1e-9 + for i in range(5) for j in range(2)) + +@XFAIL +def test_solve_nonlinear_trans(): + # After the transcendental equation solver these will work + x, y = symbols('x, y', real=True) + soln1 = FiniteSet((2*LambertW(y/2), y)) + soln2 = FiniteSet((-x*sqrt(exp(x)), y), (x*sqrt(exp(x)), y)) + soln3 = FiniteSet((x*exp(x/2), x)) + soln4 = FiniteSet(2*LambertW(y/2), y) + assert nonlinsolve([x**2 - y**2/exp(x)], [x, y]) == soln1 + assert nonlinsolve([x**2 - y**2/exp(x)], [y, x]) == soln2 + assert nonlinsolve([x**2 - y**2/exp(x)], [y, x]) == soln3 + assert nonlinsolve([x**2 - y**2/exp(x)], [x, y]) == soln4 + + +def test_nonlinsolve_issue_25182(): + a1, b1, c1, ca, cb, cg = symbols('a1, b1, c1, ca, cb, cg') + eq1 = a1*a1 + b1*b1 - 2.*a1*b1*cg - c1*c1 + eq2 = a1*a1 + c1*c1 - 2.*a1*c1*cb - b1*b1 + eq3 = b1*b1 + c1*c1 - 2.*b1*c1*ca - a1*a1 + assert nonlinsolve([eq1, eq2, eq3], [c1, cb, cg]) == FiniteSet( + (1.0*b1*ca - 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2), + -1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1, + -1.0*b1*(ca - 1)*(ca + 1)/a1 + 1.0*ca*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1), + (1.0*b1*ca + 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2), + 1.0*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1, + -1.0*b1*(ca - 1)*(ca + 1)/a1 - 1.0*ca*sqrt(a1**2 + b1**2*ca**2 - b1**2)/a1)) + + +def test_issue_14642(): + x = Symbol('x') + n1 = 0.5*x**3+x**2+0.5+I #add I in the Polynomials + solution = solveset(n1, x) + assert abs(solution.args[0] - (-2.28267560928153 - 0.312325580497716*I)) <= 1e-9 + assert abs(solution.args[1] - (-0.297354141679308 + 1.01904778618762*I)) <= 1e-9 + assert abs(solution.args[2] - (0.580029750960839 - 0.706722205689907*I)) <= 1e-9 + + # Symbolic + n1 = S.Half*x**3+x**2+S.Half+I + res = FiniteSet(-((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49) + /2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2))/3)/3 - S(2)/3 - 4*cos(atan((27 + + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)* + 31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2))/3)/(3*((3* + sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)**2)**(S(1)/ + 6)) + I*(-((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/ + 2)/2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos( + atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49) + /2)/2 + S(43)/2))/3)/3 + 4*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172) + /49)/2)/2 + S(43)/2))/3)/(3*((3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6))), -S(2)/3 - sqrt(3)*((3* + sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)**2)**(S(1) + /6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + S(43)/2)) + /3)/6 - 4*re(1/((-S(1)/2 - sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + + (43 + 54*I)**2)/2)**(S(1)/3)))/3 + ((3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)* + 31985**(S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)* + sin(atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 + I*(-4*im(1/((-S(1)/2 - + sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1)/ + 3)))/3 + ((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/ + 49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/ + 4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2))/3)/6), -S(2)/3 - 4*re(1/((-S(1)/2 + + sqrt(3)*I/2)*(S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1) + /3)))/3 + sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + ((3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan(S(172)/49)/2) + /2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**(S(1)/4)*cos(atan( + S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin(atan(S(172)/49)/2)/2 + + S(43)/2))/3)/6 + I*(-sqrt(3)*((3*sqrt(3)*31985**(S(1)/4)*sin(atan( + S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)*cos( + atan(S(172)/49)/2)/2)**2)**(S(1)/6)*cos(atan((27 + 3*sqrt(3)*31985**( + S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 + ((3*sqrt(3)*31985**(S(1)/4)* + sin(atan(S(172)/49)/2)/2 + S(43)/2)**2 + (27 + 3*sqrt(3)*31985**(S(1)/4)* + cos(atan(S(172)/49)/2)/2)**2)**(S(1)/6)*sin(atan((27 + 3*sqrt(3)*31985**( + S(1)/4)*cos(atan(S(172)/49)/2)/2)/(3*sqrt(3)*31985**(S(1)/4)*sin( + atan(S(172)/49)/2)/2 + S(43)/2))/3)/6 - 4*im(1/((-S(1)/2 + sqrt(3)*I/2)* + (S(43)/2 + 27*I + sqrt(-256 + (43 + 54*I)**2)/2)**(S(1)/3)))/3)) + + assert solveset(n1, x) == res + + +def test_issue_13961(): + V = (ax, bx, cx, gx, jx, lx, mx, nx, q) = symbols('ax bx cx gx jx lx mx nx q') + S = (ax*q - lx*q - mx, ax - gx*q - lx, bx*q**2 + cx*q - jx*q - nx, q*(-ax*q + lx*q + mx), q*(-ax + gx*q + lx)) + + sol = FiniteSet((lx + mx/q, (-cx*q + jx*q + nx)/q**2, cx, mx/q**2, jx, lx, mx, nx, Complement({q}, {0})), + (lx + mx/q, (cx*q - jx*q - nx)/q**2*-1, cx, mx/q**2, jx, lx, mx, nx, Complement({q}, {0}))) + assert nonlinsolve(S, *V) == sol + # The two solutions are in fact identical, so even better if only one is returned + + +def test_issue_14541(): + solutions = solveset(sqrt(-x**2 - 2.0), x) + assert abs(solutions.args[0]+1.4142135623731*I) <= 1e-9 + assert abs(solutions.args[1]-1.4142135623731*I) <= 1e-9 + + +def test_issue_13396(): + expr = -2*y*exp(-x**2 - y**2)*Abs(x) + sol = FiniteSet(0) + + assert solveset(expr, y, domain=S.Reals) == sol + + # Related type of equation also solved here + assert solveset(atan(x**2 - y**2)-pi/2, y, S.Reals) is S.EmptySet + + +def test_issue_12032(): + sol = FiniteSet(-sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 + + sqrt(Abs(-2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2, + -sqrt(Abs(-2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2 - + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2, + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 - + I*sqrt(Abs(-2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) - + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))))/2, + sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)))/2 + + I*sqrt(Abs(-2/sqrt(-2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) + + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3))) - + 2*(Rational(1, 16) + sqrt(849)/144)**(Rational(1, 3)) + + 2/(3*(Rational(1, 16) + sqrt(849)/144)**(Rational(1,3)))))/2) + assert solveset(x**4 + x - 1, x) == sol + + +def test_issue_10876(): + assert solveset(1/sqrt(x), x) == S.EmptySet + + +def test_issue_19050(): + # test_issue_19050 --> TypeError removed + assert dumeq(nonlinsolve([x + y, sin(y)], [x, y]), + FiniteSet((ImageSet(Lambda(n, -2*n*pi), S.Integers), ImageSet(Lambda(n, 2*n*pi), S.Integers)),\ + (ImageSet(Lambda(n, -2*n*pi - pi), S.Integers), ImageSet(Lambda(n, 2*n*pi + pi), S.Integers)))) + assert dumeq(nonlinsolve([x + y, sin(y) + cos(y)], [x, y]), + FiniteSet((ImageSet(Lambda(n, -2*n*pi - 3*pi/4), S.Integers), ImageSet(Lambda(n, 2*n*pi + 3*pi/4), S.Integers)), \ + (ImageSet(Lambda(n, -2*n*pi - 7*pi/4), S.Integers), ImageSet(Lambda(n, 2*n*pi + 7*pi/4), S.Integers)))) + + +def test_issue_16618(): + eqn = [sin(x)*sin(y), cos(x)*cos(y) - 1] + # nonlinsolve's answer is still suspicious since it contains only three + # distinct Dummys instead of 4. (Both 'x' ImageSets share the same Dummy.) + ans = FiniteSet((ImageSet(Lambda(n, 2*n*pi), S.Integers), ImageSet(Lambda(n, 2*n*pi), S.Integers)), + (ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), ImageSet(Lambda(n, 2*n*pi + pi), S.Integers))) + sol = nonlinsolve(eqn, [x, y]) + + for i0, j0 in zip(ordered(sol), ordered(ans)): + assert len(i0) == len(j0) == 2 + assert all(a.dummy_eq(b) for a, b in zip(i0, j0)) + assert len(sol) == len(ans) + + +def test_issue_17566(): + assert nonlinsolve([32*(2**x)/2**(-y) - 4**y, 27*(3**x) - S(1)/3**y], x, y) ==\ + FiniteSet((-log(81)/log(3), 1)) + + +def test_issue_16643(): + n = Dummy('n') + assert solveset(x**2*sin(x), x).dummy_eq(Union(ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), + ImageSet(Lambda(n, 2*n*pi), S.Integers))) + + +def test_issue_19587(): + n,m = symbols('n m') + assert nonlinsolve([32*2**m*2**n - 4**n, 27*3**m - 3**(-n)], m, n) ==\ + FiniteSet((-log(81)/log(3), 1)) + + +def test_issue_5132_1(): + system = [sqrt(x**2 + y**2) - sqrt(10), x + y - 4] + assert nonlinsolve(system, [x, y]) == FiniteSet((1, 3), (3, 1)) + + n = Dummy('n') + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + s_real_y = -log(3) + s_real_z = sqrt(-exp(2*x) - sin(log(3))) + soln_real = FiniteSet((s_real_y, s_real_z), (s_real_y, -s_real_z)) + lam = Lambda(n, 2*n*I*pi + -log(3)) + s_complex_y = ImageSet(lam, S.Integers) + lam = Lambda(n, sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_1 = ImageSet(lam, S.Integers) + lam = Lambda(n, -sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_2 = ImageSet(lam, S.Integers) + soln_complex = FiniteSet( + (s_complex_y, s_complex_z_1), + (s_complex_y, s_complex_z_2) + ) + soln = soln_real + soln_complex + assert dumeq(nonlinsolve(eqs, [y, z]), soln) + + +def test_issue_5132_2(): + x, y = symbols('x, y', real=True) + eqs = [exp(x)**2 - sin(y) + z**2] + n = Dummy('n') + soln_real = (log(-z**2 + sin(y))/2, z) + lam = Lambda( n, I*(2*n*pi + arg(-z**2 + sin(y)))/2 + log(Abs(z**2 - sin(y)))/2) + img = ImageSet(lam, S.Integers) + # not sure about the complex soln. But it looks correct. + soln_complex = (img, z) + soln = FiniteSet(soln_real, soln_complex) + assert dumeq(nonlinsolve(eqs, [x, z]), soln) + + system = [r - x**2 - y**2, tan(t) - y/x] + s_x = sqrt(r/(tan(t)**2 + 1)) + s_y = sqrt(r/(tan(t)**2 + 1))*tan(t) + soln = FiniteSet((s_x, s_y), (-s_x, -s_y)) + assert nonlinsolve(system, [x, y]) == soln + + +def test_issue_6752(): + a, b = symbols('a, b', real=True) + assert nonlinsolve([a**2 + a, a - b], [a, b]) == {(-1, -1), (0, 0)} + + +@SKIP("slow") +def test_issue_5114_solveset(): + # slow testcase + from sympy.abc import o, p + + # there is no 'a' in the equation set but this is how the + # problem was originally posed + syms = [a, b, c, f, h, k, n] + eqs = [b + r/d - c/d, + c*(1/d + 1/e + 1/g) - f/g - r/d, + f*(1/g + 1/i + 1/j) - c/g - h/i, + h*(1/i + 1/l + 1/m) - f/i - k/m, + k*(1/m + 1/o + 1/p) - h/m - n/p, + n*(1/p + 1/q) - k/p] + assert len(nonlinsolve(eqs, syms)) == 1 + + +@SKIP("Hangs") +def _test_issue_5335(): + # Not able to check zero dimensional system. + # is_zero_dimensional Hangs + lam, a0, conc = symbols('lam a0 conc') + eqs = [lam + 2*y - a0*(1 - x/2)*x - 0.005*x/2*x, + a0*(1 - x/2)*x - 1*y - 0.743436700916726*y, + x + y - conc] + sym = [x, y, a0] + # there are 4 solutions but only two are valid + assert len(nonlinsolve(eqs, sym)) == 2 + # float + eqs = [lam + 2*y - a0*(1 - x/2)*x - 0.005*x/2*x, + a0*(1 - x/2)*x - 1*y - 0.743436700916726*y, + x + y - conc] + sym = [x, y, a0] + assert len(nonlinsolve(eqs, sym)) == 2 + + +def test_issue_2777(): + # the equations represent two circles + x, y = symbols('x y', real=True) + e1, e2 = sqrt(x**2 + y**2) - 10, sqrt(y**2 + (-x + 10)**2) - 3 + a, b = Rational(191, 20), 3*sqrt(391)/20 + ans = {(a, -b), (a, b)} + assert nonlinsolve((e1, e2), (x, y)) == ans + assert nonlinsolve((e1, e2/(x - a)), (x, y)) == S.EmptySet + # make the 2nd circle's radius be -3 + e2 += 6 + assert nonlinsolve((e1, e2), (x, y)) == S.EmptySet + + +def test_issue_8828(): + x1 = 0 + y1 = -620 + r1 = 920 + x2 = 126 + y2 = 276 + x3 = 51 + y3 = 205 + r3 = 104 + v = [x, y, z] + + f1 = (x - x1)**2 + (y - y1)**2 - (r1 - z)**2 + f2 = (x2 - x)**2 + (y2 - y)**2 - z**2 + f3 = (x - x3)**2 + (y - y3)**2 - (r3 - z)**2 + F = [f1, f2, f3] + + g1 = sqrt((x - x1)**2 + (y - y1)**2) + z - r1 + g2 = f2 + g3 = sqrt((x - x3)**2 + (y - y3)**2) + z - r3 + G = [g1, g2, g3] + + # both soln same + A = nonlinsolve(F, v) + B = nonlinsolve(G, v) + assert A == B + + +def test_nonlinsolve_conditionset(): + # when solveset failed to solve all the eq + # return conditionset + f = Function('f') + f1 = f(x) - pi/2 + f2 = f(y) - pi*Rational(3, 2) + intermediate_system = Eq(2*f(x) - pi, 0) & Eq(2*f(y) - 3*pi, 0) + syms = Tuple(x, y) + soln = ConditionSet( + syms, + intermediate_system, + S.Complexes**2) + assert nonlinsolve([f1, f2], [x, y]) == soln + + +def test_substitution_basic(): + assert substitution([], [x, y]) == S.EmptySet + assert substitution([], []) == S.EmptySet + system = [2*x**2 + 3*y**2 - 30, 3*x**2 - 2*y**2 - 19] + soln = FiniteSet((-3, -2), (-3, 2), (3, -2), (3, 2)) + assert substitution(system, [x, y]) == soln + + soln = FiniteSet((-1, 1)) + assert substitution([x + y], [x], [{y: 1}], [y], set(), [x, y]) == soln + assert substitution( + [x + y], [x], [{y: 1}], [y], + {x + 1}, [y, x]) == S.EmptySet + + +def test_substitution_incorrect(): + # the solutions in the following two tests are incorrect. The + # correct result is EmptySet in both cases. + assert substitution([h - 1, k - 1, f - 2, f - 4, -2 * k], + [h, k, f]) == {(1, 1, f)} + assert substitution([x + y + z, S.One, S.One, S.One], [x, y, z]) == \ + {(-y - z, y, z)} + + # the correct result in the test below is {(-I, I, I, -I), + # (I, -I, -I, I)} + assert substitution([a - d, b + d, c + d, d**2 + 1], [a, b, c, d]) == \ + {(d, -d, -d, d)} + + # the result in the test below is incomplete. The complete result + # is {(0, b), (log(2), 2)} + assert substitution([a*(a - log(b)), a*(b - 2)], [a, b]) == \ + {(0, b)} + + # The system in the test below is zero-dimensional, so the result + # should have no free symbols + assert substitution([-k*y + 6*x - 4*y, -81*k + 49*y**2 - 270, + -3*k*z + k + z**3, k**2 - 2*k + 4], + [x, y, z, k]).free_symbols == {z} + + +def test_substitution_redundant(): + # the third and fourth solutions are redundant in the test below + assert substitution([x**2 - y**2, z - 1], [x, z]) == \ + {(-y, 1), (y, 1), (-sqrt(y**2), 1), (sqrt(y**2), 1)} + + # the system below has three solutions. Two of the solutions + # returned by substitution are redundant. + res = substitution([x - y, y**3 - 3*y**2 + 1], [x, y]) + assert len(res) == 5 + + +def test_issue_5132_substitution(): + x, y, z, r, t = symbols('x, y, z, r, t', real=True) + system = [r - x**2 - y**2, tan(t) - y/x] + s_x_1 = Complement(FiniteSet(-sqrt(r/(tan(t)**2 + 1))), FiniteSet(0)) + s_x_2 = Complement(FiniteSet(sqrt(r/(tan(t)**2 + 1))), FiniteSet(0)) + s_y = sqrt(r/(tan(t)**2 + 1))*tan(t) + soln = FiniteSet((s_x_2, s_y)) + FiniteSet((s_x_1, -s_y)) + assert substitution(system, [x, y]) == soln + + n = Dummy('n') + eqs = [exp(x)**2 - sin(y) + z**2, 1/exp(y) - 3] + s_real_y = -log(3) + s_real_z = sqrt(-exp(2*x) - sin(log(3))) + soln_real = FiniteSet((s_real_y, s_real_z), (s_real_y, -s_real_z)) + lam = Lambda(n, 2*n*I*pi + -log(3)) + s_complex_y = ImageSet(lam, S.Integers) + lam = Lambda(n, sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_1 = ImageSet(lam, S.Integers) + lam = Lambda(n, -sqrt(-exp(2*x) + sin(2*n*I*pi + -log(3)))) + s_complex_z_2 = ImageSet(lam, S.Integers) + soln_complex = FiniteSet( + (s_complex_y, s_complex_z_1), + (s_complex_y, s_complex_z_2)) + soln = soln_real + soln_complex + assert dumeq(substitution(eqs, [y, z]), soln) + + +def test_raises_substitution(): + raises(ValueError, lambda: substitution([x**2 -1], [])) + raises(TypeError, lambda: substitution([x**2 -1])) + raises(ValueError, lambda: substitution([x**2 -1], [sin(x)])) + raises(TypeError, lambda: substitution([x**2 -1], x)) + raises(TypeError, lambda: substitution([x**2 -1], 1)) + + +def test_issue_21022(): + from sympy.core.sympify import sympify + + eqs = [ + 'k-16', + 'p-8', + 'y*y+z*z-x*x', + 'd - x + p', + 'd*d+k*k-y*y', + 'z*z-p*p-k*k', + 'abc-efg', + ] + efg = Symbol('efg') + eqs = [sympify(x) for x in eqs] + + syb = list(ordered(set.union(*[x.free_symbols for x in eqs]))) + res = nonlinsolve(eqs, syb) + + ans = FiniteSet( + (efg, 32, efg, 16, 8, 40, -16*sqrt(5), -8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, -16*sqrt(5), 8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, 16*sqrt(5), -8*sqrt(5)), + (efg, 32, efg, 16, 8, 40, 16*sqrt(5), 8*sqrt(5)), + ) + assert len(res) == len(ans) == 4 + assert res == ans + for result in res.args: + assert len(result) == 8 + + +def test_issue_17940(): + n = Dummy('n') + k1 = Dummy('k1') + sol = ImageSet(Lambda(((k1, n),), I*(2*k1*pi + arg(2*n*I*pi + log(5))) + + log(Abs(2*n*I*pi + log(5)))), + ProductSet(S.Integers, S.Integers)) + assert solveset(exp(exp(x)) - 5, x).dummy_eq(sol) + + +def test_issue_17906(): + assert solveset(7**(x**2 - 80) - 49**x, x) == FiniteSet(-8, 10) + + +@XFAIL +def test_issue_17933(): + eq1 = x*sin(45) - y*cos(q) + eq2 = x*cos(45) - y*sin(q) + eq3 = 9*x*sin(45)/10 + y*cos(q) + eq4 = 9*x*cos(45)/10 + y*sin(z) - z + assert nonlinsolve([eq1, eq2, eq3, eq4], x, y, z, q) ==\ + FiniteSet((0, 0, 0, q)) + +def test_issue_17933_bis(): + # nonlinsolve's result depends on the 'default_sort_key' ordering of + # the unknowns. + eq1 = x*sin(45) - y*cos(q) + eq2 = x*cos(45) - y*sin(q) + eq3 = 9*x*sin(45)/10 + y*cos(q) + eq4 = 9*x*cos(45)/10 + y*sin(z) - z + zz = Symbol('zz') + eqs = [e.subs(q, zz) for e in (eq1, eq2, eq3, eq4)] + assert nonlinsolve(eqs, x, y, z, zz) == FiniteSet((0, 0, 0, zz)) + + +def test_issue_14565(): + # removed redundancy + assert dumeq(nonlinsolve([k + m, k + m*exp(-2*pi*k)], [k, m]) , + FiniteSet((-n*I, ImageSet(Lambda(n, n*I), S.Integers)))) + + +# end of tests for nonlinsolve + + +def test_issue_9556(): + b = Symbol('b', positive=True) + + assert solveset(Abs(x) + 1, x, S.Reals) is S.EmptySet + assert solveset(Abs(x) + b, x, S.Reals) is S.EmptySet + assert solveset(Eq(b, -1), b, S.Reals) is S.EmptySet + + +def test_issue_9611(): + assert solveset(Eq(x - x + a, a), x, S.Reals) == S.Reals + assert solveset(Eq(y - y + a, a), y) == S.Complexes + + +def test_issue_9557(): + assert solveset(x**2 + a, x, S.Reals) == Intersection(S.Reals, + FiniteSet(-sqrt(-a), sqrt(-a))) + + +def test_issue_9778(): + x = Symbol('x', real=True) + y = Symbol('y', real=True) + assert solveset(x**3 + 1, x, S.Reals) == FiniteSet(-1) + assert solveset(x**Rational(3, 5) + 1, x, S.Reals) == S.EmptySet + assert solveset(x**3 + y, x, S.Reals) == \ + FiniteSet(-Abs(y)**Rational(1, 3)*sign(y)) + + +def test_issue_10214(): + assert solveset(x**Rational(3, 2) + 4, x, S.Reals) == S.EmptySet + assert solveset(x**(Rational(-3, 2)) + 4, x, S.Reals) == S.EmptySet + + ans = FiniteSet(-2**Rational(2, 3)) + assert solveset(x**(S(3)) + 4, x, S.Reals) == ans + assert (x**(S(3)) + 4).subs(x,list(ans)[0]) == 0 # substituting ans and verifying the result. + assert (x**(S(3)) + 4).subs(x,-(-2)**Rational(2, 3)) == 0 + + +def test_issue_9849(): + assert solveset(Abs(sin(x)) + 1, x, S.Reals) == S.EmptySet + + +def test_issue_9953(): + assert linsolve([ ], x) == S.EmptySet + + +def test_issue_9913(): + assert solveset(2*x + 1/(x - 10)**2, x, S.Reals) == \ + FiniteSet(-(3*sqrt(24081)/4 + Rational(4027, 4))**Rational(1, 3)/3 - 100/ + (3*(3*sqrt(24081)/4 + Rational(4027, 4))**Rational(1, 3)) + Rational(20, 3)) + + +def test_issue_10397(): + assert solveset(sqrt(x), x, S.Complexes) == FiniteSet(0) + + +def test_issue_14987(): + raises(ValueError, lambda: linear_eq_to_matrix( + [x**2], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [x*(-3/x + 1) + 2*y - a], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x**2 - 3*x)/(x - 3) - 3], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x + 1)**3 - x**3 - 3*x**2 + 7], x)) + raises(ValueError, lambda: linear_eq_to_matrix( + [x*(1/x + 1) + y], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [(x + 1)*y], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(1/x, 1/x + y)], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(y/x, y/x + y)], [x, y])) + raises(ValueError, lambda: linear_eq_to_matrix( + [Eq(x*(x + 1), x**2 + y)], [x, y])) + + +def test_simplification(): + eq = x + (a - b)/(-2*a + 2*b) + assert solveset(eq, x) == FiniteSet(S.Half) + assert solveset(eq, x, S.Reals) == Intersection({-((a - b)/(-2*a + 2*b))}, S.Reals) + # So that ap - bn is not zero: + ap = Symbol('ap', positive=True) + bn = Symbol('bn', negative=True) + eq = x + (ap - bn)/(-2*ap + 2*bn) + assert solveset(eq, x) == FiniteSet(S.Half) + assert solveset(eq, x, S.Reals) == FiniteSet(S.Half) + + +def test_integer_domain_relational(): + eq1 = 2*x + 3 > 0 + eq2 = x**2 + 3*x - 2 >= 0 + eq3 = x + 1/x > -2 + 1/x + eq4 = x + sqrt(x**2 - 5) > 0 + eq = x + 1/x > -2 + 1/x + eq5 = eq.subs(x,log(x)) + eq6 = log(x)/x <= 0 + eq7 = log(x)/x < 0 + eq8 = x/(x-3) < 3 + eq9 = x/(x**2-3) < 3 + + assert solveset(eq1, x, S.Integers) == Range(-1, oo, 1) + assert solveset(eq2, x, S.Integers) == Union(Range(-oo, -3, 1), Range(1, oo, 1)) + assert solveset(eq3, x, S.Integers) == Union(Range(-1, 0, 1), Range(1, oo, 1)) + assert solveset(eq4, x, S.Integers) == Range(3, oo, 1) + assert solveset(eq5, x, S.Integers) == Range(2, oo, 1) + assert solveset(eq6, x, S.Integers) == Range(1, 2, 1) + assert solveset(eq7, x, S.Integers) == S.EmptySet + assert solveset(eq8, x, domain=Range(0,5)) == Range(0, 3, 1) + assert solveset(eq9, x, domain=Range(0,5)) == Union(Range(0, 2, 1), Range(2, 5, 1)) + + # test_issue_19794 + assert solveset(x + 2 < 0, x, S.Integers) == Range(-oo, -2, 1) + + +def test_issue_10555(): + f = Function('f') + g = Function('g') + assert solveset(f(x) - pi/2, x, S.Reals).dummy_eq( + ConditionSet(x, Eq(f(x) - pi/2, 0), S.Reals)) + assert solveset(f(g(x)) - pi/2, g(x), S.Reals).dummy_eq( + ConditionSet(g(x), Eq(f(g(x)) - pi/2, 0), S.Reals)) + + +def test_issue_8715(): + eq = x + 1/x > -2 + 1/x + assert solveset(eq, x, S.Reals) == \ + (Interval.open(-2, oo) - FiniteSet(0)) + assert solveset(eq.subs(x,log(x)), x, S.Reals) == \ + Interval.open(exp(-2), oo) - FiniteSet(1) + + +def test_issue_11174(): + eq = z**2 + exp(2*x) - sin(y) + soln = Intersection(S.Reals, FiniteSet(log(-z**2 + sin(y))/2)) + assert solveset(eq, x, S.Reals) == soln + + eq = sqrt(r)*Abs(tan(t))/sqrt(tan(t)**2 + 1) + x*tan(t) + s = -sqrt(r)*Abs(tan(t))/(sqrt(tan(t)**2 + 1)*tan(t)) + soln = Intersection(S.Reals, FiniteSet(s)) + assert solveset(eq, x, S.Reals) == soln + + +def test_issue_11534(): + # eq1 and eq2 should not have the same solutions because squaring both + # sides of the radical equation introduces a spurious solution branch. + # The equations have a symbolic parameter y and it is easy to see that for + # y != 0 the solution s1 will not be valid for eq1. + x = Symbol('x', real=True) + y = Symbol('y', real=True) + eq1 = -y + x/sqrt(-x**2 + 1) + eq2 = -y**2 + x**2/(-x**2 + 1) + + # We get a ConditionSet here because s1 works in eq1 if y is equal to zero + # although not for any other value of y. That case is redundant though + # because if y=0 then s1=s2 so the solution for eq1 could just be returned + # as s2 - {-1, 1}. In fact we have + # |y/sqrt(y**2 + 1)| < 1 + # So the complements are not needed either. The ideal output here would be + # sol1 = s2 + # sol2 = s1 | s2. + s1, s2 = FiniteSet(-y/sqrt(y**2 + 1)), FiniteSet(y/sqrt(y**2 + 1)) + cset = ConditionSet(x, Eq(eq1, 0), s1) + sol1 = (s2 - {-1, 1}) | (cset - {-1, 1}) + sol2 = (s1 | s2) - {-1, 1} + + assert solveset(eq1, x, S.Reals) == sol1 + assert solveset(eq2, x, S.Reals) == sol2 + + +def test_issue_10477(): + assert solveset((x**2 + 4*x - 3)/x < 2, x, S.Reals) == \ + Union(Interval.open(-oo, -3), Interval.open(0, 1)) + + +def test_issue_10671(): + assert solveset(sin(y), y, Interval(0, pi)) == FiniteSet(0, pi) + i = Interval(1, 10) + assert solveset((1/x).diff(x) < 0, x, i) == i + + +def test_issue_11064(): + eq = x + sqrt(x**2 - 5) + assert solveset(eq > 0, x, S.Reals) == \ + Interval(sqrt(5), oo) + assert solveset(eq < 0, x, S.Reals) == \ + Interval(-oo, -sqrt(5)) + assert solveset(eq > sqrt(5), x, S.Reals) == \ + Interval.Lopen(sqrt(5), oo) + + +def test_issue_12478(): + eq = sqrt(x - 2) + 2 + soln = solveset_real(eq, x) + assert soln is S.EmptySet + assert solveset(eq < 0, x, S.Reals) is S.EmptySet + assert solveset(eq > 0, x, S.Reals) == Interval(2, oo) + + +def test_issue_12429(): + eq = solveset(log(x)/x <= 0, x, S.Reals) + sol = Interval.Lopen(0, 1) + assert eq == sol + + +def test_issue_19506(): + eq = arg(x + I) + C = Dummy('C') + assert solveset(eq).dummy_eq(Intersection(ConditionSet(C, Eq(im(C) + 1, 0), S.Complexes), + ConditionSet(C, re(C) > 0, S.Complexes))) + + +def test_solveset_arg(): + assert solveset(arg(x), x, S.Reals) == Interval.open(0, oo) + assert solveset(arg(4*x -3), x, S.Reals) == Interval.open(Rational(3, 4), oo) + + +def test__is_finite_with_finite_vars(): + f = _is_finite_with_finite_vars + # issue 12482 + assert all(f(1/x) is None for x in ( + Dummy(), Dummy(real=True), Dummy(complex=True))) + assert f(1/Dummy(real=False)) is True # b/c it's finite but not 0 + + +def test_issue_13550(): + assert solveset(x**2 - 2*x - 15, symbol = x, domain = Interval(-oo, 0)) == FiniteSet(-3) + + +def test_issue_13849(): + assert nonlinsolve((t*(sqrt(5) + sqrt(2)) - sqrt(2), t), t) is S.EmptySet + + +def test_issue_14223(): + assert solveset((Abs(x + Min(x, 2)) - 2).rewrite(Piecewise), x, + S.Reals) == FiniteSet(-1, 1) + assert solveset((Abs(x + Min(x, 2)) - 2).rewrite(Piecewise), x, + Interval(0, 2)) == FiniteSet(1) + assert solveset(x, x, FiniteSet(1, 2)) is S.EmptySet + + +def test_issue_10158(): + dom = S.Reals + assert solveset(x*Max(x, 15) - 10, x, dom) == FiniteSet(Rational(2, 3)) + assert solveset(x*Min(x, 15) - 10, x, dom) == FiniteSet(-sqrt(10), sqrt(10)) + assert solveset(Max(Abs(x - 3) - 1, x + 2) - 3, x, dom) == FiniteSet(-1, 1) + assert solveset(Abs(x - 1) - Abs(y), x, dom) == FiniteSet(-Abs(y) + 1, Abs(y) + 1) + assert solveset(Abs(x + 4*Abs(x + 1)), x, dom) == FiniteSet(Rational(-4, 3), Rational(-4, 5)) + assert solveset(2*Abs(x + Abs(x + Max(3, x))) - 2, x, S.Reals) == FiniteSet(-1, -2) + dom = S.Complexes + raises(ValueError, lambda: solveset(x*Max(x, 15) - 10, x, dom)) + raises(ValueError, lambda: solveset(x*Min(x, 15) - 10, x, dom)) + raises(ValueError, lambda: solveset(Max(Abs(x - 3) - 1, x + 2) - 3, x, dom)) + raises(ValueError, lambda: solveset(Abs(x - 1) - Abs(y), x, dom)) + raises(ValueError, lambda: solveset(Abs(x + 4*Abs(x + 1)), x, dom)) + + +def test_issue_14300(): + f = 1 - exp(-18000000*x) - y + a1 = FiniteSet(-log(-y + 1)/18000000) + + assert solveset(f, x, S.Reals) == \ + Intersection(S.Reals, a1) + assert dumeq(solveset(f, x), + ImageSet(Lambda(n, -I*(2*n*pi + arg(-y + 1))/18000000 - + log(Abs(y - 1))/18000000), S.Integers)) + + +def test_issue_14454(): + number = CRootOf(x**4 + x - 1, 2) + raises(ValueError, lambda: invert_real(number, 0, x)) + assert invert_real(x**2, number, x) # no error + + +def test_issue_17882(): + assert solveset(-8*x**2/(9*(x**2 - 1)**(S(4)/3)) + 4/(3*(x**2 - 1)**(S(1)/3)), x, S.Complexes) == \ + FiniteSet(sqrt(3), -sqrt(3)) + + +def test_term_factors(): + assert list(_term_factors(3**x - 2)) == [-2, 3**x] + expr = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + assert set(_term_factors(expr)) == { + 3**(x + 2), 4**(x + 2), 3**(x + 3), 4**(x - 1), -1, 4**(x + 1)} + + +#################### tests for transolve and its helpers ############### + +def test_transolve(): + + assert _transolve(3**x, x, S.Reals) == S.EmptySet + assert _transolve(3**x - 9**(x + 5), x, S.Reals) == FiniteSet(-10) + + +def test_issue_21276(): + eq = (2*x*(y - z) - y*erf(y - z) - y + z*erf(y - z) + z)**2 + assert solveset(eq.expand(), y) == FiniteSet(z, z + erfinv(2*x - 1)) + + +# exponential tests +def test_exponential_real(): + from sympy.abc import y + + e1 = 3**(2*x) - 2**(x + 3) + e2 = 4**(5 - 9*x) - 8**(2 - x) + e3 = 2**x + 4**x + e4 = exp(log(5)*x) - 2**x + e5 = exp(x/y)*exp(-z/y) - 2 + e6 = 5**(x/2) - 2**(x/3) + e7 = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + e8 = -9*exp(-2*x + 5) + 4*exp(3*x + 1) + e9 = 2**x + 4**x + 8**x - 84 + e10 = 29*2**(x + 1)*615**(x) - 123*2726**(x) + + assert solveset(e1, x, S.Reals) == FiniteSet( + -3*log(2)/(-2*log(3) + log(2))) + assert solveset(e2, x, S.Reals) == FiniteSet(Rational(4, 15)) + assert solveset(e3, x, S.Reals) == S.EmptySet + assert solveset(e4, x, S.Reals) == FiniteSet(0) + assert solveset(e5, x, S.Reals) == Intersection( + S.Reals, FiniteSet(y*log(2*exp(z/y)))) + assert solveset(e6, x, S.Reals) == FiniteSet(0) + assert solveset(e7, x, S.Reals) == FiniteSet(2) + assert solveset(e8, x, S.Reals) == FiniteSet(-2*log(2)/5 + 2*log(3)/5 + Rational(4, 5)) + assert solveset(e9, x, S.Reals) == FiniteSet(2) + assert solveset(e10,x, S.Reals) == FiniteSet((-log(29) - log(2) + log(123))/(-log(2726) + log(2) + log(615))) + + assert solveset_real(-9*exp(-2*x + 5) + 2**(x + 1), x) == FiniteSet( + -((-5 - 2*log(3) + log(2))/(log(2) + 2))) + assert solveset_real(4**(x/2) - 2**(x/3), x) == FiniteSet(0) + b = sqrt(6)*sqrt(log(2))/sqrt(log(5)) + assert solveset_real(5**(x/2) - 2**(3/x), x) == FiniteSet(-b, b) + + # coverage test + C1, C2 = symbols('C1 C2') + f = Function('f') + assert solveset_real(C1 + C2/x**2 - exp(-f(x)), f(x)) == Intersection( + S.Reals, FiniteSet(-log(C1 + C2/x**2))) + y = symbols('y', positive=True) + assert solveset_real(x**2 - y**2/exp(x), y) == Intersection( + S.Reals, FiniteSet(-sqrt(x**2*exp(x)), sqrt(x**2*exp(x)))) + p = Symbol('p', positive=True) + assert solveset_real((1/p + 1)**(p + 1), p).dummy_eq( + ConditionSet(x, Eq((1 + 1/x)**(x + 1), 0), S.Reals)) + assert solveset(2**x - 4**x + 12, x, S.Reals) == {2} + assert solveset(2**x - 2**(2*x) + 12, x, S.Reals) == {2} + + +@XFAIL +def test_exponential_complex(): + n = Dummy('n') + + assert dumeq(solveset_complex(2**x + 4**x, x),imageset( + Lambda(n, I*(2*n*pi + pi)/log(2)), S.Integers)) + assert solveset_complex(x**z*y**z - 2, z) == FiniteSet( + log(2)/(log(x) + log(y))) + assert dumeq(solveset_complex(4**(x/2) - 2**(x/3), x), imageset( + Lambda(n, 3*n*I*pi/log(2)), S.Integers)) + assert dumeq(solveset(2**x + 32, x), imageset( + Lambda(n, (I*(2*n*pi + pi) + 5*log(2))/log(2)), S.Integers)) + + eq = (2**exp(y**2/x) + 2)/(x**2 + 15) + a = sqrt(x)*sqrt(-log(log(2)) + log(log(2) + 2*n*I*pi)) + assert solveset_complex(eq, y) == FiniteSet(-a, a) + + union1 = imageset(Lambda(n, I*(2*n*pi - pi*Rational(2, 3))/log(2)), S.Integers) + union2 = imageset(Lambda(n, I*(2*n*pi + pi*Rational(2, 3))/log(2)), S.Integers) + assert dumeq(solveset(2**x + 4**x + 8**x, x), Union(union1, union2)) + + eq = 4**(x + 1) + 4**(x + 2) + 4**(x - 1) - 3**(x + 2) - 3**(x + 3) + res = solveset(eq, x) + num = 2*n*I*pi - 4*log(2) + 2*log(3) + den = -2*log(2) + log(3) + ans = imageset(Lambda(n, num/den), S.Integers) + assert dumeq(res, ans) + + +def test_expo_conditionset(): + + f1 = (exp(x) + 1)**x - 2 + f2 = (x + 2)**y*x - 3 + f3 = 2**x - exp(x) - 3 + f4 = log(x) - exp(x) + f5 = 2**x + 3**x - 5**x + + assert solveset(f1, x, S.Reals).dummy_eq(ConditionSet( + x, Eq((exp(x) + 1)**x - 2, 0), S.Reals)) + assert solveset(f2, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(x*(x + 2)**y - 3, 0), S.Reals)) + assert solveset(f3, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(2**x - exp(x) - 3, 0), S.Reals)) + assert solveset(f4, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(-exp(x) + log(x), 0), S.Reals)) + assert solveset(f5, x, S.Reals).dummy_eq(ConditionSet( + x, Eq(2**x + 3**x - 5**x, 0), S.Reals)) + + +def test_exponential_symbols(): + x, y, z = symbols('x y z', positive=True) + xr, zr = symbols('xr, zr', real=True) + + assert solveset(z**x - y, x, S.Reals) == Intersection( + S.Reals, FiniteSet(log(y)/log(z))) + + f1 = 2*x**w - 4*y**w + f2 = (x/y)**w - 2 + sol1 = Intersection({log(2)/(log(x) - log(y))}, S.Reals) + sol2 = Intersection({log(2)/log(x/y)}, S.Reals) + assert solveset(f1, w, S.Reals) == sol1, solveset(f1, w, S.Reals) + assert solveset(f2, w, S.Reals) == sol2, solveset(f2, w, S.Reals) + + assert solveset(x**x, x, Interval.Lopen(0,oo)).dummy_eq( + ConditionSet(w, Eq(w**w, 0), Interval.open(0, oo))) + assert solveset(x**y - 1, y, S.Reals) == FiniteSet(0) + assert solveset(exp(x/y)*exp(-z/y) - 2, y, S.Reals) == \ + Complement(ConditionSet(y, Eq(im(x)/y, 0) & Eq(im(z)/y, 0), \ + Complement(Intersection(FiniteSet((x - z)/log(2)), S.Reals), FiniteSet(0))), FiniteSet(0)) + assert solveset(exp(xr/y)*exp(-zr/y) - 2, y, S.Reals) == \ + Complement(FiniteSet((xr - zr)/log(2)), FiniteSet(0)) + + assert solveset(a**x - b**x, x).dummy_eq(ConditionSet( + w, Ne(a, 0) & Ne(b, 0), FiniteSet(0))) + + +def test_ignore_assumptions(): + # make sure assumptions are ignored + xpos = symbols('x', positive=True) + x = symbols('x') + assert solveset_complex(xpos**2 - 4, xpos + ) == solveset_complex(x**2 - 4, x) + + +@XFAIL +def test_issue_10864(): + assert solveset(x**(y*z) - x, x, S.Reals) == FiniteSet(1) + + +@XFAIL +def test_solve_only_exp_2(): + assert solveset_real(sqrt(exp(x)) + sqrt(exp(-x)) - 4, x) == \ + FiniteSet(2*log(-sqrt(3) + 2), 2*log(sqrt(3) + 2)) + + +def test_is_exponential(): + assert _is_exponential(y, x) is False + assert _is_exponential(3**x - 2, x) is True + assert _is_exponential(5**x - 7**(2 - x), x) is True + assert _is_exponential(sin(2**x) - 4*x, x) is False + assert _is_exponential(x**y - z, y) is True + assert _is_exponential(x**y - z, x) is False + assert _is_exponential(2**x + 4**x - 1, x) is True + assert _is_exponential(x**(y*z) - x, x) is False + assert _is_exponential(x**(2*x) - 3**x, x) is False + assert _is_exponential(x**y - y*z, y) is False + assert _is_exponential(x**y - x*z, y) is True + + +def test_solve_exponential(): + assert _solve_exponential(3**(2*x) - 2**(x + 3), 0, x, S.Reals) == \ + FiniteSet(-3*log(2)/(-2*log(3) + log(2))) + assert _solve_exponential(2**y + 4**y, 1, y, S.Reals) == \ + FiniteSet(log(Rational(-1, 2) + sqrt(5)/2)/log(2)) + assert _solve_exponential(2**y + 4**y, 0, y, S.Reals) == \ + S.EmptySet + assert _solve_exponential(2**x + 3**x - 5**x, 0, x, S.Reals) == \ + ConditionSet(x, Eq(2**x + 3**x - 5**x, 0), S.Reals) + +# end of exponential tests + + +# logarithmic tests +def test_logarithmic(): + assert solveset_real(log(x - 3) + log(x + 3), x) == FiniteSet( + -sqrt(10), sqrt(10)) + assert solveset_real(log(x + 1) - log(2*x - 1), x) == FiniteSet(2) + assert solveset_real(log(x + 3) + log(1 + 3/x) - 3, x) == FiniteSet( + -3 + sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 + exp(3)/2, + -sqrt(-12 + exp(3))*exp(Rational(3, 2))/2 - 3 + exp(3)/2) + + eq = z - log(x) + log(y/(x*(-1 + y**2/x**2))) + assert solveset_real(eq, x) == \ + Intersection(S.Reals, FiniteSet(-sqrt(y**2 - y*exp(z)), + sqrt(y**2 - y*exp(z)))) - \ + Intersection(S.Reals, FiniteSet(-sqrt(y**2), sqrt(y**2))) + assert solveset_real( + log(3*x) - log(-x + 1) - log(4*x + 1), x) == FiniteSet(Rational(-1, 2), S.Half) + assert solveset(log(x**y) - y*log(x), x, S.Reals) == S.Reals + +@XFAIL +def test_uselogcombine_2(): + eq = log(exp(2*x) + 1) + log(-tanh(x) + 1) - log(2) + assert solveset_real(eq, x) is S.EmptySet + eq = log(8*x) - log(sqrt(x) + 1) - 2 + assert solveset_real(eq, x) is S.EmptySet + + +def test_is_logarithmic(): + assert _is_logarithmic(y, x) is False + assert _is_logarithmic(log(x), x) is True + assert _is_logarithmic(log(x) - 3, x) is True + assert _is_logarithmic(log(x)*log(y), x) is True + assert _is_logarithmic(log(x)**2, x) is False + assert _is_logarithmic(log(x - 3) + log(x + 3), x) is True + assert _is_logarithmic(log(x**y) - y*log(x), x) is True + assert _is_logarithmic(sin(log(x)), x) is False + assert _is_logarithmic(x + y, x) is False + assert _is_logarithmic(log(3*x) - log(1 - x) + 4, x) is True + assert _is_logarithmic(log(x) + log(y) + x, x) is False + assert _is_logarithmic(log(log(x - 3)) + log(x - 3), x) is True + assert _is_logarithmic(log(log(3) + x) + log(x), x) is True + assert _is_logarithmic(log(x)*(y + 3) + log(x), y) is False + + +def test_solve_logarithm(): + y = Symbol('y') + assert _solve_logarithm(log(x**y) - y*log(x), 0, x, S.Reals) == S.Reals + y = Symbol('y', positive=True) + assert _solve_logarithm(log(x)*log(y), 0, x, S.Reals) == FiniteSet(1) + +# end of logarithmic tests + + +# lambert tests +def test_is_lambert(): + a, b, c = symbols('a,b,c') + assert _is_lambert(x**2, x) is False + assert _is_lambert(a**x**2+b*x+c, x) is True + assert _is_lambert(E**2, x) is False + assert _is_lambert(x*E**2, x) is False + assert _is_lambert(3*log(x) - x*log(3), x) is True + assert _is_lambert(log(log(x - 3)) + log(x-3), x) is True + assert _is_lambert(5*x - 1 + 3*exp(2 - 7*x), x) is True + assert _is_lambert((a/x + exp(x/2)).diff(x, 2), x) is True + assert _is_lambert((x**2 - 2*x + 1).subs(x, (log(x) + 3*x)**2 - 1), x) is True + assert _is_lambert(x*sinh(x) - 1, x) is True + assert _is_lambert(x*cos(x) - 5, x) is True + assert _is_lambert(tanh(x) - 5*x, x) is True + assert _is_lambert(cosh(x) - sinh(x), x) is False + +# end of lambert tests + + +def test_linear_coeffs(): + from sympy.solvers.solveset import linear_coeffs + assert linear_coeffs(0, x) == [0, 0] + assert all(i is S.Zero for i in linear_coeffs(0, x)) + assert linear_coeffs(x + 2*y + 3, x, y) == [1, 2, 3] + assert linear_coeffs(x + 2*y + 3, y, x) == [2, 1, 3] + assert linear_coeffs(x + 2*x**2 + 3, x, x**2) == [1, 2, 3] + raises(ValueError, lambda: + linear_coeffs(x + 2*x**2 + x**3, x, x**2)) + raises(ValueError, lambda: + linear_coeffs(1/x*(x - 1) + 1/x, x)) + raises(ValueError, lambda: + linear_coeffs(x, x, x)) + assert linear_coeffs(a*(x + y), x, y) == [a, a, 0] + assert linear_coeffs(1.0, x, y) == [0, 0, 1.0] + # don't include coefficients of 0 + assert linear_coeffs(Eq(x, x + y), x, y, dict=True) == {y: -1} + assert linear_coeffs(0, x, y, dict=True) == {} + + +def test_is_modular(): + assert _is_modular(y, x) is False + assert _is_modular(Mod(x, 3) - 1, x) is True + assert _is_modular(Mod(x**3 - 3*x**2 - x + 1, 3) - 1, x) is True + assert _is_modular(Mod(exp(x + y), 3) - 2, x) is True + assert _is_modular(Mod(exp(x + y), 3) - log(x), x) is True + assert _is_modular(Mod(x, 3) - 1, y) is False + assert _is_modular(Mod(x, 3)**2 - 5, x) is False + assert _is_modular(Mod(x, 3)**2 - y, x) is False + assert _is_modular(exp(Mod(x, 3)) - 1, x) is False + assert _is_modular(Mod(3, y) - 1, y) is False + + +def test_invert_modular(): + n = Dummy('n', integer=True) + from sympy.solvers.solveset import _invert_modular as invert_modular + + # no solutions + assert invert_modular(Mod(x, 12), S(1)/2, n, x) == (x, S.EmptySet) + # non invertible cases + assert invert_modular(Mod(sin(x), 7), S(5), n, x) == (Mod(sin(x), 7), 5) + assert invert_modular(Mod(exp(x), 7), S(5), n, x) == (Mod(exp(x), 7), 5) + assert invert_modular(Mod(log(x), 7), S(5), n, x) == (Mod(log(x), 7), 5) + # a is symbol + assert dumeq(invert_modular(Mod(x, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 5), S.Integers))) + # a.is_Add + assert dumeq(invert_modular(Mod(x + 8, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 4), S.Integers))) + assert invert_modular(Mod(x**2 + x, 7), S(5), n, x) == \ + (Mod(x**2 + x, 7), 5) + # a.is_Mul + assert dumeq(invert_modular(Mod(3*x, 7), S(5), n, x), + (x, ImageSet(Lambda(n, 7*n + 4), S.Integers))) + assert invert_modular(Mod((x + 1)*(x + 2), 7), S(5), n, x) == \ + (Mod((x + 1)*(x + 2), 7), 5) + # a.is_Pow + assert invert_modular(Mod(x**4, 7), S(5), n, x) == \ + (x, S.EmptySet) + assert dumeq(invert_modular(Mod(3**x, 4), S(3), n, x), + (x, ImageSet(Lambda(n, 2*n + 1), S.Naturals0))) + assert dumeq(invert_modular(Mod(2**(x**2 + x + 1), 7), S(2), n, x), + (x**2 + x + 1, ImageSet(Lambda(n, 3*n + 1), S.Naturals0))) + assert invert_modular(Mod(sin(x)**4, 7), S(5), n, x) == (x, S.EmptySet) + + +def test_solve_modular(): + n = Dummy('n', integer=True) + # if rhs has symbol (need to be implemented in future). + assert solveset(Mod(x, 4) - x, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(-x + Mod(x, 4), 0), + S.Integers)) + # when _invert_modular fails to invert + assert solveset(3 - Mod(sin(x), 7), x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(sin(x), 7) - 3, 0), S.Integers)) + assert solveset(3 - Mod(log(x), 7), x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(log(x), 7) - 3, 0), S.Integers)) + assert solveset(3 - Mod(exp(x), 7), x, S.Integers + ).dummy_eq(ConditionSet(x, Eq(Mod(exp(x), 7) - 3, 0), + S.Integers)) + # EmptySet solution definitely + assert solveset(7 - Mod(x, 5), x, S.Integers) is S.EmptySet + assert solveset(5 - Mod(x, 5), x, S.Integers) is S.EmptySet + # Negative m + assert dumeq(solveset(2 + Mod(x, -3), x, S.Integers), + ImageSet(Lambda(n, -3*n - 2), S.Integers)) + assert solveset(4 + Mod(x, -3), x, S.Integers) is S.EmptySet + # linear expression in Mod + assert dumeq(solveset(3 - Mod(x, 5), x, S.Integers), + ImageSet(Lambda(n, 5*n + 3), S.Integers)) + assert dumeq(solveset(3 - Mod(5*x - 8, 7), x, S.Integers), + ImageSet(Lambda(n, 7*n + 5), S.Integers)) + assert dumeq(solveset(3 - Mod(5*x, 7), x, S.Integers), + ImageSet(Lambda(n, 7*n + 2), S.Integers)) + # higher degree expression in Mod + assert dumeq(solveset(Mod(x**2, 160) - 9, x, S.Integers), + Union(ImageSet(Lambda(n, 160*n + 3), S.Integers), + ImageSet(Lambda(n, 160*n + 13), S.Integers), + ImageSet(Lambda(n, 160*n + 67), S.Integers), + ImageSet(Lambda(n, 160*n + 77), S.Integers), + ImageSet(Lambda(n, 160*n + 83), S.Integers), + ImageSet(Lambda(n, 160*n + 93), S.Integers), + ImageSet(Lambda(n, 160*n + 147), S.Integers), + ImageSet(Lambda(n, 160*n + 157), S.Integers))) + assert solveset(3 - Mod(x**4, 7), x, S.Integers) is S.EmptySet + assert dumeq(solveset(Mod(x**4, 17) - 13, x, S.Integers), + Union(ImageSet(Lambda(n, 17*n + 3), S.Integers), + ImageSet(Lambda(n, 17*n + 5), S.Integers), + ImageSet(Lambda(n, 17*n + 12), S.Integers), + ImageSet(Lambda(n, 17*n + 14), S.Integers))) + # a.is_Pow tests + assert dumeq(solveset(Mod(7**x, 41) - 15, x, S.Integers), + ImageSet(Lambda(n, 40*n + 3), S.Naturals0)) + assert dumeq(solveset(Mod(12**x, 21) - 18, x, S.Integers), + ImageSet(Lambda(n, 6*n + 2), S.Naturals0)) + assert dumeq(solveset(Mod(3**x, 4) - 3, x, S.Integers), + ImageSet(Lambda(n, 2*n + 1), S.Naturals0)) + assert dumeq(solveset(Mod(2**x, 7) - 2 , x, S.Integers), + ImageSet(Lambda(n, 3*n + 1), S.Naturals0)) + assert dumeq(solveset(Mod(3**(3**x), 4) - 3, x, S.Integers), + Intersection(ImageSet(Lambda(n, Intersection({log(2*n + 1)/log(3)}, + S.Integers)), S.Naturals0), S.Integers)) + # Implemented for m without primitive root + assert solveset(Mod(x**3, 7) - 2, x, S.Integers) is S.EmptySet + assert dumeq(solveset(Mod(x**3, 8) - 1, x, S.Integers), + ImageSet(Lambda(n, 8*n + 1), S.Integers)) + assert dumeq(solveset(Mod(x**4, 9) - 4, x, S.Integers), + Union(ImageSet(Lambda(n, 9*n + 4), S.Integers), + ImageSet(Lambda(n, 9*n + 5), S.Integers))) + # domain intersection + assert dumeq(solveset(3 - Mod(5*x - 8, 7), x, S.Naturals0), + Intersection(ImageSet(Lambda(n, 7*n + 5), S.Integers), S.Naturals0)) + # Complex args + assert solveset(Mod(x, 3) - I, x, S.Integers) == \ + S.EmptySet + assert solveset(Mod(I*x, 3) - 2, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(I*x, 3) - 2, 0), S.Integers)) + assert solveset(Mod(I + x, 3) - 2, x, S.Integers + ).dummy_eq( + ConditionSet(x, Eq(Mod(x + I, 3) - 2, 0), S.Integers)) + + # issue 17373 (https://github.com/sympy/sympy/issues/17373) + assert dumeq(solveset(Mod(x**4, 14) - 11, x, S.Integers), + Union(ImageSet(Lambda(n, 14*n + 3), S.Integers), + ImageSet(Lambda(n, 14*n + 11), S.Integers))) + assert dumeq(solveset(Mod(x**31, 74) - 43, x, S.Integers), + ImageSet(Lambda(n, 74*n + 31), S.Integers)) + + # issue 13178 + n = symbols('n', integer=True) + a = 742938285 + b = 1898888478 + m = 2**31 - 1 + c = 20170816 + assert dumeq(solveset(c - Mod(a**n*b, m), n, S.Integers), + ImageSet(Lambda(n, 2147483646*n + 100), S.Naturals0)) + assert dumeq(solveset(c - Mod(a**n*b, m), n, S.Naturals0), + Intersection(ImageSet(Lambda(n, 2147483646*n + 100), S.Naturals0), + S.Naturals0)) + assert dumeq(solveset(c - Mod(a**(2*n)*b, m), n, S.Integers), + Intersection(ImageSet(Lambda(n, 1073741823*n + 50), S.Naturals0), + S.Integers)) + assert solveset(c - Mod(a**(2*n + 7)*b, m), n, S.Integers) is S.EmptySet + assert dumeq(solveset(c - Mod(a**(n - 4)*b, m), n, S.Integers), + Intersection(ImageSet(Lambda(n, 2147483646*n + 104), S.Naturals0), + S.Integers)) + +# end of modular tests + +def test_issue_17276(): + assert nonlinsolve([Eq(x, 5**(S(1)/5)), Eq(x*y, 25*sqrt(5))], x, y) == \ + FiniteSet((5**(S(1)/5), 25*5**(S(3)/10))) + + +def test_issue_10426(): + x = Dummy('x') + a = Symbol('a') + n = Dummy('n') + assert (solveset(sin(x + a) - sin(x), a)).dummy_eq(Dummy('x')) == (Union( + ImageSet(Lambda(n, 2*n*pi), S.Integers), + Intersection(S.Complexes, ImageSet(Lambda(n, -I*(I*(2*n*pi + arg(-exp(-2*I*x))) + 2*im(x))), + S.Integers)))).dummy_eq(Dummy('x,n')) + + +def test_solveset_conjugate(): + """Test solveset for simple conjugate functions""" + assert solveset(conjugate(x) -3 + I) == FiniteSet(3 + I) + + +def test_issue_18208(): + variables = symbols('x0:16') + symbols('y0:12') + x0, x1, x2, x3, x4, x5, x6, x7, x8, x9, x10, x11, x12, x13, x14, x15,\ + y0, y1, y2, y3, y4, y5, y6, y7, y8, y9, y10, y11 = variables + + eqs = [x0 + x1 + x2 + x3 - 51, + x0 + x1 + x4 + x5 - 46, + x2 + x3 + x6 + x7 - 39, + x0 + x3 + x4 + x7 - 50, + x1 + x2 + x5 + x6 - 35, + x4 + x5 + x6 + x7 - 34, + x4 + x5 + x8 + x9 - 46, + x10 + x11 + x6 + x7 - 23, + x11 + x4 + x7 + x8 - 25, + x10 + x5 + x6 + x9 - 44, + x10 + x11 + x8 + x9 - 35, + x12 + x13 + x8 + x9 - 35, + x10 + x11 + x14 + x15 - 29, + x11 + x12 + x15 + x8 - 35, + x10 + x13 + x14 + x9 - 29, + x12 + x13 + x14 + x15 - 29, + y0 + y1 + y2 + y3 - 55, + y0 + y1 + y4 + y5 - 53, + y2 + y3 + y6 + y7 - 56, + y0 + y3 + y4 + y7 - 57, + y1 + y2 + y5 + y6 - 52, + y4 + y5 + y6 + y7 - 54, + y4 + y5 + y8 + y9 - 48, + y10 + y11 + y6 + y7 - 60, + y11 + y4 + y7 + y8 - 51, + y10 + y5 + y6 + y9 - 57, + y10 + y11 + y8 + y9 - 54, + x10 - 2, + x11 - 5, + x12 - 1, + x13 - 6, + x14 - 1, + x15 - 21, + y0 - 12, + y1 - 20] + + expected = [38 - x3, x3 - 10, 23 - x3, x3, 12 - x7, x7 + 6, 16 - x7, x7, + 8, 20, 2, 5, 1, 6, 1, 21, 12, 20, -y11 + y9 + 2, y11 - y9 + 21, + -y11 - y7 + y9 + 24, y11 + y7 - y9 - 3, 33 - y7, y7, 27 - y9, y9, + 27 - y11, y11] + + A, b = linear_eq_to_matrix(eqs, variables) + + # solve + solve_expected = {v:eq for v, eq in zip(variables, expected) if v != eq} + + assert solve(eqs, variables) == solve_expected + + # linsolve + linsolve_expected = FiniteSet(Tuple(*expected)) + + assert linsolve(eqs, variables) == linsolve_expected + assert linsolve((A, b), variables) == linsolve_expected + + # gauss_jordan_solve + gj_solve, new_vars = A.gauss_jordan_solve(b) + gj_solve = list(gj_solve) + + gj_expected = linsolve_expected.subs(zip([x3, x7, y7, y9, y11], new_vars)) + + assert FiniteSet(Tuple(*gj_solve)) == gj_expected + + # nonlinsolve + # The solution set of nonlinsolve is currently equivalent to linsolve and is + # also correct. However, we would prefer to use the same symbols as parameters + # for the solution to the underdetermined system in all cases if possible. + # We want a solution that is not just equivalent but also given in the same form. + # This test may be changed should nonlinsolve be modified in this way. + + nonlinsolve_expected = FiniteSet((38 - x3, x3 - 10, 23 - x3, x3, 12 - x7, x7 + 6, + 16 - x7, x7, 8, 20, 2, 5, 1, 6, 1, 21, 12, 20, + -y5 + y7 - 1, y5 - y7 + 24, 21 - y5, y5, 33 - y7, + y7, 27 - y9, y9, -y5 + y7 - y9 + 24, y5 - y7 + y9 + 3)) + + assert nonlinsolve(eqs, variables) == nonlinsolve_expected + + +def test_substitution_with_infeasible_solution(): + a00, a01, a10, a11, l0, l1, l2, l3, m0, m1, m2, m3, m4, m5, m6, m7, c00, c01, c10, c11, p00, p01, p10, p11 = symbols( + 'a00, a01, a10, a11, l0, l1, l2, l3, m0, m1, m2, m3, m4, m5, m6, m7, c00, c01, c10, c11, p00, p01, p10, p11' + ) + solvefor = [p00, p01, p10, p11, c00, c01, c10, c11, m0, m1, m3, l0, l1, l2, l3] + system = [ + -l0 * c00 - l1 * c01 + m0 + c00 + c01, + -l0 * c10 - l1 * c11 + m1, + -l2 * c00 - l3 * c01 + c00 + c01, + -l2 * c10 - l3 * c11 + m3, + -l0 * p00 - l2 * p10 + p00 + p10, + -l1 * p00 - l3 * p10 + p00 + p10, + -l0 * p01 - l2 * p11, + -l1 * p01 - l3 * p11, + -a00 + c00 * p00 + c10 * p01, + -a01 + c01 * p00 + c11 * p01, + -a10 + c00 * p10 + c10 * p11, + -a11 + c01 * p10 + c11 * p11, + -m0 * p00, + -m1 * p01, + -m2 * p10, + -m3 * p11, + -m4 * c00, + -m5 * c01, + -m6 * c10, + -m7 * c11, + m2, + m4, + m5, + m6, + m7 + ] + sol = FiniteSet( + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, l2, l3), + (p00, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, 1, 1, -p01/p11, -p01/p11), + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, 1, -l3*p11/p01, -p01/p11, l3), + (0, Complement(FiniteSet(p01), FiniteSet(0)), 0, p11, 0, 0, 0, 0, 0, 0, 0, -l2*p11/p01, -l3*p11/p01, l2, l3), + ) + assert sol != nonlinsolve(system, solvefor) + + +def test_issue_20097(): + assert solveset(1/sqrt(x)) is S.EmptySet + + +def test_issue_15350(): + assert solveset(diff(sqrt(1/x+x))) == FiniteSet(-1, 1) + + +def test_issue_18359(): + c1 = Piecewise((0, x < 0), (Min(1, x)/2 - Min(2, x)/2 + Min(3, x)/2, True)) + c2 = Piecewise((Piecewise((0, x < 0), (Min(1, x)/2 - Min(2, x)/2 + Min(3, x)/2, True)), x >= 0), (0, True)) + correct_result = Interval(1, 2) + result1 = solveset(c1 - Rational(1, 2), x, Interval(0, 3)) + result2 = solveset(c2 - Rational(1, 2), x, Interval(0, 3)) + assert result1 == correct_result + assert result2 == correct_result + + +def test_issue_17604(): + lhs = -2**(3*x/11)*exp(x/11) + pi**(x/11) + assert _is_exponential(lhs, x) + assert _solve_exponential(lhs, 0, x, S.Complexes) == FiniteSet(0) + + +def test_issue_17580(): + assert solveset(1/(1 - x**3)**2, x, S.Reals) is S.EmptySet + + +def test_issue_17566_actual(): + sys = [2**x + 2**y - 3, 4**x + 9**y - 5] + # Not clear this is the correct result, but at least no recursion error + assert nonlinsolve(sys, x, y) == FiniteSet((log(3 - 2**y)/log(2), y)) + + +def test_issue_17565(): + eq = Ge(2*(x - 2)**2/(3*(x + 1)**(Integer(1)/3)) + 2*(x - 2)*(x + 1)**(Integer(2)/3), 0) + res = Union(Interval.Lopen(-1, -Rational(1, 4)), Interval(2, oo)) + assert solveset(eq, x, S.Reals) == res + + +def test_issue_15024(): + function = (x + 5)/sqrt(-x**2 - 10*x) + assert solveset(function, x, S.Reals) == FiniteSet(Integer(-5)) + + +def test_issue_16877(): + assert dumeq(nonlinsolve([x - 1, sin(y)], x, y), + FiniteSet((1, ImageSet(Lambda(n, 2*n*pi), S.Integers)), + (1, ImageSet(Lambda(n, 2*n*pi + pi), S.Integers)))) + # Even better if (1, ImageSet(Lambda(n, n*pi), S.Integers)) is obtained + + +def test_issue_16876(): + assert dumeq(nonlinsolve([sin(x), 2*x - 4*y], x, y), + FiniteSet((ImageSet(Lambda(n, 2*n*pi), S.Integers), + ImageSet(Lambda(n, n*pi), S.Integers)), + (ImageSet(Lambda(n, 2*n*pi + pi), S.Integers), + ImageSet(Lambda(n, n*pi + pi/2), S.Integers)))) + # Even better if (ImageSet(Lambda(n, n*pi), S.Integers), + # ImageSet(Lambda(n, n*pi/2), S.Integers)) is obtained + +def test_issue_21236(): + x, z = symbols("x z") + y = symbols('y', rational=True) + assert solveset(x**y - z, x, S.Reals) == ConditionSet(x, Eq(x**y - z, 0), S.Reals) + e1, e2 = symbols('e1 e2', even=True) + y = e1/e2 # don't know if num or den will be odd and the other even + assert solveset(x**y - z, x, S.Reals) == ConditionSet(x, Eq(x**y - z, 0), S.Reals) + + +def test_issue_21908(): + assert nonlinsolve([(x**2 + 2*x - y**2)*exp(x), -2*y*exp(x)], x, y + ) == {(-2, 0), (0, 0)} + + +def test_issue_19144(): + # test case 1 + expr1 = [x + y - 1, y**2 + 1] + eq1 = [Eq(i, 0) for i in expr1] + soln1 = {(1 - I, I), (1 + I, -I)} + soln_expr1 = nonlinsolve(expr1, [x, y]) + soln_eq1 = nonlinsolve(eq1, [x, y]) + assert soln_eq1 == soln_expr1 == soln1 + # test case 2 - with denoms + expr2 = [x/y - 1, y**2 + 1] + eq2 = [Eq(i, 0) for i in expr2] + soln2 = {(-I, -I), (I, I)} + soln_expr2 = nonlinsolve(expr2, [x, y]) + soln_eq2 = nonlinsolve(eq2, [x, y]) + assert soln_eq2 == soln_expr2 == soln2 + # denominators that cancel in expression + assert nonlinsolve([Eq(x + 1/x, 1/x)], [x]) == FiniteSet((S.EmptySet,)) + + +def test_issue_22413(): + res = nonlinsolve((4*y*(2*x + 2*exp(y) + 1)*exp(2*x), + 4*x*exp(2*x) + 4*y*exp(2*x + y) + 4*exp(2*x + y) + 1), + x, y) + # First solution is not correct, but the issue was an exception + sols = FiniteSet((x, S.Zero), (-exp(y) - S.Half, y)) + assert res == sols + + +def test_issue_23318(): + eqs_eq = [ + Eq(53.5780461486929, x * log(y / (5.0 - y) + 1) / y), + Eq(x, 0.0015 * z), + Eq(0.0015, 7845.32 * y / z), + ] + eqs_expr = [eq.lhs - eq.rhs for eq in eqs_eq] + + sol = {(266.97755814852, 0.0340301680681629, 177985.03876568)} + + assert_close_nl(nonlinsolve(eqs_eq, [x, y, z]), sol) + assert_close_nl(nonlinsolve(eqs_expr, [x, y, z]), sol) + + logterm = log(1.91196789933362e-7*z/(5.0 - 1.91196789933362e-7*z) + 1) + eq = -0.0015*z*logterm + 1.02439504345316e-5*z + assert_close_ss(solveset(eq, z), {0, 177985.038765679}) + + +def test_issue_19814(): + assert nonlinsolve([ 2**m - 2**(2*n), 4*2**m - 2**(4*n)], m, n + ) == FiniteSet((log(2**(2*n))/log(2), S.Complexes)) + + +def test_issue_22058(): + sol = solveset(-sqrt(t)*x**2 + 2*x + sqrt(t), x, S.Reals) + # doesn't fail (and following numerical check) + assert sol.xreplace({t: 1}) == {1 - sqrt(2), 1 + sqrt(2)}, sol.xreplace({t: 1}) + + +def test_issue_11184(): + assert solveset(20*sqrt(y**2 + (sqrt(-(y - 10)*(y + 10)) + 10)**2) - 60, y, S.Reals) is S.EmptySet + + +def test_issue_21890(): + e = S(2)/3 + assert nonlinsolve([4*x**3*y**4 - 2*y, 4*x**4*y**3 - 2*x], x, y) == { + (2**e/(2*y), y), ((-2**e/4 - 2**e*sqrt(3)*I/4)/y, y), + ((-2**e/4 + 2**e*sqrt(3)*I/4)/y, y)} + assert nonlinsolve([(1 - 4*x**2)*exp(-2*x**2 - 2*y**2), + -4*x*y*exp(-2*x**2)*exp(-2*y**2)], x, y) == {(-S(1)/2, 0), (S(1)/2, 0)} + rx, ry = symbols('x y', real=True) + sol = nonlinsolve([4*rx**3*ry**4 - 2*ry, 4*rx**4*ry**3 - 2*rx], rx, ry) + ans = {(2**(S(2)/3)/(2*ry), ry), + ((-2**(S(2)/3)/4 - 2**(S(2)/3)*sqrt(3)*I/4)/ry, ry), + ((-2**(S(2)/3)/4 + 2**(S(2)/3)*sqrt(3)*I/4)/ry, ry)} + assert sol == ans + + +def test_issue_22628(): + assert nonlinsolve([h - 1, k - 1, f - 2, f - 4, -2*k], h, k, f) == S.EmptySet + assert nonlinsolve([x**3 - 1, x + y, x**2 - 4], [x, y]) == S.EmptySet + + +def test_issue_25781(): + assert solve(sqrt(x/2) - x) == [0, S.Half] + + +def test_issue_26077(): + _n = Symbol('_n') + function = x*cot(5*x) + critical_points = stationary_points(function, x, S.Reals) + excluded_points = Union( + ImageSet(Lambda(_n, 2*_n*pi/5), S.Integers), + ImageSet(Lambda(_n, 2*_n*pi/5 + pi/5), S.Integers) + ) + solution = ConditionSet(x, + Eq(x*(-5*cot(5*x)**2 - 5) + cot(5*x), 0), + Complement(S.Reals, excluded_points) + ) + assert solution.as_dummy() == critical_points.as_dummy() diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c06198ef43e870b5fa955e579a2f6278381c9795 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/drv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/drv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aa374f4cd9acda6e9083007082c3c4f882b431f5 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/drv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/error_prop.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/error_prop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..606a7ec445a0ea320229b0532d6750d66f99b187 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/error_prop.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/frv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/frv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6277a59712f84889cefdbe62e73f51ffeb5fba88 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/frv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/frv_types.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/frv_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..826b3b499f2a29d8e411e70f72d0eaf4c6709be4 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/frv_types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/joint_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/joint_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9b388aacba83e09848388a652cd45d84a1c1e8aa Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/joint_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/joint_rv_types.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/joint_rv_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fa7406152f85a770568b08188bf72b59cfb61446 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/joint_rv_types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/matrix_distributions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/matrix_distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d20cf7c458aaf29acd5249cd9e7b7c16b285d09f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/matrix_distributions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/random_matrix.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/random_matrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d2f862fdd8177817bbb0ca5cb150901244d2ac05 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/random_matrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/random_matrix_models.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/random_matrix_models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..169761314647ccda0143fe4aaca3776ebeeb56d2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/random_matrix_models.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3f362496bf694344d47b0f3f15b761dc814aacc9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/rv_interface.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/rv_interface.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c10c07bc977a737836cf9909264ff97458cbd3f0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/rv_interface.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/stochastic_process.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/stochastic_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48bbceec189400f26e2719b11d1fffdb4e66c718 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/stochastic_process.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/stochastic_process_types.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/stochastic_process_types.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e034d72baf2801f7ada06c8498a26821f917cd8c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/stochastic_process_types.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/symbolic_multivariate_probability.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/symbolic_multivariate_probability.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..70c565b11a91e090e67ce1c530b1d0d558009b04 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/symbolic_multivariate_probability.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/__pycache__/symbolic_probability.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/__pycache__/symbolic_probability.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c0ee3292cf10acbc658142858f137d64713fb545 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/__pycache__/symbolic_probability.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6def329e2a72ceade6bfc36a51a9e878d13186fc Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_numpy.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_numpy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a85f8b90b8479279dd45e9ce06e0c80bd675c98b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_numpy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_pymc.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_pymc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f074c0a1102c72f743536e4ffd1e59cb23478a51 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_pymc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_scipy.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_scipy.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca7e189fa04a63db335b1d830d4bb8e0f051b43c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/__pycache__/sample_scipy.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/__init__.py b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f88755923e24fd68be474e017c681bd0871998fd Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..757d416ad346ec91d5b2c4e96e15d3fd9cc57546 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_continuous_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_discrete_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_discrete_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3c265bea4d4ae2fd438b79afa5a4684cb639766 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_discrete_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_finite_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_finite_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24aad77f3abdb66985f38de334d124f4bb0494c0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/sampling/tests/__pycache__/test_sample_finite_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py b/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..953bb602df5e63da2882ee118de9dbf24b6f7804 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_continuous_rv.py @@ -0,0 +1,181 @@ +from sympy.core.numbers import oo +from sympy.core.symbol import Symbol +from sympy.functions.elementary.exponential import exp +from sympy.sets.sets import Interval +from sympy.external import import_module +from sympy.stats import Beta, Chi, Normal, Gamma, Exponential, LogNormal, Pareto, ChiSquared, Uniform, sample, \ + BetaPrime, Cauchy, GammaInverse, GaussianInverse, StudentT, Weibull, density, ContinuousRV, FDistribution, \ + Gumbel, Laplace, Logistic, Rayleigh, Triangular +from sympy.testing.pytest import skip, raises + + +def test_sample_numpy(): + distribs_numpy = [ + Beta("B", 1, 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + ChiSquared("CS", 2), + Uniform("U", 0, 1), + FDistribution("FD", 1, 2), + Gumbel("GB", 1, 2), + Laplace("L", 1, 2), + Logistic("LO", 1, 2), + Rayleigh("R", 1), + Triangular("T", 1, 2, 2), + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Chi("C", 1), library='numpy')) + raises(NotImplementedError, + lambda: Chi("C", 1).pspace.distribution.sample(library='tensorflow')) + + +def test_sample_scipy(): + distribs_scipy = [ + Beta("B", 1, 1), + BetaPrime("BP", 1, 1), + Cauchy("C", 1, 1), + Chi("C", 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + GammaInverse("GI", 1, 1), + GaussianInverse("GUI", 1, 1), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + StudentT("S", 2), + ChiSquared("CS", 2), + Uniform("U", 0, 1) + ] + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size, library='scipy') + samps2 = sample(X, size=(2, 2), library='scipy') + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Beta("B", 1, 1), + Cauchy("C", 1, 1), + Normal("N", 0, 1), + Gamma("G", 2, 7), + GaussianInverse("GI", 1, 1), + Exponential("E", 2), + LogNormal("LN", 0, 1), + Pareto("P", 1, 1), + ChiSquared("CS", 2), + Uniform("U", 0, 1) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Chi("C", 1), library='pymc')) + + +def test_sampling_gamma_inverse(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for sampling of gamma inverse.') + X = GammaInverse("x", 1, 1) + assert sample(X) in X.pspace.domain.set + + +def test_lognormal_sampling(): + # Right now, only density function and sampling works + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + for i in range(3): + X = LogNormal('x', i, 1) + assert sample(X) in X.pspace.domain.set + + size = 5 + samps = sample(X, size=size) + for samp in samps: + assert samp in X.pspace.domain.set + + +def test_sampling_gaussian_inverse(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for sampling of Gaussian inverse.') + X = GaussianInverse("x", 1, 1) + assert sample(X, library='scipy') in X.pspace.domain.set + + +def test_prefab_sampling(): + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + N = Normal('X', 0, 1) + L = LogNormal('L', 0, 1) + E = Exponential('Ex', 1) + P = Pareto('P', 1, 3) + W = Weibull('W', 1, 1) + U = Uniform('U', 0, 1) + B = Beta('B', 2, 5) + G = Gamma('G', 1, 3) + + variables = [N, L, E, P, W, U, B, G] + niter = 10 + size = 5 + for var in variables: + for _ in range(niter): + assert sample(var) in var.pspace.domain.set + samps = sample(var, size=size) + for samp in samps: + assert samp in var.pspace.domain.set + + +def test_sample_continuous(): + z = Symbol('z') + Z = ContinuousRV(z, exp(-z), set=Interval(0, oo)) + assert density(Z)(-1) == 0 + + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(Z) in Z.pspace.domain.set + sym, val = list(Z.pspace.sample().items())[0] + assert sym == Z and val in Interval(0, oo) + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(Z, size=10, library=lib, seed=0) + s1 = sample(Z, size=10, library=lib, seed=0) + s2 = sample(Z, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert all(s1 != s2) + except NotImplementedError: + continue diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py b/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..10029af647b9811fe1e29f05360069c20ea1e659 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_discrete_rv.py @@ -0,0 +1,99 @@ +from sympy.core.singleton import S +from sympy.core.symbol import Symbol +from sympy.external import import_module +from sympy.stats import Geometric, Poisson, Zeta, sample, Skellam, DiscreteRV, Logarithmic, NegativeBinomial, YuleSimon +from sympy.testing.pytest import skip, raises, slow + + +def test_sample_numpy(): + distribs_numpy = [ + Geometric('G', 0.5), + Poisson('P', 1), + Zeta('Z', 2) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Skellam('S', 1, 1), library='numpy')) + raises(NotImplementedError, + lambda: Skellam('S', 1, 1).pspace.distribution.sample(library='tensorflow')) + + +def test_sample_scipy(): + p = S(2)/3 + x = Symbol('x', integer=True, positive=True) + pdf = p*(1 - p)**(x - 1) # pdf of Geometric Distribution + distribs_scipy = [ + DiscreteRV(x, pdf, set=S.Naturals), + Geometric('G', 0.5), + Logarithmic('L', 0.5), + NegativeBinomial('N', 5, 0.4), + Poisson('P', 1), + Skellam('S', 1, 1), + YuleSimon('Y', 1), + Zeta('Z', 2) + ] + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size, library='scipy') + samps2 = sample(X, size=(2, 2), library='scipy') + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Geometric('G', 0.5), + Poisson('P', 1), + NegativeBinomial('N', 5, 0.4) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Skellam('S', 1, 1), library='pymc')) + +@slow +def test_sample_discrete(): + X = Geometric('X', S.Half) + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests') + assert sample(X) in X.pspace.domain.set + samps = sample(X, size=2) # This takes long time if ran without scipy + for samp in samps: + assert samp in X.pspace.domain.set + + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0, s1, s2 = [], [], [] + s0 = sample(X, size=10, library=lib, seed=0) + s1 = sample(X, size=10, library=lib, seed=0) + s2 = sample(X, size=10, library=lib, seed=1) + assert all(s0 == s1) + assert not all(s1 == s2) + except NotImplementedError: + continue diff --git a/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py b/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py new file mode 100644 index 0000000000000000000000000000000000000000..96cabe0ff4aaa5977e16600217fbbdeb08b962ae --- /dev/null +++ b/lib/python3.10/site-packages/sympy/stats/sampling/tests/test_sample_finite_rv.py @@ -0,0 +1,94 @@ +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.external import import_module +from sympy.stats import Binomial, sample, Die, FiniteRV, DiscreteUniform, Bernoulli, BetaBinomial, Hypergeometric, \ + Rademacher +from sympy.testing.pytest import skip, raises + +def test_given_sample(): + X = Die('X', 6) + scipy = import_module('scipy') + if not scipy: + skip('Scipy is not installed. Abort tests') + assert sample(X, X > 5) == 6 + +def test_sample_numpy(): + distribs_numpy = [ + Binomial("B", 5, 0.4), + Hypergeometric("H", 2, 1, 1) + ] + size = 3 + numpy = import_module('numpy') + if not numpy: + skip('Numpy is not installed. Abort tests for _sample_numpy.') + else: + for X in distribs_numpy: + samps = sample(X, size=size, library='numpy') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: sample(Die("D"), library='numpy')) + raises(NotImplementedError, + lambda: Die("D").pspace.sample(library='tensorflow')) + + +def test_sample_scipy(): + distribs_scipy = [ + FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}), + DiscreteUniform("Y", list(range(5))), + Die("D"), + Bernoulli("Be", 0.3), + Binomial("Bi", 5, 0.4), + BetaBinomial("Bb", 2, 1, 1), + Hypergeometric("H", 1, 1, 1), + Rademacher("R") + ] + + size = 3 + scipy = import_module('scipy') + if not scipy: + skip('Scipy not installed. Abort tests for _sample_scipy.') + else: + for X in distribs_scipy: + samps = sample(X, size=size) + samps2 = sample(X, size=(2, 2)) + for sam in samps: + assert sam in X.pspace.domain.set + for i in range(2): + for j in range(2): + assert samps2[i][j] in X.pspace.domain.set + + +def test_sample_pymc(): + distribs_pymc = [ + Bernoulli('B', 0.2), + Binomial('N', 5, 0.4) + ] + size = 3 + pymc = import_module('pymc') + if not pymc: + skip('PyMC is not installed. Abort tests for _sample_pymc.') + else: + for X in distribs_pymc: + samps = sample(X, size=size, library='pymc') + for sam in samps: + assert sam in X.pspace.domain.set + raises(NotImplementedError, + lambda: (sample(Die("D"), library='pymc'))) + + +def test_sample_seed(): + F = FiniteRV('F', {1: S.Half, 2: Rational(1, 4), 3: Rational(1, 4)}) + size = 10 + libraries = ['scipy', 'numpy', 'pymc'] + for lib in libraries: + try: + imported_lib = import_module(lib) + if imported_lib: + s0 = sample(F, size=size, library=lib, seed=0) + s1 = sample(F, size=size, library=lib, seed=0) + s2 = sample(F, size=size, library=lib, seed=1) + assert all(s0 == s1) + assert not all(s1 == s2) + except NotImplementedError: + continue diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ca4b448b05d0d1ac79ba707e616ebf46375df9f7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_compound_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_compound_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9972253c0989e135eed1287c986afa344023b87a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_compound_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_continuous_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_continuous_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5623198b55aa35f0a28993695af68961f626a382 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_continuous_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_discrete_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_discrete_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d94079756763b0e74aee518daa61b53af0d14551 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_discrete_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_error_prop.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_error_prop.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..55c06ead3d34ddb7ff9017ea51fab95a04081c8c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_error_prop.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_finite_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_finite_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47ce1adb42f49d2860cde75bca9e50138ec8bcb6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_finite_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_joint_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_joint_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b65cd72fbbcf90244a10019984e98861b1364168 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_joint_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_matrix_distributions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_matrix_distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5208c662ce462e017589a2d981b6fafbe4c15c3 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_matrix_distributions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_mix.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_mix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..74f5e58f9e51295fa4ca8892b8a1ba9fdda94516 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_mix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_random_matrix.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_random_matrix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95f01788d601d6214d590afb47abb8a5742fa82c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_random_matrix.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_rv.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_rv.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..85e262c461c3685a6f1b50cf0c24642a7e50b9d7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_rv.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_stochastic_process.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_stochastic_process.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc47b1bc85445aeb6e5771d7ece6432a261425ef Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_stochastic_process.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_symbolic_multivariate.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_symbolic_multivariate.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2349d8b1ecbce3396d1808fa0eca073ad349f807 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_symbolic_multivariate.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_symbolic_probability.cpython-310.pyc b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_symbolic_probability.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a232d5c69fc27e8774de740793afcedfeb2d41b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/stats/tests/__pycache__/test_symbolic_probability.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fcb1e0a5b6c16b2583daf7c08c8219e6249e75f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/autowrap.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/autowrap.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..86ef2e4f55c869669e8ccefc9ea9c4c0984ab93c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/autowrap.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/codegen.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/codegen.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8b3ec447dcff82afced71a0ee1ac0ed5909abc79 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/codegen.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/decorator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/decorator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bed880c88bb6b53f05a27002dc39e890ea1b2b45 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/decorator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/enumerative.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/enumerative.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c9c3c5e6299e634eedfd2a3c4162d52baa47f68f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/enumerative.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/exceptions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/exceptions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ab8c97bf3f6298661499b6bbf271154e53293bf9 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/exceptions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/iterables.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/iterables.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b1e7dffa0e636ae19e40314b5f58b60271ec485a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/iterables.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/lambdify.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/lambdify.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..09aee2449b7ccba275a1517ebd51dfa8bf986ed8 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/lambdify.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/magic.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/magic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..97bd8ee53dc1b270497535c547ff0b25d758b41a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/magic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/matchpy_connector.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/matchpy_connector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95a59992438d2df19dec74d128865580d3aff016 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/matchpy_connector.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/memoization.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/memoization.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..46cf7dafdc910457a2820d1b41d577b5fc929116 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/memoization.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/misc.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/misc.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36714649ff2683d74cce048cad6b0a12bb37105a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/misc.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/pkgdata.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/pkgdata.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d820aea84b0c555a46b7008da9e6242e7237ef9f Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/pkgdata.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/pytest.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/pytest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ec871a1dfb2960b40f191a368ff480fba04f845 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/pytest.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/randtest.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/randtest.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de2ef4c4536e02671ac31800be62817dfca4951c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/randtest.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/runtests.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/runtests.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a2df1534f61dd9feb29927403f37cff25301265 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/runtests.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/source.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/source.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36b1cf5eb1f4543be8b3f9dd92adc8feb4768005 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/source.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/timeutils.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/timeutils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3cec92f2a1dee0594c2e3c36f95dddd6aeaa1e81 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/timeutils.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/__pycache__/tmpfiles.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/__pycache__/tmpfiles.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..615223d6354ce797a852c58416cea439bf7d2e90 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/__pycache__/tmpfiles.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/__init__.py b/lib/python3.10/site-packages/sympy/utilities/_compilation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d2c05ad48a93493bb5434256c88dfd614ac47b2d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/_compilation/__init__.py @@ -0,0 +1,22 @@ +""" This sub-module is private, i.e. external code should not depend on it. + +These functions are used by tests run as part of continuous integration. +Once the implementation is mature (it should support the major +platforms: Windows, OS X & Linux) it may become official API which + may be relied upon by downstream libraries. Until then API may break +without prior notice. + +TODO: +- (optionally) clean up after tempfile.mkdtemp() +- cross-platform testing +- caching of compiler choice and intermediate files + +""" + +from .compilation import compile_link_import_strings, compile_run_strings +from .availability import has_fortran, has_c, has_cxx + +__all__ = [ + 'compile_link_import_strings', 'compile_run_strings', + 'has_fortran', 'has_c', 'has_cxx', +] diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..abac70a3ae22221e68d9ee3242aa8017a02ddb81 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/availability.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/availability.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03224963c72830908acb0029e9a89c8e99c951f7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/availability.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/compilation.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/compilation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1121778d60bc22a52c028bcbd7e6656d3fa87d03 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/compilation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/runners.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/runners.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..84f8e281169efd1fc1bfaa69afea7b1f2d3850a0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/runners.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/util.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..768414fe99a16e551180db7ede1c04c2b172afe2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/_compilation/__pycache__/util.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/availability.py b/lib/python3.10/site-packages/sympy/utilities/_compilation/availability.py new file mode 100644 index 0000000000000000000000000000000000000000..dc97b3e7b8c7e7307c6c21352ed4035d977aabb3 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/_compilation/availability.py @@ -0,0 +1,77 @@ +import os +from .compilation import compile_run_strings +from .util import CompilerNotFoundError + +def has_fortran(): + if not hasattr(has_fortran, 'result'): + try: + (stdout, stderr), info = compile_run_strings( + [('main.f90', ( + 'program foo\n' + 'print *, "hello world"\n' + 'end program' + ))], clean=True + ) + except CompilerNotFoundError: + has_fortran.result = False + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise + else: + if info['exit_status'] != os.EX_OK or 'hello world' not in stdout: + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr)) + has_fortran.result = False + else: + has_fortran.result = True + return has_fortran.result + + +def has_c(): + if not hasattr(has_c, 'result'): + try: + (stdout, stderr), info = compile_run_strings( + [('main.c', ( + '#include \n' + 'int main(){\n' + 'printf("hello world\\n");\n' + 'return 0;\n' + '}' + ))], clean=True + ) + except CompilerNotFoundError: + has_c.result = False + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise + else: + if info['exit_status'] != os.EX_OK or 'hello world' not in stdout: + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr)) + has_c.result = False + else: + has_c.result = True + return has_c.result + + +def has_cxx(): + if not hasattr(has_cxx, 'result'): + try: + (stdout, stderr), info = compile_run_strings( + [('main.cxx', ( + '#include \n' + 'int main(){\n' + 'std::cout << "hello world" << std::endl;\n' + '}' + ))], clean=True + ) + except CompilerNotFoundError: + has_cxx.result = False + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise + else: + if info['exit_status'] != os.EX_OK or 'hello world' not in stdout: + if os.environ.get('SYMPY_STRICT_COMPILER_CHECKS', '0') == '1': + raise ValueError("Failed to compile test program:\n%s\n%s\n" % (stdout, stderr)) + has_cxx.result = False + else: + has_cxx.result = True + return has_cxx.result diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/compilation.py b/lib/python3.10/site-packages/sympy/utilities/_compilation/compilation.py new file mode 100644 index 0000000000000000000000000000000000000000..ca6c916506de2e66c5b6061a295b58a431bd2d04 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/_compilation/compilation.py @@ -0,0 +1,657 @@ +import glob +import os +import shutil +import subprocess +import sys +import tempfile +import warnings +from sysconfig import get_config_var, get_config_vars, get_path + +from .runners import ( + CCompilerRunner, + CppCompilerRunner, + FortranCompilerRunner +) +from .util import ( + get_abspath, make_dirs, copy, Glob, ArbitraryDepthGlob, + glob_at_depth, import_module_from_file, pyx_is_cplus, + sha256_of_string, sha256_of_file, CompileError +) + +if os.name == 'posix': + objext = '.o' +elif os.name == 'nt': + objext = '.obj' +else: + warnings.warn("Unknown os.name: {}".format(os.name)) + objext = '.o' + + +def compile_sources(files, Runner=None, destdir=None, cwd=None, keep_dir_struct=False, + per_file_kwargs=None, **kwargs): + """ Compile source code files to object files. + + Parameters + ========== + + files : iterable of str + Paths to source files, if ``cwd`` is given, the paths are taken as relative. + Runner: CompilerRunner subclass (optional) + Could be e.g. ``FortranCompilerRunner``. Will be inferred from filename + extensions if missing. + destdir: str + Output directory, if cwd is given, the path is taken as relative. + cwd: str + Working directory. Specify to have compiler run in other directory. + also used as root of relative paths. + keep_dir_struct: bool + Reproduce directory structure in `destdir`. default: ``False`` + per_file_kwargs: dict + Dict mapping instances in ``files`` to keyword arguments. + \\*\\*kwargs: dict + Default keyword arguments to pass to ``Runner``. + + Returns + ======= + List of strings (paths of object files). + """ + _per_file_kwargs = {} + + if per_file_kwargs is not None: + for k, v in per_file_kwargs.items(): + if isinstance(k, Glob): + for path in glob.glob(k.pathname): + _per_file_kwargs[path] = v + elif isinstance(k, ArbitraryDepthGlob): + for path in glob_at_depth(k.filename, cwd): + _per_file_kwargs[path] = v + else: + _per_file_kwargs[k] = v + + # Set up destination directory + destdir = destdir or '.' + if not os.path.isdir(destdir): + if os.path.exists(destdir): + raise OSError("{} is not a directory".format(destdir)) + else: + make_dirs(destdir) + if cwd is None: + cwd = '.' + for f in files: + copy(f, destdir, only_update=True, dest_is_dir=True) + + # Compile files and return list of paths to the objects + dstpaths = [] + for f in files: + if keep_dir_struct: + name, ext = os.path.splitext(f) + else: + name, ext = os.path.splitext(os.path.basename(f)) + file_kwargs = kwargs.copy() + file_kwargs.update(_per_file_kwargs.get(f, {})) + dstpaths.append(src2obj(f, Runner, cwd=cwd, **file_kwargs)) + return dstpaths + + +def get_mixed_fort_c_linker(vendor=None, cplus=False, cwd=None): + vendor = vendor or os.environ.get('SYMPY_COMPILER_VENDOR', 'gnu') + + if vendor.lower() == 'intel': + if cplus: + return (FortranCompilerRunner, + {'flags': ['-nofor_main', '-cxxlib']}, vendor) + else: + return (FortranCompilerRunner, + {'flags': ['-nofor_main']}, vendor) + elif vendor.lower() == 'gnu' or 'llvm': + if cplus: + return (CppCompilerRunner, + {'lib_options': ['fortran']}, vendor) + else: + return (FortranCompilerRunner, + {}, vendor) + else: + raise ValueError("No vendor found.") + + +def link(obj_files, out_file=None, shared=False, Runner=None, + cwd=None, cplus=False, fort=False, extra_objs=None, **kwargs): + """ Link object files. + + Parameters + ========== + + obj_files: iterable of str + Paths to object files. + out_file: str (optional) + Path to executable/shared library, if ``None`` it will be + deduced from the last item in obj_files. + shared: bool + Generate a shared library? + Runner: CompilerRunner subclass (optional) + If not given the ``cplus`` and ``fort`` flags will be inspected + (fallback is the C compiler). + cwd: str + Path to the root of relative paths and working directory for compiler. + cplus: bool + C++ objects? default: ``False``. + fort: bool + Fortran objects? default: ``False``. + extra_objs: list + List of paths to extra object files / static libraries. + \\*\\*kwargs: dict + Keyword arguments passed to ``Runner``. + + Returns + ======= + + The absolute path to the generated shared object / executable. + + """ + if out_file is None: + out_file, ext = os.path.splitext(os.path.basename(obj_files[-1])) + if shared: + out_file += get_config_var('EXT_SUFFIX') + + if not Runner: + if fort: + Runner, extra_kwargs, vendor = \ + get_mixed_fort_c_linker( + vendor=kwargs.get('vendor', None), + cplus=cplus, + cwd=cwd, + ) + for k, v in extra_kwargs.items(): + if k in kwargs: + kwargs[k].expand(v) + else: + kwargs[k] = v + else: + if cplus: + Runner = CppCompilerRunner + else: + Runner = CCompilerRunner + + flags = kwargs.pop('flags', []) + if shared: + if '-shared' not in flags: + flags.append('-shared') + run_linker = kwargs.pop('run_linker', True) + if not run_linker: + raise ValueError("run_linker was set to False (nonsensical).") + + out_file = get_abspath(out_file, cwd=cwd) + runner = Runner(obj_files+(extra_objs or []), out_file, flags, cwd=cwd, **kwargs) + runner.run() + return out_file + + +def link_py_so(obj_files, so_file=None, cwd=None, libraries=None, + cplus=False, fort=False, extra_objs=None, **kwargs): + """ Link Python extension module (shared object) for importing + + Parameters + ========== + + obj_files: iterable of str + Paths to object files to be linked. + so_file: str + Name (path) of shared object file to create. If not specified it will + have the basname of the last object file in `obj_files` but with the + extension '.so' (Unix). + cwd: path string + Root of relative paths and working directory of linker. + libraries: iterable of strings + Libraries to link against, e.g. ['m']. + cplus: bool + Any C++ objects? default: ``False``. + fort: bool + Any Fortran objects? default: ``False``. + extra_objs: list + List of paths of extra object files / static libraries to link against. + kwargs**: dict + Keyword arguments passed to ``link(...)``. + + Returns + ======= + + Absolute path to the generate shared object. + """ + libraries = libraries or [] + + include_dirs = kwargs.pop('include_dirs', []) + library_dirs = kwargs.pop('library_dirs', []) + + # Add Python include and library directories + # PY_LDFLAGS does not available on all python implementations + # e.g. when with pypy, so it's LDFLAGS we need to use + if sys.platform == "win32": + warnings.warn("Windows not yet supported.") + elif sys.platform == 'darwin': + cfgDict = get_config_vars() + kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']] + library_dirs += [cfgDict['LIBDIR']] + + # In macOS, linker needs to compile frameworks + # e.g. "-framework CoreFoundation" + is_framework = False + for opt in cfgDict['LIBS'].split(): + if is_framework: + kwargs['linkline'] = kwargs.get('linkline', []) + ['-framework', opt] + is_framework = False + elif opt.startswith('-l'): + libraries.append(opt[2:]) + elif opt.startswith('-framework'): + is_framework = True + # The python library is not included in LIBS + libfile = cfgDict['LIBRARY'] + libname = ".".join(libfile.split('.')[:-1])[3:] + libraries.append(libname) + + elif sys.platform[:3] == 'aix': + # Don't use the default code below + pass + else: + if get_config_var('Py_ENABLE_SHARED'): + cfgDict = get_config_vars() + kwargs['linkline'] = kwargs.get('linkline', []) + [cfgDict['LDFLAGS']] + library_dirs += [cfgDict['LIBDIR']] + for opt in cfgDict['BLDLIBRARY'].split(): + if opt.startswith('-l'): + libraries += [opt[2:]] + else: + pass + + flags = kwargs.pop('flags', []) + needed_flags = ('-pthread',) + for flag in needed_flags: + if flag not in flags: + flags.append(flag) + + return link(obj_files, shared=True, flags=flags, cwd=cwd, cplus=cplus, fort=fort, + include_dirs=include_dirs, libraries=libraries, + library_dirs=library_dirs, extra_objs=extra_objs, **kwargs) + + +def simple_cythonize(src, destdir=None, cwd=None, **cy_kwargs): + """ Generates a C file from a Cython source file. + + Parameters + ========== + + src: str + Path to Cython source. + destdir: str (optional) + Path to output directory (default: '.'). + cwd: path string (optional) + Root of relative paths (default: '.'). + **cy_kwargs: + Second argument passed to cy_compile. Generates a .cpp file if ``cplus=True`` in ``cy_kwargs``, + else a .c file. + """ + from Cython.Compiler.Main import ( + default_options, CompilationOptions + ) + from Cython.Compiler.Main import compile as cy_compile + + assert src.lower().endswith('.pyx') or src.lower().endswith('.py') + cwd = cwd or '.' + destdir = destdir or '.' + + ext = '.cpp' if cy_kwargs.get('cplus', False) else '.c' + c_name = os.path.splitext(os.path.basename(src))[0] + ext + + dstfile = os.path.join(destdir, c_name) + + if cwd: + ori_dir = os.getcwd() + else: + ori_dir = '.' + os.chdir(cwd) + try: + cy_options = CompilationOptions(default_options) + cy_options.__dict__.update(cy_kwargs) + # Set language_level if not set by cy_kwargs + # as not setting it is deprecated + if 'language_level' not in cy_kwargs: + cy_options.__dict__['language_level'] = 3 + cy_result = cy_compile([src], cy_options) + if cy_result.num_errors > 0: + raise ValueError("Cython compilation failed.") + + # Move generated C file to destination + # In macOS, the generated C file is in the same directory as the source + # but the /var is a symlink to /private/var, so we need to use realpath + if os.path.realpath(os.path.dirname(src)) != os.path.realpath(destdir): + if os.path.exists(dstfile): + os.unlink(dstfile) + shutil.move(os.path.join(os.path.dirname(src), c_name), destdir) + finally: + os.chdir(ori_dir) + return dstfile + + +extension_mapping = { + '.c': (CCompilerRunner, None), + '.cpp': (CppCompilerRunner, None), + '.cxx': (CppCompilerRunner, None), + '.f': (FortranCompilerRunner, None), + '.for': (FortranCompilerRunner, None), + '.ftn': (FortranCompilerRunner, None), + '.f90': (FortranCompilerRunner, None), # ifort only knows about .f90 + '.f95': (FortranCompilerRunner, 'f95'), + '.f03': (FortranCompilerRunner, 'f2003'), + '.f08': (FortranCompilerRunner, 'f2008'), +} + + +def src2obj(srcpath, Runner=None, objpath=None, cwd=None, inc_py=False, **kwargs): + """ Compiles a source code file to an object file. + + Files ending with '.pyx' assumed to be cython files and + are dispatched to pyx2obj. + + Parameters + ========== + + srcpath: str + Path to source file. + Runner: CompilerRunner subclass (optional) + If ``None``: deduced from extension of srcpath. + objpath : str (optional) + Path to generated object. If ``None``: deduced from ``srcpath``. + cwd: str (optional) + Working directory and root of relative paths. If ``None``: current dir. + inc_py: bool + Add Python include path to kwarg "include_dirs". Default: False + \\*\\*kwargs: dict + keyword arguments passed to Runner or pyx2obj + + """ + name, ext = os.path.splitext(os.path.basename(srcpath)) + if objpath is None: + if os.path.isabs(srcpath): + objpath = '.' + else: + objpath = os.path.dirname(srcpath) + objpath = objpath or '.' # avoid objpath == '' + + if os.path.isdir(objpath): + objpath = os.path.join(objpath, name + objext) + + include_dirs = kwargs.pop('include_dirs', []) + if inc_py: + py_inc_dir = get_path('include') + if py_inc_dir not in include_dirs: + include_dirs.append(py_inc_dir) + + if ext.lower() == '.pyx': + return pyx2obj(srcpath, objpath=objpath, include_dirs=include_dirs, cwd=cwd, + **kwargs) + + if Runner is None: + Runner, std = extension_mapping[ext.lower()] + if 'std' not in kwargs: + kwargs['std'] = std + + flags = kwargs.pop('flags', []) + needed_flags = ('-fPIC',) + for flag in needed_flags: + if flag not in flags: + flags.append(flag) + + # src2obj implies not running the linker... + run_linker = kwargs.pop('run_linker', False) + if run_linker: + raise CompileError("src2obj called with run_linker=True") + + runner = Runner([srcpath], objpath, include_dirs=include_dirs, + run_linker=run_linker, cwd=cwd, flags=flags, **kwargs) + runner.run() + return objpath + + +def pyx2obj(pyxpath, objpath=None, destdir=None, cwd=None, + include_dirs=None, cy_kwargs=None, cplus=None, **kwargs): + """ + Convenience function + + If cwd is specified, pyxpath and dst are taken to be relative + If only_update is set to `True` the modification time is checked + and compilation is only run if the source is newer than the + destination + + Parameters + ========== + + pyxpath: str + Path to Cython source file. + objpath: str (optional) + Path to object file to generate. + destdir: str (optional) + Directory to put generated C file. When ``None``: directory of ``objpath``. + cwd: str (optional) + Working directory and root of relative paths. + include_dirs: iterable of path strings (optional) + Passed onto src2obj and via cy_kwargs['include_path'] + to simple_cythonize. + cy_kwargs: dict (optional) + Keyword arguments passed onto `simple_cythonize` + cplus: bool (optional) + Indicate whether C++ is used. default: auto-detect using ``.util.pyx_is_cplus``. + compile_kwargs: dict + keyword arguments passed onto src2obj + + Returns + ======= + + Absolute path of generated object file. + + """ + assert pyxpath.endswith('.pyx') + cwd = cwd or '.' + objpath = objpath or '.' + destdir = destdir or os.path.dirname(objpath) + + abs_objpath = get_abspath(objpath, cwd=cwd) + + if os.path.isdir(abs_objpath): + pyx_fname = os.path.basename(pyxpath) + name, ext = os.path.splitext(pyx_fname) + objpath = os.path.join(objpath, name + objext) + + cy_kwargs = cy_kwargs or {} + cy_kwargs['output_dir'] = cwd + if cplus is None: + cplus = pyx_is_cplus(pyxpath) + cy_kwargs['cplus'] = cplus + + interm_c_file = simple_cythonize(pyxpath, destdir=destdir, cwd=cwd, **cy_kwargs) + + include_dirs = include_dirs or [] + flags = kwargs.pop('flags', []) + needed_flags = ('-fwrapv', '-pthread', '-fPIC') + for flag in needed_flags: + if flag not in flags: + flags.append(flag) + + options = kwargs.pop('options', []) + + if kwargs.pop('strict_aliasing', False): + raise CompileError("Cython requires strict aliasing to be disabled.") + + # Let's be explicit about standard + if cplus: + std = kwargs.pop('std', 'c++98') + else: + std = kwargs.pop('std', 'c99') + + return src2obj(interm_c_file, objpath=objpath, cwd=cwd, + include_dirs=include_dirs, flags=flags, std=std, + options=options, inc_py=True, strict_aliasing=False, + **kwargs) + + +def _any_X(srcs, cls): + for src in srcs: + name, ext = os.path.splitext(src) + key = ext.lower() + if key in extension_mapping: + if extension_mapping[key][0] == cls: + return True + return False + + +def any_fortran_src(srcs): + return _any_X(srcs, FortranCompilerRunner) + + +def any_cplus_src(srcs): + return _any_X(srcs, CppCompilerRunner) + + +def compile_link_import_py_ext(sources, extname=None, build_dir='.', compile_kwargs=None, + link_kwargs=None, extra_objs=None): + """ Compiles sources to a shared object (Python extension) and imports it + + Sources in ``sources`` which is imported. If shared object is newer than the sources, they + are not recompiled but instead it is imported. + + Parameters + ========== + + sources : list of strings + List of paths to sources. + extname : string + Name of extension (default: ``None``). + If ``None``: taken from the last file in ``sources`` without extension. + build_dir: str + Path to directory in which objects files etc. are generated. + compile_kwargs: dict + keyword arguments passed to ``compile_sources`` + link_kwargs: dict + keyword arguments passed to ``link_py_so`` + extra_objs: list + List of paths to (prebuilt) object files / static libraries to link against. + + Returns + ======= + + The imported module from of the Python extension. + """ + if extname is None: + extname = os.path.splitext(os.path.basename(sources[-1]))[0] + + compile_kwargs = compile_kwargs or {} + link_kwargs = link_kwargs or {} + + try: + mod = import_module_from_file(os.path.join(build_dir, extname), sources) + except ImportError: + objs = compile_sources(list(map(get_abspath, sources)), destdir=build_dir, + cwd=build_dir, **compile_kwargs) + so = link_py_so(objs, cwd=build_dir, fort=any_fortran_src(sources), + cplus=any_cplus_src(sources), extra_objs=extra_objs, **link_kwargs) + mod = import_module_from_file(so) + return mod + + +def _write_sources_to_build_dir(sources, build_dir): + build_dir = build_dir or tempfile.mkdtemp() + if not os.path.isdir(build_dir): + raise OSError("Non-existent directory: ", build_dir) + + source_files = [] + for name, src in sources: + dest = os.path.join(build_dir, name) + differs = True + sha256_in_mem = sha256_of_string(src.encode('utf-8')).hexdigest() + if os.path.exists(dest): + if os.path.exists(dest + '.sha256'): + with open(dest + '.sha256') as fh: + sha256_on_disk = fh.read() + else: + sha256_on_disk = sha256_of_file(dest).hexdigest() + + differs = sha256_on_disk != sha256_in_mem + if differs: + with open(dest, 'wt') as fh: + fh.write(src) + with open(dest + '.sha256', 'wt') as fh: + fh.write(sha256_in_mem) + source_files.append(dest) + return source_files, build_dir + + +def compile_link_import_strings(sources, build_dir=None, **kwargs): + """ Compiles, links and imports extension module from source. + + Parameters + ========== + + sources : iterable of name/source pair tuples + build_dir : string (default: None) + Path. ``None`` implies use a temporary directory. + **kwargs: + Keyword arguments passed onto `compile_link_import_py_ext`. + + Returns + ======= + + mod : module + The compiled and imported extension module. + info : dict + Containing ``build_dir`` as 'build_dir'. + + """ + source_files, build_dir = _write_sources_to_build_dir(sources, build_dir) + mod = compile_link_import_py_ext(source_files, build_dir=build_dir, **kwargs) + info = {"build_dir": build_dir} + return mod, info + + +def compile_run_strings(sources, build_dir=None, clean=False, compile_kwargs=None, link_kwargs=None): + """ Compiles, links and runs a program built from sources. + + Parameters + ========== + + sources : iterable of name/source pair tuples + build_dir : string (default: None) + Path. ``None`` implies use a temporary directory. + clean : bool + Whether to remove build_dir after use. This will only have an + effect if ``build_dir`` is ``None`` (which creates a temporary directory). + Passing ``clean == True`` and ``build_dir != None`` raises a ``ValueError``. + This will also set ``build_dir`` in returned info dictionary to ``None``. + compile_kwargs: dict + Keyword arguments passed onto ``compile_sources`` + link_kwargs: dict + Keyword arguments passed onto ``link`` + + Returns + ======= + + (stdout, stderr): pair of strings + info: dict + Containing exit status as 'exit_status' and ``build_dir`` as 'build_dir' + + """ + if clean and build_dir is not None: + raise ValueError("Automatic removal of build_dir is only available for temporary directory.") + try: + source_files, build_dir = _write_sources_to_build_dir(sources, build_dir) + objs = compile_sources(list(map(get_abspath, source_files)), destdir=build_dir, + cwd=build_dir, **(compile_kwargs or {})) + prog = link(objs, cwd=build_dir, + fort=any_fortran_src(source_files), + cplus=any_cplus_src(source_files), **(link_kwargs or {})) + p = subprocess.Popen([prog], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + exit_status = p.wait() + stdout, stderr = [txt.decode('utf-8') for txt in p.communicate()] + finally: + if clean and os.path.isdir(build_dir): + shutil.rmtree(build_dir) + build_dir = None + info = {"exit_status": exit_status, "build_dir": build_dir} + return (stdout, stderr), info diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/runners.py b/lib/python3.10/site-packages/sympy/utilities/_compilation/runners.py new file mode 100644 index 0000000000000000000000000000000000000000..1f37d6cf8ac47807da7f3f00dfc5cd847c03fa8d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/_compilation/runners.py @@ -0,0 +1,301 @@ +from __future__ import annotations +from typing import Callable, Optional + +from collections import OrderedDict +import os +import re +import subprocess +import warnings + +from .util import ( + find_binary_of_command, unique_list, CompileError +) + + +class CompilerRunner: + """ CompilerRunner base class. + + Parameters + ========== + + sources : list of str + Paths to sources. + out : str + flags : iterable of str + Compiler flags. + run_linker : bool + compiler_name_exe : (str, str) tuple + Tuple of compiler name & command to call. + cwd : str + Path of root of relative paths. + include_dirs : list of str + Include directories. + libraries : list of str + Libraries to link against. + library_dirs : list of str + Paths to search for shared libraries. + std : str + Standard string, e.g. ``'c++11'``, ``'c99'``, ``'f2003'``. + define: iterable of strings + macros to define + undef : iterable of strings + macros to undefine + preferred_vendor : string + name of preferred vendor e.g. 'gnu' or 'intel' + + Methods + ======= + + run(): + Invoke compilation as a subprocess. + + """ + + environ_key_compiler: str # e.g. 'CC', 'CXX', ... + environ_key_flags: str # e.g. 'CFLAGS', 'CXXFLAGS', ... + environ_key_ldflags: str = "LDFLAGS" # typically 'LDFLAGS' + + # Subclass to vendor/binary dict + compiler_dict: dict[str, str] + + # Standards should be a tuple of supported standards + # (first one will be the default) + standards: tuple[None | str, ...] + + # Subclass to dict of binary/formater-callback + std_formater: dict[str, Callable[[Optional[str]], str]] + + # subclass to be e.g. {'gcc': 'gnu', ...} + compiler_name_vendor_mapping: dict[str, str] + + def __init__(self, sources, out, flags=None, run_linker=True, compiler=None, cwd='.', + include_dirs=None, libraries=None, library_dirs=None, std=None, define=None, + undef=None, strict_aliasing=None, preferred_vendor=None, linkline=None, **kwargs): + if isinstance(sources, str): + raise ValueError("Expected argument sources to be a list of strings.") + self.sources = list(sources) + self.out = out + self.flags = flags or [] + if os.environ.get(self.environ_key_flags): + self.flags += os.environ[self.environ_key_flags].split() + self.cwd = cwd + if compiler: + self.compiler_name, self.compiler_binary = compiler + elif os.environ.get(self.environ_key_compiler): + self.compiler_binary = os.environ[self.environ_key_compiler] + for k, v in self.compiler_dict.items(): + if k in self.compiler_binary: + self.compiler_vendor = k + self.compiler_name = v + break + else: + self.compiler_vendor, self.compiler_name = list(self.compiler_dict.items())[0] + warnings.warn("failed to determine what kind of compiler %s is, assuming %s" % + (self.compiler_binary, self.compiler_name)) + else: + # Find a compiler + if preferred_vendor is None: + preferred_vendor = os.environ.get('SYMPY_COMPILER_VENDOR', None) + self.compiler_name, self.compiler_binary, self.compiler_vendor = self.find_compiler(preferred_vendor) + if self.compiler_binary is None: + raise ValueError("No compiler found (searched: {})".format(', '.join(self.compiler_dict.values()))) + self.define = define or [] + self.undef = undef or [] + self.include_dirs = include_dirs or [] + self.libraries = libraries or [] + self.library_dirs = library_dirs or [] + self.std = std or self.standards[0] + self.run_linker = run_linker + if self.run_linker: + # both gnu and intel compilers use '-c' for disabling linker + self.flags = list(filter(lambda x: x != '-c', self.flags)) + else: + if '-c' not in self.flags: + self.flags.append('-c') + + if self.std: + self.flags.append(self.std_formater[ + self.compiler_name](self.std)) + + self.linkline = (linkline or []) + [lf for lf in map( + str.strip, os.environ.get(self.environ_key_ldflags, "").split() + ) if lf != ""] + + if strict_aliasing is not None: + nsa_re = re.compile("no-strict-aliasing$") + sa_re = re.compile("strict-aliasing$") + if strict_aliasing is True: + if any(map(nsa_re.match, flags)): + raise CompileError("Strict aliasing cannot be both enforced and disabled") + elif any(map(sa_re.match, flags)): + pass # already enforced + else: + flags.append('-fstrict-aliasing') + elif strict_aliasing is False: + if any(map(nsa_re.match, flags)): + pass # already disabled + else: + if any(map(sa_re.match, flags)): + raise CompileError("Strict aliasing cannot be both enforced and disabled") + else: + flags.append('-fno-strict-aliasing') + else: + msg = "Expected argument strict_aliasing to be True/False, got {}" + raise ValueError(msg.format(strict_aliasing)) + + @classmethod + def find_compiler(cls, preferred_vendor=None): + """ Identify a suitable C/fortran/other compiler. """ + candidates = list(cls.compiler_dict.keys()) + if preferred_vendor: + if preferred_vendor in candidates: + candidates = [preferred_vendor]+candidates + else: + raise ValueError("Unknown vendor {}".format(preferred_vendor)) + name, path = find_binary_of_command([cls.compiler_dict[x] for x in candidates]) + return name, path, cls.compiler_name_vendor_mapping[name] + + def cmd(self): + """ List of arguments (str) to be passed to e.g. ``subprocess.Popen``. """ + cmd = ( + [self.compiler_binary] + + self.flags + + ['-U'+x for x in self.undef] + + ['-D'+x for x in self.define] + + ['-I'+x for x in self.include_dirs] + + self.sources + ) + if self.run_linker: + cmd += (['-L'+x for x in self.library_dirs] + + ['-l'+x for x in self.libraries] + + self.linkline) + counted = [] + for envvar in re.findall(r'\$\{(\w+)\}', ' '.join(cmd)): + if os.getenv(envvar) is None: + if envvar not in counted: + counted.append(envvar) + msg = "Environment variable '{}' undefined.".format(envvar) + raise CompileError(msg) + return cmd + + def run(self): + self.flags = unique_list(self.flags) + + # Append output flag and name to tail of flags + self.flags.extend(['-o', self.out]) + env = os.environ.copy() + env['PWD'] = self.cwd + + # NOTE: intel compilers seems to need shell=True + p = subprocess.Popen(' '.join(self.cmd()), + shell=True, + cwd=self.cwd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + env=env) + comm = p.communicate() + try: + self.cmd_outerr = comm[0].decode('utf-8') + except UnicodeDecodeError: + self.cmd_outerr = comm[0].decode('iso-8859-1') # win32 + self.cmd_returncode = p.returncode + + # Error handling + if self.cmd_returncode != 0: + msg = "Error executing '{}' in {} (exited status {}):\n {}\n".format( + ' '.join(self.cmd()), self.cwd, str(self.cmd_returncode), self.cmd_outerr + ) + raise CompileError(msg) + + return self.cmd_outerr, self.cmd_returncode + + +class CCompilerRunner(CompilerRunner): + + environ_key_compiler = 'CC' + environ_key_flags = 'CFLAGS' + + compiler_dict = OrderedDict([ + ('gnu', 'gcc'), + ('intel', 'icc'), + ('llvm', 'clang'), + ]) + + standards = ('c89', 'c90', 'c99', 'c11') # First is default + + std_formater = { + 'gcc': '-std={}'.format, + 'icc': '-std={}'.format, + 'clang': '-std={}'.format, + } + + compiler_name_vendor_mapping = { + 'gcc': 'gnu', + 'icc': 'intel', + 'clang': 'llvm' + } + + +def _mk_flag_filter(cmplr_name): # helper for class initialization + not_welcome = {'g++': ("Wimplicit-interface",)} # "Wstrict-prototypes",)} + if cmplr_name in not_welcome: + def fltr(x): + for nw in not_welcome[cmplr_name]: + if nw in x: + return False + return True + else: + def fltr(x): + return True + return fltr + + +class CppCompilerRunner(CompilerRunner): + + environ_key_compiler = 'CXX' + environ_key_flags = 'CXXFLAGS' + + compiler_dict = OrderedDict([ + ('gnu', 'g++'), + ('intel', 'icpc'), + ('llvm', 'clang++'), + ]) + + # First is the default, c++0x == c++11 + standards = ('c++98', 'c++0x') + + std_formater = { + 'g++': '-std={}'.format, + 'icpc': '-std={}'.format, + 'clang++': '-std={}'.format, + } + + compiler_name_vendor_mapping = { + 'g++': 'gnu', + 'icpc': 'intel', + 'clang++': 'llvm' + } + + +class FortranCompilerRunner(CompilerRunner): + + environ_key_compiler = 'FC' + environ_key_flags = 'FFLAGS' + + standards = (None, 'f77', 'f95', 'f2003', 'f2008') + + std_formater = { + 'gfortran': lambda x: '-std=gnu' if x is None else '-std=legacy' if x == 'f77' else '-std={}'.format(x), + 'ifort': lambda x: '-stand f08' if x is None else '-stand f{}'.format(x[-2:]), # f2008 => f08 + } + + compiler_dict = OrderedDict([ + ('gnu', 'gfortran'), + ('intel', 'ifort'), + ]) + + compiler_name_vendor_mapping = { + 'gfortran': 'gnu', + 'ifort': 'intel', + } diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__init__.py b/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b890779e74ea0b892d2a166e862f4d51bf3c794d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__pycache__/test_compilation.cpython-310.pyc b/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__pycache__/test_compilation.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..644673009905291e9aeb82bc37e0e677f3b4c829 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/__pycache__/test_compilation.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/test_compilation.py b/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/test_compilation.py new file mode 100644 index 0000000000000000000000000000000000000000..d9cf601edb404d8a98c55a779ee89219cf68af4a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/_compilation/tests/test_compilation.py @@ -0,0 +1,101 @@ +import shutil +import os +import subprocess +import tempfile +from sympy.external import import_module +from sympy.testing.pytest import skip + +from sympy.utilities._compilation.compilation import compile_link_import_py_ext, compile_link_import_strings, compile_sources, get_abspath + +numpy = import_module('numpy') +cython = import_module('cython') + +_sources1 = [ + ('sigmoid.c', r""" +#include + +void sigmoid(int n, const double * const restrict in, + double * const restrict out, double lim){ + for (int i=0; i 0: + if not os.path.exists(parent): + make_dirs(parent) + + if not os.path.exists(path): + os.mkdir(path, 0o777) + else: + assert os.path.isdir(path) + +def missing_or_other_newer(path, other_path, cwd=None): + """ + Investigate if path is non-existant or older than provided reference + path. + + Parameters + ========== + path: string + path to path which might be missing or too old + other_path: string + reference path + cwd: string + working directory (root of relative paths) + + Returns + ======= + True if path is older or missing. + """ + cwd = cwd or '.' + path = get_abspath(path, cwd=cwd) + other_path = get_abspath(other_path, cwd=cwd) + if not os.path.exists(path): + return True + if os.path.getmtime(other_path) - 1e-6 >= os.path.getmtime(path): + # 1e-6 is needed beacuse http://stackoverflow.com/questions/17086426/ + return True + return False + +def copy(src, dst, only_update=False, copystat=True, cwd=None, + dest_is_dir=False, create_dest_dirs=False): + """ Variation of ``shutil.copy`` with extra options. + + Parameters + ========== + + src : str + Path to source file. + dst : str + Path to destination. + only_update : bool + Only copy if source is newer than destination + (returns None if it was newer), default: ``False``. + copystat : bool + See ``shutil.copystat``. default: ``True``. + cwd : str + Path to working directory (root of relative paths). + dest_is_dir : bool + Ensures that dst is treated as a directory. default: ``False`` + create_dest_dirs : bool + Creates directories if needed. + + Returns + ======= + + Path to the copied file. + + """ + if cwd: # Handle working directory + if not os.path.isabs(src): + src = os.path.join(cwd, src) + if not os.path.isabs(dst): + dst = os.path.join(cwd, dst) + + if not os.path.exists(src): # Make sure source file extists + raise FileNotFoundError("Source: `{}` does not exist".format(src)) + + # We accept both (re)naming destination file _or_ + # passing a (possible non-existent) destination directory + if dest_is_dir: + if not dst[-1] == '/': + dst = dst+'/' + else: + if os.path.exists(dst) and os.path.isdir(dst): + dest_is_dir = True + + if dest_is_dir: + dest_dir = dst + dest_fname = os.path.basename(src) + dst = os.path.join(dest_dir, dest_fname) + else: + dest_dir = os.path.dirname(dst) + + if not os.path.exists(dest_dir): + if create_dest_dirs: + make_dirs(dest_dir) + else: + raise FileNotFoundError("You must create directory first.") + + if only_update: + if not missing_or_other_newer(dst, src): + return + + if os.path.islink(dst): + dst = os.path.abspath(os.path.realpath(dst), cwd=cwd) + + shutil.copy(src, dst) + if copystat: + shutil.copystat(src, dst) + + return dst + +Glob = namedtuple('Glob', 'pathname') +ArbitraryDepthGlob = namedtuple('ArbitraryDepthGlob', 'filename') + +def glob_at_depth(filename_glob, cwd=None): + if cwd is not None: + cwd = '.' + globbed = [] + for root, dirs, filenames in os.walk(cwd): + for fn in filenames: + # This is not tested: + if fnmatch.fnmatch(fn, filename_glob): + globbed.append(os.path.join(root, fn)) + return globbed + +def sha256_of_file(path, nblocks=128): + """ Computes the SHA256 hash of a file. + + Parameters + ========== + + path : string + Path to file to compute hash of. + nblocks : int + Number of blocks to read per iteration. + + Returns + ======= + + hashlib sha256 hash object. Use ``.digest()`` or ``.hexdigest()`` + on returned object to get binary or hex encoded string. + """ + sh = sha256() + with open(path, 'rb') as f: + for chunk in iter(lambda: f.read(nblocks*sh.block_size), b''): + sh.update(chunk) + return sh + + +def sha256_of_string(string): + """ Computes the SHA256 hash of a string. """ + sh = sha256() + sh.update(string) + return sh + + +def pyx_is_cplus(path): + """ + Inspect a Cython source file (.pyx) and look for comment line like: + + # distutils: language = c++ + + Returns True if such a file is present in the file, else False. + """ + with open(path) as fh: + for line in fh: + if line.startswith('#') and '=' in line: + splitted = line.split('=') + if len(splitted) != 2: + continue + lhs, rhs = splitted + if lhs.strip().split()[-1].lower() == 'language' and \ + rhs.strip().split()[0].lower() == 'c++': + return True + return False + +def import_module_from_file(filename, only_if_newer_than=None): + """ Imports Python extension (from shared object file) + + Provide a list of paths in `only_if_newer_than` to check + timestamps of dependencies. import_ raises an ImportError + if any is newer. + + Word of warning: The OS may cache shared objects which makes + reimporting same path of an shared object file very problematic. + + It will not detect the new time stamp, nor new checksum, but will + instead silently use old module. Use unique names for this reason. + + Parameters + ========== + + filename : str + Path to shared object. + only_if_newer_than : iterable of strings + Paths to dependencies of the shared object. + + Raises + ====== + + ``ImportError`` if any of the files specified in ``only_if_newer_than`` are newer + than the file given by filename. + """ + path, name = os.path.split(filename) + name, ext = os.path.splitext(name) + name = name.split('.')[0] + if sys.version_info[0] == 2: + from imp import find_module, load_module + fobj, filename, data = find_module(name, [path]) + if only_if_newer_than: + for dep in only_if_newer_than: + if os.path.getmtime(filename) < os.path.getmtime(dep): + raise ImportError("{} is newer than {}".format(dep, filename)) + mod = load_module(name, fobj, filename, data) + else: + import importlib.util + spec = importlib.util.spec_from_file_location(name, filename) + if spec is None: + raise ImportError("Failed to import: '%s'" % filename) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + return mod + + +def find_binary_of_command(candidates): + """ Finds binary first matching name among candidates. + + Calls ``which`` from shutils for provided candidates and returns + first hit. + + Parameters + ========== + + candidates : iterable of str + Names of candidate commands + + Raises + ====== + + CompilerNotFoundError if no candidates match. + """ + from shutil import which + for c in candidates: + binary_path = which(c) + if c and binary_path: + return c, binary_path + + raise CompilerNotFoundError('No binary located for candidates: {}'.format(candidates)) + + +def unique_list(l): + """ Uniquify a list (skip duplicate items). """ + result = [] + for x in l: + if x not in result: + result.append(x) + return result diff --git a/lib/python3.10/site-packages/sympy/utilities/mathml/__init__.py b/lib/python3.10/site-packages/sympy/utilities/mathml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eded44ee3c0f34ad1324765ba06ee9d6eb5e9899 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/mathml/__init__.py @@ -0,0 +1,122 @@ +"""Module with some functions for MathML, like transforming MathML +content in MathML presentation. + +To use this module, you will need lxml. +""" + +from pathlib import Path + +from sympy.utilities.decorator import doctest_depends_on + + +__doctest_requires__ = {('apply_xsl', 'c2p'): ['lxml']} + + +def add_mathml_headers(s): + return """""" + s + "" + + +def _read_binary(pkgname, filename): + import sys + + if sys.version_info >= (3, 10): + # files was added in Python 3.9 but only seems to work here in 3.10+ + from importlib.resources import files + return files(pkgname).joinpath(filename).read_bytes() + else: + # read_binary was deprecated in Python 3.11 + from importlib.resources import read_binary + return read_binary(pkgname, filename) + + +def _read_xsl(xsl): + # Previously these values were allowed: + if xsl == 'mathml/data/simple_mmlctop.xsl': + xsl = 'simple_mmlctop.xsl' + elif xsl == 'mathml/data/mmlctop.xsl': + xsl = 'mmlctop.xsl' + elif xsl == 'mathml/data/mmltex.xsl': + xsl = 'mmltex.xsl' + + if xsl in ['simple_mmlctop.xsl', 'mmlctop.xsl', 'mmltex.xsl']: + xslbytes = _read_binary('sympy.utilities.mathml.data', xsl) + else: + xslbytes = Path(xsl).read_bytes() + + return xslbytes + + +@doctest_depends_on(modules=('lxml',)) +def apply_xsl(mml, xsl): + """Apply a xsl to a MathML string. + + Parameters + ========== + + mml + A string with MathML code. + xsl + A string giving the name of an xsl (xml stylesheet) file which can be + found in sympy/utilities/mathml/data. The following files are supplied + with SymPy: + + - mmlctop.xsl + - mmltex.xsl + - simple_mmlctop.xsl + + Alternatively, a full path to an xsl file can be given. + + Examples + ======== + + >>> from sympy.utilities.mathml import apply_xsl + >>> xsl = 'simple_mmlctop.xsl' + >>> mml = ' a b ' + >>> res = apply_xsl(mml,xsl) + >>> print(res) + + + a + + + b + + """ + from lxml import etree + + parser = etree.XMLParser(resolve_entities=False) + ac = etree.XSLTAccessControl.DENY_ALL + + s = etree.XML(_read_xsl(xsl), parser=parser) + transform = etree.XSLT(s, access_control=ac) + doc = etree.XML(mml, parser=parser) + result = transform(doc) + s = str(result) + return s + + +@doctest_depends_on(modules=('lxml',)) +def c2p(mml, simple=False): + """Transforms a document in MathML content (like the one that sympy produces) + in one document in MathML presentation, more suitable for printing, and more + widely accepted + + Examples + ======== + + >>> from sympy.utilities.mathml import c2p + >>> mml = ' 2 ' + >>> c2p(mml,simple=True) != c2p(mml,simple=False) + True + + """ + + if not mml.startswith('= 0x03000000 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "wrapper_module_%(num)s", + NULL, + -1, + wrapper_module_%(num)sMethods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = PyModule_Create(&moduledef); + if (!m) { + return NULL; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "test", ufunc0); + Py_DECREF(ufunc0); + return m; +} +#else +PyMODINIT_FUNC initwrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods); + if (m == NULL) { + return; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(test_funcs, test_data, test_types, 1, 3, 1, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "test", ufunc0); + Py_DECREF(ufunc0); +} +#endif""" % {'num': CodeWrapper._module_counter} + assert source == expected + + +def test_ufuncify_source_multioutput(): + x, y, z = symbols('x,y,z') + var_symbols = (x, y, z) + expr = x + y**3 + 10*z**2 + code_wrapper = UfuncifyCodeWrapper(C99CodeGen("ufuncify")) + routines = [make_routine("func{}".format(i), expr.diff(var_symbols[i]), var_symbols) for i in range(len(var_symbols))] + source = get_string(code_wrapper.dump_c, routines, funcname='multitest') + expected = """\ +#include "Python.h" +#include "math.h" +#include "numpy/ndarraytypes.h" +#include "numpy/ufuncobject.h" +#include "numpy/halffloat.h" +#include "file.h" + +static PyMethodDef wrapper_module_%(num)sMethods[] = { + {NULL, NULL, 0, NULL} +}; + +#ifdef NPY_1_19_API_VERSION +static void multitest_ufunc(char **args, const npy_intp *dimensions, const npy_intp* steps, void* data) +#else +static void multitest_ufunc(char **args, npy_intp *dimensions, npy_intp* steps, void* data) +#endif +{ + npy_intp i; + npy_intp n = dimensions[0]; + char *in0 = args[0]; + char *in1 = args[1]; + char *in2 = args[2]; + char *out0 = args[3]; + char *out1 = args[4]; + char *out2 = args[5]; + npy_intp in0_step = steps[0]; + npy_intp in1_step = steps[1]; + npy_intp in2_step = steps[2]; + npy_intp out0_step = steps[3]; + npy_intp out1_step = steps[4]; + npy_intp out2_step = steps[5]; + for (i = 0; i < n; i++) { + *((double *)out0) = func0(*(double *)in0, *(double *)in1, *(double *)in2); + *((double *)out1) = func1(*(double *)in0, *(double *)in1, *(double *)in2); + *((double *)out2) = func2(*(double *)in0, *(double *)in1, *(double *)in2); + in0 += in0_step; + in1 += in1_step; + in2 += in2_step; + out0 += out0_step; + out1 += out1_step; + out2 += out2_step; + } +} +PyUFuncGenericFunction multitest_funcs[1] = {&multitest_ufunc}; +static char multitest_types[6] = {NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE}; +static void *multitest_data[1] = {NULL}; + +#if PY_VERSION_HEX >= 0x03000000 +static struct PyModuleDef moduledef = { + PyModuleDef_HEAD_INIT, + "wrapper_module_%(num)s", + NULL, + -1, + wrapper_module_%(num)sMethods, + NULL, + NULL, + NULL, + NULL +}; + +PyMODINIT_FUNC PyInit_wrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = PyModule_Create(&moduledef); + if (!m) { + return NULL; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "multitest", ufunc0); + Py_DECREF(ufunc0); + return m; +} +#else +PyMODINIT_FUNC initwrapper_module_%(num)s(void) +{ + PyObject *m, *d; + PyObject *ufunc0; + m = Py_InitModule("wrapper_module_%(num)s", wrapper_module_%(num)sMethods); + if (m == NULL) { + return; + } + import_array(); + import_umath(); + d = PyModule_GetDict(m); + ufunc0 = PyUFunc_FromFuncAndData(multitest_funcs, multitest_data, multitest_types, 1, 3, 3, + PyUFunc_None, "wrapper_module_%(num)s", "Created in SymPy with Ufuncify", 0); + PyDict_SetItemString(d, "multitest", ufunc0); + Py_DECREF(ufunc0); +} +#endif""" % {'num': CodeWrapper._module_counter} + assert source == expected diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen.py new file mode 100644 index 0000000000000000000000000000000000000000..4ccc6f9a90fb0a0bec39cea22420da8091ede740 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen.py @@ -0,0 +1,1632 @@ +from io import StringIO + +from sympy.core import symbols, Eq, pi, Catalan, Lambda, Dummy +from sympy.core.relational import Equality +from sympy.core.symbol import Symbol +from sympy.functions.special.error_functions import erf +from sympy.integrals.integrals import Integral +from sympy.matrices import Matrix, MatrixSymbol +from sympy.utilities.codegen import ( + codegen, make_routine, CCodeGen, C89CodeGen, C99CodeGen, InputArgument, + CodeGenError, FCodeGen, CodeGenArgumentListError, OutputArgument, + InOutArgument) +from sympy.testing.pytest import raises +from sympy.utilities.lambdify import implemented_function + +#FIXME: Fails due to circular import in with core +# from sympy import codegen + + +def get_string(dump_fn, routines, prefix="file", header=False, empty=False): + """Wrapper for dump_fn. dump_fn writes its results to a stream object and + this wrapper returns the contents of that stream as a string. This + auxiliary function is used by many tests below. + + The header and the empty lines are not generated to facilitate the + testing of the output. + """ + output = StringIO() + dump_fn(routines, output, prefix, header, empty) + source = output.getvalue() + output.close() + return source + + +def test_Routine_argument_order(): + a, x, y, z = symbols('a x y z') + expr = (x + y)*z + raises(CodeGenArgumentListError, lambda: make_routine("test", expr, + argument_sequence=[z, x])) + raises(CodeGenArgumentListError, lambda: make_routine("test", Eq(a, + expr), argument_sequence=[z, x, y])) + r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y]) + assert [ arg.name for arg in r.arguments ] == [z, x, a, y] + assert [ type(arg) for arg in r.arguments ] == [ + InputArgument, InputArgument, OutputArgument, InputArgument ] + r = make_routine('test', Eq(z, expr), argument_sequence=[z, x, y]) + assert [ type(arg) for arg in r.arguments ] == [ + InOutArgument, InputArgument, InputArgument ] + + from sympy.tensor import IndexedBase, Idx + A, B = map(IndexedBase, ['A', 'B']) + m = symbols('m', integer=True) + i = Idx('i', m) + r = make_routine('test', Eq(A[i], B[i]), argument_sequence=[B, A, m]) + assert [ arg.name for arg in r.arguments ] == [B.label, A.label, m] + + expr = Integral(x*y*z, (x, 1, 2), (y, 1, 3)) + r = make_routine('test', Eq(a, expr), argument_sequence=[z, x, a, y]) + assert [ arg.name for arg in r.arguments ] == [z, x, a, y] + + +def test_empty_c_code(): + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, []) + assert source == "#include \"file.h\"\n#include \n" + + +def test_empty_c_code_with_comment(): + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [], header=True) + assert source[:82] == ( + "/******************************************************************************\n *" + ) + # " Code generated with SymPy 0.7.2-git " + assert source[158:] == ( "*\n" + " * *\n" + " * See http://www.sympy.org/ for more information. *\n" + " * *\n" + " * This file is part of 'project' *\n" + " ******************************************************************************/\n" + "#include \"file.h\"\n" + "#include \n" + ) + + +def test_empty_c_header(): + code_gen = C99CodeGen() + source = get_string(code_gen.dump_h, []) + assert source == "#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n#endif\n" + + +def test_simple_c_code(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test(double x, double y, double z) {\n" + " double test_result;\n" + " test_result = z*(x + y);\n" + " return test_result;\n" + "}\n" + ) + assert source == expected + + +def test_c_code_reserved_words(): + x, y, z = symbols('if, typedef, while') + expr = (x + y) * z + routine = make_routine("test", expr) + code_gen = C99CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test(double if_, double typedef_, double while_) {\n" + " double test_result;\n" + " test_result = while_*(if_ + typedef_);\n" + " return test_result;\n" + "}\n" + ) + assert source == expected + + +def test_numbersymbol_c_code(): + routine = make_routine("test", pi**Catalan) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test() {\n" + " double test_result;\n" + " double const Catalan = %s;\n" + " test_result = pow(M_PI, Catalan);\n" + " return test_result;\n" + "}\n" + ) % Catalan.evalf(17) + assert source == expected + + +def test_c_code_argument_order(): + x, y, z = symbols('x,y,z') + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y]) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_c, [routine]) + expected = ( + "#include \"file.h\"\n" + "#include \n" + "double test(double z, double x, double y) {\n" + " double test_result;\n" + " test_result = x + y;\n" + " return test_result;\n" + "}\n" + ) + assert source == expected + + +def test_simple_c_header(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = C89CodeGen() + source = get_string(code_gen.dump_h, [routine]) + expected = ( + "#ifndef PROJECT__FILE__H\n" + "#define PROJECT__FILE__H\n" + "double test(double x, double y, double z);\n" + "#endif\n" + ) + assert source == expected + + +def test_simple_c_codegen(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + expected = [ + ("file.c", + "#include \"file.h\"\n" + "#include \n" + "double test(double x, double y, double z) {\n" + " double test_result;\n" + " test_result = z*(x + y);\n" + " return test_result;\n" + "}\n"), + ("file.h", + "#ifndef PROJECT__FILE__H\n" + "#define PROJECT__FILE__H\n" + "double test(double x, double y, double z);\n" + "#endif\n") + ] + result = codegen(("test", expr), "C", "file", header=False, empty=False) + assert result == expected + + +def test_multiple_results_c(): + x, y, z = symbols('x,y,z') + expr1 = (x + y)*z + expr2 = (x - y)*z + routine = make_routine( + "test", + [expr1, expr2] + ) + code_gen = C99CodeGen() + raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine])) + + +def test_no_results_c(): + raises(ValueError, lambda: make_routine("test", [])) + + +def test_ansi_math1_codegen(): + # not included: log10 + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh) + from sympy.functions.elementary.integers import (ceiling, floor) + from sympy.functions.elementary.miscellaneous import sqrt + from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan) + x = symbols('x') + name_expr = [ + ("test_fabs", Abs(x)), + ("test_acos", acos(x)), + ("test_asin", asin(x)), + ("test_atan", atan(x)), + ("test_ceil", ceiling(x)), + ("test_cos", cos(x)), + ("test_cosh", cosh(x)), + ("test_floor", floor(x)), + ("test_log", log(x)), + ("test_ln", log(x)), + ("test_sin", sin(x)), + ("test_sinh", sinh(x)), + ("test_sqrt", sqrt(x)), + ("test_tan", tan(x)), + ("test_tanh", tanh(x)), + ] + result = codegen(name_expr, "C89", "file", header=False, empty=False) + assert result[0][0] == "file.c" + assert result[0][1] == ( + '#include "file.h"\n#include \n' + 'double test_fabs(double x) {\n double test_fabs_result;\n test_fabs_result = fabs(x);\n return test_fabs_result;\n}\n' + 'double test_acos(double x) {\n double test_acos_result;\n test_acos_result = acos(x);\n return test_acos_result;\n}\n' + 'double test_asin(double x) {\n double test_asin_result;\n test_asin_result = asin(x);\n return test_asin_result;\n}\n' + 'double test_atan(double x) {\n double test_atan_result;\n test_atan_result = atan(x);\n return test_atan_result;\n}\n' + 'double test_ceil(double x) {\n double test_ceil_result;\n test_ceil_result = ceil(x);\n return test_ceil_result;\n}\n' + 'double test_cos(double x) {\n double test_cos_result;\n test_cos_result = cos(x);\n return test_cos_result;\n}\n' + 'double test_cosh(double x) {\n double test_cosh_result;\n test_cosh_result = cosh(x);\n return test_cosh_result;\n}\n' + 'double test_floor(double x) {\n double test_floor_result;\n test_floor_result = floor(x);\n return test_floor_result;\n}\n' + 'double test_log(double x) {\n double test_log_result;\n test_log_result = log(x);\n return test_log_result;\n}\n' + 'double test_ln(double x) {\n double test_ln_result;\n test_ln_result = log(x);\n return test_ln_result;\n}\n' + 'double test_sin(double x) {\n double test_sin_result;\n test_sin_result = sin(x);\n return test_sin_result;\n}\n' + 'double test_sinh(double x) {\n double test_sinh_result;\n test_sinh_result = sinh(x);\n return test_sinh_result;\n}\n' + 'double test_sqrt(double x) {\n double test_sqrt_result;\n test_sqrt_result = sqrt(x);\n return test_sqrt_result;\n}\n' + 'double test_tan(double x) {\n double test_tan_result;\n test_tan_result = tan(x);\n return test_tan_result;\n}\n' + 'double test_tanh(double x) {\n double test_tanh_result;\n test_tanh_result = tanh(x);\n return test_tanh_result;\n}\n' + ) + assert result[1][0] == "file.h" + assert result[1][1] == ( + '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n' + 'double test_fabs(double x);\ndouble test_acos(double x);\n' + 'double test_asin(double x);\ndouble test_atan(double x);\n' + 'double test_ceil(double x);\ndouble test_cos(double x);\n' + 'double test_cosh(double x);\ndouble test_floor(double x);\n' + 'double test_log(double x);\ndouble test_ln(double x);\n' + 'double test_sin(double x);\ndouble test_sinh(double x);\n' + 'double test_sqrt(double x);\ndouble test_tan(double x);\n' + 'double test_tanh(double x);\n#endif\n' + ) + + +def test_ansi_math2_codegen(): + # not included: frexp, ldexp, modf, fmod + from sympy.functions.elementary.trigonometric import atan2 + x, y = symbols('x,y') + name_expr = [ + ("test_atan2", atan2(x, y)), + ("test_pow", x**y), + ] + result = codegen(name_expr, "C89", "file", header=False, empty=False) + assert result[0][0] == "file.c" + assert result[0][1] == ( + '#include "file.h"\n#include \n' + 'double test_atan2(double x, double y) {\n double test_atan2_result;\n test_atan2_result = atan2(x, y);\n return test_atan2_result;\n}\n' + 'double test_pow(double x, double y) {\n double test_pow_result;\n test_pow_result = pow(x, y);\n return test_pow_result;\n}\n' + ) + assert result[1][0] == "file.h" + assert result[1][1] == ( + '#ifndef PROJECT__FILE__H\n#define PROJECT__FILE__H\n' + 'double test_atan2(double x, double y);\n' + 'double test_pow(double x, double y);\n' + '#endif\n' + ) + + +def test_complicated_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + x, y, z = symbols('x,y,z') + name_expr = [ + ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()), + ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))), + ] + result = codegen(name_expr, "C89", "file", header=False, empty=False) + assert result[0][0] == "file.c" + assert result[0][1] == ( + '#include "file.h"\n#include \n' + 'double test1(double x, double y, double z) {\n' + ' double test1_result;\n' + ' test1_result = ' + 'pow(sin(x), 7) + ' + '7*pow(sin(x), 6)*cos(y) + ' + '7*pow(sin(x), 6)*tan(z) + ' + '21*pow(sin(x), 5)*pow(cos(y), 2) + ' + '42*pow(sin(x), 5)*cos(y)*tan(z) + ' + '21*pow(sin(x), 5)*pow(tan(z), 2) + ' + '35*pow(sin(x), 4)*pow(cos(y), 3) + ' + '105*pow(sin(x), 4)*pow(cos(y), 2)*tan(z) + ' + '105*pow(sin(x), 4)*cos(y)*pow(tan(z), 2) + ' + '35*pow(sin(x), 4)*pow(tan(z), 3) + ' + '35*pow(sin(x), 3)*pow(cos(y), 4) + ' + '140*pow(sin(x), 3)*pow(cos(y), 3)*tan(z) + ' + '210*pow(sin(x), 3)*pow(cos(y), 2)*pow(tan(z), 2) + ' + '140*pow(sin(x), 3)*cos(y)*pow(tan(z), 3) + ' + '35*pow(sin(x), 3)*pow(tan(z), 4) + ' + '21*pow(sin(x), 2)*pow(cos(y), 5) + ' + '105*pow(sin(x), 2)*pow(cos(y), 4)*tan(z) + ' + '210*pow(sin(x), 2)*pow(cos(y), 3)*pow(tan(z), 2) + ' + '210*pow(sin(x), 2)*pow(cos(y), 2)*pow(tan(z), 3) + ' + '105*pow(sin(x), 2)*cos(y)*pow(tan(z), 4) + ' + '21*pow(sin(x), 2)*pow(tan(z), 5) + ' + '7*sin(x)*pow(cos(y), 6) + ' + '42*sin(x)*pow(cos(y), 5)*tan(z) + ' + '105*sin(x)*pow(cos(y), 4)*pow(tan(z), 2) + ' + '140*sin(x)*pow(cos(y), 3)*pow(tan(z), 3) + ' + '105*sin(x)*pow(cos(y), 2)*pow(tan(z), 4) + ' + '42*sin(x)*cos(y)*pow(tan(z), 5) + ' + '7*sin(x)*pow(tan(z), 6) + ' + 'pow(cos(y), 7) + ' + '7*pow(cos(y), 6)*tan(z) + ' + '21*pow(cos(y), 5)*pow(tan(z), 2) + ' + '35*pow(cos(y), 4)*pow(tan(z), 3) + ' + '35*pow(cos(y), 3)*pow(tan(z), 4) + ' + '21*pow(cos(y), 2)*pow(tan(z), 5) + ' + '7*cos(y)*pow(tan(z), 6) + ' + 'pow(tan(z), 7);\n' + ' return test1_result;\n' + '}\n' + 'double test2(double x, double y, double z) {\n' + ' double test2_result;\n' + ' test2_result = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n' + ' return test2_result;\n' + '}\n' + ) + assert result[1][0] == "file.h" + assert result[1][1] == ( + '#ifndef PROJECT__FILE__H\n' + '#define PROJECT__FILE__H\n' + 'double test1(double x, double y, double z);\n' + 'double test2(double x, double y, double z);\n' + '#endif\n' + ) + + +def test_loops_c(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "C99", "file", header=False, empty=False) + + assert f1 == 'file.c' + expected = ( + '#include "file.h"\n' + '#include \n' + 'void matrix_vector(double *A, int m, int n, double *x, double *y) {\n' + ' for (int i=0; i\n' + 'void test_dummies(int m_%(mno)i, double *x, double *y) {\n' + ' for (int i_%(ino)i=0; i_%(ino)i\n' + 'void matrix_vector(double *A, int m, int n, int o, int p, double *x, double *y) {\n' + ' for (int i=o; i<%(upperi)s; i++){\n' + ' y[i] = 0;\n' + ' }\n' + ' for (int i=o; i<%(upperi)s; i++){\n' + ' for (int j=0; j\n' + 'double foo(double x, double *y) {\n' + ' (*y) = sin(x);\n' + ' double foo_result;\n' + ' foo_result = cos(x);\n' + ' return foo_result;\n' + '}\n' + ) + assert result[0][1] == expected + + +def test_output_arg_c_reserved_words(): + from sympy.core.relational import Equality + from sympy.functions.elementary.trigonometric import (cos, sin) + x, y, z = symbols("if, while, z") + r = make_routine("foo", [Equality(y, sin(x)), cos(x)]) + c = C89CodeGen() + result = c.write([r], "test", header=False, empty=False) + assert result[0][0] == "test.c" + expected = ( + '#include "test.h"\n' + '#include \n' + 'double foo(double if_, double *while_) {\n' + ' (*while_) = sin(if_);\n' + ' double foo_result;\n' + ' foo_result = cos(if_);\n' + ' return foo_result;\n' + '}\n' + ) + assert result[0][1] == expected + + +def test_multidim_c_argument_cse(): + A_sym = MatrixSymbol('A', 3, 3) + b_sym = MatrixSymbol('b', 3, 1) + A = Matrix(A_sym) + b = Matrix(b_sym) + c = A*b + cgen = CCodeGen(project="test", cse=True) + r = cgen.routine("c", c) + r.arguments[-1].result_var = "out" + r.arguments[-1]._name = "out" + code = get_string(cgen.dump_c, [r], prefix="test") + expected = ( + '#include "test.h"\n' + "#include \n" + "void c(double *A, double *b, double *out) {\n" + " out[0] = A[0]*b[0] + A[1]*b[1] + A[2]*b[2];\n" + " out[1] = A[3]*b[0] + A[4]*b[1] + A[5]*b[2];\n" + " out[2] = A[6]*b[0] + A[7]*b[1] + A[8]*b[2];\n" + "}\n" + ) + assert code == expected + + +def test_ccode_results_named_ordered(): + x, y, z = symbols('x,y,z') + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(A, Matrix([[1, 2, x]])) + expr2 = Equality(C, (x + y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + expected = ( + '#include "test.h"\n' + '#include \n' + 'void test(double x, double *C, double z, double y, double *A, double *B) {\n' + ' (*C) = z*(x + y);\n' + ' A[0] = 1;\n' + ' A[1] = 2;\n' + ' A[2] = x;\n' + ' (*B) = 2*x;\n' + '}\n' + ) + + result = codegen(name_expr, "c", "test", header=False, empty=False, + argument_sequence=(x, C, z, y, A, B)) + source = result[0][1] + assert source == expected + + +def test_ccode_matrixsymbol_slice(): + A = MatrixSymbol('A', 5, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 5, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result = codegen(name_expr, "c99", "test", header=False, empty=False) + source = result[0][1] + expected = ( + '#include "test.h"\n' + '#include \n' + 'void test(double *A, double *B, double *C, double *D) {\n' + ' B[0] = A[0];\n' + ' B[1] = A[1];\n' + ' B[2] = A[2];\n' + ' C[0] = A[3];\n' + ' C[1] = A[4];\n' + ' C[2] = A[5];\n' + ' D[0] = A[2];\n' + ' D[1] = A[5];\n' + ' D[2] = A[8];\n' + ' D[3] = A[11];\n' + ' D[4] = A[14];\n' + '}\n' + ) + assert source == expected + +def test_ccode_cse(): + a, b, c, d = symbols('a b c d') + e = MatrixSymbol('e', 3, 1) + name_expr = ("test", [Equality(e, Matrix([[a*b], [a*b + c*d], [a*b*c*d]]))]) + generator = CCodeGen(cse=True) + result = codegen(name_expr, code_gen=generator, header=False, empty=False) + source = result[0][1] + expected = ( + '#include "test.h"\n' + '#include \n' + 'void test(double a, double b, double c, double d, double *e) {\n' + ' const double x0 = a*b;\n' + ' const double x1 = c*d;\n' + ' e[0] = x0;\n' + ' e[1] = x0 + x1;\n' + ' e[2] = x0*x1;\n' + '}\n' + ) + assert source == expected + +def test_ccode_unused_array_arg(): + x = MatrixSymbol('x', 2, 1) + # x does not appear in output + name_expr = ("test", 1.0) + generator = CCodeGen() + result = codegen(name_expr, code_gen=generator, header=False, empty=False, argument_sequence=(x,)) + source = result[0][1] + # note: x should appear as (double *) + expected = ( + '#include "test.h"\n' + '#include \n' + 'double test(double *x) {\n' + ' double test_result;\n' + ' test_result = 1.0;\n' + ' return test_result;\n' + '}\n' + ) + assert source == expected + +def test_ccode_unused_array_arg_func(): + # issue 16689 + X = MatrixSymbol('X',3,1) + Y = MatrixSymbol('Y',3,1) + z = symbols('z',integer = True) + name_expr = ('testBug', X[0] + X[1]) + result = codegen(name_expr, language='C', header=False, empty=False, argument_sequence=(X, Y, z)) + source = result[0][1] + expected = ( + '#include "testBug.h"\n' + '#include \n' + 'double testBug(double *X, double *Y, int z) {\n' + ' double testBug_result;\n' + ' testBug_result = X[0] + X[1];\n' + ' return testBug_result;\n' + '}\n' + ) + assert source == expected + +def test_empty_f_code(): + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, []) + assert source == "" + + +def test_empty_f_code_with_header(): + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [], header=True) + assert source[:82] == ( + "!******************************************************************************\n!*" + ) + # " Code generated with SymPy 0.7.2-git " + assert source[158:] == ( "*\n" + "!* *\n" + "!* See http://www.sympy.org/ for more information. *\n" + "!* *\n" + "!* This file is part of 'project' *\n" + "!******************************************************************************\n" + ) + + +def test_empty_f_header(): + code_gen = FCodeGen() + source = get_string(code_gen.dump_h, []) + assert source == "" + + +def test_simple_f_code(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "test = z*(x + y)\n" + "end function\n" + ) + assert source == expected + + +def test_numbersymbol_f_code(): + routine = make_routine("test", pi**Catalan) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test()\n" + "implicit none\n" + "REAL*8, parameter :: Catalan = %sd0\n" + "REAL*8, parameter :: pi = %sd0\n" + "test = pi**Catalan\n" + "end function\n" + ) % (Catalan.evalf(17), pi.evalf(17)) + assert source == expected + +def test_erf_f_code(): + x = symbols('x') + routine = make_routine("test", erf(x) - erf(-2 * x)) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test(x)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "test = erf(x) + erf(2.0d0*x)\n" + "end function\n" + ) + assert source == expected, source + +def test_f_code_argument_order(): + x, y, z = symbols('x,y,z') + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y]) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = ( + "REAL*8 function test(z, x, y)\n" + "implicit none\n" + "REAL*8, intent(in) :: z\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "test = x + y\n" + "end function\n" + ) + assert source == expected + + +def test_simple_f_header(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + routine = make_routine("test", expr) + code_gen = FCodeGen() + source = get_string(code_gen.dump_h, [routine]) + expected = ( + "interface\n" + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "end function\n" + "end interface\n" + ) + assert source == expected + + +def test_simple_f_codegen(): + x, y, z = symbols('x,y,z') + expr = (x + y)*z + result = codegen( + ("test", expr), "F95", "file", header=False, empty=False) + expected = [ + ("file.f90", + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "test = z*(x + y)\n" + "end function\n"), + ("file.h", + "interface\n" + "REAL*8 function test(x, y, z)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "end function\n" + "end interface\n") + ] + assert result == expected + + +def test_multiple_results_f(): + x, y, z = symbols('x,y,z') + expr1 = (x + y)*z + expr2 = (x - y)*z + routine = make_routine( + "test", + [expr1, expr2] + ) + code_gen = FCodeGen() + raises(CodeGenError, lambda: get_string(code_gen.dump_h, [routine])) + + +def test_no_results_f(): + raises(ValueError, lambda: make_routine("test", [])) + + +def test_intrinsic_math_codegen(): + # not included: log10 + from sympy.functions.elementary.complexes import Abs + from sympy.functions.elementary.exponential import log + from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh) + from sympy.functions.elementary.miscellaneous import sqrt + from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan) + x = symbols('x') + name_expr = [ + ("test_abs", Abs(x)), + ("test_acos", acos(x)), + ("test_asin", asin(x)), + ("test_atan", atan(x)), + ("test_cos", cos(x)), + ("test_cosh", cosh(x)), + ("test_log", log(x)), + ("test_ln", log(x)), + ("test_sin", sin(x)), + ("test_sinh", sinh(x)), + ("test_sqrt", sqrt(x)), + ("test_tan", tan(x)), + ("test_tanh", tanh(x)), + ] + result = codegen(name_expr, "F95", "file", header=False, empty=False) + assert result[0][0] == "file.f90" + expected = ( + 'REAL*8 function test_abs(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_abs = abs(x)\n' + 'end function\n' + 'REAL*8 function test_acos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_acos = acos(x)\n' + 'end function\n' + 'REAL*8 function test_asin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_asin = asin(x)\n' + 'end function\n' + 'REAL*8 function test_atan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_atan = atan(x)\n' + 'end function\n' + 'REAL*8 function test_cos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_cos = cos(x)\n' + 'end function\n' + 'REAL*8 function test_cosh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_cosh = cosh(x)\n' + 'end function\n' + 'REAL*8 function test_log(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_log = log(x)\n' + 'end function\n' + 'REAL*8 function test_ln(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_ln = log(x)\n' + 'end function\n' + 'REAL*8 function test_sin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_sin = sin(x)\n' + 'end function\n' + 'REAL*8 function test_sinh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_sinh = sinh(x)\n' + 'end function\n' + 'REAL*8 function test_sqrt(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_sqrt = sqrt(x)\n' + 'end function\n' + 'REAL*8 function test_tan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_tan = tan(x)\n' + 'end function\n' + 'REAL*8 function test_tanh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'test_tanh = tanh(x)\n' + 'end function\n' + ) + assert result[0][1] == expected + + assert result[1][0] == "file.h" + expected = ( + 'interface\n' + 'REAL*8 function test_abs(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_acos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_asin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_atan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_cos(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_cosh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_log(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_ln(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_sin(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_sinh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_sqrt(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_tan(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_tanh(x)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'end function\n' + 'end interface\n' + ) + assert result[1][1] == expected + + +def test_intrinsic_math2_codegen(): + # not included: frexp, ldexp, modf, fmod + from sympy.functions.elementary.trigonometric import atan2 + x, y = symbols('x,y') + name_expr = [ + ("test_atan2", atan2(x, y)), + ("test_pow", x**y), + ] + result = codegen(name_expr, "F95", "file", header=False, empty=False) + assert result[0][0] == "file.f90" + expected = ( + 'REAL*8 function test_atan2(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'test_atan2 = atan2(x, y)\n' + 'end function\n' + 'REAL*8 function test_pow(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'test_pow = x**y\n' + 'end function\n' + ) + assert result[0][1] == expected + + assert result[1][0] == "file.h" + expected = ( + 'interface\n' + 'REAL*8 function test_atan2(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test_pow(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'end function\n' + 'end interface\n' + ) + assert result[1][1] == expected + + +def test_complicated_codegen_f95(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + x, y, z = symbols('x,y,z') + name_expr = [ + ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()), + ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))), + ] + result = codegen(name_expr, "F95", "file", header=False, empty=False) + assert result[0][0] == "file.f90" + expected = ( + 'REAL*8 function test1(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'test1 = sin(x)**7 + 7*sin(x)**6*cos(y) + 7*sin(x)**6*tan(z) + 21*sin(x) &\n' + ' **5*cos(y)**2 + 42*sin(x)**5*cos(y)*tan(z) + 21*sin(x)**5*tan(z) &\n' + ' **2 + 35*sin(x)**4*cos(y)**3 + 105*sin(x)**4*cos(y)**2*tan(z) + &\n' + ' 105*sin(x)**4*cos(y)*tan(z)**2 + 35*sin(x)**4*tan(z)**3 + 35*sin( &\n' + ' x)**3*cos(y)**4 + 140*sin(x)**3*cos(y)**3*tan(z) + 210*sin(x)**3* &\n' + ' cos(y)**2*tan(z)**2 + 140*sin(x)**3*cos(y)*tan(z)**3 + 35*sin(x) &\n' + ' **3*tan(z)**4 + 21*sin(x)**2*cos(y)**5 + 105*sin(x)**2*cos(y)**4* &\n' + ' tan(z) + 210*sin(x)**2*cos(y)**3*tan(z)**2 + 210*sin(x)**2*cos(y) &\n' + ' **2*tan(z)**3 + 105*sin(x)**2*cos(y)*tan(z)**4 + 21*sin(x)**2*tan &\n' + ' (z)**5 + 7*sin(x)*cos(y)**6 + 42*sin(x)*cos(y)**5*tan(z) + 105* &\n' + ' sin(x)*cos(y)**4*tan(z)**2 + 140*sin(x)*cos(y)**3*tan(z)**3 + 105 &\n' + ' *sin(x)*cos(y)**2*tan(z)**4 + 42*sin(x)*cos(y)*tan(z)**5 + 7*sin( &\n' + ' x)*tan(z)**6 + cos(y)**7 + 7*cos(y)**6*tan(z) + 21*cos(y)**5*tan( &\n' + ' z)**2 + 35*cos(y)**4*tan(z)**3 + 35*cos(y)**3*tan(z)**4 + 21*cos( &\n' + ' y)**2*tan(z)**5 + 7*cos(y)*tan(z)**6 + tan(z)**7\n' + 'end function\n' + 'REAL*8 function test2(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'test2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n' + 'end function\n' + ) + assert result[0][1] == expected + assert result[1][0] == "file.h" + expected = ( + 'interface\n' + 'REAL*8 function test1(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'end function\n' + 'end interface\n' + 'interface\n' + 'REAL*8 function test2(x, y, z)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(in) :: y\n' + 'REAL*8, intent(in) :: z\n' + 'end function\n' + 'end interface\n' + ) + assert result[1][1] == expected + + +def test_loops(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + + n, m = symbols('n,m', integer=True) + A, x, y = map(IndexedBase, 'Axy') + i = Idx('i', m) + j = Idx('j', n) + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False) + + assert f1 == 'file.f90' + expected = ( + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(out), dimension(1:m) :: y\n' + 'INTEGER*4 :: i\n' + 'INTEGER*4 :: j\n' + 'do i = 1, m\n' + ' y(i) = 0\n' + 'end do\n' + 'do i = 1, m\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s + y(i)\n' + ' end do\n' + 'end do\n' + 'end subroutine\n' + ) + + assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\ + code == expected % {'rhs': 'x(j)*A(i, j)'} + assert f2 == 'file.h' + assert interface == ( + 'interface\n' + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(out), dimension(1:m) :: y\n' + 'end subroutine\n' + 'end interface\n' + ) + + +def test_dummy_loops_f95(): + from sympy.tensor import IndexedBase, Idx + i, m = symbols('i m', integer=True, cls=Dummy) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx(i, m) + expected = ( + 'subroutine test_dummies(m_%(mcount)i, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m_%(mcount)i\n' + 'REAL*8, intent(in), dimension(1:m_%(mcount)i) :: x\n' + 'REAL*8, intent(out), dimension(1:m_%(mcount)i) :: y\n' + 'INTEGER*4 :: i_%(icount)i\n' + 'do i_%(icount)i = 1, m_%(mcount)i\n' + ' y(i_%(icount)i) = x(i_%(icount)i)\n' + 'end do\n' + 'end subroutine\n' + ) % {'icount': i.label.dummy_index, 'mcount': m.dummy_index} + r = make_routine('test_dummies', Eq(y[i], x[i])) + c = FCodeGen() + code = get_string(c.dump_f95, [r]) + assert code == expected + + +def test_loops_InOut(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + + i, j, n, m = symbols('i,j,n,m', integer=True) + A, x, y = symbols('A,x,y') + A = IndexedBase(A)[Idx(i, m), Idx(j, n)] + x = IndexedBase(x)[Idx(j, n)] + y = IndexedBase(y)[Idx(i, m)] + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y, y + A*x)), "F95", "file", header=False, empty=False) + + assert f1 == 'file.f90' + expected = ( + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(inout), dimension(1:m) :: y\n' + 'INTEGER*4 :: i\n' + 'INTEGER*4 :: j\n' + 'do i = 1, m\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s + y(i)\n' + ' end do\n' + 'end do\n' + 'end subroutine\n' + ) + + assert (code == expected % {'rhs': 'A(i, j)*x(j)'} or + code == expected % {'rhs': 'x(j)*A(i, j)'}) + assert f2 == 'file.h' + assert interface == ( + 'interface\n' + 'subroutine matrix_vector(A, m, n, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'REAL*8, intent(in), dimension(1:m, 1:n) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(inout), dimension(1:m) :: y\n' + 'end subroutine\n' + 'end interface\n' + ) + + +def test_partial_loops_f(): + # check that loop boundaries are determined by Idx, and array strides + # determined by shape of IndexedBase object. + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m, o, p = symbols('n m o p', integer=True) + A = IndexedBase('A', shape=(m, p)) + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', (o, m - 5)) # Note: bounds are inclusive + j = Idx('j', n) # dimension n corresponds to bounds (0, n - 1) + + (f1, code), (f2, interface) = codegen( + ('matrix_vector', Eq(y[i], A[i, j]*x[j])), "F95", "file", header=False, empty=False) + + expected = ( + 'subroutine matrix_vector(A, m, n, o, p, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'INTEGER*4, intent(in) :: n\n' + 'INTEGER*4, intent(in) :: o\n' + 'INTEGER*4, intent(in) :: p\n' + 'REAL*8, intent(in), dimension(1:m, 1:p) :: A\n' + 'REAL*8, intent(in), dimension(1:n) :: x\n' + 'REAL*8, intent(out), dimension(1:%(iup-ilow)s) :: y\n' + 'INTEGER*4 :: i\n' + 'INTEGER*4 :: j\n' + 'do i = %(ilow)s, %(iup)s\n' + ' y(i) = 0\n' + 'end do\n' + 'do i = %(ilow)s, %(iup)s\n' + ' do j = 1, n\n' + ' y(i) = %(rhs)s + y(i)\n' + ' end do\n' + 'end do\n' + 'end subroutine\n' + ) % { + 'rhs': '%(rhs)s', + 'iup': str(m - 4), + 'ilow': str(1 + o), + 'iup-ilow': str(m - 4 - o) + } + + assert code == expected % {'rhs': 'A(i, j)*x(j)'} or\ + code == expected % {'rhs': 'x(j)*A(i, j)'} + + +def test_output_arg_f(): + from sympy.core.relational import Equality + from sympy.functions.elementary.trigonometric import (cos, sin) + x, y, z = symbols("x,y,z") + r = make_routine("foo", [Equality(y, sin(x)), cos(x)]) + c = FCodeGen() + result = c.write([r], "test", header=False, empty=False) + assert result[0][0] == "test.f90" + assert result[0][1] == ( + 'REAL*8 function foo(x, y)\n' + 'implicit none\n' + 'REAL*8, intent(in) :: x\n' + 'REAL*8, intent(out) :: y\n' + 'y = sin(x)\n' + 'foo = cos(x)\n' + 'end function\n' + ) + + +def test_inline_function(): + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A, x, y = map(IndexedBase, 'Axy') + i = Idx('i', m) + p = FCodeGen() + func = implemented_function('func', Lambda(n, n*(n + 1))) + routine = make_routine('test_inline', Eq(y[i], func(x[i]))) + code = get_string(p.dump_f95, [routine]) + expected = ( + 'subroutine test_inline(m, x, y)\n' + 'implicit none\n' + 'INTEGER*4, intent(in) :: m\n' + 'REAL*8, intent(in), dimension(1:m) :: x\n' + 'REAL*8, intent(out), dimension(1:m) :: y\n' + 'INTEGER*4 :: i\n' + 'do i = 1, m\n' + ' y(i) = %s*%s\n' + 'end do\n' + 'end subroutine\n' + ) + args = ('x(i)', '(x(i) + 1)') + assert code == expected % args or\ + code == expected % args[::-1] + + +def test_f_code_call_signature_wrap(): + # Issue #7934 + x = symbols('x:20') + expr = 0 + for sym in x: + expr += sym + routine = make_routine("test", expr) + code_gen = FCodeGen() + source = get_string(code_gen.dump_f95, [routine]) + expected = """\ +REAL*8 function test(x0, x1, x10, x11, x12, x13, x14, x15, x16, x17, x18, & + x19, x2, x3, x4, x5, x6, x7, x8, x9) +implicit none +REAL*8, intent(in) :: x0 +REAL*8, intent(in) :: x1 +REAL*8, intent(in) :: x10 +REAL*8, intent(in) :: x11 +REAL*8, intent(in) :: x12 +REAL*8, intent(in) :: x13 +REAL*8, intent(in) :: x14 +REAL*8, intent(in) :: x15 +REAL*8, intent(in) :: x16 +REAL*8, intent(in) :: x17 +REAL*8, intent(in) :: x18 +REAL*8, intent(in) :: x19 +REAL*8, intent(in) :: x2 +REAL*8, intent(in) :: x3 +REAL*8, intent(in) :: x4 +REAL*8, intent(in) :: x5 +REAL*8, intent(in) :: x6 +REAL*8, intent(in) :: x7 +REAL*8, intent(in) :: x8 +REAL*8, intent(in) :: x9 +test = x0 + x1 + x10 + x11 + x12 + x13 + x14 + x15 + x16 + x17 + x18 + & + x19 + x2 + x3 + x4 + x5 + x6 + x7 + x8 + x9 +end function +""" + assert source == expected + + +def test_check_case(): + x, X = symbols('x,X') + raises(CodeGenError, lambda: codegen(('test', x*X), 'f95', 'prefix')) + + +def test_check_case_false_positive(): + # The upper case/lower case exception should not be triggered by SymPy + # objects that differ only because of assumptions. (It may be useful to + # have a check for that as well, but here we only want to test against + # false positives with respect to case checking.) + x1 = symbols('x') + x2 = symbols('x', my_assumption=True) + try: + codegen(('test', x1*x2), 'f95', 'prefix') + except CodeGenError as e: + if e.args[0].startswith("Fortran ignores case."): + raise AssertionError("This exception should not be raised!") + + +def test_c_fortran_omit_routine_name(): + x, y = symbols("x,y") + name_expr = [("foo", 2*x)] + result = codegen(name_expr, "F95", header=False, empty=False) + expresult = codegen(name_expr, "F95", "foo", header=False, empty=False) + assert result[0][1] == expresult[0][1] + + name_expr = ("foo", x*y) + result = codegen(name_expr, "F95", header=False, empty=False) + expresult = codegen(name_expr, "F95", "foo", header=False, empty=False) + assert result[0][1] == expresult[0][1] + + name_expr = ("foo", Matrix([[x, y], [x+y, x-y]])) + result = codegen(name_expr, "C89", header=False, empty=False) + expresult = codegen(name_expr, "C89", "foo", header=False, empty=False) + assert result[0][1] == expresult[0][1] + + +def test_fcode_matrix_output(): + x, y, z = symbols('x,y,z') + e1 = x + y + e2 = Matrix([[x, y], [z, 16]]) + name_expr = ("test", (e1, e2)) + result = codegen(name_expr, "f95", "test", header=False, empty=False) + source = result[0][1] + expected = ( + "REAL*8 function test(x, y, z, out_%(hash)s)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(in) :: z\n" + "REAL*8, intent(out), dimension(1:2, 1:2) :: out_%(hash)s\n" + "out_%(hash)s(1, 1) = x\n" + "out_%(hash)s(2, 1) = z\n" + "out_%(hash)s(1, 2) = y\n" + "out_%(hash)s(2, 2) = 16\n" + "test = x + y\n" + "end function\n" + ) + # look for the magic number + a = source.splitlines()[5] + b = a.split('_') + out = b[1] + expected = expected % {'hash': out} + assert source == expected + + +def test_fcode_results_named_ordered(): + x, y, z = symbols('x,y,z') + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(A, Matrix([[1, 2, x]])) + expr2 = Equality(C, (x + y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "f95", "test", header=False, empty=False, + argument_sequence=(x, z, y, C, A, B)) + source = result[0][1] + expected = ( + "subroutine test(x, z, y, C, A, B)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: z\n" + "REAL*8, intent(in) :: y\n" + "REAL*8, intent(out) :: C\n" + "REAL*8, intent(out) :: B\n" + "REAL*8, intent(out), dimension(1:1, 1:3) :: A\n" + "C = z*(x + y)\n" + "A(1, 1) = 1\n" + "A(1, 2) = 2\n" + "A(1, 3) = x\n" + "B = 2*x\n" + "end subroutine\n" + ) + assert source == expected + + +def test_fcode_matrixsymbol_slice(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 2, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result = codegen(name_expr, "f95", "test", header=False, empty=False) + source = result[0][1] + expected = ( + "subroutine test(A, B, C, D)\n" + "implicit none\n" + "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n" + "REAL*8, intent(out), dimension(1:1, 1:3) :: B\n" + "REAL*8, intent(out), dimension(1:1, 1:3) :: C\n" + "REAL*8, intent(out), dimension(1:2, 1:1) :: D\n" + "B(1, 1) = A(1, 1)\n" + "B(1, 2) = A(1, 2)\n" + "B(1, 3) = A(1, 3)\n" + "C(1, 1) = A(2, 1)\n" + "C(1, 2) = A(2, 2)\n" + "C(1, 3) = A(2, 3)\n" + "D(1, 1) = A(1, 3)\n" + "D(2, 1) = A(2, 3)\n" + "end subroutine\n" + ) + assert source == expected + + +def test_fcode_matrixsymbol_slice_autoname(): + # see issue #8093 + A = MatrixSymbol('A', 2, 3) + name_expr = ("test", A[:, 1]) + result = codegen(name_expr, "f95", "test", header=False, empty=False) + source = result[0][1] + expected = ( + "subroutine test(A, out_%(hash)s)\n" + "implicit none\n" + "REAL*8, intent(in), dimension(1:2, 1:3) :: A\n" + "REAL*8, intent(out), dimension(1:2, 1:1) :: out_%(hash)s\n" + "out_%(hash)s(1, 1) = A(1, 2)\n" + "out_%(hash)s(2, 1) = A(2, 2)\n" + "end subroutine\n" + ) + # look for the magic number + a = source.splitlines()[3] + b = a.split('_') + out = b[1] + expected = expected % {'hash': out} + assert source == expected + + +def test_global_vars(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "F95", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "REAL*8 function f(x)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "f = x*y\n" + "end function\n" + ) + assert source == expected + + expected = ( + '#include "f.h"\n' + '#include \n' + 'double f(double x, double y) {\n' + ' double f_result;\n' + ' f_result = x*y + z;\n' + ' return f_result;\n' + '}\n' + ) + result = codegen(('f', x*y+z), "C", header=False, empty=False, + global_vars=(z, t)) + source = result[0][1] + assert source == expected + +def test_custom_codegen(): + from sympy.printing.c import C99CodePrinter + from sympy.functions.elementary.exponential import exp + + printer = C99CodePrinter(settings={'user_functions': {'exp': 'fastexp'}}) + + x, y = symbols('x y') + expr = exp(x + y) + + # replace math.h with a different header + gen = C99CodeGen(printer=printer, + preprocessor_statements=['#include "fastexp.h"']) + + expected = ( + '#include "expr.h"\n' + '#include "fastexp.h"\n' + 'double expr(double x, double y) {\n' + ' double expr_result;\n' + ' expr_result = fastexp(x + y);\n' + ' return expr_result;\n' + '}\n' + ) + + result = codegen(('expr', expr), header=False, empty=False, code_gen=gen) + source = result[0][1] + assert source == expected + + # use both math.h and an external header + gen = C99CodeGen(printer=printer) + gen.preprocessor_statements.append('#include "fastexp.h"') + + expected = ( + '#include "expr.h"\n' + '#include \n' + '#include "fastexp.h"\n' + 'double expr(double x, double y) {\n' + ' double expr_result;\n' + ' expr_result = fastexp(x + y);\n' + ' return expr_result;\n' + '}\n' + ) + + result = codegen(('expr', expr), header=False, empty=False, code_gen=gen) + source = result[0][1] + assert source == expected + +def test_c_with_printer(): + # issue 13586 + from sympy.printing.c import C99CodePrinter + class CustomPrinter(C99CodePrinter): + def _print_Pow(self, expr): + return "fastpow({}, {})".format(self._print(expr.base), + self._print(expr.exp)) + + x = symbols('x') + expr = x**3 + expected =[ + ("file.c", + "#include \"file.h\"\n" + "#include \n" + "double test(double x) {\n" + " double test_result;\n" + " test_result = fastpow(x, 3);\n" + " return test_result;\n" + "}\n"), + ("file.h", + "#ifndef PROJECT__FILE__H\n" + "#define PROJECT__FILE__H\n" + "double test(double x);\n" + "#endif\n") + ] + result = codegen(("test", expr), "C","file", header=False, empty=False, printer = CustomPrinter()) + assert result == expected + + +def test_fcode_complex(): + import sympy.utilities.codegen + sympy.utilities.codegen.COMPLEX_ALLOWED = True + x = Symbol('x', real=True) + y = Symbol('y',real=True) + result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False) + source = (result[0][1]) + expected = ( + "REAL*8 function test(x, y)\n" + "implicit none\n" + "REAL*8, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "test = x + y\n" + "end function\n") + assert source == expected + x = Symbol('x') + y = Symbol('y',real=True) + result = codegen(('test',x+y), 'f95', 'test', header=False, empty=False) + source = (result[0][1]) + expected = ( + "COMPLEX*16 function test(x, y)\n" + "implicit none\n" + "COMPLEX*16, intent(in) :: x\n" + "REAL*8, intent(in) :: y\n" + "test = x + y\n" + "end function\n" + ) + assert source==expected + sympy.utilities.codegen.COMPLEX_ALLOWED = False diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_julia.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_julia.py new file mode 100644 index 0000000000000000000000000000000000000000..eb4d5920554555a103e017b0518e028ff7d51f8d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_julia.py @@ -0,0 +1,620 @@ +from io import StringIO + +from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function +from sympy.core.relational import Equality +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices import Matrix, MatrixSymbol +from sympy.utilities.codegen import JuliaCodeGen, codegen, make_routine +from sympy.testing.pytest import XFAIL +import sympy + + +x, y, z = symbols('x,y,z') + + +def test_empty_jl_code(): + code_gen = JuliaCodeGen() + output = StringIO() + code_gen.dump_jl([], output, "file", header=False, empty=False) + source = output.getvalue() + assert source == "" + + +def test_jl_simple_code(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0] == "test.jl" + source = result[1] + expected = ( + "function test(x, y, z)\n" + " out1 = z .* (x + y)\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_simple_code_with_header(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Julia", header=True, empty=False) + assert result[0] == "test.jl" + source = result[1] + expected = ( + "# Code generated with SymPy " + sympy.__version__ + "\n" + "#\n" + "# See http://www.sympy.org/ for more information.\n" + "#\n" + "# This file is part of 'project'\n" + "function test(x, y, z)\n" + " out1 = z .* (x + y)\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_simple_code_nameout(): + expr = Equality(z, (x + y)) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y)\n" + " z = x + y\n" + " return z\n" + "end\n" + ) + assert source == expected + + +def test_jl_numbersymbol(): + name_expr = ("test", pi**Catalan) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test()\n" + " out1 = pi ^ catalan\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +@XFAIL +def test_jl_numbersymbol_no_inline(): + # FIXME: how to pass inline=False to the JuliaCodePrinter? + name_expr = ("test", [pi**Catalan, EulerGamma]) + result, = codegen(name_expr, "Julia", header=False, + empty=False, inline=False) + source = result[1] + expected = ( + "function test()\n" + " Catalan = 0.915965594177219\n" + " EulerGamma = 0.5772156649015329\n" + " out1 = pi ^ Catalan\n" + " out2 = EulerGamma\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_code_argument_order(): + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y], language="julia") + code_gen = JuliaCodeGen() + output = StringIO() + code_gen.dump_jl([routine], output, "test", header=False, empty=False) + source = output.getvalue() + expected = ( + "function test(z, x, y)\n" + " out1 = x + y\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_multiple_results_m(): + # Here the output order is the input order + expr1 = (x + y)*z + expr2 = (x - y)*z + name_expr = ("test", [expr1, expr2]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " out1 = z .* (x + y)\n" + " out2 = z .* (x - y)\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_results_named_unordered(): + # Here output order is based on name_expr + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " C = z .* (x + y)\n" + " A = z .* (x - y)\n" + " B = 2 * x\n" + " return C, A, B\n" + "end\n" + ) + assert source == expected + + +def test_results_named_ordered(): + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "Julia", header=False, empty=False, + argument_sequence=(x, z, y)) + assert result[0][0] == "test.jl" + source = result[0][1] + expected = ( + "function test(x, z, y)\n" + " C = z .* (x + y)\n" + " A = z .* (x - y)\n" + " B = 2 * x\n" + " return C, A, B\n" + "end\n" + ) + assert source == expected + + +def test_complicated_jl_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + name_expr = ("testlong", + [ ((sin(x) + cos(y) + tan(z))**3).expand(), + cos(cos(cos(cos(cos(cos(cos(cos(x + y + z)))))))) + ]) + result = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0][0] == "testlong.jl" + source = result[0][1] + expected = ( + "function testlong(x, y, z)\n" + " out1 = sin(x) .^ 3 + 3 * sin(x) .^ 2 .* cos(y) + 3 * sin(x) .^ 2 .* tan(z)" + " + 3 * sin(x) .* cos(y) .^ 2 + 6 * sin(x) .* cos(y) .* tan(z) + 3 * sin(x) .* tan(z) .^ 2" + " + cos(y) .^ 3 + 3 * cos(y) .^ 2 .* tan(z) + 3 * cos(y) .* tan(z) .^ 2 + tan(z) .^ 3\n" + " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_output_arg_mixed_unordered(): + # named outputs are alphabetical, unnamed output appear in the given order + from sympy.functions.elementary.trigonometric import (cos, sin) + a = symbols("a") + name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0] == "foo.jl" + source = result[1]; + expected = ( + 'function foo(x)\n' + ' out1 = cos(2 * x)\n' + ' y = sin(x)\n' + ' out3 = cos(x)\n' + ' a = sin(2 * x)\n' + ' return out1, y, out3, a\n' + 'end\n' + ) + assert source == expected + + +def test_jl_piecewise_(): + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function pwtest(x)\n" + " out1 = ((x < -1) ? (0) :\n" + " (x <= 1) ? (x .^ 2) :\n" + " (x > 1) ? (2 - x) : (1))\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +@XFAIL +def test_jl_piecewise_no_inline(): + # FIXME: how to pass inline=False to the JuliaCodePrinter? + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True)) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Julia", header=False, empty=False, + inline=False) + source = result[1] + expected = ( + "function pwtest(x)\n" + " if (x < -1)\n" + " out1 = 0\n" + " elseif (x <= 1)\n" + " out1 = x .^ 2\n" + " elseif (x > 1)\n" + " out1 = -x + 2\n" + " else\n" + " out1 = 1\n" + " end\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_multifcns_per_file(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0][0] == "foo.jl" + source = result[0][1]; + expected = ( + "function foo(x, y)\n" + " out1 = 2 * x\n" + " out2 = 3 * y\n" + " return out1, out2\n" + "end\n" + "function bar(y)\n" + " out1 = y .^ 2\n" + " out2 = 4 * y\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_multifcns_per_file_w_header(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Julia", header=True, empty=False) + assert result[0][0] == "foo.jl" + source = result[0][1]; + expected = ( + "# Code generated with SymPy " + sympy.__version__ + "\n" + "#\n" + "# See http://www.sympy.org/ for more information.\n" + "#\n" + "# This file is part of 'project'\n" + "function foo(x, y)\n" + " out1 = 2 * x\n" + " out2 = 3 * y\n" + " return out1, out2\n" + "end\n" + "function bar(y)\n" + " out1 = y .^ 2\n" + " out2 = 4 * y\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_jl_filename_match_prefix(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result, = codegen(name_expr, "Julia", prefix="baz", header=False, + empty=False) + assert result[0] == "baz.jl" + + +def test_jl_matrix_named(): + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2)) + result = codegen(name_expr, "Julia", header=False, empty=False) + assert result[0][0] == "test.jl" + source = result[0][1] + expected = ( + "function test(x, y, z)\n" + " myout1 = [x 2 * y pi * z]\n" + " return myout1\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrix_named_matsym(): + myout1 = MatrixSymbol('myout1', 1, 3) + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(myout1, e2, evaluate=False)) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " myout1 = [x 2 * y pi * z]\n" + " return myout1\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrix_output_autoname(): + expr = Matrix([[x, x+y, 3]]) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y)\n" + " out1 = [x x + y 3]\n" + " return out1\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrix_output_autoname_2(): + e1 = (x + y) + e2 = Matrix([[2*x, 2*y, 2*z]]) + e3 = Matrix([[x], [y], [z]]) + e4 = Matrix([[x, y], [z, 16]]) + name_expr = ("test", (e1, e2, e3, e4)) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y, z)\n" + " out1 = x + y\n" + " out2 = [2 * x 2 * y 2 * z]\n" + " out3 = [x, y, z]\n" + " out4 = [x y;\n" + " z 16]\n" + " return out1, out2, out3, out4\n" + "end\n" + ) + assert source == expected + + +def test_jl_results_matrix_named_ordered(): + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, Matrix([[1, 2, x]])) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Julia", header=False, empty=False, + argument_sequence=(x, z, y)) + source = result[1] + expected = ( + "function test(x, z, y)\n" + " C = z .* (x + y)\n" + " A = [1 2 x]\n" + " B = 2 * x\n" + " return C, A, B\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 2, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[1,:]\n" + " C = A[2,:]\n" + " D = A[:,3]\n" + " return B, C, D\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice2(): + A = MatrixSymbol('A', 3, 4) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 2, 2) + name_expr = ("test", [Equality(B, A[0:2, 0:2]), + Equality(C, A[0:2, 1:3])]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[1:2,1:2]\n" + " C = A[1:2,2:3]\n" + " return B, C\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice3(): + A = MatrixSymbol('A', 8, 7) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 4, 2) + name_expr = ("test", [Equality(B, A[6:, 1::3]), + Equality(C, A[::2, ::3])]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[7:end,2:3:end]\n" + " C = A[1:2:end,1:3:end]\n" + " return B, C\n" + "end\n" + ) + assert source == expected + + +def test_jl_matrixsymbol_slice_autoname(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(A)\n" + " B = A[1,:]\n" + " out2 = A[2,:]\n" + " out3 = A[:,1]\n" + " out4 = A[:,2]\n" + " return B, out2, out3, out4\n" + "end\n" + ) + assert source == expected + + +def test_jl_loops(): + # Note: an Julia programmer would probably vectorize this across one or + # more dimensions. Also, size(A) would be used rather than passing in m + # and n. Perhaps users would expect us to vectorize automatically here? + # Or is it possible to represent such things using IndexedBase? + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Julia", + header=False, empty=False) + source = result[1] + expected = ( + 'function mat_vec_mult(y, A, m, n, x)\n' + ' for i = 1:m\n' + ' y[i] = 0\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' y[i] = %(rhs)s + y[i]\n' + ' end\n' + ' end\n' + ' return y\n' + 'end\n' + ) + assert (source == expected % {'rhs': 'A[%s,%s] .* x[j]' % (i, j)} or + source == expected % {'rhs': 'x[j] .* A[%s,%s]' % (i, j)}) + + +def test_jl_tensor_loops_multiple_contractions(): + # see comments in previous test about vectorizing + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m, o, p = symbols('n m o p', integer=True) + A = IndexedBase('A') + B = IndexedBase('B') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])), + "Julia", header=False, empty=False) + source = result[1] + expected = ( + 'function tensorthing(y, A, B, m, n, o, p)\n' + ' for i = 1:m\n' + ' y[i] = 0\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' for k = 1:o\n' + ' for l = 1:p\n' + ' y[i] = A[i,j,k,l] .* B[j,k,l] + y[i]\n' + ' end\n' + ' end\n' + ' end\n' + ' end\n' + ' return y\n' + 'end\n' + ) + assert source == expected + + +def test_jl_InOutArgument(): + expr = Equality(x, x**2) + name_expr = ("mysqr", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function mysqr(x)\n" + " x = x .^ 2\n" + " return x\n" + "end\n" + ) + assert source == expected + + +def test_jl_InOutArgument_order(): + # can specify the order as (x, y) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, + empty=False, argument_sequence=(x,y)) + source = result[1] + expected = ( + "function test(x, y)\n" + " x = x .^ 2 + y\n" + " return x\n" + "end\n" + ) + assert source == expected + # make sure it gives (x, y) not (y, x) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x, y)\n" + " x = x .^ 2 + y\n" + " return x\n" + "end\n" + ) + assert source == expected + + +def test_jl_not_supported(): + f = Function('f') + name_expr = ("test", [f(x).diff(x), S.ComplexInfinity]) + result, = codegen(name_expr, "Julia", header=False, empty=False) + source = result[1] + expected = ( + "function test(x)\n" + " # unsupported: Derivative(f(x), x)\n" + " # unsupported: zoo\n" + " out1 = Derivative(f(x), x)\n" + " out2 = zoo\n" + " return out1, out2\n" + "end\n" + ) + assert source == expected + + +def test_global_vars_octave(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "Julia", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "function f(x)\n" + " out1 = x .* y\n" + " return out1\n" + "end\n" + ) + assert source == expected + + result = codegen(('f', x*y+z), "Julia", header=False, empty=False, + argument_sequence=(x, y), global_vars=(z, t)) + source = result[0][1] + expected = ( + "function f(x, y)\n" + " out1 = x .* y + z\n" + " return out1\n" + "end\n" + ) + assert source == expected diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_octave.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_octave.py new file mode 100644 index 0000000000000000000000000000000000000000..53634cb1cd945fabcfb87dc2acbf73c9af23e519 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_octave.py @@ -0,0 +1,589 @@ +from io import StringIO + +from sympy.core import S, symbols, Eq, pi, Catalan, EulerGamma, Function +from sympy.core.relational import Equality +from sympy.functions.elementary.piecewise import Piecewise +from sympy.matrices import Matrix, MatrixSymbol +from sympy.utilities.codegen import OctaveCodeGen, codegen, make_routine +from sympy.testing.pytest import raises +from sympy.testing.pytest import XFAIL +import sympy + + +x, y, z = symbols('x,y,z') + + +def test_empty_m_code(): + code_gen = OctaveCodeGen() + output = StringIO() + code_gen.dump_m([], output, "file", header=False, empty=False) + source = output.getvalue() + assert source == "" + + +def test_m_simple_code(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0] == "test.m" + source = result[1] + expected = ( + "function out1 = test(x, y, z)\n" + " out1 = z.*(x + y);\n" + "end\n" + ) + assert source == expected + + +def test_m_simple_code_with_header(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Octave", header=True, empty=False) + assert result[0] == "test.m" + source = result[1] + expected = ( + "function out1 = test(x, y, z)\n" + " %TEST Autogenerated by SymPy\n" + " % Code generated with SymPy " + sympy.__version__ + "\n" + " %\n" + " % See http://www.sympy.org/ for more information.\n" + " %\n" + " % This file is part of 'project'\n" + " out1 = z.*(x + y);\n" + "end\n" + ) + assert source == expected + + +def test_m_simple_code_nameout(): + expr = Equality(z, (x + y)) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function z = test(x, y)\n" + " z = x + y;\n" + "end\n" + ) + assert source == expected + + +def test_m_numbersymbol(): + name_expr = ("test", pi**Catalan) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function out1 = test()\n" + " out1 = pi^%s;\n" + "end\n" + ) % Catalan.evalf(17) + assert source == expected + + +@XFAIL +def test_m_numbersymbol_no_inline(): + # FIXME: how to pass inline=False to the OctaveCodePrinter? + name_expr = ("test", [pi**Catalan, EulerGamma]) + result, = codegen(name_expr, "Octave", header=False, + empty=False, inline=False) + source = result[1] + expected = ( + "function [out1, out2] = test()\n" + " Catalan = 0.915965594177219; % constant\n" + " EulerGamma = 0.5772156649015329; % constant\n" + " out1 = pi^Catalan;\n" + " out2 = EulerGamma;\n" + "end\n" + ) + assert source == expected + + +def test_m_code_argument_order(): + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y], language="octave") + code_gen = OctaveCodeGen() + output = StringIO() + code_gen.dump_m([routine], output, "test", header=False, empty=False) + source = output.getvalue() + expected = ( + "function out1 = test(z, x, y)\n" + " out1 = x + y;\n" + "end\n" + ) + assert source == expected + + +def test_multiple_results_m(): + # Here the output order is the input order + expr1 = (x + y)*z + expr2 = (x - y)*z + name_expr = ("test", [expr1, expr2]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [out1, out2] = test(x, y, z)\n" + " out1 = z.*(x + y);\n" + " out2 = z.*(x - y);\n" + "end\n" + ) + assert source == expected + + +def test_results_named_unordered(): + # Here output order is based on name_expr + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [C, A, B] = test(x, y, z)\n" + " C = z.*(x + y);\n" + " A = z.*(x - y);\n" + " B = 2*x;\n" + "end\n" + ) + assert source == expected + + +def test_results_named_ordered(): + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "Octave", header=False, empty=False, + argument_sequence=(x, z, y)) + assert result[0][0] == "test.m" + source = result[0][1] + expected = ( + "function [C, A, B] = test(x, z, y)\n" + " C = z.*(x + y);\n" + " A = z.*(x - y);\n" + " B = 2*x;\n" + "end\n" + ) + assert source == expected + + +def test_complicated_m_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + name_expr = ("testlong", + [ ((sin(x) + cos(y) + tan(z))**3).expand(), + cos(cos(cos(cos(cos(cos(cos(cos(x + y + z)))))))) + ]) + result = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0][0] == "testlong.m" + source = result[0][1] + expected = ( + "function [out1, out2] = testlong(x, y, z)\n" + " out1 = sin(x).^3 + 3*sin(x).^2.*cos(y) + 3*sin(x).^2.*tan(z)" + " + 3*sin(x).*cos(y).^2 + 6*sin(x).*cos(y).*tan(z) + 3*sin(x).*tan(z).^2" + " + cos(y).^3 + 3*cos(y).^2.*tan(z) + 3*cos(y).*tan(z).^2 + tan(z).^3;\n" + " out2 = cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))));\n" + "end\n" + ) + assert source == expected + + +def test_m_output_arg_mixed_unordered(): + # named outputs are alphabetical, unnamed output appear in the given order + from sympy.functions.elementary.trigonometric import (cos, sin) + a = symbols("a") + name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0] == "foo.m" + source = result[1]; + expected = ( + 'function [out1, y, out3, a] = foo(x)\n' + ' out1 = cos(2*x);\n' + ' y = sin(x);\n' + ' out3 = cos(x);\n' + ' a = sin(2*x);\n' + 'end\n' + ) + assert source == expected + + +def test_m_piecewise_(): + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function out1 = pwtest(x)\n" + " out1 = ((x < -1).*(0) + (~(x < -1)).*( ...\n" + " (x <= 1).*(x.^2) + (~(x <= 1)).*( ...\n" + " (x > 1).*(2 - x) + (~(x > 1)).*(1))));\n" + "end\n" + ) + assert source == expected + + +@XFAIL +def test_m_piecewise_no_inline(): + # FIXME: how to pass inline=False to the OctaveCodePrinter? + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True)) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Octave", header=False, empty=False, + inline=False) + source = result[1] + expected = ( + "function out1 = pwtest(x)\n" + " if (x < -1)\n" + " out1 = 0;\n" + " elseif (x <= 1)\n" + " out1 = x.^2;\n" + " elseif (x > 1)\n" + " out1 = -x + 2;\n" + " else\n" + " out1 = 1;\n" + " end\n" + "end\n" + ) + assert source == expected + + +def test_m_multifcns_per_file(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0][0] == "foo.m" + source = result[0][1]; + expected = ( + "function [out1, out2] = foo(x, y)\n" + " out1 = 2*x;\n" + " out2 = 3*y;\n" + "end\n" + "function [out1, out2] = bar(y)\n" + " out1 = y.^2;\n" + " out2 = 4*y;\n" + "end\n" + ) + assert source == expected + + +def test_m_multifcns_per_file_w_header(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Octave", header=True, empty=False) + assert result[0][0] == "foo.m" + source = result[0][1]; + expected = ( + "function [out1, out2] = foo(x, y)\n" + " %FOO Autogenerated by SymPy\n" + " % Code generated with SymPy " + sympy.__version__ + "\n" + " %\n" + " % See http://www.sympy.org/ for more information.\n" + " %\n" + " % This file is part of 'project'\n" + " out1 = 2*x;\n" + " out2 = 3*y;\n" + "end\n" + "function [out1, out2] = bar(y)\n" + " out1 = y.^2;\n" + " out2 = 4*y;\n" + "end\n" + ) + assert source == expected + + +def test_m_filename_match_first_fcn(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + raises(ValueError, lambda: codegen(name_expr, + "Octave", prefix="bar", header=False, empty=False)) + + +def test_m_matrix_named(): + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(MatrixSymbol('myout1', 1, 3), e2)) + result = codegen(name_expr, "Octave", header=False, empty=False) + assert result[0][0] == "test.m" + source = result[0][1] + expected = ( + "function myout1 = test(x, y, z)\n" + " myout1 = [x 2*y pi*z];\n" + "end\n" + ) + assert source == expected + + +def test_m_matrix_named_matsym(): + myout1 = MatrixSymbol('myout1', 1, 3) + e2 = Matrix([[x, 2*y, pi*z]]) + name_expr = ("test", Equality(myout1, e2, evaluate=False)) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function myout1 = test(x, y, z)\n" + " myout1 = [x 2*y pi*z];\n" + "end\n" + ) + assert source == expected + + +def test_m_matrix_output_autoname(): + expr = Matrix([[x, x+y, 3]]) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function out1 = test(x, y)\n" + " out1 = [x x + y 3];\n" + "end\n" + ) + assert source == expected + + +def test_m_matrix_output_autoname_2(): + e1 = (x + y) + e2 = Matrix([[2*x, 2*y, 2*z]]) + e3 = Matrix([[x], [y], [z]]) + e4 = Matrix([[x, y], [z, 16]]) + name_expr = ("test", (e1, e2, e3, e4)) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [out1, out2, out3, out4] = test(x, y, z)\n" + " out1 = x + y;\n" + " out2 = [2*x 2*y 2*z];\n" + " out3 = [x; y; z];\n" + " out4 = [x y; z 16];\n" + "end\n" + ) + assert source == expected + + +def test_m_results_matrix_named_ordered(): + B, C = symbols('B,C') + A = MatrixSymbol('A', 1, 3) + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, Matrix([[1, 2, x]])) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Octave", header=False, empty=False, + argument_sequence=(x, z, y)) + source = result[1] + expected = ( + "function [C, A, B] = test(x, z, y)\n" + " C = z.*(x + y);\n" + " A = [1 2 x];\n" + " B = 2*x;\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + C = MatrixSymbol('C', 1, 3) + D = MatrixSymbol('D', 2, 1) + name_expr = ("test", [Equality(B, A[0, :]), + Equality(C, A[1, :]), + Equality(D, A[:, 2])]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, C, D] = test(A)\n" + " B = A(1, :);\n" + " C = A(2, :);\n" + " D = A(:, 3);\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice2(): + A = MatrixSymbol('A', 3, 4) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 2, 2) + name_expr = ("test", [Equality(B, A[0:2, 0:2]), + Equality(C, A[0:2, 1:3])]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, C] = test(A)\n" + " B = A(1:2, 1:2);\n" + " C = A(1:2, 2:3);\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice3(): + A = MatrixSymbol('A', 8, 7) + B = MatrixSymbol('B', 2, 2) + C = MatrixSymbol('C', 4, 2) + name_expr = ("test", [Equality(B, A[6:, 1::3]), + Equality(C, A[::2, ::3])]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, C] = test(A)\n" + " B = A(7:end, 2:3:end);\n" + " C = A(1:2:end, 1:3:end);\n" + "end\n" + ) + assert source == expected + + +def test_m_matrixsymbol_slice_autoname(): + A = MatrixSymbol('A', 2, 3) + B = MatrixSymbol('B', 1, 3) + name_expr = ("test", [Equality(B, A[0,:]), A[1,:], A[:,0], A[:,1]]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [B, out2, out3, out4] = test(A)\n" + " B = A(1, :);\n" + " out2 = A(2, :);\n" + " out3 = A(:, 1);\n" + " out4 = A(:, 2);\n" + "end\n" + ) + assert source == expected + + +def test_m_loops(): + # Note: an Octave programmer would probably vectorize this across one or + # more dimensions. Also, size(A) would be used rather than passing in m + # and n. Perhaps users would expect us to vectorize automatically here? + # Or is it possible to represent such things using IndexedBase? + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m = symbols('n m', integer=True) + A = IndexedBase('A') + x = IndexedBase('x') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + result, = codegen(('mat_vec_mult', Eq(y[i], A[i, j]*x[j])), "Octave", + header=False, empty=False) + source = result[1] + expected = ( + 'function y = mat_vec_mult(A, m, n, x)\n' + ' for i = 1:m\n' + ' y(i) = 0;\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' y(i) = %(rhs)s + y(i);\n' + ' end\n' + ' end\n' + 'end\n' + ) + assert (source == expected % {'rhs': 'A(%s, %s).*x(j)' % (i, j)} or + source == expected % {'rhs': 'x(j).*A(%s, %s)' % (i, j)}) + + +def test_m_tensor_loops_multiple_contractions(): + # see comments in previous test about vectorizing + from sympy.tensor import IndexedBase, Idx + from sympy.core.symbol import symbols + n, m, o, p = symbols('n m o p', integer=True) + A = IndexedBase('A') + B = IndexedBase('B') + y = IndexedBase('y') + i = Idx('i', m) + j = Idx('j', n) + k = Idx('k', o) + l = Idx('l', p) + result, = codegen(('tensorthing', Eq(y[i], B[j, k, l]*A[i, j, k, l])), + "Octave", header=False, empty=False) + source = result[1] + expected = ( + 'function y = tensorthing(A, B, m, n, o, p)\n' + ' for i = 1:m\n' + ' y(i) = 0;\n' + ' end\n' + ' for i = 1:m\n' + ' for j = 1:n\n' + ' for k = 1:o\n' + ' for l = 1:p\n' + ' y(i) = A(i, j, k, l).*B(j, k, l) + y(i);\n' + ' end\n' + ' end\n' + ' end\n' + ' end\n' + 'end\n' + ) + assert source == expected + + +def test_m_InOutArgument(): + expr = Equality(x, x**2) + name_expr = ("mysqr", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function x = mysqr(x)\n" + " x = x.^2;\n" + "end\n" + ) + assert source == expected + + +def test_m_InOutArgument_order(): + # can specify the order as (x, y) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, + empty=False, argument_sequence=(x,y)) + source = result[1] + expected = ( + "function x = test(x, y)\n" + " x = x.^2 + y;\n" + "end\n" + ) + assert source == expected + # make sure it gives (x, y) not (y, x) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function x = test(x, y)\n" + " x = x.^2 + y;\n" + "end\n" + ) + assert source == expected + + +def test_m_not_supported(): + f = Function('f') + name_expr = ("test", [f(x).diff(x), S.ComplexInfinity]) + result, = codegen(name_expr, "Octave", header=False, empty=False) + source = result[1] + expected = ( + "function [out1, out2] = test(x)\n" + " % unsupported: Derivative(f(x), x)\n" + " % unsupported: zoo\n" + " out1 = Derivative(f(x), x);\n" + " out2 = zoo;\n" + "end\n" + ) + assert source == expected + + +def test_global_vars_octave(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "Octave", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "function out1 = f(x)\n" + " global y\n" + " out1 = x.*y;\n" + "end\n" + ) + assert source == expected + + result = codegen(('f', x*y+z), "Octave", header=False, empty=False, + argument_sequence=(x, y), global_vars=(z, t)) + source = result[0][1] + expected = ( + "function out1 = f(x, y)\n" + " global t z\n" + " out1 = x.*y + z;\n" + "end\n" + ) + assert source == expected diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_rust.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_rust.py new file mode 100644 index 0000000000000000000000000000000000000000..235cc6350e051ab1a2915284aa4669274030943b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_codegen_rust.py @@ -0,0 +1,401 @@ +from io import StringIO + +from sympy.core import S, symbols, pi, Catalan, EulerGamma, Function +from sympy.core.relational import Equality +from sympy.functions.elementary.piecewise import Piecewise +from sympy.utilities.codegen import RustCodeGen, codegen, make_routine +from sympy.testing.pytest import XFAIL +import sympy + + +x, y, z = symbols('x,y,z') + + +def test_empty_rust_code(): + code_gen = RustCodeGen() + output = StringIO() + code_gen.dump_rs([], output, "file", header=False, empty=False) + source = output.getvalue() + assert source == "" + + +def test_simple_rust_code(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0] == "test.rs" + source = result[1] + expected = ( + "fn test(x: f64, y: f64, z: f64) -> f64 {\n" + " let out1 = z*(x + y);\n" + " out1\n" + "}\n" + ) + assert source == expected + + +def test_simple_code_with_header(): + name_expr = ("test", (x + y)*z) + result, = codegen(name_expr, "Rust", header=True, empty=False) + assert result[0] == "test.rs" + source = result[1] + version_str = "Code generated with SymPy %s" % sympy.__version__ + version_line = version_str.center(76).rstrip() + expected = ( + "/*\n" + " *%(version_line)s\n" + " *\n" + " * See http://www.sympy.org/ for more information.\n" + " *\n" + " * This file is part of 'project'\n" + " */\n" + "fn test(x: f64, y: f64, z: f64) -> f64 {\n" + " let out1 = z*(x + y);\n" + " out1\n" + "}\n" + ) % {'version_line': version_line} + assert source == expected + + +def test_simple_code_nameout(): + expr = Equality(z, (x + y)) + name_expr = ("test", expr) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64) -> f64 {\n" + " let z = x + y;\n" + " z\n" + "}\n" + ) + assert source == expected + + +def test_numbersymbol(): + name_expr = ("test", pi**Catalan) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test() -> f64 {\n" + " const Catalan: f64 = %s;\n" + " let out1 = PI.powf(Catalan);\n" + " out1\n" + "}\n" + ) % Catalan.evalf(17) + assert source == expected + + +@XFAIL +def test_numbersymbol_inline(): + # FIXME: how to pass inline to the RustCodePrinter? + name_expr = ("test", [pi**Catalan, EulerGamma]) + result, = codegen(name_expr, "Rust", header=False, + empty=False, inline=True) + source = result[1] + expected = ( + "fn test() -> (f64, f64) {\n" + " const Catalan: f64 = %s;\n" + " const EulerGamma: f64 = %s;\n" + " let out1 = PI.powf(Catalan);\n" + " let out2 = EulerGamma);\n" + " (out1, out2)\n" + "}\n" + ) % (Catalan.evalf(17), EulerGamma.evalf(17)) + assert source == expected + + +def test_argument_order(): + expr = x + y + routine = make_routine("test", expr, argument_sequence=[z, x, y], language="rust") + code_gen = RustCodeGen() + output = StringIO() + code_gen.dump_rs([routine], output, "test", header=False, empty=False) + source = output.getvalue() + expected = ( + "fn test(z: f64, x: f64, y: f64) -> f64 {\n" + " let out1 = x + y;\n" + " out1\n" + "}\n" + ) + assert source == expected + + +def test_multiple_results_rust(): + # Here the output order is the input order + expr1 = (x + y)*z + expr2 = (x - y)*z + name_expr = ("test", [expr1, expr2]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64, z: f64) -> (f64, f64) {\n" + " let out1 = z*(x + y);\n" + " let out2 = z*(x - y);\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_results_named_unordered(): + # Here output order is based on name_expr + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64, z: f64) -> (f64, f64, f64) {\n" + " let C = z*(x + y);\n" + " let A = z*(x - y);\n" + " let B = 2*x;\n" + " (C, A, B)\n" + "}\n" + ) + assert source == expected + + +def test_results_named_ordered(): + A, B, C = symbols('A,B,C') + expr1 = Equality(C, (x + y)*z) + expr2 = Equality(A, (x - y)*z) + expr3 = Equality(B, 2*x) + name_expr = ("test", [expr1, expr2, expr3]) + result = codegen(name_expr, "Rust", header=False, empty=False, + argument_sequence=(x, z, y)) + assert result[0][0] == "test.rs" + source = result[0][1] + expected = ( + "fn test(x: f64, z: f64, y: f64) -> (f64, f64, f64) {\n" + " let C = z*(x + y);\n" + " let A = z*(x - y);\n" + " let B = 2*x;\n" + " (C, A, B)\n" + "}\n" + ) + assert source == expected + + +def test_complicated_rs_codegen(): + from sympy.functions.elementary.trigonometric import (cos, sin, tan) + name_expr = ("testlong", + [ ((sin(x) + cos(y) + tan(z))**3).expand(), + cos(cos(cos(cos(cos(cos(cos(cos(x + y + z)))))))) + ]) + result = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0][0] == "testlong.rs" + source = result[0][1] + expected = ( + "fn testlong(x: f64, y: f64, z: f64) -> (f64, f64) {\n" + " let out1 = x.sin().powi(3) + 3*x.sin().powi(2)*y.cos()" + " + 3*x.sin().powi(2)*z.tan() + 3*x.sin()*y.cos().powi(2)" + " + 6*x.sin()*y.cos()*z.tan() + 3*x.sin()*z.tan().powi(2)" + " + y.cos().powi(3) + 3*y.cos().powi(2)*z.tan()" + " + 3*y.cos()*z.tan().powi(2) + z.tan().powi(3);\n" + " let out2 = (x + y + z).cos().cos().cos().cos()" + ".cos().cos().cos().cos();\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_output_arg_mixed_unordered(): + # named outputs are alphabetical, unnamed output appear in the given order + from sympy.functions.elementary.trigonometric import (cos, sin) + a = symbols("a") + name_expr = ("foo", [cos(2*x), Equality(y, sin(x)), cos(x), Equality(a, sin(2*x))]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0] == "foo.rs" + source = result[1]; + expected = ( + "fn foo(x: f64) -> (f64, f64, f64, f64) {\n" + " let out1 = (2*x).cos();\n" + " let y = x.sin();\n" + " let out3 = x.cos();\n" + " let a = (2*x).sin();\n" + " (out1, y, out3, a)\n" + "}\n" + ) + assert source == expected + + +def test_piecewise_(): + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True), evaluate=False) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn pwtest(x: f64) -> f64 {\n" + " let out1 = if (x < -1) {\n" + " 0\n" + " } else if (x <= 1) {\n" + " x.powi(2)\n" + " } else if (x > 1) {\n" + " 2 - x\n" + " } else {\n" + " 1\n" + " };\n" + " out1\n" + "}\n" + ) + assert source == expected + + +@XFAIL +def test_piecewise_inline(): + # FIXME: how to pass inline to the RustCodePrinter? + pw = Piecewise((0, x < -1), (x**2, x <= 1), (-x+2, x > 1), (1, True)) + name_expr = ("pwtest", pw) + result, = codegen(name_expr, "Rust", header=False, empty=False, + inline=True) + source = result[1] + expected = ( + "fn pwtest(x: f64) -> f64 {\n" + " let out1 = if (x < -1) { 0 } else if (x <= 1) { x.powi(2) }" + " else if (x > 1) { -x + 2 } else { 1 };\n" + " out1\n" + "}\n" + ) + assert source == expected + + +def test_multifcns_per_file(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Rust", header=False, empty=False) + assert result[0][0] == "foo.rs" + source = result[0][1]; + expected = ( + "fn foo(x: f64, y: f64) -> (f64, f64) {\n" + " let out1 = 2*x;\n" + " let out2 = 3*y;\n" + " (out1, out2)\n" + "}\n" + "fn bar(y: f64) -> (f64, f64) {\n" + " let out1 = y.powi(2);\n" + " let out2 = 4*y;\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_multifcns_per_file_w_header(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result = codegen(name_expr, "Rust", header=True, empty=False) + assert result[0][0] == "foo.rs" + source = result[0][1]; + version_str = "Code generated with SymPy %s" % sympy.__version__ + version_line = version_str.center(76).rstrip() + expected = ( + "/*\n" + " *%(version_line)s\n" + " *\n" + " * See http://www.sympy.org/ for more information.\n" + " *\n" + " * This file is part of 'project'\n" + " */\n" + "fn foo(x: f64, y: f64) -> (f64, f64) {\n" + " let out1 = 2*x;\n" + " let out2 = 3*y;\n" + " (out1, out2)\n" + "}\n" + "fn bar(y: f64) -> (f64, f64) {\n" + " let out1 = y.powi(2);\n" + " let out2 = 4*y;\n" + " (out1, out2)\n" + "}\n" + ) % {'version_line': version_line} + assert source == expected + + +def test_filename_match_prefix(): + name_expr = [ ("foo", [2*x, 3*y]), ("bar", [y**2, 4*y]) ] + result, = codegen(name_expr, "Rust", prefix="baz", header=False, + empty=False) + assert result[0] == "baz.rs" + + +def test_InOutArgument(): + expr = Equality(x, x**2) + name_expr = ("mysqr", expr) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn mysqr(x: f64) -> f64 {\n" + " let x = x.powi(2);\n" + " x\n" + "}\n" + ) + assert source == expected + + +def test_InOutArgument_order(): + # can specify the order as (x, y) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Rust", header=False, + empty=False, argument_sequence=(x,y)) + source = result[1] + expected = ( + "fn test(x: f64, y: f64) -> f64 {\n" + " let x = x.powi(2) + y;\n" + " x\n" + "}\n" + ) + assert source == expected + # make sure it gives (x, y) not (y, x) + expr = Equality(x, x**2 + y) + name_expr = ("test", expr) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64, y: f64) -> f64 {\n" + " let x = x.powi(2) + y;\n" + " x\n" + "}\n" + ) + assert source == expected + + +def test_not_supported(): + f = Function('f') + name_expr = ("test", [f(x).diff(x), S.ComplexInfinity]) + result, = codegen(name_expr, "Rust", header=False, empty=False) + source = result[1] + expected = ( + "fn test(x: f64) -> (f64, f64) {\n" + " // unsupported: Derivative(f(x), x)\n" + " // unsupported: zoo\n" + " let out1 = Derivative(f(x), x);\n" + " let out2 = zoo;\n" + " (out1, out2)\n" + "}\n" + ) + assert source == expected + + +def test_global_vars_rust(): + x, y, z, t = symbols("x y z t") + result = codegen(('f', x*y), "Rust", header=False, empty=False, + global_vars=(y,)) + source = result[0][1] + expected = ( + "fn f(x: f64) -> f64 {\n" + " let out1 = x*y;\n" + " out1\n" + "}\n" + ) + assert source == expected + + result = codegen(('f', x*y+z), "Rust", header=False, empty=False, + argument_sequence=(x, y), global_vars=(z, t)) + source = result[0][1] + expected = ( + "fn f(x: f64, y: f64) -> f64 {\n" + " let out1 = x*y + z;\n" + " out1\n" + "}\n" + ) + assert source == expected diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_decorator.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_decorator.py new file mode 100644 index 0000000000000000000000000000000000000000..b1870d4db8f719fdabfeab14120bfb3ce10131a9 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_decorator.py @@ -0,0 +1,129 @@ +from functools import wraps + +from sympy.utilities.decorator import threaded, xthreaded, memoize_property, deprecated +from sympy.testing.pytest import warns_deprecated_sympy + +from sympy.core.basic import Basic +from sympy.core.relational import Eq +from sympy.matrices.dense import Matrix + +from sympy.abc import x, y + + +def test_threaded(): + @threaded + def function(expr, *args): + return 2*expr + sum(args) + + assert function(Matrix([[x, y], [1, x]]), 1, 2) == \ + Matrix([[2*x + 3, 2*y + 3], [5, 2*x + 3]]) + + assert function(Eq(x, y), 1, 2) == Eq(2*x + 3, 2*y + 3) + + assert function([x, y], 1, 2) == [2*x + 3, 2*y + 3] + assert function((x, y), 1, 2) == (2*x + 3, 2*y + 3) + + assert function({x, y}, 1, 2) == {2*x + 3, 2*y + 3} + + @threaded + def function(expr, n): + return expr**n + + assert function(x + y, 2) == x**2 + y**2 + assert function(x, 2) == x**2 + + +def test_xthreaded(): + @xthreaded + def function(expr, n): + return expr**n + + assert function(x + y, 2) == (x + y)**2 + + +def test_wraps(): + def my_func(x): + """My function. """ + + my_func.is_my_func = True + + new_my_func = threaded(my_func) + new_my_func = wraps(my_func)(new_my_func) + + assert new_my_func.__name__ == 'my_func' + assert new_my_func.__doc__ == 'My function. ' + assert hasattr(new_my_func, 'is_my_func') + assert new_my_func.is_my_func is True + + +def test_memoize_property(): + class TestMemoize(Basic): + @memoize_property + def prop(self): + return Basic() + + member = TestMemoize() + obj1 = member.prop + obj2 = member.prop + assert obj1 is obj2 + +def test_deprecated(): + @deprecated('deprecated_function is deprecated', + deprecated_since_version='1.10', + # This is the target at the top of the file, which will never + # go away. + active_deprecations_target='active-deprecations') + def deprecated_function(x): + return x + + with warns_deprecated_sympy(): + assert deprecated_function(1) == 1 + + @deprecated('deprecated_class is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class: + pass + + with warns_deprecated_sympy(): + assert isinstance(deprecated_class(), deprecated_class) + + # Ensure the class decorator works even when the class never returns + # itself + @deprecated('deprecated_class_new is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class_new: + def __new__(cls, arg): + return arg + + with warns_deprecated_sympy(): + assert deprecated_class_new(1) == 1 + + @deprecated('deprecated_class_init is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class_init: + def __init__(self, arg): + self.arg = 1 + + with warns_deprecated_sympy(): + assert deprecated_class_init(1).arg == 1 + + @deprecated('deprecated_class_new_init is deprecated', + deprecated_since_version='1.10', + active_deprecations_target='active-deprecations') + class deprecated_class_new_init: + def __new__(cls, arg): + if arg == 0: + return arg + return object.__new__(cls) + + def __init__(self, arg): + self.arg = 1 + + with warns_deprecated_sympy(): + assert deprecated_class_new_init(0) == 0 + + with warns_deprecated_sympy(): + assert deprecated_class_new_init(1).arg == 1 diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_deprecated.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_deprecated.py new file mode 100644 index 0000000000000000000000000000000000000000..dd4534ef1abc38ff368011b3ef9d11c497f3675b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_deprecated.py @@ -0,0 +1,13 @@ +from sympy.testing.pytest import warns_deprecated_sympy + +# See https://github.com/sympy/sympy/pull/18095 + +def test_deprecated_utilities(): + with warns_deprecated_sympy(): + import sympy.utilities.pytest # noqa:F401 + with warns_deprecated_sympy(): + import sympy.utilities.runtests # noqa:F401 + with warns_deprecated_sympy(): + import sympy.utilities.randtest # noqa:F401 + with warns_deprecated_sympy(): + import sympy.utilities.tmpfiles # noqa:F401 diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_enumerative.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_enumerative.py new file mode 100644 index 0000000000000000000000000000000000000000..a29c6341dd6f3a19f145c94f6e996cc348428325 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_enumerative.py @@ -0,0 +1,178 @@ +from itertools import zip_longest + +from sympy.utilities.enumerative import ( + list_visitor, + MultisetPartitionTraverser, + multiset_partitions_taocp + ) +from sympy.utilities.iterables import _set_partitions + +# first some functions only useful as test scaffolding - these provide +# straightforward, but slow reference implementations against which to +# compare the real versions, and also a comparison to verify that +# different versions are giving identical results. + +def part_range_filter(partition_iterator, lb, ub): + """ + Filters (on the number of parts) a multiset partition enumeration + + Arguments + ========= + + lb, and ub are a range (in the Python slice sense) on the lpart + variable returned from a multiset partition enumeration. Recall + that lpart is 0-based (it points to the topmost part on the part + stack), so if you want to return parts of sizes 2,3,4,5 you would + use lb=1 and ub=5. + """ + for state in partition_iterator: + f, lpart, pstack = state + if lpart >= lb and lpart < ub: + yield state + +def multiset_partitions_baseline(multiplicities, components): + """Enumerates partitions of a multiset + + Parameters + ========== + + multiplicities + list of integer multiplicities of the components of the multiset. + + components + the components (elements) themselves + + Returns + ======= + + Set of partitions. Each partition is tuple of parts, and each + part is a tuple of components (with repeats to indicate + multiplicity) + + Notes + ===== + + Multiset partitions can be created as equivalence classes of set + partitions, and this function does just that. This approach is + slow and memory intensive compared to the more advanced algorithms + available, but the code is simple and easy to understand. Hence + this routine is strictly for testing -- to provide a + straightforward baseline against which to regress the production + versions. (This code is a simplified version of an earlier + production implementation.) + """ + + canon = [] # list of components with repeats + for ct, elem in zip(multiplicities, components): + canon.extend([elem]*ct) + + # accumulate the multiset partitions in a set to eliminate dups + cache = set() + n = len(canon) + for nc, q in _set_partitions(n): + rv = [[] for i in range(nc)] + for i in range(n): + rv[q[i]].append(canon[i]) + canonical = tuple( + sorted([tuple(p) for p in rv])) + cache.add(canonical) + return cache + + +def compare_multiset_w_baseline(multiplicities): + """ + Enumerates the partitions of multiset with AOCP algorithm and + baseline implementation, and compare the results. + + """ + letters = "abcdefghijklmnopqrstuvwxyz" + bl_partitions = multiset_partitions_baseline(multiplicities, letters) + + # The partitions returned by the different algorithms may have + # their parts in different orders. Also, they generate partitions + # in different orders. Hence the sorting, and set comparison. + + aocp_partitions = set() + for state in multiset_partitions_taocp(multiplicities): + p1 = tuple(sorted( + [tuple(p) for p in list_visitor(state, letters)])) + aocp_partitions.add(p1) + + assert bl_partitions == aocp_partitions + +def compare_multiset_states(s1, s2): + """compare for equality two instances of multiset partition states + + This is useful for comparing different versions of the algorithm + to verify correctness.""" + # Comparison is physical, the only use of semantics is to ignore + # trash off the top of the stack. + f1, lpart1, pstack1 = s1 + f2, lpart2, pstack2 = s2 + + if (lpart1 == lpart2) and (f1[0:lpart1+1] == f2[0:lpart2+1]): + if pstack1[0:f1[lpart1+1]] == pstack2[0:f2[lpart2+1]]: + return True + return False + +def test_multiset_partitions_taocp(): + """Compares the output of multiset_partitions_taocp with a baseline + (set partition based) implementation.""" + + # Test cases should not be too large, since the baseline + # implementation is fairly slow. + multiplicities = [2,2] + compare_multiset_w_baseline(multiplicities) + + multiplicities = [4,3,1] + compare_multiset_w_baseline(multiplicities) + +def test_multiset_partitions_versions(): + """Compares Knuth-based versions of multiset_partitions""" + multiplicities = [5,2,2,1] + m = MultisetPartitionTraverser() + for s1, s2 in zip_longest(m.enum_all(multiplicities), + multiset_partitions_taocp(multiplicities)): + assert compare_multiset_states(s1, s2) + +def subrange_exercise(mult, lb, ub): + """Compare filter-based and more optimized subrange implementations + + Helper for tests, called with both small and larger multisets. + """ + m = MultisetPartitionTraverser() + assert m.count_partitions(mult) == \ + m.count_partitions_slow(mult) + + # Note - multiple traversals from the same + # MultisetPartitionTraverser object cannot execute at the same + # time, hence make several instances here. + ma = MultisetPartitionTraverser() + mc = MultisetPartitionTraverser() + md = MultisetPartitionTraverser() + + # Several paths to compute just the size two partitions + a_it = ma.enum_range(mult, lb, ub) + b_it = part_range_filter(multiset_partitions_taocp(mult), lb, ub) + c_it = part_range_filter(mc.enum_small(mult, ub), lb, sum(mult)) + d_it = part_range_filter(md.enum_large(mult, lb), 0, ub) + + for sa, sb, sc, sd in zip_longest(a_it, b_it, c_it, d_it): + assert compare_multiset_states(sa, sb) + assert compare_multiset_states(sa, sc) + assert compare_multiset_states(sa, sd) + +def test_subrange(): + # Quick, but doesn't hit some of the corner cases + mult = [4,4,2,1] # mississippi + lb = 1 + ub = 2 + subrange_exercise(mult, lb, ub) + + +def test_subrange_large(): + # takes a second or so, depending on cpu, Python version, etc. + mult = [6,3,2,1] + lb = 4 + ub = 7 + subrange_exercise(mult, lb, ub) diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_exceptions.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_exceptions.py new file mode 100644 index 0000000000000000000000000000000000000000..d91e55e95d0ae4ac57cdd1989e0b3d39a55cd07d --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_exceptions.py @@ -0,0 +1,12 @@ +from sympy.testing.pytest import raises +from sympy.utilities.exceptions import sympy_deprecation_warning + +# Only test exceptions here because the other cases are tested in the +# warns_deprecated_sympy tests +def test_sympy_deprecation_warning(): + raises(TypeError, lambda: sympy_deprecation_warning('test', + deprecated_since_version=1.10, + active_deprecations_target='active-deprecations')) + + raises(ValueError, lambda: sympy_deprecation_warning('test', + deprecated_since_version="1.10", active_deprecations_target='(active-deprecations)=')) diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_iterables.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_iterables.py new file mode 100644 index 0000000000000000000000000000000000000000..1003522bcd556c6f63e04de7da57b43498575fee --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_iterables.py @@ -0,0 +1,945 @@ +from textwrap import dedent +from itertools import islice, product + +from sympy.core.basic import Basic +from sympy.core.numbers import Integer +from sympy.core.sorting import ordered +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.factorials import factorial +from sympy.matrices.dense import Matrix +from sympy.combinatorics import RGS_enum, RGS_unrank, Permutation +from sympy.utilities.iterables import ( + _partition, _set_partitions, binary_partitions, bracelets, capture, + cartes, common_prefix, common_suffix, connected_components, dict_merge, + filter_symbols, flatten, generate_bell, generate_derangements, + generate_involutions, generate_oriented_forest, group, has_dups, ibin, + iproduct, kbins, minlex, multiset, multiset_combinations, + multiset_partitions, multiset_permutations, necklaces, numbered_symbols, + partitions, permutations, postfixes, + prefixes, reshape, rotate_left, rotate_right, runs, sift, + strongly_connected_components, subsets, take, topological_sort, unflatten, + uniq, variations, ordered_partitions, rotations, is_palindromic, iterable, + NotIterable, multiset_derangements, signed_permutations, + sequence_partitions, sequence_partitions_empty) +from sympy.utilities.enumerative import ( + factoring_visitor, multiset_partitions_taocp ) + +from sympy.core.singleton import S +from sympy.testing.pytest import raises, warns_deprecated_sympy + +w, x, y, z = symbols('w,x,y,z') + + +def test_deprecated_iterables(): + from sympy.utilities.iterables import default_sort_key, ordered + with warns_deprecated_sympy(): + assert list(ordered([y, x])) == [x, y] + with warns_deprecated_sympy(): + assert sorted([y, x], key=default_sort_key) == [x, y] + + +def test_is_palindromic(): + assert is_palindromic('') + assert is_palindromic('x') + assert is_palindromic('xx') + assert is_palindromic('xyx') + assert not is_palindromic('xy') + assert not is_palindromic('xyzx') + assert is_palindromic('xxyzzyx', 1) + assert not is_palindromic('xxyzzyx', 2) + assert is_palindromic('xxyzzyx', 2, -1) + assert is_palindromic('xxyzzyx', 2, 6) + assert is_palindromic('xxyzyx', 1) + assert not is_palindromic('xxyzyx', 2) + assert is_palindromic('xxyzyx', 2, 2 + 3) + + +def test_flatten(): + assert flatten((1, (1,))) == [1, 1] + assert flatten((x, (x,))) == [x, x] + + ls = [[(-2, -1), (1, 2)], [(0, 0)]] + + assert flatten(ls, levels=0) == ls + assert flatten(ls, levels=1) == [(-2, -1), (1, 2), (0, 0)] + assert flatten(ls, levels=2) == [-2, -1, 1, 2, 0, 0] + assert flatten(ls, levels=3) == [-2, -1, 1, 2, 0, 0] + + raises(ValueError, lambda: flatten(ls, levels=-1)) + + class MyOp(Basic): + pass + + assert flatten([MyOp(x, y), z]) == [MyOp(x, y), z] + assert flatten([MyOp(x, y), z], cls=MyOp) == [x, y, z] + + assert flatten({1, 11, 2}) == list({1, 11, 2}) + + +def test_iproduct(): + assert list(iproduct()) == [()] + assert list(iproduct([])) == [] + assert list(iproduct([1,2,3])) == [(1,),(2,),(3,)] + assert sorted(iproduct([1, 2], [3, 4, 5])) == [ + (1,3),(1,4),(1,5),(2,3),(2,4),(2,5)] + assert sorted(iproduct([0,1],[0,1],[0,1])) == [ + (0,0,0),(0,0,1),(0,1,0),(0,1,1),(1,0,0),(1,0,1),(1,1,0),(1,1,1)] + assert iterable(iproduct(S.Integers)) is True + assert iterable(iproduct(S.Integers, S.Integers)) is True + assert (3,) in iproduct(S.Integers) + assert (4, 5) in iproduct(S.Integers, S.Integers) + assert (1, 2, 3) in iproduct(S.Integers, S.Integers, S.Integers) + triples = set(islice(iproduct(S.Integers, S.Integers, S.Integers), 1000)) + for n1, n2, n3 in triples: + assert isinstance(n1, Integer) + assert isinstance(n2, Integer) + assert isinstance(n3, Integer) + for t in set(product(*([range(-2, 3)]*3))): + assert t in iproduct(S.Integers, S.Integers, S.Integers) + + +def test_group(): + assert group([]) == [] + assert group([], multiple=False) == [] + + assert group([1]) == [[1]] + assert group([1], multiple=False) == [(1, 1)] + + assert group([1, 1]) == [[1, 1]] + assert group([1, 1], multiple=False) == [(1, 2)] + + assert group([1, 1, 1]) == [[1, 1, 1]] + assert group([1, 1, 1], multiple=False) == [(1, 3)] + + assert group([1, 2, 1]) == [[1], [2], [1]] + assert group([1, 2, 1], multiple=False) == [(1, 1), (2, 1), (1, 1)] + + assert group([1, 1, 2, 2, 2, 1, 3, 3]) == [[1, 1], [2, 2, 2], [1], [3, 3]] + assert group([1, 1, 2, 2, 2, 1, 3, 3], multiple=False) == [(1, 2), + (2, 3), (1, 1), (3, 2)] + + +def test_subsets(): + # combinations + assert list(subsets([1, 2, 3], 0)) == [()] + assert list(subsets([1, 2, 3], 1)) == [(1,), (2,), (3,)] + assert list(subsets([1, 2, 3], 2)) == [(1, 2), (1, 3), (2, 3)] + assert list(subsets([1, 2, 3], 3)) == [(1, 2, 3)] + l = list(range(4)) + assert list(subsets(l, 0, repetition=True)) == [()] + assert list(subsets(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)] + assert list(subsets(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2), + (0, 3), (1, 1), (1, 2), + (1, 3), (2, 2), (2, 3), + (3, 3)] + assert list(subsets(l, 3, repetition=True)) == [(0, 0, 0), (0, 0, 1), + (0, 0, 2), (0, 0, 3), + (0, 1, 1), (0, 1, 2), + (0, 1, 3), (0, 2, 2), + (0, 2, 3), (0, 3, 3), + (1, 1, 1), (1, 1, 2), + (1, 1, 3), (1, 2, 2), + (1, 2, 3), (1, 3, 3), + (2, 2, 2), (2, 2, 3), + (2, 3, 3), (3, 3, 3)] + assert len(list(subsets(l, 4, repetition=True))) == 35 + + assert list(subsets(l[:2], 3, repetition=False)) == [] + assert list(subsets(l[:2], 3, repetition=True)) == [(0, 0, 0), + (0, 0, 1), + (0, 1, 1), + (1, 1, 1)] + assert list(subsets([1, 2], repetition=True)) == \ + [(), (1,), (2,), (1, 1), (1, 2), (2, 2)] + assert list(subsets([1, 2], repetition=False)) == \ + [(), (1,), (2,), (1, 2)] + assert list(subsets([1, 2, 3], 2)) == \ + [(1, 2), (1, 3), (2, 3)] + assert list(subsets([1, 2, 3], 2, repetition=True)) == \ + [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] + + +def test_variations(): + # permutations + l = list(range(4)) + assert list(variations(l, 0, repetition=False)) == [()] + assert list(variations(l, 1, repetition=False)) == [(0,), (1,), (2,), (3,)] + assert list(variations(l, 2, repetition=False)) == [(0, 1), (0, 2), (0, 3), (1, 0), (1, 2), (1, 3), (2, 0), (2, 1), (2, 3), (3, 0), (3, 1), (3, 2)] + assert list(variations(l, 3, repetition=False)) == [(0, 1, 2), (0, 1, 3), (0, 2, 1), (0, 2, 3), (0, 3, 1), (0, 3, 2), (1, 0, 2), (1, 0, 3), (1, 2, 0), (1, 2, 3), (1, 3, 0), (1, 3, 2), (2, 0, 1), (2, 0, 3), (2, 1, 0), (2, 1, 3), (2, 3, 0), (2, 3, 1), (3, 0, 1), (3, 0, 2), (3, 1, 0), (3, 1, 2), (3, 2, 0), (3, 2, 1)] + assert list(variations(l, 0, repetition=True)) == [()] + assert list(variations(l, 1, repetition=True)) == [(0,), (1,), (2,), (3,)] + assert list(variations(l, 2, repetition=True)) == [(0, 0), (0, 1), (0, 2), + (0, 3), (1, 0), (1, 1), + (1, 2), (1, 3), (2, 0), + (2, 1), (2, 2), (2, 3), + (3, 0), (3, 1), (3, 2), + (3, 3)] + assert len(list(variations(l, 3, repetition=True))) == 64 + assert len(list(variations(l, 4, repetition=True))) == 256 + assert list(variations(l[:2], 3, repetition=False)) == [] + assert list(variations(l[:2], 3, repetition=True)) == [ + (0, 0, 0), (0, 0, 1), (0, 1, 0), (0, 1, 1), + (1, 0, 0), (1, 0, 1), (1, 1, 0), (1, 1, 1) + ] + + +def test_cartes(): + assert list(cartes([1, 2], [3, 4, 5])) == \ + [(1, 3), (1, 4), (1, 5), (2, 3), (2, 4), (2, 5)] + assert list(cartes()) == [()] + assert list(cartes('a')) == [('a',)] + assert list(cartes('a', repeat=2)) == [('a', 'a')] + assert list(cartes(list(range(2)))) == [(0,), (1,)] + + +def test_filter_symbols(): + s = numbered_symbols() + filtered = filter_symbols(s, symbols("x0 x2 x3")) + assert take(filtered, 3) == list(symbols("x1 x4 x5")) + + +def test_numbered_symbols(): + s = numbered_symbols(cls=Dummy) + assert isinstance(next(s), Dummy) + assert next(numbered_symbols('C', start=1, exclude=[symbols('C1')])) == \ + symbols('C2') + + +def test_sift(): + assert sift(list(range(5)), lambda _: _ % 2) == {1: [1, 3], 0: [0, 2, 4]} + assert sift([x, y], lambda _: _.has(x)) == {False: [y], True: [x]} + assert sift([S.One], lambda _: _.has(x)) == {False: [1]} + assert sift([0, 1, 2, 3], lambda x: x % 2, binary=True) == ( + [1, 3], [0, 2]) + assert sift([0, 1, 2, 3], lambda x: x % 3 == 1, binary=True) == ( + [1], [0, 2, 3]) + raises(ValueError, lambda: + sift([0, 1, 2, 3], lambda x: x % 3, binary=True)) + + +def test_take(): + X = numbered_symbols() + + assert take(X, 5) == list(symbols('x0:5')) + assert take(X, 5) == list(symbols('x5:10')) + + assert take([1, 2, 3, 4, 5], 5) == [1, 2, 3, 4, 5] + + +def test_dict_merge(): + assert dict_merge({}, {1: x, y: z}) == {1: x, y: z} + assert dict_merge({1: x, y: z}, {}) == {1: x, y: z} + + assert dict_merge({2: z}, {1: x, y: z}) == {1: x, 2: z, y: z} + assert dict_merge({1: x, y: z}, {2: z}) == {1: x, 2: z, y: z} + + assert dict_merge({1: y, 2: z}, {1: x, y: z}) == {1: x, 2: z, y: z} + assert dict_merge({1: x, y: z}, {1: y, 2: z}) == {1: y, 2: z, y: z} + + +def test_prefixes(): + assert list(prefixes([])) == [] + assert list(prefixes([1])) == [[1]] + assert list(prefixes([1, 2])) == [[1], [1, 2]] + + assert list(prefixes([1, 2, 3, 4, 5])) == \ + [[1], [1, 2], [1, 2, 3], [1, 2, 3, 4], [1, 2, 3, 4, 5]] + + +def test_postfixes(): + assert list(postfixes([])) == [] + assert list(postfixes([1])) == [[1]] + assert list(postfixes([1, 2])) == [[2], [1, 2]] + + assert list(postfixes([1, 2, 3, 4, 5])) == \ + [[5], [4, 5], [3, 4, 5], [2, 3, 4, 5], [1, 2, 3, 4, 5]] + + +def test_topological_sort(): + V = [2, 3, 5, 7, 8, 9, 10, 11] + E = [(7, 11), (7, 8), (5, 11), + (3, 8), (3, 10), (11, 2), + (11, 9), (11, 10), (8, 9)] + + assert topological_sort((V, E)) == [3, 5, 7, 8, 11, 2, 9, 10] + assert topological_sort((V, E), key=lambda v: -v) == \ + [7, 5, 11, 3, 10, 8, 9, 2] + + raises(ValueError, lambda: topological_sort((V, E + [(10, 7)]))) + + +def test_strongly_connected_components(): + assert strongly_connected_components(([], [])) == [] + assert strongly_connected_components(([1, 2, 3], [])) == [[1], [2], [3]] + + V = [1, 2, 3] + E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)] + assert strongly_connected_components((V, E)) == [[1, 2, 3]] + + V = [1, 2, 3, 4] + E = [(1, 2), (2, 3), (3, 2), (3, 4)] + assert strongly_connected_components((V, E)) == [[4], [2, 3], [1]] + + V = [1, 2, 3, 4] + E = [(1, 2), (2, 1), (3, 4), (4, 3)] + assert strongly_connected_components((V, E)) == [[1, 2], [3, 4]] + + +def test_connected_components(): + assert connected_components(([], [])) == [] + assert connected_components(([1, 2, 3], [])) == [[1], [2], [3]] + + V = [1, 2, 3] + E = [(1, 2), (1, 3), (2, 1), (2, 3), (3, 1)] + assert connected_components((V, E)) == [[1, 2, 3]] + + V = [1, 2, 3, 4] + E = [(1, 2), (2, 3), (3, 2), (3, 4)] + assert connected_components((V, E)) == [[1, 2, 3, 4]] + + V = [1, 2, 3, 4] + E = [(1, 2), (3, 4)] + assert connected_components((V, E)) == [[1, 2], [3, 4]] + + +def test_rotate(): + A = [0, 1, 2, 3, 4] + + assert rotate_left(A, 2) == [2, 3, 4, 0, 1] + assert rotate_right(A, 1) == [4, 0, 1, 2, 3] + A = [] + B = rotate_right(A, 1) + assert B == [] + B.append(1) + assert A == [] + B = rotate_left(A, 1) + assert B == [] + B.append(1) + assert A == [] + + +def test_multiset_partitions(): + A = [0, 1, 2, 3, 4] + + assert list(multiset_partitions(A, 5)) == [[[0], [1], [2], [3], [4]]] + assert len(list(multiset_partitions(A, 4))) == 10 + assert len(list(multiset_partitions(A, 3))) == 25 + + assert list(multiset_partitions([1, 1, 1, 2, 2], 2)) == [ + [[1, 1, 1, 2], [2]], [[1, 1, 1], [2, 2]], [[1, 1, 2, 2], [1]], + [[1, 1, 2], [1, 2]], [[1, 1], [1, 2, 2]]] + + assert list(multiset_partitions([1, 1, 2, 2], 2)) == [ + [[1, 1, 2], [2]], [[1, 1], [2, 2]], [[1, 2, 2], [1]], + [[1, 2], [1, 2]]] + + assert list(multiset_partitions([1, 2, 3, 4], 2)) == [ + [[1, 2, 3], [4]], [[1, 2, 4], [3]], [[1, 2], [3, 4]], + [[1, 3, 4], [2]], [[1, 3], [2, 4]], [[1, 4], [2, 3]], + [[1], [2, 3, 4]]] + + assert list(multiset_partitions([1, 2, 2], 2)) == [ + [[1, 2], [2]], [[1], [2, 2]]] + + assert list(multiset_partitions(3)) == [ + [[0, 1, 2]], [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]], + [[0], [1], [2]]] + assert list(multiset_partitions(3, 2)) == [ + [[0, 1], [2]], [[0, 2], [1]], [[0], [1, 2]]] + assert list(multiset_partitions([1] * 3, 2)) == [[[1], [1, 1]]] + assert list(multiset_partitions([1] * 3)) == [ + [[1, 1, 1]], [[1], [1, 1]], [[1], [1], [1]]] + a = [3, 2, 1] + assert list(multiset_partitions(a)) == \ + list(multiset_partitions(sorted(a))) + assert list(multiset_partitions(a, 5)) == [] + assert list(multiset_partitions(a, 1)) == [[[1, 2, 3]]] + assert list(multiset_partitions(a + [4], 5)) == [] + assert list(multiset_partitions(a + [4], 1)) == [[[1, 2, 3, 4]]] + assert list(multiset_partitions(2, 5)) == [] + assert list(multiset_partitions(2, 1)) == [[[0, 1]]] + assert list(multiset_partitions('a')) == [[['a']]] + assert list(multiset_partitions('a', 2)) == [] + assert list(multiset_partitions('ab')) == [[['a', 'b']], [['a'], ['b']]] + assert list(multiset_partitions('ab', 1)) == [[['a', 'b']]] + assert list(multiset_partitions('aaa', 1)) == [['aaa']] + assert list(multiset_partitions([1, 1], 1)) == [[[1, 1]]] + ans = [('mpsyy',), ('mpsy', 'y'), ('mps', 'yy'), ('mps', 'y', 'y'), + ('mpyy', 's'), ('mpy', 'sy'), ('mpy', 's', 'y'), ('mp', 'syy'), + ('mp', 'sy', 'y'), ('mp', 's', 'yy'), ('mp', 's', 'y', 'y'), + ('msyy', 'p'), ('msy', 'py'), ('msy', 'p', 'y'), ('ms', 'pyy'), + ('ms', 'py', 'y'), ('ms', 'p', 'yy'), ('ms', 'p', 'y', 'y'), + ('myy', 'ps'), ('myy', 'p', 's'), ('my', 'psy'), ('my', 'ps', 'y'), + ('my', 'py', 's'), ('my', 'p', 'sy'), ('my', 'p', 's', 'y'), + ('m', 'psyy'), ('m', 'psy', 'y'), ('m', 'ps', 'yy'), + ('m', 'ps', 'y', 'y'), ('m', 'pyy', 's'), ('m', 'py', 'sy'), + ('m', 'py', 's', 'y'), ('m', 'p', 'syy'), + ('m', 'p', 'sy', 'y'), ('m', 'p', 's', 'yy'), + ('m', 'p', 's', 'y', 'y')] + assert [tuple("".join(part) for part in p) + for p in multiset_partitions('sympy')] == ans + factorings = [[24], [8, 3], [12, 2], [4, 6], [4, 2, 3], + [6, 2, 2], [2, 2, 2, 3]] + assert [factoring_visitor(p, [2,3]) for + p in multiset_partitions_taocp([3, 1])] == factorings + + +def test_multiset_combinations(): + ans = ['iii', 'iim', 'iip', 'iis', 'imp', 'ims', 'ipp', 'ips', + 'iss', 'mpp', 'mps', 'mss', 'pps', 'pss', 'sss'] + assert [''.join(i) for i in + list(multiset_combinations('mississippi', 3))] == ans + M = multiset('mississippi') + assert [''.join(i) for i in + list(multiset_combinations(M, 3))] == ans + assert [''.join(i) for i in multiset_combinations(M, 30)] == [] + assert list(multiset_combinations([[1], [2, 3]], 2)) == [[[1], [2, 3]]] + assert len(list(multiset_combinations('a', 3))) == 0 + assert len(list(multiset_combinations('a', 0))) == 1 + assert list(multiset_combinations('abc', 1)) == [['a'], ['b'], ['c']] + raises(ValueError, lambda: list(multiset_combinations({0: 3, 1: -1}, 2))) + + +def test_multiset_permutations(): + ans = ['abby', 'abyb', 'aybb', 'baby', 'bayb', 'bbay', 'bbya', 'byab', + 'byba', 'yabb', 'ybab', 'ybba'] + assert [''.join(i) for i in multiset_permutations('baby')] == ans + assert [''.join(i) for i in multiset_permutations(multiset('baby'))] == ans + assert list(multiset_permutations([0, 0, 0], 2)) == [[0, 0]] + assert list(multiset_permutations([0, 2, 1], 2)) == [ + [0, 1], [0, 2], [1, 0], [1, 2], [2, 0], [2, 1]] + assert len(list(multiset_permutations('a', 0))) == 1 + assert len(list(multiset_permutations('a', 3))) == 0 + for nul in ([], {}, ''): + assert list(multiset_permutations(nul)) == [[]] + assert list(multiset_permutations(nul, 0)) == [[]] + # impossible requests give no result + assert list(multiset_permutations(nul, 1)) == [] + assert list(multiset_permutations(nul, -1)) == [] + + def test(): + for i in range(1, 7): + print(i) + for p in multiset_permutations([0, 0, 1, 0, 1], i): + print(p) + assert capture(lambda: test()) == dedent('''\ + 1 + [0] + [1] + 2 + [0, 0] + [0, 1] + [1, 0] + [1, 1] + 3 + [0, 0, 0] + [0, 0, 1] + [0, 1, 0] + [0, 1, 1] + [1, 0, 0] + [1, 0, 1] + [1, 1, 0] + 4 + [0, 0, 0, 1] + [0, 0, 1, 0] + [0, 0, 1, 1] + [0, 1, 0, 0] + [0, 1, 0, 1] + [0, 1, 1, 0] + [1, 0, 0, 0] + [1, 0, 0, 1] + [1, 0, 1, 0] + [1, 1, 0, 0] + 5 + [0, 0, 0, 1, 1] + [0, 0, 1, 0, 1] + [0, 0, 1, 1, 0] + [0, 1, 0, 0, 1] + [0, 1, 0, 1, 0] + [0, 1, 1, 0, 0] + [1, 0, 0, 0, 1] + [1, 0, 0, 1, 0] + [1, 0, 1, 0, 0] + [1, 1, 0, 0, 0] + 6\n''') + raises(ValueError, lambda: list(multiset_permutations({0: 3, 1: -1}))) + + +def test_partitions(): + ans = [[{}], [(0, {})]] + for i in range(2): + assert list(partitions(0, size=i)) == ans[i] + assert list(partitions(1, 0, size=i)) == ans[i] + assert list(partitions(6, 2, 2, size=i)) == ans[i] + assert list(partitions(6, 2, None, size=i)) != ans[i] + assert list(partitions(6, None, 2, size=i)) != ans[i] + assert list(partitions(6, 2, 0, size=i)) == ans[i] + + assert list(partitions(6, k=2)) == [ + {2: 3}, {1: 2, 2: 2}, {1: 4, 2: 1}, {1: 6}] + + assert list(partitions(6, k=3)) == [ + {3: 2}, {1: 1, 2: 1, 3: 1}, {1: 3, 3: 1}, {2: 3}, {1: 2, 2: 2}, + {1: 4, 2: 1}, {1: 6}] + + assert list(partitions(8, k=4, m=3)) == [ + {4: 2}, {1: 1, 3: 1, 4: 1}, {2: 2, 4: 1}, {2: 1, 3: 2}] == [ + i for i in partitions(8, k=4, m=3) if all(k <= 4 for k in i) + and sum(i.values()) <=3] + + assert list(partitions(S(3), m=2)) == [ + {3: 1}, {1: 1, 2: 1}] + + assert list(partitions(4, k=3)) == [ + {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] == [ + i for i in partitions(4) if all(k <= 3 for k in i)] + + + # Consistency check on output of _partitions and RGS_unrank. + # This provides a sanity test on both routines. Also verifies that + # the total number of partitions is the same in each case. + # (from pkrathmann2) + + for n in range(2, 6): + i = 0 + for m, q in _set_partitions(n): + assert q == RGS_unrank(i, n) + i += 1 + assert i == RGS_enum(n) + + +def test_binary_partitions(): + assert [i[:] for i in binary_partitions(10)] == [[8, 2], [8, 1, 1], + [4, 4, 2], [4, 4, 1, 1], [4, 2, 2, 2], [4, 2, 2, 1, 1], + [4, 2, 1, 1, 1, 1], [4, 1, 1, 1, 1, 1, 1], [2, 2, 2, 2, 2], + [2, 2, 2, 2, 1, 1], [2, 2, 2, 1, 1, 1, 1], [2, 2, 1, 1, 1, 1, 1, 1], + [2, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]] + + assert len([j[:] for j in binary_partitions(16)]) == 36 + + +def test_bell_perm(): + assert [len(set(generate_bell(i))) for i in range(1, 7)] == [ + factorial(i) for i in range(1, 7)] + assert list(generate_bell(3)) == [ + (0, 1, 2), (0, 2, 1), (2, 0, 1), (2, 1, 0), (1, 2, 0), (1, 0, 2)] + # generate_bell and trotterjohnson are advertised to return the same + # permutations; this is not technically necessary so this test could + # be removed + for n in range(1, 5): + p = Permutation(range(n)) + b = generate_bell(n) + for bi in b: + assert bi == tuple(p.array_form) + p = p.next_trotterjohnson() + raises(ValueError, lambda: list(generate_bell(0))) # XXX is this consistent with other permutation algorithms? + + +def test_involutions(): + lengths = [1, 2, 4, 10, 26, 76] + for n, N in enumerate(lengths): + i = list(generate_involutions(n + 1)) + assert len(i) == N + assert len({Permutation(j)**2 for j in i}) == 1 + + +def test_derangements(): + assert len(list(generate_derangements(list(range(6))))) == 265 + assert ''.join(''.join(i) for i in generate_derangements('abcde')) == ( + 'badecbaecdbcaedbcdeabceadbdaecbdeacbdecabeacdbedacbedcacabedcadebcaebd' + 'cdaebcdbeacdeabcdebaceabdcebadcedabcedbadabecdaebcdaecbdcaebdcbeadceab' + 'dcebadeabcdeacbdebacdebcaeabcdeadbceadcbecabdecbadecdabecdbaedabcedacb' + 'edbacedbca') + assert list(generate_derangements([0, 1, 2, 3])) == [ + [1, 0, 3, 2], [1, 2, 3, 0], [1, 3, 0, 2], [2, 0, 3, 1], + [2, 3, 0, 1], [2, 3, 1, 0], [3, 0, 1, 2], [3, 2, 0, 1], [3, 2, 1, 0]] + assert list(generate_derangements([0, 1, 2, 2])) == [ + [2, 2, 0, 1], [2, 2, 1, 0]] + assert list(generate_derangements('ba')) == [list('ab')] + # multiset_derangements + D = multiset_derangements + assert list(D('abb')) == [] + assert [''.join(i) for i in D('ab')] == ['ba'] + assert [''.join(i) for i in D('abc')] == ['bca', 'cab'] + assert [''.join(i) for i in D('aabb')] == ['bbaa'] + assert [''.join(i) for i in D('aabbcccc')] == [ + 'ccccaabb', 'ccccabab', 'ccccabba', 'ccccbaab', 'ccccbaba', + 'ccccbbaa'] + assert [''.join(i) for i in D('aabbccc')] == [ + 'cccabba', 'cccabab', 'cccaabb', 'ccacbba', 'ccacbab', + 'ccacabb', 'cbccbaa', 'cbccaba', 'cbccaab', 'bcccbaa', + 'bcccaba', 'bcccaab'] + assert [''.join(i) for i in D('books')] == ['kbsoo', 'ksboo', + 'sbkoo', 'skboo', 'oksbo', 'oskbo', 'okbso', 'obkso', 'oskob', + 'oksob', 'osbok', 'obsok'] + assert list(generate_derangements([[3], [2], [2], [1]])) == [ + [[2], [1], [3], [2]], [[2], [3], [1], [2]]] + + +def test_necklaces(): + def count(n, k, f): + return len(list(necklaces(n, k, f))) + m = [] + for i in range(1, 8): + m.append(( + i, count(i, 2, 0), count(i, 2, 1), count(i, 3, 1))) + assert Matrix(m) == Matrix([ + [1, 2, 2, 3], + [2, 3, 3, 6], + [3, 4, 4, 10], + [4, 6, 6, 21], + [5, 8, 8, 39], + [6, 14, 13, 92], + [7, 20, 18, 198]]) + + +def test_bracelets(): + bc = list(bracelets(2, 4)) + assert Matrix(bc) == Matrix([ + [0, 0], + [0, 1], + [0, 2], + [0, 3], + [1, 1], + [1, 2], + [1, 3], + [2, 2], + [2, 3], + [3, 3] + ]) + bc = list(bracelets(4, 2)) + assert Matrix(bc) == Matrix([ + [0, 0, 0, 0], + [0, 0, 0, 1], + [0, 0, 1, 1], + [0, 1, 0, 1], + [0, 1, 1, 1], + [1, 1, 1, 1] + ]) + + +def test_generate_oriented_forest(): + assert list(generate_oriented_forest(5)) == [[0, 1, 2, 3, 4], + [0, 1, 2, 3, 3], [0, 1, 2, 3, 2], [0, 1, 2, 3, 1], [0, 1, 2, 3, 0], + [0, 1, 2, 2, 2], [0, 1, 2, 2, 1], [0, 1, 2, 2, 0], [0, 1, 2, 1, 2], + [0, 1, 2, 1, 1], [0, 1, 2, 1, 0], [0, 1, 2, 0, 1], [0, 1, 2, 0, 0], + [0, 1, 1, 1, 1], [0, 1, 1, 1, 0], [0, 1, 1, 0, 1], [0, 1, 1, 0, 0], + [0, 1, 0, 1, 0], [0, 1, 0, 0, 0], [0, 0, 0, 0, 0]] + assert len(list(generate_oriented_forest(10))) == 1842 + + +def test_unflatten(): + r = list(range(10)) + assert unflatten(r) == list(zip(r[::2], r[1::2])) + assert unflatten(r, 5) == [tuple(r[:5]), tuple(r[5:])] + raises(ValueError, lambda: unflatten(list(range(10)), 3)) + raises(ValueError, lambda: unflatten(list(range(10)), -2)) + + +def test_common_prefix_suffix(): + assert common_prefix([], [1]) == [] + assert common_prefix(list(range(3))) == [0, 1, 2] + assert common_prefix(list(range(3)), list(range(4))) == [0, 1, 2] + assert common_prefix([1, 2, 3], [1, 2, 5]) == [1, 2] + assert common_prefix([1, 2, 3], [1, 3, 5]) == [1] + + assert common_suffix([], [1]) == [] + assert common_suffix(list(range(3))) == [0, 1, 2] + assert common_suffix(list(range(3)), list(range(3))) == [0, 1, 2] + assert common_suffix(list(range(3)), list(range(4))) == [] + assert common_suffix([1, 2, 3], [9, 2, 3]) == [2, 3] + assert common_suffix([1, 2, 3], [9, 7, 3]) == [3] + + +def test_minlex(): + assert minlex([1, 2, 0]) == (0, 1, 2) + assert minlex((1, 2, 0)) == (0, 1, 2) + assert minlex((1, 0, 2)) == (0, 2, 1) + assert minlex((1, 0, 2), directed=False) == (0, 1, 2) + assert minlex('aba') == 'aab' + assert minlex(('bb', 'aaa', 'c', 'a'), key=len) == ('c', 'a', 'bb', 'aaa') + + +def test_ordered(): + assert list(ordered((x, y), hash, default=False)) in [[x, y], [y, x]] + assert list(ordered((x, y), hash, default=False)) == \ + list(ordered((y, x), hash, default=False)) + assert list(ordered((x, y))) == [x, y] + + seq, keys = [[[1, 2, 1], [0, 3, 1], [1, 1, 3], [2], [1]], + (lambda x: len(x), lambda x: sum(x))] + assert list(ordered(seq, keys, default=False, warn=False)) == \ + [[1], [2], [1, 2, 1], [0, 3, 1], [1, 1, 3]] + raises(ValueError, lambda: + list(ordered(seq, keys, default=False, warn=True))) + + +def test_runs(): + assert runs([]) == [] + assert runs([1]) == [[1]] + assert runs([1, 1]) == [[1], [1]] + assert runs([1, 1, 2]) == [[1], [1, 2]] + assert runs([1, 2, 1]) == [[1, 2], [1]] + assert runs([2, 1, 1]) == [[2], [1], [1]] + from operator import lt + assert runs([2, 1, 1], lt) == [[2, 1], [1]] + + +def test_reshape(): + seq = list(range(1, 9)) + assert reshape(seq, [4]) == \ + [[1, 2, 3, 4], [5, 6, 7, 8]] + assert reshape(seq, (4,)) == \ + [(1, 2, 3, 4), (5, 6, 7, 8)] + assert reshape(seq, (2, 2)) == \ + [(1, 2, 3, 4), (5, 6, 7, 8)] + assert reshape(seq, (2, [2])) == \ + [(1, 2, [3, 4]), (5, 6, [7, 8])] + assert reshape(seq, ((2,), [2])) == \ + [((1, 2), [3, 4]), ((5, 6), [7, 8])] + assert reshape(seq, (1, [2], 1)) == \ + [(1, [2, 3], 4), (5, [6, 7], 8)] + assert reshape(tuple(seq), ([[1], 1, (2,)],)) == \ + (([[1], 2, (3, 4)],), ([[5], 6, (7, 8)],)) + assert reshape(tuple(seq), ([1], 1, (2,))) == \ + (([1], 2, (3, 4)), ([5], 6, (7, 8))) + assert reshape(list(range(12)), [2, [3], {2}, (1, (3,), 1)]) == \ + [[0, 1, [2, 3, 4], {5, 6}, (7, (8, 9, 10), 11)]] + raises(ValueError, lambda: reshape([0, 1], [-1])) + raises(ValueError, lambda: reshape([0, 1], [3])) + + +def test_uniq(): + assert list(uniq(p for p in partitions(4))) == \ + [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2: 1}, {1: 4}] + assert list(uniq(x % 2 for x in range(5))) == [0, 1] + assert list(uniq('a')) == ['a'] + assert list(uniq('ababc')) == list('abc') + assert list(uniq([[1], [2, 1], [1]])) == [[1], [2, 1]] + assert list(uniq(permutations(i for i in [[1], 2, 2]))) == \ + [([1], 2, 2), (2, [1], 2), (2, 2, [1])] + assert list(uniq([2, 3, 2, 4, [2], [1], [2], [3], [1]])) == \ + [2, 3, 4, [2], [1], [3]] + f = [1] + raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)]) + f = [[1]] + raises(RuntimeError, lambda: [f.remove(i) for i in uniq(f)]) + + +def test_kbins(): + assert len(list(kbins('1123', 2, ordered=1))) == 24 + assert len(list(kbins('1123', 2, ordered=11))) == 36 + assert len(list(kbins('1123', 2, ordered=10))) == 10 + assert len(list(kbins('1123', 2, ordered=0))) == 5 + assert len(list(kbins('1123', 2, ordered=None))) == 3 + + def test1(): + for orderedval in [None, 0, 1, 10, 11]: + print('ordered =', orderedval) + for p in kbins([0, 0, 1], 2, ordered=orderedval): + print(' ', p) + assert capture(lambda : test1()) == dedent('''\ + ordered = None + [[0], [0, 1]] + [[0, 0], [1]] + ordered = 0 + [[0, 0], [1]] + [[0, 1], [0]] + ordered = 1 + [[0], [0, 1]] + [[0], [1, 0]] + [[1], [0, 0]] + ordered = 10 + [[0, 0], [1]] + [[1], [0, 0]] + [[0, 1], [0]] + [[0], [0, 1]] + ordered = 11 + [[0], [0, 1]] + [[0, 0], [1]] + [[0], [1, 0]] + [[0, 1], [0]] + [[1], [0, 0]] + [[1, 0], [0]]\n''') + + def test2(): + for orderedval in [None, 0, 1, 10, 11]: + print('ordered =', orderedval) + for p in kbins(list(range(3)), 2, ordered=orderedval): + print(' ', p) + assert capture(lambda : test2()) == dedent('''\ + ordered = None + [[0], [1, 2]] + [[0, 1], [2]] + ordered = 0 + [[0, 1], [2]] + [[0, 2], [1]] + [[0], [1, 2]] + ordered = 1 + [[0], [1, 2]] + [[0], [2, 1]] + [[1], [0, 2]] + [[1], [2, 0]] + [[2], [0, 1]] + [[2], [1, 0]] + ordered = 10 + [[0, 1], [2]] + [[2], [0, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[0], [1, 2]] + [[1, 2], [0]] + ordered = 11 + [[0], [1, 2]] + [[0, 1], [2]] + [[0], [2, 1]] + [[0, 2], [1]] + [[1], [0, 2]] + [[1, 0], [2]] + [[1], [2, 0]] + [[1, 2], [0]] + [[2], [0, 1]] + [[2, 0], [1]] + [[2], [1, 0]] + [[2, 1], [0]]\n''') + + +def test_has_dups(): + assert has_dups(set()) is False + assert has_dups(list(range(3))) is False + assert has_dups([1, 2, 1]) is True + assert has_dups([[1], [1]]) is True + assert has_dups([[1], [2]]) is False + + +def test__partition(): + assert _partition('abcde', [1, 0, 1, 2, 0]) == [ + ['b', 'e'], ['a', 'c'], ['d']] + assert _partition('abcde', [1, 0, 1, 2, 0], 3) == [ + ['b', 'e'], ['a', 'c'], ['d']] + output = (3, [1, 0, 1, 2, 0]) + assert _partition('abcde', *output) == [['b', 'e'], ['a', 'c'], ['d']] + + +def test_ordered_partitions(): + from sympy.functions.combinatorial.numbers import nT + f = ordered_partitions + assert list(f(0, 1)) == [[]] + assert list(f(1, 0)) == [[]] + for i in range(1, 7): + for j in [None] + list(range(1, i)): + assert ( + sum(1 for p in f(i, j, 1)) == + sum(1 for p in f(i, j, 0)) == + nT(i, j)) + + +def test_rotations(): + assert list(rotations('ab')) == [['a', 'b'], ['b', 'a']] + assert list(rotations(range(3))) == [[0, 1, 2], [1, 2, 0], [2, 0, 1]] + assert list(rotations(range(3), dir=-1)) == [[0, 1, 2], [2, 0, 1], [1, 2, 0]] + + +def test_ibin(): + assert ibin(3) == [1, 1] + assert ibin(3, 3) == [0, 1, 1] + assert ibin(3, str=True) == '11' + assert ibin(3, 3, str=True) == '011' + assert list(ibin(2, 'all')) == [(0, 0), (0, 1), (1, 0), (1, 1)] + assert list(ibin(2, '', str=True)) == ['00', '01', '10', '11'] + raises(ValueError, lambda: ibin(-.5)) + raises(ValueError, lambda: ibin(2, 1)) + + +def test_iterable(): + assert iterable(0) is False + assert iterable(1) is False + assert iterable(None) is False + + class Test1(NotIterable): + pass + + assert iterable(Test1()) is False + + class Test2(NotIterable): + _iterable = True + + assert iterable(Test2()) is True + + class Test3: + pass + + assert iterable(Test3()) is False + + class Test4: + _iterable = True + + assert iterable(Test4()) is True + + class Test5: + def __iter__(self): + yield 1 + + assert iterable(Test5()) is True + + class Test6(Test5): + _iterable = False + + assert iterable(Test6()) is False + + +def test_sequence_partitions(): + assert list(sequence_partitions([1], 1)) == [[[1]]] + assert list(sequence_partitions([1, 2], 1)) == [[[1, 2]]] + assert list(sequence_partitions([1, 2], 2)) == [[[1], [2]]] + assert list(sequence_partitions([1, 2, 3], 1)) == [[[1, 2, 3]]] + assert list(sequence_partitions([1, 2, 3], 2)) == \ + [[[1], [2, 3]], [[1, 2], [3]]] + assert list(sequence_partitions([1, 2, 3], 3)) == [[[1], [2], [3]]] + + # Exceptional cases + assert list(sequence_partitions([], 0)) == [] + assert list(sequence_partitions([], 1)) == [] + assert list(sequence_partitions([1, 2], 0)) == [] + assert list(sequence_partitions([1, 2], 3)) == [] + + +def test_sequence_partitions_empty(): + assert list(sequence_partitions_empty([], 1)) == [[[]]] + assert list(sequence_partitions_empty([], 2)) == [[[], []]] + assert list(sequence_partitions_empty([], 3)) == [[[], [], []]] + assert list(sequence_partitions_empty([1], 1)) == [[[1]]] + assert list(sequence_partitions_empty([1], 2)) == [[[], [1]], [[1], []]] + assert list(sequence_partitions_empty([1], 3)) == \ + [[[], [], [1]], [[], [1], []], [[1], [], []]] + assert list(sequence_partitions_empty([1, 2], 1)) == [[[1, 2]]] + assert list(sequence_partitions_empty([1, 2], 2)) == \ + [[[], [1, 2]], [[1], [2]], [[1, 2], []]] + assert list(sequence_partitions_empty([1, 2], 3)) == [ + [[], [], [1, 2]], [[], [1], [2]], [[], [1, 2], []], + [[1], [], [2]], [[1], [2], []], [[1, 2], [], []] + ] + assert list(sequence_partitions_empty([1, 2, 3], 1)) == [[[1, 2, 3]]] + assert list(sequence_partitions_empty([1, 2, 3], 2)) == \ + [[[], [1, 2, 3]], [[1], [2, 3]], [[1, 2], [3]], [[1, 2, 3], []]] + assert list(sequence_partitions_empty([1, 2, 3], 3)) == [ + [[], [], [1, 2, 3]], [[], [1], [2, 3]], + [[], [1, 2], [3]], [[], [1, 2, 3], []], + [[1], [], [2, 3]], [[1], [2], [3]], + [[1], [2, 3], []], [[1, 2], [], [3]], + [[1, 2], [3], []], [[1, 2, 3], [], []] + ] + + # Exceptional cases + assert list(sequence_partitions([], 0)) == [] + assert list(sequence_partitions([1], 0)) == [] + assert list(sequence_partitions([1, 2], 0)) == [] + + +def test_signed_permutations(): + ans = [(0, 1, 1), (0, -1, 1), (0, 1, -1), (0, -1, -1), + (1, 0, 1), (-1, 0, 1), (1, 0, -1), (-1, 0, -1), + (1, 1, 0), (-1, 1, 0), (1, -1, 0), (-1, -1, 0)] + assert list(signed_permutations((0, 1, 1))) == ans + assert list(signed_permutations((1, 0, 1))) == ans + assert list(signed_permutations((1, 1, 0))) == ans diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_lambdify.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_lambdify.py new file mode 100644 index 0000000000000000000000000000000000000000..631cd7bfb67cec64e87fb6681b4e3b79ef3f37e8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_lambdify.py @@ -0,0 +1,1917 @@ +from itertools import product +import math +import inspect + + + +import mpmath +from sympy.testing.pytest import raises, warns_deprecated_sympy +from sympy.concrete.summations import Sum +from sympy.core.function import (Function, Lambda, diff) +from sympy.core.numbers import (E, Float, I, Rational, all_close, oo, pi) +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.core.symbol import (Dummy, symbols) +from sympy.functions.combinatorial.factorials import (RisingFactorial, factorial) +from sympy.functions.combinatorial.numbers import bernoulli, harmonic +from sympy.functions.elementary.complexes import Abs +from sympy.functions.elementary.exponential import exp, log +from sympy.functions.elementary.hyperbolic import acosh +from sympy.functions.elementary.integers import floor +from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt) +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, cos, cot, sin, + sinc, tan) +from sympy.functions.special.bessel import (besseli, besselj, besselk, bessely, jn, yn) +from sympy.functions.special.beta_functions import (beta, betainc, betainc_regularized) +from sympy.functions.special.delta_functions import (Heaviside) +from sympy.functions.special.error_functions import (Ei, erf, erfc, fresnelc, fresnels, Si, Ci) +from sympy.functions.special.gamma_functions import (digamma, gamma, loggamma, polygamma) +from sympy.integrals.integrals import Integral +from sympy.logic.boolalg import (And, false, ITE, Not, Or, true) +from sympy.matrices.expressions.dotproduct import DotProduct +from sympy.simplify.cse_main import cse +from sympy.tensor.array import derive_by_array, Array +from sympy.tensor.indexed import IndexedBase +from sympy.utilities.lambdify import lambdify +from sympy.utilities.iterables import numbered_symbols +from sympy.vector import CoordSys3D +from sympy.core.expr import UnevaluatedExpr +from sympy.codegen.cfunctions import expm1, log1p, exp2, log2, log10, hypot +from sympy.codegen.numpy_nodes import logaddexp, logaddexp2 +from sympy.codegen.scipy_nodes import cosm1, powm1 +from sympy.functions.elementary.complexes import re, im, arg +from sympy.functions.special.polynomials import \ + chebyshevt, chebyshevu, legendre, hermite, laguerre, gegenbauer, \ + assoc_legendre, assoc_laguerre, jacobi +from sympy.matrices import Matrix, MatrixSymbol, SparseMatrix +from sympy.printing.lambdarepr import LambdaPrinter +from sympy.printing.numpy import NumPyPrinter +from sympy.utilities.lambdify import implemented_function, lambdastr +from sympy.testing.pytest import skip +from sympy.utilities.decorator import conserve_mpmath_dps +from sympy.utilities.exceptions import ignore_warnings +from sympy.external import import_module +from sympy.functions.special.gamma_functions import uppergamma, lowergamma + + +import sympy + + +MutableDenseMatrix = Matrix + +numpy = import_module('numpy') +scipy = import_module('scipy', import_kwargs={'fromlist': ['sparse']}) +numexpr = import_module('numexpr') +tensorflow = import_module('tensorflow') +cupy = import_module('cupy') +jax = import_module('jax') +numba = import_module('numba') + +if tensorflow: + # Hide Tensorflow warnings + import os + os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' + +w, x, y, z = symbols('w,x,y,z') + +#================== Test different arguments ======================= + + +def test_no_args(): + f = lambdify([], 1) + raises(TypeError, lambda: f(-1)) + assert f() == 1 + + +def test_single_arg(): + f = lambdify(x, 2*x) + assert f(1) == 2 + + +def test_list_args(): + f = lambdify([x, y], x + y) + assert f(1, 2) == 3 + + +def test_nested_args(): + f1 = lambdify([[w, x]], [w, x]) + assert f1([91, 2]) == [91, 2] + raises(TypeError, lambda: f1(1, 2)) + + f2 = lambdify([(w, x), (y, z)], [w, x, y, z]) + assert f2((18, 12), (73, 4)) == [18, 12, 73, 4] + raises(TypeError, lambda: f2(3, 4)) + + f3 = lambdify([w, [[[x]], y], z], [w, x, y, z]) + assert f3(10, [[[52]], 31], 44) == [10, 52, 31, 44] + + +def test_str_args(): + f = lambdify('x,y,z', 'z,y,x') + assert f(3, 2, 1) == (1, 2, 3) + assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0) + # make sure correct number of args required + raises(TypeError, lambda: f(0)) + + +def test_own_namespace_1(): + myfunc = lambda x: 1 + f = lambdify(x, sin(x), {"sin": myfunc}) + assert f(0.1) == 1 + assert f(100) == 1 + + +def test_own_namespace_2(): + def myfunc(x): + return 1 + f = lambdify(x, sin(x), {'sin': myfunc}) + assert f(0.1) == 1 + assert f(100) == 1 + + +def test_own_module(): + f = lambdify(x, sin(x), math) + assert f(0) == 0.0 + + p, q, r = symbols("p q r", real=True) + ae = abs(exp(p+UnevaluatedExpr(q+r))) + f = lambdify([p, q, r], [ae, ae], modules=math) + results = f(1.0, 1e18, -1e18) + refvals = [math.exp(1.0)]*2 + for res, ref in zip(results, refvals): + assert abs((res-ref)/ref) < 1e-15 + + +def test_bad_args(): + # no vargs given + raises(TypeError, lambda: lambdify(1)) + # same with vector exprs + raises(TypeError, lambda: lambdify([1, 2])) + + +def test_atoms(): + # Non-Symbol atoms should not be pulled out from the expression namespace + f = lambdify(x, pi + x, {"pi": 3.14}) + assert f(0) == 3.14 + f = lambdify(x, I + x, {"I": 1j}) + assert f(1) == 1 + 1j + +#================== Test different modules ========================= + +# high precision output of sin(0.2*pi) is used to detect if precision is lost unwanted + + +@conserve_mpmath_dps +def test_sympy_lambda(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin(x), "sympy") + assert f(x) == sin(x) + prec = 1e-15 + assert -prec < f(Rational(1, 5)).evalf() - Float(str(sin02)) < prec + # arctan is in numpy module and should not be available + # The arctan below gives NameError. What is this supposed to test? + # raises(NameError, lambda: lambdify(x, arctan(x), "sympy")) + + +@conserve_mpmath_dps +def test_math_lambda(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin(x), "math") + prec = 1e-15 + assert -prec < f(0.2) - sin02 < prec + raises(TypeError, lambda: f(x)) + # if this succeeds, it can't be a Python math function + + +@conserve_mpmath_dps +def test_mpmath_lambda(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin(x), "mpmath") + prec = 1e-49 # mpmath precision is around 50 decimal places + assert -prec < f(mpmath.mpf("0.2")) - sin02 < prec + raises(TypeError, lambda: f(x)) + # if this succeeds, it can't be a mpmath function + + ref2 = (mpmath.mpf("1e-30") + - mpmath.mpf("1e-45")/2 + + 5*mpmath.mpf("1e-60")/6 + - 3*mpmath.mpf("1e-75")/4 + + 33*mpmath.mpf("1e-90")/40 + ) + f2a = lambdify((x, y), x**y - 1, "mpmath") + f2b = lambdify((x, y), powm1(x, y), "mpmath") + f2c = lambdify((x,), expm1(x*log1p(x)), "mpmath") + ans2a = f2a(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15")) + ans2b = f2b(mpmath.mpf("1")+mpmath.mpf("1e-15"), mpmath.mpf("1e-15")) + ans2c = f2c(mpmath.mpf("1e-15")) + assert abs(ans2a - ref2) < 1e-51 + assert abs(ans2b - ref2) < 1e-67 + assert abs(ans2c - ref2) < 1e-80 + + +@conserve_mpmath_dps +def test_number_precision(): + mpmath.mp.dps = 50 + sin02 = mpmath.mpf("0.19866933079506121545941262711838975037020672954020") + f = lambdify(x, sin02, "mpmath") + prec = 1e-49 # mpmath precision is around 50 decimal places + assert -prec < f(0) - sin02 < prec + +@conserve_mpmath_dps +def test_mpmath_precision(): + mpmath.mp.dps = 100 + assert str(lambdify((), pi.evalf(100), 'mpmath')()) == str(pi.evalf(100)) + +#================== Test Translations ============================== +# We can only check if all translated functions are valid. It has to be checked +# by hand if they are complete. + + +def test_math_transl(): + from sympy.utilities.lambdify import MATH_TRANSLATIONS + for sym, mat in MATH_TRANSLATIONS.items(): + assert sym in sympy.__dict__ + assert mat in math.__dict__ + + +def test_mpmath_transl(): + from sympy.utilities.lambdify import MPMATH_TRANSLATIONS + for sym, mat in MPMATH_TRANSLATIONS.items(): + assert sym in sympy.__dict__ or sym == 'Matrix' + assert mat in mpmath.__dict__ + + +def test_numpy_transl(): + if not numpy: + skip("numpy not installed.") + + from sympy.utilities.lambdify import NUMPY_TRANSLATIONS + for sym, nump in NUMPY_TRANSLATIONS.items(): + assert sym in sympy.__dict__ + assert nump in numpy.__dict__ + + +def test_scipy_transl(): + if not scipy: + skip("scipy not installed.") + + from sympy.utilities.lambdify import SCIPY_TRANSLATIONS + for sym, scip in SCIPY_TRANSLATIONS.items(): + assert sym in sympy.__dict__ + assert scip in scipy.__dict__ or scip in scipy.special.__dict__ + + +def test_numpy_translation_abs(): + if not numpy: + skip("numpy not installed.") + + f = lambdify(x, Abs(x), "numpy") + assert f(-1) == 1 + assert f(1) == 1 + + +def test_numexpr_printer(): + if not numexpr: + skip("numexpr not installed.") + + # if translation/printing is done incorrectly then evaluating + # a lambdified numexpr expression will throw an exception + from sympy.printing.lambdarepr import NumExprPrinter + + blacklist = ('where', 'complex', 'contains') + arg_tuple = (x, y, z) # some functions take more than one argument + for sym in NumExprPrinter._numexpr_functions.keys(): + if sym in blacklist: + continue + ssym = S(sym) + if hasattr(ssym, '_nargs'): + nargs = ssym._nargs[0] + else: + nargs = 1 + args = arg_tuple[:nargs] + f = lambdify(args, ssym(*args), modules='numexpr') + assert f(*(1, )*nargs) is not None + + +def test_issue_9334(): + if not numexpr: + skip("numexpr not installed.") + if not numpy: + skip("numpy not installed.") + expr = S('b*a - sqrt(a**2)') + a, b = sorted(expr.free_symbols, key=lambda s: s.name) + func_numexpr = lambdify((a,b), expr, modules=[numexpr], dummify=False) + foo, bar = numpy.random.random((2, 4)) + func_numexpr(foo, bar) + + +def test_issue_12984(): + if not numexpr: + skip("numexpr not installed.") + func_numexpr = lambdify((x,y,z), Piecewise((y, x >= 0), (z, x > -1)), numexpr) + with ignore_warnings(RuntimeWarning): + assert func_numexpr(1, 24, 42) == 24 + assert str(func_numexpr(-1, 24, 42)) == 'nan' + + +def test_empty_modules(): + x, y = symbols('x y') + expr = -(x % y) + + no_modules = lambdify([x, y], expr) + empty_modules = lambdify([x, y], expr, modules=[]) + assert no_modules(3, 7) == empty_modules(3, 7) + assert no_modules(3, 7) == -3 + + +def test_exponentiation(): + f = lambdify(x, x**2) + assert f(-1) == 1 + assert f(0) == 0 + assert f(1) == 1 + assert f(-2) == 4 + assert f(2) == 4 + assert f(2.5) == 6.25 + + +def test_sqrt(): + f = lambdify(x, sqrt(x)) + assert f(0) == 0.0 + assert f(1) == 1.0 + assert f(4) == 2.0 + assert abs(f(2) - 1.414) < 0.001 + assert f(6.25) == 2.5 + + +def test_trig(): + f = lambdify([x], [cos(x), sin(x)], 'math') + d = f(pi) + prec = 1e-11 + assert -prec < d[0] + 1 < prec + assert -prec < d[1] < prec + d = f(3.14159) + prec = 1e-5 + assert -prec < d[0] + 1 < prec + assert -prec < d[1] < prec + + +def test_integral(): + if numpy and not scipy: + skip("scipy not installed.") + f = Lambda(x, exp(-x**2)) + l = lambdify(y, Integral(f(x), (x, y, oo))) + d = l(-oo) + assert 1.77245385 < d < 1.772453851 + + +def test_double_integral(): + if numpy and not scipy: + skip("scipy not installed.") + # example from http://mpmath.org/doc/current/calculus/integration.html + i = Integral(1/(1 - x**2*y**2), (x, 0, 1), (y, 0, z)) + l = lambdify([z], i) + d = l(1) + assert 1.23370055 < d < 1.233700551 + +def test_spherical_bessel(): + if numpy and not scipy: + skip("scipy not installed.") + test_point = 4.2 #randomly selected + x = symbols("x") + jtest = jn(2, x) + assert abs(lambdify(x,jtest)(test_point) - + jtest.subs(x,test_point).evalf()) < 1e-8 + ytest = yn(2, x) + assert abs(lambdify(x,ytest)(test_point) - + ytest.subs(x,test_point).evalf()) < 1e-8 + + +#================== Test vectors =================================== + + +def test_vector_simple(): + f = lambdify((x, y, z), (z, y, x)) + assert f(3, 2, 1) == (1, 2, 3) + assert f(1.0, 2.0, 3.0) == (3.0, 2.0, 1.0) + # make sure correct number of args required + raises(TypeError, lambda: f(0)) + + +def test_vector_discontinuous(): + f = lambdify(x, (-1/x, 1/x)) + raises(ZeroDivisionError, lambda: f(0)) + assert f(1) == (-1.0, 1.0) + assert f(2) == (-0.5, 0.5) + assert f(-2) == (0.5, -0.5) + + +def test_trig_symbolic(): + f = lambdify([x], [cos(x), sin(x)], 'math') + d = f(pi) + assert abs(d[0] + 1) < 0.0001 + assert abs(d[1] - 0) < 0.0001 + + +def test_trig_float(): + f = lambdify([x], [cos(x), sin(x)]) + d = f(3.14159) + assert abs(d[0] + 1) < 0.0001 + assert abs(d[1] - 0) < 0.0001 + + +def test_docs(): + f = lambdify(x, x**2) + assert f(2) == 4 + f = lambdify([x, y, z], [z, y, x]) + assert f(1, 2, 3) == [3, 2, 1] + f = lambdify(x, sqrt(x)) + assert f(4) == 2.0 + f = lambdify((x, y), sin(x*y)**2) + assert f(0, 5) == 0 + + +def test_math(): + f = lambdify((x, y), sin(x), modules="math") + assert f(0, 5) == 0 + + +def test_sin(): + f = lambdify(x, sin(x)**2) + assert isinstance(f(2), float) + f = lambdify(x, sin(x)**2, modules="math") + assert isinstance(f(2), float) + + +def test_matrix(): + A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) + sol = Matrix([[1, 2], [sin(3) + 4, 1]]) + f = lambdify((x, y, z), A, modules="sympy") + assert f(1, 2, 3) == sol + f = lambdify((x, y, z), (A, [A]), modules="sympy") + assert f(1, 2, 3) == (sol, [sol]) + J = Matrix((x, x + y)).jacobian((x, y)) + v = Matrix((x, y)) + sol = Matrix([[1, 0], [1, 1]]) + assert lambdify(v, J, modules='sympy')(1, 2) == sol + assert lambdify(v.T, J, modules='sympy')(1, 2) == sol + + +def test_numpy_matrix(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) + sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]]) + #Lambdify array first, to ensure return to array as default + f = lambdify((x, y, z), A, ['numpy']) + numpy.testing.assert_allclose(f(1, 2, 3), sol_arr) + #Check that the types are arrays and matrices + assert isinstance(f(1, 2, 3), numpy.ndarray) + + # gh-15071 + class dot(Function): + pass + x_dot_mtx = dot(x, Matrix([[2], [1], [0]])) + f_dot1 = lambdify(x, x_dot_mtx) + inp = numpy.zeros((17, 3)) + assert numpy.all(f_dot1(inp) == 0) + + strict_kw = {"allow_unknown_functions": False, "inline": True, "fully_qualified_modules": False} + p2 = NumPyPrinter(dict(user_functions={'dot': 'dot'}, **strict_kw)) + f_dot2 = lambdify(x, x_dot_mtx, printer=p2) + assert numpy.all(f_dot2(inp) == 0) + + p3 = NumPyPrinter(strict_kw) + # The line below should probably fail upon construction (before calling with "(inp)"): + raises(Exception, lambda: lambdify(x, x_dot_mtx, printer=p3)(inp)) + + +def test_numpy_transpose(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[1, x], [0, 1]]) + f = lambdify((x), A.T, modules="numpy") + numpy.testing.assert_array_equal(f(2), numpy.array([[1, 0], [2, 1]])) + + +def test_numpy_dotproduct(): + if not numpy: + skip("numpy not installed") + A = Matrix([x, y, z]) + f1 = lambdify([x, y, z], DotProduct(A, A), modules='numpy') + f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy') + f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='numpy') + f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='numpy') + + assert f1(1, 2, 3) == \ + f2(1, 2, 3) == \ + f3(1, 2, 3) == \ + f4(1, 2, 3) == \ + numpy.array([14]) + + +def test_numpy_inverse(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[1, x], [0, 1]]) + f = lambdify((x), A**-1, modules="numpy") + numpy.testing.assert_array_equal(f(2), numpy.array([[1, -2], [0, 1]])) + + +def test_numpy_old_matrix(): + if not numpy: + skip("numpy not installed.") + A = Matrix([[x, x*y], [sin(z) + 4, x**z]]) + sol_arr = numpy.array([[1, 2], [numpy.sin(3) + 4, 1]]) + f = lambdify((x, y, z), A, [{'ImmutableDenseMatrix': numpy.matrix}, 'numpy']) + with ignore_warnings(PendingDeprecationWarning): + numpy.testing.assert_allclose(f(1, 2, 3), sol_arr) + assert isinstance(f(1, 2, 3), numpy.matrix) + + +def test_scipy_sparse_matrix(): + if not scipy: + skip("scipy not installed.") + A = SparseMatrix([[x, 0], [0, y]]) + f = lambdify((x, y), A, modules="scipy") + B = f(1, 2) + assert isinstance(B, scipy.sparse.coo_matrix) + + +def test_python_div_zero_issue_11306(): + if not numpy: + skip("numpy not installed.") + p = Piecewise((1 / x, y < -1), (x, y < 1), (1 / x, True)) + f = lambdify([x, y], p, modules='numpy') + with numpy.errstate(divide='ignore'): + assert float(f(numpy.array(0), numpy.array(0.5))) == 0 + assert float(f(numpy.array(0), numpy.array(1))) == float('inf') + + +def test_issue9474(): + mods = [None, 'math'] + if numpy: + mods.append('numpy') + if mpmath: + mods.append('mpmath') + for mod in mods: + f = lambdify(x, S.One/x, modules=mod) + assert f(2) == 0.5 + f = lambdify(x, floor(S.One/x), modules=mod) + assert f(2) == 0 + + for absfunc, modules in product([Abs, abs], mods): + f = lambdify(x, absfunc(x), modules=modules) + assert f(-1) == 1 + assert f(1) == 1 + assert f(3+4j) == 5 + + +def test_issue_9871(): + if not numexpr: + skip("numexpr not installed.") + if not numpy: + skip("numpy not installed.") + + r = sqrt(x**2 + y**2) + expr = diff(1/r, x) + + xn = yn = numpy.linspace(1, 10, 16) + # expr(xn, xn) = -xn/(sqrt(2)*xn)^3 + fv_exact = -numpy.sqrt(2.)**-3 * xn**-2 + + fv_numpy = lambdify((x, y), expr, modules='numpy')(xn, yn) + fv_numexpr = lambdify((x, y), expr, modules='numexpr')(xn, yn) + numpy.testing.assert_allclose(fv_numpy, fv_exact, rtol=1e-10) + numpy.testing.assert_allclose(fv_numexpr, fv_exact, rtol=1e-10) + + +def test_numpy_piecewise(): + if not numpy: + skip("numpy not installed.") + pieces = Piecewise((x, x < 3), (x**2, x > 5), (0, True)) + f = lambdify(x, pieces, modules="numpy") + numpy.testing.assert_array_equal(f(numpy.arange(10)), + numpy.array([0, 1, 2, 0, 0, 0, 36, 49, 64, 81])) + # If we evaluate somewhere all conditions are False, we should get back NaN + nodef_func = lambdify(x, Piecewise((x, x > 0), (-x, x < 0))) + numpy.testing.assert_array_equal(nodef_func(numpy.array([-1, 0, 1])), + numpy.array([1, numpy.nan, 1])) + + +def test_numpy_logical_ops(): + if not numpy: + skip("numpy not installed.") + and_func = lambdify((x, y), And(x, y), modules="numpy") + and_func_3 = lambdify((x, y, z), And(x, y, z), modules="numpy") + or_func = lambdify((x, y), Or(x, y), modules="numpy") + or_func_3 = lambdify((x, y, z), Or(x, y, z), modules="numpy") + not_func = lambdify((x), Not(x), modules="numpy") + arr1 = numpy.array([True, True]) + arr2 = numpy.array([False, True]) + arr3 = numpy.array([True, False]) + numpy.testing.assert_array_equal(and_func(arr1, arr2), numpy.array([False, True])) + numpy.testing.assert_array_equal(and_func_3(arr1, arr2, arr3), numpy.array([False, False])) + numpy.testing.assert_array_equal(or_func(arr1, arr2), numpy.array([True, True])) + numpy.testing.assert_array_equal(or_func_3(arr1, arr2, arr3), numpy.array([True, True])) + numpy.testing.assert_array_equal(not_func(arr2), numpy.array([True, False])) + + +def test_numpy_matmul(): + if not numpy: + skip("numpy not installed.") + xmat = Matrix([[x, y], [z, 1+z]]) + ymat = Matrix([[x**2], [Abs(x)]]) + mat_func = lambdify((x, y, z), xmat*ymat, modules="numpy") + numpy.testing.assert_array_equal(mat_func(0.5, 3, 4), numpy.array([[1.625], [3.5]])) + numpy.testing.assert_array_equal(mat_func(-0.5, 3, 4), numpy.array([[1.375], [3.5]])) + # Multiple matrices chained together in multiplication + f = lambdify((x, y, z), xmat*xmat*xmat, modules="numpy") + numpy.testing.assert_array_equal(f(0.5, 3, 4), numpy.array([[72.125, 119.25], + [159, 251]])) + + +def test_numpy_numexpr(): + if not numpy: + skip("numpy not installed.") + if not numexpr: + skip("numexpr not installed.") + a, b, c = numpy.random.randn(3, 128, 128) + # ensure that numpy and numexpr return same value for complicated expression + expr = sin(x) + cos(y) + tan(z)**2 + Abs(z-y)*acos(sin(y*z)) + \ + Abs(y-z)*acosh(2+exp(y-x))- sqrt(x**2+I*y**2) + npfunc = lambdify((x, y, z), expr, modules='numpy') + nefunc = lambdify((x, y, z), expr, modules='numexpr') + assert numpy.allclose(npfunc(a, b, c), nefunc(a, b, c)) + + +def test_numexpr_userfunctions(): + if not numpy: + skip("numpy not installed.") + if not numexpr: + skip("numexpr not installed.") + a, b = numpy.random.randn(2, 10) + uf = type('uf', (Function, ), + {'eval' : classmethod(lambda x, y : y**2+1)}) + func = lambdify(x, 1-uf(x), modules='numexpr') + assert numpy.allclose(func(a), -(a**2)) + + uf = implemented_function(Function('uf'), lambda x, y : 2*x*y+1) + func = lambdify((x, y), uf(x, y), modules='numexpr') + assert numpy.allclose(func(a, b), 2*a*b+1) + + +def test_tensorflow_basic_math(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(sin(x), Abs(1/(x+2))) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + a = tensorflow.constant(0, dtype=tensorflow.float32) + assert func(a).eval(session=s) == 0.5 + + +def test_tensorflow_placeholders(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(sin(x), Abs(1/(x+2))) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + a = tensorflow.compat.v1.placeholder(dtype=tensorflow.float32) + assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5 + + +def test_tensorflow_variables(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(sin(x), Abs(1/(x+2))) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + a = tensorflow.Variable(0, dtype=tensorflow.float32) + s.run(a.initializer) + assert func(a).eval(session=s, feed_dict={a: 0}) == 0.5 + + +def test_tensorflow_logical_operations(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Not(And(Or(x, y), y)) + func = lambdify([x, y], expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(False, True).eval(session=s) == False + + +def test_tensorflow_piecewise(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Piecewise((0, Eq(x,0)), (-1, x < 0), (1, x > 0)) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(-1).eval(session=s) == -1 + assert func(0).eval(session=s) == 0 + assert func(1).eval(session=s) == 1 + + +def test_tensorflow_multi_max(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Max(x, -x, x**2) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(-2).eval(session=s) == 4 + + +def test_tensorflow_multi_min(): + if not tensorflow: + skip("tensorflow not installed.") + expr = Min(x, -x, x**2) + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(-2).eval(session=s) == -2 + + +def test_tensorflow_relational(): + if not tensorflow: + skip("tensorflow not installed.") + expr = x >= 0 + func = lambdify(x, expr, modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + assert func(1).eval(session=s) == True + + +def test_tensorflow_complexes(): + if not tensorflow: + skip("tensorflow not installed") + + func1 = lambdify(x, re(x), modules="tensorflow") + func2 = lambdify(x, im(x), modules="tensorflow") + func3 = lambdify(x, Abs(x), modules="tensorflow") + func4 = lambdify(x, arg(x), modules="tensorflow") + + with tensorflow.compat.v1.Session() as s: + # For versions before + # https://github.com/tensorflow/tensorflow/issues/30029 + # resolved, using Python numeric types may not work + a = tensorflow.constant(1+2j) + assert func1(a).eval(session=s) == 1 + assert func2(a).eval(session=s) == 2 + + tensorflow_result = func3(a).eval(session=s) + sympy_result = Abs(1 + 2j).evalf() + assert abs(tensorflow_result-sympy_result) < 10**-6 + + tensorflow_result = func4(a).eval(session=s) + sympy_result = arg(1 + 2j).evalf() + assert abs(tensorflow_result-sympy_result) < 10**-6 + + +def test_tensorflow_array_arg(): + # Test for issue 14655 (tensorflow part) + if not tensorflow: + skip("tensorflow not installed.") + + f = lambdify([[x, y]], x*x + y, 'tensorflow') + + with tensorflow.compat.v1.Session() as s: + fcall = f(tensorflow.constant([2.0, 1.0])) + assert fcall.eval(session=s) == 5.0 + + +#================== Test symbolic ================================== + + +def test_sym_single_arg(): + f = lambdify(x, x * y) + assert f(z) == z * y + + +def test_sym_list_args(): + f = lambdify([x, y], x + y + z) + assert f(1, 2) == 3 + z + + +def test_sym_integral(): + f = Lambda(x, exp(-x**2)) + l = lambdify(x, Integral(f(x), (x, -oo, oo)), modules="sympy") + assert l(y) == Integral(exp(-y**2), (y, -oo, oo)) + assert l(y).doit() == sqrt(pi) + + +def test_namespace_order(): + # lambdify had a bug, such that module dictionaries or cached module + # dictionaries would pull earlier namespaces into themselves. + # Because the module dictionaries form the namespace of the + # generated lambda, this meant that the behavior of a previously + # generated lambda function could change as a result of later calls + # to lambdify. + n1 = {'f': lambda x: 'first f'} + n2 = {'f': lambda x: 'second f', + 'g': lambda x: 'function g'} + f = sympy.Function('f') + g = sympy.Function('g') + if1 = lambdify(x, f(x), modules=(n1, "sympy")) + assert if1(1) == 'first f' + if2 = lambdify(x, g(x), modules=(n2, "sympy")) + # previously gave 'second f' + assert if1(1) == 'first f' + + assert if2(1) == 'function g' + + +def test_imps(): + # Here we check if the default returned functions are anonymous - in + # the sense that we can have more than one function with the same name + f = implemented_function('f', lambda x: 2*x) + g = implemented_function('f', lambda x: math.sqrt(x)) + l1 = lambdify(x, f(x)) + l2 = lambdify(x, g(x)) + assert str(f(x)) == str(g(x)) + assert l1(3) == 6 + assert l2(3) == math.sqrt(3) + # check that we can pass in a Function as input + func = sympy.Function('myfunc') + assert not hasattr(func, '_imp_') + my_f = implemented_function(func, lambda x: 2*x) + assert hasattr(my_f, '_imp_') + # Error for functions with same name and different implementation + f2 = implemented_function("f", lambda x: x + 101) + raises(ValueError, lambda: lambdify(x, f(f2(x)))) + + +def test_imps_errors(): + # Test errors that implemented functions can return, and still be able to + # form expressions. + # See: https://github.com/sympy/sympy/issues/10810 + # + # XXX: Removed AttributeError here. This test was added due to issue 10810 + # but that issue was about ValueError. It doesn't seem reasonable to + # "support" catching AttributeError in the same context... + for val, error_class in product((0, 0., 2, 2.0), (TypeError, ValueError)): + + def myfunc(a): + if a == 0: + raise error_class + return 1 + + f = implemented_function('f', myfunc) + expr = f(val) + assert expr == f(val) + + +def test_imps_wrong_args(): + raises(ValueError, lambda: implemented_function(sin, lambda x: x)) + + +def test_lambdify_imps(): + # Test lambdify with implemented functions + # first test basic (sympy) lambdify + f = sympy.cos + assert lambdify(x, f(x))(0) == 1 + assert lambdify(x, 1 + f(x))(0) == 2 + assert lambdify((x, y), y + f(x))(0, 1) == 2 + # make an implemented function and test + f = implemented_function("f", lambda x: x + 100) + assert lambdify(x, f(x))(0) == 100 + assert lambdify(x, 1 + f(x))(0) == 101 + assert lambdify((x, y), y + f(x))(0, 1) == 101 + # Can also handle tuples, lists, dicts as expressions + lam = lambdify(x, (f(x), x)) + assert lam(3) == (103, 3) + lam = lambdify(x, [f(x), x]) + assert lam(3) == [103, 3] + lam = lambdify(x, [f(x), (f(x), x)]) + assert lam(3) == [103, (103, 3)] + lam = lambdify(x, {f(x): x}) + assert lam(3) == {103: 3} + lam = lambdify(x, {f(x): x}) + assert lam(3) == {103: 3} + lam = lambdify(x, {x: f(x)}) + assert lam(3) == {3: 103} + # Check that imp preferred to other namespaces by default + d = {'f': lambda x: x + 99} + lam = lambdify(x, f(x), d) + assert lam(3) == 103 + # Unless flag passed + lam = lambdify(x, f(x), d, use_imps=False) + assert lam(3) == 102 + + +def test_dummification(): + t = symbols('t') + F = Function('F') + G = Function('G') + #"\alpha" is not a valid Python variable name + #lambdify should sub in a dummy for it, and return + #without a syntax error + alpha = symbols(r'\alpha') + some_expr = 2 * F(t)**2 / G(t) + lam = lambdify((F(t), G(t)), some_expr) + assert lam(3, 9) == 2 + lam = lambdify(sin(t), 2 * sin(t)**2) + assert lam(F(t)) == 2 * F(t)**2 + #Test that \alpha was properly dummified + lam = lambdify((alpha, t), 2*alpha + t) + assert lam(2, 1) == 5 + raises(SyntaxError, lambda: lambdify(F(t) * G(t), F(t) * G(t) + 5)) + raises(SyntaxError, lambda: lambdify(2 * F(t), 2 * F(t) + 5)) + raises(SyntaxError, lambda: lambdify(2 * F(t), 4 * F(t) + 5)) + + +def test_lambdify__arguments_with_invalid_python_identifiers(): + # see sympy/sympy#26690 + N = CoordSys3D('N') + xn, yn, zn = N.base_scalars() + expr = xn + yn + f = lambdify([xn, yn], expr) + res = f(0.2, 0.3) + ref = 0.2 + 0.3 + assert abs(res-ref) < 1e-15 + + +def test_curly_matrix_symbol(): + # Issue #15009 + curlyv = sympy.MatrixSymbol("{v}", 2, 1) + lam = lambdify(curlyv, curlyv) + assert lam(1)==1 + lam = lambdify(curlyv, curlyv, dummify=True) + assert lam(1)==1 + + +def test_python_keywords(): + # Test for issue 7452. The automatic dummification should ensure use of + # Python reserved keywords as symbol names will create valid lambda + # functions. This is an additional regression test. + python_if = symbols('if') + expr = python_if / 2 + f = lambdify(python_if, expr) + assert f(4.0) == 2.0 + + +def test_lambdify_docstring(): + func = lambdify((w, x, y, z), w + x + y + z) + ref = ( + "Created with lambdify. Signature:\n\n" + "func(w, x, y, z)\n\n" + "Expression:\n\n" + "w + x + y + z" + ).splitlines() + assert func.__doc__.splitlines()[:len(ref)] == ref + syms = symbols('a1:26') + func = lambdify(syms, sum(syms)) + ref = ( + "Created with lambdify. Signature:\n\n" + "func(a1, a2, a3, a4, a5, a6, a7, a8, a9, a10, a11, a12, a13, a14, a15,\n" + " a16, a17, a18, a19, a20, a21, a22, a23, a24, a25)\n\n" + "Expression:\n\n" + "a1 + a10 + a11 + a12 + a13 + a14 + a15 + a16 + a17 + a18 + a19 + a2 + a20 +..." + ).splitlines() + assert func.__doc__.splitlines()[:len(ref)] == ref + + +#================== Test special printers ========================== + + +def test_special_printers(): + from sympy.printing.lambdarepr import IntervalPrinter + + def intervalrepr(expr): + return IntervalPrinter().doprint(expr) + + expr = sqrt(sqrt(2) + sqrt(3)) + S.Half + + func0 = lambdify((), expr, modules="mpmath", printer=intervalrepr) + func1 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter) + func2 = lambdify((), expr, modules="mpmath", printer=IntervalPrinter()) + + mpi = type(mpmath.mpi(1, 2)) + + assert isinstance(func0(), mpi) + assert isinstance(func1(), mpi) + assert isinstance(func2(), mpi) + + # To check Is lambdify loggamma works for mpmath or not + exp1 = lambdify(x, loggamma(x), 'mpmath')(5) + exp2 = lambdify(x, loggamma(x), 'mpmath')(1.8) + exp3 = lambdify(x, loggamma(x), 'mpmath')(15) + exp_ls = [exp1, exp2, exp3] + + sol1 = mpmath.loggamma(5) + sol2 = mpmath.loggamma(1.8) + sol3 = mpmath.loggamma(15) + sol_ls = [sol1, sol2, sol3] + + assert exp_ls == sol_ls + + +def test_true_false(): + # We want exact is comparison here, not just == + assert lambdify([], true)() is True + assert lambdify([], false)() is False + + +def test_issue_2790(): + assert lambdify((x, (y, z)), x + y)(1, (2, 4)) == 3 + assert lambdify((x, (y, (w, z))), w + x + y + z)(1, (2, (3, 4))) == 10 + assert lambdify(x, x + 1, dummify=False)(1) == 2 + + +def test_issue_12092(): + f = implemented_function('f', lambda x: x**2) + assert f(f(2)).evalf() == Float(16) + + +def test_issue_14911(): + class Variable(sympy.Symbol): + def _sympystr(self, printer): + return printer.doprint(self.name) + + _lambdacode = _sympystr + _numpycode = _sympystr + + x = Variable('x') + y = 2 * x + code = LambdaPrinter().doprint(y) + assert code.replace(' ', '') == '2*x' + + +def test_ITE(): + assert lambdify((x, y, z), ITE(x, y, z))(True, 5, 3) == 5 + assert lambdify((x, y, z), ITE(x, y, z))(False, 5, 3) == 3 + + +def test_Min_Max(): + # see gh-10375 + assert lambdify((x, y, z), Min(x, y, z))(1, 2, 3) == 1 + assert lambdify((x, y, z), Max(x, y, z))(1, 2, 3) == 3 + + +def test_Indexed(): + # Issue #10934 + if not numpy: + skip("numpy not installed") + + a = IndexedBase('a') + i, j = symbols('i j') + b = numpy.array([[1, 2], [3, 4]]) + assert lambdify(a, Sum(a[x, y], (x, 0, 1), (y, 0, 1)))(b) == 10 + + +def test_issue_12173(): + #test for issue 12173 + expr1 = lambdify((x, y), uppergamma(x, y),"mpmath")(1, 2) + expr2 = lambdify((x, y), lowergamma(x, y),"mpmath")(1, 2) + assert expr1 == uppergamma(1, 2).evalf() + assert expr2 == lowergamma(1, 2).evalf() + + +def test_issue_13642(): + if not numpy: + skip("numpy not installed") + f = lambdify(x, sinc(x)) + assert Abs(f(1) - sinc(1)).n() < 1e-15 + + +def test_sinc_mpmath(): + f = lambdify(x, sinc(x), "mpmath") + assert Abs(f(1) - sinc(1)).n() < 1e-15 + + +def test_lambdify_dummy_arg(): + d1 = Dummy() + f1 = lambdify(d1, d1 + 1, dummify=False) + assert f1(2) == 3 + f1b = lambdify(d1, d1 + 1) + assert f1b(2) == 3 + d2 = Dummy('x') + f2 = lambdify(d2, d2 + 1) + assert f2(2) == 3 + f3 = lambdify([[d2]], d2 + 1) + assert f3([2]) == 3 + + +def test_lambdify_mixed_symbol_dummy_args(): + d = Dummy() + # Contrived example of name clash + dsym = symbols(str(d)) + f = lambdify([d, dsym], d - dsym) + assert f(4, 1) == 3 + + +def test_numpy_array_arg(): + # Test for issue 14655 (numpy part) + if not numpy: + skip("numpy not installed") + + f = lambdify([[x, y]], x*x + y, 'numpy') + + assert f(numpy.array([2.0, 1.0])) == 5 + + +def test_scipy_fns(): + if not scipy: + skip("scipy not installed") + + single_arg_sympy_fns = [Ei, erf, erfc, factorial, gamma, loggamma, digamma, Si, Ci] + single_arg_scipy_fns = [scipy.special.expi, scipy.special.erf, scipy.special.erfc, + scipy.special.factorial, scipy.special.gamma, scipy.special.gammaln, + scipy.special.psi, scipy.special.sici, scipy.special.sici] + numpy.random.seed(0) + for (sympy_fn, scipy_fn) in zip(single_arg_sympy_fns, single_arg_scipy_fns): + f = lambdify(x, sympy_fn(x), modules="scipy") + for i in range(20): + tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5) + # SciPy thinks that factorial(z) is 0 when re(z) < 0 and + # does not support complex numbers. + # SymPy does not think so. + if sympy_fn == factorial: + tv = numpy.abs(tv) + # SciPy supports gammaln for real arguments only, + # and there is also a branch cut along the negative real axis + if sympy_fn == loggamma: + tv = numpy.abs(tv) + # SymPy's digamma evaluates as polygamma(0, z) + # which SciPy supports for real arguments only + if sympy_fn == digamma: + tv = numpy.real(tv) + sympy_result = sympy_fn(tv).evalf() + scipy_result = scipy_fn(tv) + # SciPy's sici returns a tuple with both Si and Ci present in it + # which needs to be unpacked + if sympy_fn == Si: + scipy_result = scipy_fn(tv)[0] + if sympy_fn == Ci: + scipy_result = scipy_fn(tv)[1] + assert abs(f(tv) - sympy_result) < 1e-13*(1 + abs(sympy_result)) + assert abs(f(tv) - scipy_result) < 1e-13*(1 + abs(sympy_result)) + + double_arg_sympy_fns = [RisingFactorial, besselj, bessely, besseli, + besselk, polygamma] + double_arg_scipy_fns = [scipy.special.poch, scipy.special.jv, + scipy.special.yv, scipy.special.iv, scipy.special.kv, scipy.special.polygamma] + for (sympy_fn, scipy_fn) in zip(double_arg_sympy_fns, double_arg_scipy_fns): + f = lambdify((x, y), sympy_fn(x, y), modules="scipy") + for i in range(20): + # SciPy supports only real orders of Bessel functions + tv1 = numpy.random.uniform(-10, 10) + tv2 = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5) + # SciPy requires a real valued 2nd argument for: poch, polygamma + if sympy_fn in (RisingFactorial, polygamma): + tv2 = numpy.real(tv2) + if sympy_fn == polygamma: + tv1 = abs(int(tv1)) # first argument to polygamma must be a non-negative integer. + sympy_result = sympy_fn(tv1, tv2).evalf() + assert abs(f(tv1, tv2) - sympy_result) < 1e-13*(1 + abs(sympy_result)) + assert abs(f(tv1, tv2) - scipy_fn(tv1, tv2)) < 1e-13*(1 + abs(sympy_result)) + + +def test_scipy_polys(): + if not scipy: + skip("scipy not installed") + numpy.random.seed(0) + + params = symbols('n k a b') + # list polynomials with the number of parameters + polys = [ + (chebyshevt, 1), + (chebyshevu, 1), + (legendre, 1), + (hermite, 1), + (laguerre, 1), + (gegenbauer, 2), + (assoc_legendre, 2), + (assoc_laguerre, 2), + (jacobi, 3) + ] + + msg = \ + "The random test of the function {func} with the arguments " \ + "{args} had failed because the SymPy result {sympy_result} " \ + "and SciPy result {scipy_result} had failed to converge " \ + "within the tolerance {tol} " \ + "(Actual absolute difference : {diff})" + + for sympy_fn, num_params in polys: + args = params[:num_params] + (x,) + f = lambdify(args, sympy_fn(*args)) + for _ in range(10): + tn = numpy.random.randint(3, 10) + tparams = tuple(numpy.random.uniform(0, 5, size=num_params-1)) + tv = numpy.random.uniform(-10, 10) + 1j*numpy.random.uniform(-5, 5) + # SciPy supports hermite for real arguments only + if sympy_fn == hermite: + tv = numpy.real(tv) + # assoc_legendre needs x in (-1, 1) and integer param at most n + if sympy_fn == assoc_legendre: + tv = numpy.random.uniform(-1, 1) + tparams = tuple(numpy.random.randint(1, tn, size=1)) + + vals = (tn,) + tparams + (tv,) + scipy_result = f(*vals) + sympy_result = sympy_fn(*vals).evalf() + atol = 1e-9*(1 + abs(sympy_result)) + diff = abs(scipy_result - sympy_result) + try: + assert diff < atol + except TypeError: + raise AssertionError( + msg.format( + func=repr(sympy_fn), + args=repr(vals), + sympy_result=repr(sympy_result), + scipy_result=repr(scipy_result), + diff=diff, + tol=atol) + ) + + +def test_lambdify_inspect(): + f = lambdify(x, x**2) + # Test that inspect.getsource works but don't hard-code implementation + # details + assert 'x**2' in inspect.getsource(f) + + +def test_issue_14941(): + x, y = Dummy(), Dummy() + + # test dict + f1 = lambdify([x, y], {x: 3, y: 3}, 'sympy') + assert f1(2, 3) == {2: 3, 3: 3} + + # test tuple + f2 = lambdify([x, y], (y, x), 'sympy') + assert f2(2, 3) == (3, 2) + f2b = lambdify([], (1,)) # gh-23224 + assert f2b() == (1,) + + # test list + f3 = lambdify([x, y], [y, x], 'sympy') + assert f3(2, 3) == [3, 2] + + +def test_lambdify_Derivative_arg_issue_16468(): + f = Function('f')(x) + fx = f.diff() + assert lambdify((f, fx), f + fx)(10, 5) == 15 + assert eval(lambdastr((f, fx), f/fx))(10, 5) == 2 + raises(Exception, lambda: + eval(lambdastr((f, fx), f/fx, dummify=False))) + assert eval(lambdastr((f, fx), f/fx, dummify=True))(10, 5) == 2 + assert eval(lambdastr((fx, f), f/fx, dummify=True))(S(10), 5) == S.Half + assert lambdify(fx, 1 + fx)(41) == 42 + assert eval(lambdastr(fx, 1 + fx, dummify=True))(41) == 42 + + +def test_imag_real(): + f_re = lambdify([z], sympy.re(z)) + val = 3+2j + assert f_re(val) == val.real + + f_im = lambdify([z], sympy.im(z)) # see #15400 + assert f_im(val) == val.imag + + +def test_MatrixSymbol_issue_15578(): + if not numpy: + skip("numpy not installed") + A = MatrixSymbol('A', 2, 2) + A0 = numpy.array([[1, 2], [3, 4]]) + f = lambdify(A, A**(-1)) + assert numpy.allclose(f(A0), numpy.array([[-2., 1.], [1.5, -0.5]])) + g = lambdify(A, A**3) + assert numpy.allclose(g(A0), numpy.array([[37, 54], [81, 118]])) + + +def test_issue_15654(): + if not scipy: + skip("scipy not installed") + from sympy.abc import n, l, r, Z + from sympy.physics import hydrogen + nv, lv, rv, Zv = 1, 0, 3, 1 + sympy_value = hydrogen.R_nl(nv, lv, rv, Zv).evalf() + f = lambdify((n, l, r, Z), hydrogen.R_nl(n, l, r, Z)) + scipy_value = f(nv, lv, rv, Zv) + assert abs(sympy_value - scipy_value) < 1e-15 + + +def test_issue_15827(): + if not numpy: + skip("numpy not installed") + A = MatrixSymbol("A", 3, 3) + B = MatrixSymbol("B", 2, 3) + C = MatrixSymbol("C", 3, 4) + D = MatrixSymbol("D", 4, 5) + k=symbols("k") + f = lambdify(A, (2*k)*A) + g = lambdify(A, (2+k)*A) + h = lambdify(A, 2*A) + i = lambdify((B, C, D), 2*B*C*D) + assert numpy.array_equal(f(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \ + numpy.array([[2*k, 4*k, 6*k], [2*k, 4*k, 6*k], [2*k, 4*k, 6*k]], dtype=object)) + + assert numpy.array_equal(g(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \ + numpy.array([[k + 2, 2*k + 4, 3*k + 6], [k + 2, 2*k + 4, 3*k + 6], \ + [k + 2, 2*k + 4, 3*k + 6]], dtype=object)) + + assert numpy.array_equal(h(numpy.array([[1, 2, 3], [1, 2, 3], [1, 2, 3]])), \ + numpy.array([[2, 4, 6], [2, 4, 6], [2, 4, 6]])) + + assert numpy.array_equal(i(numpy.array([[1, 2, 3], [1, 2, 3]]), numpy.array([[1, 2, 3, 4], [1, 2, 3, 4], [1, 2, 3, 4]]), \ + numpy.array([[1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5], [1, 2, 3, 4, 5]])), numpy.array([[ 120, 240, 360, 480, 600], \ + [ 120, 240, 360, 480, 600]])) + + +def test_issue_16930(): + if not scipy: + skip("scipy not installed") + + x = symbols("x") + f = lambda x: S.GoldenRatio * x**2 + f_ = lambdify(x, f(x), modules='scipy') + assert f_(1) == scipy.constants.golden_ratio + +def test_issue_17898(): + if not scipy: + skip("scipy not installed") + x = symbols("x") + f_ = lambdify([x], sympy.LambertW(x,-1), modules='scipy') + assert f_(0.1) == mpmath.lambertw(0.1, -1) + +def test_issue_13167_21411(): + if not numpy: + skip("numpy not installed") + f1 = lambdify(x, sympy.Heaviside(x)) + f2 = lambdify(x, sympy.Heaviside(x, 1)) + res1 = f1([-1, 0, 1]) + res2 = f2([-1, 0, 1]) + assert Abs(res1[0]).n() < 1e-15 # First functionality: only one argument passed + assert Abs(res1[1] - 1/2).n() < 1e-15 + assert Abs(res1[2] - 1).n() < 1e-15 + assert Abs(res2[0]).n() < 1e-15 # Second functionality: two arguments passed + assert Abs(res2[1] - 1).n() < 1e-15 + assert Abs(res2[2] - 1).n() < 1e-15 + +def test_single_e(): + f = lambdify(x, E) + assert f(23) == exp(1.0) + +def test_issue_16536(): + if not scipy: + skip("scipy not installed") + + a = symbols('a') + f1 = lowergamma(a, x) + F = lambdify((a, x), f1, modules='scipy') + assert abs(lowergamma(1, 3) - F(1, 3)) <= 1e-10 + + f2 = uppergamma(a, x) + F = lambdify((a, x), f2, modules='scipy') + assert abs(uppergamma(1, 3) - F(1, 3)) <= 1e-10 + + +def test_issue_22726(): + if not numpy: + skip("numpy not installed") + + x1, x2 = symbols('x1 x2') + f = Max(S.Zero, Min(x1, x2)) + g = derive_by_array(f, (x1, x2)) + G = lambdify((x1, x2), g, modules='numpy') + point = {x1: 1, x2: 2} + assert (abs(g.subs(point) - G(*point.values())) <= 1e-10).all() + + +def test_issue_22739(): + if not numpy: + skip("numpy not installed") + + x1, x2 = symbols('x1 x2') + f = Heaviside(Min(x1, x2)) + F = lambdify((x1, x2), f, modules='numpy') + point = {x1: 1, x2: 2} + assert abs(f.subs(point) - F(*point.values())) <= 1e-10 + + +def test_issue_22992(): + if not numpy: + skip("numpy not installed") + + a, t = symbols('a t') + expr = a*(log(cot(t/2)) - cos(t)) + F = lambdify([a, t], expr, 'numpy') + + point = {a: 10, t: 2} + + assert abs(expr.subs(point) - F(*point.values())) <= 1e-10 + + # Standard math + F = lambdify([a, t], expr) + + assert abs(expr.subs(point) - F(*point.values())) <= 1e-10 + + +def test_issue_19764(): + if not numpy: + skip("numpy not installed") + + expr = Array([x, x**2]) + f = lambdify(x, expr, 'numpy') + + assert f(1).__class__ == numpy.ndarray + +def test_issue_20070(): + if not numba: + skip("numba not installed") + + f = lambdify(x, sin(x), 'numpy') + assert numba.jit(f, nopython=True)(1)==0.8414709848078965 + + +def test_fresnel_integrals_scipy(): + if not scipy: + skip("scipy not installed") + + f1 = fresnelc(x) + f2 = fresnels(x) + F1 = lambdify(x, f1, modules='scipy') + F2 = lambdify(x, f2, modules='scipy') + + assert abs(fresnelc(1.3) - F1(1.3)) <= 1e-10 + assert abs(fresnels(1.3) - F2(1.3)) <= 1e-10 + + +def test_beta_scipy(): + if not scipy: + skip("scipy not installed") + + f = beta(x, y) + F = lambdify((x, y), f, modules='scipy') + + assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10 + + +def test_beta_math(): + f = beta(x, y) + F = lambdify((x, y), f, modules='math') + + assert abs(beta(1.3, 2.3) - F(1.3, 2.3)) <= 1e-10 + + +def test_betainc_scipy(): + if not scipy: + skip("scipy not installed") + + f = betainc(w, x, y, z) + F = lambdify((w, x, y, z), f, modules='scipy') + + assert abs(betainc(1.4, 3.1, 0.1, 0.5) - F(1.4, 3.1, 0.1, 0.5)) <= 1e-10 + + +def test_betainc_regularized_scipy(): + if not scipy: + skip("scipy not installed") + + f = betainc_regularized(w, x, y, z) + F = lambdify((w, x, y, z), f, modules='scipy') + + assert abs(betainc_regularized(0.2, 3.5, 0.1, 1) - F(0.2, 3.5, 0.1, 1)) <= 1e-10 + + +def test_numpy_special_math(): + if not numpy: + skip("numpy not installed") + + funcs = [expm1, log1p, exp2, log2, log10, hypot, logaddexp, logaddexp2] + for func in funcs: + if 2 in func.nargs: + expr = func(x, y) + args = (x, y) + num_args = (0.3, 0.4) + elif 1 in func.nargs: + expr = func(x) + args = (x,) + num_args = (0.3,) + else: + raise NotImplementedError("Need to handle other than unary & binary functions in test") + f = lambdify(args, expr) + result = f(*num_args) + reference = expr.subs(dict(zip(args, num_args))).evalf() + assert numpy.allclose(result, float(reference)) + + lae2 = lambdify((x, y), logaddexp2(log2(x), log2(y))) + assert abs(2.0**lae2(1e-50, 2.5e-50) - 3.5e-50) < 1e-62 # from NumPy's docstring + + +def test_scipy_special_math(): + if not scipy: + skip("scipy not installed") + + cm1 = lambdify((x,), cosm1(x), modules='scipy') + assert abs(cm1(1e-20) + 5e-41) < 1e-200 + + have_scipy_1_10plus = tuple(map(int, scipy.version.version.split('.')[:2])) >= (1, 10) + + if have_scipy_1_10plus: + cm2 = lambdify((x, y), powm1(x, y), modules='scipy') + assert abs(cm2(1.2, 1e-9) - 1.82321557e-10) < 1e-17 + + +def test_scipy_bernoulli(): + if not scipy: + skip("scipy not installed") + + bern = lambdify((x,), bernoulli(x), modules='scipy') + assert bern(1) == 0.5 + + +def test_scipy_harmonic(): + if not scipy: + skip("scipy not installed") + + hn = lambdify((x,), harmonic(x), modules='scipy') + assert hn(2) == 1.5 + hnm = lambdify((x, y), harmonic(x, y), modules='scipy') + assert hnm(2, 2) == 1.25 + + +def test_cupy_array_arg(): + if not cupy: + skip("CuPy not installed") + + f = lambdify([[x, y]], x*x + y, 'cupy') + result = f(cupy.array([2.0, 1.0])) + assert result == 5 + assert "cupy" in str(type(result)) + + +def test_cupy_array_arg_using_numpy(): + # numpy functions can be run on cupy arrays + # unclear if we can "officially" support this, + # depends on numpy __array_function__ support + if not cupy: + skip("CuPy not installed") + + f = lambdify([[x, y]], x*x + y, 'numpy') + result = f(cupy.array([2.0, 1.0])) + assert result == 5 + assert "cupy" in str(type(result)) + +def test_cupy_dotproduct(): + if not cupy: + skip("CuPy not installed") + + A = Matrix([x, y, z]) + f1 = lambdify([x, y, z], DotProduct(A, A), modules='cupy') + f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy') + f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='cupy') + f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='cupy') + + assert f1(1, 2, 3) == \ + f2(1, 2, 3) == \ + f3(1, 2, 3) == \ + f4(1, 2, 3) == \ + cupy.array([14]) + + +def test_jax_array_arg(): + if not jax: + skip("JAX not installed") + + f = lambdify([[x, y]], x*x + y, 'jax') + result = f(jax.numpy.array([2.0, 1.0])) + assert result == 5 + assert "jax" in str(type(result)) + + +def test_jax_array_arg_using_numpy(): + if not jax: + skip("JAX not installed") + + f = lambdify([[x, y]], x*x + y, 'numpy') + result = f(jax.numpy.array([2.0, 1.0])) + assert result == 5 + assert "jax" in str(type(result)) + + +def test_jax_dotproduct(): + if not jax: + skip("JAX not installed") + + A = Matrix([x, y, z]) + f1 = lambdify([x, y, z], DotProduct(A, A), modules='jax') + f2 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax') + f3 = lambdify([x, y, z], DotProduct(A.T, A), modules='jax') + f4 = lambdify([x, y, z], DotProduct(A, A.T), modules='jax') + + assert f1(1, 2, 3) == \ + f2(1, 2, 3) == \ + f3(1, 2, 3) == \ + f4(1, 2, 3) == \ + jax.numpy.array([14]) + + +def test_lambdify_cse(): + def no_op_cse(exprs): + return (), exprs + + def dummy_cse(exprs): + from sympy.simplify.cse_main import cse + return cse(exprs, symbols=numbered_symbols(cls=Dummy)) + + def minmem(exprs): + from sympy.simplify.cse_main import cse_release_variables, cse + return cse(exprs, postprocess=cse_release_variables) + + class Case: + def __init__(self, *, args, exprs, num_args, requires_numpy=False): + self.args = args + self.exprs = exprs + self.num_args = num_args + subs_dict = dict(zip(self.args, self.num_args)) + self.ref = [e.subs(subs_dict).evalf() for e in exprs] + self.requires_numpy = requires_numpy + + def lambdify(self, *, cse): + return lambdify(self.args, self.exprs, cse=cse) + + def assertAllClose(self, result, *, abstol=1e-15, reltol=1e-15): + if self.requires_numpy: + assert all(numpy.allclose(result[i], numpy.asarray(r, dtype=float), + rtol=reltol, atol=abstol) + for i, r in enumerate(self.ref)) + return + + for i, r in enumerate(self.ref): + abs_err = abs(result[i] - r) + if r == 0: + assert abs_err < abstol + else: + assert abs_err/abs(r) < reltol + + cases = [ + Case( + args=(x, y, z), + exprs=[ + x + y + z, + x + y - z, + 2*x + 2*y - z, + (x+y)**2 + (y+z)**2, + ], + num_args=(2., 3., 4.) + ), + Case( + args=(x, y, z), + exprs=[ + x + sympy.Heaviside(x), + y + sympy.Heaviside(x), + z + sympy.Heaviside(x, 1), + z/sympy.Heaviside(x, 1) + ], + num_args=(0., 3., 4.) + ), + Case( + args=(x, y, z), + exprs=[ + x + sinc(y), + y + sinc(y), + z - sinc(y) + ], + num_args=(0.1, 0.2, 0.3) + ), + Case( + args=(x, y, z), + exprs=[ + Matrix([[x, x*y], [sin(z) + 4, x**z]]), + x*y+sin(z)-x**z, + Matrix([x*x, sin(z), x**z]) + ], + num_args=(1.,2.,3.), + requires_numpy=True + ), + Case( + args=(x, y), + exprs=[(x + y - 1)**2, x, x + y, + (x + y)/(2*x + 1) + (x + y - 1)**2, (2*x + 1)**(x + y)], + num_args=(1,2) + ) + ] + for case in cases: + if not numpy and case.requires_numpy: + continue + for _cse in [False, True, minmem, no_op_cse, dummy_cse]: + f = case.lambdify(cse=_cse) + result = f(*case.num_args) + case.assertAllClose(result) + +def test_issue_25288(): + syms = numbered_symbols(cls=Dummy) + ok = lambdify(x, [x**2, sin(x**2)], cse=lambda e: cse(e, symbols=syms))(2) + assert ok + + +def test_deprecated_set(): + with warns_deprecated_sympy(): + lambdify({x, y}, x + y) + +def test_issue_13881(): + if not numpy: + skip("numpy not installed.") + + X = MatrixSymbol('X', 3, 1) + + f = lambdify(X, X.T*X, 'numpy') + assert f(numpy.array([1, 2, 3])) == 14 + assert f(numpy.array([3, 2, 1])) == 14 + + f = lambdify(X, X*X.T, 'numpy') + assert f(numpy.array([1, 2, 3])) == 14 + assert f(numpy.array([3, 2, 1])) == 14 + + f = lambdify(X, (X*X.T)*X, 'numpy') + arr1 = numpy.array([[1], [2], [3]]) + arr2 = numpy.array([[14],[28],[42]]) + + assert numpy.array_equal(f(arr1), arr2) + + +def test_23536_lambdify_cse_dummy(): + + f = Function('x')(y) + g = Function('w')(y) + expr = z + (f**4 + g**5)*(f**3 + (g*f)**3) + expr = expr.expand() + eval_expr = lambdify(((f, g), z), expr, cse=True) + ans = eval_expr((1.0, 2.0), 3.0) # shouldn't raise NameError + assert ans == 300.0 # not a list and value is 300 + + +class LambdifyDocstringTestCase: + SIGNATURE = None + EXPR = None + SRC = None + + def __init__(self, docstring_limit, expected_redacted): + self.docstring_limit = docstring_limit + self.expected_redacted = expected_redacted + + @property + def expected_expr(self): + expr_redacted_msg = "EXPRESSION REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + return self.EXPR if not self.expected_redacted else expr_redacted_msg + + @property + def expected_src(self): + src_redacted_msg = "SOURCE CODE REDACTED DUE TO LENGTH, (see lambdify's `docstring_limit`)" + return self.SRC if not self.expected_redacted else src_redacted_msg + + @property + def expected_docstring(self): + expected_docstring = ( + f'Created with lambdify. Signature:\n\n' + f'func({self.SIGNATURE})\n\n' + f'Expression:\n\n' + f'{self.expected_expr}\n\n' + f'Source code:\n\n' + f'{self.expected_src}\n\n' + f'Imported modules:\n\n' + ) + return expected_docstring + + def __len__(self): + return len(self.expected_docstring) + + def __repr__(self): + return ( + f'{self.__class__.__name__}(' + f'docstring_limit={self.docstring_limit}, ' + f'expected_redacted={self.expected_redacted})' + ) + + +def test_lambdify_docstring_size_limit_simple_symbol(): + + class SimpleSymbolTestCase(LambdifyDocstringTestCase): + SIGNATURE = 'x' + EXPR = 'x' + SRC = ( + 'def _lambdifygenerated(x):\n' + ' return x\n' + ) + + x = symbols('x') + + test_cases = ( + SimpleSymbolTestCase(docstring_limit=None, expected_redacted=False), + SimpleSymbolTestCase(docstring_limit=100, expected_redacted=False), + SimpleSymbolTestCase(docstring_limit=1, expected_redacted=False), + SimpleSymbolTestCase(docstring_limit=0, expected_redacted=True), + SimpleSymbolTestCase(docstring_limit=-1, expected_redacted=True), + ) + for test_case in test_cases: + lambdified_expr = lambdify( + [x], + x, + 'sympy', + docstring_limit=test_case.docstring_limit, + ) + assert lambdified_expr.__doc__ == test_case.expected_docstring + + +def test_lambdify_docstring_size_limit_nested_expr(): + + class ExprListTestCase(LambdifyDocstringTestCase): + SIGNATURE = 'x, y, z' + EXPR = ( + '[x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z ' + '+ 3*x*z**2 +...' + ) + SRC = ( + 'def _lambdifygenerated(x, y, z):\n' + ' return [x, [y], z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 ' + '+ 6*x*y*z + 3*x*z**2 + y**3 + 3*y**2*z + 3*y*z**2 + z**3]\n' + ) + + x, y, z = symbols('x, y, z') + expr = [x, [y], z, ((x + y + z)**3).expand()] + + test_cases = ( + ExprListTestCase(docstring_limit=None, expected_redacted=False), + ExprListTestCase(docstring_limit=200, expected_redacted=False), + ExprListTestCase(docstring_limit=50, expected_redacted=True), + ExprListTestCase(docstring_limit=0, expected_redacted=True), + ExprListTestCase(docstring_limit=-1, expected_redacted=True), + ) + for test_case in test_cases: + lambdified_expr = lambdify( + [x, y, z], + expr, + 'sympy', + docstring_limit=test_case.docstring_limit, + ) + assert lambdified_expr.__doc__ == test_case.expected_docstring + + +def test_lambdify_docstring_size_limit_matrix(): + + class MatrixTestCase(LambdifyDocstringTestCase): + SIGNATURE = 'x, y, z' + EXPR = ( + 'Matrix([[0, x], [x + y + z, x**3 + 3*x**2*y + 3*x**2*z + 3*x*y**2 ' + '+ 6*x*y*z...' + ) + SRC = ( + 'def _lambdifygenerated(x, y, z):\n' + ' return ImmutableDenseMatrix([[0, x], [x + y + z, x**3 ' + '+ 3*x**2*y + 3*x**2*z + 3*x*y**2 + 6*x*y*z + 3*x*z**2 + y**3 ' + '+ 3*y**2*z + 3*y*z**2 + z**3]])\n' + ) + + x, y, z = symbols('x, y, z') + expr = Matrix([[S.Zero, x], [x + y + z, ((x + y + z)**3).expand()]]) + + test_cases = ( + MatrixTestCase(docstring_limit=None, expected_redacted=False), + MatrixTestCase(docstring_limit=200, expected_redacted=False), + MatrixTestCase(docstring_limit=50, expected_redacted=True), + MatrixTestCase(docstring_limit=0, expected_redacted=True), + MatrixTestCase(docstring_limit=-1, expected_redacted=True), + ) + for test_case in test_cases: + lambdified_expr = lambdify( + [x, y, z], + expr, + 'sympy', + docstring_limit=test_case.docstring_limit, + ) + assert lambdified_expr.__doc__ == test_case.expected_docstring + + +def test_lambdify_empty_tuple(): + a = symbols("a") + expr = ((), (a,)) + f = lambdify(a, expr) + result = f(1) + assert result == ((), (1,)), "Lambdify did not handle the empty tuple correctly." + +def test_assoc_legendre_numerical_evaluation(): + + tol = 1e-10 + + sympy_result_integer = assoc_legendre(1, 1/2, 0.1).evalf() + sympy_result_complex = assoc_legendre(2, 1, 3).evalf() + mpmath_result_integer = -0.474572528387641 + mpmath_result_complex = -25.45584412271571*I + + assert all_close(sympy_result_integer, mpmath_result_integer, tol) + assert all_close(sympy_result_complex, mpmath_result_complex, tol) diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_matchpy_connector.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_matchpy_connector.py new file mode 100644 index 0000000000000000000000000000000000000000..3648bd49f9e56ca20fbf428ed46c01429dbe8b15 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_matchpy_connector.py @@ -0,0 +1,164 @@ +import pickle + +from sympy.core.relational import (Eq, Ne) +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer + +matchpy = import_module("matchpy") + +x, y, z = symbols("x y z") + + +def _get_first_match(expr, pattern): + from matchpy import ManyToOneMatcher, Pattern + + matcher = ManyToOneMatcher() + matcher.add(Pattern(pattern)) + return next(iter(matcher.match(expr))) + + +def test_matchpy_connector(): + if matchpy is None: + skip("matchpy not installed") + + from multiset import Multiset + from matchpy import Pattern, Substitution + + w_ = WildDot("w_") + w__ = WildPlus("w__") + w___ = WildStar("w___") + + expr = x + y + pattern = x + w_ + p, subst = _get_first_match(expr, pattern) + assert p == Pattern(pattern) + assert subst == Substitution({'w_': y}) + + expr = x + y + z + pattern = x + w__ + p, subst = _get_first_match(expr, pattern) + assert p == Pattern(pattern) + assert subst == Substitution({'w__': Multiset([y, z])}) + + expr = x + y + z + pattern = x + y + z + w___ + p, subst = _get_first_match(expr, pattern) + assert p == Pattern(pattern) + assert subst == Substitution({'w___': Multiset()}) + + +def test_matchpy_optional(): + if matchpy is None: + skip("matchpy not installed") + + from matchpy import Pattern, Substitution + from matchpy import ManyToOneReplacer, ReplacementRule + + p = WildDot("p", optional=1) + q = WildDot("q", optional=0) + + pattern = p*x + q + + expr1 = 2*x + pa, subst = _get_first_match(expr1, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': 2, 'q': 0}) + + expr2 = x + 3 + pa, subst = _get_first_match(expr2, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': 1, 'q': 3}) + + expr3 = x + pa, subst = _get_first_match(expr3, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': 1, 'q': 0}) + + expr4 = x*y + z + pa, subst = _get_first_match(expr4, pattern) + assert pa == Pattern(pattern) + assert subst == Substitution({'p': y, 'q': z}) + + replacer = ManyToOneReplacer() + replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q))) + assert replacer.replace(expr1) == sin(2)*cos(0) + assert replacer.replace(expr2) == sin(1)*cos(3) + assert replacer.replace(expr3) == sin(1)*cos(0) + assert replacer.replace(expr4) == sin(y)*cos(z) + + +def test_replacer(): + if matchpy is None: + skip("matchpy not installed") + + for info in [True, False]: + for lambdify in [True, False]: + _perform_test_replacer(info, lambdify) + + +def _perform_test_replacer(info, lambdify): + + x1_ = WildDot("x1_") + x2_ = WildDot("x2_") + + a_ = WildDot("a_", optional=S.One) + b_ = WildDot("b_", optional=S.One) + c_ = WildDot("c_", optional=S.Zero) + + replacer = Replacer(common_constraints=[ + matchpy.CustomConstraint(lambda a_: not a_.has(x)), + matchpy.CustomConstraint(lambda b_: not b_.has(x)), + matchpy.CustomConstraint(lambda c_: not c_.has(x)), + ], lambdify=lambdify, info=info) + + # Rewrite the equation into implicit form, unless it's already solved: + replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)], info=1) + + # Simple equation solver for real numbers: + replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_), info=2) + disc = b_**2 - 4*a_*c_ + replacer.add( + Eq(a_*x**2 + b_*x + c_, 0), + Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)), + conditions_nonfalse=[disc >= 0], + info=3 + ) + replacer.add( + Eq(a_*x**2 + c_, 0), + Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)), + conditions_nonfalse=[-c_*a_ > 0], + info=4 + ) + + g = lambda expr, infos: (expr, infos) if info else expr + + assert replacer.replace(Eq(3*x, y)) == g(Eq(x, y/3), [1, 2]) + assert replacer.replace(Eq(x**2 + 1, 0)) == g(Eq(x**2 + 1, 0), []) + assert replacer.replace(Eq(x**2, 4)) == g((Eq(x, 2) | Eq(x, -2)), [1, 4]) + assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == g(Eq(x, -2*y), [3]) + + +def test_matchpy_object_pickle(): + if matchpy is None: + return + + a1 = WildDot("a") + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 + + a1 = WildDot("a", S(1)) + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 + + a1 = WildPlus("a", S(1)) + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 + + a1 = WildStar("a", S(1)) + a2 = pickle.loads(pickle.dumps(a1)) + assert a1 == a2 diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_mathml.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_mathml.py new file mode 100644 index 0000000000000000000000000000000000000000..e4a7598f175be34f8bb34c0bd9c003d1c0238c7b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_mathml.py @@ -0,0 +1,33 @@ +import os +from textwrap import dedent +from sympy.external import import_module +from sympy.testing.pytest import skip +from sympy.utilities.mathml import apply_xsl + + + +lxml = import_module('lxml') + +path = os.path.abspath(os.path.join(os.path.dirname(__file__), "test_xxe.py")) + + +def test_xxe(): + assert os.path.isfile(path) + if not lxml: + skip("lxml not installed.") + + mml = dedent( + rf""" + + ]> + + John + &ent; + + """ + ) + xsl = 'mathml/data/simple_mmlctop.xsl' + + res = apply_xsl(mml, xsl) + assert res == \ + '\n\nJohn\n\n\n' diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_misc.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f9f61ee6c84def5388ba9cd206851f36950aa2c5 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_misc.py @@ -0,0 +1,151 @@ +from textwrap import dedent +import sys +from subprocess import Popen, PIPE +import os + +from sympy.core.singleton import S +from sympy.testing.pytest import (raises, warns_deprecated_sympy, + skip_under_pyodide) +from sympy.utilities.misc import (translate, replace, ordinal, rawlines, + strlines, as_int, find_executable) +from sympy.external import import_module + +pyodide_js = import_module('pyodide_js') + + +def test_translate(): + abc = 'abc' + assert translate(abc, None, 'a') == 'bc' + assert translate(abc, None, '') == 'abc' + assert translate(abc, {'a': 'x'}, 'c') == 'xb' + assert translate(abc, {'a': 'bc'}, 'c') == 'bcb' + assert translate(abc, {'ab': 'x'}, 'c') == 'x' + assert translate(abc, {'ab': ''}, 'c') == '' + assert translate(abc, {'bc': 'x'}, 'c') == 'ab' + assert translate(abc, {'abc': 'x', 'a': 'y'}) == 'x' + u = chr(4096) + assert translate(abc, 'a', 'x', u) == 'xbc' + assert (u in translate(abc, 'a', u, u)) is True + + +def test_replace(): + assert replace('abc', ('a', 'b')) == 'bbc' + assert replace('abc', {'a': 'Aa'}) == 'Aabc' + assert replace('abc', ('a', 'b'), ('c', 'C')) == 'bbC' + + +def test_ordinal(): + assert ordinal(-1) == '-1st' + assert ordinal(0) == '0th' + assert ordinal(1) == '1st' + assert ordinal(2) == '2nd' + assert ordinal(3) == '3rd' + assert all(ordinal(i).endswith('th') for i in range(4, 21)) + assert ordinal(100) == '100th' + assert ordinal(101) == '101st' + assert ordinal(102) == '102nd' + assert ordinal(103) == '103rd' + assert ordinal(104) == '104th' + assert ordinal(200) == '200th' + assert all(ordinal(i) == str(i) + 'th' for i in range(-220, -203)) + + +def test_rawlines(): + assert rawlines('a a\na') == "dedent('''\\\n a a\n a''')" + assert rawlines('a a') == "'a a'" + assert rawlines(strlines('\\le"ft')) == ( + '(\n' + " '(\\n'\n" + ' \'r\\\'\\\\le"ft\\\'\\n\'\n' + " ')'\n" + ')') + + +def test_strlines(): + q = 'this quote (") is in the middle' + # the following assert rhs was prepared with + # print(rawlines(strlines(q, 10))) + assert strlines(q, 10) == dedent('''\ + ( + 'this quo' + 'te (") i' + 's in the' + ' middle' + )''') + assert q == ( + 'this quo' + 'te (") i' + 's in the' + ' middle' + ) + q = "this quote (') is in the middle" + assert strlines(q, 20) == dedent('''\ + ( + "this quote (') is " + "in the middle" + )''') + assert strlines('\\left') == ( + '(\n' + "r'\\left'\n" + ')') + assert strlines('\\left', short=True) == r"r'\left'" + assert strlines('\\le"ft') == ( + '(\n' + 'r\'\\le"ft\'\n' + ')') + q = 'this\nother line' + assert strlines(q) == rawlines(q) + + +def test_translate_args(): + try: + translate(None, None, None, 'not_none') + except ValueError: + pass # Exception raised successfully + else: + assert False + + assert translate('s', None, None, None) == 's' + + try: + translate('s', 'a', 'bc') + except ValueError: + pass # Exception raised successfully + else: + assert False + + +@skip_under_pyodide("Cannot create subprocess under pyodide.") +def test_debug_output(): + env = os.environ.copy() + env['SYMPY_DEBUG'] = 'True' + cmd = 'from sympy import *; x = Symbol("x"); print(integrate((1-cos(x))/x, x))' + cmdline = [sys.executable, '-c', cmd] + proc = Popen(cmdline, env=env, stdout=PIPE, stderr=PIPE) + out, err = proc.communicate() + out = out.decode('ascii') # utf-8? + err = err.decode('ascii') + expected = 'substituted: -x*(1 - cos(x)), u: 1/x, u_var: _u' + assert expected in err, err + + +def test_as_int(): + raises(ValueError, lambda : as_int(True)) + raises(ValueError, lambda : as_int(1.1)) + raises(ValueError, lambda : as_int([])) + raises(ValueError, lambda : as_int(S.NaN)) + raises(ValueError, lambda : as_int(S.Infinity)) + raises(ValueError, lambda : as_int(S.NegativeInfinity)) + raises(ValueError, lambda : as_int(S.ComplexInfinity)) + # for the following, limited precision makes int(arg) == arg + # but the int value is not necessarily what a user might have + # expected; Q.prime is more nuanced in its response for + # expressions which might be complex representations of an + # integer. This is not -- by design -- as_ints role. + raises(ValueError, lambda : as_int(1e23)) + raises(ValueError, lambda : as_int(S('1.'+'0'*20+'1'))) + assert as_int(True, strict=False) == 1 + +def test_deprecated_find_executable(): + with warns_deprecated_sympy(): + find_executable('python') diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_pickling.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_pickling.py new file mode 100644 index 0000000000000000000000000000000000000000..fdb61428139b36c23826d41d6fc7d3a81b5dafc6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_pickling.py @@ -0,0 +1,717 @@ +import inspect +import copy +import pickle + +from sympy.physics.units import meter + +from sympy.testing.pytest import XFAIL, raises, ignore_warnings + +from sympy.core.basic import Atom, Basic +from sympy.core.singleton import SingletonRegistry +from sympy.core.symbol import Str, Dummy, Symbol, Wild +from sympy.core.numbers import (E, I, pi, oo, zoo, nan, Integer, + Rational, Float, AlgebraicNumber) +from sympy.core.relational import (Equality, GreaterThan, LessThan, Relational, + StrictGreaterThan, StrictLessThan, Unequality) +from sympy.core.add import Add +from sympy.core.mul import Mul +from sympy.core.power import Pow +from sympy.core.function import Derivative, Function, FunctionClass, Lambda, \ + WildFunction +from sympy.sets.sets import Interval +from sympy.core.multidimensional import vectorize + +from sympy.external.gmpy import gmpy as _gmpy +from sympy.utilities.exceptions import SymPyDeprecationWarning + +from sympy.core.singleton import S +from sympy.core.symbol import symbols + +from sympy.external import import_module +cloudpickle = import_module('cloudpickle') + + +not_equal_attrs = { + '_assumptions', # This is a local cache that isn't automatically filled on creation + '_mhash', # Cached after __hash__ is called but set to None after creation +} + + +deprecated_attrs = { + 'is_EmptySet', # Deprecated from SymPy 1.5. This can be removed when is_EmptySet is removed. + 'expr_free_symbols', # Deprecated from SymPy 1.9. This can be removed when exr_free_symbols is removed. +} + + +def check(a, exclude=[], check_attr=True, deprecated=()): + """ Check that pickling and copying round-trips. + """ + # Pickling with protocols 0 and 1 is disabled for Basic instances: + if isinstance(a, Basic): + for protocol in [0, 1]: + raises(NotImplementedError, lambda: pickle.dumps(a, protocol)) + + protocols = [2, copy.copy, copy.deepcopy, 3, 4] + if cloudpickle: + protocols.extend([cloudpickle]) + + for protocol in protocols: + if protocol in exclude: + continue + + if callable(protocol): + if isinstance(a, type): + # Classes can't be copied, but that's okay. + continue + b = protocol(a) + elif inspect.ismodule(protocol): + b = protocol.loads(protocol.dumps(a)) + else: + b = pickle.loads(pickle.dumps(a, protocol)) + + d1 = dir(a) + d2 = dir(b) + assert set(d1) == set(d2) + + if not check_attr: + continue + + def c(a, b, d): + for i in d: + if i in not_equal_attrs: + if hasattr(a, i): + assert hasattr(b, i), i + elif i in deprecated_attrs or i in deprecated: + with ignore_warnings(SymPyDeprecationWarning): + assert getattr(a, i) == getattr(b, i), i + elif not hasattr(a, i): + continue + else: + attr = getattr(a, i) + if not hasattr(attr, "__call__"): + assert hasattr(b, i), i + assert getattr(b, i) == attr, "%s != %s, protocol: %s" % (getattr(b, i), attr, protocol) + + c(a, b, d1) + c(b, a, d2) + + + +#================== core ========================= + + +def test_core_basic(): + for c in (Atom, Atom(), Basic, Basic(), SingletonRegistry, S): + check(c) + +def test_core_Str(): + check(Str('x')) + +def test_core_symbol(): + # make the Symbol a unique name that doesn't class with any other + # testing variable in this file since after this test the symbol + # having the same name will be cached as noncommutative + for c in (Dummy, Dummy("x", commutative=False), Symbol, + Symbol("_issue_3130", commutative=False), Wild, Wild("x")): + check(c) + + +def test_core_numbers(): + for c in (Integer(2), Rational(2, 3), Float("1.2")): + check(c) + for c in (AlgebraicNumber, AlgebraicNumber(sqrt(3))): + check(c, check_attr=False) + + +def test_core_float_copy(): + # See gh-7457 + y = Symbol("x") + 1.0 + check(y) # does not raise TypeError ("argument is not an mpz") + + +def test_core_relational(): + x = Symbol("x") + y = Symbol("y") + for c in (Equality, Equality(x, y), GreaterThan, GreaterThan(x, y), + LessThan, LessThan(x, y), Relational, Relational(x, y), + StrictGreaterThan, StrictGreaterThan(x, y), StrictLessThan, + StrictLessThan(x, y), Unequality, Unequality(x, y)): + check(c) + + +def test_core_add(): + x = Symbol("x") + for c in (Add, Add(x, 4)): + check(c) + + +def test_core_mul(): + x = Symbol("x") + for c in (Mul, Mul(x, 4)): + check(c) + + +def test_core_power(): + x = Symbol("x") + for c in (Pow, Pow(x, 4)): + check(c) + + +def test_core_function(): + x = Symbol("x") + for f in (Derivative, Derivative(x), Function, FunctionClass, Lambda, + WildFunction): + check(f) + + +def test_core_undefinedfunctions(): + f = Function("f") + # Full XFAILed test below + exclude = list(range(5)) + # https://github.com/cloudpipe/cloudpickle/issues/65 + # https://github.com/cloudpipe/cloudpickle/issues/190 + exclude.append(cloudpickle) + check(f, exclude=exclude) + +@XFAIL +def test_core_undefinedfunctions_fail(): + # This fails because f is assumed to be a class at sympy.basic.function.f + f = Function("f") + check(f) + + +def test_core_interval(): + for c in (Interval, Interval(0, 2)): + check(c) + + +def test_core_multidimensional(): + for c in (vectorize, vectorize(0)): + check(c) + + +def test_Singletons(): + protocols = [0, 1, 2, 3, 4] + copiers = [copy.copy, copy.deepcopy] + copiers += [lambda x: pickle.loads(pickle.dumps(x, proto)) + for proto in protocols] + if cloudpickle: + copiers += [lambda x: cloudpickle.loads(cloudpickle.dumps(x))] + + for obj in (Integer(-1), Integer(0), Integer(1), Rational(1, 2), pi, E, I, + oo, -oo, zoo, nan, S.GoldenRatio, S.TribonacciConstant, + S.EulerGamma, S.Catalan, S.EmptySet, S.IdentityFunction): + for func in copiers: + assert func(obj) is obj + + +#================== functions =================== +from sympy.functions import (Piecewise, lowergamma, acosh, chebyshevu, + chebyshevt, ln, chebyshevt_root, legendre, Heaviside, bernoulli, coth, + tanh, assoc_legendre, sign, arg, asin, DiracDelta, re, rf, Abs, + uppergamma, binomial, sinh, cos, cot, acos, acot, gamma, bell, + hermite, harmonic, LambertW, zeta, log, factorial, asinh, acoth, cosh, + dirichlet_eta, Eijk, loggamma, erf, ceiling, im, fibonacci, + tribonacci, conjugate, tan, chebyshevu_root, floor, atanh, sqrt, sin, + atan, ff, lucas, atan2, polygamma, exp) + + +def test_functions(): + one_var = (acosh, ln, Heaviside, factorial, bernoulli, coth, tanh, + sign, arg, asin, DiracDelta, re, Abs, sinh, cos, cot, acos, acot, + gamma, bell, harmonic, LambertW, zeta, log, factorial, asinh, + acoth, cosh, dirichlet_eta, loggamma, erf, ceiling, im, fibonacci, + tribonacci, conjugate, tan, floor, atanh, sin, atan, lucas, exp) + two_var = (rf, ff, lowergamma, chebyshevu, chebyshevt, binomial, + atan2, polygamma, hermite, legendre, uppergamma) + x, y, z = symbols("x,y,z") + others = (chebyshevt_root, chebyshevu_root, Eijk(x, y, z), + Piecewise( (0, x < -1), (x**2, x <= 1), (x**3, True)), + assoc_legendre) + for cls in one_var: + check(cls) + c = cls(x) + check(c) + for cls in two_var: + check(cls) + c = cls(x, y) + check(c) + for cls in others: + check(cls) + +#================== geometry ==================== +from sympy.geometry.entity import GeometryEntity +from sympy.geometry.point import Point +from sympy.geometry.ellipse import Circle, Ellipse +from sympy.geometry.line import Line, LinearEntity, Ray, Segment +from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle + + +def test_geometry(): + p1 = Point(1, 2) + p2 = Point(2, 3) + p3 = Point(0, 0) + p4 = Point(0, 1) + for c in ( + GeometryEntity, GeometryEntity(), Point, p1, Circle, Circle(p1, 2), + Ellipse, Ellipse(p1, 3, 4), Line, Line(p1, p2), LinearEntity, + LinearEntity(p1, p2), Ray, Ray(p1, p2), Segment, Segment(p1, p2), + Polygon, Polygon(p1, p2, p3, p4), RegularPolygon, + RegularPolygon(p1, 4, 5), Triangle, Triangle(p1, p2, p3)): + check(c, check_attr=False) + +#================== integrals ==================== +from sympy.integrals.integrals import Integral + + +def test_integrals(): + x = Symbol("x") + for c in (Integral, Integral(x)): + check(c) + +#==================== logic ===================== +from sympy.core.logic import Logic + + +def test_logic(): + for c in (Logic, Logic(1)): + check(c) + +#================== matrices ==================== +from sympy.matrices import Matrix, SparseMatrix + + +def test_matrices(): + for c in (Matrix, Matrix([1, 2, 3]), SparseMatrix, SparseMatrix([[1, 2], [3, 4]])): + check(c, deprecated=['_smat', '_mat']) + +#================== ntheory ===================== +from sympy.ntheory.generate import Sieve + + +def test_ntheory(): + for c in (Sieve, Sieve()): + check(c) + +#================== physics ===================== +from sympy.physics.paulialgebra import Pauli +from sympy.physics.units import Unit + + +def test_physics(): + for c in (Unit, meter, Pauli, Pauli(1)): + check(c) + +#================== plotting ==================== +# XXX: These tests are not complete, so XFAIL them + + +@XFAIL +def test_plotting(): + from sympy.plotting.pygletplot.color_scheme import ColorGradient, ColorScheme + from sympy.plotting.pygletplot.managed_window import ManagedWindow + from sympy.plotting.plot import Plot, ScreenShot + from sympy.plotting.pygletplot.plot_axes import PlotAxes, PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate + from sympy.plotting.pygletplot.plot_camera import PlotCamera + from sympy.plotting.pygletplot.plot_controller import PlotController + from sympy.plotting.pygletplot.plot_curve import PlotCurve + from sympy.plotting.pygletplot.plot_interval import PlotInterval + from sympy.plotting.pygletplot.plot_mode import PlotMode + from sympy.plotting.pygletplot.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \ + ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical + from sympy.plotting.pygletplot.plot_object import PlotObject + from sympy.plotting.pygletplot.plot_surface import PlotSurface + from sympy.plotting.pygletplot.plot_window import PlotWindow + for c in ( + ColorGradient, ColorGradient(0.2, 0.4), ColorScheme, ManagedWindow, + ManagedWindow, Plot, ScreenShot, PlotAxes, PlotAxesBase, + PlotAxesFrame, PlotAxesOrdinate, PlotCamera, PlotController, + PlotCurve, PlotInterval, PlotMode, Cartesian2D, Cartesian3D, + Cylindrical, ParametricCurve2D, ParametricCurve3D, + ParametricSurface, Polar, Spherical, PlotObject, PlotSurface, + PlotWindow): + check(c) + + +@XFAIL +def test_plotting2(): + #from sympy.plotting.color_scheme import ColorGradient + from sympy.plotting.pygletplot.color_scheme import ColorScheme + #from sympy.plotting.managed_window import ManagedWindow + from sympy.plotting.plot import Plot + #from sympy.plotting.plot import ScreenShot + from sympy.plotting.pygletplot.plot_axes import PlotAxes + #from sympy.plotting.plot_axes import PlotAxesBase, PlotAxesFrame, PlotAxesOrdinate + #from sympy.plotting.plot_camera import PlotCamera + #from sympy.plotting.plot_controller import PlotController + #from sympy.plotting.plot_curve import PlotCurve + #from sympy.plotting.plot_interval import PlotInterval + #from sympy.plotting.plot_mode import PlotMode + #from sympy.plotting.plot_modes import Cartesian2D, Cartesian3D, Cylindrical, \ + # ParametricCurve2D, ParametricCurve3D, ParametricSurface, Polar, Spherical + #from sympy.plotting.plot_object import PlotObject + #from sympy.plotting.plot_surface import PlotSurface + # from sympy.plotting.plot_window import PlotWindow + check(ColorScheme("rainbow")) + check(Plot(1, visible=False)) + check(PlotAxes()) + +#================== polys ======================= +from sympy.polys.domains.integerring import ZZ +from sympy.polys.domains.rationalfield import QQ +from sympy.polys.orderings import lex +from sympy.polys.polytools import Poly + +def test_pickling_polys_polytools(): + from sympy.polys.polytools import PurePoly + # from sympy.polys.polytools import GroebnerBasis + x = Symbol('x') + + for c in (Poly, Poly(x, x)): + check(c) + + for c in (PurePoly, PurePoly(x)): + check(c) + + # TODO: fix pickling of Options class (see GroebnerBasis._options) + # for c in (GroebnerBasis, GroebnerBasis([x**2 - 1], x, order=lex)): + # check(c) + +def test_pickling_polys_polyclasses(): + from sympy.polys.polyclasses import DMP, DMF, ANP + + for c in (DMP, DMP([[ZZ(1)], [ZZ(2)], [ZZ(3)]], ZZ)): + check(c, deprecated=['rep']) + for c in (DMF, DMF(([ZZ(1), ZZ(2)], [ZZ(1), ZZ(3)]), ZZ)): + check(c) + for c in (ANP, ANP([QQ(1), QQ(2)], [QQ(1), QQ(2), QQ(3)], QQ)): + check(c) + +@XFAIL +def test_pickling_polys_rings(): + # NOTE: can't use protocols < 2 because we have to execute __new__ to + # make sure caching of rings works properly. + + from sympy.polys.rings import PolyRing + + ring = PolyRing("x,y,z", ZZ, lex) + + for c in (PolyRing, ring): + check(c, exclude=[0, 1]) + + for c in (ring.dtype, ring.one): + check(c, exclude=[0, 1], check_attr=False) # TODO: Py3k + +def test_pickling_polys_fields(): + pass + # NOTE: can't use protocols < 2 because we have to execute __new__ to + # make sure caching of fields works properly. + + # from sympy.polys.fields import FracField + + # field = FracField("x,y,z", ZZ, lex) + + # TODO: AssertionError: assert id(obj) not in self.memo + # for c in (FracField, field): + # check(c, exclude=[0, 1]) + + # TODO: AssertionError: assert id(obj) not in self.memo + # for c in (field.dtype, field.one): + # check(c, exclude=[0, 1]) + +def test_pickling_polys_elements(): + from sympy.polys.domains.pythonrational import PythonRational + #from sympy.polys.domains.pythonfinitefield import PythonFiniteField + #from sympy.polys.domains.mpelements import MPContext + + for c in (PythonRational, PythonRational(1, 7)): + check(c) + + #gf = PythonFiniteField(17) + + # TODO: fix pickling of ModularInteger + # for c in (gf.dtype, gf(5)): + # check(c) + + #mp = MPContext() + + # TODO: fix pickling of RealElement + # for c in (mp.mpf, mp.mpf(1.0)): + # check(c) + + # TODO: fix pickling of ComplexElement + # for c in (mp.mpc, mp.mpc(1.0, -1.5)): + # check(c) + +def test_pickling_polys_domains(): + # from sympy.polys.domains.pythonfinitefield import PythonFiniteField + from sympy.polys.domains.pythonintegerring import PythonIntegerRing + from sympy.polys.domains.pythonrationalfield import PythonRationalField + + # TODO: fix pickling of ModularInteger + # for c in (PythonFiniteField, PythonFiniteField(17)): + # check(c) + + for c in (PythonIntegerRing, PythonIntegerRing()): + check(c, check_attr=False) + + for c in (PythonRationalField, PythonRationalField()): + check(c, check_attr=False) + + if _gmpy is not None: + # from sympy.polys.domains.gmpyfinitefield import GMPYFiniteField + from sympy.polys.domains.gmpyintegerring import GMPYIntegerRing + from sympy.polys.domains.gmpyrationalfield import GMPYRationalField + + # TODO: fix pickling of ModularInteger + # for c in (GMPYFiniteField, GMPYFiniteField(17)): + # check(c) + + for c in (GMPYIntegerRing, GMPYIntegerRing()): + check(c, check_attr=False) + + for c in (GMPYRationalField, GMPYRationalField()): + check(c, check_attr=False) + + #from sympy.polys.domains.realfield import RealField + #from sympy.polys.domains.complexfield import ComplexField + from sympy.polys.domains.algebraicfield import AlgebraicField + #from sympy.polys.domains.polynomialring import PolynomialRing + #from sympy.polys.domains.fractionfield import FractionField + from sympy.polys.domains.expressiondomain import ExpressionDomain + + # TODO: fix pickling of RealElement + # for c in (RealField, RealField(100)): + # check(c) + + # TODO: fix pickling of ComplexElement + # for c in (ComplexField, ComplexField(100)): + # check(c) + + for c in (AlgebraicField, AlgebraicField(QQ, sqrt(3))): + check(c, check_attr=False) + + # TODO: AssertionError + # for c in (PolynomialRing, PolynomialRing(ZZ, "x,y,z")): + # check(c) + + # TODO: AttributeError: 'PolyElement' object has no attribute 'ring' + # for c in (FractionField, FractionField(ZZ, "x,y,z")): + # check(c) + + for c in (ExpressionDomain, ExpressionDomain()): + check(c, check_attr=False) + + +def test_pickling_polys_orderings(): + from sympy.polys.orderings import (LexOrder, GradedLexOrder, + ReversedGradedLexOrder, InverseOrder) + # from sympy.polys.orderings import ProductOrder + + for c in (LexOrder, LexOrder()): + check(c) + + for c in (GradedLexOrder, GradedLexOrder()): + check(c) + + for c in (ReversedGradedLexOrder, ReversedGradedLexOrder()): + check(c) + + # TODO: Argh, Python is so naive. No lambdas nor inner function support in + # pickling module. Maybe someone could figure out what to do with this. + # + # for c in (ProductOrder, ProductOrder((LexOrder(), lambda m: m[:2]), + # (GradedLexOrder(), lambda m: m[2:]))): + # check(c) + + for c in (InverseOrder, InverseOrder(LexOrder())): + check(c) + +def test_pickling_polys_monomials(): + from sympy.polys.monomials import MonomialOps, Monomial + x, y, z = symbols("x,y,z") + + for c in (MonomialOps, MonomialOps(3)): + check(c) + + for c in (Monomial, Monomial((1, 2, 3), (x, y, z))): + check(c) + +def test_pickling_polys_errors(): + from sympy.polys.polyerrors import (HeuristicGCDFailed, + HomomorphismFailed, IsomorphismFailed, ExtraneousFactors, + EvaluationFailed, RefinementFailed, CoercionFailed, NotInvertible, + NotReversible, NotAlgebraic, DomainError, PolynomialError, + UnificationFailed, GeneratorsError, GeneratorsNeeded, + UnivariatePolynomialError, MultivariatePolynomialError, OptionError, + FlagError) + # from sympy.polys.polyerrors import (ExactQuotientFailed, + # OperationNotSupported, ComputationFailed, PolificationFailed) + + # x = Symbol('x') + + # TODO: TypeError: __init__() takes at least 3 arguments (1 given) + # for c in (ExactQuotientFailed, ExactQuotientFailed(x, 3*x, ZZ)): + # check(c) + + # TODO: TypeError: can't pickle instancemethod objects + # for c in (OperationNotSupported, OperationNotSupported(Poly(x), Poly.gcd)): + # check(c) + + for c in (HeuristicGCDFailed, HeuristicGCDFailed()): + check(c) + + for c in (HomomorphismFailed, HomomorphismFailed()): + check(c) + + for c in (IsomorphismFailed, IsomorphismFailed()): + check(c) + + for c in (ExtraneousFactors, ExtraneousFactors()): + check(c) + + for c in (EvaluationFailed, EvaluationFailed()): + check(c) + + for c in (RefinementFailed, RefinementFailed()): + check(c) + + for c in (CoercionFailed, CoercionFailed()): + check(c) + + for c in (NotInvertible, NotInvertible()): + check(c) + + for c in (NotReversible, NotReversible()): + check(c) + + for c in (NotAlgebraic, NotAlgebraic()): + check(c) + + for c in (DomainError, DomainError()): + check(c) + + for c in (PolynomialError, PolynomialError()): + check(c) + + for c in (UnificationFailed, UnificationFailed()): + check(c) + + for c in (GeneratorsError, GeneratorsError()): + check(c) + + for c in (GeneratorsNeeded, GeneratorsNeeded()): + check(c) + + # TODO: PicklingError: Can't pickle at 0x38578c0>: it's not found as __main__. + # for c in (ComputationFailed, ComputationFailed(lambda t: t, 3, None)): + # check(c) + + for c in (UnivariatePolynomialError, UnivariatePolynomialError()): + check(c) + + for c in (MultivariatePolynomialError, MultivariatePolynomialError()): + check(c) + + # TODO: TypeError: __init__() takes at least 3 arguments (1 given) + # for c in (PolificationFailed, PolificationFailed({}, x, x, False)): + # check(c) + + for c in (OptionError, OptionError()): + check(c) + + for c in (FlagError, FlagError()): + check(c) + +#def test_pickling_polys_options(): + #from sympy.polys.polyoptions import Options + + # TODO: fix pickling of `symbols' flag + # for c in (Options, Options((), dict(domain='ZZ', polys=False))): + # check(c) + +# TODO: def test_pickling_polys_rootisolation(): +# RealInterval +# ComplexInterval + +def test_pickling_polys_rootoftools(): + from sympy.polys.rootoftools import CRootOf, RootSum + + x = Symbol('x') + f = x**3 + x + 3 + + for c in (CRootOf, CRootOf(f, 0)): + check(c) + + for c in (RootSum, RootSum(f, exp)): + check(c) + +#================== printing ==================== +from sympy.printing.latex import LatexPrinter +from sympy.printing.mathml import MathMLContentPrinter, MathMLPresentationPrinter +from sympy.printing.pretty.pretty import PrettyPrinter +from sympy.printing.pretty.stringpict import prettyForm, stringPict +from sympy.printing.printer import Printer +from sympy.printing.python import PythonPrinter + + +def test_printing(): + for c in (LatexPrinter, LatexPrinter(), MathMLContentPrinter, + MathMLPresentationPrinter, PrettyPrinter, prettyForm, stringPict, + stringPict("a"), Printer, Printer(), PythonPrinter, + PythonPrinter()): + check(c) + + +@XFAIL +def test_printing1(): + check(MathMLContentPrinter()) + + +@XFAIL +def test_printing2(): + check(MathMLPresentationPrinter()) + + +@XFAIL +def test_printing3(): + check(PrettyPrinter()) + +#================== series ====================== +from sympy.series.limits import Limit +from sympy.series.order import Order + + +def test_series(): + e = Symbol("e") + x = Symbol("x") + for c in (Limit, Limit(e, x, 1), Order, Order(e)): + check(c) + +#================== concrete ================== +from sympy.concrete.products import Product +from sympy.concrete.summations import Sum + + +def test_concrete(): + x = Symbol("x") + for c in (Product, Product(x, (x, 2, 4)), Sum, Sum(x, (x, 2, 4))): + check(c) + +def test_deprecation_warning(): + w = SymPyDeprecationWarning("message", deprecated_since_version='1.0', active_deprecations_target="active-deprecations") + check(w) + +def test_issue_18438(): + assert pickle.loads(pickle.dumps(S.Half)) == S.Half + + +#================= old pickles ================= +def test_unpickle_from_older_versions(): + data = ( + b'\x80\x04\x95^\x00\x00\x00\x00\x00\x00\x00\x8c\x10sympy.core.power' + b'\x94\x8c\x03Pow\x94\x93\x94\x8c\x12sympy.core.numbers\x94\x8c' + b'\x07Integer\x94\x93\x94K\x02\x85\x94R\x94}\x94bh\x03\x8c\x04Half' + b'\x94\x93\x94)R\x94}\x94b\x86\x94R\x94}\x94b.' + ) + assert pickle.loads(data) == sqrt(2) diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_source.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_source.py new file mode 100644 index 0000000000000000000000000000000000000000..468185bc579fc325aee21024dfa15ebf14287b5f --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_source.py @@ -0,0 +1,11 @@ +from sympy.utilities.source import get_mod_func, get_class + + +def test_get_mod_func(): + assert get_mod_func( + 'sympy.core.basic.Basic') == ('sympy.core.basic', 'Basic') + + +def test_get_class(): + _basic = get_class('sympy.core.basic.Basic') + assert _basic.__name__ == 'Basic' diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_timeutils.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_timeutils.py new file mode 100644 index 0000000000000000000000000000000000000000..14edfd089c7315ee9f39a4298af0289f8919da6b --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_timeutils.py @@ -0,0 +1,10 @@ +"""Tests for simple tools for timing functions' execution. """ + +from sympy.utilities.timeutils import timed + +def test_timed(): + result = timed(lambda: 1 + 1, limit=100000) + assert result[0] == 100000 and result[3] == "ns", str(result) + + result = timed("1 + 1", limit=100000) + assert result[0] == 100000 and result[3] == "ns" diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_wester.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_wester.py new file mode 100644 index 0000000000000000000000000000000000000000..848dbdae82bcbfa6a1374d91003209a0d6a2ab8e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_wester.py @@ -0,0 +1,3104 @@ +""" Tests from Michael Wester's 1999 paper "Review of CAS mathematical +capabilities". + +http://www.math.unm.edu/~wester/cas/book/Wester.pdf +See also http://math.unm.edu/~wester/cas_review.html for detailed output of +each tested system. +""" + +from sympy.assumptions.ask import Q, ask +from sympy.assumptions.refine import refine +from sympy.concrete.products import product +from sympy.core import EulerGamma +from sympy.core.evalf import N +from sympy.core.function import (Derivative, Function, Lambda, Subs, + diff, expand, expand_func) +from sympy.core.mul import Mul +from sympy.core.intfunc import igcd +from sympy.core.numbers import (AlgebraicNumber, E, I, Rational, + nan, oo, pi, zoo) +from sympy.core.relational import Eq, Lt +from sympy.core.singleton import S +from sympy.core.symbol import Dummy, Symbol, symbols +from sympy.functions.combinatorial.factorials import (rf, binomial, + factorial, factorial2) +from sympy.functions.combinatorial.numbers import bernoulli, fibonacci, totient, partition +from sympy.functions.elementary.complexes import (conjugate, im, re, + sign) +from sympy.functions.elementary.exponential import LambertW, exp, log +from sympy.functions.elementary.hyperbolic import (asinh, cosh, sinh, + tanh) +from sympy.functions.elementary.integers import ceiling, floor +from sympy.functions.elementary.miscellaneous import Max, Min, sqrt +from sympy.functions.elementary.piecewise import Piecewise +from sympy.functions.elementary.trigonometric import (acos, acot, asin, + atan, cos, cot, csc, sec, sin, tan) +from sympy.functions.special.bessel import besselj +from sympy.functions.special.delta_functions import DiracDelta +from sympy.functions.special.elliptic_integrals import (elliptic_e, + elliptic_f) +from sympy.functions.special.gamma_functions import gamma, polygamma +from sympy.functions.special.hyper import hyper +from sympy.functions.special.polynomials import (assoc_legendre, + chebyshevt) +from sympy.functions.special.zeta_functions import polylog +from sympy.geometry.util import idiff +from sympy.logic.boolalg import And +from sympy.matrices.dense import hessian, wronskian +from sympy.matrices.expressions.matmul import MatMul +from sympy.ntheory.continued_fraction import ( + continued_fraction_convergents as cf_c, + continued_fraction_iterator as cf_i, continued_fraction_periodic as + cf_p, continued_fraction_reduce as cf_r) +from sympy.ntheory.factor_ import factorint +from sympy.ntheory.generate import primerange +from sympy.polys.domains.integerring import ZZ +from sympy.polys.orthopolys import legendre_poly +from sympy.polys.partfrac import apart +from sympy.polys.polytools import Poly, factor, gcd, resultant +from sympy.series.limits import limit +from sympy.series.order import O +from sympy.series.residues import residue +from sympy.series.series import series +from sympy.sets.fancysets import ImageSet +from sympy.sets.sets import FiniteSet, Intersection, Interval, Union +from sympy.simplify.combsimp import combsimp +from sympy.simplify.hyperexpand import hyperexpand +from sympy.simplify.powsimp import powdenest, powsimp +from sympy.simplify.radsimp import radsimp +from sympy.simplify.simplify import logcombine, simplify +from sympy.simplify.sqrtdenest import sqrtdenest +from sympy.simplify.trigsimp import trigsimp +from sympy.solvers.solvers import solve + +import mpmath +from sympy.functions.combinatorial.numbers import stirling +from sympy.functions.special.delta_functions import Heaviside +from sympy.functions.special.error_functions import Ci, Si, erf +from sympy.functions.special.zeta_functions import zeta +from sympy.testing.pytest import (XFAIL, slow, SKIP, tooslow, raises) +from sympy.utilities.iterables import partitions +from mpmath import mpi, mpc +from sympy.matrices import Matrix, GramSchmidt, eye +from sympy.matrices.expressions.blockmatrix import BlockMatrix, block_collapse +from sympy.matrices.expressions import MatrixSymbol, ZeroMatrix +from sympy.physics.quantum import Commutator +from sympy.polys.rings import PolyRing +from sympy.polys.fields import FracField +from sympy.polys.solvers import solve_lin_sys +from sympy.concrete import Sum +from sympy.concrete.products import Product +from sympy.integrals import integrate +from sympy.integrals.transforms import laplace_transform,\ + inverse_laplace_transform, LaplaceTransform, fourier_transform,\ + mellin_transform, laplace_correspondence, laplace_initial_conds +from sympy.solvers.recurr import rsolve +from sympy.solvers.solveset import solveset, solveset_real, linsolve +from sympy.solvers.ode import dsolve +from sympy.core.relational import Equality +from itertools import islice, takewhile +from sympy.series.formal import fps +from sympy.series.fourier import fourier_series +from sympy.calculus.util import minimum + + +EmptySet = S.EmptySet +R = Rational +x, y, z = symbols('x y z') +i, j, k, l, m, n = symbols('i j k l m n', integer=True) +f = Function('f') +g = Function('g') + +# A. Boolean Logic and Quantifier Elimination +# Not implemented. + +# B. Set Theory + + +def test_B1(): + assert (FiniteSet(i, j, j, k, k, k) | FiniteSet(l, k, j) | + FiniteSet(j, m, j)) == FiniteSet(i, j, k, l, m) + + +def test_B2(): + assert (FiniteSet(i, j, j, k, k, k) & FiniteSet(l, k, j) & + FiniteSet(j, m, j)) == Intersection({j, m}, {i, j, k}, {j, k, l}) + # Previous output below. Not sure why that should be the expected output. + # There should probably be a way to rewrite Intersections that way but I + # don't see why an Intersection should evaluate like that: + # + # == Union({j}, Intersection({m}, Union({j, k}, Intersection({i}, {l})))) + + +def test_B3(): + assert (FiniteSet(i, j, k, l, m) - FiniteSet(j) == + FiniteSet(i, k, l, m)) + + +def test_B4(): + assert (FiniteSet(*(FiniteSet(i, j)*FiniteSet(k, l))) == + FiniteSet((i, k), (i, l), (j, k), (j, l))) + + +# C. Numbers + + +def test_C1(): + assert (factorial(50) == + 30414093201713378043612608166064768844377641568960512000000000000) + + +def test_C2(): + assert (factorint(factorial(50)) == {2: 47, 3: 22, 5: 12, 7: 8, + 11: 4, 13: 3, 17: 2, 19: 2, 23: 2, 29: 1, 31: 1, 37: 1, + 41: 1, 43: 1, 47: 1}) + + +def test_C3(): + assert (factorial2(10), factorial2(9)) == (3840, 945) + + +# Base conversions; not really implemented by SymPy +# Whatever. Take credit! +def test_C4(): + assert 0xABC == 2748 + + +def test_C5(): + assert 123 == int('234', 7) + + +def test_C6(): + assert int('677', 8) == int('1BF', 16) == 447 + + +def test_C7(): + assert log(32768, 8) == 5 + + +def test_C8(): + # Modular multiplicative inverse. Would be nice if divmod could do this. + assert ZZ.invert(5, 7) == 3 + assert ZZ.invert(5, 6) == 5 + + +def test_C9(): + assert igcd(igcd(1776, 1554), 5698) == 74 + + +def test_C10(): + x = 0 + for n in range(2, 11): + x += R(1, n) + assert x == R(4861, 2520) + + +def test_C11(): + assert R(1, 7) == S('0.[142857]') + + +def test_C12(): + assert R(7, 11) * R(22, 7) == 2 + + +def test_C13(): + test = R(10, 7) * (1 + R(29, 1000)) ** R(1, 3) + good = 3 ** R(1, 3) + assert test == good + + +def test_C14(): + assert sqrtdenest(sqrt(2*sqrt(3) + 4)) == 1 + sqrt(3) + + +def test_C15(): + test = sqrtdenest(sqrt(14 + 3*sqrt(3 + 2*sqrt(5 - 12*sqrt(3 - 2*sqrt(2)))))) + good = sqrt(2) + 3 + assert test == good + + +def test_C16(): + test = sqrtdenest(sqrt(10 + 2*sqrt(6) + 2*sqrt(10) + 2*sqrt(15))) + good = sqrt(2) + sqrt(3) + sqrt(5) + assert test == good + + +def test_C17(): + test = radsimp((sqrt(3) + sqrt(2)) / (sqrt(3) - sqrt(2))) + good = 5 + 2*sqrt(6) + assert test == good + + +def test_C18(): + assert simplify((sqrt(-2 + sqrt(-5)) * sqrt(-2 - sqrt(-5))).expand(complex=True)) == 3 + + +@XFAIL +def test_C19(): + assert radsimp(simplify((90 + 34*sqrt(7)) ** R(1, 3))) == 3 + sqrt(7) + + +def test_C20(): + inside = (135 + 78*sqrt(3)) + test = AlgebraicNumber((inside**R(2, 3) + 3) * sqrt(3) / inside**R(1, 3)) + assert simplify(test) == AlgebraicNumber(12) + + +def test_C21(): + assert simplify(AlgebraicNumber((41 + 29*sqrt(2)) ** R(1, 5))) == \ + AlgebraicNumber(1 + sqrt(2)) + + +@XFAIL +def test_C22(): + test = simplify(((6 - 4*sqrt(2))*log(3 - 2*sqrt(2)) + (3 - 2*sqrt(2))*log(17 + - 12*sqrt(2)) + 32 - 24*sqrt(2)) / (48*sqrt(2) - 72)) + good = sqrt(2)/3 - log(sqrt(2) - 1)/3 + assert test == good + + +def test_C23(): + assert 2 * oo - 3 is oo + + +@XFAIL +def test_C24(): + raise NotImplementedError("2**aleph_null == aleph_1") + +# D. Numerical Analysis + + +def test_D1(): + assert 0.0 / sqrt(2) == 0.0 + + +def test_D2(): + assert str(exp(-1000000).evalf()) == '3.29683147808856e-434295' + + +def test_D3(): + assert exp(pi*sqrt(163)).evalf(50).num.ae(262537412640768744) + + +def test_D4(): + assert floor(R(-5, 3)) == -2 + assert ceiling(R(-5, 3)) == -1 + + +@XFAIL +def test_D5(): + raise NotImplementedError("cubic_spline([1, 2, 4, 5], [1, 4, 2, 3], x)(3) == 27/8") + + +@XFAIL +def test_D6(): + raise NotImplementedError("translate sum(a[i]*x**i, (i,1,n)) to FORTRAN") + + +@XFAIL +def test_D7(): + raise NotImplementedError("translate sum(a[i]*x**i, (i,1,n)) to C") + + +@XFAIL +def test_D8(): + # One way is to cheat by converting the sum to a string, + # and replacing the '[' and ']' with ''. + # E.g., horner(S(str(_).replace('[','').replace(']',''))) + raise NotImplementedError("apply Horner's rule to sum(a[i]*x**i, (i,1,5))") + + +@XFAIL +def test_D9(): + raise NotImplementedError("translate D8 to FORTRAN") + + +@XFAIL +def test_D10(): + raise NotImplementedError("translate D8 to C") + + +@XFAIL +def test_D11(): + #Is there a way to use count_ops? + raise NotImplementedError("flops(sum(product(f[i][k], (i,1,k)), (k,1,n)))") + + +@XFAIL +def test_D12(): + assert (mpi(-4, 2) * x + mpi(1, 3)) ** 2 == mpi(-8, 16)*x**2 + mpi(-24, 12)*x + mpi(1, 9) + + +@XFAIL +def test_D13(): + raise NotImplementedError("discretize a PDE: diff(f(x,t),t) == diff(diff(f(x,t),x),x)") + +# E. Statistics +# See scipy; all of this is numerical. + +# F. Combinatorial Theory. + + +def test_F1(): + assert rf(x, 3) == x*(1 + x)*(2 + x) + + +def test_F2(): + assert expand_func(binomial(n, 3)) == n*(n - 1)*(n - 2)/6 + + +@XFAIL +def test_F3(): + assert combsimp(2**n * factorial(n) * factorial2(2*n - 1)) == factorial(2*n) + + +@XFAIL +def test_F4(): + assert combsimp(2**n * factorial(n) * product(2*k - 1, (k, 1, n))) == factorial(2*n) + + +@XFAIL +def test_F5(): + assert gamma(n + R(1, 2)) / sqrt(pi) / factorial(n) == factorial(2*n)/2**(2*n)/factorial(n)**2 + + +def test_F6(): + partTest = [p.copy() for p in partitions(4)] + partDesired = [{4: 1}, {1: 1, 3: 1}, {2: 2}, {1: 2, 2:1}, {1: 4}] + assert partTest == partDesired + + +def test_F7(): + assert partition(4) == 5 + + +def test_F8(): + assert stirling(5, 2, signed=True) == -50 # if signed, then kind=1 + + +def test_F9(): + assert totient(1776) == 576 + +# G. Number Theory + + +def test_G1(): + assert list(primerange(999983, 1000004)) == [999983, 1000003] + + +@XFAIL +def test_G2(): + raise NotImplementedError("find the primitive root of 191 == 19") + + +@XFAIL +def test_G3(): + raise NotImplementedError("(a+b)**p mod p == a**p + b**p mod p; p prime") + +# ... G14 Modular equations are not implemented. + +def test_G15(): + assert Rational(sqrt(3).evalf()).limit_denominator(15) == R(26, 15) + assert list(takewhile(lambda x: x.q <= 15, cf_c(cf_i(sqrt(3)))))[-1] == \ + R(26, 15) + + +def test_G16(): + assert list(islice(cf_i(pi),10)) == [3, 7, 15, 1, 292, 1, 1, 1, 2, 1] + + +def test_G17(): + assert cf_p(0, 1, 23) == [4, [1, 3, 1, 8]] + + +def test_G18(): + assert cf_p(1, 2, 5) == [[1]] + assert cf_r([[1]]).expand() == S.Half + sqrt(5)/2 + + +@XFAIL +def test_G19(): + s = symbols('s', integer=True, positive=True) + it = cf_i((exp(1/s) - 1)/(exp(1/s) + 1)) + assert list(islice(it, 5)) == [0, 2*s, 6*s, 10*s, 14*s] + + +def test_G20(): + s = symbols('s', integer=True, positive=True) + # Wester erroneously has this as -s + sqrt(s**2 + 1) + assert cf_r([[2*s]]) == s + sqrt(s**2 + 1) + + +@XFAIL +def test_G20b(): + s = symbols('s', integer=True, positive=True) + assert cf_p(s, 1, s**2 + 1) == [[2*s]] + + +# H. Algebra + + +def test_H1(): + assert simplify(2*2**n) == simplify(2**(n + 1)) + assert powdenest(2*2**n) == simplify(2**(n + 1)) + + +def test_H2(): + assert powsimp(4 * 2**n) == 2**(n + 2) + + +def test_H3(): + assert (-1)**(n*(n + 1)) == 1 + + +def test_H4(): + expr = factor(6*x - 10) + assert type(expr) is Mul + assert expr.args[0] == 2 + assert expr.args[1] == 3*x - 5 + +p1 = 64*x**34 - 21*x**47 - 126*x**8 - 46*x**5 - 16*x**60 - 81 +p2 = 72*x**60 - 25*x**25 - 19*x**23 - 22*x**39 - 83*x**52 + 54*x**10 + 81 +q = 34*x**19 - 25*x**16 + 70*x**7 + 20*x**3 - 91*x - 86 + + +def test_H5(): + assert gcd(p1, p2, x) == 1 + + +def test_H6(): + assert gcd(expand(p1 * q), expand(p2 * q)) == q + + +def test_H7(): + p1 = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5 + p2 = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z + assert gcd(p1, p2, x, y, z) == 1 + + +def test_H8(): + p1 = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5 + p2 = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z + q = 11*x**12*y**7*z**13 - 23*x**2*y**8*z**10 + 47*x**17*y**5*z**8 + assert gcd(p1 * q, p2 * q, x, y, z) == q + + +def test_H9(): + x = Symbol('x', zero=False) + p1 = 2*x**(n + 4) - x**(n + 2) + p2 = 4*x**(n + 1) + 3*x**n + assert gcd(p1, p2) == x**n + + +def test_H10(): + p1 = 3*x**4 + 3*x**3 + x**2 - x - 2 + p2 = x**3 - 3*x**2 + x + 5 + assert resultant(p1, p2, x) == 0 + + +def test_H11(): + assert resultant(p1 * q, p2 * q, x) == 0 + + +def test_H12(): + num = x**2 - 4 + den = x**2 + 4*x + 4 + assert simplify(num/den) == (x - 2)/(x + 2) + + +@XFAIL +def test_H13(): + assert simplify((exp(x) - 1) / (exp(x/2) + 1)) == exp(x/2) - 1 + + +def test_H14(): + p = (x + 1) ** 20 + ep = expand(p) + assert ep == (1 + 20*x + 190*x**2 + 1140*x**3 + 4845*x**4 + 15504*x**5 + + 38760*x**6 + 77520*x**7 + 125970*x**8 + 167960*x**9 + 184756*x**10 + + 167960*x**11 + 125970*x**12 + 77520*x**13 + 38760*x**14 + 15504*x**15 + + 4845*x**16 + 1140*x**17 + 190*x**18 + 20*x**19 + x**20) + dep = diff(ep, x) + assert dep == (20 + 380*x + 3420*x**2 + 19380*x**3 + 77520*x**4 + + 232560*x**5 + 542640*x**6 + 1007760*x**7 + 1511640*x**8 + 1847560*x**9 + + 1847560*x**10 + 1511640*x**11 + 1007760*x**12 + 542640*x**13 + + 232560*x**14 + 77520*x**15 + 19380*x**16 + 3420*x**17 + 380*x**18 + + 20*x**19) + assert factor(dep) == 20*(1 + x)**19 + + +def test_H15(): + assert simplify(Mul(*[x - r for r in solveset(x**3 + x**2 - 7)])) == x**3 + x**2 - 7 + + +def test_H16(): + assert factor(x**100 - 1) == ((x - 1)*(x + 1)*(x**2 + 1)*(x**4 - x**3 + + x**2 - x + 1)*(x**4 + x**3 + x**2 + x + 1)*(x**8 - x**6 + x**4 + - x**2 + 1)*(x**20 - x**15 + x**10 - x**5 + 1)*(x**20 + x**15 + x**10 + + x**5 + 1)*(x**40 - x**30 + x**20 - x**10 + 1)) + + +def test_H17(): + assert simplify(factor(expand(p1 * p2)) - p1*p2) == 0 + + +@XFAIL +def test_H18(): + # Factor over complex rationals. + test = factor(4*x**4 + 8*x**3 + 77*x**2 + 18*x + 153) + good = (2*x + 3*I)*(2*x - 3*I)*(x + 1 - 4*I)*(x + 1 + 4*I) + assert test == good + + +def test_H19(): + a = symbols('a') + # The idea is to let a**2 == 2, then solve 1/(a-1). Answer is a+1") + assert Poly(a - 1).invert(Poly(a**2 - 2)) == a + 1 + + +@XFAIL +def test_H20(): + raise NotImplementedError("let a**2==2; (x**3 + (a-2)*x**2 - " + + "(2*a+3)*x - 3*a) / (x**2-2) = (x**2 - 2*x - 3) / (x-a)") + + +@XFAIL +def test_H21(): + raise NotImplementedError("evaluate (b+c)**4 assuming b**3==2, c**2==3. \ + Answer is 2*b + 8*c + 18*b**2 + 12*b*c + 9") + + +def test_H22(): + assert factor(x**4 - 3*x**2 + 1, modulus=5) == (x - 2)**2 * (x + 2)**2 + + +def test_H23(): + f = x**11 + x + 1 + g = (x**2 + x + 1) * (x**9 - x**8 + x**6 - x**5 + x**3 - x**2 + 1) + assert factor(f, modulus=65537) == g + + +def test_H24(): + phi = AlgebraicNumber(S.GoldenRatio.expand(func=True), alias='phi') + assert factor(x**4 - 3*x**2 + 1, extension=phi) == \ + (x - phi)*(x + 1 - phi)*(x - 1 + phi)*(x + phi) + + +def test_H25(): + e = (x - 2*y**2 + 3*z**3) ** 20 + assert factor(expand(e)) == e + + +def test_H26(): + g = expand((sin(x) - 2*cos(y)**2 + 3*tan(z)**3)**20) + assert factor(g, expand=False) == (-sin(x) + 2*cos(y)**2 - 3*tan(z)**3)**20 + + +def test_H27(): + f = 24*x*y**19*z**8 - 47*x**17*y**5*z**8 + 6*x**15*y**9*z**2 - 3*x**22 + 5 + g = 34*x**5*y**8*z**13 + 20*x**7*y**7*z**7 + 12*x**9*y**16*z**4 + 80*y**14*z + h = -2*z*y**7 \ + *(6*x**9*y**9*z**3 + 10*x**7*z**6 + 17*y*x**5*z**12 + 40*y**7) \ + *(3*x**22 + 47*x**17*y**5*z**8 - 6*x**15*y**9*z**2 - 24*x*y**19*z**8 - 5) + assert factor(expand(f*g)) == h + + +@XFAIL +def test_H28(): + raise NotImplementedError("expand ((1 - c**2)**5 * (1 - s**2)**5 * " + + "(c**2 + s**2)**10) with c**2 + s**2 = 1. Answer is c**10*s**10.") + + +@XFAIL +def test_H29(): + assert factor(4*x**2 - 21*x*y + 20*y**2, modulus=3) == (x + y)*(x - y) + + +def test_H30(): + test = factor(x**3 + y**3, extension=sqrt(-3)) + answer = (x + y)*(x + y*(-R(1, 2) - sqrt(3)/2*I))*(x + y*(-R(1, 2) + sqrt(3)/2*I)) + assert answer == test + + +def test_H31(): + f = (x**2 + 2*x + 3)/(x**3 + 4*x**2 + 5*x + 2) + g = 2 / (x + 1)**2 - 2 / (x + 1) + 3 / (x + 2) + assert apart(f) == g + + +@XFAIL +def test_H32(): # issue 6558 + raise NotImplementedError("[A*B*C - (A*B*C)**(-1)]*A*C*B (product \ + of a non-commuting product and its inverse)") + + +def test_H33(): + A, B, C = symbols('A, B, C', commutative=False) + assert (Commutator(A, Commutator(B, C)) + + Commutator(B, Commutator(C, A)) + + Commutator(C, Commutator(A, B))).doit().expand() == 0 + + +# I. Trigonometry + +def test_I1(): + assert tan(pi*R(7, 10)) == -sqrt(1 + 2/sqrt(5)) + + +@XFAIL +def test_I2(): + assert sqrt((1 + cos(6))/2) == -cos(3) + + +def test_I3(): + assert cos(n*pi) + sin((4*n - 1)*pi/2) == (-1)**n - 1 + + +def test_I4(): + assert refine(cos(pi*cos(n*pi)) + sin(pi/2*cos(n*pi)), Q.integer(n)) == (-1)**n - 1 + + +@XFAIL +def test_I5(): + assert sin((n**5/5 + n**4/2 + n**3/3 - n/30) * pi) == 0 + + +@XFAIL +def test_I6(): + raise NotImplementedError("assuming -3*pi pi**E) + + +@XFAIL +def test_N2(): + x = symbols('x', real=True) + assert ask(x**4 - x + 1 > 0) is True + assert ask(x**4 - x + 1 > 1) is False + + +@XFAIL +def test_N3(): + x = symbols('x', real=True) + assert ask(And(Lt(-1, x), Lt(x, 1)), abs(x) < 1 ) + +@XFAIL +def test_N4(): + x, y = symbols('x y', real=True) + assert ask(2*x**2 > 2*y**2, (x > y) & (y > 0)) is True + + +@XFAIL +def test_N5(): + x, y, k = symbols('x y k', real=True) + assert ask(k*x**2 > k*y**2, (x > y) & (y > 0) & (k > 0)) is True + + +@slow +@XFAIL +def test_N6(): + x, y, k, n = symbols('x y k n', real=True) + assert ask(k*x**n > k*y**n, (x > y) & (y > 0) & (k > 0) & (n > 0)) is True + + +@XFAIL +def test_N7(): + x, y = symbols('x y', real=True) + assert ask(y > 0, (x > 1) & (y >= x - 1)) is True + + +@XFAIL +@slow +def test_N8(): + x, y, z = symbols('x y z', real=True) + assert ask(Eq(x, y) & Eq(y, z), + (x >= y) & (y >= z) & (z >= x)) + + +def test_N9(): + x = Symbol('x') + assert solveset(abs(x - 1) > 2, domain=S.Reals) == Union(Interval(-oo, -1, False, True), + Interval(3, oo, True)) + + +def test_N10(): + x = Symbol('x') + p = (x - 1)*(x - 2)*(x - 3)*(x - 4)*(x - 5) + assert solveset(expand(p) < 0, domain=S.Reals) == Union(Interval(-oo, 1, True, True), + Interval(2, 3, True, True), + Interval(4, 5, True, True)) + + +def test_N11(): + x = Symbol('x') + assert solveset(6/(x - 3) <= 3, domain=S.Reals) == Union(Interval(-oo, 3, True, True), Interval(5, oo)) + + +def test_N12(): + x = Symbol('x') + assert solveset(sqrt(x) < 2, domain=S.Reals) == Interval(0, 4, False, True) + + +def test_N13(): + x = Symbol('x') + assert solveset(sin(x) < 2, domain=S.Reals) == S.Reals + + +@XFAIL +def test_N14(): + x = Symbol('x') + # Gives 'Union(Interval(Integer(0), Mul(Rational(1, 2), pi), false, true), + # Interval(Mul(Rational(1, 2), pi), Mul(Integer(2), pi), true, false))' + # which is not the correct answer, but the provided also seems wrong. + assert solveset(sin(x) < 1, x, domain=S.Reals) == Union(Interval(-oo, pi/2, True, True), + Interval(pi/2, oo, True, True)) + + +def test_N15(): + r, t = symbols('r t') + # raises NotImplementedError: only univariate inequalities are supported + solveset(abs(2*r*(cos(t) - 1) + 1) <= 1, r, S.Reals) + + +def test_N16(): + r, t = symbols('r t') + solveset((r**2)*((cos(t) - 4)**2)*sin(t)**2 < 9, r, S.Reals) + + +@XFAIL +def test_N17(): + # currently only univariate inequalities are supported + assert solveset((x + y > 0, x - y < 0), (x, y)) == (abs(x) < y) + + +def test_O1(): + M = Matrix((1 + I, -2, 3*I)) + assert sqrt(expand(M.dot(M.H))) == sqrt(15) + + +def test_O2(): + assert Matrix((2, 2, -3)).cross(Matrix((1, 3, 1))) == Matrix([[11], + [-5], + [4]]) + +# The vector module has no way of representing vectors symbolically (without +# respect to a basis) +@XFAIL +def test_O3(): + # assert (va ^ vb) | (vc ^ vd) == -(va | vc)*(vb | vd) + (va | vd)*(vb | vc) + raise NotImplementedError("""The vector module has no way of representing + vectors symbolically (without respect to a basis)""") + +def test_O4(): + from sympy.vector import CoordSys3D, Del + N = CoordSys3D("N") + delop = Del() + i, j, k = N.base_vectors() + x, y, z = N.base_scalars() + F = i*(x*y*z) + j*((x*y*z)**2) + k*((y**2)*(z**3)) + assert delop.cross(F).doit() == (-2*x**2*y**2*z + 2*y*z**3)*i + x*y*j + (2*x*y**2*z**2 - x*z)*k + +@XFAIL +def test_O5(): + #assert grad|(f^g)-g|(grad^f)+f|(grad^g) == 0 + raise NotImplementedError("""The vector module has no way of representing + vectors symbolically (without respect to a basis)""") + +#testO8-O9 MISSING!! + + +def test_O10(): + L = [Matrix([2, 3, 5]), Matrix([3, 6, 2]), Matrix([8, 3, 6])] + assert GramSchmidt(L) == [Matrix([ + [2], + [3], + [5]]), + Matrix([ + [R(23, 19)], + [R(63, 19)], + [R(-47, 19)]]), + Matrix([ + [R(1692, 353)], + [R(-1551, 706)], + [R(-423, 706)]])] + + +def test_P1(): + assert Matrix(3, 3, lambda i, j: j - i).diagonal(-1) == Matrix( + 1, 2, [-1, -1]) + + +def test_P2(): + M = Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) + M.row_del(1) + M.col_del(2) + assert M == Matrix([[1, 2], + [7, 8]]) + + +def test_P3(): + A = Matrix([ + [11, 12, 13, 14], + [21, 22, 23, 24], + [31, 32, 33, 34], + [41, 42, 43, 44]]) + + A11 = A[0:3, 1:4] + A12 = A[(0, 1, 3), (2, 0, 3)] + A21 = A + A221 = -A[0:2, 2:4] + A222 = -A[(3, 0), (2, 1)] + A22 = BlockMatrix([[A221, A222]]).T + rows = [[-A11, A12], [A21, A22]] + raises(ValueError, lambda: BlockMatrix(rows)) + B = Matrix(rows) + assert B == Matrix([ + [-12, -13, -14, 13, 11, 14], + [-22, -23, -24, 23, 21, 24], + [-32, -33, -34, 43, 41, 44], + [11, 12, 13, 14, -13, -23], + [21, 22, 23, 24, -14, -24], + [31, 32, 33, 34, -43, -13], + [41, 42, 43, 44, -42, -12]]) + + +@XFAIL +def test_P4(): + raise NotImplementedError("Block matrix diagonalization not supported") + + +def test_P5(): + M = Matrix([[7, 11], + [3, 8]]) + assert M % 2 == Matrix([[1, 1], + [1, 0]]) + + +def test_P6(): + M = Matrix([[cos(x), sin(x)], + [-sin(x), cos(x)]]) + assert M.diff(x, 2) == Matrix([[-cos(x), -sin(x)], + [sin(x), -cos(x)]]) + + +def test_P7(): + M = Matrix([[x, y]])*( + z*Matrix([[1, 3, 5], + [2, 4, 6]]) + Matrix([[7, -9, 11], + [-8, 10, -12]])) + assert M == Matrix([[x*(z + 7) + y*(2*z - 8), x*(3*z - 9) + y*(4*z + 10), + x*(5*z + 11) + y*(6*z - 12)]]) + + +def test_P8(): + M = Matrix([[1, -2*I], + [-3*I, 4]]) + assert M.norm(ord=S.Infinity) == 7 + + +def test_P9(): + a, b, c = symbols('a b c', nonzero=True) + M = Matrix([[a/(b*c), 1/c, 1/b], + [1/c, b/(a*c), 1/a], + [1/b, 1/a, c/(a*b)]]) + assert factor(M.norm('fro')) == (a**2 + b**2 + c**2)/(abs(a)*abs(b)*abs(c)) + + +@XFAIL +def test_P10(): + M = Matrix([[1, 2 + 3*I], + [f(4 - 5*I), 6]]) + # conjugate(f(4 - 5*i)) is not simplified to f(4+5*I) + assert M.H == Matrix([[1, f(4 + 5*I)], + [2 + 3*I, 6]]) + + +@XFAIL +def test_P11(): + # raises NotImplementedError("Matrix([[x,y],[1,x*y]]).inv() + # not simplifying to extract common factor") + assert Matrix([[x, y], + [1, x*y]]).inv() == (1/(x**2 - 1))*Matrix([[x, -1], + [-1/y, x/y]]) + + +def test_P11_workaround(): + # This test was changed to inverse method ADJ because it depended on the + # specific form of inverse returned from the 'GE' method which has changed. + M = Matrix([[x, y], [1, x*y]]).inv('ADJ') + c = gcd(tuple(M)) + assert MatMul(c, M/c, evaluate=False) == MatMul(c, Matrix([ + [x*y, -y], + [ -1, x]]), evaluate=False) + + +def test_P12(): + A11 = MatrixSymbol('A11', n, n) + A12 = MatrixSymbol('A12', n, n) + A22 = MatrixSymbol('A22', n, n) + B = BlockMatrix([[A11, A12], + [ZeroMatrix(n, n), A22]]) + assert block_collapse(B.I) == BlockMatrix([[A11.I, (-1)*A11.I*A12*A22.I], + [ZeroMatrix(n, n), A22.I]]) + + +def test_P13(): + M = Matrix([[1, x - 2, x - 3], + [x - 1, x**2 - 3*x + 6, x**2 - 3*x - 2], + [x - 2, x**2 - 8, 2*(x**2) - 12*x + 14]]) + L, U, _ = M.LUdecomposition() + assert simplify(L) == Matrix([[1, 0, 0], + [x - 1, 1, 0], + [x - 2, x - 3, 1]]) + assert simplify(U) == Matrix([[1, x - 2, x - 3], + [0, 4, x - 5], + [0, 0, x - 7]]) + + +def test_P14(): + M = Matrix([[1, 2, 3, 1, 3], + [3, 2, 1, 1, 7], + [0, 2, 4, 1, 1], + [1, 1, 1, 1, 4]]) + R, _ = M.rref() + assert R == Matrix([[1, 0, -1, 0, 2], + [0, 1, 2, 0, -1], + [0, 0, 0, 1, 3], + [0, 0, 0, 0, 0]]) + + +def test_P15(): + M = Matrix([[-1, 3, 7, -5], + [4, -2, 1, 3], + [2, 4, 15, -7]]) + assert M.rank() == 2 + + +def test_P16(): + M = Matrix([[2*sqrt(2), 8], + [6*sqrt(6), 24*sqrt(3)]]) + assert M.rank() == 1 + + +def test_P17(): + t = symbols('t', real=True) + M=Matrix([ + [sin(2*t), cos(2*t)], + [2*(1 - (cos(t)**2))*cos(t), (1 - 2*(sin(t)**2))*sin(t)]]) + assert M.rank() == 1 + + +def test_P18(): + M = Matrix([[1, 0, -2, 0], + [-2, 1, 0, 3], + [-1, 2, -6, 6]]) + assert M.nullspace() == [Matrix([[2], + [4], + [1], + [0]]), + Matrix([[0], + [-3], + [0], + [1]])] + + +def test_P19(): + w = symbols('w') + M = Matrix([[1, 1, 1, 1], + [w, x, y, z], + [w**2, x**2, y**2, z**2], + [w**3, x**3, y**3, z**3]]) + assert M.det() == (w**3*x**2*y - w**3*x**2*z - w**3*x*y**2 + w**3*x*z**2 + + w**3*y**2*z - w**3*y*z**2 - w**2*x**3*y + w**2*x**3*z + + w**2*x*y**3 - w**2*x*z**3 - w**2*y**3*z + w**2*y*z**3 + + w*x**3*y**2 - w*x**3*z**2 - w*x**2*y**3 + w*x**2*z**3 + + w*y**3*z**2 - w*y**2*z**3 - x**3*y**2*z + x**3*y*z**2 + + x**2*y**3*z - x**2*y*z**3 - x*y**3*z**2 + x*y**2*z**3 + ) + + +@XFAIL +def test_P20(): + raise NotImplementedError("Matrix minimal polynomial not supported") + + +def test_P21(): + M = Matrix([[5, -3, -7], + [-2, 1, 2], + [2, -3, -4]]) + assert M.charpoly(x).as_expr() == x**3 - 2*x**2 - 5*x + 6 + + +def test_P22(): + d = 100 + M = (2 - x)*eye(d) + assert M.eigenvals() == {-x + 2: d} + + +def test_P23(): + M = Matrix([ + [2, 1, 0, 0, 0], + [1, 2, 1, 0, 0], + [0, 1, 2, 1, 0], + [0, 0, 1, 2, 1], + [0, 0, 0, 1, 2]]) + assert M.eigenvals() == { + S('1'): 1, + S('2'): 1, + S('3'): 1, + S('sqrt(3) + 2'): 1, + S('-sqrt(3) + 2'): 1} + + +def test_P24(): + M = Matrix([[611, 196, -192, 407, -8, -52, -49, 29], + [196, 899, 113, -192, -71, -43, -8, -44], + [-192, 113, 899, 196, 61, 49, 8, 52], + [ 407, -192, 196, 611, 8, 44, 59, -23], + [ -8, -71, 61, 8, 411, -599, 208, 208], + [ -52, -43, 49, 44, -599, 411, 208, 208], + [ -49, -8, 8, 59, 208, 208, 99, -911], + [ 29, -44, 52, -23, 208, 208, -911, 99]]) + assert M.eigenvals() == { + S('0'): 1, + S('10*sqrt(10405)'): 1, + S('100*sqrt(26) + 510'): 1, + S('1000'): 2, + S('-100*sqrt(26) + 510'): 1, + S('-10*sqrt(10405)'): 1, + S('1020'): 1} + + +def test_P25(): + MF = N(Matrix([[ 611, 196, -192, 407, -8, -52, -49, 29], + [ 196, 899, 113, -192, -71, -43, -8, -44], + [-192, 113, 899, 196, 61, 49, 8, 52], + [ 407, -192, 196, 611, 8, 44, 59, -23], + [ -8, -71, 61, 8, 411, -599, 208, 208], + [ -52, -43, 49, 44, -599, 411, 208, 208], + [ -49, -8, 8, 59, 208, 208, 99, -911], + [ 29, -44, 52, -23, 208, 208, -911, 99]])) + + ev_1 = sorted(MF.eigenvals(multiple=True)) + ev_2 = sorted( + [-1020.0490184299969, 0.0, 0.09804864072151699, 1000.0, 1000.0, + 1019.9019513592784, 1020.0, 1020.0490184299969]) + + for x, y in zip(ev_1, ev_2): + assert abs(x - y) < 1e-12 + + +def test_P26(): + a0, a1, a2, a3, a4 = symbols('a0 a1 a2 a3 a4') + M = Matrix([[-a4, -a3, -a2, -a1, -a0, 0, 0, 0, 0], + [ 1, 0, 0, 0, 0, 0, 0, 0, 0], + [ 0, 1, 0, 0, 0, 0, 0, 0, 0], + [ 0, 0, 1, 0, 0, 0, 0, 0, 0], + [ 0, 0, 0, 1, 0, 0, 0, 0, 0], + [ 0, 0, 0, 0, 0, -1, -1, 0, 0], + [ 0, 0, 0, 0, 0, 1, 0, 0, 0], + [ 0, 0, 0, 0, 0, 0, 1, -1, -1], + [ 0, 0, 0, 0, 0, 0, 0, 1, 0]]) + assert M.eigenvals(error_when_incomplete=False) == { + S('-1/2 - sqrt(3)*I/2'): 2, + S('-1/2 + sqrt(3)*I/2'): 2} + + +def test_P27(): + a = symbols('a') + M = Matrix([[a, 0, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 0, a, 0, 0], + [0, 0, 0, a, 0], + [0, -2, 0, 0, 2]]) + + assert M.eigenvects() == [ + (a, 3, [ + Matrix([1, 0, 0, 0, 0]), + Matrix([0, 0, 1, 0, 0]), + Matrix([0, 0, 0, 1, 0]) + ]), + (1 - I, 1, [ + Matrix([0, (1 + I)/2, 0, 0, 1]) + ]), + (1 + I, 1, [ + Matrix([0, (1 - I)/2, 0, 0, 1]) + ]), + ] + + +@XFAIL +def test_P28(): + raise NotImplementedError("Generalized eigenvectors not supported \ +https://github.com/sympy/sympy/issues/5293") + + +@XFAIL +def test_P29(): + raise NotImplementedError("Generalized eigenvectors not supported \ +https://github.com/sympy/sympy/issues/5293") + + +def test_P30(): + M = Matrix([[1, 0, 0, 1, -1], + [0, 1, -2, 3, -3], + [0, 0, -1, 2, -2], + [1, -1, 1, 0, 1], + [1, -1, 1, -1, 2]]) + _, J = M.jordan_form() + assert J == Matrix([[-1, 0, 0, 0, 0], + [0, 1, 1, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 1], + [0, 0, 0, 0, 1]]) + + +@XFAIL +def test_P31(): + raise NotImplementedError("Smith normal form not implemented") + + +def test_P32(): + M = Matrix([[1, -2], + [2, 1]]) + assert exp(M).rewrite(cos).simplify() == Matrix([[E*cos(2), -E*sin(2)], + [E*sin(2), E*cos(2)]]) + + +def test_P33(): + w, t = symbols('w t') + M = Matrix([[0, 1, 0, 0], + [0, 0, 0, 2*w], + [0, 0, 0, 1], + [0, -2*w, 3*w**2, 0]]) + assert exp(M*t).rewrite(cos).expand() == Matrix([ + [1, -3*t + 4*sin(t*w)/w, 6*t*w - 6*sin(t*w), -2*cos(t*w)/w + 2/w], + [0, 4*cos(t*w) - 3, -6*w*cos(t*w) + 6*w, 2*sin(t*w)], + [0, 2*cos(t*w)/w - 2/w, -3*cos(t*w) + 4, sin(t*w)/w], + [0, -2*sin(t*w), 3*w*sin(t*w), cos(t*w)]]) + + +@XFAIL +def test_P34(): + a, b, c = symbols('a b c', real=True) + M = Matrix([[a, 1, 0, 0, 0, 0], + [0, a, 0, 0, 0, 0], + [0, 0, b, 0, 0, 0], + [0, 0, 0, c, 1, 0], + [0, 0, 0, 0, c, 1], + [0, 0, 0, 0, 0, c]]) + # raises exception, sin(M) not supported. exp(M*I) also not supported + # https://github.com/sympy/sympy/issues/6218 + assert sin(M) == Matrix([[sin(a), cos(a), 0, 0, 0, 0], + [0, sin(a), 0, 0, 0, 0], + [0, 0, sin(b), 0, 0, 0], + [0, 0, 0, sin(c), cos(c), -sin(c)/2], + [0, 0, 0, 0, sin(c), cos(c)], + [0, 0, 0, 0, 0, sin(c)]]) + + +@XFAIL +def test_P35(): + M = pi/2*Matrix([[2, 1, 1], + [2, 3, 2], + [1, 1, 2]]) + # raises exception, sin(M) not supported. exp(M*I) also not supported + # https://github.com/sympy/sympy/issues/6218 + assert sin(M) == eye(3) + + +@XFAIL +def test_P36(): + M = Matrix([[10, 7], + [7, 17]]) + assert sqrt(M) == Matrix([[3, 1], + [1, 4]]) + + +def test_P37(): + M = Matrix([[1, 1, 0], + [0, 1, 0], + [0, 0, 1]]) + assert M**S.Half == Matrix([[1, R(1, 2), 0], + [0, 1, 0], + [0, 0, 1]]) + + +@XFAIL +def test_P38(): + M=Matrix([[0, 1, 0], + [0, 0, 0], + [0, 0, 0]]) + + with raises(AssertionError): + # raises ValueError: Matrix det == 0; not invertible + M**S.Half + # if it doesn't raise then this assertion will be + # raised and the test will be flagged as not XFAILing + assert None + +@XFAIL +def test_P39(): + """ + M=Matrix([ + [1, 1], + [2, 2], + [3, 3]]) + M.SVD() + """ + raise NotImplementedError("Singular value decomposition not implemented") + + +def test_P40(): + r, t = symbols('r t', real=True) + M = Matrix([r*cos(t), r*sin(t)]) + assert M.jacobian(Matrix([r, t])) == Matrix([[cos(t), -r*sin(t)], + [sin(t), r*cos(t)]]) + + +def test_P41(): + r, t = symbols('r t', real=True) + assert hessian(r**2*sin(t),(r,t)) == Matrix([[ 2*sin(t), 2*r*cos(t)], + [2*r*cos(t), -r**2*sin(t)]]) + + +def test_P42(): + assert wronskian([cos(x), sin(x)], x).simplify() == 1 + + +def test_P43(): + def __my_jacobian(M, Y): + return Matrix([M.diff(v).T for v in Y]).T + r, t = symbols('r t', real=True) + M = Matrix([r*cos(t), r*sin(t)]) + assert __my_jacobian(M,[r,t]) == Matrix([[cos(t), -r*sin(t)], + [sin(t), r*cos(t)]]) + + +def test_P44(): + def __my_hessian(f, Y): + V = Matrix([diff(f, v) for v in Y]) + return Matrix([V.T.diff(v) for v in Y]) + r, t = symbols('r t', real=True) + assert __my_hessian(r**2*sin(t), (r, t)) == Matrix([ + [ 2*sin(t), 2*r*cos(t)], + [2*r*cos(t), -r**2*sin(t)]]) + + +def test_P45(): + def __my_wronskian(Y, v): + M = Matrix([Matrix(Y).T.diff(x, n) for n in range(0, len(Y))]) + return M.det() + assert __my_wronskian([cos(x), sin(x)], x).simplify() == 1 + +# Q1-Q6 Tensor tests missing + + +@XFAIL +def test_R1(): + i, j, n = symbols('i j n', integer=True, positive=True) + xn = MatrixSymbol('xn', n, 1) + Sm = Sum((xn[i, 0] - Sum(xn[j, 0], (j, 0, n - 1))/n)**2, (i, 0, n - 1)) + # sum does not calculate + # Unknown result + Sm.doit() + raise NotImplementedError('Unknown result') + +@XFAIL +def test_R2(): + m, b = symbols('m b') + i, n = symbols('i n', integer=True, positive=True) + xn = MatrixSymbol('xn', n, 1) + yn = MatrixSymbol('yn', n, 1) + f = Sum((yn[i, 0] - m*xn[i, 0] - b)**2, (i, 0, n - 1)) + f1 = diff(f, m) + f2 = diff(f, b) + # raises TypeError: solveset() takes at most 2 arguments (3 given) + solveset((f1, f2), (m, b), domain=S.Reals) + + +@XFAIL +def test_R3(): + n, k = symbols('n k', integer=True, positive=True) + sk = ((-1)**k) * (binomial(2*n, k))**2 + Sm = Sum(sk, (k, 1, oo)) + T = Sm.doit() + T2 = T.combsimp() + # returns -((-1)**n*factorial(2*n) + # - (factorial(n))**2)*exp_polar(-I*pi)/(factorial(n))**2 + assert T2 == (-1)**n*binomial(2*n, n) + + +@XFAIL +def test_R4(): +# Macsyma indefinite sum test case: +#(c15) /* Check whether the full Gosper algorithm is implemented +# => 1/2^(n + 1) binomial(n, k - 1) */ +#closedform(indefsum(binomial(n, k)/2^n - binomial(n + 1, k)/2^(n + 1), k)); +#Time= 2690 msecs +# (- n + k - 1) binomial(n + 1, k) +#(d15) - -------------------------------- +# n +# 2 2 (n + 1) +# +#(c16) factcomb(makefact(%)); +#Time= 220 msecs +# n! +#(d16) ---------------- +# n +# 2 k! 2 (n - k)! +# Might be possible after fixing https://github.com/sympy/sympy/pull/1879 + raise NotImplementedError("Indefinite sum not supported") + + +@XFAIL +def test_R5(): + a, b, c, n, k = symbols('a b c n k', integer=True, positive=True) + sk = ((-1)**k)*(binomial(a + b, a + k) + *binomial(b + c, b + k)*binomial(c + a, c + k)) + Sm = Sum(sk, (k, 1, oo)) + T = Sm.doit() # hypergeometric series not calculated + assert T == factorial(a+b+c)/(factorial(a)*factorial(b)*factorial(c)) + + +def test_R6(): + n, k = symbols('n k', integer=True, positive=True) + gn = MatrixSymbol('gn', n + 2, 1) + Sm = Sum(gn[k, 0] - gn[k - 1, 0], (k, 1, n + 1)) + assert Sm.doit() == -gn[0, 0] + gn[n + 1, 0] + + +def test_R7(): + n, k = symbols('n k', integer=True, positive=True) + T = Sum(k**3,(k,1,n)).doit() + assert T.factor() == n**2*(n + 1)**2/4 + +@XFAIL +def test_R8(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(k**2*binomial(n, k), (k, 1, n)) + T = Sm.doit() #returns Piecewise function + assert T.combsimp() == n*(n + 1)*2**(n - 2) + + +def test_R9(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(binomial(n, k - 1)/k, (k, 1, n + 1)) + assert Sm.doit().simplify() == (2**(n + 1) - 1)/(n + 1) + + +@XFAIL +def test_R10(): + n, m, r, k = symbols('n m r k', integer=True, positive=True) + Sm = Sum(binomial(n, k)*binomial(m, r - k), (k, 0, r)) + T = Sm.doit() + T2 = T.combsimp().rewrite(factorial) + assert T2 == factorial(m + n)/(factorial(r)*factorial(m + n - r)) + assert T2 == binomial(m + n, r).rewrite(factorial) + # rewrite(binomial) is not working. + # https://github.com/sympy/sympy/issues/7135 + T3 = T2.rewrite(binomial) + assert T3 == binomial(m + n, r) + + +@XFAIL +def test_R11(): + n, k = symbols('n k', integer=True, positive=True) + sk = binomial(n, k)*fibonacci(k) + Sm = Sum(sk, (k, 0, n)) + T = Sm.doit() + # Fibonacci simplification not implemented + # https://github.com/sympy/sympy/issues/7134 + assert T == fibonacci(2*n) + + +@XFAIL +def test_R12(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(fibonacci(k)**2, (k, 0, n)) + T = Sm.doit() + assert T == fibonacci(n)*fibonacci(n + 1) + + +@XFAIL +def test_R13(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(sin(k*x), (k, 1, n)) + T = Sm.doit() # Sum is not calculated + assert T.simplify() == cot(x/2)/2 - cos(x*(2*n + 1)/2)/(2*sin(x/2)) + + +@XFAIL +def test_R14(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(sin((2*k - 1)*x), (k, 1, n)) + T = Sm.doit() # Sum is not calculated + assert T.simplify() == sin(n*x)**2/sin(x) + + +@XFAIL +def test_R15(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(binomial(n - k, k), (k, 0, floor(n/2))) + T = Sm.doit() # Sum is not calculated + assert T.simplify() == fibonacci(n + 1) + + +def test_R16(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/k**2 + 1/k**3, (k, 1, oo)) + assert Sm.doit() == zeta(3) + pi**2/6 + + +def test_R17(): + k = symbols('k', integer=True, positive=True) + assert abs(float(Sum(1/k**2 + 1/k**3, (k, 1, oo))) + - 2.8469909700078206) < 1e-15 + + +def test_R18(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/(2**k*k**2), (k, 1, oo)) + T = Sm.doit() + assert T.simplify() == -log(2)**2/2 + pi**2/12 + + +@slow +@XFAIL +def test_R19(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/((3*k + 1)*(3*k + 2)*(3*k + 3)), (k, 0, oo)) + T = Sm.doit() + # assert fails, T not simplified + assert T.simplify() == -log(3)/4 + sqrt(3)*pi/12 + + +@XFAIL +def test_R20(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(binomial(n, 4*k), (k, 0, oo)) + T = Sm.doit() + # assert fails, T not simplified + assert T.simplify() == 2**(n/2)*cos(pi*n/4)/2 + 2**(n - 1)/2 + + +@XFAIL +def test_R21(): + k = symbols('k', integer=True, positive=True) + Sm = Sum(1/(sqrt(k*(k + 1)) * (sqrt(k) + sqrt(k + 1))), (k, 1, oo)) + T = Sm.doit() # Sum not calculated + assert T.simplify() == 1 + + +# test_R22 answer not available in Wester samples +# Sum(Sum(binomial(n, k)*binomial(n - k, n - 2*k)*x**n*y**(n - 2*k), +# (k, 0, floor(n/2))), (n, 0, oo)) with abs(x*y)<1? + + +@XFAIL +def test_R23(): + n, k = symbols('n k', integer=True, positive=True) + Sm = Sum(Sum((factorial(n)/(factorial(k)**2*factorial(n - 2*k)))* + (x/y)**k*(x*y)**(n - k), (n, 2*k, oo)), (k, 0, oo)) + # Missing how to express constraint abs(x*y)<1? + T = Sm.doit() # Sum not calculated + assert T == -1/sqrt(x**2*y**2 - 4*x**2 - 2*x*y + 1) + + +def test_R24(): + m, k = symbols('m k', integer=True, positive=True) + Sm = Sum(Product(k/(2*k - 1), (k, 1, m)), (m, 2, oo)) + assert Sm.doit() == pi/2 + + +def test_S1(): + k = symbols('k', integer=True, positive=True) + Pr = Product(gamma(k/3), (k, 1, 8)) + assert Pr.doit().simplify() == 640*sqrt(3)*pi**3/6561 + + +def test_S2(): + n, k = symbols('n k', integer=True, positive=True) + assert Product(k, (k, 1, n)).doit() == factorial(n) + + +def test_S3(): + n, k = symbols('n k', integer=True, positive=True) + assert Product(x**k, (k, 1, n)).doit().simplify() == x**(n*(n + 1)/2) + + +def test_S4(): + n, k = symbols('n k', integer=True, positive=True) + assert Product(1 + 1/k, (k, 1, n -1)).doit().simplify() == n + + +def test_S5(): + n, k = symbols('n k', integer=True, positive=True) + assert (Product((2*k - 1)/(2*k), (k, 1, n)).doit().gammasimp() == + gamma(n + S.Half)/(sqrt(pi)*gamma(n + 1))) + + +@XFAIL +def test_S6(): + n, k = symbols('n k', integer=True, positive=True) + # Product does not evaluate + assert (Product(x**2 -2*x*cos(k*pi/n) + 1, (k, 1, n - 1)).doit().simplify() + == (x**(2*n) - 1)/(x**2 - 1)) + + +@XFAIL +def test_S7(): + k = symbols('k', integer=True, positive=True) + Pr = Product((k**3 - 1)/(k**3 + 1), (k, 2, oo)) + T = Pr.doit() # Product does not evaluate + assert T.simplify() == R(2, 3) + + +@XFAIL +def test_S8(): + k = symbols('k', integer=True, positive=True) + Pr = Product(1 - 1/(2*k)**2, (k, 1, oo)) + T = Pr.doit() + # Product does not evaluate + assert T.simplify() == 2/pi + + +@XFAIL +def test_S9(): + k = symbols('k', integer=True, positive=True) + Pr = Product(1 + (-1)**(k + 1)/(2*k - 1), (k, 1, oo)) + T = Pr.doit() + # Product produces 0 + # https://github.com/sympy/sympy/issues/7133 + assert T.simplify() == sqrt(2) + + +@XFAIL +def test_S10(): + k = symbols('k', integer=True, positive=True) + Pr = Product((k*(k + 1) + 1 + I)/(k*(k + 1) + 1 - I), (k, 0, oo)) + T = Pr.doit() + # Product does not evaluate + assert T.simplify() == -1 + + +def test_T1(): + assert limit((1 + 1/n)**n, n, oo) == E + assert limit((1 - cos(x))/x**2, x, 0) == S.Half + + +def test_T2(): + assert limit((3**x + 5**x)**(1/x), x, oo) == 5 + + +def test_T3(): + assert limit(log(x)/(log(x) + sin(x)), x, oo) == 1 + + +def test_T4(): + assert limit((exp(x*exp(-x)/(exp(-x) + exp(-2*x**2/(x + 1)))) + - exp(x))/x, x, oo) == -exp(2) + + +def test_T5(): + assert limit(x*log(x)*log(x*exp(x) - x**2)**2/log(log(x**2 + + 2*exp(exp(3*x**3*log(x))))), x, oo) == R(1, 3) + + +def test_T6(): + assert limit(1/n * factorial(n)**(1/n), n, oo) == exp(-1) + + +def test_T7(): + limit(1/n * gamma(n + 1)**(1/n), n, oo) + + +def test_T8(): + a, z = symbols('a z', positive=True) + assert limit(gamma(z + a)/gamma(z)*exp(-a*log(z)), z, oo) == 1 + + +@XFAIL +def test_T9(): + z, k = symbols('z k', positive=True) + # raises NotImplementedError: + # Don't know how to calculate the mrv of '(1, k)' + assert limit(hyper((1, k), (1,), z/k), k, oo) == exp(z) + + +@XFAIL +def test_T10(): + # No longer raises PoleError, but should return euler-mascheroni constant + assert limit(zeta(x) - 1/(x - 1), x, 1) == integrate(-1/x + 1/floor(x), (x, 1, oo)) + +@XFAIL +def test_T11(): + n, k = symbols('n k', integer=True, positive=True) + # evaluates to 0 + assert limit(n**x/(x*product((1 + x/k), (k, 1, n))), n, oo) == gamma(x) + + +def test_T12(): + x, t = symbols('x t', real=True) + # Does not evaluate the limit but returns an expression with erf + assert limit(x * integrate(exp(-t**2), (t, 0, x))/(1 - exp(-x**2)), + x, 0) == 1 + + +def test_T13(): + x = symbols('x', real=True) + assert [limit(x/abs(x), x, 0, dir='-'), + limit(x/abs(x), x, 0, dir='+')] == [-1, 1] + + +def test_T14(): + x = symbols('x', real=True) + assert limit(atan(-log(x)), x, 0, dir='+') == pi/2 + + +def test_U1(): + x = symbols('x', real=True) + assert diff(abs(x), x) == sign(x) + + +def test_U2(): + f = Lambda(x, Piecewise((-x, x < 0), (x, x >= 0))) + assert diff(f(x), x) == Piecewise((-1, x < 0), (1, x >= 0)) + + +def test_U3(): + f = Lambda(x, Piecewise((x**2 - 1, x == 1), (x**3, x != 1))) + f1 = Lambda(x, diff(f(x), x)) + assert f1(x) == 3*x**2 + assert f1(1) == 3 + + +@XFAIL +def test_U4(): + n = symbols('n', integer=True, positive=True) + x = symbols('x', real=True) + d = diff(x**n, x, n) + assert d.rewrite(factorial) == factorial(n) + + +def test_U5(): + # issue 6681 + t = symbols('t') + ans = ( + Derivative(f(g(t)), g(t))*Derivative(g(t), (t, 2)) + + Derivative(f(g(t)), (g(t), 2))*Derivative(g(t), t)**2) + assert f(g(t)).diff(t, 2) == ans + assert ans.doit() == ans + + +def test_U6(): + h = Function('h') + T = integrate(f(y), (y, h(x), g(x))) + assert T.diff(x) == ( + f(g(x))*Derivative(g(x), x) - f(h(x))*Derivative(h(x), x)) + + +@XFAIL +def test_U7(): + p, t = symbols('p t', real=True) + # Exact differential => d(V(P, T)) => dV/dP DP + dV/dT DT + # raises ValueError: Since there is more than one variable in the + # expression, the variable(s) of differentiation must be supplied to + # differentiate f(p,t) + diff(f(p, t)) + + +def test_U8(): + x, y = symbols('x y', real=True) + eq = cos(x*y) + x + # If SymPy had implicit_diff() function this hack could be avoided + # TODO: Replace solve with solveset, current test fails for solveset + assert idiff(y - eq, y, x) == (-y*sin(x*y) + 1)/(x*sin(x*y) + 1) + + +def test_U9(): + # Wester sample case for Maple: + # O29 := diff(f(x, y), x) + diff(f(x, y), y); + # /d \ /d \ + # |-- f(x, y)| + |-- f(x, y)| + # \dx / \dy / + # + # O30 := factor(subs(f(x, y) = g(x^2 + y^2), %)); + # 2 2 + # 2 D(g)(x + y ) (x + y) + x, y = symbols('x y', real=True) + su = diff(f(x, y), x) + diff(f(x, y), y) + s2 = su.subs(f(x, y), g(x**2 + y**2)) + s3 = s2.doit().factor() + # Subs not performed, s3 = 2*(x + y)*Subs(Derivative( + # g(_xi_1), _xi_1), _xi_1, x**2 + y**2) + # Derivative(g(x*2 + y**2), x**2 + y**2) is not valid in SymPy, + # and probably will remain that way. You can take derivatives with respect + # to other expressions only if they are atomic, like a symbol or a + # function. + # D operator should be added to SymPy + # See https://github.com/sympy/sympy/issues/4719. + assert s3 == (x + y)*Subs(Derivative(g(x), x), x, x**2 + y**2)*2 + + +def test_U10(): + # see issue 2519: + assert residue((z**3 + 5)/((z**4 - 1)*(z + 1)), z, -1) == R(-9, 4) + +@XFAIL +def test_U11(): + # assert (2*dx + dz) ^ (3*dx + dy + dz) ^ (dx + dy + 4*dz) == 8*dx ^ dy ^dz + raise NotImplementedError + + +@XFAIL +def test_U12(): + # Wester sample case: + # (c41) /* d(3 x^5 dy /\ dz + 5 x y^2 dz /\ dx + 8 z dx /\ dy) + # => (15 x^4 + 10 x y + 8) dx /\ dy /\ dz */ + # factor(ext_diff(3*x^5 * dy ~ dz + 5*x*y^2 * dz ~ dx + 8*z * dx ~ dy)); + # 4 + # (d41) (10 x y + 15 x + 8) dx dy dz + raise NotImplementedError( + "External diff of differential form not supported") + + +def test_U13(): + assert minimum(x**4 - x + 1, x) == -3*2**R(1,3)/8 + 1 + + +@XFAIL +def test_U14(): + #f = 1/(x**2 + y**2 + 1) + #assert [minimize(f), maximize(f)] == [0,1] + raise NotImplementedError("minimize(), maximize() not supported") + + +@XFAIL +def test_U15(): + raise NotImplementedError("minimize() not supported and also solve does \ +not support multivariate inequalities") + + +@XFAIL +def test_U16(): + raise NotImplementedError("minimize() not supported in SymPy and also \ +solve does not support multivariate inequalities") + + +@XFAIL +def test_U17(): + raise NotImplementedError("Linear programming, symbolic simplex not \ +supported in SymPy") + + +def test_V1(): + x = symbols('x', real=True) + assert integrate(abs(x), x) == Piecewise((-x**2/2, x <= 0), (x**2/2, True)) + + +def test_V2(): + assert integrate(Piecewise((-x, x < 0), (x, x >= 0)), x + ) == Piecewise((-x**2/2, x < 0), (x**2/2, True)) + + +def test_V3(): + assert integrate(1/(x**3 + 2),x).diff().simplify() == 1/(x**3 + 2) + + +def test_V4(): + assert integrate(2**x/sqrt(1 + 4**x), x) == asinh(2**x)/log(2) + + +@XFAIL +def test_V5(): + # Returns (-45*x**2 + 80*x - 41)/(5*sqrt(2*x - 1)*(4*x**2 - 4*x + 1)) + assert (integrate((3*x - 5)**2/(2*x - 1)**R(7, 2), x).simplify() == + (-41 + 80*x - 45*x**2)/(5*(2*x - 1)**R(5, 2))) + + +@XFAIL +def test_V6(): + # returns RootSum(40*_z**2 - 1, Lambda(_i, _i*log(-4*_i + exp(-m*x))))/m + assert (integrate(1/(2*exp(m*x) - 5*exp(-m*x)), x) == sqrt(10)*( + log(2*exp(m*x) - sqrt(10)) - log(2*exp(m*x) + sqrt(10)))/(20*m)) + + +def test_V7(): + r1 = integrate(sinh(x)**4/cosh(x)**2) + assert r1.simplify() == x*R(-3, 2) + sinh(x)**3/(2*cosh(x)) + 3*tanh(x)/2 + + +@XFAIL +def test_V8_V9(): +#Macsyma test case: +#(c27) /* This example involves several symbolic parameters +# => 1/sqrt(b^2 - a^2) log([sqrt(b^2 - a^2) tan(x/2) + a + b]/ +# [sqrt(b^2 - a^2) tan(x/2) - a - b]) (a^2 < b^2) +# [Gradshteyn and Ryzhik 2.553(3)] */ +#assume(b^2 > a^2)$ +#(c28) integrate(1/(a + b*cos(x)), x); +#(c29) trigsimp(ratsimp(diff(%, x))); +# 1 +#(d29) ------------ +# b cos(x) + a + raise NotImplementedError( + "Integrate with assumption not supported") + + +def test_V10(): + assert integrate(1/(3 + 3*cos(x) + 4*sin(x)), x) == log(4*tan(x/2) + 3)/4 + + +def test_V11(): + r1 = integrate(1/(4 + 3*cos(x) + 4*sin(x)), x) + r2 = factor(r1) + assert (logcombine(r2, force=True) == + log(((tan(x/2) + 1)/(tan(x/2) + 7))**R(1, 3))) + + +def test_V12(): + r1 = integrate(1/(5 + 3*cos(x) + 4*sin(x)), x) + assert r1 == -1/(tan(x/2) + 2) + + +@XFAIL +def test_V13(): + r1 = integrate(1/(6 + 3*cos(x) + 4*sin(x)), x) + # expression not simplified, returns: -sqrt(11)*I*log(tan(x/2) + 4/3 + # - sqrt(11)*I/3)/11 + sqrt(11)*I*log(tan(x/2) + 4/3 + sqrt(11)*I/3)/11 + assert r1.simplify() == 2*sqrt(11)*atan(sqrt(11)*(3*tan(x/2) + 4)/11)/11 + + +@slow +@XFAIL +def test_V14(): + r1 = integrate(log(abs(x**2 - y**2)), x) + # Piecewise result does not simplify to the desired result. + assert (r1.simplify() == x*log(abs(x**2 - y**2)) + + y*log(x + y) - y*log(x - y) - 2*x) + + +def test_V15(): + r1 = integrate(x*acot(x/y), x) + assert simplify(r1 - (x*y + (x**2 + y**2)*acot(x/y))/2) == 0 + + +@XFAIL +def test_V16(): + # Integral not calculated + assert integrate(cos(5*x)*Ci(2*x), x) == Ci(2*x)*sin(5*x)/5 - (Si(3*x) + Si(7*x))/10 + +@XFAIL +def test_V17(): + r1 = integrate((diff(f(x), x)*g(x) + - f(x)*diff(g(x), x))/(f(x)**2 - g(x)**2), x) + # integral not calculated + assert simplify(r1 - (f(x) - g(x))/(f(x) + g(x))/2) == 0 + + +@XFAIL +def test_W1(): + # The function has a pole at y. + # The integral has a Cauchy principal value of zero but SymPy returns -I*pi + # https://github.com/sympy/sympy/issues/7159 + assert integrate(1/(x - y), (x, y - 1, y + 1)) == 0 + + +@XFAIL +def test_W2(): + # The function has a pole at y. + # The integral is divergent but SymPy returns -2 + # https://github.com/sympy/sympy/issues/7160 + # Test case in Macsyma: + # (c6) errcatch(integrate(1/(x - a)^2, x, a - 1, a + 1)); + # Integral is divergent + assert integrate(1/(x - y)**2, (x, y - 1, y + 1)) is zoo + + +@XFAIL +@slow +def test_W3(): + # integral is not calculated + # https://github.com/sympy/sympy/issues/7161 + assert integrate(sqrt(x + 1/x - 2), (x, 0, 1)) == R(4, 3) + + +@XFAIL +@slow +def test_W4(): + # integral is not calculated + assert integrate(sqrt(x + 1/x - 2), (x, 1, 2)) == -2*sqrt(2)/3 + R(4, 3) + + +@XFAIL +@slow +def test_W5(): + # integral is not calculated + assert integrate(sqrt(x + 1/x - 2), (x, 0, 2)) == -2*sqrt(2)/3 + R(8, 3) + + +@XFAIL +@slow +def test_W6(): + # integral is not calculated + assert integrate(sqrt(2 - 2*cos(2*x))/2, (x, pi*R(-3, 4), -pi/4)) == sqrt(2) + + +def test_W7(): + a = symbols('a', positive=True) + r1 = integrate(cos(x)/(x**2 + a**2), (x, -oo, oo)) + assert r1.simplify() == pi*exp(-a)/a + + +@XFAIL +def test_W8(): + # Test case in Mathematica: + # In[19]:= Integrate[t^(a - 1)/(1 + t), {t, 0, Infinity}, + # Assumptions -> 0 < a < 1] + # Out[19]= Pi Csc[a Pi] + raise NotImplementedError( + "Integrate with assumption 0 < a < 1 not supported") + + +@XFAIL +@slow +def test_W9(): + # Integrand with a residue at infinity => -2 pi [sin(pi/5) + sin(2pi/5)] + # (principal value) [Levinson and Redheffer, p. 234] *) + r1 = integrate(5*x**3/(1 + x + x**2 + x**3 + x**4), (x, -oo, oo)) + r2 = r1.doit() + assert r2 == -2*pi*(sqrt(-sqrt(5)/8 + 5/8) + sqrt(sqrt(5)/8 + 5/8)) + + +@XFAIL +def test_W10(): + # integrate(1/[1 + x + x^2 + ... + x^(2 n)], x = -infinity..infinity) = + # 2 pi/(2 n + 1) [1 + cos(pi/[2 n + 1])] csc(2 pi/[2 n + 1]) + # [Levinson and Redheffer, p. 255] => 2 pi/5 [1 + cos(pi/5)] csc(2 pi/5) */ + r1 = integrate(x/(1 + x + x**2 + x**4), (x, -oo, oo)) + r2 = r1.doit() + assert r2 == 2*pi*(sqrt(5)/4 + 5/4)*csc(pi*R(2, 5))/5 + + +@XFAIL +def test_W11(): + # integral not calculated + assert (integrate(sqrt(1 - x**2)/(1 + x**2), (x, -1, 1)) == + pi*(-1 + sqrt(2))) + + +def test_W12(): + p = symbols('p', positive=True) + q = symbols('q', real=True) + r1 = integrate(x*exp(-p*x**2 + 2*q*x), (x, -oo, oo)) + assert r1.simplify() == sqrt(pi)*q*exp(q**2/p)/p**R(3, 2) + + +@XFAIL +def test_W13(): + # Integral not calculated. Expected result is 2*(Euler_mascheroni_constant) + r1 = integrate(1/log(x) + 1/(1 - x) - log(log(1/x)), (x, 0, 1)) + assert r1 == 2*EulerGamma + + +def test_W14(): + assert integrate(sin(x)/x*exp(2*I*x), (x, -oo, oo)) == 0 + + +@XFAIL +def test_W15(): + # integral not calculated + assert integrate(log(gamma(x))*cos(6*pi*x), (x, 0, 1)) == R(1, 12) + + +def test_W16(): + assert integrate((1 + x)**3*legendre_poly(1, x)*legendre_poly(2, x), + (x, -1, 1)) == R(36, 35) + + +def test_W17(): + a, b = symbols('a b', positive=True) + assert integrate(exp(-a*x)*besselj(0, b*x), + (x, 0, oo)) == 1/(b*sqrt(a**2/b**2 + 1)) + + +def test_W18(): + assert integrate((besselj(1, x)/x)**2, (x, 0, oo)) == 4/(3*pi) + + +@XFAIL +def test_W19(): + # Integral not calculated + # Expected result is (cos 7 - 1)/7 [Gradshteyn and Ryzhik 6.782(3)] + assert integrate(Ci(x)*besselj(0, 2*sqrt(7*x)), (x, 0, oo)) == (cos(7) - 1)/7 + + +@XFAIL +def test_W20(): + # integral not calculated + assert (integrate(x**2*polylog(3, 1/(x + 1)), (x, 0, 1)) == + -pi**2/36 - R(17, 108) + zeta(3)/4 + + (-pi**2/2 - 4*log(2) + log(2)**2 + 35/3)*log(2)/9) + + +def test_W21(): + assert abs(N(integrate(x**2*polylog(3, 1/(x + 1)), (x, 0, 1))) + - 0.210882859565594) < 1e-15 + + +def test_W22(): + t, u = symbols('t u', real=True) + s = Lambda(x, Piecewise((1, And(x >= 1, x <= 2)), (0, True))) + assert integrate(s(t)*cos(t), (t, 0, u)) == Piecewise( + (0, u < 0), + (-sin(Min(1, u)) + sin(Min(2, u)), True)) + + +@slow +def test_W23(): + a, b = symbols('a b', positive=True) + r1 = integrate(integrate(x/(x**2 + y**2), (x, a, b)), (y, -oo, oo)) + assert r1.collect(pi).cancel() == -pi*a + pi*b + + +def test_W23b(): + # like W23 but limits are reversed + a, b = symbols('a b', positive=True) + r2 = integrate(integrate(x/(x**2 + y**2), (y, -oo, oo)), (x, a, b)) + assert r2.collect(pi) == pi*(-a + b) + + +@XFAIL +@tooslow +def test_W24(): + # Not that slow, but does not fully evaluate so simplify is slow. + # Maybe also require doit() + x, y = symbols('x y', real=True) + r1 = integrate(integrate(sqrt(x**2 + y**2), (x, 0, 1)), (y, 0, 1)) + assert (r1 - (sqrt(2) + asinh(1))/3).simplify() == 0 + + +@XFAIL +@tooslow +def test_W25(): + a, x, y = symbols('a x y', real=True) + i1 = integrate( + sin(a)*sin(y)/sqrt(1 - sin(a)**2*sin(x)**2*sin(y)**2), + (x, 0, pi/2)) + i2 = integrate(i1, (y, 0, pi/2)) + assert (i2 - pi*a/2).simplify() == 0 + + +def test_W26(): + x, y = symbols('x y', real=True) + assert integrate(integrate(abs(y - x**2), (y, 0, 2)), + (x, -1, 1)) == R(46, 15) + + +def test_W27(): + a, b, c = symbols('a b c') + assert integrate(integrate(integrate(1, (z, 0, c*(1 - x/a - y/b))), + (y, 0, b*(1 - x/a))), + (x, 0, a)) == a*b*c/6 + + +def test_X1(): + v, c = symbols('v c', real=True) + assert (series(1/sqrt(1 - (v/c)**2), v, x0=0, n=8) == + 5*v**6/(16*c**6) + 3*v**4/(8*c**4) + v**2/(2*c**2) + 1 + O(v**8)) + + +def test_X2(): + v, c = symbols('v c', real=True) + s1 = series(1/sqrt(1 - (v/c)**2), v, x0=0, n=8) + assert (1/s1**2).series(v, x0=0, n=8) == -v**2/c**2 + 1 + O(v**8) + + +def test_X3(): + s1 = (sin(x).series()/cos(x).series()).series() + s2 = tan(x).series() + assert s2 == x + x**3/3 + 2*x**5/15 + O(x**6) + assert s1 == s2 + + +def test_X4(): + s1 = log(sin(x)/x).series() + assert s1 == -x**2/6 - x**4/180 + O(x**6) + assert log(series(sin(x)/x)).series() == s1 + + +@XFAIL +def test_X5(): + # test case in Mathematica syntax: + # In[21]:= (* => [a f'(a d) + g(b d) + integrate(h(c y), y = 0..d)] + # + [a^2 f''(a d) + b g'(b d) + h(c d)] (x - d) *) + # In[22]:= D[f[a*x], x] + g[b*x] + Integrate[h[c*y], {y, 0, x}] + # Out[22]= g[b x] + Integrate[h[c y], {y, 0, x}] + a f'[a x] + # In[23]:= Series[%, {x, d, 1}] + # Out[23]= (g[b d] + Integrate[h[c y], {y, 0, d}] + a f'[a d]) + + # 2 2 + # (h[c d] + b g'[b d] + a f''[a d]) (-d + x) + O[-d + x] + h = Function('h') + a, b, c, d = symbols('a b c d', real=True) + # series() raises NotImplementedError: + # The _eval_nseries method should be added to to give terms up to O(x**n) at x=0 + series(diff(f(a*x), x) + g(b*x) + integrate(h(c*y), (y, 0, x)), + x, x0=d, n=2) + # assert missing, until exception is removed + + +def test_X6(): + # Taylor series of nonscalar objects (noncommutative multiplication) + # expected result => (B A - A B) t^2/2 + O(t^3) [Stanly Steinberg] + a, b = symbols('a b', commutative=False, scalar=False) + assert (series(exp((a + b)*x) - exp(a*x) * exp(b*x), x, x0=0, n=3) == + x**2*(-a*b/2 + b*a/2) + O(x**3)) + + +def test_X7(): + # => sum( Bernoulli[k]/k! x^(k - 2), k = 1..infinity ) + # = 1/x^2 - 1/(2 x) + 1/12 - x^2/720 + x^4/30240 + O(x^6) + # [Levinson and Redheffer, p. 173] + assert (series(1/(x*(exp(x) - 1)), x, 0, 7) == x**(-2) - 1/(2*x) + + R(1, 12) - x**2/720 + x**4/30240 - x**6/1209600 + O(x**7)) + + +def test_X8(): + # Puiseux series (terms with fractional degree): + # => 1/sqrt(x - 3/2 pi) + (x - 3/2 pi)^(3/2) / 12 + O([x - 3/2 pi]^(7/2)) + + # see issue 7167: + x = symbols('x', real=True) + assert (series(sqrt(sec(x)), x, x0=pi*3/2, n=4) == + 1/sqrt(x - pi*R(3, 2)) + (x - pi*R(3, 2))**R(3, 2)/12 + + (x - pi*R(3, 2))**R(7, 2)/160 + O((x - pi*R(3, 2))**4, (x, pi*R(3, 2)))) + + +def test_X9(): + assert (series(x**x, x, x0=0, n=4) == 1 + x*log(x) + x**2*log(x)**2/2 + + x**3*log(x)**3/6 + O(x**4*log(x)**4)) + + +def test_X10(): + z, w = symbols('z w') + assert (series(log(sinh(z)) + log(cosh(z + w)), z, x0=0, n=2) == + log(cosh(w)) + log(z) + z*sinh(w)/cosh(w) + O(z**2)) + + +def test_X11(): + z, w = symbols('z w') + assert (series(log(sinh(z) * cosh(z + w)), z, x0=0, n=2) == + log(cosh(w)) + log(z) + z*sinh(w)/cosh(w) + O(z**2)) + + +@XFAIL +def test_X12(): + # Look at the generalized Taylor series around x = 1 + # Result => (x - 1)^a/e^b [1 - (a + 2 b) (x - 1) / 2 + O((x - 1)^2)] + a, b, x = symbols('a b x', real=True) + # series returns O(log(x-1)**2) + # https://github.com/sympy/sympy/issues/7168 + assert (series(log(x)**a*exp(-b*x), x, x0=1, n=2) == + (x - 1)**a/exp(b)*(1 - (a + 2*b)*(x - 1)/2 + O((x - 1)**2))) + + +def test_X13(): + assert series(sqrt(2*x**2 + 1), x, x0=oo, n=1) == sqrt(2)*x + O(1/x, (x, oo)) + + +@XFAIL +def test_X14(): + # Wallis' product => 1/sqrt(pi n) + ... [Knopp, p. 385] + assert series(1/2**(2*n)*binomial(2*n, n), + n, x==oo, n=1) == 1/(sqrt(pi)*sqrt(n)) + O(1/x, (x, oo)) + + +@SKIP("https://github.com/sympy/sympy/issues/7164") +def test_X15(): + # => 0!/x - 1!/x^2 + 2!/x^3 - 3!/x^4 + O(1/x^5) [Knopp, p. 544] + x, t = symbols('x t', real=True) + # raises RuntimeError: maximum recursion depth exceeded + # https://github.com/sympy/sympy/issues/7164 + # 2019-02-17: Raises + # PoleError: + # Asymptotic expansion of Ei around [-oo] is not implemented. + e1 = integrate(exp(-t)/t, (t, x, oo)) + assert (series(e1, x, x0=oo, n=5) == + 6/x**4 + 2/x**3 - 1/x**2 + 1/x + O(x**(-5), (x, oo))) + + +def test_X16(): + # Multivariate Taylor series expansion => 1 - (x^2 + 2 x y + y^2)/2 + O(x^4) + assert (series(cos(x + y), x + y, x0=0, n=4) == 1 - (x + y)**2/2 + + O(x**4 + x**3*y + x**2*y**2 + x*y**3 + y**4, x, y)) + + +@XFAIL +def test_X17(): + # Power series (compute the general formula) + # (c41) powerseries(log(sin(x)/x), x, 0); + # /aquarius/data2/opt/local/macsyma_422/library1/trgred.so being loaded. + # inf + # ==== i1 2 i1 2 i1 + # \ (- 1) 2 bern(2 i1) x + # (d41) > ------------------------------ + # / 2 i1 (2 i1)! + # ==== + # i1 = 1 + # fps does not calculate + assert fps(log(sin(x)/x)) == \ + Sum((-1)**k*2**(2*k - 1)*bernoulli(2*k)*x**(2*k)/(k*factorial(2*k)), (k, 1, oo)) + + +@XFAIL +def test_X18(): + # Power series (compute the general formula). Maple FPS: + # > FormalPowerSeries(exp(-x)*sin(x), x = 0); + # infinity + # ----- (1/2 k) k + # \ 2 sin(3/4 k Pi) x + # ) ------------------------- + # / k! + # ----- + # + # Now, SymPy returns + # oo + # _____ + # \ ` + # \ / k k\ + # \ k |I*(-1 - I) I*(-1 + I) | + # \ x *|----------- - -----------| + # / \ 2 2 / + # / ------------------------------ + # / k! + # /____, + # k = 0 + k = Dummy('k') + assert fps(exp(-x)*sin(x)) == \ + Sum(2**(S.Half*k)*sin(R(3, 4)*k*pi)*x**k/factorial(k), (k, 0, oo)) + + +@XFAIL +def test_X19(): + # (c45) /* Derive an explicit Taylor series solution of y as a function of + # x from the following implicit relation: + # y = x - 1 + (x - 1)^2/2 + 2/3 (x - 1)^3 + (x - 1)^4 + + # 17/10 (x - 1)^5 + ... + # */ + # x = sin(y) + cos(y); + # Time= 0 msecs + # (d45) x = sin(y) + cos(y) + # + # (c46) taylor_revert(%, y, 7); + raise NotImplementedError("Solve using series not supported. \ +Inverse Taylor series expansion also not supported") + + +@XFAIL +def test_X20(): + # Pade (rational function) approximation => (2 - x)/(2 + x) + # > numapprox[pade](exp(-x), x = 0, [1, 1]); + # bytes used=9019816, alloc=3669344, time=13.12 + # 1 - 1/2 x + # --------- + # 1 + 1/2 x + # mpmath support numeric Pade approximant but there is + # no symbolic implementation in SymPy + # https://en.wikipedia.org/wiki/Pad%C3%A9_approximant + raise NotImplementedError("Symbolic Pade approximant not supported") + + +def test_X21(): + """ + Test whether `fourier_series` of x periodical on the [-p, p] interval equals + `- (2 p / pi) sum( (-1)^n / n sin(n pi x / p), n = 1..infinity )`. + """ + p = symbols('p', positive=True) + n = symbols('n', positive=True, integer=True) + s = fourier_series(x, (x, -p, p)) + + # All cosine coefficients are equal to 0 + assert s.an.formula == 0 + + # Check for sine coefficients + assert s.bn.formula.subs(s.bn.variables[0], 0) == 0 + assert s.bn.formula.subs(s.bn.variables[0], n) == \ + -2*p/pi * (-1)**n / n * sin(n*pi*x/p) + + +@XFAIL +def test_X22(): + # (c52) /* => p / 2 + # - (2 p / pi^2) sum( [1 - (-1)^n] cos(n pi x / p) / n^2, + # n = 1..infinity ) */ + # fourier_series(abs(x), x, p); + # p + # (e52) a = - + # 0 2 + # + # %nn + # (2 (- 1) - 2) p + # (e53) a = ------------------ + # %nn 2 2 + # %pi %nn + # + # (e54) b = 0 + # %nn + # + # Time= 5290 msecs + # inf %nn %pi %nn x + # ==== (2 (- 1) - 2) cos(---------) + # \ p + # p > ------------------------------- + # / 2 + # ==== %nn + # %nn = 1 p + # (d54) ----------------------------------------- + - + # 2 2 + # %pi + raise NotImplementedError("Fourier series not supported") + + +def test_Y1(): + t = symbols('t', positive=True) + w = symbols('w', real=True) + s = symbols('s') + F, _, _ = laplace_transform(cos((w - 1)*t), t, s) + assert F == s/(s**2 + (w - 1)**2) + + +def test_Y2(): + t = symbols('t', positive=True) + w = symbols('w', real=True) + s = symbols('s') + f = inverse_laplace_transform(s/(s**2 + (w - 1)**2), s, t, simplify=True) + assert f == cos(t*(w - 1)) + + +def test_Y3(): + t = symbols('t', positive=True) + w = symbols('w', real=True) + s = symbols('s') + F, _, _ = laplace_transform(sinh(w*t)*cosh(w*t), t, s, simplify=True) + assert F == w/(s**2 - 4*w**2) + + +def test_Y4(): + t = symbols('t', positive=True) + s = symbols('s') + F, _, _ = laplace_transform(erf(3/sqrt(t)), t, s, simplify=True) + assert F == 1/s - exp(-6*sqrt(s))/s + + +def test_Y5_Y6(): +# Solve y'' + y = 4 [H(t - 1) - H(t - 2)], y(0) = 1, y'(0) = 0 where H is the +# Heaviside (unit step) function (the RHS describes a pulse of magnitude 4 and +# duration 1). See David A. Sanchez, Richard C. Allen, Jr. and Walter T. +# Kyner, _Differential Equations: An Introduction_, Addison-Wesley Publishing +# Company, 1983, p. 211. First, take the Laplace transform of the ODE +# => s^2 Y(s) - s + Y(s) = 4/s [e^(-s) - e^(-2 s)] +# where Y(s) is the Laplace transform of y(t) + t = symbols('t', real=True) + s = symbols('s') + y = Function('y') + Y = Function('Y') + F = laplace_correspondence(laplace_transform(diff(y(t), t, 2) + y(t) + - 4*(Heaviside(t - 1) - Heaviside(t - 2)), + t, s, noconds=True), {y: Y}) + D = ( + -F + s**2*Y(s) - s*y(0) + Y(s) - Subs(Derivative(y(t), t), t, 0) - + 4*exp(-s)/s + 4*exp(-2*s)/s) + assert D == 0 +# Now, solve for Y(s) and then take the inverse Laplace transform +# => Y(s) = s/(s^2 + 1) + 4 [1/s - s/(s^2 + 1)] [e^(-s) - e^(-2 s)] +# => y(t) = cos t + 4 {[1 - cos(t - 1)] H(t - 1) - [1 - cos(t - 2)] H(t - 2)} + Yf = solve(F, Y(s))[0] + Yf = laplace_initial_conds(Yf, t, {y: [1, 0]}) + assert Yf == (s**2*exp(2*s) + 4*exp(s) - 4)*exp(-2*s)/(s*(s**2 + 1)) + yf = inverse_laplace_transform(Yf, s, t) + yf = yf.collect(Heaviside(t-1)).collect(Heaviside(t-2)) + assert yf == ( + (4 - 4*cos(t - 1))*Heaviside(t - 1) + + (4*cos(t - 2) - 4)*Heaviside(t - 2) + + cos(t)*Heaviside(t)) + + +@XFAIL +def test_Y7(): + # What is the Laplace transform of an infinite square wave? + # => 1/s + 2 sum( (-1)^n e^(- s n a)/s, n = 1..infinity ) + # [Sanchez, Allen and Kyner, p. 213] + t = symbols('t', positive=True) + a = symbols('a', real=True) + s = symbols('s') + F, _, _ = laplace_transform(1 + 2*Sum((-1)**n*Heaviside(t - n*a), + (n, 1, oo)), t, s) + # returns 2*LaplaceTransform(Sum((-1)**n*Heaviside(-a*n + t), + # (n, 1, oo)), t, s) + 1/s + # https://github.com/sympy/sympy/issues/7177 + assert F == 2*Sum((-1)**n*exp(-a*n*s)/s, (n, 1, oo)) + 1/s + + +@XFAIL +def test_Y8(): + assert fourier_transform(1, x, z) == DiracDelta(z) + + +def test_Y9(): + assert (fourier_transform(exp(-9*x**2), x, z) == + sqrt(pi)*exp(-pi**2*z**2/9)/3) + + +def test_Y10(): + assert (fourier_transform(abs(x)*exp(-3*abs(x)), x, z).cancel() == + (-8*pi**2*z**2 + 18)/(16*pi**4*z**4 + 72*pi**2*z**2 + 81)) + + +@SKIP("https://github.com/sympy/sympy/issues/7181") +@slow +def test_Y11(): + # => pi cot(pi s) (0 < Re s < 1) [Gradshteyn and Ryzhik 17.43(5)] + x, s = symbols('x s') + # raises RuntimeError: maximum recursion depth exceeded + # https://github.com/sympy/sympy/issues/7181 + # Update 2019-02-17 raises: + # TypeError: cannot unpack non-iterable MellinTransform object + F, _, _ = mellin_transform(1/(1 - x), x, s) + assert F == pi*cot(pi*s) + + +@XFAIL +def test_Y12(): + # => 2^(s - 4) gamma(s/2)/gamma(4 - s/2) (0 < Re s < 1) + # [Gradshteyn and Ryzhik 17.43(16)] + x, s = symbols('x s') + # returns Wrong value -2**(s - 4)*gamma(s/2 - 3)/gamma(-s/2 + 1) + # https://github.com/sympy/sympy/issues/7182 + F, _, _ = mellin_transform(besselj(3, x)/x**3, x, s) + assert F == -2**(s - 4)*gamma(s/2)/gamma(-s/2 + 4) + + +@XFAIL +def test_Y13(): +# Z[H(t - m T)] => z/[z^m (z - 1)] (H is the Heaviside (unit step) function) z + raise NotImplementedError("z-transform not supported") + + +@XFAIL +def test_Y14(): +# Z[H(t - m T)] => z/[z^m (z - 1)] (H is the Heaviside (unit step) function) + raise NotImplementedError("z-transform not supported") + + +def test_Z1(): + r = Function('r') + assert (rsolve(r(n + 2) - 2*r(n + 1) + r(n) - 2, r(n), + {r(0): 1, r(1): m}).simplify() == n**2 + n*(m - 2) + 1) + + +def test_Z2(): + r = Function('r') + assert (rsolve(r(n) - (5*r(n - 1) - 6*r(n - 2)), r(n), {r(0): 0, r(1): 1}) + == -2**n + 3**n) + + +def test_Z3(): + # => r(n) = Fibonacci[n + 1] [Cohen, p. 83] + r = Function('r') + # recurrence solution is correct, Wester expects it to be simplified to + # fibonacci(n+1), but that is quite hard + expected = ((S(1)/2 - sqrt(5)/2)**n*(S(1)/2 - sqrt(5)/10) + + (S(1)/2 + sqrt(5)/2)**n*(sqrt(5)/10 + S(1)/2)) + sol = rsolve(r(n) - (r(n - 1) + r(n - 2)), r(n), {r(1): 1, r(2): 2}) + assert sol == expected + + +@XFAIL +def test_Z4(): +# => [c^(n+1) [c^(n+1) - 2 c - 2] + (n+1) c^2 + 2 c - n] / [(c-1)^3 (c+1)] +# [Joan Z. Yu and Robert Israel in sci.math.symbolic] + r = Function('r') + c = symbols('c') + # raises ValueError: Polynomial or rational function expected, + # got '(c**2 - c**n)/(c - c**n) + s = rsolve(r(n) - ((1 + c - c**(n-1) - c**(n+1))/(1 - c**n)*r(n - 1) + - c*(1 - c**(n-2))/(1 - c**(n-1))*r(n - 2) + 1), + r(n), {r(1): 1, r(2): (2 + 2*c + c**2)/(1 + c)}) + assert (s - (c*(n + 1)*(c*(n + 1) - 2*c - 2) + + (n + 1)*c**2 + 2*c - n)/((c-1)**3*(c+1)) == 0) + + +@XFAIL +def test_Z5(): + # Second order ODE with initial conditions---solve directly + # transform: f(t) = sin(2 t)/8 - t cos(2 t)/4 + C1, C2 = symbols('C1 C2') + # initial conditions not supported, this is a manual workaround + # https://github.com/sympy/sympy/issues/4720 + eq = Derivative(f(x), x, 2) + 4*f(x) - sin(2*x) + sol = dsolve(eq, f(x)) + f0 = Lambda(x, sol.rhs) + assert f0(x) == C2*sin(2*x) + (C1 - x/4)*cos(2*x) + f1 = Lambda(x, diff(f0(x), x)) + # TODO: Replace solve with solveset, when it works for solveset + const_dict = solve((f0(0), f1(0))) + result = f0(x).subs(C1, const_dict[C1]).subs(C2, const_dict[C2]) + assert result == -x*cos(2*x)/4 + sin(2*x)/8 + # Result is OK, but ODE solving with initial conditions should be + # supported without all this manual work + raise NotImplementedError('ODE solving with initial conditions \ +not supported') + + +@XFAIL +def test_Z6(): + # Second order ODE with initial conditions---solve using Laplace + # transform: f(t) = sin(2 t)/8 - t cos(2 t)/4 + t = symbols('t', positive=True) + s = symbols('s') + eq = Derivative(f(t), t, 2) + 4*f(t) - sin(2*t) + F, _, _ = laplace_transform(eq, t, s) + # Laplace transform for diff() not calculated + # https://github.com/sympy/sympy/issues/7176 + assert (F == s**2*LaplaceTransform(f(t), t, s) + + 4*LaplaceTransform(f(t), t, s) - 2/(s**2 + 4)) + # rest of test case not implemented diff --git a/lib/python3.10/site-packages/sympy/utilities/tests/test_xxe.py b/lib/python3.10/site-packages/sympy/utilities/tests/test_xxe.py new file mode 100644 index 0000000000000000000000000000000000000000..3936e8aa135dde5f22c71548e2f90ed58ac25cb8 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/utilities/tests/test_xxe.py @@ -0,0 +1,3 @@ +# A test file for XXE injection +# Username: Test +# Password: Test diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..48ffe216f35082452cb8d52d30898bc57c665f35 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/basisdependent.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/basisdependent.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1794dd87612db070e5b18dcd8b9fb97e80ccb050 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/basisdependent.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/coordsysrect.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/coordsysrect.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1f3e18d94a527fe6ed9b5a7bdbe038e65f9f3b18 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/coordsysrect.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/deloperator.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/deloperator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f754f11eca5065a4f5f106c632448e5f57101267 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/deloperator.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/dyadic.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/dyadic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ebefd83c6661118ad11cdcf9ad1a8d4d882846b0 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/dyadic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/functions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e1430fd76d7bbc808ae3be62f9a368493b1478aa Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/implicitregion.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/implicitregion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d5aaf60b2c78f0b166f5fd1b8027038472d6a004 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/implicitregion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/integrals.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/integrals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4322dfc160e49b1d04bd3fc415cc2ec19e542c6a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/integrals.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/operators.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/operators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b57cdcfef731715615ad9b66a304ab7ee3c304df Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/operators.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/orienters.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/orienters.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..555bc1e3d7fbfe301099cf72cf1bf5f7b8437614 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/orienters.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/parametricregion.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/parametricregion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6df1b84c0ebdb08817d39603511bec179d873fba Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/parametricregion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/point.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/point.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f256285299d4db443274d805781fe939ea52b84c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/point.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/scalar.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/scalar.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b4241a881f30cc5bf3410fdcf74b0b395a117927 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/scalar.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/__pycache__/vector.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/__pycache__/vector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b54ad226261c5da608a97f3771760b825d5778b7 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/__pycache__/vector.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__init__.py b/lib/python3.10/site-packages/sympy/vector/tests/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/__init__.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..388f79dddd863181c5b72a3b41b418ac2ed2733c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/__init__.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_coordsysrect.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_coordsysrect.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b309265dcc656b4400a9667a92a01c9216eefbd2 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_coordsysrect.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_dyadic.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_dyadic.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1a9fe535e5163bd41e603b3ad3797d3186ae229 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_dyadic.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_field_functions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_field_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cde838a0e16dec8c000fa841cd8882960eb818da Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_field_functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_functions.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_functions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a37232d957fd78aa2e34f8631fbf0761b1e2a80b Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_functions.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_implicitregion.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_implicitregion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4315e98d884bfe2c9b61b78bd3642ee4d844bc9c Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_implicitregion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_integrals.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_integrals.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..087aa4ade942fca18bf74c5c8d345e8d029e467a Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_integrals.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_operators.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_operators.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..117fea61e3583b56b945c84daf5db4c8c3e07d17 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_operators.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_parametricregion.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_parametricregion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b301e45fdbfe45e5c742ffe4ab28949df565bafd Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_parametricregion.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_printing.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_printing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..61f1f74b2e6590dcf65029ffbffd93a93d8f5dd6 Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_printing.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_vector.cpython-310.pyc b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_vector.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..233caaa1c812b504f4764f9bb37734b3dedd916d Binary files /dev/null and b/lib/python3.10/site-packages/sympy/vector/tests/__pycache__/test_vector.cpython-310.pyc differ diff --git a/lib/python3.10/site-packages/sympy/vector/tests/test_functions.py b/lib/python3.10/site-packages/sympy/vector/tests/test_functions.py new file mode 100644 index 0000000000000000000000000000000000000000..dfdf9821b6c853755ce12d0cbdfa599bd4f312e4 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/vector/tests/test_functions.py @@ -0,0 +1,184 @@ +from sympy.vector.vector import Vector +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.functions import express, matrix_to_vector, orthogonalize +from sympy.core.numbers import Rational +from sympy.core.singleton import S +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.testing.pytest import raises + +N = CoordSys3D('N') +q1, q2, q3, q4, q5 = symbols('q1 q2 q3 q4 q5') +A = N.orient_new_axis('A', q1, N.k) # type: ignore +B = A.orient_new_axis('B', q2, A.i) +C = B.orient_new_axis('C', q3, B.j) + + +def test_express(): + assert express(Vector.zero, N) == Vector.zero + assert express(S.Zero, N) is S.Zero + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + assert express(A.i, N) == cos(q1)*N.i + sin(q1)*N.j + assert express(A.j, N) == -sin(q1)*N.i + cos(q1)*N.j + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == cos(q2)*B.j - sin(q2)*B.k + assert express(A.k, B) == sin(q2)*B.j + cos(q2)*B.k + assert express(A.i, C) == cos(q3)*C.i + sin(q3)*C.k + assert express(A.j, C) == sin(q2)*sin(q3)*C.i + cos(q2)*C.j - \ + sin(q2)*cos(q3)*C.k + assert express(A.k, C) == -sin(q3)*cos(q2)*C.i + sin(q2)*C.j + \ + cos(q2)*cos(q3)*C.k + # Check to make sure UnitVectors get converted properly + assert express(N.i, N) == N.i + assert express(N.j, N) == N.j + assert express(N.k, N) == N.k + assert express(N.i, A) == (cos(q1)*A.i - sin(q1)*A.j) + assert express(N.j, A) == (sin(q1)*A.i + cos(q1)*A.j) + assert express(N.k, A) == A.k + assert express(N.i, B) == (cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k) + assert express(N.j, B) == (sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k) + assert express(N.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(N.i, C) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*C.i - + sin(q1)*cos(q2)*C.j + + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*C.k) + assert express(N.j, C) == ( + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*C.i + + cos(q1)*cos(q2)*C.j + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*C.k) + assert express(N.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(A.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(A.j, N) == (-sin(q1)*N.i + cos(q1)*N.j) + assert express(A.k, N) == N.k + assert express(A.i, A) == A.i + assert express(A.j, A) == A.j + assert express(A.k, A) == A.k + assert express(A.i, B) == B.i + assert express(A.j, B) == (cos(q2)*B.j - sin(q2)*B.k) + assert express(A.k, B) == (sin(q2)*B.j + cos(q2)*B.k) + assert express(A.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(A.j, C) == (sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k) + assert express(A.k, C) == (-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k) + + assert express(B.i, N) == (cos(q1)*N.i + sin(q1)*N.j) + assert express(B.j, N) == (-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(B.k, N) == (sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k) + assert express(B.i, A) == A.i + assert express(B.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(B.k, A) == (-sin(q2)*A.j + cos(q2)*A.k) + assert express(B.i, B) == B.i + assert express(B.j, B) == B.j + assert express(B.k, B) == B.k + assert express(B.i, C) == (cos(q3)*C.i + sin(q3)*C.k) + assert express(B.j, C) == C.j + assert express(B.k, C) == (-sin(q3)*C.i + cos(q3)*C.k) + + assert express(C.i, N) == ( + (cos(q1)*cos(q3) - sin(q1)*sin(q2)*sin(q3))*N.i + + (sin(q1)*cos(q3) + sin(q2)*sin(q3)*cos(q1))*N.j - + sin(q3)*cos(q2)*N.k) + assert express(C.j, N) == ( + -sin(q1)*cos(q2)*N.i + cos(q1)*cos(q2)*N.j + sin(q2)*N.k) + assert express(C.k, N) == ( + (sin(q3)*cos(q1) + sin(q1)*sin(q2)*cos(q3))*N.i + + (sin(q1)*sin(q3) - sin(q2)*cos(q1)*cos(q3))*N.j + + cos(q2)*cos(q3)*N.k) + assert express(C.i, A) == (cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k) + assert express(C.j, A) == (cos(q2)*A.j + sin(q2)*A.k) + assert express(C.k, A) == (sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k) + assert express(C.i, B) == (cos(q3)*B.i - sin(q3)*B.k) + assert express(C.j, B) == B.j + assert express(C.k, B) == (sin(q3)*B.i + cos(q3)*B.k) + assert express(C.i, C) == C.i + assert express(C.j, C) == C.j + assert express(C.k, C) == C.k == (C.k) + + # Check to make sure Vectors get converted back to UnitVectors + assert N.i == express((cos(q1)*A.i - sin(q1)*A.j), N).simplify() + assert N.j == express((sin(q1)*A.i + cos(q1)*A.j), N).simplify() + assert N.i == express((cos(q1)*B.i - sin(q1)*cos(q2)*B.j + + sin(q1)*sin(q2)*B.k), N).simplify() + assert N.j == express((sin(q1)*B.i + cos(q1)*cos(q2)*B.j - + sin(q2)*cos(q1)*B.k), N).simplify() + assert N.k == express((sin(q2)*B.j + cos(q2)*B.k), N).simplify() + + + assert A.i == express((cos(q1)*N.i + sin(q1)*N.j), A).simplify() + assert A.j == express((-sin(q1)*N.i + cos(q1)*N.j), A).simplify() + + assert A.j == express((cos(q2)*B.j - sin(q2)*B.k), A).simplify() + assert A.k == express((sin(q2)*B.j + cos(q2)*B.k), A).simplify() + + assert A.i == express((cos(q3)*C.i + sin(q3)*C.k), A).simplify() + assert A.j == express((sin(q2)*sin(q3)*C.i + cos(q2)*C.j - + sin(q2)*cos(q3)*C.k), A).simplify() + + assert A.k == express((-sin(q3)*cos(q2)*C.i + sin(q2)*C.j + + cos(q2)*cos(q3)*C.k), A).simplify() + assert B.i == express((cos(q1)*N.i + sin(q1)*N.j), B).simplify() + assert B.j == express((-sin(q1)*cos(q2)*N.i + + cos(q1)*cos(q2)*N.j + sin(q2)*N.k), B).simplify() + + assert B.k == express((sin(q1)*sin(q2)*N.i - + sin(q2)*cos(q1)*N.j + cos(q2)*N.k), B).simplify() + + assert B.j == express((cos(q2)*A.j + sin(q2)*A.k), B).simplify() + assert B.k == express((-sin(q2)*A.j + cos(q2)*A.k), B).simplify() + assert B.i == express((cos(q3)*C.i + sin(q3)*C.k), B).simplify() + assert B.k == express((-sin(q3)*C.i + cos(q3)*C.k), B).simplify() + assert C.i == express((cos(q3)*A.i + sin(q2)*sin(q3)*A.j - + sin(q3)*cos(q2)*A.k), C).simplify() + assert C.j == express((cos(q2)*A.j + sin(q2)*A.k), C).simplify() + assert C.k == express((sin(q3)*A.i - sin(q2)*cos(q3)*A.j + + cos(q2)*cos(q3)*A.k), C).simplify() + assert C.i == express((cos(q3)*B.i - sin(q3)*B.k), C).simplify() + assert C.k == express((sin(q3)*B.i + cos(q3)*B.k), C).simplify() + + +def test_matrix_to_vector(): + m = Matrix([[1], [2], [3]]) + assert matrix_to_vector(m, C) == C.i + 2*C.j + 3*C.k + m = Matrix([[0], [0], [0]]) + assert matrix_to_vector(m, N) == matrix_to_vector(m, C) == \ + Vector.zero + m = Matrix([[q1], [q2], [q3]]) + assert matrix_to_vector(m, N) == q1*N.i + q2*N.j + q3*N.k + + +def test_orthogonalize(): + C = CoordSys3D('C') + a, b = symbols('a b', integer=True) + i, j, k = C.base_vectors() + v1 = i + 2*j + v2 = 2*i + 3*j + v3 = 3*i + 5*j + v4 = 3*i + j + v5 = 2*i + 2*j + v6 = a*i + b*j + v7 = 4*a*i + 4*b*j + assert orthogonalize(v1, v2) == [C.i + 2*C.j, C.i*Rational(2, 5) + -C.j/5] + # from wikipedia + assert orthogonalize(v4, v5, orthonormal=True) == \ + [(3*sqrt(10))*C.i/10 + (sqrt(10))*C.j/10, (-sqrt(10))*C.i/10 + (3*sqrt(10))*C.j/10] + raises(ValueError, lambda: orthogonalize(v1, v2, v3)) + raises(ValueError, lambda: orthogonalize(v6, v7)) diff --git a/lib/python3.10/site-packages/sympy/vector/tests/test_implicitregion.py b/lib/python3.10/site-packages/sympy/vector/tests/test_implicitregion.py new file mode 100644 index 0000000000000000000000000000000000000000..3686d847a7f165cb5ba9aeb813e5922aaa17e1e0 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/vector/tests/test_implicitregion.py @@ -0,0 +1,90 @@ +from sympy.core.relational import Eq +from sympy.core.singleton import S +from sympy.abc import x, y, z, s, t +from sympy.sets import FiniteSet, EmptySet +from sympy.geometry import Point +from sympy.vector import ImplicitRegion +from sympy.testing.pytest import raises + + +def test_ImplicitRegion(): + ellipse = ImplicitRegion((x, y), (x**2/4 + y**2/16 - 1)) + assert ellipse.equation == x**2/4 + y**2/16 - 1 + assert ellipse.variables == (x, y) + assert ellipse.degree == 2 + r = ImplicitRegion((x, y, z), Eq(x**4 + y**2 - x*y, 6)) + assert r.equation == x**4 + y**2 - x*y - 6 + assert r.variables == (x, y, z) + assert r.degree == 4 + + +def test_regular_point(): + r1 = ImplicitRegion((x,), x**2 - 16) + assert r1.regular_point() == (-4,) + c1 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert c1.regular_point() == (0, -2) + c2 = ImplicitRegion((x, y), (x - S(5)/2)**2 + y**2 - (S(1)/4)**2) + assert c2.regular_point() == (S(5)/2, -S(1)/4) + c3 = ImplicitRegion((x, y), (y - 5)**2 - 16*(x - 5)) + assert c3.regular_point() == (5, 5) + r2 = ImplicitRegion((x, y), x**2 - 4*x*y - 3*y**2 + 4*x + 8*y - 5) + assert r2.regular_point() == (S(4)/7, S(9)/7) + r3 = ImplicitRegion((x, y), x**2 - 2*x*y + 3*y**2 - 2*x - 5*y + 3/2) + raises(ValueError, lambda: r3.regular_point()) + + +def test_singular_points_and_multiplicty(): + r1 = ImplicitRegion((x, y, z), Eq(x + y + z, 0)) + assert r1.singular_points() == EmptySet + r2 = ImplicitRegion((x, y, z), x*y*z + y**4 -x**2*z**2) + assert r2.singular_points() == FiniteSet((0, 0, z), (x, 0, 0)) + assert r2.multiplicity((0, 0, 0)) == 3 + assert r2.multiplicity((0, 0, 6)) == 2 + r3 = ImplicitRegion((x, y, z), z**2 - x**2 - y**2) + assert r3.singular_points() == FiniteSet((0, 0, 0)) + assert r3.multiplicity((0, 0, 0)) == 2 + r4 = ImplicitRegion((x, y), x**2 + y**2 - 2*x) + assert r4.singular_points() == EmptySet + assert r4.multiplicity(Point(1, 3)) == 0 + + +def test_rational_parametrization(): + p = ImplicitRegion((x,), x - 2) + assert p.rational_parametrization() == (x - 2,) + + line = ImplicitRegion((x, y), Eq(y, 3*x + 2)) + assert line.rational_parametrization() == (x, 3*x + 2) + + circle1 = ImplicitRegion((x, y), (x-2)**2 + (y+3)**2 - 4) + assert circle1.rational_parametrization(parameters=t) == (4*t/(t**2 + 1) + 2, 4*t**2/(t**2 + 1) - 5) + circle2 = ImplicitRegion((x, y), (x - S.Half)**2 + y**2 - (S(1)/2)**2) + + assert circle2.rational_parametrization(parameters=t) == (t/(t**2 + 1) + S(1)/2, t**2/(t**2 + 1) - S(1)/2) + circle3 = ImplicitRegion((x, y), Eq(x**2 + y**2, 2*x)) + assert circle3.rational_parametrization(parameters=(t,)) == (2*t/(t**2 + 1) + 1, 2*t**2/(t**2 + 1) - 1) + + parabola = ImplicitRegion((x, y), (y - 3)**2 - 4*(x + 6)) + assert parabola.rational_parametrization(t) == (-6 + 4/t**2, 3 + 4/t) + + rect_hyperbola = ImplicitRegion((x, y), x*y - 1) + assert rect_hyperbola.rational_parametrization(t) == (-1 + (t + 1)/t, t) + + cubic_curve = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert cubic_curve.rational_parametrization(parameters=(t)) == (t**2 - 1, t*(t**2 - 1)) + cuspidal = ImplicitRegion((x, y), (x**3 - y**2)) + assert cuspidal.rational_parametrization(t) == (t**2, t**3) + + I = ImplicitRegion((x, y), x**3 + x**2 - y**2) + assert I.rational_parametrization(t) == (t**2 - 1, t*(t**2 - 1)) + + sphere = ImplicitRegion((x, y, z), Eq(x**2 + y**2 + z**2, 2*x)) + assert sphere.rational_parametrization(parameters=(s, t)) == (2/(s**2 + t**2 + 1), 2*t/(s**2 + t**2 + 1), 2*s/(s**2 + t**2 + 1)) + + conic = ImplicitRegion((x, y), Eq(x**2 + 4*x*y + 3*y**2 + x - y + 10, 0)) + assert conic.rational_parametrization(t) == ( + S(17)/2 + 4/(3*t**2 + 4*t + 1), 4*t/(3*t**2 + 4*t + 1) - S(11)/2) + + r1 = ImplicitRegion((x, y), y**2 - x**3 + x) + raises(NotImplementedError, lambda: r1.rational_parametrization()) + r2 = ImplicitRegion((x, y), y**2 - x**3 - x**2 + 1) + raises(NotImplementedError, lambda: r2.rational_parametrization()) diff --git a/lib/python3.10/site-packages/sympy/vector/tests/test_integrals.py b/lib/python3.10/site-packages/sympy/vector/tests/test_integrals.py new file mode 100644 index 0000000000000000000000000000000000000000..08e15562cacf088d469266ca33a3cb993584aa9a --- /dev/null +++ b/lib/python3.10/site-packages/sympy/vector/tests/test_integrals.py @@ -0,0 +1,106 @@ +from sympy.core.numbers import pi +from sympy.core.singleton import S +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.testing.pytest import raises +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.integrals import ParametricIntegral, vector_integrate +from sympy.vector.parametricregion import ParametricRegion +from sympy.vector.implicitregion import ImplicitRegion +from sympy.abc import x, y, z, u, v, r, t, theta, phi +from sympy.geometry import Point, Segment, Curve, Circle, Polygon, Plane + +C = CoordSys3D('C') + +def test_parametric_lineintegrals(): + halfcircle = ParametricRegion((4*cos(theta), 4*sin(theta)), (theta, -pi/2, pi/2)) + assert ParametricIntegral(C.x*C.y**4, halfcircle) == S(8192)/5 + + curve = ParametricRegion((t, t**2, t**3), (t, 0, 1)) + field1 = 8*C.x**2*C.y*C.z*C.i + 5*C.z*C.j - 4*C.x*C.y*C.k + assert ParametricIntegral(field1, curve) == 1 + line = ParametricRegion((4*t - 1, 2 - 2*t, t), (t, 0, 1)) + assert ParametricIntegral(C.x*C.z*C.i - C.y*C.z*C.k, line) == 3 + + assert ParametricIntegral(4*C.x**3, ParametricRegion((1, t), (t, 0, 2))) == 8 + + helix = ParametricRegion((cos(t), sin(t), 3*t), (t, 0, 4*pi)) + assert ParametricIntegral(C.x*C.y*C.z, helix) == -3*sqrt(10)*pi + + field2 = C.y*C.i + C.z*C.j + C.z*C.k + assert ParametricIntegral(field2, ParametricRegion((cos(t), sin(t), t**2), (t, 0, pi))) == -5*pi/2 + pi**4/2 + +def test_parametric_surfaceintegrals(): + + semisphere = ParametricRegion((2*sin(phi)*cos(theta), 2*sin(phi)*sin(theta), 2*cos(phi)),\ + (theta, 0, 2*pi), (phi, 0, pi/2)) + assert ParametricIntegral(C.z, semisphere) == 8*pi + + cylinder = ParametricRegion((sqrt(3)*cos(theta), sqrt(3)*sin(theta), z), (z, 0, 6), (theta, 0, 2*pi)) + assert ParametricIntegral(C.y, cylinder) == 0 + + cone = ParametricRegion((v*cos(u), v*sin(u), v), (u, 0, 2*pi), (v, 0, 1)) + assert ParametricIntegral(C.x*C.i + C.y*C.j + C.z**4*C.k, cone) == pi/3 + + triangle1 = ParametricRegion((x, y), (x, 0, 2), (y, 0, 10 - 5*x)) + triangle2 = ParametricRegion((x, y), (y, 0, 10 - 5*x), (x, 0, 2)) + assert ParametricIntegral(-15.6*C.y*C.k, triangle1) == ParametricIntegral(-15.6*C.y*C.k, triangle2) + assert ParametricIntegral(C.z, triangle1) == 10*C.z + +def test_parametric_volumeintegrals(): + + cube = ParametricRegion((x, y, z), (x, 0, 1), (y, 0, 1), (z, 0, 1)) + assert ParametricIntegral(1, cube) == 1 + + solidsphere1 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (theta, 0, 2*pi), (phi, 0, pi)) + solidsphere2 = ParametricRegion((r*sin(phi)*cos(theta), r*sin(phi)*sin(theta), r*cos(phi)),\ + (r, 0, 2), (phi, 0, pi), (theta, 0, 2*pi)) + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere1) == -256*pi/15 + assert ParametricIntegral(C.x**2 + C.y**2, solidsphere2) == 256*pi/15 + + region_under_plane1 = ParametricRegion((x, y, z), (x, 0, 3), (y, 0, -2*x/3 + 2),\ + (z, 0, 6 - 2*x - 3*y)) + region_under_plane2 = ParametricRegion((x, y, z), (x, 0, 3), (z, 0, 6 - 2*x - 3*y),\ + (y, 0, -2*x/3 + 2)) + + assert ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane1) == \ + ParametricIntegral(C.x*C.i + C.j - 100*C.k, region_under_plane2) + assert ParametricIntegral(2*C.x, region_under_plane2) == -9 + +def test_vector_integrate(): + halfdisc = ParametricRegion((r*cos(theta), r* sin(theta)), (r, -2, 2), (theta, 0, pi)) + assert vector_integrate(C.x**2, halfdisc) == 4*pi + assert vector_integrate(C.x, ParametricRegion((t, t**2), (t, 2, 3))) == -17*sqrt(17)/12 + 37*sqrt(37)/12 + + assert vector_integrate(C.y**3*C.z, (C.x, 0, 3), (C.y, -1, 4)) == 765*C.z/4 + + s1 = Segment(Point(0, 0), Point(0, 1)) + assert vector_integrate(-15*C.y, s1) == S(-15)/2 + s2 = Segment(Point(4, 3, 9), Point(1, 1, 7)) + assert vector_integrate(C.y*C.i, s2) == -6 + + curve = Curve((sin(t), cos(t)), (t, 0, 2)) + assert vector_integrate(5*C.z, curve) == 10*C.z + + c1 = Circle(Point(2, 3), 6) + assert vector_integrate(C.x*C.y, c1) == 72*pi + c2 = Circle(Point(0, 0), Point(1, 1), Point(1, 0)) + assert vector_integrate(1, c2) == c2.circumference + + triangle = Polygon((0, 0), (1, 0), (1, 1)) + assert vector_integrate(C.x*C.i - 14*C.y*C.j, triangle) == 0 + p1, p2, p3, p4 = [(0, 0), (1, 0), (5, 1), (0, 1)] + poly = Polygon(p1, p2, p3, p4) + assert vector_integrate(-23*C.z, poly) == -161*C.z - 23*sqrt(17)*C.z + + point = Point(2, 3) + assert vector_integrate(C.i*C.y - C.z, point) == ParametricIntegral(C.y*C.i, ParametricRegion((2, 3))) + + c3 = ImplicitRegion((x, y), x**2 + y**2 - 4) + assert vector_integrate(45, c3) == 180*pi + c4 = ImplicitRegion((x, y), (x - 3)**2 + (y - 4)**2 - 9) + assert vector_integrate(1, c4) == 6*pi + + pl = Plane(Point(1, 1, 1), Point(2, 3, 4), Point(2, 2, 2)) + raises(ValueError, lambda: vector_integrate(C.x*C.z*C.i + C.k, pl)) diff --git a/lib/python3.10/site-packages/sympy/vector/tests/test_operators.py b/lib/python3.10/site-packages/sympy/vector/tests/test_operators.py new file mode 100644 index 0000000000000000000000000000000000000000..5734edadd00547c67d6f864b50afd966ad8392a6 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/vector/tests/test_operators.py @@ -0,0 +1,43 @@ +from sympy.vector import CoordSys3D, Gradient, Divergence, Curl, VectorZero, Laplacian +from sympy.printing.repr import srepr + +R = CoordSys3D('R') +s1 = R.x*R.y*R.z # type: ignore +s2 = R.x + 3*R.y**2 # type: ignore +s3 = R.x**2 + R.y**2 + R.z**2 # type: ignore +v1 = R.x*R.i + R.z*R.z*R.j # type: ignore +v2 = R.x*R.i + R.y*R.j + R.z*R.k # type: ignore +v3 = R.x**2*R.i + R.y**2*R.j + R.z**2*R.k # type: ignore + + +def test_Gradient(): + assert Gradient(s1) == Gradient(R.x*R.y*R.z) + assert Gradient(s2) == Gradient(R.x + 3*R.y**2) + assert Gradient(s1).doit() == R.y*R.z*R.i + R.x*R.z*R.j + R.x*R.y*R.k + assert Gradient(s2).doit() == R.i + 6*R.y*R.j + + +def test_Divergence(): + assert Divergence(v1) == Divergence(R.x*R.i + R.z*R.z*R.j) + assert Divergence(v2) == Divergence(R.x*R.i + R.y*R.j + R.z*R.k) + assert Divergence(v1).doit() == 1 + assert Divergence(v2).doit() == 3 + # issue 22384 + Rc = CoordSys3D('R', transformation='cylindrical') + assert Divergence(Rc.i).doit() == 1/Rc.r + + +def test_Curl(): + assert Curl(v1) == Curl(R.x*R.i + R.z*R.z*R.j) + assert Curl(v2) == Curl(R.x*R.i + R.y*R.j + R.z*R.k) + assert Curl(v1).doit() == (-2*R.z)*R.i + assert Curl(v2).doit() == VectorZero() + + +def test_Laplacian(): + assert Laplacian(s3) == Laplacian(R.x**2 + R.y**2 + R.z**2) + assert Laplacian(v3) == Laplacian(R.x**2*R.i + R.y**2*R.j + R.z**2*R.k) + assert Laplacian(s3).doit() == 6 + assert Laplacian(v3).doit() == 2*R.i + 2*R.j + 2*R.k + assert srepr(Laplacian(s3)) == \ + 'Laplacian(Add(Pow(R.x, Integer(2)), Pow(R.y, Integer(2)), Pow(R.z, Integer(2))))' diff --git a/lib/python3.10/site-packages/sympy/vector/tests/test_parametricregion.py b/lib/python3.10/site-packages/sympy/vector/tests/test_parametricregion.py new file mode 100644 index 0000000000000000000000000000000000000000..e785b96744f9e2c39e91b997fcb70f8a921256bd --- /dev/null +++ b/lib/python3.10/site-packages/sympy/vector/tests/test_parametricregion.py @@ -0,0 +1,97 @@ +from sympy.core.numbers import pi +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.parametricregion import ParametricRegion, parametric_region_list +from sympy.geometry import Point, Segment, Curve, Ellipse, Line, Parabola, Polygon +from sympy.testing.pytest import raises +from sympy.abc import a, b, r, t, x, y, z, theta, phi + + +C = CoordSys3D('C') + +def test_ParametricRegion(): + + point = ParametricRegion((3, 4)) + assert point.definition == (3, 4) + assert point.parameters == () + assert point.limits == {} + assert point.dimensions == 0 + + # line x = y + line_xy = ParametricRegion((y, y), (y, 1, 5)) + assert line_xy .definition == (y, y) + assert line_xy.parameters == (y,) + assert line_xy.dimensions == 1 + + # line y = z + line_yz = ParametricRegion((x,t,t), x, (t, 1, 2)) + assert line_yz.definition == (x,t,t) + assert line_yz.parameters == (x, t) + assert line_yz.limits == {t: (1, 2)} + assert line_yz.dimensions == 1 + + p1 = ParametricRegion((9*a, -16*b), (a, 0, 2), (b, -1, 5)) + assert p1.definition == (9*a, -16*b) + assert p1.parameters == (a, b) + assert p1.limits == {a: (0, 2), b: (-1, 5)} + assert p1.dimensions == 2 + + p2 = ParametricRegion((t, t**3), t) + assert p2.parameters == (t,) + assert p2.limits == {} + assert p2.dimensions == 0 + + circle = ParametricRegion((r*cos(theta), r*sin(theta)), r, (theta, 0, 2*pi)) + assert circle.definition == (r*cos(theta), r*sin(theta)) + assert circle.dimensions == 1 + + halfdisc = ParametricRegion((r*cos(theta), r*sin(theta)), (r, -2, 2), (theta, 0, pi)) + assert halfdisc.definition == (r*cos(theta), r*sin(theta)) + assert halfdisc.parameters == (r, theta) + assert halfdisc.limits == {r: (-2, 2), theta: (0, pi)} + assert halfdisc.dimensions == 2 + + ellipse = ParametricRegion((a*cos(t), b*sin(t)), (t, 0, 8)) + assert ellipse.parameters == (t,) + assert ellipse.limits == {t: (0, 8)} + assert ellipse.dimensions == 1 + + cylinder = ParametricRegion((r*cos(theta), r*sin(theta), z), (r, 0, 1), (theta, 0, 2*pi), (z, 0, 4)) + assert cylinder.parameters == (r, theta, z) + assert cylinder.dimensions == 3 + + sphere = ParametricRegion((r*sin(phi)*cos(theta),r*sin(phi)*sin(theta), r*cos(phi)), + r, (theta, 0, 2*pi), (phi, 0, pi)) + assert sphere.definition == (r*sin(phi)*cos(theta),r*sin(phi)*sin(theta), r*cos(phi)) + assert sphere.parameters == (r, theta, phi) + assert sphere.dimensions == 2 + + raises(ValueError, lambda: ParametricRegion((a*t**2, 2*a*t), (a, -2))) + raises(ValueError, lambda: ParametricRegion((a, b), (a**2, sin(b)), (a, 2, 4, 6))) + + +def test_parametric_region_list(): + + point = Point(-5, 12) + assert parametric_region_list(point) == [ParametricRegion((-5, 12))] + + e = Ellipse(Point(2, 8), 2, 6) + assert parametric_region_list(e, t) == [ParametricRegion((2*cos(t) + 2, 6*sin(t) + 8), (t, 0, 2*pi))] + + c = Curve((t, t**3), (t, 5, 3)) + assert parametric_region_list(c) == [ParametricRegion((t, t**3), (t, 5, 3))] + + s = Segment(Point(2, 11, -6), Point(0, 2, 5)) + assert parametric_region_list(s, t) == [ParametricRegion((2 - 2*t, 11 - 9*t, 11*t - 6), (t, 0, 1))] + s1 = Segment(Point(0, 0), (1, 0)) + assert parametric_region_list(s1, t) == [ParametricRegion((t, 0), (t, 0, 1))] + s2 = Segment(Point(1, 2, 3), Point(1, 2, 5)) + assert parametric_region_list(s2, t) == [ParametricRegion((1, 2, 2*t + 3), (t, 0, 1))] + s3 = Segment(Point(12, 56), Point(12, 56)) + assert parametric_region_list(s3) == [ParametricRegion((12, 56))] + + poly = Polygon((1,3), (-3, 8), (2, 4)) + assert parametric_region_list(poly, t) == [ParametricRegion((1 - 4*t, 5*t + 3), (t, 0, 1)), ParametricRegion((5*t - 3, 8 - 4*t), (t, 0, 1)), ParametricRegion((2 - t, 4 - t), (t, 0, 1))] + + p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7,8))) + raises(ValueError, lambda: parametric_region_list(p1)) diff --git a/lib/python3.10/site-packages/sympy/vector/tests/test_printing.py b/lib/python3.10/site-packages/sympy/vector/tests/test_printing.py new file mode 100644 index 0000000000000000000000000000000000000000..ae76905e967bdf93485f135c6a69f968e1208986 --- /dev/null +++ b/lib/python3.10/site-packages/sympy/vector/tests/test_printing.py @@ -0,0 +1,221 @@ +# -*- coding: utf-8 -*- +from sympy.core.function import Function +from sympy.integrals.integrals import Integral +from sympy.printing.latex import latex +from sympy.printing.pretty import pretty as xpretty +from sympy.vector import CoordSys3D, Del, Vector, express +from sympy.abc import a, b, c +from sympy.testing.pytest import XFAIL + + +def pretty(expr): + """ASCII pretty-printing""" + return xpretty(expr, use_unicode=False, wrap_line=False) + + +def upretty(expr): + """Unicode pretty-printing""" + return xpretty(expr, use_unicode=True, wrap_line=False) + + +# Initialize the basic and tedious vector/dyadic expressions +# needed for testing. +# Some of the pretty forms shown denote how the expressions just +# above them should look with pretty printing. +N = CoordSys3D('N') +C = N.orient_new_axis('C', a, N.k) # type: ignore +v = [] +d = [] +v.append(Vector.zero) +v.append(N.i) # type: ignore +v.append(-N.i) # type: ignore +v.append(N.i + N.j) # type: ignore +v.append(a*N.i) # type: ignore +v.append(a*N.i - b*N.j) # type: ignore +v.append((a**2 + N.x)*N.i + N.k) # type: ignore +v.append((a**2 + b)*N.i + 3*(C.y - c)*N.k) # type: ignore +f = Function('f') +v.append(N.j - (Integral(f(b)) - C.x**2)*N.k) # type: ignore +upretty_v_8 = """\ + ⎛ 2 ⌠ ⎞ \n\ +j_N + ⎜x_C - ⎮ f(b) db⎟ k_N\n\ + ⎝ ⌡ ⎠ \ +""" +pretty_v_8 = """\ +j_N + / / \\\n\ + | 2 | |\n\ + |x_C - | f(b) db|\n\ + | | |\n\ + \\ / / \ +""" + +v.append(N.i + C.k) # type: ignore +v.append(express(N.i, C)) # type: ignore +v.append((a**2 + b)*N.i + (Integral(f(b)))*N.k) # type: ignore +upretty_v_11 = """\ +⎛ 2 ⎞ ⎛⌠ ⎞ \n\ +⎝a + b⎠ i_N + ⎜⎮ f(b) db⎟ k_N\n\ + ⎝⌡ ⎠ \ +""" +pretty_v_11 = """\ +/ 2 \\ + / / \\\n\ +\\a + b/ i_N| | |\n\ + | | f(b) db|\n\ + | | |\n\ + \\/ / \ +""" + +for x in v: + d.append(x | N.k) # type: ignore +s = 3*N.x**2*C.y # type: ignore +upretty_s = """\ + 2\n\ +3⋅y_C⋅x_N \ +""" +pretty_s = """\ + 2\n\ +3*y_C*x_N \ +""" + +# This is the pretty form for ((a**2 + b)*N.i + 3*(C.y - c)*N.k) | N.k +upretty_d_7 = """\ +⎛ 2 ⎞ \n\ +⎝a + b⎠ (i_N|k_N) + (3⋅y_C - 3⋅c) (k_N|k_N)\ +""" +pretty_d_7 = """\ +/ 2 \\ (i_N|k_N) + (3*y_C - 3*c) (k_N|k_N)\n\ +\\a + b/ \ +""" + + +def test_str_printing(): + assert str(v[0]) == '0' + assert str(v[1]) == 'N.i' + assert str(v[2]) == '(-1)*N.i' + assert str(v[3]) == 'N.i + N.j' + assert str(v[8]) == 'N.j + (C.x**2 - Integral(f(b), b))*N.k' + assert str(v[9]) == 'C.k + N.i' + assert str(s) == '3*C.y*N.x**2' + assert str(d[0]) == '0' + assert str(d[1]) == '(N.i|N.k)' + assert str(d[4]) == 'a*(N.i|N.k)' + assert str(d[5]) == 'a*(N.i|N.k) + (-b)*(N.j|N.k)' + assert str(d[8]) == ('(N.j|N.k) + (C.x**2 - ' + + 'Integral(f(b), b))*(N.k|N.k)') + + +@XFAIL +def test_pretty_printing_ascii(): + assert pretty(v[0]) == '0' + assert pretty(v[1]) == 'i_N' + assert pretty(v[5]) == '(a) i_N + (-b) j_N' + assert pretty(v[8]) == pretty_v_8 + assert pretty(v[2]) == '(-1) i_N' + assert pretty(v[11]) == pretty_v_11 + assert pretty(s) == pretty_s + assert pretty(d[0]) == '(0|0)' + assert pretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert pretty(d[7]) == pretty_d_7 + assert pretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_pretty_print_unicode_v(): + assert upretty(v[0]) == '0' + assert upretty(v[1]) == 'i_N' + assert upretty(v[5]) == '(a) i_N + (-b) j_N' + # Make sure the printing works in other objects + assert upretty(v[5].args) == '((a) i_N, (-b) j_N)' + assert upretty(v[8]) == upretty_v_8 + assert upretty(v[2]) == '(-1) i_N' + assert upretty(v[11]) == upretty_v_11 + assert upretty(s) == upretty_s + assert upretty(d[0]) == '(0|0)' + assert upretty(d[5]) == '(a) (i_N|k_N) + (-b) (j_N|k_N)' + assert upretty(d[7]) == upretty_d_7 + assert upretty(d[10]) == '(cos(a)) (i_C|k_N) + (-sin(a)) (j_C|k_N)' + + +def test_latex_printing(): + assert latex(v[0]) == '\\mathbf{\\hat{0}}' + assert latex(v[1]) == '\\mathbf{\\hat{i}_{N}}' + assert latex(v[2]) == '- \\mathbf{\\hat{i}_{N}}' + assert latex(v[5]) == ('\\left(a\\right)\\mathbf{\\hat{i}_{N}} + ' + + '\\left(- b\\right)\\mathbf{\\hat{j}_{N}}') + assert latex(v[6]) == ('\\left(\\mathbf{{x}_{N}} + a^{2}\\right)\\mathbf{\\hat{i}_' + + '{N}} + \\mathbf{\\hat{k}_{N}}') + assert latex(v[8]) == ('\\mathbf{\\hat{j}_{N}} + \\left(\\mathbf{{x}_' + + '{C}}^{2} - \\int f{\\left(b \\right)}\\,' + + ' db\\right)\\mathbf{\\hat{k}_{N}}') + assert latex(s) == '3 \\mathbf{{y}_{C}} \\mathbf{{x}_{N}}^{2}' + assert latex(d[0]) == '(\\mathbf{\\hat{0}}|\\mathbf{\\hat{0}})' + assert latex(d[4]) == ('\\left(a\\right)\\left(\\mathbf{\\hat{i}_{N}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right)') + assert latex(d[9]) == ('\\left(\\mathbf{\\hat{k}_{C}}{\\middle|}' + + '\\mathbf{\\hat{k}_{N}}\\right) + \\left(' + + '\\mathbf{\\hat{i}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + assert latex(d[11]) == ('\\left(a^{2} + b\\right)\\left(\\mathbf{\\hat{i}_{N}}' + + '{\\middle|}\\mathbf{\\hat{k}_{N}}\\right) + ' + + '\\left(\\int f{\\left(b \\right)}\\, db\\right)\\left(' + + '\\mathbf{\\hat{k}_{N}}{\\middle|}\\mathbf{' + + '\\hat{k}_{N}}\\right)') + +def test_issue_23058(): + from sympy import symbols, sin, cos, pi, UnevaluatedExpr + + delop = Del() + CC_ = CoordSys3D("C") + y = CC_.y + xhat = CC_.i + + t = symbols("t") + ten = symbols("10", positive=True) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + vecE = (1/eps) * Integral(delop.cross(vecB/mu).doit(), t) + vecE = vecE.doit() + + vecB_str = """\ +⎛ ⎛y_C⎞ ⎛ 5 ⎞⎞ \n\ +⎜2⋅sin⎜───⎟⋅cos⎝10 ⋅t⎠⎟ i_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────⎟ \n\ +⎜ 4 ⎟ \n\ +⎝ 10 ⎠ \ +""" + vecE_str = """\ +⎛ 4 ⎛ 5 ⎞ ⎛y_C⎞ ⎞ \n\ +⎜-10 ⋅sin⎝10 ⋅t⎠⋅cos⎜───⎟ ⎟ k_C\n\ +⎜ ⎜ 3⎟ ⎟ \n\ +⎜ ⎝10 ⎠ ⎟ \n\ +⎜─────────────────────────⎟ \n\ +⎝ 2⋅π ⎠ \ +""" + + assert upretty(vecB) == vecB_str + assert upretty(vecE) == vecE_str + + ten = UnevaluatedExpr(10) + eps, mu = 4*pi*ten**(-11), ten**(-5) + + Bx = 2 * ten**(-4) * cos(ten**5 * t) * sin(ten**(-3) * y) + vecB = Bx * xhat + + vecB_str = """\ +⎛ -4 ⎛ 5⎞ ⎛ -3⎞⎞ \n\ +⎝2⋅10 ⋅cos⎝t⋅10 ⎠⋅sin⎝y_C⋅10 ⎠⎠ i_C \ +""" + assert upretty(vecB) == vecB_str + +def test_custom_names(): + A = CoordSys3D('A', vector_names=['x', 'y', 'z'], + variable_names=['i', 'j', 'k']) + assert A.i.__str__() == 'A.i' + assert A.x.__str__() == 'A.x' + assert A.i._pretty_form == 'i_A' + assert A.x._pretty_form == 'x_A' + assert A.i._latex_form == r'\mathbf{{i}_{A}}' + assert A.x._latex_form == r"\mathbf{\hat{x}_{A}}" diff --git a/lib/python3.10/site-packages/sympy/vector/tests/test_vector.py b/lib/python3.10/site-packages/sympy/vector/tests/test_vector.py new file mode 100644 index 0000000000000000000000000000000000000000..b68fb9fb3efb1f11f1d5a8908aa80dc1f9d7e46e --- /dev/null +++ b/lib/python3.10/site-packages/sympy/vector/tests/test_vector.py @@ -0,0 +1,266 @@ +from sympy.core import Rational, S +from sympy.simplify import simplify, trigsimp +from sympy.core.function import (Derivative, Function, diff) +from sympy.core.numbers import pi +from sympy.core.symbol import symbols +from sympy.functions.elementary.miscellaneous import sqrt +from sympy.functions.elementary.trigonometric import (cos, sin) +from sympy.integrals.integrals import Integral +from sympy.matrices.immutable import ImmutableDenseMatrix as Matrix +from sympy.vector.vector import Vector, BaseVector, VectorAdd, \ + VectorMul, VectorZero +from sympy.vector.coordsysrect import CoordSys3D +from sympy.vector.vector import Cross, Dot, cross +from sympy.testing.pytest import raises + +C = CoordSys3D('C') + +i, j, k = C.base_vectors() +a, b, c = symbols('a b c') + + +def test_cross(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Cross(v1, v2) == Cross(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Cross(v1, v2).doit() == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert cross(v1, v2) == C.z**3*C.i + (-C.x*C.z)*C.j + (C.x*C.y - C.x*C.z**2)*C.k + assert Cross(v1, v2) == -Cross(v2, v1) + assert Cross(v1, v2) + Cross(v2, v1) == Vector.zero + + +def test_dot(): + v1 = C.x * i + C.z * C.z * j + v2 = C.x * i + C.y * j + C.z * k + assert Dot(v1, v2) == Dot(C.x*C.i + C.z**2*C.j, C.x*C.i + C.y*C.j + C.z*C.k) + assert Dot(v1, v2).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v1, v2).doit() == C.x**2 + C.y*C.z**2 + assert Dot(v1, v2) == Dot(v2, v1) + + +def test_vector_sympy(): + """ + Test whether the Vector framework confirms to the hashing + and equality testing properties of SymPy. + """ + v1 = 3*j + assert v1 == j*3 + assert v1.components == {j: 3} + v2 = 3*i + 4*j + 5*k + v3 = 2*i + 4*j + i + 4*k + k + assert v3 == v2 + assert v3.__hash__() == v2.__hash__() + + +def test_vector(): + assert isinstance(i, BaseVector) + assert i != j + assert j != k + assert k != i + assert i - i == Vector.zero + assert i + Vector.zero == i + assert i - Vector.zero == i + assert Vector.zero != 0 + assert -Vector.zero == Vector.zero + + v1 = a*i + b*j + c*k + v2 = a**2*i + b**2*j + c**2*k + v3 = v1 + v2 + v4 = 2 * v1 + v5 = a * i + + assert isinstance(v1, VectorAdd) + assert v1 - v1 == Vector.zero + assert v1 + Vector.zero == v1 + assert v1.dot(i) == a + assert v1.dot(j) == b + assert v1.dot(k) == c + assert i.dot(v2) == a**2 + assert j.dot(v2) == b**2 + assert k.dot(v2) == c**2 + assert v3.dot(i) == a**2 + a + assert v3.dot(j) == b**2 + b + assert v3.dot(k) == c**2 + c + + assert v1 + v2 == v2 + v1 + assert v1 - v2 == -1 * (v2 - v1) + assert a * v1 == v1 * a + + assert isinstance(v5, VectorMul) + assert v5.base_vector == i + assert v5.measure_number == a + assert isinstance(v4, Vector) + assert isinstance(v4, VectorAdd) + assert isinstance(v4, Vector) + assert isinstance(Vector.zero, VectorZero) + assert isinstance(Vector.zero, Vector) + assert isinstance(v1 * 0, VectorZero) + + assert v1.to_matrix(C) == Matrix([[a], [b], [c]]) + + assert i.components == {i: 1} + assert v5.components == {i: a} + assert v1.components == {i: a, j: b, k: c} + + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(a, v1) == v1*a + assert VectorMul(1, i) == i + assert VectorAdd(v1, Vector.zero) == v1 + assert VectorMul(0, Vector.zero) == Vector.zero + raises(TypeError, lambda: v1.outer(1)) + raises(TypeError, lambda: v1.dot(1)) + + +def test_vector_magnitude_normalize(): + assert Vector.zero.magnitude() == 0 + assert Vector.zero.normalize() == Vector.zero + + assert i.magnitude() == 1 + assert j.magnitude() == 1 + assert k.magnitude() == 1 + assert i.normalize() == i + assert j.normalize() == j + assert k.normalize() == k + + v1 = a * i + assert v1.normalize() == (a/sqrt(a**2))*i + assert v1.magnitude() == sqrt(a**2) + + v2 = a*i + b*j + c*k + assert v2.magnitude() == sqrt(a**2 + b**2 + c**2) + assert v2.normalize() == v2 / v2.magnitude() + + v3 = i + j + assert v3.normalize() == (sqrt(2)/2)*C.i + (sqrt(2)/2)*C.j + + +def test_vector_simplify(): + A, s, k, m = symbols('A, s, k, m') + + test1 = (1 / a + 1 / b) * i + assert (test1 & i) != (a + b) / (a * b) + test1 = simplify(test1) + assert (test1 & i) == (a + b) / (a * b) + assert test1.simplify() == simplify(test1) + + test2 = (A**2 * s**4 / (4 * pi * k * m**3)) * i + test2 = simplify(test2) + assert (test2 & i) == (A**2 * s**4 / (4 * pi * k * m**3)) + + test3 = ((4 + 4 * a - 2 * (2 + 2 * a)) / (2 + 2 * a)) * i + test3 = simplify(test3) + assert (test3 & i) == 0 + + test4 = ((-4 * a * b**2 - 2 * b**3 - 2 * a**2 * b) / (a + b)**2) * i + test4 = simplify(test4) + assert (test4 & i) == -2 * b + + v = (sin(a)+cos(a))**2*i - j + assert trigsimp(v) == (2*sin(a + pi/4)**2)*i + (-1)*j + assert trigsimp(v) == v.trigsimp() + + assert simplify(Vector.zero) == Vector.zero + + +def test_vector_dot(): + assert i.dot(Vector.zero) == 0 + assert Vector.zero.dot(i) == 0 + assert i & Vector.zero == 0 + + assert i.dot(i) == 1 + assert i.dot(j) == 0 + assert i.dot(k) == 0 + assert i & i == 1 + assert i & j == 0 + assert i & k == 0 + + assert j.dot(i) == 0 + assert j.dot(j) == 1 + assert j.dot(k) == 0 + assert j & i == 0 + assert j & j == 1 + assert j & k == 0 + + assert k.dot(i) == 0 + assert k.dot(j) == 0 + assert k.dot(k) == 1 + assert k & i == 0 + assert k & j == 0 + assert k & k == 1 + + raises(TypeError, lambda: k.dot(1)) + + +def test_vector_cross(): + assert i.cross(Vector.zero) == Vector.zero + assert Vector.zero.cross(i) == Vector.zero + + assert i.cross(i) == Vector.zero + assert i.cross(j) == k + assert i.cross(k) == -j + assert i ^ i == Vector.zero + assert i ^ j == k + assert i ^ k == -j + + assert j.cross(i) == -k + assert j.cross(j) == Vector.zero + assert j.cross(k) == i + assert j ^ i == -k + assert j ^ j == Vector.zero + assert j ^ k == i + + assert k.cross(i) == j + assert k.cross(j) == -i + assert k.cross(k) == Vector.zero + assert k ^ i == j + assert k ^ j == -i + assert k ^ k == Vector.zero + + assert k.cross(1) == Cross(k, 1) + + +def test_projection(): + v1 = i + j + k + v2 = 3*i + 4*j + v3 = 0*i + 0*j + assert v1.projection(v1) == i + j + k + assert v1.projection(v2) == Rational(7, 3)*C.i + Rational(7, 3)*C.j + Rational(7, 3)*C.k + assert v1.projection(v1, scalar=True) == S.One + assert v1.projection(v2, scalar=True) == Rational(7, 3) + assert v3.projection(v1) == Vector.zero + assert v3.projection(v1, scalar=True) == S.Zero + + +def test_vector_diff_integrate(): + f = Function('f') + v = f(a)*C.i + a**2*C.j - C.k + assert Derivative(v, a) == Derivative((f(a))*C.i + + a**2*C.j + (-1)*C.k, a) + assert (diff(v, a) == v.diff(a) == Derivative(v, a).doit() == + (Derivative(f(a), a))*C.i + 2*a*C.j) + assert (Integral(v, a) == (Integral(f(a), a))*C.i + + (Integral(a**2, a))*C.j + (Integral(-1, a))*C.k) + + +def test_vector_args(): + raises(ValueError, lambda: BaseVector(3, C)) + raises(TypeError, lambda: BaseVector(0, Vector.zero)) + + +def test_srepr(): + from sympy.printing.repr import srepr + res = "CoordSys3D(Str('C'), Tuple(ImmutableDenseMatrix([[Integer(1), "\ + "Integer(0), Integer(0)], [Integer(0), Integer(1), Integer(0)], "\ + "[Integer(0), Integer(0), Integer(1)]]), VectorZero())).i" + assert srepr(C.i) == res + + +def test_scalar(): + from sympy.vector import CoordSys3D + C = CoordSys3D('C') + v1 = 3*C.i + 4*C.j + 5*C.k + v2 = 3*C.i - 4*C.j + 5*C.k + assert v1.is_Vector is True + assert v1.is_scalar is False + assert (v1.dot(v2)).is_scalar is True + assert (v1.cross(v2)).is_scalar is False