File size: 4,850 Bytes
ad5f26a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import pickle

from sympy.core.relational import (Eq, Ne)
from sympy.core.singleton import S
from sympy.core.symbol import symbols
from sympy.functions.elementary.miscellaneous import sqrt
from sympy.functions.elementary.trigonometric import (cos, sin)
from sympy.external import import_module
from sympy.testing.pytest import skip
from sympy.utilities.matchpy_connector import WildDot, WildPlus, WildStar, Replacer

matchpy = import_module("matchpy")

x, y, z = symbols("x y z")


def _get_first_match(expr, pattern):
    from matchpy import ManyToOneMatcher, Pattern

    matcher = ManyToOneMatcher()
    matcher.add(Pattern(pattern))
    return next(iter(matcher.match(expr)))


def test_matchpy_connector():
    if matchpy is None:
        skip("matchpy not installed")

    from multiset import Multiset
    from matchpy import Pattern, Substitution

    w_ = WildDot("w_")
    w__ = WildPlus("w__")
    w___ = WildStar("w___")

    expr = x + y
    pattern = x + w_
    p, subst = _get_first_match(expr, pattern)
    assert p == Pattern(pattern)
    assert subst == Substitution({'w_': y})

    expr = x + y + z
    pattern = x + w__
    p, subst = _get_first_match(expr, pattern)
    assert p == Pattern(pattern)
    assert subst == Substitution({'w__': Multiset([y, z])})

    expr = x + y + z
    pattern = x + y + z + w___
    p, subst = _get_first_match(expr, pattern)
    assert p == Pattern(pattern)
    assert subst == Substitution({'w___': Multiset()})


def test_matchpy_optional():
    if matchpy is None:
        skip("matchpy not installed")

    from matchpy import Pattern, Substitution
    from matchpy import ManyToOneReplacer, ReplacementRule

    p = WildDot("p", optional=1)
    q = WildDot("q", optional=0)

    pattern = p*x + q

    expr1 = 2*x
    pa, subst = _get_first_match(expr1, pattern)
    assert pa == Pattern(pattern)
    assert subst == Substitution({'p': 2, 'q': 0})

    expr2 = x + 3
    pa, subst = _get_first_match(expr2, pattern)
    assert pa == Pattern(pattern)
    assert subst == Substitution({'p': 1, 'q': 3})

    expr3 = x
    pa, subst = _get_first_match(expr3, pattern)
    assert pa == Pattern(pattern)
    assert subst == Substitution({'p': 1, 'q': 0})

    expr4 = x*y + z
    pa, subst = _get_first_match(expr4, pattern)
    assert pa == Pattern(pattern)
    assert subst == Substitution({'p': y, 'q': z})

    replacer = ManyToOneReplacer()
    replacer.add(ReplacementRule(Pattern(pattern), lambda p, q: sin(p)*cos(q)))
    assert replacer.replace(expr1) == sin(2)*cos(0)
    assert replacer.replace(expr2) == sin(1)*cos(3)
    assert replacer.replace(expr3) == sin(1)*cos(0)
    assert replacer.replace(expr4) == sin(y)*cos(z)


def test_replacer():
    if matchpy is None:
        skip("matchpy not installed")

    for info in [True, False]:
        for lambdify in [True, False]:
            _perform_test_replacer(info, lambdify)


def _perform_test_replacer(info, lambdify):

    x1_ = WildDot("x1_")
    x2_ = WildDot("x2_")

    a_ = WildDot("a_", optional=S.One)
    b_ = WildDot("b_", optional=S.One)
    c_ = WildDot("c_", optional=S.Zero)

    replacer = Replacer(common_constraints=[
        matchpy.CustomConstraint(lambda a_: not a_.has(x)),
        matchpy.CustomConstraint(lambda b_: not b_.has(x)),
        matchpy.CustomConstraint(lambda c_: not c_.has(x)),
    ], lambdify=lambdify, info=info)

    # Rewrite the equation into implicit form, unless it's already solved:
    replacer.add(Eq(x1_, x2_), Eq(x1_ - x2_, 0), conditions_nonfalse=[Ne(x2_, 0), Ne(x1_, 0), Ne(x1_, x), Ne(x2_, x)], info=1)

    # Simple equation solver for real numbers:
    replacer.add(Eq(a_*x + b_, 0), Eq(x, -b_/a_), info=2)
    disc = b_**2 - 4*a_*c_
    replacer.add(
        Eq(a_*x**2 + b_*x + c_, 0),
        Eq(x, (-b_ - sqrt(disc))/(2*a_)) | Eq(x, (-b_ + sqrt(disc))/(2*a_)),
        conditions_nonfalse=[disc >= 0],
        info=3
    )
    replacer.add(
        Eq(a_*x**2 + c_, 0),
        Eq(x, sqrt(-c_/a_)) | Eq(x, -sqrt(-c_/a_)),
        conditions_nonfalse=[-c_*a_ > 0],
        info=4
    )

    g = lambda expr, infos: (expr, infos) if info else expr

    assert replacer.replace(Eq(3*x, y)) == g(Eq(x, y/3), [1, 2])
    assert replacer.replace(Eq(x**2 + 1, 0)) == g(Eq(x**2 + 1, 0), [])
    assert replacer.replace(Eq(x**2, 4)) == g((Eq(x, 2) | Eq(x, -2)), [1, 4])
    assert replacer.replace(Eq(x**2 + 4*y*x + 4*y**2, 0)) == g(Eq(x, -2*y), [3])


def test_matchpy_object_pickle():
    if matchpy is None:
        return

    a1 = WildDot("a")
    a2 = pickle.loads(pickle.dumps(a1))
    assert a1 == a2

    a1 = WildDot("a", S(1))
    a2 = pickle.loads(pickle.dumps(a1))
    assert a1 == a2

    a1 = WildPlus("a", S(1))
    a2 = pickle.loads(pickle.dumps(a1))
    assert a1 == a2

    a1 = WildStar("a", S(1))
    a2 = pickle.loads(pickle.dumps(a1))
    assert a1 == a2