MTerryJack commited on
Commit
1aa2e4f
·
verified ·
1 Parent(s): 8e202ef

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .venv/lib/python3.13/site-packages/sympy/calculus/__init__.py +25 -0
  2. .venv/lib/python3.13/site-packages/sympy/calculus/accumulationbounds.py +804 -0
  3. .venv/lib/python3.13/site-packages/sympy/calculus/euler.py +108 -0
  4. .venv/lib/python3.13/site-packages/sympy/calculus/finite_diff.py +476 -0
  5. .venv/lib/python3.13/site-packages/sympy/calculus/singularities.py +406 -0
  6. .venv/lib/python3.13/site-packages/sympy/calculus/tests/__init__.py +0 -0
  7. .venv/lib/python3.13/site-packages/sympy/calculus/tests/test_accumulationbounds.py +336 -0
  8. .venv/lib/python3.13/site-packages/sympy/calculus/tests/test_euler.py +74 -0
  9. .venv/lib/python3.13/site-packages/sympy/calculus/tests/test_finite_diff.py +164 -0
  10. .venv/lib/python3.13/site-packages/sympy/calculus/tests/test_singularities.py +122 -0
  11. .venv/lib/python3.13/site-packages/sympy/calculus/tests/test_util.py +392 -0
  12. .venv/lib/python3.13/site-packages/sympy/calculus/util.py +895 -0
  13. .venv/lib/python3.13/site-packages/sympy/categories/__init__.py +33 -0
  14. .venv/lib/python3.13/site-packages/sympy/categories/baseclasses.py +978 -0
  15. .venv/lib/python3.13/site-packages/sympy/categories/diagram_drawing.py +2580 -0
  16. .venv/lib/python3.13/site-packages/sympy/categories/tests/__init__.py +0 -0
  17. .venv/lib/python3.13/site-packages/sympy/categories/tests/test_baseclasses.py +209 -0
  18. .venv/lib/python3.13/site-packages/sympy/categories/tests/test_drawing.py +919 -0
  19. .venv/lib/python3.13/site-packages/sympy/diffgeom/__init__.py +19 -0
  20. .venv/lib/python3.13/site-packages/sympy/diffgeom/diffgeom.py +2270 -0
  21. .venv/lib/python3.13/site-packages/sympy/diffgeom/rn.py +143 -0
  22. .venv/lib/python3.13/site-packages/sympy/diffgeom/tests/__init__.py +0 -0
  23. .venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_class_structure.py +33 -0
  24. .venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_diffgeom.py +342 -0
  25. .venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_function_diffgeom_book.py +145 -0
  26. .venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_hyperbolic_space.py +91 -0
  27. .venv/lib/python3.13/site-packages/sympy/external/__init__.py +20 -0
  28. .venv/lib/python3.13/site-packages/sympy/external/gmpy.py +342 -0
  29. .venv/lib/python3.13/site-packages/sympy/external/importtools.py +187 -0
  30. .venv/lib/python3.13/site-packages/sympy/external/ntheory.py +618 -0
  31. .venv/lib/python3.13/site-packages/sympy/external/pythonmpq.py +341 -0
  32. .venv/lib/python3.13/site-packages/sympy/external/tests/__init__.py +0 -0
  33. .venv/lib/python3.13/site-packages/sympy/external/tests/test_autowrap.py +313 -0
  34. .venv/lib/python3.13/site-packages/sympy/external/tests/test_codegen.py +375 -0
  35. .venv/lib/python3.13/site-packages/sympy/external/tests/test_gmpy.py +12 -0
  36. .venv/lib/python3.13/site-packages/sympy/external/tests/test_importtools.py +40 -0
  37. .venv/lib/python3.13/site-packages/sympy/external/tests/test_ntheory.py +307 -0
  38. .venv/lib/python3.13/site-packages/sympy/external/tests/test_numpy.py +335 -0
  39. .venv/lib/python3.13/site-packages/sympy/external/tests/test_pythonmpq.py +176 -0
  40. .venv/lib/python3.13/site-packages/sympy/external/tests/test_scipy.py +35 -0
  41. .venv/lib/python3.13/site-packages/sympy/functions/__init__.py +115 -0
  42. .venv/lib/python3.13/site-packages/sympy/geometry/__init__.py +45 -0
  43. .venv/lib/python3.13/site-packages/sympy/geometry/curve.py +424 -0
  44. .venv/lib/python3.13/site-packages/sympy/geometry/ellipse.py +1768 -0
  45. .venv/lib/python3.13/site-packages/sympy/geometry/entity.py +641 -0
  46. .venv/lib/python3.13/site-packages/sympy/geometry/exceptions.py +5 -0
  47. .venv/lib/python3.13/site-packages/sympy/geometry/line.py +2877 -0
  48. .venv/lib/python3.13/site-packages/sympy/geometry/parabola.py +422 -0
  49. .venv/lib/python3.13/site-packages/sympy/geometry/plane.py +878 -0
  50. .venv/lib/python3.13/site-packages/sympy/geometry/point.py +1378 -0
.venv/lib/python3.13/site-packages/sympy/calculus/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Calculus-related methods."""
2
+
3
+ from .euler import euler_equations
4
+ from .singularities import (singularities, is_increasing,
5
+ is_strictly_increasing, is_decreasing,
6
+ is_strictly_decreasing, is_monotonic)
7
+ from .finite_diff import finite_diff_weights, apply_finite_diff, differentiate_finite
8
+ from .util import (periodicity, not_empty_in, is_convex,
9
+ stationary_points, minimum, maximum)
10
+ from .accumulationbounds import AccumBounds
11
+
12
+ __all__ = [
13
+ 'euler_equations',
14
+
15
+ 'singularities', 'is_increasing',
16
+ 'is_strictly_increasing', 'is_decreasing',
17
+ 'is_strictly_decreasing', 'is_monotonic',
18
+
19
+ 'finite_diff_weights', 'apply_finite_diff', 'differentiate_finite',
20
+
21
+ 'periodicity', 'not_empty_in', 'is_convex', 'stationary_points',
22
+ 'minimum', 'maximum',
23
+
24
+ 'AccumBounds'
25
+ ]
.venv/lib/python3.13/site-packages/sympy/calculus/accumulationbounds.py ADDED
@@ -0,0 +1,804 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import Add, Mul, Pow, S
2
+ from sympy.core.basic import Basic
3
+ from sympy.core.expr import Expr
4
+ from sympy.core.numbers import _sympifyit, oo, zoo
5
+ from sympy.core.relational import is_le, is_lt, is_ge, is_gt
6
+ from sympy.core.sympify import _sympify
7
+ from sympy.functions.elementary.miscellaneous import Min, Max
8
+ from sympy.logic.boolalg import And
9
+ from sympy.multipledispatch import dispatch
10
+ from sympy.series.order import Order
11
+ from sympy.sets.sets import FiniteSet
12
+
13
+
14
+ class AccumulationBounds(Expr):
15
+ r"""An accumulation bounds.
16
+
17
+ # Note AccumulationBounds has an alias: AccumBounds
18
+
19
+ AccumulationBounds represent an interval `[a, b]`, which is always closed
20
+ at the ends. Here `a` and `b` can be any value from extended real numbers.
21
+
22
+ The intended meaning of AccummulationBounds is to give an approximate
23
+ location of the accumulation points of a real function at a limit point.
24
+
25
+ Let `a` and `b` be reals such that `a \le b`.
26
+
27
+ `\left\langle a, b\right\rangle = \{x \in \mathbb{R} \mid a \le x \le b\}`
28
+
29
+ `\left\langle -\infty, b\right\rangle = \{x \in \mathbb{R} \mid x \le b\} \cup \{-\infty, \infty\}`
30
+
31
+ `\left\langle a, \infty \right\rangle = \{x \in \mathbb{R} \mid a \le x\} \cup \{-\infty, \infty\}`
32
+
33
+ `\left\langle -\infty, \infty \right\rangle = \mathbb{R} \cup \{-\infty, \infty\}`
34
+
35
+ ``oo`` and ``-oo`` are added to the second and third definition respectively,
36
+ since if either ``-oo`` or ``oo`` is an argument, then the other one should
37
+ be included (though not as an end point). This is forced, since we have,
38
+ for example, ``1/AccumBounds(0, 1) = AccumBounds(1, oo)``, and the limit at
39
+ `0` is not one-sided. As `x` tends to `0-`, then `1/x \rightarrow -\infty`, so `-\infty`
40
+ should be interpreted as belonging to ``AccumBounds(1, oo)`` though it need
41
+ not appear explicitly.
42
+
43
+ In many cases it suffices to know that the limit set is bounded.
44
+ However, in some other cases more exact information could be useful.
45
+ For example, all accumulation values of `\cos(x) + 1` are non-negative.
46
+ (``AccumBounds(-1, 1) + 1 = AccumBounds(0, 2)``)
47
+
48
+ A AccumulationBounds object is defined to be real AccumulationBounds,
49
+ if its end points are finite reals.
50
+
51
+ Let `X`, `Y` be real AccumulationBounds, then their sum, difference,
52
+ product are defined to be the following sets:
53
+
54
+ `X + Y = \{ x+y \mid x \in X \cap y \in Y\}`
55
+
56
+ `X - Y = \{ x-y \mid x \in X \cap y \in Y\}`
57
+
58
+ `X \times Y = \{ x \times y \mid x \in X \cap y \in Y\}`
59
+
60
+ When an AccumBounds is raised to a negative power, if 0 is contained
61
+ between the bounds then an infinite range is returned, otherwise if an
62
+ endpoint is 0 then a semi-infinite range with consistent sign will be returned.
63
+
64
+ AccumBounds in expressions behave a lot like Intervals but the
65
+ semantics are not necessarily the same. Division (or exponentiation
66
+ to a negative integer power) could be handled with *intervals* by
67
+ returning a union of the results obtained after splitting the
68
+ bounds between negatives and positives, but that is not done with
69
+ AccumBounds. In addition, bounds are assumed to be independent of
70
+ each other; if the same bound is used in more than one place in an
71
+ expression, the result may not be the supremum or infimum of the
72
+ expression (see below). Finally, when a boundary is ``1``,
73
+ exponentiation to the power of ``oo`` yields ``oo``, neither
74
+ ``1`` nor ``nan``.
75
+
76
+ Examples
77
+ ========
78
+
79
+ >>> from sympy import AccumBounds, sin, exp, log, pi, E, S, oo
80
+ >>> from sympy.abc import x
81
+
82
+ >>> AccumBounds(0, 1) + AccumBounds(1, 2)
83
+ AccumBounds(1, 3)
84
+
85
+ >>> AccumBounds(0, 1) - AccumBounds(0, 2)
86
+ AccumBounds(-2, 1)
87
+
88
+ >>> AccumBounds(-2, 3)*AccumBounds(-1, 1)
89
+ AccumBounds(-3, 3)
90
+
91
+ >>> AccumBounds(1, 2)*AccumBounds(3, 5)
92
+ AccumBounds(3, 10)
93
+
94
+ The exponentiation of AccumulationBounds is defined
95
+ as follows:
96
+
97
+ If 0 does not belong to `X` or `n > 0` then
98
+
99
+ `X^n = \{ x^n \mid x \in X\}`
100
+
101
+ >>> AccumBounds(1, 4)**(S(1)/2)
102
+ AccumBounds(1, 2)
103
+
104
+ otherwise, an infinite or semi-infinite result is obtained:
105
+
106
+ >>> 1/AccumBounds(-1, 1)
107
+ AccumBounds(-oo, oo)
108
+ >>> 1/AccumBounds(0, 2)
109
+ AccumBounds(1/2, oo)
110
+ >>> 1/AccumBounds(-oo, 0)
111
+ AccumBounds(-oo, 0)
112
+
113
+ A boundary of 1 will always generate all nonnegatives:
114
+
115
+ >>> AccumBounds(1, 2)**oo
116
+ AccumBounds(0, oo)
117
+ >>> AccumBounds(0, 1)**oo
118
+ AccumBounds(0, oo)
119
+
120
+ If the exponent is itself an AccumulationBounds or is not an
121
+ integer then unevaluated results will be returned unless the base
122
+ values are positive:
123
+
124
+ >>> AccumBounds(2, 3)**AccumBounds(-1, 2)
125
+ AccumBounds(1/3, 9)
126
+ >>> AccumBounds(-2, 3)**AccumBounds(-1, 2)
127
+ AccumBounds(-2, 3)**AccumBounds(-1, 2)
128
+
129
+ >>> AccumBounds(-2, -1)**(S(1)/2)
130
+ sqrt(AccumBounds(-2, -1))
131
+
132
+ Note: `\left\langle a, b\right\rangle^2` is not same as `\left\langle a, b\right\rangle \times \left\langle a, b\right\rangle`
133
+
134
+ >>> AccumBounds(-1, 1)**2
135
+ AccumBounds(0, 1)
136
+
137
+ >>> AccumBounds(1, 3) < 4
138
+ True
139
+
140
+ >>> AccumBounds(1, 3) < -1
141
+ False
142
+
143
+ Some elementary functions can also take AccumulationBounds as input.
144
+ A function `f` evaluated for some real AccumulationBounds `\left\langle a, b \right\rangle`
145
+ is defined as `f(\left\langle a, b\right\rangle) = \{ f(x) \mid a \le x \le b \}`
146
+
147
+ >>> sin(AccumBounds(pi/6, pi/3))
148
+ AccumBounds(1/2, sqrt(3)/2)
149
+
150
+ >>> exp(AccumBounds(0, 1))
151
+ AccumBounds(1, E)
152
+
153
+ >>> log(AccumBounds(1, E))
154
+ AccumBounds(0, 1)
155
+
156
+ Some symbol in an expression can be substituted for a AccumulationBounds
157
+ object. But it does not necessarily evaluate the AccumulationBounds for
158
+ that expression.
159
+
160
+ The same expression can be evaluated to different values depending upon
161
+ the form it is used for substitution since each instance of an
162
+ AccumulationBounds is considered independent. For example:
163
+
164
+ >>> (x**2 + 2*x + 1).subs(x, AccumBounds(-1, 1))
165
+ AccumBounds(-1, 4)
166
+
167
+ >>> ((x + 1)**2).subs(x, AccumBounds(-1, 1))
168
+ AccumBounds(0, 4)
169
+
170
+ References
171
+ ==========
172
+
173
+ .. [1] https://en.wikipedia.org/wiki/Interval_arithmetic
174
+
175
+ .. [2] https://fab.cba.mit.edu/classes/S62.12/docs/Hickey_interval.pdf
176
+
177
+ Notes
178
+ =====
179
+
180
+ Do not use ``AccumulationBounds`` for floating point interval arithmetic
181
+ calculations, use ``mpmath.iv`` instead.
182
+ """
183
+
184
+ is_extended_real = True
185
+ is_number = False
186
+
187
+ def __new__(cls, min, max) -> Expr: # type: ignore
188
+
189
+ min = _sympify(min)
190
+ max = _sympify(max)
191
+
192
+ # Only allow real intervals (use symbols with 'is_extended_real=True').
193
+ if not min.is_extended_real or not max.is_extended_real:
194
+ raise ValueError("Only real AccumulationBounds are supported")
195
+
196
+ if max == min:
197
+ return max
198
+
199
+ # Make sure that the created AccumBounds object will be valid.
200
+ if max.is_number and min.is_number:
201
+ bad = max.is_comparable and min.is_comparable and max < min
202
+ else:
203
+ bad = (max - min).is_extended_negative
204
+ if bad:
205
+ raise ValueError(
206
+ "Lower limit should be smaller than upper limit")
207
+
208
+ return Basic.__new__(cls, min, max)
209
+
210
+ # setting the operation priority
211
+ _op_priority = 11.0
212
+
213
+ def _eval_is_real(self):
214
+ if self.min.is_real and self.max.is_real:
215
+ return True
216
+
217
+ @property
218
+ def min(self):
219
+ """
220
+ Returns the minimum possible value attained by AccumulationBounds
221
+ object.
222
+
223
+ Examples
224
+ ========
225
+
226
+ >>> from sympy import AccumBounds
227
+ >>> AccumBounds(1, 3).min
228
+ 1
229
+
230
+ """
231
+ return self.args[0]
232
+
233
+ @property
234
+ def max(self):
235
+ """
236
+ Returns the maximum possible value attained by AccumulationBounds
237
+ object.
238
+
239
+ Examples
240
+ ========
241
+
242
+ >>> from sympy import AccumBounds
243
+ >>> AccumBounds(1, 3).max
244
+ 3
245
+
246
+ """
247
+ return self.args[1]
248
+
249
+ @property
250
+ def delta(self):
251
+ """
252
+ Returns the difference of maximum possible value attained by
253
+ AccumulationBounds object and minimum possible value attained
254
+ by AccumulationBounds object.
255
+
256
+ Examples
257
+ ========
258
+
259
+ >>> from sympy import AccumBounds
260
+ >>> AccumBounds(1, 3).delta
261
+ 2
262
+
263
+ """
264
+ return self.max - self.min
265
+
266
+ @property
267
+ def mid(self):
268
+ """
269
+ Returns the mean of maximum possible value attained by
270
+ AccumulationBounds object and minimum possible value
271
+ attained by AccumulationBounds object.
272
+
273
+ Examples
274
+ ========
275
+
276
+ >>> from sympy import AccumBounds
277
+ >>> AccumBounds(1, 3).mid
278
+ 2
279
+
280
+ """
281
+ return (self.min + self.max) / 2
282
+
283
+ @_sympifyit('other', NotImplemented)
284
+ def _eval_power(self, other):
285
+ return self.__pow__(other)
286
+
287
+ @_sympifyit('other', NotImplemented)
288
+ def __add__(self, other):
289
+ if isinstance(other, Expr):
290
+ if isinstance(other, AccumBounds):
291
+ return AccumBounds(
292
+ Add(self.min, other.min),
293
+ Add(self.max, other.max))
294
+ if other is S.Infinity and self.min is S.NegativeInfinity or \
295
+ other is S.NegativeInfinity and self.max is S.Infinity:
296
+ return AccumBounds(-oo, oo)
297
+ elif other.is_extended_real:
298
+ if self.min is S.NegativeInfinity and self.max is S.Infinity:
299
+ return AccumBounds(-oo, oo)
300
+ elif self.min is S.NegativeInfinity:
301
+ return AccumBounds(-oo, self.max + other)
302
+ elif self.max is S.Infinity:
303
+ return AccumBounds(self.min + other, oo)
304
+ else:
305
+ return AccumBounds(Add(self.min, other), Add(self.max, other))
306
+ return Add(self, other, evaluate=False)
307
+ return NotImplemented
308
+
309
+ __radd__ = __add__
310
+
311
+ def __neg__(self):
312
+ return AccumBounds(-self.max, -self.min)
313
+
314
+ @_sympifyit('other', NotImplemented)
315
+ def __sub__(self, other):
316
+ if isinstance(other, Expr):
317
+ if isinstance(other, AccumBounds):
318
+ return AccumBounds(
319
+ Add(self.min, -other.max),
320
+ Add(self.max, -other.min))
321
+ if other is S.NegativeInfinity and self.min is S.NegativeInfinity or \
322
+ other is S.Infinity and self.max is S.Infinity:
323
+ return AccumBounds(-oo, oo)
324
+ elif other.is_extended_real:
325
+ if self.min is S.NegativeInfinity and self.max is S.Infinity:
326
+ return AccumBounds(-oo, oo)
327
+ elif self.min is S.NegativeInfinity:
328
+ return AccumBounds(-oo, self.max - other)
329
+ elif self.max is S.Infinity:
330
+ return AccumBounds(self.min - other, oo)
331
+ else:
332
+ return AccumBounds(
333
+ Add(self.min, -other),
334
+ Add(self.max, -other))
335
+ return Add(self, -other, evaluate=False)
336
+ return NotImplemented
337
+
338
+ @_sympifyit('other', NotImplemented)
339
+ def __rsub__(self, other):
340
+ return self.__neg__() + other
341
+
342
+ @_sympifyit('other', NotImplemented)
343
+ def __mul__(self, other):
344
+ if self.args == (-oo, oo):
345
+ return self
346
+ if isinstance(other, Expr):
347
+ if isinstance(other, AccumBounds):
348
+ if other.args == (-oo, oo):
349
+ return other
350
+ v = set()
351
+ for a in self.args:
352
+ vi = other*a
353
+ v.update(vi.args or (vi,))
354
+ return AccumBounds(Min(*v), Max(*v))
355
+ if other is S.Infinity:
356
+ if self.min.is_zero:
357
+ return AccumBounds(0, oo)
358
+ if self.max.is_zero:
359
+ return AccumBounds(-oo, 0)
360
+ if other is S.NegativeInfinity:
361
+ if self.min.is_zero:
362
+ return AccumBounds(-oo, 0)
363
+ if self.max.is_zero:
364
+ return AccumBounds(0, oo)
365
+ if other.is_extended_real:
366
+ if other.is_zero:
367
+ if self.max is S.Infinity:
368
+ return AccumBounds(0, oo)
369
+ if self.min is S.NegativeInfinity:
370
+ return AccumBounds(-oo, 0)
371
+ return S.Zero
372
+ if other.is_extended_positive:
373
+ return AccumBounds(
374
+ Mul(self.min, other),
375
+ Mul(self.max, other))
376
+ elif other.is_extended_negative:
377
+ return AccumBounds(
378
+ Mul(self.max, other),
379
+ Mul(self.min, other))
380
+ if isinstance(other, Order):
381
+ return other
382
+ return Mul(self, other, evaluate=False)
383
+ return NotImplemented
384
+
385
+ __rmul__ = __mul__
386
+
387
+ @_sympifyit('other', NotImplemented)
388
+ def __truediv__(self, other):
389
+ if isinstance(other, Expr):
390
+ if isinstance(other, AccumBounds):
391
+ if other.min.is_positive or other.max.is_negative:
392
+ return self * AccumBounds(1/other.max, 1/other.min)
393
+
394
+ if (self.min.is_extended_nonpositive and self.max.is_extended_nonnegative and
395
+ other.min.is_extended_nonpositive and other.max.is_extended_nonnegative):
396
+ if self.min.is_zero and other.min.is_zero:
397
+ return AccumBounds(0, oo)
398
+ if self.max.is_zero and other.min.is_zero:
399
+ return AccumBounds(-oo, 0)
400
+ return AccumBounds(-oo, oo)
401
+
402
+ if self.max.is_extended_negative:
403
+ if other.min.is_extended_negative:
404
+ if other.max.is_zero:
405
+ return AccumBounds(self.max / other.min, oo)
406
+ if other.max.is_extended_positive:
407
+ # if we were dealing with intervals we would return
408
+ # Union(Interval(-oo, self.max/other.max),
409
+ # Interval(self.max/other.min, oo))
410
+ return AccumBounds(-oo, oo)
411
+
412
+ if other.min.is_zero and other.max.is_extended_positive:
413
+ return AccumBounds(-oo, self.max / other.max)
414
+
415
+ if self.min.is_extended_positive:
416
+ if other.min.is_extended_negative:
417
+ if other.max.is_zero:
418
+ return AccumBounds(-oo, self.min / other.min)
419
+ if other.max.is_extended_positive:
420
+ # if we were dealing with intervals we would return
421
+ # Union(Interval(-oo, self.min/other.min),
422
+ # Interval(self.min/other.max, oo))
423
+ return AccumBounds(-oo, oo)
424
+
425
+ if other.min.is_zero and other.max.is_extended_positive:
426
+ return AccumBounds(self.min / other.max, oo)
427
+
428
+ elif other.is_extended_real:
429
+ if other in (S.Infinity, S.NegativeInfinity):
430
+ if self == AccumBounds(-oo, oo):
431
+ return AccumBounds(-oo, oo)
432
+ if self.max is S.Infinity:
433
+ return AccumBounds(Min(0, other), Max(0, other))
434
+ if self.min is S.NegativeInfinity:
435
+ return AccumBounds(Min(0, -other), Max(0, -other))
436
+ if other.is_extended_positive:
437
+ return AccumBounds(self.min / other, self.max / other)
438
+ elif other.is_extended_negative:
439
+ return AccumBounds(self.max / other, self.min / other)
440
+ if (1 / other) is S.ComplexInfinity:
441
+ return Mul(self, 1 / other, evaluate=False)
442
+ else:
443
+ return Mul(self, 1 / other)
444
+
445
+ return NotImplemented
446
+
447
+ @_sympifyit('other', NotImplemented)
448
+ def __rtruediv__(self, other):
449
+ if isinstance(other, Expr):
450
+ if other.is_extended_real:
451
+ if other.is_zero:
452
+ return S.Zero
453
+ if (self.min.is_extended_nonpositive and self.max.is_extended_nonnegative):
454
+ if self.min.is_zero:
455
+ if other.is_extended_positive:
456
+ return AccumBounds(Mul(other, 1 / self.max), oo)
457
+ if other.is_extended_negative:
458
+ return AccumBounds(-oo, Mul(other, 1 / self.max))
459
+ if self.max.is_zero:
460
+ if other.is_extended_positive:
461
+ return AccumBounds(-oo, Mul(other, 1 / self.min))
462
+ if other.is_extended_negative:
463
+ return AccumBounds(Mul(other, 1 / self.min), oo)
464
+ return AccumBounds(-oo, oo)
465
+ else:
466
+ return AccumBounds(Min(other / self.min, other / self.max),
467
+ Max(other / self.min, other / self.max))
468
+ return Mul(other, 1 / self, evaluate=False)
469
+ else:
470
+ return NotImplemented
471
+
472
+ @_sympifyit('other', NotImplemented)
473
+ def __pow__(self, other):
474
+ if isinstance(other, Expr):
475
+ if other is S.Infinity:
476
+ if self.min.is_extended_nonnegative:
477
+ if self.max < 1:
478
+ return S.Zero
479
+ if self.min > 1:
480
+ return S.Infinity
481
+ return AccumBounds(0, oo)
482
+ elif self.max.is_extended_negative:
483
+ if self.min > -1:
484
+ return S.Zero
485
+ if self.max < -1:
486
+ return zoo
487
+ return S.NaN
488
+ else:
489
+ if self.min > -1:
490
+ if self.max < 1:
491
+ return S.Zero
492
+ return AccumBounds(0, oo)
493
+ return AccumBounds(-oo, oo)
494
+
495
+ if other is S.NegativeInfinity:
496
+ return (1/self)**oo
497
+
498
+ # generically true
499
+ if (self.max - self.min).is_nonnegative:
500
+ # well defined
501
+ if self.min.is_nonnegative:
502
+ # no 0 to worry about
503
+ if other.is_nonnegative:
504
+ # no infinity to worry about
505
+ return self.func(self.min**other, self.max**other)
506
+
507
+ if other.is_zero:
508
+ return S.One # x**0 = 1
509
+
510
+ if other.is_Integer or other.is_integer:
511
+ if self.min.is_extended_positive:
512
+ return AccumBounds(
513
+ Min(self.min**other, self.max**other),
514
+ Max(self.min**other, self.max**other))
515
+ elif self.max.is_extended_negative:
516
+ return AccumBounds(
517
+ Min(self.max**other, self.min**other),
518
+ Max(self.max**other, self.min**other))
519
+
520
+ if other % 2 == 0:
521
+ if other.is_extended_negative:
522
+ if self.min.is_zero:
523
+ return AccumBounds(self.max**other, oo)
524
+ if self.max.is_zero:
525
+ return AccumBounds(self.min**other, oo)
526
+ return (1/self)**(-other)
527
+ return AccumBounds(
528
+ S.Zero, Max(self.min**other, self.max**other))
529
+ elif other % 2 == 1:
530
+ if other.is_extended_negative:
531
+ if self.min.is_zero:
532
+ return AccumBounds(self.max**other, oo)
533
+ if self.max.is_zero:
534
+ return AccumBounds(-oo, self.min**other)
535
+ return (1/self)**(-other)
536
+ return AccumBounds(self.min**other, self.max**other)
537
+
538
+ # non-integer exponent
539
+ # 0**neg or neg**frac yields complex
540
+ if (other.is_number or other.is_rational) and (
541
+ self.min.is_extended_nonnegative or (
542
+ other.is_extended_nonnegative and
543
+ self.min.is_extended_nonnegative)):
544
+ num, den = other.as_numer_denom()
545
+ if num is S.One:
546
+ return AccumBounds(*[i**(1/den) for i in self.args])
547
+
548
+ elif den is not S.One: # e.g. if other is not Float
549
+ return (self**num)**(1/den) # ok for non-negative base
550
+
551
+ if isinstance(other, AccumBounds):
552
+ if (self.min.is_extended_positive or
553
+ self.min.is_extended_nonnegative and
554
+ other.min.is_extended_nonnegative):
555
+ p = [self**i for i in other.args]
556
+ if not any(i.is_Pow for i in p):
557
+ a = [j for i in p for j in i.args or (i,)]
558
+ try:
559
+ return self.func(min(a), max(a))
560
+ except TypeError: # can't sort
561
+ pass
562
+
563
+ return Pow(self, other, evaluate=False)
564
+
565
+ return NotImplemented
566
+
567
+ @_sympifyit('other', NotImplemented)
568
+ def __rpow__(self, other):
569
+ if other.is_real and other.is_extended_nonnegative and (
570
+ self.max - self.min).is_extended_positive:
571
+ if other is S.One:
572
+ return S.One
573
+ if other.is_extended_positive:
574
+ a, b = [other**i for i in self.args]
575
+ if min(a, b) != a:
576
+ a, b = b, a
577
+ return self.func(a, b)
578
+ if other.is_zero:
579
+ if self.min.is_zero:
580
+ return self.func(0, 1)
581
+ if self.min.is_extended_positive:
582
+ return S.Zero
583
+
584
+ return Pow(other, self, evaluate=False)
585
+
586
+ def __abs__(self):
587
+ if self.max.is_extended_negative:
588
+ return self.__neg__()
589
+ elif self.min.is_extended_negative:
590
+ return AccumBounds(S.Zero, Max(abs(self.min), self.max))
591
+ else:
592
+ return self
593
+
594
+
595
+ def __contains__(self, other):
596
+ """
597
+ Returns ``True`` if other is contained in self, where other
598
+ belongs to extended real numbers, ``False`` if not contained,
599
+ otherwise TypeError is raised.
600
+
601
+ Examples
602
+ ========
603
+
604
+ >>> from sympy import AccumBounds, oo
605
+ >>> 1 in AccumBounds(-1, 3)
606
+ True
607
+
608
+ -oo and oo go together as limits (in AccumulationBounds).
609
+
610
+ >>> -oo in AccumBounds(1, oo)
611
+ True
612
+
613
+ >>> oo in AccumBounds(-oo, 0)
614
+ True
615
+
616
+ """
617
+ other = _sympify(other)
618
+
619
+ if other in (S.Infinity, S.NegativeInfinity):
620
+ if self.min is S.NegativeInfinity or self.max is S.Infinity:
621
+ return True
622
+ return False
623
+
624
+ rv = And(self.min <= other, self.max >= other)
625
+ if rv not in (True, False):
626
+ raise TypeError("input failed to evaluate")
627
+ return rv
628
+
629
+ def intersection(self, other):
630
+ """
631
+ Returns the intersection of 'self' and 'other'.
632
+ Here other can be an instance of :py:class:`~.FiniteSet` or AccumulationBounds.
633
+
634
+ Parameters
635
+ ==========
636
+
637
+ other : AccumulationBounds
638
+ Another AccumulationBounds object with which the intersection
639
+ has to be computed.
640
+
641
+ Returns
642
+ =======
643
+
644
+ AccumulationBounds
645
+ Intersection of ``self`` and ``other``.
646
+
647
+ Examples
648
+ ========
649
+
650
+ >>> from sympy import AccumBounds, FiniteSet
651
+ >>> AccumBounds(1, 3).intersection(AccumBounds(2, 4))
652
+ AccumBounds(2, 3)
653
+
654
+ >>> AccumBounds(1, 3).intersection(AccumBounds(4, 6))
655
+ EmptySet
656
+
657
+ >>> AccumBounds(1, 4).intersection(FiniteSet(1, 2, 5))
658
+ {1, 2}
659
+
660
+ """
661
+ if not isinstance(other, (AccumBounds, FiniteSet)):
662
+ raise TypeError(
663
+ "Input must be AccumulationBounds or FiniteSet object")
664
+
665
+ if isinstance(other, FiniteSet):
666
+ fin_set = S.EmptySet
667
+ for i in other:
668
+ if i in self:
669
+ fin_set = fin_set + FiniteSet(i)
670
+ return fin_set
671
+
672
+ if self.max < other.min or self.min > other.max:
673
+ return S.EmptySet
674
+
675
+ if self.min <= other.min:
676
+ if self.max <= other.max:
677
+ return AccumBounds(other.min, self.max)
678
+ if self.max > other.max:
679
+ return other
680
+
681
+ if other.min <= self.min:
682
+ if other.max < self.max:
683
+ return AccumBounds(self.min, other.max)
684
+ if other.max > self.max:
685
+ return self
686
+
687
+ def union(self, other):
688
+ # TODO : Devise a better method for Union of AccumBounds
689
+ # this method is not actually correct and
690
+ # can be made better
691
+ if not isinstance(other, AccumBounds):
692
+ raise TypeError(
693
+ "Input must be AccumulationBounds or FiniteSet object")
694
+
695
+ if self.min <= other.min and self.max >= other.min:
696
+ return AccumBounds(self.min, Max(self.max, other.max))
697
+
698
+ if other.min <= self.min and other.max >= self.min:
699
+ return AccumBounds(other.min, Max(self.max, other.max))
700
+
701
+
702
+ @dispatch(AccumulationBounds, AccumulationBounds) # type: ignore # noqa:F811
703
+ def _eval_is_le(lhs, rhs): # noqa:F811
704
+ if is_le(lhs.max, rhs.min):
705
+ return True
706
+ if is_gt(lhs.min, rhs.max):
707
+ return False
708
+
709
+
710
+ @dispatch(AccumulationBounds, Basic) # type: ignore # noqa:F811
711
+ def _eval_is_le(lhs, rhs): # noqa: F811
712
+
713
+ """
714
+ Returns ``True `` if range of values attained by ``lhs`` AccumulationBounds
715
+ object is greater than the range of values attained by ``rhs``,
716
+ where ``rhs`` may be any value of type AccumulationBounds object or
717
+ extended real number value, ``False`` if ``rhs`` satisfies
718
+ the same property, else an unevaluated :py:class:`~.Relational`.
719
+
720
+ Examples
721
+ ========
722
+
723
+ >>> from sympy import AccumBounds, oo
724
+ >>> AccumBounds(1, 3) > AccumBounds(4, oo)
725
+ False
726
+ >>> AccumBounds(1, 4) > AccumBounds(3, 4)
727
+ AccumBounds(1, 4) > AccumBounds(3, 4)
728
+ >>> AccumBounds(1, oo) > -1
729
+ True
730
+
731
+ """
732
+ if not rhs.is_extended_real:
733
+ raise TypeError(
734
+ "Invalid comparison of %s %s" %
735
+ (type(rhs), rhs))
736
+ elif rhs.is_comparable:
737
+ if is_le(lhs.max, rhs):
738
+ return True
739
+ if is_gt(lhs.min, rhs):
740
+ return False
741
+
742
+
743
+ @dispatch(AccumulationBounds, AccumulationBounds)
744
+ def _eval_is_ge(lhs, rhs): # noqa:F811
745
+ if is_ge(lhs.min, rhs.max):
746
+ return True
747
+ if is_lt(lhs.max, rhs.min):
748
+ return False
749
+
750
+
751
+ @dispatch(AccumulationBounds, Expr) # type:ignore
752
+ def _eval_is_ge(lhs, rhs): # noqa: F811
753
+ """
754
+ Returns ``True`` if range of values attained by ``lhs`` AccumulationBounds
755
+ object is less that the range of values attained by ``rhs``, where
756
+ other may be any value of type AccumulationBounds object or extended
757
+ real number value, ``False`` if ``rhs`` satisfies the same
758
+ property, else an unevaluated :py:class:`~.Relational`.
759
+
760
+ Examples
761
+ ========
762
+
763
+ >>> from sympy import AccumBounds, oo
764
+ >>> AccumBounds(1, 3) >= AccumBounds(4, oo)
765
+ False
766
+ >>> AccumBounds(1, 4) >= AccumBounds(3, 4)
767
+ AccumBounds(1, 4) >= AccumBounds(3, 4)
768
+ >>> AccumBounds(1, oo) >= 1
769
+ True
770
+ """
771
+
772
+ if not rhs.is_extended_real:
773
+ raise TypeError(
774
+ "Invalid comparison of %s %s" %
775
+ (type(rhs), rhs))
776
+ elif rhs.is_comparable:
777
+ if is_ge(lhs.min, rhs):
778
+ return True
779
+ if is_lt(lhs.max, rhs):
780
+ return False
781
+
782
+
783
+ @dispatch(Expr, AccumulationBounds) # type:ignore
784
+ def _eval_is_ge(lhs, rhs): # noqa:F811
785
+ if not lhs.is_extended_real:
786
+ raise TypeError(
787
+ "Invalid comparison of %s %s" %
788
+ (type(lhs), lhs))
789
+ elif lhs.is_comparable:
790
+ if is_le(rhs.max, lhs):
791
+ return True
792
+ if is_gt(rhs.min, lhs):
793
+ return False
794
+
795
+
796
+ @dispatch(AccumulationBounds, AccumulationBounds) # type:ignore
797
+ def _eval_is_ge(lhs, rhs): # noqa:F811
798
+ if is_ge(lhs.min, rhs.max):
799
+ return True
800
+ if is_lt(lhs.max, rhs.min):
801
+ return False
802
+
803
+ # setting an alias for AccumulationBounds
804
+ AccumBounds = AccumulationBounds
.venv/lib/python3.13/site-packages/sympy/calculus/euler.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module implements a method to find
3
+ Euler-Lagrange Equations for given Lagrangian.
4
+ """
5
+ from itertools import combinations_with_replacement
6
+ from sympy.core.function import (Derivative, Function, diff)
7
+ from sympy.core.relational import Eq
8
+ from sympy.core.singleton import S
9
+ from sympy.core.symbol import Symbol
10
+ from sympy.core.sympify import sympify
11
+ from sympy.utilities.iterables import iterable
12
+
13
+
14
+ def euler_equations(L, funcs=(), vars=()):
15
+ r"""
16
+ Find the Euler-Lagrange equations [1]_ for a given Lagrangian.
17
+
18
+ Parameters
19
+ ==========
20
+
21
+ L : Expr
22
+ The Lagrangian that should be a function of the functions listed
23
+ in the second argument and their derivatives.
24
+
25
+ For example, in the case of two functions $f(x,y)$, $g(x,y)$ and
26
+ two independent variables $x$, $y$ the Lagrangian has the form:
27
+
28
+ .. math:: L\left(f(x,y),g(x,y),\frac{\partial f(x,y)}{\partial x},
29
+ \frac{\partial f(x,y)}{\partial y},
30
+ \frac{\partial g(x,y)}{\partial x},
31
+ \frac{\partial g(x,y)}{\partial y},x,y\right)
32
+
33
+ In many cases it is not necessary to provide anything, except the
34
+ Lagrangian, it will be auto-detected (and an error raised if this
35
+ cannot be done).
36
+
37
+ funcs : Function or an iterable of Functions
38
+ The functions that the Lagrangian depends on. The Euler equations
39
+ are differential equations for each of these functions.
40
+
41
+ vars : Symbol or an iterable of Symbols
42
+ The Symbols that are the independent variables of the functions.
43
+
44
+ Returns
45
+ =======
46
+
47
+ eqns : list of Eq
48
+ The list of differential equations, one for each function.
49
+
50
+ Examples
51
+ ========
52
+
53
+ >>> from sympy import euler_equations, Symbol, Function
54
+ >>> x = Function('x')
55
+ >>> t = Symbol('t')
56
+ >>> L = (x(t).diff(t))**2/2 - x(t)**2/2
57
+ >>> euler_equations(L, x(t), t)
58
+ [Eq(-x(t) - Derivative(x(t), (t, 2)), 0)]
59
+ >>> u = Function('u')
60
+ >>> x = Symbol('x')
61
+ >>> L = (u(t, x).diff(t))**2/2 - (u(t, x).diff(x))**2/2
62
+ >>> euler_equations(L, u(t, x), [t, x])
63
+ [Eq(-Derivative(u(t, x), (t, 2)) + Derivative(u(t, x), (x, 2)), 0)]
64
+
65
+ References
66
+ ==========
67
+
68
+ .. [1] https://en.wikipedia.org/wiki/Euler%E2%80%93Lagrange_equation
69
+
70
+ """
71
+
72
+ funcs = tuple(funcs) if iterable(funcs) else (funcs,)
73
+
74
+ if not funcs:
75
+ funcs = tuple(L.atoms(Function))
76
+ else:
77
+ for f in funcs:
78
+ if not isinstance(f, Function):
79
+ raise TypeError('Function expected, got: %s' % f)
80
+
81
+ vars = tuple(vars) if iterable(vars) else (vars,)
82
+
83
+ if not vars:
84
+ vars = funcs[0].args
85
+ else:
86
+ vars = tuple(sympify(var) for var in vars)
87
+
88
+ if not all(isinstance(v, Symbol) for v in vars):
89
+ raise TypeError('Variables are not symbols, got %s' % vars)
90
+
91
+ for f in funcs:
92
+ if not vars == f.args:
93
+ raise ValueError("Variables %s do not match args: %s" % (vars, f))
94
+
95
+ order = max([len(d.variables) for d in L.atoms(Derivative)
96
+ if d.expr in funcs] + [0])
97
+
98
+ eqns = []
99
+ for f in funcs:
100
+ eq = diff(L, f)
101
+ for i in range(1, order + 1):
102
+ for p in combinations_with_replacement(vars, i):
103
+ eq = eq + S.NegativeOne**i*diff(L, diff(f, *p), *p)
104
+ new_eq = Eq(eq, 0)
105
+ if isinstance(new_eq, Eq):
106
+ eqns.append(new_eq)
107
+
108
+ return eqns
.venv/lib/python3.13/site-packages/sympy/calculus/finite_diff.py ADDED
@@ -0,0 +1,476 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Finite difference weights
3
+ =========================
4
+
5
+ This module implements an algorithm for efficient generation of finite
6
+ difference weights for ordinary differentials of functions for
7
+ derivatives from 0 (interpolation) up to arbitrary order.
8
+
9
+ The core algorithm is provided in the finite difference weight generating
10
+ function (``finite_diff_weights``), and two convenience functions are provided
11
+ for:
12
+
13
+ - estimating a derivative (or interpolate) directly from a series of points
14
+ is also provided (``apply_finite_diff``).
15
+ - differentiating by using finite difference approximations
16
+ (``differentiate_finite``).
17
+
18
+ """
19
+
20
+ from sympy.core.function import Derivative
21
+ from sympy.core.singleton import S
22
+ from sympy.core.function import Subs
23
+ from sympy.core.traversal import preorder_traversal
24
+ from sympy.utilities.exceptions import sympy_deprecation_warning
25
+ from sympy.utilities.iterables import iterable
26
+
27
+
28
+
29
+ def finite_diff_weights(order, x_list, x0=S.One):
30
+ """
31
+ Calculates the finite difference weights for an arbitrarily spaced
32
+ one-dimensional grid (``x_list``) for derivatives at ``x0`` of order
33
+ 0, 1, ..., up to ``order`` using a recursive formula. Order of accuracy
34
+ is at least ``len(x_list) - order``, if ``x_list`` is defined correctly.
35
+
36
+ Parameters
37
+ ==========
38
+
39
+ order: int
40
+ Up to what derivative order weights should be calculated.
41
+ 0 corresponds to interpolation.
42
+ x_list: sequence
43
+ Sequence of (unique) values for the independent variable.
44
+ It is useful (but not necessary) to order ``x_list`` from
45
+ nearest to furthest from ``x0``; see examples below.
46
+ x0: Number or Symbol
47
+ Root or value of the independent variable for which the finite
48
+ difference weights should be generated. Default is ``S.One``.
49
+
50
+ Returns
51
+ =======
52
+
53
+ list
54
+ A list of sublists, each corresponding to coefficients for
55
+ increasing derivative order, and each containing lists of
56
+ coefficients for increasing subsets of x_list.
57
+
58
+ Examples
59
+ ========
60
+
61
+ >>> from sympy import finite_diff_weights, S
62
+ >>> res = finite_diff_weights(1, [-S(1)/2, S(1)/2, S(3)/2, S(5)/2], 0)
63
+ >>> res
64
+ [[[1, 0, 0, 0],
65
+ [1/2, 1/2, 0, 0],
66
+ [3/8, 3/4, -1/8, 0],
67
+ [5/16, 15/16, -5/16, 1/16]],
68
+ [[0, 0, 0, 0],
69
+ [-1, 1, 0, 0],
70
+ [-1, 1, 0, 0],
71
+ [-23/24, 7/8, 1/8, -1/24]]]
72
+ >>> res[0][-1] # FD weights for 0th derivative, using full x_list
73
+ [5/16, 15/16, -5/16, 1/16]
74
+ >>> res[1][-1] # FD weights for 1st derivative
75
+ [-23/24, 7/8, 1/8, -1/24]
76
+ >>> res[1][-2] # FD weights for 1st derivative, using x_list[:-1]
77
+ [-1, 1, 0, 0]
78
+ >>> res[1][-1][0] # FD weight for 1st deriv. for x_list[0]
79
+ -23/24
80
+ >>> res[1][-1][1] # FD weight for 1st deriv. for x_list[1], etc.
81
+ 7/8
82
+
83
+ Each sublist contains the most accurate formula at the end.
84
+ Note, that in the above example ``res[1][1]`` is the same as ``res[1][2]``.
85
+ Since res[1][2] has an order of accuracy of
86
+ ``len(x_list[:3]) - order = 3 - 1 = 2``, the same is true for ``res[1][1]``!
87
+
88
+ >>> res = finite_diff_weights(1, [S(0), S(1), -S(1), S(2), -S(2)], 0)[1]
89
+ >>> res
90
+ [[0, 0, 0, 0, 0],
91
+ [-1, 1, 0, 0, 0],
92
+ [0, 1/2, -1/2, 0, 0],
93
+ [-1/2, 1, -1/3, -1/6, 0],
94
+ [0, 2/3, -2/3, -1/12, 1/12]]
95
+ >>> res[0] # no approximation possible, using x_list[0] only
96
+ [0, 0, 0, 0, 0]
97
+ >>> res[1] # classic forward step approximation
98
+ [-1, 1, 0, 0, 0]
99
+ >>> res[2] # classic centered approximation
100
+ [0, 1/2, -1/2, 0, 0]
101
+ >>> res[3:] # higher order approximations
102
+ [[-1/2, 1, -1/3, -1/6, 0], [0, 2/3, -2/3, -1/12, 1/12]]
103
+
104
+ Let us compare this to a differently defined ``x_list``. Pay attention to
105
+ ``foo[i][k]`` corresponding to the gridpoint defined by ``x_list[k]``.
106
+
107
+ >>> foo = finite_diff_weights(1, [-S(2), -S(1), S(0), S(1), S(2)], 0)[1]
108
+ >>> foo
109
+ [[0, 0, 0, 0, 0],
110
+ [-1, 1, 0, 0, 0],
111
+ [1/2, -2, 3/2, 0, 0],
112
+ [1/6, -1, 1/2, 1/3, 0],
113
+ [1/12, -2/3, 0, 2/3, -1/12]]
114
+ >>> foo[1] # not the same and of lower accuracy as res[1]!
115
+ [-1, 1, 0, 0, 0]
116
+ >>> foo[2] # classic double backward step approximation
117
+ [1/2, -2, 3/2, 0, 0]
118
+ >>> foo[4] # the same as res[4]
119
+ [1/12, -2/3, 0, 2/3, -1/12]
120
+
121
+ Note that, unless you plan on using approximations based on subsets of
122
+ ``x_list``, the order of gridpoints does not matter.
123
+
124
+ The capability to generate weights at arbitrary points can be
125
+ used e.g. to minimize Runge's phenomenon by using Chebyshev nodes:
126
+
127
+ >>> from sympy import cos, symbols, pi, simplify
128
+ >>> N, (h, x) = 4, symbols('h x')
129
+ >>> x_list = [x+h*cos(i*pi/(N)) for i in range(N,-1,-1)] # chebyshev nodes
130
+ >>> print(x_list)
131
+ [-h + x, -sqrt(2)*h/2 + x, x, sqrt(2)*h/2 + x, h + x]
132
+ >>> mycoeffs = finite_diff_weights(1, x_list, 0)[1][4]
133
+ >>> [simplify(c) for c in mycoeffs] #doctest: +NORMALIZE_WHITESPACE
134
+ [(h**3/2 + h**2*x - 3*h*x**2 - 4*x**3)/h**4,
135
+ (-sqrt(2)*h**3 - 4*h**2*x + 3*sqrt(2)*h*x**2 + 8*x**3)/h**4,
136
+ (6*h**2*x - 8*x**3)/h**4,
137
+ (sqrt(2)*h**3 - 4*h**2*x - 3*sqrt(2)*h*x**2 + 8*x**3)/h**4,
138
+ (-h**3/2 + h**2*x + 3*h*x**2 - 4*x**3)/h**4]
139
+
140
+ Notes
141
+ =====
142
+
143
+ If weights for a finite difference approximation of 3rd order
144
+ derivative is wanted, weights for 0th, 1st and 2nd order are
145
+ calculated "for free", so are formulae using subsets of ``x_list``.
146
+ This is something one can take advantage of to save computational cost.
147
+ Be aware that one should define ``x_list`` from nearest to furthest from
148
+ ``x0``. If not, subsets of ``x_list`` will yield poorer approximations,
149
+ which might not grand an order of accuracy of ``len(x_list) - order``.
150
+
151
+ See also
152
+ ========
153
+
154
+ sympy.calculus.finite_diff.apply_finite_diff
155
+
156
+ References
157
+ ==========
158
+
159
+ .. [1] Generation of Finite Difference Formulas on Arbitrarily Spaced
160
+ Grids, Bengt Fornberg; Mathematics of computation; 51; 184;
161
+ (1988); 699-706; doi:10.1090/S0025-5718-1988-0935077-0
162
+
163
+ """
164
+ # The notation below closely corresponds to the one used in the paper.
165
+ order = S(order)
166
+ if not order.is_number:
167
+ raise ValueError("Cannot handle symbolic order.")
168
+ if order < 0:
169
+ raise ValueError("Negative derivative order illegal.")
170
+ if int(order) != order:
171
+ raise ValueError("Non-integer order illegal")
172
+ M = order
173
+ N = len(x_list) - 1
174
+ delta = [[[0 for nu in range(N+1)] for n in range(N+1)] for
175
+ m in range(M+1)]
176
+ delta[0][0][0] = S.One
177
+ c1 = S.One
178
+ for n in range(1, N+1):
179
+ c2 = S.One
180
+ for nu in range(n):
181
+ c3 = x_list[n] - x_list[nu]
182
+ c2 = c2 * c3
183
+ if n <= M:
184
+ delta[n][n-1][nu] = 0
185
+ for m in range(min(n, M)+1):
186
+ delta[m][n][nu] = (x_list[n]-x0)*delta[m][n-1][nu] -\
187
+ m*delta[m-1][n-1][nu]
188
+ delta[m][n][nu] /= c3
189
+ for m in range(min(n, M)+1):
190
+ delta[m][n][n] = c1/c2*(m*delta[m-1][n-1][n-1] -
191
+ (x_list[n-1]-x0)*delta[m][n-1][n-1])
192
+ c1 = c2
193
+ return delta
194
+
195
+
196
+ def apply_finite_diff(order, x_list, y_list, x0=S.Zero):
197
+ """
198
+ Calculates the finite difference approximation of
199
+ the derivative of requested order at ``x0`` from points
200
+ provided in ``x_list`` and ``y_list``.
201
+
202
+ Parameters
203
+ ==========
204
+
205
+ order: int
206
+ order of derivative to approximate. 0 corresponds to interpolation.
207
+ x_list: sequence
208
+ Sequence of (unique) values for the independent variable.
209
+ y_list: sequence
210
+ The function value at corresponding values for the independent
211
+ variable in x_list.
212
+ x0: Number or Symbol
213
+ At what value of the independent variable the derivative should be
214
+ evaluated. Defaults to 0.
215
+
216
+ Returns
217
+ =======
218
+
219
+ sympy.core.add.Add or sympy.core.numbers.Number
220
+ The finite difference expression approximating the requested
221
+ derivative order at ``x0``.
222
+
223
+ Examples
224
+ ========
225
+
226
+ >>> from sympy import apply_finite_diff
227
+ >>> cube = lambda arg: (1.0*arg)**3
228
+ >>> xlist = range(-3,3+1)
229
+ >>> apply_finite_diff(2, xlist, map(cube, xlist), 2) - 12 # doctest: +SKIP
230
+ -3.55271367880050e-15
231
+
232
+ we see that the example above only contain rounding errors.
233
+ apply_finite_diff can also be used on more abstract objects:
234
+
235
+ >>> from sympy import IndexedBase, Idx
236
+ >>> x, y = map(IndexedBase, 'xy')
237
+ >>> i = Idx('i')
238
+ >>> x_list, y_list = zip(*[(x[i+j], y[i+j]) for j in range(-1,2)])
239
+ >>> apply_finite_diff(1, x_list, y_list, x[i])
240
+ ((x[i + 1] - x[i])/(-x[i - 1] + x[i]) - 1)*y[i]/(x[i + 1] - x[i]) -
241
+ (x[i + 1] - x[i])*y[i - 1]/((x[i + 1] - x[i - 1])*(-x[i - 1] + x[i])) +
242
+ (-x[i - 1] + x[i])*y[i + 1]/((x[i + 1] - x[i - 1])*(x[i + 1] - x[i]))
243
+
244
+ Notes
245
+ =====
246
+
247
+ Order = 0 corresponds to interpolation.
248
+ Only supply so many points you think makes sense
249
+ to around x0 when extracting the derivative (the function
250
+ need to be well behaved within that region). Also beware
251
+ of Runge's phenomenon.
252
+
253
+ See also
254
+ ========
255
+
256
+ sympy.calculus.finite_diff.finite_diff_weights
257
+
258
+ References
259
+ ==========
260
+
261
+ Fortran 90 implementation with Python interface for numerics: finitediff_
262
+
263
+ .. _finitediff: https://github.com/bjodah/finitediff
264
+
265
+ """
266
+
267
+ # In the original paper the following holds for the notation:
268
+ # M = order
269
+ # N = len(x_list) - 1
270
+
271
+ N = len(x_list) - 1
272
+ if len(x_list) != len(y_list):
273
+ raise ValueError("x_list and y_list not equal in length.")
274
+
275
+ delta = finite_diff_weights(order, x_list, x0)
276
+
277
+ derivative = 0
278
+ for nu in range(len(x_list)):
279
+ derivative += delta[order][N][nu]*y_list[nu]
280
+ return derivative
281
+
282
+
283
+ def _as_finite_diff(derivative, points=1, x0=None, wrt=None):
284
+ """
285
+ Returns an approximation of a derivative of a function in
286
+ the form of a finite difference formula. The expression is a
287
+ weighted sum of the function at a number of discrete values of
288
+ (one of) the independent variable(s).
289
+
290
+ Parameters
291
+ ==========
292
+
293
+ derivative: a Derivative instance
294
+
295
+ points: sequence or coefficient, optional
296
+ If sequence: discrete values (length >= order+1) of the
297
+ independent variable used for generating the finite
298
+ difference weights.
299
+ If it is a coefficient, it will be used as the step-size
300
+ for generating an equidistant sequence of length order+1
301
+ centered around ``x0``. default: 1 (step-size 1)
302
+
303
+ x0: number or Symbol, optional
304
+ the value of the independent variable (``wrt``) at which the
305
+ derivative is to be approximated. Default: same as ``wrt``.
306
+
307
+ wrt: Symbol, optional
308
+ "with respect to" the variable for which the (partial)
309
+ derivative is to be approximated for. If not provided it
310
+ is required that the Derivative is ordinary. Default: ``None``.
311
+
312
+ Examples
313
+ ========
314
+
315
+ >>> from sympy import symbols, Function, exp, sqrt, Symbol
316
+ >>> from sympy.calculus.finite_diff import _as_finite_diff
317
+ >>> x, h = symbols('x h')
318
+ >>> f = Function('f')
319
+ >>> _as_finite_diff(f(x).diff(x))
320
+ -f(x - 1/2) + f(x + 1/2)
321
+
322
+ The default step size and number of points are 1 and ``order + 1``
323
+ respectively. We can change the step size by passing a symbol
324
+ as a parameter:
325
+
326
+ >>> _as_finite_diff(f(x).diff(x), h)
327
+ -f(-h/2 + x)/h + f(h/2 + x)/h
328
+
329
+ We can also specify the discretized values to be used in a sequence:
330
+
331
+ >>> _as_finite_diff(f(x).diff(x), [x, x+h, x+2*h])
332
+ -3*f(x)/(2*h) + 2*f(h + x)/h - f(2*h + x)/(2*h)
333
+
334
+ The algorithm is not restricted to use equidistant spacing, nor
335
+ do we need to make the approximation around ``x0``, but we can get
336
+ an expression estimating the derivative at an offset:
337
+
338
+ >>> e, sq2 = exp(1), sqrt(2)
339
+ >>> xl = [x-h, x+h, x+e*h]
340
+ >>> _as_finite_diff(f(x).diff(x, 1), xl, x+h*sq2)
341
+ 2*h*((h + sqrt(2)*h)/(2*h) - (-sqrt(2)*h + h)/(2*h))*f(E*h + x)/((-h + E*h)*(h + E*h)) +
342
+ (-(-sqrt(2)*h + h)/(2*h) - (-sqrt(2)*h + E*h)/(2*h))*f(-h + x)/(h + E*h) +
343
+ (-(h + sqrt(2)*h)/(2*h) + (-sqrt(2)*h + E*h)/(2*h))*f(h + x)/(-h + E*h)
344
+
345
+ Partial derivatives are also supported:
346
+
347
+ >>> y = Symbol('y')
348
+ >>> d2fdxdy=f(x,y).diff(x,y)
349
+ >>> _as_finite_diff(d2fdxdy, wrt=x)
350
+ -Derivative(f(x - 1/2, y), y) + Derivative(f(x + 1/2, y), y)
351
+
352
+ See also
353
+ ========
354
+
355
+ sympy.calculus.finite_diff.apply_finite_diff
356
+ sympy.calculus.finite_diff.finite_diff_weights
357
+
358
+ """
359
+ if derivative.is_Derivative:
360
+ pass
361
+ elif derivative.is_Atom:
362
+ return derivative
363
+ else:
364
+ return derivative.fromiter(
365
+ [_as_finite_diff(ar, points, x0, wrt) for ar
366
+ in derivative.args], **derivative.assumptions0)
367
+
368
+ if wrt is None:
369
+ old = None
370
+ for v in derivative.variables:
371
+ if old is v:
372
+ continue
373
+ derivative = _as_finite_diff(derivative, points, x0, v)
374
+ old = v
375
+ return derivative
376
+
377
+ order = derivative.variables.count(wrt)
378
+
379
+ if x0 is None:
380
+ x0 = wrt
381
+
382
+ if not iterable(points):
383
+ if getattr(points, 'is_Function', False) and wrt in points.args:
384
+ points = points.subs(wrt, x0)
385
+ # points is simply the step-size, let's make it a
386
+ # equidistant sequence centered around x0
387
+ if order % 2 == 0:
388
+ # even order => odd number of points, grid point included
389
+ points = [x0 + points*i for i
390
+ in range(-order//2, order//2 + 1)]
391
+ else:
392
+ # odd order => even number of points, half-way wrt grid point
393
+ points = [x0 + points*S(i)/2 for i
394
+ in range(-order, order + 1, 2)]
395
+ others = [wrt, 0]
396
+ for v in set(derivative.variables):
397
+ if v == wrt:
398
+ continue
399
+ others += [v, derivative.variables.count(v)]
400
+ if len(points) < order+1:
401
+ raise ValueError("Too few points for order %d" % order)
402
+ return apply_finite_diff(order, points, [
403
+ Derivative(derivative.expr.subs({wrt: x}), *others) for
404
+ x in points], x0)
405
+
406
+
407
+ def differentiate_finite(expr, *symbols,
408
+ points=1, x0=None, wrt=None, evaluate=False):
409
+ r""" Differentiate expr and replace Derivatives with finite differences.
410
+
411
+ Parameters
412
+ ==========
413
+
414
+ expr : expression
415
+ \*symbols : differentiate with respect to symbols
416
+ points: sequence, coefficient or undefined function, optional
417
+ see ``Derivative.as_finite_difference``
418
+ x0: number or Symbol, optional
419
+ see ``Derivative.as_finite_difference``
420
+ wrt: Symbol, optional
421
+ see ``Derivative.as_finite_difference``
422
+
423
+ Examples
424
+ ========
425
+
426
+ >>> from sympy import sin, Function, differentiate_finite
427
+ >>> from sympy.abc import x, y, h
428
+ >>> f, g = Function('f'), Function('g')
429
+ >>> differentiate_finite(f(x)*g(x), x, points=[x-h, x+h])
430
+ -f(-h + x)*g(-h + x)/(2*h) + f(h + x)*g(h + x)/(2*h)
431
+
432
+ ``differentiate_finite`` works on any expression, including the expressions
433
+ with embedded derivatives:
434
+
435
+ >>> differentiate_finite(f(x) + sin(x), x, 2)
436
+ -2*f(x) + f(x - 1) + f(x + 1) - 2*sin(x) + sin(x - 1) + sin(x + 1)
437
+ >>> differentiate_finite(f(x, y), x, y)
438
+ f(x - 1/2, y - 1/2) - f(x - 1/2, y + 1/2) - f(x + 1/2, y - 1/2) + f(x + 1/2, y + 1/2)
439
+ >>> differentiate_finite(f(x)*g(x).diff(x), x)
440
+ (-g(x) + g(x + 1))*f(x + 1/2) - (g(x) - g(x - 1))*f(x - 1/2)
441
+
442
+ To make finite difference with non-constant discretization step use
443
+ undefined functions:
444
+
445
+ >>> dx = Function('dx')
446
+ >>> differentiate_finite(f(x)*g(x).diff(x), points=dx(x))
447
+ -(-g(x - dx(x)/2 - dx(x - dx(x)/2)/2)/dx(x - dx(x)/2) +
448
+ g(x - dx(x)/2 + dx(x - dx(x)/2)/2)/dx(x - dx(x)/2))*f(x - dx(x)/2)/dx(x) +
449
+ (-g(x + dx(x)/2 - dx(x + dx(x)/2)/2)/dx(x + dx(x)/2) +
450
+ g(x + dx(x)/2 + dx(x + dx(x)/2)/2)/dx(x + dx(x)/2))*f(x + dx(x)/2)/dx(x)
451
+
452
+ """
453
+ if any(term.is_Derivative for term in list(preorder_traversal(expr))):
454
+ evaluate = False
455
+
456
+ Dexpr = expr.diff(*symbols, evaluate=evaluate)
457
+ if evaluate:
458
+ sympy_deprecation_warning("""
459
+ The evaluate flag to differentiate_finite() is deprecated.
460
+
461
+ evaluate=True expands the intermediate derivatives before computing
462
+ differences, but this usually not what you want, as it does not
463
+ satisfy the product rule.
464
+ """,
465
+ deprecated_since_version="1.5",
466
+ active_deprecations_target="deprecated-differentiate_finite-evaluate",
467
+ )
468
+ return Dexpr.replace(
469
+ lambda arg: arg.is_Derivative,
470
+ lambda arg: arg.as_finite_difference(points=points, x0=x0, wrt=wrt))
471
+ else:
472
+ DFexpr = Dexpr.as_finite_difference(points=points, x0=x0, wrt=wrt)
473
+ return DFexpr.replace(
474
+ lambda arg: isinstance(arg, Subs),
475
+ lambda arg: arg.expr.as_finite_difference(
476
+ points=points, x0=arg.point[0], wrt=arg.variables[0]))
.venv/lib/python3.13/site-packages/sympy/calculus/singularities.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Singularities
3
+ =============
4
+
5
+ This module implements algorithms for finding singularities for a function
6
+ and identifying types of functions.
7
+
8
+ The differential calculus methods in this module include methods to identify
9
+ the following function types in the given ``Interval``:
10
+ - Increasing
11
+ - Strictly Increasing
12
+ - Decreasing
13
+ - Strictly Decreasing
14
+ - Monotonic
15
+
16
+ """
17
+
18
+ from sympy.core.power import Pow
19
+ from sympy.core.singleton import S
20
+ from sympy.core.symbol import Symbol
21
+ from sympy.core.sympify import sympify
22
+ from sympy.functions.elementary.exponential import log
23
+ from sympy.functions.elementary.trigonometric import sec, csc, cot, tan, cos
24
+ from sympy.functions.elementary.hyperbolic import (
25
+ sech, csch, coth, tanh, cosh, asech, acsch, atanh, acoth)
26
+ from sympy.utilities.misc import filldedent
27
+
28
+
29
+ def singularities(expression, symbol, domain=None):
30
+ """
31
+ Find singularities of a given function.
32
+
33
+ Parameters
34
+ ==========
35
+
36
+ expression : Expr
37
+ The target function in which singularities need to be found.
38
+ symbol : Symbol
39
+ The symbol over the values of which the singularity in
40
+ expression in being searched for.
41
+
42
+ Returns
43
+ =======
44
+
45
+ Set
46
+ A set of values for ``symbol`` for which ``expression`` has a
47
+ singularity. An ``EmptySet`` is returned if ``expression`` has no
48
+ singularities for any given value of ``Symbol``.
49
+
50
+ Raises
51
+ ======
52
+
53
+ NotImplementedError
54
+ Methods for determining the singularities of this function have
55
+ not been developed.
56
+
57
+ Notes
58
+ =====
59
+
60
+ This function does not find non-isolated singularities
61
+ nor does it find branch points of the expression.
62
+
63
+ Currently supported functions are:
64
+ - univariate continuous (real or complex) functions
65
+
66
+ References
67
+ ==========
68
+
69
+ .. [1] https://en.wikipedia.org/wiki/Mathematical_singularity
70
+
71
+ Examples
72
+ ========
73
+
74
+ >>> from sympy import singularities, Symbol, log
75
+ >>> x = Symbol('x', real=True)
76
+ >>> y = Symbol('y', real=False)
77
+ >>> singularities(x**2 + x + 1, x)
78
+ EmptySet
79
+ >>> singularities(1/(x + 1), x)
80
+ {-1}
81
+ >>> singularities(1/(y**2 + 1), y)
82
+ {-I, I}
83
+ >>> singularities(1/(y**3 + 1), y)
84
+ {-1, 1/2 - sqrt(3)*I/2, 1/2 + sqrt(3)*I/2}
85
+ >>> singularities(log(x), x)
86
+ {0}
87
+
88
+ """
89
+ from sympy.solvers.solveset import solveset
90
+
91
+ if domain is None:
92
+ domain = S.Reals if symbol.is_real else S.Complexes
93
+ try:
94
+ sings = S.EmptySet
95
+ e = expression.rewrite([sec, csc, cot, tan], cos)
96
+ e = e.rewrite([sech, csch, coth, tanh], cosh)
97
+ for i in e.atoms(Pow):
98
+ if i.exp.is_infinite:
99
+ raise NotImplementedError
100
+ if i.exp.is_negative:
101
+ # XXX: exponent of varying sign not handled
102
+ sings += solveset(i.base, symbol, domain)
103
+ for i in expression.atoms(log, asech, acsch):
104
+ sings += solveset(i.args[0], symbol, domain)
105
+ for i in expression.atoms(atanh, acoth):
106
+ sings += solveset(i.args[0] - 1, symbol, domain)
107
+ sings += solveset(i.args[0] + 1, symbol, domain)
108
+ return sings
109
+ except NotImplementedError:
110
+ raise NotImplementedError(filldedent('''
111
+ Methods for determining the singularities
112
+ of this function have not been developed.'''))
113
+
114
+
115
+ ###########################################################################
116
+ # DIFFERENTIAL CALCULUS METHODS #
117
+ ###########################################################################
118
+
119
+
120
+ def monotonicity_helper(expression, predicate, interval=S.Reals, symbol=None):
121
+ """
122
+ Helper function for functions checking function monotonicity.
123
+
124
+ Parameters
125
+ ==========
126
+
127
+ expression : Expr
128
+ The target function which is being checked
129
+ predicate : function
130
+ The property being tested for. The function takes in an integer
131
+ and returns a boolean. The integer input is the derivative and
132
+ the boolean result should be true if the property is being held,
133
+ and false otherwise.
134
+ interval : Set, optional
135
+ The range of values in which we are testing, defaults to all reals.
136
+ symbol : Symbol, optional
137
+ The symbol present in expression which gets varied over the given range.
138
+
139
+ It returns a boolean indicating whether the interval in which
140
+ the function's derivative satisfies given predicate is a superset
141
+ of the given interval.
142
+
143
+ Returns
144
+ =======
145
+
146
+ Boolean
147
+ True if ``predicate`` is true for all the derivatives when ``symbol``
148
+ is varied in ``range``, False otherwise.
149
+
150
+ """
151
+ from sympy.solvers.solveset import solveset
152
+
153
+ expression = sympify(expression)
154
+ free = expression.free_symbols
155
+
156
+ if symbol is None:
157
+ if len(free) > 1:
158
+ raise NotImplementedError(
159
+ 'The function has not yet been implemented'
160
+ ' for all multivariate expressions.'
161
+ )
162
+
163
+ variable = symbol or (free.pop() if free else Symbol('x'))
164
+ derivative = expression.diff(variable)
165
+ predicate_interval = solveset(predicate(derivative), variable, S.Reals)
166
+ return interval.is_subset(predicate_interval)
167
+
168
+
169
+ def is_increasing(expression, interval=S.Reals, symbol=None):
170
+ """
171
+ Return whether the function is increasing in the given interval.
172
+
173
+ Parameters
174
+ ==========
175
+
176
+ expression : Expr
177
+ The target function which is being checked.
178
+ interval : Set, optional
179
+ The range of values in which we are testing (defaults to set of
180
+ all real numbers).
181
+ symbol : Symbol, optional
182
+ The symbol present in expression which gets varied over the given range.
183
+
184
+ Returns
185
+ =======
186
+
187
+ Boolean
188
+ True if ``expression`` is increasing (either strictly increasing or
189
+ constant) in the given ``interval``, False otherwise.
190
+
191
+ Examples
192
+ ========
193
+
194
+ >>> from sympy import is_increasing
195
+ >>> from sympy.abc import x, y
196
+ >>> from sympy import S, Interval, oo
197
+ >>> is_increasing(x**3 - 3*x**2 + 4*x, S.Reals)
198
+ True
199
+ >>> is_increasing(-x**2, Interval(-oo, 0))
200
+ True
201
+ >>> is_increasing(-x**2, Interval(0, oo))
202
+ False
203
+ >>> is_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval(-2, 3))
204
+ False
205
+ >>> is_increasing(x**2 + y, Interval(1, 2), x)
206
+ True
207
+
208
+ """
209
+ return monotonicity_helper(expression, lambda x: x >= 0, interval, symbol)
210
+
211
+
212
+ def is_strictly_increasing(expression, interval=S.Reals, symbol=None):
213
+ """
214
+ Return whether the function is strictly increasing in the given interval.
215
+
216
+ Parameters
217
+ ==========
218
+
219
+ expression : Expr
220
+ The target function which is being checked.
221
+ interval : Set, optional
222
+ The range of values in which we are testing (defaults to set of
223
+ all real numbers).
224
+ symbol : Symbol, optional
225
+ The symbol present in expression which gets varied over the given range.
226
+
227
+ Returns
228
+ =======
229
+
230
+ Boolean
231
+ True if ``expression`` is strictly increasing in the given ``interval``,
232
+ False otherwise.
233
+
234
+ Examples
235
+ ========
236
+
237
+ >>> from sympy import is_strictly_increasing
238
+ >>> from sympy.abc import x, y
239
+ >>> from sympy import Interval, oo
240
+ >>> is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.Ropen(-oo, -2))
241
+ True
242
+ >>> is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.Lopen(3, oo))
243
+ True
244
+ >>> is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.open(-2, 3))
245
+ False
246
+ >>> is_strictly_increasing(-x**2, Interval(0, oo))
247
+ False
248
+ >>> is_strictly_increasing(-x**2 + y, Interval(-oo, 0), x)
249
+ False
250
+
251
+ """
252
+ return monotonicity_helper(expression, lambda x: x > 0, interval, symbol)
253
+
254
+
255
+ def is_decreasing(expression, interval=S.Reals, symbol=None):
256
+ """
257
+ Return whether the function is decreasing in the given interval.
258
+
259
+ Parameters
260
+ ==========
261
+
262
+ expression : Expr
263
+ The target function which is being checked.
264
+ interval : Set, optional
265
+ The range of values in which we are testing (defaults to set of
266
+ all real numbers).
267
+ symbol : Symbol, optional
268
+ The symbol present in expression which gets varied over the given range.
269
+
270
+ Returns
271
+ =======
272
+
273
+ Boolean
274
+ True if ``expression`` is decreasing (either strictly decreasing or
275
+ constant) in the given ``interval``, False otherwise.
276
+
277
+ Examples
278
+ ========
279
+
280
+ >>> from sympy import is_decreasing
281
+ >>> from sympy.abc import x, y
282
+ >>> from sympy import S, Interval, oo
283
+ >>> is_decreasing(1/(x**2 - 3*x), Interval.open(S(3)/2, 3))
284
+ True
285
+ >>> is_decreasing(1/(x**2 - 3*x), Interval.open(1.5, 3))
286
+ True
287
+ >>> is_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo))
288
+ True
289
+ >>> is_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, S(3)/2))
290
+ False
291
+ >>> is_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, 1.5))
292
+ False
293
+ >>> is_decreasing(-x**2, Interval(-oo, 0))
294
+ False
295
+ >>> is_decreasing(-x**2 + y, Interval(-oo, 0), x)
296
+ False
297
+
298
+ """
299
+ return monotonicity_helper(expression, lambda x: x <= 0, interval, symbol)
300
+
301
+
302
+ def is_strictly_decreasing(expression, interval=S.Reals, symbol=None):
303
+ """
304
+ Return whether the function is strictly decreasing in the given interval.
305
+
306
+ Parameters
307
+ ==========
308
+
309
+ expression : Expr
310
+ The target function which is being checked.
311
+ interval : Set, optional
312
+ The range of values in which we are testing (defaults to set of
313
+ all real numbers).
314
+ symbol : Symbol, optional
315
+ The symbol present in expression which gets varied over the given range.
316
+
317
+ Returns
318
+ =======
319
+
320
+ Boolean
321
+ True if ``expression`` is strictly decreasing in the given ``interval``,
322
+ False otherwise.
323
+
324
+ Examples
325
+ ========
326
+
327
+ >>> from sympy import is_strictly_decreasing
328
+ >>> from sympy.abc import x, y
329
+ >>> from sympy import S, Interval, oo
330
+ >>> is_strictly_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo))
331
+ True
332
+ >>> is_strictly_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, S(3)/2))
333
+ False
334
+ >>> is_strictly_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, 1.5))
335
+ False
336
+ >>> is_strictly_decreasing(-x**2, Interval(-oo, 0))
337
+ False
338
+ >>> is_strictly_decreasing(-x**2 + y, Interval(-oo, 0), x)
339
+ False
340
+
341
+ """
342
+ return monotonicity_helper(expression, lambda x: x < 0, interval, symbol)
343
+
344
+
345
+ def is_monotonic(expression, interval=S.Reals, symbol=None):
346
+ """
347
+ Return whether the function is monotonic in the given interval.
348
+
349
+ Parameters
350
+ ==========
351
+
352
+ expression : Expr
353
+ The target function which is being checked.
354
+ interval : Set, optional
355
+ The range of values in which we are testing (defaults to set of
356
+ all real numbers).
357
+ symbol : Symbol, optional
358
+ The symbol present in expression which gets varied over the given range.
359
+
360
+ Returns
361
+ =======
362
+
363
+ Boolean
364
+ True if ``expression`` is monotonic in the given ``interval``,
365
+ False otherwise.
366
+
367
+ Raises
368
+ ======
369
+
370
+ NotImplementedError
371
+ Monotonicity check has not been implemented for the queried function.
372
+
373
+ Examples
374
+ ========
375
+
376
+ >>> from sympy import is_monotonic
377
+ >>> from sympy.abc import x, y
378
+ >>> from sympy import S, Interval, oo
379
+ >>> is_monotonic(1/(x**2 - 3*x), Interval.open(S(3)/2, 3))
380
+ True
381
+ >>> is_monotonic(1/(x**2 - 3*x), Interval.open(1.5, 3))
382
+ True
383
+ >>> is_monotonic(1/(x**2 - 3*x), Interval.Lopen(3, oo))
384
+ True
385
+ >>> is_monotonic(x**3 - 3*x**2 + 4*x, S.Reals)
386
+ True
387
+ >>> is_monotonic(-x**2, S.Reals)
388
+ False
389
+ >>> is_monotonic(x**2 + y + 1, Interval(1, 2), x)
390
+ True
391
+
392
+ """
393
+ from sympy.solvers.solveset import solveset
394
+
395
+ expression = sympify(expression)
396
+
397
+ free = expression.free_symbols
398
+ if symbol is None and len(free) > 1:
399
+ raise NotImplementedError(
400
+ 'is_monotonic has not yet been implemented'
401
+ ' for all multivariate expressions.'
402
+ )
403
+
404
+ variable = symbol or (free.pop() if free else Symbol('x'))
405
+ turning_points = solveset(expression.diff(variable), variable, interval)
406
+ return interval.intersection(turning_points) is S.EmptySet
.venv/lib/python3.13/site-packages/sympy/calculus/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_accumulationbounds.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.numbers import (E, Rational, oo, pi, zoo)
2
+ from sympy.core.singleton import S
3
+ from sympy.core.symbol import Symbol
4
+ from sympy.functions.elementary.exponential import (exp, log)
5
+ from sympy.functions.elementary.miscellaneous import (Max, Min, sqrt)
6
+ from sympy.functions.elementary.trigonometric import (cos, sin, tan)
7
+ from sympy.calculus.accumulationbounds import AccumBounds
8
+ from sympy.core import Add, Mul, Pow
9
+ from sympy.core.expr import unchanged
10
+ from sympy.testing.pytest import raises, XFAIL
11
+ from sympy.abc import x
12
+
13
+ a = Symbol('a', real=True)
14
+ B = AccumBounds
15
+
16
+
17
+ def test_AccumBounds():
18
+ assert B(1, 2).args == (1, 2)
19
+ assert B(1, 2).delta is S.One
20
+ assert B(1, 2).mid == Rational(3, 2)
21
+ assert B(1, 3).is_real == True
22
+
23
+ assert B(1, 1) is S.One
24
+
25
+ assert B(1, 2) + 1 == B(2, 3)
26
+ assert 1 + B(1, 2) == B(2, 3)
27
+ assert B(1, 2) + B(2, 3) == B(3, 5)
28
+
29
+ assert -B(1, 2) == B(-2, -1)
30
+
31
+ assert B(1, 2) - 1 == B(0, 1)
32
+ assert 1 - B(1, 2) == B(-1, 0)
33
+ assert B(2, 3) - B(1, 2) == B(0, 2)
34
+
35
+ assert x + B(1, 2) == Add(B(1, 2), x)
36
+ assert a + B(1, 2) == B(1 + a, 2 + a)
37
+ assert B(1, 2) - x == Add(B(1, 2), -x)
38
+
39
+ assert B(-oo, 1) + oo == B(-oo, oo)
40
+ assert B(1, oo) + oo is oo
41
+ assert B(1, oo) - oo == B(-oo, oo)
42
+ assert (-oo - B(-1, oo)) is -oo
43
+ assert B(-oo, 1) - oo is -oo
44
+
45
+ assert B(1, oo) - oo == B(-oo, oo)
46
+ assert B(-oo, 1) - (-oo) == B(-oo, oo)
47
+ assert (oo - B(1, oo)) == B(-oo, oo)
48
+ assert (-oo - B(1, oo)) is -oo
49
+
50
+ assert B(1, 2)/2 == B(S.Half, 1)
51
+ assert 2/B(2, 3) == B(Rational(2, 3), 1)
52
+ assert 1/B(-1, 1) == B(-oo, oo)
53
+
54
+ assert abs(B(1, 2)) == B(1, 2)
55
+ assert abs(B(-2, -1)) == B(1, 2)
56
+ assert abs(B(-2, 1)) == B(0, 2)
57
+ assert abs(B(-1, 2)) == B(0, 2)
58
+ c = Symbol('c')
59
+ raises(ValueError, lambda: B(0, c))
60
+ raises(ValueError, lambda: B(1, -1))
61
+ r = Symbol('r', real=True)
62
+ raises(ValueError, lambda: B(r, r - 1))
63
+
64
+
65
+ def test_AccumBounds_mul():
66
+ assert B(1, 2)*2 == B(2, 4)
67
+ assert 2*B(1, 2) == B(2, 4)
68
+ assert B(1, 2)*B(2, 3) == B(2, 6)
69
+ assert B(0, 2)*B(2, oo) == B(0, oo)
70
+ l, r = B(-oo, oo), B(-a, a)
71
+ assert l*r == B(-oo, oo)
72
+ assert r*l == B(-oo, oo)
73
+ l, r = B(1, oo), B(-3, -2)
74
+ assert l*r == B(-oo, -2)
75
+ assert r*l == B(-oo, -2)
76
+ assert B(1, 2)*0 == 0
77
+ assert B(1, oo)*0 == B(0, oo)
78
+ assert B(-oo, 1)*0 == B(-oo, 0)
79
+ assert B(-oo, oo)*0 == B(-oo, oo)
80
+
81
+ assert B(1, 2)*x == Mul(B(1, 2), x, evaluate=False)
82
+
83
+ assert B(0, 2)*oo == B(0, oo)
84
+ assert B(-2, 0)*oo == B(-oo, 0)
85
+ assert B(0, 2)*(-oo) == B(-oo, 0)
86
+ assert B(-2, 0)*(-oo) == B(0, oo)
87
+ assert B(-1, 1)*oo == B(-oo, oo)
88
+ assert B(-1, 1)*(-oo) == B(-oo, oo)
89
+ assert B(-oo, oo)*oo == B(-oo, oo)
90
+
91
+
92
+ def test_AccumBounds_div():
93
+ assert B(-1, 3)/B(3, 4) == B(Rational(-1, 3), 1)
94
+ assert B(-2, 4)/B(-3, 4) == B(-oo, oo)
95
+ assert B(-3, -2)/B(-4, 0) == B(S.Half, oo)
96
+
97
+ # these two tests can have a better answer
98
+ # after Union of B is improved
99
+ assert B(-3, -2)/B(-2, 1) == B(-oo, oo)
100
+ assert B(2, 3)/B(-2, 2) == B(-oo, oo)
101
+
102
+ assert B(-3, -2)/B(0, 4) == B(-oo, Rational(-1, 2))
103
+ assert B(2, 4)/B(-3, 0) == B(-oo, Rational(-2, 3))
104
+ assert B(2, 4)/B(0, 3) == B(Rational(2, 3), oo)
105
+
106
+ assert B(0, 1)/B(0, 1) == B(0, oo)
107
+ assert B(-1, 0)/B(0, 1) == B(-oo, 0)
108
+ assert B(-1, 2)/B(-2, 2) == B(-oo, oo)
109
+
110
+ assert 1/B(-1, 2) == B(-oo, oo)
111
+ assert 1/B(0, 2) == B(S.Half, oo)
112
+ assert (-1)/B(0, 2) == B(-oo, Rational(-1, 2))
113
+ assert 1/B(-oo, 0) == B(-oo, 0)
114
+ assert 1/B(-1, 0) == B(-oo, -1)
115
+ assert (-2)/B(-oo, 0) == B(0, oo)
116
+ assert 1/B(-oo, -1) == B(-1, 0)
117
+
118
+ assert B(1, 2)/a == Mul(B(1, 2), 1/a, evaluate=False)
119
+
120
+ assert B(1, 2)/0 == B(1, 2)*zoo
121
+ assert B(1, oo)/oo == B(0, oo)
122
+ assert B(1, oo)/(-oo) == B(-oo, 0)
123
+ assert B(-oo, -1)/oo == B(-oo, 0)
124
+ assert B(-oo, -1)/(-oo) == B(0, oo)
125
+ assert B(-oo, oo)/oo == B(-oo, oo)
126
+ assert B(-oo, oo)/(-oo) == B(-oo, oo)
127
+ assert B(-1, oo)/oo == B(0, oo)
128
+ assert B(-1, oo)/(-oo) == B(-oo, 0)
129
+ assert B(-oo, 1)/oo == B(-oo, 0)
130
+ assert B(-oo, 1)/(-oo) == B(0, oo)
131
+
132
+
133
+ def test_issue_18795():
134
+ r = Symbol('r', real=True)
135
+ a = B(-1,1)
136
+ c = B(7, oo)
137
+ b = B(-oo, oo)
138
+ assert c - tan(r) == B(7-tan(r), oo)
139
+ assert b + tan(r) == B(-oo, oo)
140
+ assert (a + r)/a == B(-oo, oo)*B(r - 1, r + 1)
141
+ assert (b + a)/a == B(-oo, oo)
142
+
143
+
144
+ def test_AccumBounds_func():
145
+ assert (x**2 + 2*x + 1).subs(x, B(-1, 1)) == B(-1, 4)
146
+ assert exp(B(0, 1)) == B(1, E)
147
+ assert exp(B(-oo, oo)) == B(0, oo)
148
+ assert log(B(3, 6)) == B(log(3), log(6))
149
+
150
+
151
+ @XFAIL
152
+ def test_AccumBounds_powf():
153
+ nn = Symbol('nn', nonnegative=True)
154
+ assert B(1 + nn, 2 + nn)**B(1, 2) == B(1 + nn, (2 + nn)**2)
155
+ i = Symbol('i', integer=True, negative=True)
156
+ assert B(1, 2)**i == B(2**i, 1)
157
+
158
+
159
+ def test_AccumBounds_pow():
160
+ assert B(0, 2)**2 == B(0, 4)
161
+ assert B(-1, 1)**2 == B(0, 1)
162
+ assert B(1, 2)**2 == B(1, 4)
163
+ assert B(-1, 2)**3 == B(-1, 8)
164
+ assert B(-1, 1)**0 == 1
165
+
166
+ assert B(1, 2)**Rational(5, 2) == B(1, 4*sqrt(2))
167
+ assert B(0, 2)**S.Half == B(0, sqrt(2))
168
+
169
+ neg = Symbol('neg', negative=True)
170
+ assert unchanged(Pow, B(neg, 1), S.Half)
171
+ nn = Symbol('nn', nonnegative=True)
172
+ assert B(nn, nn + 1)**S.Half == B(sqrt(nn), sqrt(nn + 1))
173
+ assert B(nn, nn + 1)**nn == B(nn**nn, (nn + 1)**nn)
174
+ assert unchanged(Pow, B(nn, nn + 1), x)
175
+ i = Symbol('i', integer=True)
176
+ assert B(1, 2)**i == B(Min(1, 2**i), Max(1, 2**i))
177
+ i = Symbol('i', integer=True, nonnegative=True)
178
+ assert B(1, 2)**i == B(1, 2**i)
179
+ assert B(0, 1)**i == B(0**i, 1)
180
+
181
+ assert B(1, 5)**(-2) == B(Rational(1, 25), 1)
182
+ assert B(-1, 3)**(-2) == B(0, oo)
183
+ assert B(0, 2)**(-3) == B(Rational(1, 8), oo)
184
+ assert B(-2, 0)**(-3) == B(-oo, -Rational(1, 8))
185
+ assert B(0, 2)**(-2) == B(Rational(1, 4), oo)
186
+ assert B(-1, 2)**(-3) == B(-oo, oo)
187
+ assert B(-3, -2)**(-3) == B(Rational(-1, 8), Rational(-1, 27))
188
+ assert B(-3, -2)**(-2) == B(Rational(1, 9), Rational(1, 4))
189
+ assert B(0, oo)**S.Half == B(0, oo)
190
+ assert B(-oo, 0)**(-2) == B(0, oo)
191
+ assert B(-2, 0)**(-2) == B(Rational(1, 4), oo)
192
+
193
+ assert B(Rational(1, 3), S.Half)**oo is S.Zero
194
+ assert B(0, S.Half)**oo is S.Zero
195
+ assert B(S.Half, 1)**oo == B(0, oo)
196
+ assert B(0, 1)**oo == B(0, oo)
197
+ assert B(2, 3)**oo is oo
198
+ assert B(1, 2)**oo == B(0, oo)
199
+ assert B(S.Half, 3)**oo == B(0, oo)
200
+ assert B(Rational(-1, 3), Rational(-1, 4))**oo is S.Zero
201
+ assert B(-1, Rational(-1, 2))**oo is S.NaN
202
+ assert B(-3, -2)**oo is zoo
203
+ assert B(-2, -1)**oo is S.NaN
204
+ assert B(-2, Rational(-1, 2))**oo is S.NaN
205
+ assert B(Rational(-1, 2), S.Half)**oo is S.Zero
206
+ assert B(Rational(-1, 2), 1)**oo == B(0, oo)
207
+ assert B(Rational(-2, 3), 2)**oo == B(0, oo)
208
+ assert B(-1, 1)**oo == B(-oo, oo)
209
+ assert B(-1, S.Half)**oo == B(-oo, oo)
210
+ assert B(-1, 2)**oo == B(-oo, oo)
211
+ assert B(-2, S.Half)**oo == B(-oo, oo)
212
+
213
+ assert B(1, 2)**x == Pow(B(1, 2), x, evaluate=False)
214
+
215
+ assert B(2, 3)**(-oo) is S.Zero
216
+ assert B(0, 2)**(-oo) == B(0, oo)
217
+ assert B(-1, 2)**(-oo) == B(-oo, oo)
218
+
219
+ assert (tan(x)**sin(2*x)).subs(x, B(0, pi/2)) == \
220
+ Pow(B(-oo, oo), B(0, 1))
221
+
222
+
223
+ def test_AccumBounds_exponent():
224
+ # base is 0
225
+ z = 0**B(a, a + S.Half)
226
+ assert z.subs(a, 0) == B(0, 1)
227
+ assert z.subs(a, 1) == 0
228
+ p = z.subs(a, -1)
229
+ assert p.is_Pow and p.args == (0, B(-1, -S.Half))
230
+ # base > 0
231
+ # when base is 1 the type of bounds does not matter
232
+ assert 1**B(a, a + 1) == 1
233
+ # otherwise we need to know if 0 is in the bounds
234
+ assert S.Half**B(-2, 2) == B(S(1)/4, 4)
235
+ assert 2**B(-2, 2) == B(S(1)/4, 4)
236
+
237
+ # +eps may introduce +oo
238
+ # if there is a negative integer exponent
239
+ assert B(0, 1)**B(S(1)/2, 1) == B(0, 1)
240
+ assert B(0, 1)**B(0, 1) == B(0, 1)
241
+
242
+ # positive bases have positive bounds
243
+ assert B(2, 3)**B(-3, -2) == B(S(1)/27, S(1)/4)
244
+ assert B(2, 3)**B(-3, 2) == B(S(1)/27, 9)
245
+
246
+ # bounds generating imaginary parts unevaluated
247
+ assert unchanged(Pow, B(-1, 1), B(1, 2))
248
+ assert B(0, S(1)/2)**B(1, oo) == B(0, S(1)/2)
249
+ assert B(0, 1)**B(1, oo) == B(0, oo)
250
+ assert B(0, 2)**B(1, oo) == B(0, oo)
251
+ assert B(0, oo)**B(1, oo) == B(0, oo)
252
+ assert B(S(1)/2, 1)**B(1, oo) == B(0, oo)
253
+ assert B(S(1)/2, 1)**B(-oo, -1) == B(0, oo)
254
+ assert B(S(1)/2, 1)**B(-oo, oo) == B(0, oo)
255
+ assert B(S(1)/2, 2)**B(1, oo) == B(0, oo)
256
+ assert B(S(1)/2, 2)**B(-oo, -1) == B(0, oo)
257
+ assert B(S(1)/2, 2)**B(-oo, oo) == B(0, oo)
258
+ assert B(S(1)/2, oo)**B(1, oo) == B(0, oo)
259
+ assert B(S(1)/2, oo)**B(-oo, -1) == B(0, oo)
260
+ assert B(S(1)/2, oo)**B(-oo, oo) == B(0, oo)
261
+ assert B(1, 2)**B(1, oo) == B(0, oo)
262
+ assert B(1, 2)**B(-oo, -1) == B(0, oo)
263
+ assert B(1, 2)**B(-oo, oo) == B(0, oo)
264
+ assert B(1, oo)**B(1, oo) == B(0, oo)
265
+ assert B(1, oo)**B(-oo, -1) == B(0, oo)
266
+ assert B(1, oo)**B(-oo, oo) == B(0, oo)
267
+ assert B(2, oo)**B(1, oo) == B(2, oo)
268
+ assert B(2, oo)**B(-oo, -1) == B(0, S(1)/2)
269
+ assert B(2, oo)**B(-oo, oo) == B(0, oo)
270
+
271
+
272
+ def test_comparison_AccumBounds():
273
+ assert (B(1, 3) < 4) == S.true
274
+ assert (B(1, 3) < -1) == S.false
275
+ assert (B(1, 3) < 2).rel_op == '<'
276
+ assert (B(1, 3) <= 2).rel_op == '<='
277
+
278
+ assert (B(1, 3) > 4) == S.false
279
+ assert (B(1, 3) > -1) == S.true
280
+ assert (B(1, 3) > 2).rel_op == '>'
281
+ assert (B(1, 3) >= 2).rel_op == '>='
282
+
283
+ assert (B(1, 3) < B(4, 6)) == S.true
284
+ assert (B(1, 3) < B(2, 4)).rel_op == '<'
285
+ assert (B(1, 3) < B(-2, 0)) == S.false
286
+
287
+ assert (B(1, 3) <= B(4, 6)) == S.true
288
+ assert (B(1, 3) <= B(-2, 0)) == S.false
289
+
290
+ assert (B(1, 3) > B(4, 6)) == S.false
291
+ assert (B(1, 3) > B(-2, 0)) == S.true
292
+
293
+ assert (B(1, 3) >= B(4, 6)) == S.false
294
+ assert (B(1, 3) >= B(-2, 0)) == S.true
295
+
296
+ # issue 13499
297
+ assert (cos(x) > 0).subs(x, oo) == (B(-1, 1) > 0)
298
+
299
+ c = Symbol('c')
300
+ raises(TypeError, lambda: (B(0, 1) < c))
301
+ raises(TypeError, lambda: (B(0, 1) <= c))
302
+ raises(TypeError, lambda: (B(0, 1) > c))
303
+ raises(TypeError, lambda: (B(0, 1) >= c))
304
+
305
+
306
+ def test_contains_AccumBounds():
307
+ assert (1 in B(1, 2)) == S.true
308
+ raises(TypeError, lambda: a in B(1, 2))
309
+ assert 0 in B(-1, 0)
310
+ raises(TypeError, lambda:
311
+ (cos(1)**2 + sin(1)**2 - 1) in B(-1, 0))
312
+ assert (-oo in B(1, oo)) == S.true
313
+ assert (oo in B(-oo, 0)) == S.true
314
+
315
+ # issue 13159
316
+ assert Mul(0, B(-1, 1)) == Mul(B(-1, 1), 0) == 0
317
+ import itertools
318
+ for perm in itertools.permutations([0, B(-1, 1), x]):
319
+ assert Mul(*perm) == 0
320
+
321
+
322
+ def test_intersection_AccumBounds():
323
+ assert B(0, 3).intersection(B(1, 2)) == B(1, 2)
324
+ assert B(0, 3).intersection(B(1, 4)) == B(1, 3)
325
+ assert B(0, 3).intersection(B(-1, 2)) == B(0, 2)
326
+ assert B(0, 3).intersection(B(-1, 4)) == B(0, 3)
327
+ assert B(0, 1).intersection(B(2, 3)) == S.EmptySet
328
+ raises(TypeError, lambda: B(0, 3).intersection(1))
329
+
330
+
331
+ def test_union_AccumBounds():
332
+ assert B(0, 3).union(B(1, 2)) == B(0, 3)
333
+ assert B(0, 3).union(B(1, 4)) == B(0, 4)
334
+ assert B(0, 3).union(B(-1, 2)) == B(-1, 3)
335
+ assert B(0, 3).union(B(-1, 4)) == B(-1, 4)
336
+ raises(TypeError, lambda: B(0, 3).union(1))
.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_euler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.function import (Derivative as D, Function)
2
+ from sympy.core.relational import Eq
3
+ from sympy.core.symbol import (Symbol, symbols)
4
+ from sympy.functions.elementary.trigonometric import (cos, sin)
5
+ from sympy.testing.pytest import raises
6
+ from sympy.calculus.euler import euler_equations as euler
7
+
8
+
9
+ def test_euler_interface():
10
+ x = Function('x')
11
+ y = Symbol('y')
12
+ t = Symbol('t')
13
+ raises(TypeError, lambda: euler())
14
+ raises(TypeError, lambda: euler(D(x(t), t)*y(t), [x(t), y]))
15
+ raises(ValueError, lambda: euler(D(x(t), t)*x(y), [x(t), x(y)]))
16
+ raises(TypeError, lambda: euler(D(x(t), t)**2, x(0)))
17
+ raises(TypeError, lambda: euler(D(x(t), t)*y(t), [t]))
18
+ assert euler(D(x(t), t)**2/2, {x(t)}) == [Eq(-D(x(t), t, t), 0)]
19
+ assert euler(D(x(t), t)**2/2, x(t), {t}) == [Eq(-D(x(t), t, t), 0)]
20
+
21
+
22
+ def test_euler_pendulum():
23
+ x = Function('x')
24
+ t = Symbol('t')
25
+ L = D(x(t), t)**2/2 + cos(x(t))
26
+ assert euler(L, x(t), t) == [Eq(-sin(x(t)) - D(x(t), t, t), 0)]
27
+
28
+
29
+ def test_euler_henonheiles():
30
+ x = Function('x')
31
+ y = Function('y')
32
+ t = Symbol('t')
33
+ L = sum(D(z(t), t)**2/2 - z(t)**2/2 for z in [x, y])
34
+ L += -x(t)**2*y(t) + y(t)**3/3
35
+ assert euler(L, [x(t), y(t)], t) == [Eq(-2*x(t)*y(t) - x(t) -
36
+ D(x(t), t, t), 0),
37
+ Eq(-x(t)**2 + y(t)**2 -
38
+ y(t) - D(y(t), t, t), 0)]
39
+
40
+
41
+ def test_euler_sineg():
42
+ psi = Function('psi')
43
+ t = Symbol('t')
44
+ x = Symbol('x')
45
+ L = D(psi(t, x), t)**2/2 - D(psi(t, x), x)**2/2 + cos(psi(t, x))
46
+ assert euler(L, psi(t, x), [t, x]) == [Eq(-sin(psi(t, x)) -
47
+ D(psi(t, x), t, t) +
48
+ D(psi(t, x), x, x), 0)]
49
+
50
+
51
+ def test_euler_high_order():
52
+ # an example from hep-th/0309038
53
+ m = Symbol('m')
54
+ k = Symbol('k')
55
+ x = Function('x')
56
+ y = Function('y')
57
+ t = Symbol('t')
58
+ L = (m*D(x(t), t)**2/2 + m*D(y(t), t)**2/2 -
59
+ k*D(x(t), t)*D(y(t), t, t) + k*D(y(t), t)*D(x(t), t, t))
60
+ assert euler(L, [x(t), y(t)]) == [Eq(2*k*D(y(t), t, t, t) -
61
+ m*D(x(t), t, t), 0),
62
+ Eq(-2*k*D(x(t), t, t, t) -
63
+ m*D(y(t), t, t), 0)]
64
+
65
+ w = Symbol('w')
66
+ L = D(x(t, w), t, w)**2/2
67
+ assert euler(L) == [Eq(D(x(t, w), t, t, w, w), 0)]
68
+
69
+ def test_issue_18653():
70
+ x, y, z = symbols("x y z")
71
+ f, g, h = symbols("f g h", cls=Function, args=(x, y))
72
+ f, g, h = f(), g(), h()
73
+ expr2 = f.diff(x)*h.diff(z)
74
+ assert euler(expr2, (f,), (x, y)) == []
.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_finite_diff.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import product
2
+
3
+ from sympy.core.function import (Function, diff)
4
+ from sympy.core.numbers import Rational
5
+ from sympy.core.singleton import S
6
+ from sympy.core.symbol import symbols
7
+ from sympy.functions.elementary.exponential import exp
8
+ from sympy.calculus.finite_diff import (
9
+ apply_finite_diff, differentiate_finite, finite_diff_weights,
10
+ _as_finite_diff
11
+ )
12
+ from sympy.testing.pytest import raises, warns_deprecated_sympy
13
+
14
+
15
+ def test_apply_finite_diff():
16
+ x, h = symbols('x h')
17
+ f = Function('f')
18
+ assert (apply_finite_diff(1, [x-h, x+h], [f(x-h), f(x+h)], x) -
19
+ (f(x+h)-f(x-h))/(2*h)).simplify() == 0
20
+
21
+ assert (apply_finite_diff(1, [5, 6, 7], [f(5), f(6), f(7)], 5) -
22
+ (Rational(-3, 2)*f(5) + 2*f(6) - S.Half*f(7))).simplify() == 0
23
+ raises(ValueError, lambda: apply_finite_diff(1, [x, h], [f(x)]))
24
+
25
+
26
+ def test_finite_diff_weights():
27
+
28
+ d = finite_diff_weights(1, [5, 6, 7], 5)
29
+ assert d[1][2] == [Rational(-3, 2), 2, Rational(-1, 2)]
30
+
31
+ # Table 1, p. 702 in doi:10.1090/S0025-5718-1988-0935077-0
32
+ # --------------------------------------------------------
33
+ xl = [0, 1, -1, 2, -2, 3, -3, 4, -4]
34
+
35
+ # d holds all coefficients
36
+ d = finite_diff_weights(4, xl, S.Zero)
37
+
38
+ # Zeroeth derivative
39
+ for i in range(5):
40
+ assert d[0][i] == [S.One] + [S.Zero]*8
41
+
42
+ # First derivative
43
+ assert d[1][0] == [S.Zero]*9
44
+ assert d[1][2] == [S.Zero, S.Half, Rational(-1, 2)] + [S.Zero]*6
45
+ assert d[1][4] == [S.Zero, Rational(2, 3), Rational(-2, 3), Rational(-1, 12), Rational(1, 12)] + [S.Zero]*4
46
+ assert d[1][6] == [S.Zero, Rational(3, 4), Rational(-3, 4), Rational(-3, 20), Rational(3, 20),
47
+ Rational(1, 60), Rational(-1, 60)] + [S.Zero]*2
48
+ assert d[1][8] == [S.Zero, Rational(4, 5), Rational(-4, 5), Rational(-1, 5), Rational(1, 5),
49
+ Rational(4, 105), Rational(-4, 105), Rational(-1, 280), Rational(1, 280)]
50
+
51
+ # Second derivative
52
+ for i in range(2):
53
+ assert d[2][i] == [S.Zero]*9
54
+ assert d[2][2] == [-S(2), S.One, S.One] + [S.Zero]*6
55
+ assert d[2][4] == [Rational(-5, 2), Rational(4, 3), Rational(4, 3), Rational(-1, 12), Rational(-1, 12)] + [S.Zero]*4
56
+ assert d[2][6] == [Rational(-49, 18), Rational(3, 2), Rational(3, 2), Rational(-3, 20), Rational(-3, 20),
57
+ Rational(1, 90), Rational(1, 90)] + [S.Zero]*2
58
+ assert d[2][8] == [Rational(-205, 72), Rational(8, 5), Rational(8, 5), Rational(-1, 5), Rational(-1, 5),
59
+ Rational(8, 315), Rational(8, 315), Rational(-1, 560), Rational(-1, 560)]
60
+
61
+ # Third derivative
62
+ for i in range(3):
63
+ assert d[3][i] == [S.Zero]*9
64
+ assert d[3][4] == [S.Zero, -S.One, S.One, S.Half, Rational(-1, 2)] + [S.Zero]*4
65
+ assert d[3][6] == [S.Zero, Rational(-13, 8), Rational(13, 8), S.One, -S.One,
66
+ Rational(-1, 8), Rational(1, 8)] + [S.Zero]*2
67
+ assert d[3][8] == [S.Zero, Rational(-61, 30), Rational(61, 30), Rational(169, 120), Rational(-169, 120),
68
+ Rational(-3, 10), Rational(3, 10), Rational(7, 240), Rational(-7, 240)]
69
+
70
+ # Fourth derivative
71
+ for i in range(4):
72
+ assert d[4][i] == [S.Zero]*9
73
+ assert d[4][4] == [S(6), -S(4), -S(4), S.One, S.One] + [S.Zero]*4
74
+ assert d[4][6] == [Rational(28, 3), Rational(-13, 2), Rational(-13, 2), S(2), S(2),
75
+ Rational(-1, 6), Rational(-1, 6)] + [S.Zero]*2
76
+ assert d[4][8] == [Rational(91, 8), Rational(-122, 15), Rational(-122, 15), Rational(169, 60), Rational(169, 60),
77
+ Rational(-2, 5), Rational(-2, 5), Rational(7, 240), Rational(7, 240)]
78
+
79
+ # Table 2, p. 703 in doi:10.1090/S0025-5718-1988-0935077-0
80
+ # --------------------------------------------------------
81
+ xl = [[j/S(2) for j in list(range(-i*2+1, 0, 2))+list(range(1, i*2+1, 2))]
82
+ for i in range(1, 5)]
83
+
84
+ # d holds all coefficients
85
+ d = [finite_diff_weights({0: 1, 1: 2, 2: 4, 3: 4}[i], xl[i], 0) for
86
+ i in range(4)]
87
+
88
+ # Zeroth derivative
89
+ assert d[0][0][1] == [S.Half, S.Half]
90
+ assert d[1][0][3] == [Rational(-1, 16), Rational(9, 16), Rational(9, 16), Rational(-1, 16)]
91
+ assert d[2][0][5] == [Rational(3, 256), Rational(-25, 256), Rational(75, 128), Rational(75, 128),
92
+ Rational(-25, 256), Rational(3, 256)]
93
+ assert d[3][0][7] == [Rational(-5, 2048), Rational(49, 2048), Rational(-245, 2048), Rational(1225, 2048),
94
+ Rational(1225, 2048), Rational(-245, 2048), Rational(49, 2048), Rational(-5, 2048)]
95
+
96
+ # First derivative
97
+ assert d[0][1][1] == [-S.One, S.One]
98
+ assert d[1][1][3] == [Rational(1, 24), Rational(-9, 8), Rational(9, 8), Rational(-1, 24)]
99
+ assert d[2][1][5] == [Rational(-3, 640), Rational(25, 384), Rational(-75, 64),
100
+ Rational(75, 64), Rational(-25, 384), Rational(3, 640)]
101
+ assert d[3][1][7] == [Rational(5, 7168), Rational(-49, 5120),
102
+ Rational(245, 3072), Rational(-1225, 1024),
103
+ Rational(1225, 1024), Rational(-245, 3072),
104
+ Rational(49, 5120), Rational(-5, 7168)]
105
+
106
+ # Reasonably the rest of the table is also correct... (testing of that
107
+ # deemed excessive at the moment)
108
+ raises(ValueError, lambda: finite_diff_weights(-1, [1, 2]))
109
+ raises(ValueError, lambda: finite_diff_weights(1.2, [1, 2]))
110
+ x = symbols('x')
111
+ raises(ValueError, lambda: finite_diff_weights(x, [1, 2]))
112
+
113
+
114
+ def test_as_finite_diff():
115
+ x = symbols('x')
116
+ f = Function('f')
117
+ dx = Function('dx')
118
+
119
+ _as_finite_diff(f(x).diff(x), [x-2, x-1, x, x+1, x+2])
120
+
121
+ # Use of undefined functions in ``points``
122
+ df_true = -f(x+dx(x)/2-dx(x+dx(x)/2)/2) / dx(x+dx(x)/2) \
123
+ + f(x+dx(x)/2+dx(x+dx(x)/2)/2) / dx(x+dx(x)/2)
124
+ df_test = diff(f(x), x).as_finite_difference(points=dx(x), x0=x+dx(x)/2)
125
+ assert (df_test - df_true).simplify() == 0
126
+
127
+
128
+ def test_differentiate_finite():
129
+ x, y, h = symbols('x y h')
130
+ f = Function('f')
131
+ with warns_deprecated_sympy():
132
+ res0 = differentiate_finite(f(x, y) + exp(42), x, y, evaluate=True)
133
+ xm, xp, ym, yp = [v + sign*S.Half for v, sign in product([x, y], [-1, 1])]
134
+ ref0 = f(xm, ym) + f(xp, yp) - f(xm, yp) - f(xp, ym)
135
+ assert (res0 - ref0).simplify() == 0
136
+
137
+ g = Function('g')
138
+ with warns_deprecated_sympy():
139
+ res1 = differentiate_finite(f(x)*g(x) + 42, x, evaluate=True)
140
+ ref1 = (-f(x - S.Half) + f(x + S.Half))*g(x) + \
141
+ (-g(x - S.Half) + g(x + S.Half))*f(x)
142
+ assert (res1 - ref1).simplify() == 0
143
+
144
+ res2 = differentiate_finite(f(x) + x**3 + 42, x, points=[x-1, x+1])
145
+ ref2 = (f(x + 1) + (x + 1)**3 - f(x - 1) - (x - 1)**3)/2
146
+ assert (res2 - ref2).simplify() == 0
147
+ raises(TypeError, lambda: differentiate_finite(f(x)*g(x), x,
148
+ pints=[x-1, x+1]))
149
+
150
+ res3 = differentiate_finite(f(x)*g(x).diff(x), x)
151
+ ref3 = (-g(x) + g(x + 1))*f(x + S.Half) - (g(x) - g(x - 1))*f(x - S.Half)
152
+ assert res3 == ref3
153
+
154
+ res4 = differentiate_finite(f(x)*g(x).diff(x).diff(x), x)
155
+ ref4 = -((g(x - Rational(3, 2)) - 2*g(x - S.Half) + g(x + S.Half))*f(x - S.Half)) \
156
+ + (g(x - S.Half) - 2*g(x + S.Half) + g(x + Rational(3, 2)))*f(x + S.Half)
157
+ assert res4 == ref4
158
+
159
+ res5_expr = f(x).diff(x)*g(x).diff(x)
160
+ res5 = differentiate_finite(res5_expr, points=[x-h, x, x+h])
161
+ ref5 = (-2*f(x)/h + f(-h + x)/(2*h) + 3*f(h + x)/(2*h))*(-2*g(x)/h + g(-h + x)/(2*h) \
162
+ + 3*g(h + x)/(2*h))/(2*h) - (2*f(x)/h - 3*f(-h + x)/(2*h) - \
163
+ f(h + x)/(2*h))*(2*g(x)/h - 3*g(-h + x)/(2*h) - g(h + x)/(2*h))/(2*h)
164
+ assert res5 == ref5
.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_singularities.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.numbers import (I, Rational, pi, oo)
2
+ from sympy.core.singleton import S
3
+ from sympy.core.symbol import Symbol, Dummy
4
+ from sympy.core.function import Lambda
5
+ from sympy.functions.elementary.exponential import (exp, log)
6
+ from sympy.functions.elementary.trigonometric import sec, csc
7
+ from sympy.functions.elementary.hyperbolic import (coth, sech,
8
+ atanh, asech, acoth, acsch)
9
+ from sympy.functions.elementary.miscellaneous import sqrt
10
+ from sympy.calculus.singularities import (
11
+ singularities,
12
+ is_increasing,
13
+ is_strictly_increasing,
14
+ is_decreasing,
15
+ is_strictly_decreasing,
16
+ is_monotonic
17
+ )
18
+ from sympy.sets import Interval, FiniteSet, Union, ImageSet
19
+ from sympy.testing.pytest import raises
20
+ from sympy.abc import x, y
21
+
22
+
23
+ def test_singularities():
24
+ x = Symbol('x')
25
+ assert singularities(x**2, x) == S.EmptySet
26
+ assert singularities(x/(x**2 + 3*x + 2), x) == FiniteSet(-2, -1)
27
+ assert singularities(1/(x**2 + 1), x) == FiniteSet(I, -I)
28
+ assert singularities(x/(x**3 + 1), x) == \
29
+ FiniteSet(-1, (1 - sqrt(3) * I) / 2, (1 + sqrt(3) * I) / 2)
30
+ assert singularities(1/(y**2 + 2*I*y + 1), y) == \
31
+ FiniteSet(-I + sqrt(2)*I, -I - sqrt(2)*I)
32
+ _n = Dummy('n')
33
+ assert singularities(sech(x), x).dummy_eq(Union(
34
+ ImageSet(Lambda(_n, 2*_n*I*pi + I*pi/2), S.Integers),
35
+ ImageSet(Lambda(_n, 2*_n*I*pi + 3*I*pi/2), S.Integers)))
36
+ assert singularities(coth(x), x).dummy_eq(Union(
37
+ ImageSet(Lambda(_n, 2*_n*I*pi + I*pi), S.Integers),
38
+ ImageSet(Lambda(_n, 2*_n*I*pi), S.Integers)))
39
+ assert singularities(atanh(x), x) == FiniteSet(-1, 1)
40
+ assert singularities(acoth(x), x) == FiniteSet(-1, 1)
41
+ assert singularities(asech(x), x) == FiniteSet(0)
42
+ assert singularities(acsch(x), x) == FiniteSet(0)
43
+
44
+ x = Symbol('x', real=True)
45
+ assert singularities(1/(x**2 + 1), x) == S.EmptySet
46
+ assert singularities(exp(1/x), x, S.Reals) == FiniteSet(0)
47
+ assert singularities(exp(1/x), x, Interval(1, 2)) == S.EmptySet
48
+ assert singularities(log((x - 2)**2), x, Interval(1, 3)) == FiniteSet(2)
49
+ raises(NotImplementedError, lambda: singularities(x**-oo, x))
50
+ assert singularities(sec(x), x, Interval(0, 3*pi)) == FiniteSet(
51
+ pi/2, 3*pi/2, 5*pi/2)
52
+ assert singularities(csc(x), x, Interval(0, 3*pi)) == FiniteSet(
53
+ 0, pi, 2*pi, 3*pi)
54
+
55
+
56
+ def test_is_increasing():
57
+ """Test whether is_increasing returns correct value."""
58
+ a = Symbol('a', negative=True)
59
+
60
+ assert is_increasing(x**3 - 3*x**2 + 4*x, S.Reals)
61
+ assert is_increasing(-x**2, Interval(-oo, 0))
62
+ assert not is_increasing(-x**2, Interval(0, oo))
63
+ assert not is_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval(-2, 3))
64
+ assert is_increasing(x**2 + y, Interval(1, oo), x)
65
+ assert is_increasing(-x**2*a, Interval(1, oo), x)
66
+ assert is_increasing(1)
67
+
68
+ assert is_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval(-2, 3)) is False
69
+
70
+
71
+ def test_is_strictly_increasing():
72
+ """Test whether is_strictly_increasing returns correct value."""
73
+ assert is_strictly_increasing(
74
+ 4*x**3 - 6*x**2 - 72*x + 30, Interval.Ropen(-oo, -2))
75
+ assert is_strictly_increasing(
76
+ 4*x**3 - 6*x**2 - 72*x + 30, Interval.Lopen(3, oo))
77
+ assert not is_strictly_increasing(
78
+ 4*x**3 - 6*x**2 - 72*x + 30, Interval.open(-2, 3))
79
+ assert not is_strictly_increasing(-x**2, Interval(0, oo))
80
+ assert not is_strictly_decreasing(1)
81
+
82
+ assert is_strictly_increasing(4*x**3 - 6*x**2 - 72*x + 30, Interval.open(-2, 3)) is False
83
+
84
+
85
+ def test_is_decreasing():
86
+ """Test whether is_decreasing returns correct value."""
87
+ b = Symbol('b', positive=True)
88
+
89
+ assert is_decreasing(1/(x**2 - 3*x), Interval.open(Rational(3,2), 3))
90
+ assert is_decreasing(1/(x**2 - 3*x), Interval.open(1.5, 3))
91
+ assert is_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo))
92
+ assert not is_decreasing(1/(x**2 - 3*x), Interval.Ropen(-oo, Rational(3, 2)))
93
+ assert not is_decreasing(-x**2, Interval(-oo, 0))
94
+ assert not is_decreasing(-x**2*b, Interval(-oo, 0), x)
95
+
96
+
97
+ def test_is_strictly_decreasing():
98
+ """Test whether is_strictly_decreasing returns correct value."""
99
+ assert is_strictly_decreasing(1/(x**2 - 3*x), Interval.Lopen(3, oo))
100
+ assert not is_strictly_decreasing(
101
+ 1/(x**2 - 3*x), Interval.Ropen(-oo, Rational(3, 2)))
102
+ assert not is_strictly_decreasing(-x**2, Interval(-oo, 0))
103
+ assert not is_strictly_decreasing(1)
104
+ assert is_strictly_decreasing(1/(x**2 - 3*x), Interval.open(Rational(3,2), 3))
105
+ assert is_strictly_decreasing(1/(x**2 - 3*x), Interval.open(1.5, 3))
106
+
107
+
108
+ def test_is_monotonic():
109
+ """Test whether is_monotonic returns correct value."""
110
+ assert is_monotonic(1/(x**2 - 3*x), Interval.open(Rational(3,2), 3))
111
+ assert is_monotonic(1/(x**2 - 3*x), Interval.open(1.5, 3))
112
+ assert is_monotonic(1/(x**2 - 3*x), Interval.Lopen(3, oo))
113
+ assert is_monotonic(x**3 - 3*x**2 + 4*x, S.Reals)
114
+ assert not is_monotonic(-x**2, S.Reals)
115
+ assert is_monotonic(x**2 + y + 1, Interval(1, 2), x)
116
+ raises(NotImplementedError, lambda: is_monotonic(x**2 + y + 1))
117
+
118
+
119
+ def test_issue_23401():
120
+ x = Symbol('x')
121
+ expr = (x + 1)/(-1.0e-3*x**2 + 0.1*x + 0.1)
122
+ assert is_increasing(expr, Interval(1,2), x)
.venv/lib/python3.13/site-packages/sympy/calculus/tests/test_util.py ADDED
@@ -0,0 +1,392 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core.function import Lambda
2
+ from sympy.core.numbers import (E, I, Rational, oo, pi)
3
+ from sympy.core.relational import Eq
4
+ from sympy.core.singleton import S
5
+ from sympy.core.symbol import (Dummy, Symbol)
6
+ from sympy.functions.elementary.complexes import (Abs, re)
7
+ from sympy.functions.elementary.exponential import (exp, log)
8
+ from sympy.functions.elementary.integers import frac
9
+ from sympy.functions.elementary.miscellaneous import sqrt
10
+ from sympy.functions.elementary.piecewise import Piecewise
11
+ from sympy.functions.elementary.trigonometric import (
12
+ cos, cot, csc, sec, sin, tan, asin, acos, atan, acot, asec, acsc)
13
+ from sympy.functions.elementary.hyperbolic import (sinh, cosh, tanh, coth,
14
+ sech, csch, asinh, acosh, atanh, acoth, asech, acsch)
15
+ from sympy.functions.special.gamma_functions import gamma
16
+ from sympy.functions.special.error_functions import expint
17
+ from sympy.matrices.expressions.matexpr import MatrixSymbol
18
+ from sympy.simplify.simplify import simplify
19
+ from sympy.calculus.util import (function_range, continuous_domain, not_empty_in,
20
+ periodicity, lcim, is_convex,
21
+ stationary_points, minimum, maximum)
22
+ from sympy.sets.sets import (Interval, FiniteSet, Complement, Union)
23
+ from sympy.sets.fancysets import ImageSet
24
+ from sympy.sets.conditionset import ConditionSet
25
+ from sympy.testing.pytest import XFAIL, raises, _both_exp_pow, slow
26
+ from sympy.abc import x, y
27
+
28
+ a = Symbol('a', real=True)
29
+
30
+ def test_function_range():
31
+ assert function_range(sin(x), x, Interval(-pi/2, pi/2)
32
+ ) == Interval(-1, 1)
33
+ assert function_range(sin(x), x, Interval(0, pi)
34
+ ) == Interval(0, 1)
35
+ assert function_range(tan(x), x, Interval(0, pi)
36
+ ) == Interval(-oo, oo)
37
+ assert function_range(tan(x), x, Interval(pi/2, pi)
38
+ ) == Interval(-oo, 0)
39
+ assert function_range((x + 3)/(x - 2), x, Interval(-5, 5)
40
+ ) == Union(Interval(-oo, Rational(2, 7)), Interval(Rational(8, 3), oo))
41
+ assert function_range(1/(x**2), x, Interval(-1, 1)
42
+ ) == Interval(1, oo)
43
+ assert function_range(exp(x), x, Interval(-1, 1)
44
+ ) == Interval(exp(-1), exp(1))
45
+ assert function_range(log(x) - x, x, S.Reals
46
+ ) == Interval(-oo, -1)
47
+ assert function_range(sqrt(3*x - 1), x, Interval(0, 2)
48
+ ) == Interval(0, sqrt(5))
49
+ assert function_range(x*(x - 1) - (x**2 - x), x, S.Reals
50
+ ) == FiniteSet(0)
51
+ assert function_range(x*(x - 1) - (x**2 - x) + y, x, S.Reals
52
+ ) == FiniteSet(y)
53
+ assert function_range(sin(x), x, Union(Interval(-5, -3), FiniteSet(4))
54
+ ) == Union(Interval(-sin(3), 1), FiniteSet(sin(4)))
55
+ assert function_range(cos(x), x, Interval(-oo, -4)
56
+ ) == Interval(-1, 1)
57
+ assert function_range(cos(x), x, S.EmptySet) == S.EmptySet
58
+ assert function_range(x/sqrt(x**2+1), x, S.Reals) == Interval.open(-1,1)
59
+ raises(NotImplementedError, lambda : function_range(
60
+ exp(x)*(sin(x) - cos(x))/2 - x, x, S.Reals))
61
+ raises(NotImplementedError, lambda : function_range(
62
+ sin(x) + x, x, S.Reals)) # issue 13273
63
+ raises(NotImplementedError, lambda : function_range(
64
+ log(x), x, S.Integers))
65
+ raises(NotImplementedError, lambda : function_range(
66
+ sin(x)/2, x, S.Naturals))
67
+
68
+
69
+ @slow
70
+ def test_function_range1():
71
+ assert function_range(tan(x)**2 + tan(3*x)**2 + 1, x, S.Reals) == Interval(1,oo)
72
+
73
+
74
+ def test_continuous_domain():
75
+ assert continuous_domain(sin(x), x, Interval(0, 2*pi)) == Interval(0, 2*pi)
76
+ assert continuous_domain(tan(x), x, Interval(0, 2*pi)) == \
77
+ Union(Interval(0, pi/2, False, True), Interval(pi/2, pi*Rational(3, 2), True, True),
78
+ Interval(pi*Rational(3, 2), 2*pi, True, False))
79
+ assert continuous_domain(cot(x), x, Interval(0, 2*pi)) == Union(
80
+ Interval.open(0, pi), Interval.open(pi, 2*pi))
81
+ assert continuous_domain((x - 1)/((x - 1)**2), x, S.Reals) == \
82
+ Union(Interval(-oo, 1, True, True), Interval(1, oo, True, True))
83
+ assert continuous_domain(log(x) + log(4*x - 1), x, S.Reals) == \
84
+ Interval(Rational(1, 4), oo, True, True)
85
+ assert continuous_domain(1/sqrt(x - 3), x, S.Reals) == Interval(3, oo, True, True)
86
+ assert continuous_domain(1/x - 2, x, S.Reals) == \
87
+ Union(Interval.open(-oo, 0), Interval.open(0, oo))
88
+ assert continuous_domain(1/(x**2 - 4) + 2, x, S.Reals) == \
89
+ Union(Interval.open(-oo, -2), Interval.open(-2, 2), Interval.open(2, oo))
90
+ assert continuous_domain((x+1)**pi, x, S.Reals) == Interval(-1, oo)
91
+ assert continuous_domain((x+1)**(pi/2), x, S.Reals) == Interval(-1, oo)
92
+ assert continuous_domain(x**x, x, S.Reals) == Interval(0, oo)
93
+ assert continuous_domain((x+1)**log(x**2), x, S.Reals) == Union(
94
+ Interval.Ropen(-1, 0), Interval.open(0, oo))
95
+ domain = continuous_domain(log(tan(x)**2 + 1), x, S.Reals)
96
+ assert not domain.contains(3*pi/2)
97
+ assert domain.contains(5)
98
+ d = Symbol('d', even=True, zero=False)
99
+ assert continuous_domain(x**(1/d), x, S.Reals) == Interval(0, oo)
100
+ n = Dummy('n')
101
+ assert continuous_domain(1/sin(x), x, S.Reals).dummy_eq(Complement(
102
+ S.Reals, Union(ImageSet(Lambda(n, 2*n*pi + pi), S.Integers),
103
+ ImageSet(Lambda(n, 2*n*pi), S.Integers))))
104
+ assert continuous_domain(sin(x) + cos(x), x, S.Reals) == S.Reals
105
+ assert continuous_domain(asin(x), x, S.Reals) == Interval(-1, 1) # issue #21786
106
+ assert continuous_domain(1/acos(log(x)), x, S.Reals) == Interval.Ropen(exp(-1), E)
107
+ assert continuous_domain(sinh(x)+cosh(x), x, S.Reals) == S.Reals
108
+ assert continuous_domain(tanh(x)+sech(x), x, S.Reals) == S.Reals
109
+ assert continuous_domain(atan(x)+asinh(x), x, S.Reals) == S.Reals
110
+ assert continuous_domain(acosh(x), x, S.Reals) == Interval(1, oo)
111
+ assert continuous_domain(atanh(x), x, S.Reals) == Interval.open(-1, 1)
112
+ assert continuous_domain(atanh(x)+acosh(x), x, S.Reals) == S.EmptySet
113
+ assert continuous_domain(asech(x), x, S.Reals) == Interval.Lopen(0, 1)
114
+ assert continuous_domain(acoth(x), x, S.Reals) == Union(
115
+ Interval.open(-oo, -1), Interval.open(1, oo))
116
+ assert continuous_domain(asec(x), x, S.Reals) == Union(
117
+ Interval(-oo, -1), Interval(1, oo))
118
+ assert continuous_domain(acsc(x), x, S.Reals) == Union(
119
+ Interval(-oo, -1), Interval(1, oo))
120
+ for f in (coth, acsch, csch):
121
+ assert continuous_domain(f(x), x, S.Reals) == Union(
122
+ Interval.open(-oo, 0), Interval.open(0, oo))
123
+ assert continuous_domain(acot(x), x, S.Reals).contains(0) == False
124
+ assert continuous_domain(1/(exp(x) - x), x, S.Reals) == Complement(
125
+ S.Reals, ConditionSet(x, Eq(-x + exp(x), 0), S.Reals))
126
+ assert continuous_domain(frac(x**2), x, Interval(-2,-1)) == Union(
127
+ Interval.open(-2, -sqrt(3)), Interval.open(-sqrt(2), -1),
128
+ Interval.open(-sqrt(3), -sqrt(2)))
129
+ assert continuous_domain(frac(x), x, S.Reals) == Complement(
130
+ S.Reals, S.Integers)
131
+ raises(NotImplementedError, lambda : continuous_domain(
132
+ 1/(x**2+1), x, S.Complexes))
133
+ raises(NotImplementedError, lambda : continuous_domain(
134
+ gamma(x), x, Interval(-5,0)))
135
+ assert continuous_domain(x + gamma(pi), x, S.Reals) == S.Reals
136
+
137
+
138
+ @XFAIL
139
+ def test_continuous_domain_acot():
140
+ acot_cont = Piecewise((pi+acot(x), x<0), (acot(x), True))
141
+ assert continuous_domain(acot_cont, x, S.Reals) == S.Reals
142
+
143
+ @XFAIL
144
+ def test_continuous_domain_gamma():
145
+ assert continuous_domain(gamma(x), x, S.Reals).contains(-1) == False
146
+
147
+ @XFAIL
148
+ def test_continuous_domain_neg_power():
149
+ assert continuous_domain((x-2)**(1-x), x, S.Reals) == Interval.open(2, oo)
150
+
151
+
152
+ def test_not_empty_in():
153
+ assert not_empty_in(FiniteSet(x, 2*x).intersect(Interval(1, 2, True, False)), x) == \
154
+ Interval(S.Half, 2, True, False)
155
+ assert not_empty_in(FiniteSet(x, x**2).intersect(Interval(1, 2)), x) == \
156
+ Union(Interval(-sqrt(2), -1), Interval(1, 2))
157
+ assert not_empty_in(FiniteSet(x**2 + x, x).intersect(Interval(2, 4)), x) == \
158
+ Union(Interval(-sqrt(17)/2 - S.Half, -2),
159
+ Interval(1, Rational(-1, 2) + sqrt(17)/2), Interval(2, 4))
160
+ assert not_empty_in(FiniteSet(x/(x - 1)).intersect(S.Reals), x) == \
161
+ Complement(S.Reals, FiniteSet(1))
162
+ assert not_empty_in(FiniteSet(a/(a - 1)).intersect(S.Reals), a) == \
163
+ Complement(S.Reals, FiniteSet(1))
164
+ assert not_empty_in(FiniteSet((x**2 - 3*x + 2)/(x - 1)).intersect(S.Reals), x) == \
165
+ Complement(S.Reals, FiniteSet(1))
166
+ assert not_empty_in(FiniteSet(3, 4, x/(x - 1)).intersect(Interval(2, 3)), x) == \
167
+ Interval(-oo, oo)
168
+ assert not_empty_in(FiniteSet(4, x/(x - 1)).intersect(Interval(2, 3)), x) == \
169
+ Interval(S(3)/2, 2)
170
+ assert not_empty_in(FiniteSet(x/(x**2 - 1)).intersect(S.Reals), x) == \
171
+ Complement(S.Reals, FiniteSet(-1, 1))
172
+ assert not_empty_in(FiniteSet(x, x**2).intersect(Union(Interval(1, 3, True, True),
173
+ Interval(4, 5))), x) == \
174
+ Union(Interval(-sqrt(5), -2), Interval(-sqrt(3), -1, True, True),
175
+ Interval(1, 3, True, True), Interval(4, 5))
176
+ assert not_empty_in(FiniteSet(1).intersect(Interval(3, 4)), x) == S.EmptySet
177
+ assert not_empty_in(FiniteSet(x**2/(x + 2)).intersect(Interval(1, oo)), x) == \
178
+ Union(Interval(-2, -1, True, False), Interval(2, oo))
179
+ raises(ValueError, lambda: not_empty_in(x))
180
+ raises(ValueError, lambda: not_empty_in(Interval(0, 1), x))
181
+ raises(NotImplementedError,
182
+ lambda: not_empty_in(FiniteSet(x).intersect(S.Reals), x, a))
183
+
184
+
185
+ @_both_exp_pow
186
+ def test_periodicity():
187
+ assert periodicity(sin(2*x), x) == pi
188
+ assert periodicity((-2)*tan(4*x), x) == pi/4
189
+ assert periodicity(sin(x)**2, x) == 2*pi
190
+ assert periodicity(3**tan(3*x), x) == pi/3
191
+ assert periodicity(tan(x)*cos(x), x) == 2*pi
192
+ assert periodicity(sin(x)**(tan(x)), x) == 2*pi
193
+ assert periodicity(tan(x)*sec(x), x) == 2*pi
194
+ assert periodicity(sin(2*x)*cos(2*x) - y, x) == pi/2
195
+ assert periodicity(tan(x) + cot(x), x) == pi
196
+ assert periodicity(sin(x) - cos(2*x), x) == 2*pi
197
+ assert periodicity(sin(x) - 1, x) == 2*pi
198
+ assert periodicity(sin(4*x) + sin(x)*cos(x), x) == pi
199
+ assert periodicity(exp(sin(x)), x) == 2*pi
200
+ assert periodicity(log(cot(2*x)) - sin(cos(2*x)), x) == pi
201
+ assert periodicity(sin(2*x)*exp(tan(x) - csc(2*x)), x) == pi
202
+ assert periodicity(cos(sec(x) - csc(2*x)), x) == 2*pi
203
+ assert periodicity(tan(sin(2*x)), x) == pi
204
+ assert periodicity(2*tan(x)**2, x) == pi
205
+ assert periodicity(sin(x%4), x) == 4
206
+ assert periodicity(sin(x)%4, x) == 2*pi
207
+ assert periodicity(tan((3*x-2)%4), x) == Rational(4, 3)
208
+ assert periodicity((sqrt(2)*(x+1)+x) % 3, x) == 3 / (sqrt(2)+1)
209
+ assert periodicity((x**2+1) % x, x) is None
210
+ assert periodicity(sin(re(x)), x) == 2*pi
211
+ assert periodicity(sin(x)**2 + cos(x)**2, x) is S.Zero
212
+ assert periodicity(tan(x), y) is S.Zero
213
+ assert periodicity(sin(x) + I*cos(x), x) == 2*pi
214
+ assert periodicity(x - sin(2*y), y) == pi
215
+
216
+ assert periodicity(exp(x), x) is None
217
+ assert periodicity(exp(I*x), x) == 2*pi
218
+ assert periodicity(exp(I*a), a) == 2*pi
219
+ assert periodicity(exp(a), a) is None
220
+ assert periodicity(exp(log(sin(a) + I*cos(2*a)), evaluate=False), a) == 2*pi
221
+ assert periodicity(exp(log(sin(2*a) + I*cos(a)), evaluate=False), a) == 2*pi
222
+ assert periodicity(exp(sin(a)), a) == 2*pi
223
+ assert periodicity(exp(2*I*a), a) == pi
224
+ assert periodicity(exp(a + I*sin(a)), a) is None
225
+ assert periodicity(exp(cos(a/2) + sin(a)), a) == 4*pi
226
+ assert periodicity(log(x), x) is None
227
+ assert periodicity(exp(x)**sin(x), x) is None
228
+ assert periodicity(sin(x)**y, y) is None
229
+
230
+ assert periodicity(Abs(sin(Abs(sin(x)))), x) == pi
231
+ assert all(periodicity(Abs(f(x)), x) == pi for f in (
232
+ cos, sin, sec, csc, tan, cot))
233
+ assert periodicity(Abs(sin(tan(x))), x) == pi
234
+ assert periodicity(Abs(sin(sin(x) + tan(x))), x) == 2*pi
235
+ assert periodicity(sin(x) > S.Half, x) == 2*pi
236
+
237
+ assert periodicity(x > 2, x) is None
238
+ assert periodicity(x**3 - x**2 + 1, x) is None
239
+ assert periodicity(Abs(x), x) is None
240
+ assert periodicity(Abs(x**2 - 1), x) is None
241
+
242
+ assert periodicity((x**2 + 4)%2, x) is None
243
+ assert periodicity((E**x)%3, x) is None
244
+
245
+ assert periodicity(sin(expint(1, x))/expint(1, x), x) is None
246
+ # returning `None` for any Piecewise
247
+ p = Piecewise((0, x < -1), (x**2, x <= 1), (log(x), True))
248
+ assert periodicity(p, x) is None
249
+
250
+ m = MatrixSymbol('m', 3, 3)
251
+ raises(NotImplementedError, lambda: periodicity(sin(m), m))
252
+ raises(NotImplementedError, lambda: periodicity(sin(m[0, 0]), m))
253
+ raises(NotImplementedError, lambda: periodicity(sin(m), m[0, 0]))
254
+ raises(NotImplementedError, lambda: periodicity(sin(m[0, 0]), m[0, 0]))
255
+
256
+
257
+ def test_periodicity_check():
258
+ assert periodicity(tan(x), x, check=True) == pi
259
+ assert periodicity(sin(x) + cos(x), x, check=True) == 2*pi
260
+ assert periodicity(sec(x), x) == 2*pi
261
+ assert periodicity(sin(x*y), x) == 2*pi/abs(y)
262
+ assert periodicity(Abs(sec(sec(x))), x) == pi
263
+
264
+
265
+ def test_lcim():
266
+ assert lcim([S.Half, S(2), S(3)]) == 6
267
+ assert lcim([pi/2, pi/4, pi]) == pi
268
+ assert lcim([2*pi, pi/2]) == 2*pi
269
+ assert lcim([S.One, 2*pi]) is None
270
+ assert lcim([S(2) + 2*E, E/3 + Rational(1, 3), S.One + E]) == S(2) + 2*E
271
+
272
+
273
+ def test_is_convex():
274
+ assert is_convex(1/x, x, domain=Interval.open(0, oo)) == True
275
+ assert is_convex(1/x, x, domain=Interval(-oo, 0)) == False
276
+ assert is_convex(x**2, x, domain=Interval(0, oo)) == True
277
+ assert is_convex(1/x**3, x, domain=Interval.Lopen(0, oo)) == True
278
+ assert is_convex(-1/x**3, x, domain=Interval.Ropen(-oo, 0)) == True
279
+ assert is_convex(log(x) ,x) == False
280
+ assert is_convex(x**2+y**2, x, y) == True
281
+ assert is_convex(cos(x) + cos(y), x) == False
282
+ assert is_convex(8*x**2 - 2*y**2, x, y) == False
283
+
284
+
285
+ def test_stationary_points():
286
+ assert stationary_points(sin(x), x, Interval(-pi/2, pi/2)
287
+ ) == {-pi/2, pi/2}
288
+ assert stationary_points(sin(x), x, Interval.Ropen(0, pi/4)
289
+ ) is S.EmptySet
290
+ assert stationary_points(tan(x), x,
291
+ ) is S.EmptySet
292
+ assert stationary_points(sin(x)*cos(x), x, Interval(0, pi)
293
+ ) == {pi/4, pi*Rational(3, 4)}
294
+ assert stationary_points(sec(x), x, Interval(0, pi)
295
+ ) == {0, pi}
296
+ assert stationary_points((x+3)*(x-2), x
297
+ ) == FiniteSet(Rational(-1, 2))
298
+ assert stationary_points((x + 3)/(x - 2), x, Interval(-5, 5)
299
+ ) is S.EmptySet
300
+ assert stationary_points((x**2+3)/(x-2), x
301
+ ) == {2 - sqrt(7), 2 + sqrt(7)}
302
+ assert stationary_points((x**2+3)/(x-2), x, Interval(0, 5)
303
+ ) == {2 + sqrt(7)}
304
+ assert stationary_points(x**4 + x**3 - 5*x**2, x, S.Reals
305
+ ) == FiniteSet(-2, 0, Rational(5, 4))
306
+ assert stationary_points(exp(x), x
307
+ ) is S.EmptySet
308
+ assert stationary_points(log(x) - x, x, S.Reals
309
+ ) == {1}
310
+ assert stationary_points(cos(x), x, Union(Interval(0, 5), Interval(-6, -3))
311
+ ) == {0, -pi, pi}
312
+ assert stationary_points(y, x, S.Reals
313
+ ) == S.Reals
314
+ assert stationary_points(y, x, S.EmptySet) == S.EmptySet
315
+
316
+
317
+ def test_maximum():
318
+ assert maximum(sin(x), x) is S.One
319
+ assert maximum(sin(x), x, Interval(0, 1)) == sin(1)
320
+ assert maximum(tan(x), x) is oo
321
+ assert maximum(tan(x), x, Interval(-pi/4, pi/4)) is S.One
322
+ assert maximum(sin(x)*cos(x), x, S.Reals) == S.Half
323
+ assert simplify(maximum(sin(x)*cos(x), x, Interval(pi*Rational(3, 8), pi*Rational(5, 8)))
324
+ ) == sqrt(2)/4
325
+ assert maximum((x+3)*(x-2), x) is oo
326
+ assert maximum((x+3)*(x-2), x, Interval(-5, 0)) == S(14)
327
+ assert maximum((x+3)/(x-2), x, Interval(-5, 0)) == Rational(2, 7)
328
+ assert simplify(maximum(-x**4-x**3+x**2+10, x)
329
+ ) == 41*sqrt(41)/512 + Rational(5419, 512)
330
+ assert maximum(exp(x), x, Interval(-oo, 2)) == exp(2)
331
+ assert maximum(log(x) - x, x, S.Reals) is S.NegativeOne
332
+ assert maximum(cos(x), x, Union(Interval(0, 5), Interval(-6, -3))
333
+ ) is S.One
334
+ assert maximum(cos(x)-sin(x), x, S.Reals) == sqrt(2)
335
+ assert maximum(y, x, S.Reals) == y
336
+ assert maximum(abs(a**3 + a), a, Interval(0, 2)) == 10
337
+ assert maximum(abs(60*a**3 + 24*a), a, Interval(0, 2)) == 528
338
+ assert maximum(abs(12*a*(5*a**2 + 2)), a, Interval(0, 2)) == 528
339
+ assert maximum(x/sqrt(x**2+1), x, S.Reals) == 1
340
+
341
+ raises(ValueError, lambda : maximum(sin(x), x, S.EmptySet))
342
+ raises(ValueError, lambda : maximum(log(cos(x)), x, S.EmptySet))
343
+ raises(ValueError, lambda : maximum(1/(x**2 + y**2 + 1), x, S.EmptySet))
344
+ raises(ValueError, lambda : maximum(sin(x), sin(x)))
345
+ raises(ValueError, lambda : maximum(sin(x), x*y, S.EmptySet))
346
+ raises(ValueError, lambda : maximum(sin(x), S.One))
347
+
348
+
349
+ def test_minimum():
350
+ assert minimum(sin(x), x) is S.NegativeOne
351
+ assert minimum(sin(x), x, Interval(1, 4)) == sin(4)
352
+ assert minimum(tan(x), x) is -oo
353
+ assert minimum(tan(x), x, Interval(-pi/4, pi/4)) is S.NegativeOne
354
+ assert minimum(sin(x)*cos(x), x, S.Reals) == Rational(-1, 2)
355
+ assert simplify(minimum(sin(x)*cos(x), x, Interval(pi*Rational(3, 8), pi*Rational(5, 8)))
356
+ ) == -sqrt(2)/4
357
+ assert minimum((x+3)*(x-2), x) == Rational(-25, 4)
358
+ assert minimum((x+3)/(x-2), x, Interval(-5, 0)) == Rational(-3, 2)
359
+ assert minimum(x**4-x**3+x**2+10, x) == S(10)
360
+ assert minimum(exp(x), x, Interval(-2, oo)) == exp(-2)
361
+ assert minimum(log(x) - x, x, S.Reals) is -oo
362
+ assert minimum(cos(x), x, Union(Interval(0, 5), Interval(-6, -3))
363
+ ) is S.NegativeOne
364
+ assert minimum(cos(x)-sin(x), x, S.Reals) == -sqrt(2)
365
+ assert minimum(y, x, S.Reals) == y
366
+ assert minimum(x/sqrt(x**2+1), x, S.Reals) == -1
367
+
368
+ raises(ValueError, lambda : minimum(sin(x), x, S.EmptySet))
369
+ raises(ValueError, lambda : minimum(log(cos(x)), x, S.EmptySet))
370
+ raises(ValueError, lambda : minimum(1/(x**2 + y**2 + 1), x, S.EmptySet))
371
+ raises(ValueError, lambda : minimum(sin(x), sin(x)))
372
+ raises(ValueError, lambda : minimum(sin(x), x*y, S.EmptySet))
373
+ raises(ValueError, lambda : minimum(sin(x), S.One))
374
+
375
+
376
+ def test_issue_19869():
377
+ assert (maximum(sqrt(3)*(x - 1)/(3*sqrt(x**2 + 1)), x)
378
+ ) == sqrt(3)/3
379
+
380
+
381
+ def test_issue_16469():
382
+ f = abs(a)
383
+ assert function_range(f, a, S.Reals) == Interval(0, oo, False, True)
384
+
385
+
386
+ @_both_exp_pow
387
+ def test_issue_18747():
388
+ assert periodicity(exp(pi*I*(x/4 + S.Half/2)), x) == 8
389
+
390
+
391
+ def test_issue_25942():
392
+ assert (acos(x) > pi/3).as_set() == Interval.Ropen(-1, S(1)/2)
.venv/lib/python3.13/site-packages/sympy/calculus/util.py ADDED
@@ -0,0 +1,895 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .accumulationbounds import AccumBounds, AccumulationBounds # noqa: F401
2
+ from .singularities import singularities
3
+ from sympy.core import Pow, S
4
+ from sympy.core.function import diff, expand_mul, Function
5
+ from sympy.core.kind import NumberKind
6
+ from sympy.core.mod import Mod
7
+ from sympy.core.numbers import equal_valued
8
+ from sympy.core.relational import Relational
9
+ from sympy.core.symbol import Symbol, Dummy
10
+ from sympy.core.sympify import _sympify
11
+ from sympy.functions.elementary.complexes import Abs, im, re
12
+ from sympy.functions.elementary.exponential import exp, log
13
+ from sympy.functions.elementary.integers import frac
14
+ from sympy.functions.elementary.piecewise import Piecewise
15
+ from sympy.functions.elementary.trigonometric import (
16
+ TrigonometricFunction, sin, cos, tan, cot, csc, sec,
17
+ asin, acos, acot, atan, asec, acsc)
18
+ from sympy.functions.elementary.hyperbolic import (sinh, cosh, tanh, coth,
19
+ sech, csch, asinh, acosh, atanh, acoth, asech, acsch)
20
+ from sympy.polys.polytools import degree, lcm_list
21
+ from sympy.sets.sets import (Interval, Intersection, FiniteSet, Union,
22
+ Complement)
23
+ from sympy.sets.fancysets import ImageSet
24
+ from sympy.sets.conditionset import ConditionSet
25
+ from sympy.utilities import filldedent
26
+ from sympy.utilities.iterables import iterable
27
+ from sympy.matrices.dense import hessian
28
+
29
+
30
+ def continuous_domain(f, symbol, domain):
31
+ """
32
+ Returns the domain on which the function expression f is continuous.
33
+
34
+ This function is limited by the ability to determine the various
35
+ singularities and discontinuities of the given function.
36
+ The result is either given as a union of intervals or constructed using
37
+ other set operations.
38
+
39
+ Parameters
40
+ ==========
41
+
42
+ f : :py:class:`~.Expr`
43
+ The concerned function.
44
+ symbol : :py:class:`~.Symbol`
45
+ The variable for which the intervals are to be determined.
46
+ domain : :py:class:`~.Interval`
47
+ The domain over which the continuity of the symbol has to be checked.
48
+
49
+ Examples
50
+ ========
51
+
52
+ >>> from sympy import Interval, Symbol, S, tan, log, pi, sqrt
53
+ >>> from sympy.calculus.util import continuous_domain
54
+ >>> x = Symbol('x')
55
+ >>> continuous_domain(1/x, x, S.Reals)
56
+ Union(Interval.open(-oo, 0), Interval.open(0, oo))
57
+ >>> continuous_domain(tan(x), x, Interval(0, pi))
58
+ Union(Interval.Ropen(0, pi/2), Interval.Lopen(pi/2, pi))
59
+ >>> continuous_domain(sqrt(x - 2), x, Interval(-5, 5))
60
+ Interval(2, 5)
61
+ >>> continuous_domain(log(2*x - 1), x, S.Reals)
62
+ Interval.open(1/2, oo)
63
+
64
+ Returns
65
+ =======
66
+
67
+ :py:class:`~.Interval`
68
+ Union of all intervals where the function is continuous.
69
+
70
+ Raises
71
+ ======
72
+
73
+ NotImplementedError
74
+ If the method to determine continuity of such a function
75
+ has not yet been developed.
76
+
77
+ """
78
+ from sympy.solvers.inequalities import solve_univariate_inequality
79
+
80
+ if not domain.is_subset(S.Reals):
81
+ raise NotImplementedError(filldedent('''
82
+ Domain must be a subset of S.Reals.
83
+ '''))
84
+ implemented = [Pow, exp, log, Abs, frac,
85
+ sin, cos, tan, cot, sec, csc,
86
+ asin, acos, atan, acot, asec, acsc,
87
+ sinh, cosh, tanh, coth, sech, csch,
88
+ asinh, acosh, atanh, acoth, asech, acsch]
89
+ used = [fct.func for fct in f.atoms(Function) if fct.has(symbol)]
90
+ if any(func not in implemented for func in used):
91
+ raise NotImplementedError(filldedent('''
92
+ Unable to determine the domain of the given function.
93
+ '''))
94
+
95
+ x = Symbol('x')
96
+ constraints = {
97
+ log: (x > 0,),
98
+ asin: (x >= -1, x <= 1),
99
+ acos: (x >= -1, x <= 1),
100
+ acosh: (x >= 1,),
101
+ atanh: (x > -1, x < 1),
102
+ asech: (x > 0, x <= 1)
103
+ }
104
+ constraints_union = {
105
+ asec: (x <= -1, x >= 1),
106
+ acsc: (x <= -1, x >= 1),
107
+ acoth: (x < -1, x > 1)
108
+ }
109
+
110
+ cont_domain = domain
111
+ for atom in f.atoms(Pow):
112
+ den = atom.exp.as_numer_denom()[1]
113
+ if atom.exp.is_rational and den.is_odd:
114
+ pass # 0**negative handled by singularities()
115
+ else:
116
+ constraint = solve_univariate_inequality(atom.base >= 0,
117
+ symbol).as_set()
118
+ cont_domain = Intersection(constraint, cont_domain)
119
+
120
+ for atom in f.atoms(Function):
121
+ if atom.func in constraints:
122
+ for c in constraints[atom.func]:
123
+ constraint_relational = c.subs(x, atom.args[0])
124
+ constraint_set = solve_univariate_inequality(
125
+ constraint_relational, symbol).as_set()
126
+ cont_domain = Intersection(constraint_set, cont_domain)
127
+ elif atom.func in constraints_union:
128
+ constraint_set = S.EmptySet
129
+ for c in constraints_union[atom.func]:
130
+ constraint_relational = c.subs(x, atom.args[0])
131
+ constraint_set += solve_univariate_inequality(
132
+ constraint_relational, symbol).as_set()
133
+ cont_domain = Intersection(constraint_set, cont_domain)
134
+ # XXX: the discontinuities below could be factored out in
135
+ # a new "discontinuities()".
136
+ elif atom.func == acot:
137
+ from sympy.solvers.solveset import solveset_real
138
+ # Sympy's acot() has a step discontinuity at 0. Since it's
139
+ # neither an essential singularity nor a pole, singularities()
140
+ # will not report it. But it's still relevant for determining
141
+ # the continuity of the function f.
142
+ cont_domain -= solveset_real(atom.args[0], symbol)
143
+ # Note that the above may introduce spurious discontinuities, e.g.
144
+ # for abs(acot(x)) at 0.
145
+ elif atom.func == frac:
146
+ from sympy.solvers.solveset import solveset_real
147
+ r = function_range(atom.args[0], symbol, domain)
148
+ r = Intersection(r, S.Integers)
149
+ if r.is_finite_set:
150
+ discont = S.EmptySet
151
+ for n in r:
152
+ discont += solveset_real(atom.args[0]-n, symbol)
153
+ else:
154
+ discont = ConditionSet(
155
+ symbol, S.Integers.contains(atom.args[0]), cont_domain)
156
+ cont_domain -= discont
157
+
158
+ return cont_domain - singularities(f, symbol, domain)
159
+
160
+
161
+ def function_range(f, symbol, domain):
162
+ """
163
+ Finds the range of a function in a given domain.
164
+ This method is limited by the ability to determine the singularities and
165
+ determine limits.
166
+
167
+ Parameters
168
+ ==========
169
+
170
+ f : :py:class:`~.Expr`
171
+ The concerned function.
172
+ symbol : :py:class:`~.Symbol`
173
+ The variable for which the range of function is to be determined.
174
+ domain : :py:class:`~.Interval`
175
+ The domain under which the range of the function has to be found.
176
+
177
+ Examples
178
+ ========
179
+
180
+ >>> from sympy import Interval, Symbol, S, exp, log, pi, sqrt, sin, tan
181
+ >>> from sympy.calculus.util import function_range
182
+ >>> x = Symbol('x')
183
+ >>> function_range(sin(x), x, Interval(0, 2*pi))
184
+ Interval(-1, 1)
185
+ >>> function_range(tan(x), x, Interval(-pi/2, pi/2))
186
+ Interval(-oo, oo)
187
+ >>> function_range(1/x, x, S.Reals)
188
+ Union(Interval.open(-oo, 0), Interval.open(0, oo))
189
+ >>> function_range(exp(x), x, S.Reals)
190
+ Interval.open(0, oo)
191
+ >>> function_range(log(x), x, S.Reals)
192
+ Interval(-oo, oo)
193
+ >>> function_range(sqrt(x), x, Interval(-5, 9))
194
+ Interval(0, 3)
195
+
196
+ Returns
197
+ =======
198
+
199
+ :py:class:`~.Interval`
200
+ Union of all ranges for all intervals under domain where function is
201
+ continuous.
202
+
203
+ Raises
204
+ ======
205
+
206
+ NotImplementedError
207
+ If any of the intervals, in the given domain, for which function
208
+ is continuous are not finite or real,
209
+ OR if the critical points of the function on the domain cannot be found.
210
+ """
211
+
212
+ if domain is S.EmptySet:
213
+ return S.EmptySet
214
+
215
+ period = periodicity(f, symbol)
216
+ if period == S.Zero:
217
+ # the expression is constant wrt symbol
218
+ return FiniteSet(f.expand())
219
+
220
+ from sympy.series.limits import limit
221
+ from sympy.solvers.solveset import solveset
222
+
223
+ if period is not None:
224
+ if isinstance(domain, Interval):
225
+ if (domain.inf - domain.sup).is_infinite:
226
+ domain = Interval(0, period)
227
+ elif isinstance(domain, Union):
228
+ for sub_dom in domain.args:
229
+ if isinstance(sub_dom, Interval) and \
230
+ ((sub_dom.inf - sub_dom.sup).is_infinite):
231
+ domain = Interval(0, period)
232
+
233
+ intervals = continuous_domain(f, symbol, domain)
234
+ range_int = S.EmptySet
235
+ if isinstance(intervals,(Interval, FiniteSet)):
236
+ interval_iter = (intervals,)
237
+ elif isinstance(intervals, Union):
238
+ interval_iter = intervals.args
239
+ else:
240
+ raise NotImplementedError("Unable to find range for the given domain.")
241
+
242
+ for interval in interval_iter:
243
+ if isinstance(interval, FiniteSet):
244
+ for singleton in interval:
245
+ if singleton in domain:
246
+ range_int += FiniteSet(f.subs(symbol, singleton))
247
+ elif isinstance(interval, Interval):
248
+ vals = S.EmptySet
249
+ critical_values = S.EmptySet
250
+ bounds = ((interval.left_open, interval.inf, '+'),
251
+ (interval.right_open, interval.sup, '-'))
252
+
253
+ for is_open, limit_point, direction in bounds:
254
+ if is_open:
255
+ critical_values += FiniteSet(limit(f, symbol, limit_point, direction))
256
+ vals += critical_values
257
+ else:
258
+ vals += FiniteSet(f.subs(symbol, limit_point))
259
+
260
+ critical_points = solveset(f.diff(symbol), symbol, interval)
261
+
262
+ if not iterable(critical_points):
263
+ raise NotImplementedError(
264
+ 'Unable to find critical points for {}'.format(f))
265
+ if isinstance(critical_points, ImageSet):
266
+ raise NotImplementedError(
267
+ 'Infinite number of critical points for {}'.format(f))
268
+
269
+ for critical_point in critical_points:
270
+ vals += FiniteSet(f.subs(symbol, critical_point))
271
+
272
+ left_open, right_open = False, False
273
+
274
+ if critical_values is not S.EmptySet:
275
+ if critical_values.inf == vals.inf:
276
+ left_open = True
277
+
278
+ if critical_values.sup == vals.sup:
279
+ right_open = True
280
+
281
+ range_int += Interval(vals.inf, vals.sup, left_open, right_open)
282
+ else:
283
+ raise NotImplementedError("Unable to find range for the given domain.")
284
+
285
+ return range_int
286
+
287
+
288
+ def not_empty_in(finset_intersection, *syms):
289
+ """
290
+ Finds the domain of the functions in ``finset_intersection`` in which the
291
+ ``finite_set`` is not-empty.
292
+
293
+ Parameters
294
+ ==========
295
+
296
+ finset_intersection : Intersection of FiniteSet
297
+ The unevaluated intersection of FiniteSet containing
298
+ real-valued functions with Union of Sets
299
+ syms : Tuple of symbols
300
+ Symbol for which domain is to be found
301
+
302
+ Raises
303
+ ======
304
+
305
+ NotImplementedError
306
+ The algorithms to find the non-emptiness of the given FiniteSet are
307
+ not yet implemented.
308
+ ValueError
309
+ The input is not valid.
310
+ RuntimeError
311
+ It is a bug, please report it to the github issue tracker
312
+ (https://github.com/sympy/sympy/issues).
313
+
314
+ Examples
315
+ ========
316
+
317
+ >>> from sympy import FiniteSet, Interval, not_empty_in, oo
318
+ >>> from sympy.abc import x
319
+ >>> not_empty_in(FiniteSet(x/2).intersect(Interval(0, 1)), x)
320
+ Interval(0, 2)
321
+ >>> not_empty_in(FiniteSet(x, x**2).intersect(Interval(1, 2)), x)
322
+ Union(Interval(1, 2), Interval(-sqrt(2), -1))
323
+ >>> not_empty_in(FiniteSet(x**2/(x + 2)).intersect(Interval(1, oo)), x)
324
+ Union(Interval.Lopen(-2, -1), Interval(2, oo))
325
+ """
326
+
327
+ # TODO: handle piecewise defined functions
328
+ # TODO: handle transcendental functions
329
+ # TODO: handle multivariate functions
330
+ if len(syms) == 0:
331
+ raise ValueError("One or more symbols must be given in syms.")
332
+
333
+ if finset_intersection is S.EmptySet:
334
+ return S.EmptySet
335
+
336
+ if isinstance(finset_intersection, Union):
337
+ elm_in_sets = finset_intersection.args[0]
338
+ return Union(not_empty_in(finset_intersection.args[1], *syms),
339
+ elm_in_sets)
340
+
341
+ if isinstance(finset_intersection, FiniteSet):
342
+ finite_set = finset_intersection
343
+ _sets = S.Reals
344
+ else:
345
+ finite_set = finset_intersection.args[1]
346
+ _sets = finset_intersection.args[0]
347
+
348
+ if not isinstance(finite_set, FiniteSet):
349
+ raise ValueError('A FiniteSet must be given, not %s: %s' %
350
+ (type(finite_set), finite_set))
351
+
352
+ if len(syms) == 1:
353
+ symb = syms[0]
354
+ else:
355
+ raise NotImplementedError('more than one variables %s not handled' %
356
+ (syms,))
357
+
358
+ def elm_domain(expr, intrvl):
359
+ """ Finds the domain of an expression in any given interval """
360
+ from sympy.solvers.solveset import solveset
361
+
362
+ _start = intrvl.start
363
+ _end = intrvl.end
364
+ _singularities = solveset(expr.as_numer_denom()[1], symb,
365
+ domain=S.Reals)
366
+
367
+ if intrvl.right_open:
368
+ if _end is S.Infinity:
369
+ _domain1 = S.Reals
370
+ else:
371
+ _domain1 = solveset(expr < _end, symb, domain=S.Reals)
372
+ else:
373
+ _domain1 = solveset(expr <= _end, symb, domain=S.Reals)
374
+
375
+ if intrvl.left_open:
376
+ if _start is S.NegativeInfinity:
377
+ _domain2 = S.Reals
378
+ else:
379
+ _domain2 = solveset(expr > _start, symb, domain=S.Reals)
380
+ else:
381
+ _domain2 = solveset(expr >= _start, symb, domain=S.Reals)
382
+
383
+ # domain in the interval
384
+ expr_with_sing = Intersection(_domain1, _domain2)
385
+ expr_domain = Complement(expr_with_sing, _singularities)
386
+ return expr_domain
387
+
388
+ if isinstance(_sets, Interval):
389
+ return Union(*[elm_domain(element, _sets) for element in finite_set])
390
+
391
+ if isinstance(_sets, Union):
392
+ _domain = S.EmptySet
393
+ for intrvl in _sets.args:
394
+ _domain_element = Union(*[elm_domain(element, intrvl)
395
+ for element in finite_set])
396
+ _domain = Union(_domain, _domain_element)
397
+ return _domain
398
+
399
+
400
+ def periodicity(f, symbol, check=False):
401
+ """
402
+ Tests the given function for periodicity in the given symbol.
403
+
404
+ Parameters
405
+ ==========
406
+
407
+ f : :py:class:`~.Expr`
408
+ The concerned function.
409
+ symbol : :py:class:`~.Symbol`
410
+ The variable for which the period is to be determined.
411
+ check : bool, optional
412
+ The flag to verify whether the value being returned is a period or not.
413
+
414
+ Returns
415
+ =======
416
+
417
+ period
418
+ The period of the function is returned.
419
+ ``None`` is returned when the function is aperiodic or has a complex period.
420
+ The value of $0$ is returned as the period of a constant function.
421
+
422
+ Raises
423
+ ======
424
+
425
+ NotImplementedError
426
+ The value of the period computed cannot be verified.
427
+
428
+
429
+ Notes
430
+ =====
431
+
432
+ Currently, we do not support functions with a complex period.
433
+ The period of functions having complex periodic values such
434
+ as ``exp``, ``sinh`` is evaluated to ``None``.
435
+
436
+ The value returned might not be the "fundamental" period of the given
437
+ function i.e. it may not be the smallest periodic value of the function.
438
+
439
+ The verification of the period through the ``check`` flag is not reliable
440
+ due to internal simplification of the given expression. Hence, it is set
441
+ to ``False`` by default.
442
+
443
+ Examples
444
+ ========
445
+ >>> from sympy import periodicity, Symbol, sin, cos, tan, exp
446
+ >>> x = Symbol('x')
447
+ >>> f = sin(x) + sin(2*x) + sin(3*x)
448
+ >>> periodicity(f, x)
449
+ 2*pi
450
+ >>> periodicity(sin(x)*cos(x), x)
451
+ pi
452
+ >>> periodicity(exp(tan(2*x) - 1), x)
453
+ pi/2
454
+ >>> periodicity(sin(4*x)**cos(2*x), x)
455
+ pi
456
+ >>> periodicity(exp(x), x)
457
+ """
458
+ if symbol.kind is not NumberKind:
459
+ raise NotImplementedError("Cannot use symbol of kind %s" % symbol.kind)
460
+ temp = Dummy('x', real=True)
461
+ f = f.subs(symbol, temp)
462
+ symbol = temp
463
+
464
+ def _check(orig_f, period):
465
+ '''Return the checked period or raise an error.'''
466
+ new_f = orig_f.subs(symbol, symbol + period)
467
+ if new_f.equals(orig_f):
468
+ return period
469
+ else:
470
+ raise NotImplementedError(filldedent('''
471
+ The period of the given function cannot be verified.
472
+ When `%s` was replaced with `%s + %s` in `%s`, the result
473
+ was `%s` which was not recognized as being the same as
474
+ the original function.
475
+ So either the period was wrong or the two forms were
476
+ not recognized as being equal.
477
+ Set check=False to obtain the value.''' %
478
+ (symbol, symbol, period, orig_f, new_f)))
479
+
480
+ orig_f = f
481
+ period = None
482
+
483
+ if isinstance(f, Relational):
484
+ f = f.lhs - f.rhs
485
+
486
+ f = f.simplify()
487
+
488
+ if symbol not in f.free_symbols:
489
+ return S.Zero
490
+
491
+ if isinstance(f, TrigonometricFunction):
492
+ try:
493
+ period = f.period(symbol)
494
+ except NotImplementedError:
495
+ pass
496
+
497
+ if isinstance(f, Abs):
498
+ arg = f.args[0]
499
+ if isinstance(arg, (sec, csc, cos)):
500
+ # all but tan and cot might have a
501
+ # a period that is half as large
502
+ # so recast as sin
503
+ arg = sin(arg.args[0])
504
+ period = periodicity(arg, symbol)
505
+ if period is not None and isinstance(arg, sin):
506
+ # the argument of Abs was a trigonometric other than
507
+ # cot or tan; test to see if the half-period
508
+ # is valid. Abs(arg) has behaviour equivalent to
509
+ # orig_f, so use that for test:
510
+ orig_f = Abs(arg)
511
+ try:
512
+ return _check(orig_f, period/2)
513
+ except NotImplementedError as err:
514
+ if check:
515
+ raise NotImplementedError(err)
516
+ # else let new orig_f and period be
517
+ # checked below
518
+
519
+ if isinstance(f, exp) or (f.is_Pow and f.base == S.Exp1):
520
+ f = Pow(S.Exp1, expand_mul(f.exp))
521
+ if im(f) != 0:
522
+ period_real = periodicity(re(f), symbol)
523
+ period_imag = periodicity(im(f), symbol)
524
+ if period_real is not None and period_imag is not None:
525
+ period = lcim([period_real, period_imag])
526
+
527
+ if f.is_Pow and f.base != S.Exp1:
528
+ base, expo = f.args
529
+ base_has_sym = base.has(symbol)
530
+ expo_has_sym = expo.has(symbol)
531
+
532
+ if base_has_sym and not expo_has_sym:
533
+ period = periodicity(base, symbol)
534
+
535
+ elif expo_has_sym and not base_has_sym:
536
+ period = periodicity(expo, symbol)
537
+
538
+ else:
539
+ period = _periodicity(f.args, symbol)
540
+
541
+ elif f.is_Mul:
542
+ coeff, g = f.as_independent(symbol, as_Add=False)
543
+ if isinstance(g, TrigonometricFunction) or not equal_valued(coeff, 1):
544
+ period = periodicity(g, symbol)
545
+ else:
546
+ period = _periodicity(g.args, symbol)
547
+
548
+ elif f.is_Add:
549
+ k, g = f.as_independent(symbol)
550
+ if k is not S.Zero:
551
+ return periodicity(g, symbol)
552
+
553
+ period = _periodicity(g.args, symbol)
554
+
555
+ elif isinstance(f, Mod):
556
+ a, n = f.args
557
+
558
+ if a == symbol:
559
+ period = n
560
+ elif isinstance(a, TrigonometricFunction):
561
+ period = periodicity(a, symbol)
562
+ #check if 'f' is linear in 'symbol'
563
+ elif (a.is_polynomial(symbol) and degree(a, symbol) == 1 and
564
+ symbol not in n.free_symbols):
565
+ period = Abs(n / a.diff(symbol))
566
+
567
+ elif isinstance(f, Piecewise):
568
+ pass # not handling Piecewise yet as the return type is not favorable
569
+
570
+ elif period is None:
571
+ from sympy.solvers.decompogen import compogen, decompogen
572
+ g_s = decompogen(f, symbol)
573
+ num_of_gs = len(g_s)
574
+ if num_of_gs > 1:
575
+ for index, g in enumerate(reversed(g_s)):
576
+ start_index = num_of_gs - 1 - index
577
+ g = compogen(g_s[start_index:], symbol)
578
+ if g not in (orig_f, f): # Fix for issue 12620
579
+ period = periodicity(g, symbol)
580
+ if period is not None:
581
+ break
582
+
583
+ if period is not None:
584
+ if check:
585
+ return _check(orig_f, period)
586
+ return period
587
+
588
+ return None
589
+
590
+
591
+ def _periodicity(args, symbol):
592
+ """
593
+ Helper for `periodicity` to find the period of a list of simpler
594
+ functions.
595
+ It uses the `lcim` method to find the least common period of
596
+ all the functions.
597
+
598
+ Parameters
599
+ ==========
600
+
601
+ args : Tuple of :py:class:`~.Symbol`
602
+ All the symbols present in a function.
603
+
604
+ symbol : :py:class:`~.Symbol`
605
+ The symbol over which the function is to be evaluated.
606
+
607
+ Returns
608
+ =======
609
+
610
+ period
611
+ The least common period of the function for all the symbols
612
+ of the function.
613
+ ``None`` if for at least one of the symbols the function is aperiodic.
614
+
615
+ """
616
+ periods = []
617
+ for f in args:
618
+ period = periodicity(f, symbol)
619
+ if period is None:
620
+ return None
621
+
622
+ if period is not S.Zero:
623
+ periods.append(period)
624
+
625
+ if len(periods) > 1:
626
+ return lcim(periods)
627
+
628
+ if periods:
629
+ return periods[0]
630
+
631
+
632
+ def lcim(numbers):
633
+ """Returns the least common integral multiple of a list of numbers.
634
+
635
+ The numbers can be rational or irrational or a mixture of both.
636
+ `None` is returned for incommensurable numbers.
637
+
638
+ Parameters
639
+ ==========
640
+
641
+ numbers : list
642
+ Numbers (rational and/or irrational) for which lcim is to be found.
643
+
644
+ Returns
645
+ =======
646
+
647
+ number
648
+ lcim if it exists, otherwise ``None`` for incommensurable numbers.
649
+
650
+ Examples
651
+ ========
652
+
653
+ >>> from sympy.calculus.util import lcim
654
+ >>> from sympy import S, pi
655
+ >>> lcim([S(1)/2, S(3)/4, S(5)/6])
656
+ 15/2
657
+ >>> lcim([2*pi, 3*pi, pi, pi/2])
658
+ 6*pi
659
+ >>> lcim([S(1), 2*pi])
660
+ """
661
+ result = None
662
+ if all(num.is_irrational for num in numbers):
663
+ factorized_nums = [num.factor() for num in numbers]
664
+ factors_num = [num.as_coeff_Mul() for num in factorized_nums]
665
+ term = factors_num[0][1]
666
+ if all(factor == term for coeff, factor in factors_num):
667
+ common_term = term
668
+ coeffs = [coeff for coeff, factor in factors_num]
669
+ result = lcm_list(coeffs) * common_term
670
+
671
+ elif all(num.is_rational for num in numbers):
672
+ result = lcm_list(numbers)
673
+
674
+ else:
675
+ pass
676
+
677
+ return result
678
+
679
+ def is_convex(f, *syms, domain=S.Reals):
680
+ r"""Determines the convexity of the function passed in the argument.
681
+
682
+ Parameters
683
+ ==========
684
+
685
+ f : :py:class:`~.Expr`
686
+ The concerned function.
687
+ syms : Tuple of :py:class:`~.Symbol`
688
+ The variables with respect to which the convexity is to be determined.
689
+ domain : :py:class:`~.Interval`, optional
690
+ The domain over which the convexity of the function has to be checked.
691
+ If unspecified, S.Reals will be the default domain.
692
+
693
+ Returns
694
+ =======
695
+
696
+ bool
697
+ The method returns ``True`` if the function is convex otherwise it
698
+ returns ``False``.
699
+
700
+ Raises
701
+ ======
702
+
703
+ NotImplementedError
704
+ The check for the convexity of multivariate functions is not implemented yet.
705
+
706
+ Notes
707
+ =====
708
+
709
+ To determine concavity of a function pass `-f` as the concerned function.
710
+ To determine logarithmic convexity of a function pass `\log(f)` as
711
+ concerned function.
712
+ To determine logarithmic concavity of a function pass `-\log(f)` as
713
+ concerned function.
714
+
715
+ Currently, convexity check of multivariate functions is not handled.
716
+
717
+ Examples
718
+ ========
719
+
720
+ >>> from sympy import is_convex, symbols, exp, oo, Interval
721
+ >>> x = symbols('x')
722
+ >>> is_convex(exp(x), x)
723
+ True
724
+ >>> is_convex(x**3, x, domain = Interval(-1, oo))
725
+ False
726
+ >>> is_convex(1/x**2, x, domain=Interval.open(0, oo))
727
+ True
728
+
729
+ References
730
+ ==========
731
+
732
+ .. [1] https://en.wikipedia.org/wiki/Convex_function
733
+ .. [2] http://www.ifp.illinois.edu/~angelia/L3_convfunc.pdf
734
+ .. [3] https://en.wikipedia.org/wiki/Logarithmically_convex_function
735
+ .. [4] https://en.wikipedia.org/wiki/Logarithmically_concave_function
736
+ .. [5] https://en.wikipedia.org/wiki/Concave_function
737
+
738
+ """
739
+ if len(syms) > 1 :
740
+ return hessian(f, syms).is_positive_semidefinite
741
+ from sympy.solvers.inequalities import solve_univariate_inequality
742
+ f = _sympify(f)
743
+ var = syms[0]
744
+ if any(s in domain for s in singularities(f, var)):
745
+ return False
746
+ condition = f.diff(var, 2) < 0
747
+ if solve_univariate_inequality(condition, var, False, domain):
748
+ return False
749
+ return True
750
+
751
+
752
+ def stationary_points(f, symbol, domain=S.Reals):
753
+ """
754
+ Returns the stationary points of a function (where derivative of the
755
+ function is 0) in the given domain.
756
+
757
+ Parameters
758
+ ==========
759
+
760
+ f : :py:class:`~.Expr`
761
+ The concerned function.
762
+ symbol : :py:class:`~.Symbol`
763
+ The variable for which the stationary points are to be determined.
764
+ domain : :py:class:`~.Interval`
765
+ The domain over which the stationary points have to be checked.
766
+ If unspecified, ``S.Reals`` will be the default domain.
767
+
768
+ Returns
769
+ =======
770
+
771
+ Set
772
+ A set of stationary points for the function. If there are no
773
+ stationary point, an :py:class:`~.EmptySet` is returned.
774
+
775
+ Examples
776
+ ========
777
+
778
+ >>> from sympy import Interval, Symbol, S, sin, pi, pprint, stationary_points
779
+ >>> x = Symbol('x')
780
+
781
+ >>> stationary_points(1/x, x, S.Reals)
782
+ EmptySet
783
+
784
+ >>> pprint(stationary_points(sin(x), x), use_unicode=False)
785
+ pi 3*pi
786
+ {2*n*pi + -- | n in Integers} U {2*n*pi + ---- | n in Integers}
787
+ 2 2
788
+
789
+ >>> stationary_points(sin(x),x, Interval(0, 4*pi))
790
+ {pi/2, 3*pi/2, 5*pi/2, 7*pi/2}
791
+
792
+ """
793
+ from sympy.solvers.solveset import solveset
794
+
795
+ if domain is S.EmptySet:
796
+ return S.EmptySet
797
+
798
+ domain = continuous_domain(f, symbol, domain)
799
+ set = solveset(diff(f, symbol), symbol, domain)
800
+
801
+ return set
802
+
803
+
804
+ def maximum(f, symbol, domain=S.Reals):
805
+ """
806
+ Returns the maximum value of a function in the given domain.
807
+
808
+ Parameters
809
+ ==========
810
+
811
+ f : :py:class:`~.Expr`
812
+ The concerned function.
813
+ symbol : :py:class:`~.Symbol`
814
+ The variable for maximum value needs to be determined.
815
+ domain : :py:class:`~.Interval`
816
+ The domain over which the maximum have to be checked.
817
+ If unspecified, then the global maximum is returned.
818
+
819
+ Returns
820
+ =======
821
+
822
+ number
823
+ Maximum value of the function in given domain.
824
+
825
+ Examples
826
+ ========
827
+
828
+ >>> from sympy import Interval, Symbol, S, sin, cos, pi, maximum
829
+ >>> x = Symbol('x')
830
+
831
+ >>> f = -x**2 + 2*x + 5
832
+ >>> maximum(f, x, S.Reals)
833
+ 6
834
+
835
+ >>> maximum(sin(x), x, Interval(-pi, pi/4))
836
+ sqrt(2)/2
837
+
838
+ >>> maximum(sin(x)*cos(x), x)
839
+ 1/2
840
+
841
+ """
842
+ if isinstance(symbol, Symbol):
843
+ if domain is S.EmptySet:
844
+ raise ValueError("Maximum value not defined for empty domain.")
845
+
846
+ return function_range(f, symbol, domain).sup
847
+ else:
848
+ raise ValueError("%s is not a valid symbol." % symbol)
849
+
850
+
851
+ def minimum(f, symbol, domain=S.Reals):
852
+ """
853
+ Returns the minimum value of a function in the given domain.
854
+
855
+ Parameters
856
+ ==========
857
+
858
+ f : :py:class:`~.Expr`
859
+ The concerned function.
860
+ symbol : :py:class:`~.Symbol`
861
+ The variable for minimum value needs to be determined.
862
+ domain : :py:class:`~.Interval`
863
+ The domain over which the minimum have to be checked.
864
+ If unspecified, then the global minimum is returned.
865
+
866
+ Returns
867
+ =======
868
+
869
+ number
870
+ Minimum value of the function in the given domain.
871
+
872
+ Examples
873
+ ========
874
+
875
+ >>> from sympy import Interval, Symbol, S, sin, cos, minimum
876
+ >>> x = Symbol('x')
877
+
878
+ >>> f = x**2 + 2*x + 5
879
+ >>> minimum(f, x, S.Reals)
880
+ 4
881
+
882
+ >>> minimum(sin(x), x, Interval(2, 3))
883
+ sin(3)
884
+
885
+ >>> minimum(sin(x)*cos(x), x)
886
+ -1/2
887
+
888
+ """
889
+ if isinstance(symbol, Symbol):
890
+ if domain is S.EmptySet:
891
+ raise ValueError("Minimum value not defined for empty domain.")
892
+
893
+ return function_range(f, symbol, domain).inf
894
+ else:
895
+ raise ValueError("%s is not a valid symbol." % symbol)
.venv/lib/python3.13/site-packages/sympy/categories/__init__.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Category Theory module.
3
+
4
+ Provides some of the fundamental category-theory-related classes,
5
+ including categories, morphisms, diagrams. Functors are not
6
+ implemented yet.
7
+
8
+ The general reference work this module tries to follow is
9
+
10
+ [JoyOfCats] J. Adamek, H. Herrlich. G. E. Strecker: Abstract and
11
+ Concrete Categories. The Joy of Cats.
12
+
13
+ The latest version of this book should be available for free download
14
+ from
15
+
16
+ katmat.math.uni-bremen.de/acc/acc.pdf
17
+
18
+ """
19
+
20
+ from .baseclasses import (Object, Morphism, IdentityMorphism,
21
+ NamedMorphism, CompositeMorphism, Category,
22
+ Diagram)
23
+
24
+ from .diagram_drawing import (DiagramGrid, XypicDiagramDrawer,
25
+ xypic_draw_diagram, preview_diagram)
26
+
27
+ __all__ = [
28
+ 'Object', 'Morphism', 'IdentityMorphism', 'NamedMorphism',
29
+ 'CompositeMorphism', 'Category', 'Diagram',
30
+
31
+ 'DiagramGrid', 'XypicDiagramDrawer', 'xypic_draw_diagram',
32
+ 'preview_diagram',
33
+ ]
.venv/lib/python3.13/site-packages/sympy/categories/baseclasses.py ADDED
@@ -0,0 +1,978 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import S, Basic, Dict, Symbol, Tuple, sympify
2
+ from sympy.core.symbol import Str
3
+ from sympy.sets import Set, FiniteSet, EmptySet
4
+ from sympy.utilities.iterables import iterable
5
+
6
+
7
+ class Class(Set):
8
+ r"""
9
+ The base class for any kind of class in the set-theoretic sense.
10
+
11
+ Explanation
12
+ ===========
13
+
14
+ In axiomatic set theories, everything is a class. A class which
15
+ can be a member of another class is a set. A class which is not a
16
+ member of another class is a proper class. The class `\{1, 2\}`
17
+ is a set; the class of all sets is a proper class.
18
+
19
+ This class is essentially a synonym for :class:`sympy.core.Set`.
20
+ The goal of this class is to assure easier migration to the
21
+ eventual proper implementation of set theory.
22
+ """
23
+ is_proper = False
24
+
25
+
26
+ class Object(Symbol):
27
+ """
28
+ The base class for any kind of object in an abstract category.
29
+
30
+ Explanation
31
+ ===========
32
+
33
+ While technically any instance of :class:`~.Basic` will do, this
34
+ class is the recommended way to create abstract objects in
35
+ abstract categories.
36
+ """
37
+
38
+
39
+ class Morphism(Basic):
40
+ """
41
+ The base class for any morphism in an abstract category.
42
+
43
+ Explanation
44
+ ===========
45
+
46
+ In abstract categories, a morphism is an arrow between two
47
+ category objects. The object where the arrow starts is called the
48
+ domain, while the object where the arrow ends is called the
49
+ codomain.
50
+
51
+ Two morphisms between the same pair of objects are considered to
52
+ be the same morphisms. To distinguish between morphisms between
53
+ the same objects use :class:`NamedMorphism`.
54
+
55
+ It is prohibited to instantiate this class. Use one of the
56
+ derived classes instead.
57
+
58
+ See Also
59
+ ========
60
+
61
+ IdentityMorphism, NamedMorphism, CompositeMorphism
62
+ """
63
+ def __new__(cls, domain, codomain):
64
+ raise(NotImplementedError(
65
+ "Cannot instantiate Morphism. Use derived classes instead."))
66
+
67
+ @property
68
+ def domain(self):
69
+ """
70
+ Returns the domain of the morphism.
71
+
72
+ Examples
73
+ ========
74
+
75
+ >>> from sympy.categories import Object, NamedMorphism
76
+ >>> A = Object("A")
77
+ >>> B = Object("B")
78
+ >>> f = NamedMorphism(A, B, "f")
79
+ >>> f.domain
80
+ Object("A")
81
+
82
+ """
83
+ return self.args[0]
84
+
85
+ @property
86
+ def codomain(self):
87
+ """
88
+ Returns the codomain of the morphism.
89
+
90
+ Examples
91
+ ========
92
+
93
+ >>> from sympy.categories import Object, NamedMorphism
94
+ >>> A = Object("A")
95
+ >>> B = Object("B")
96
+ >>> f = NamedMorphism(A, B, "f")
97
+ >>> f.codomain
98
+ Object("B")
99
+
100
+ """
101
+ return self.args[1]
102
+
103
+ def compose(self, other):
104
+ r"""
105
+ Composes self with the supplied morphism.
106
+
107
+ The order of elements in the composition is the usual order,
108
+ i.e., to construct `g\circ f` use ``g.compose(f)``.
109
+
110
+ Examples
111
+ ========
112
+
113
+ >>> from sympy.categories import Object, NamedMorphism
114
+ >>> A = Object("A")
115
+ >>> B = Object("B")
116
+ >>> C = Object("C")
117
+ >>> f = NamedMorphism(A, B, "f")
118
+ >>> g = NamedMorphism(B, C, "g")
119
+ >>> g * f
120
+ CompositeMorphism((NamedMorphism(Object("A"), Object("B"), "f"),
121
+ NamedMorphism(Object("B"), Object("C"), "g")))
122
+ >>> (g * f).domain
123
+ Object("A")
124
+ >>> (g * f).codomain
125
+ Object("C")
126
+
127
+ """
128
+ return CompositeMorphism(other, self)
129
+
130
+ def __mul__(self, other):
131
+ r"""
132
+ Composes self with the supplied morphism.
133
+
134
+ The semantics of this operation is given by the following
135
+ equation: ``g * f == g.compose(f)`` for composable morphisms
136
+ ``g`` and ``f``.
137
+
138
+ See Also
139
+ ========
140
+
141
+ compose
142
+ """
143
+ return self.compose(other)
144
+
145
+
146
+ class IdentityMorphism(Morphism):
147
+ """
148
+ Represents an identity morphism.
149
+
150
+ Explanation
151
+ ===========
152
+
153
+ An identity morphism is a morphism with equal domain and codomain,
154
+ which acts as an identity with respect to composition.
155
+
156
+ Examples
157
+ ========
158
+
159
+ >>> from sympy.categories import Object, NamedMorphism, IdentityMorphism
160
+ >>> A = Object("A")
161
+ >>> B = Object("B")
162
+ >>> f = NamedMorphism(A, B, "f")
163
+ >>> id_A = IdentityMorphism(A)
164
+ >>> id_B = IdentityMorphism(B)
165
+ >>> f * id_A == f
166
+ True
167
+ >>> id_B * f == f
168
+ True
169
+
170
+ See Also
171
+ ========
172
+
173
+ Morphism
174
+ """
175
+ def __new__(cls, domain):
176
+ return Basic.__new__(cls, domain)
177
+
178
+ @property
179
+ def codomain(self):
180
+ return self.domain
181
+
182
+
183
+ class NamedMorphism(Morphism):
184
+ """
185
+ Represents a morphism which has a name.
186
+
187
+ Explanation
188
+ ===========
189
+
190
+ Names are used to distinguish between morphisms which have the
191
+ same domain and codomain: two named morphisms are equal if they
192
+ have the same domains, codomains, and names.
193
+
194
+ Examples
195
+ ========
196
+
197
+ >>> from sympy.categories import Object, NamedMorphism
198
+ >>> A = Object("A")
199
+ >>> B = Object("B")
200
+ >>> f = NamedMorphism(A, B, "f")
201
+ >>> f
202
+ NamedMorphism(Object("A"), Object("B"), "f")
203
+ >>> f.name
204
+ 'f'
205
+
206
+ See Also
207
+ ========
208
+
209
+ Morphism
210
+ """
211
+ def __new__(cls, domain, codomain, name):
212
+ if not name:
213
+ raise ValueError("Empty morphism names not allowed.")
214
+
215
+ if not isinstance(name, Str):
216
+ name = Str(name)
217
+
218
+ return Basic.__new__(cls, domain, codomain, name)
219
+
220
+ @property
221
+ def name(self):
222
+ """
223
+ Returns the name of the morphism.
224
+
225
+ Examples
226
+ ========
227
+
228
+ >>> from sympy.categories import Object, NamedMorphism
229
+ >>> A = Object("A")
230
+ >>> B = Object("B")
231
+ >>> f = NamedMorphism(A, B, "f")
232
+ >>> f.name
233
+ 'f'
234
+
235
+ """
236
+ return self.args[2].name
237
+
238
+
239
+ class CompositeMorphism(Morphism):
240
+ r"""
241
+ Represents a morphism which is a composition of other morphisms.
242
+
243
+ Explanation
244
+ ===========
245
+
246
+ Two composite morphisms are equal if the morphisms they were
247
+ obtained from (components) are the same and were listed in the
248
+ same order.
249
+
250
+ The arguments to the constructor for this class should be listed
251
+ in diagram order: to obtain the composition `g\circ f` from the
252
+ instances of :class:`Morphism` ``g`` and ``f`` use
253
+ ``CompositeMorphism(f, g)``.
254
+
255
+ Examples
256
+ ========
257
+
258
+ >>> from sympy.categories import Object, NamedMorphism, CompositeMorphism
259
+ >>> A = Object("A")
260
+ >>> B = Object("B")
261
+ >>> C = Object("C")
262
+ >>> f = NamedMorphism(A, B, "f")
263
+ >>> g = NamedMorphism(B, C, "g")
264
+ >>> g * f
265
+ CompositeMorphism((NamedMorphism(Object("A"), Object("B"), "f"),
266
+ NamedMorphism(Object("B"), Object("C"), "g")))
267
+ >>> CompositeMorphism(f, g) == g * f
268
+ True
269
+
270
+ """
271
+ @staticmethod
272
+ def _add_morphism(t, morphism):
273
+ """
274
+ Intelligently adds ``morphism`` to tuple ``t``.
275
+
276
+ Explanation
277
+ ===========
278
+
279
+ If ``morphism`` is a composite morphism, its components are
280
+ added to the tuple. If ``morphism`` is an identity, nothing
281
+ is added to the tuple.
282
+
283
+ No composability checks are performed.
284
+ """
285
+ if isinstance(morphism, CompositeMorphism):
286
+ # ``morphism`` is a composite morphism; we have to
287
+ # denest its components.
288
+ return t + morphism.components
289
+ elif isinstance(morphism, IdentityMorphism):
290
+ # ``morphism`` is an identity. Nothing happens.
291
+ return t
292
+ else:
293
+ return t + Tuple(morphism)
294
+
295
+ def __new__(cls, *components):
296
+ if components and not isinstance(components[0], Morphism):
297
+ # Maybe the user has explicitly supplied a list of
298
+ # morphisms.
299
+ return CompositeMorphism.__new__(cls, *components[0])
300
+
301
+ normalised_components = Tuple()
302
+
303
+ for current, following in zip(components, components[1:]):
304
+ if not isinstance(current, Morphism) or \
305
+ not isinstance(following, Morphism):
306
+ raise TypeError("All components must be morphisms.")
307
+
308
+ if current.codomain != following.domain:
309
+ raise ValueError("Uncomposable morphisms.")
310
+
311
+ normalised_components = CompositeMorphism._add_morphism(
312
+ normalised_components, current)
313
+
314
+ # We haven't added the last morphism to the list of normalised
315
+ # components. Add it now.
316
+ normalised_components = CompositeMorphism._add_morphism(
317
+ normalised_components, components[-1])
318
+
319
+ if not normalised_components:
320
+ # If ``normalised_components`` is empty, only identities
321
+ # were supplied. Since they all were composable, they are
322
+ # all the same identities.
323
+ return components[0]
324
+ elif len(normalised_components) == 1:
325
+ # No sense to construct a whole CompositeMorphism.
326
+ return normalised_components[0]
327
+
328
+ return Basic.__new__(cls, normalised_components)
329
+
330
+ @property
331
+ def components(self):
332
+ """
333
+ Returns the components of this composite morphism.
334
+
335
+ Examples
336
+ ========
337
+
338
+ >>> from sympy.categories import Object, NamedMorphism
339
+ >>> A = Object("A")
340
+ >>> B = Object("B")
341
+ >>> C = Object("C")
342
+ >>> f = NamedMorphism(A, B, "f")
343
+ >>> g = NamedMorphism(B, C, "g")
344
+ >>> (g * f).components
345
+ (NamedMorphism(Object("A"), Object("B"), "f"),
346
+ NamedMorphism(Object("B"), Object("C"), "g"))
347
+
348
+ """
349
+ return self.args[0]
350
+
351
+ @property
352
+ def domain(self):
353
+ """
354
+ Returns the domain of this composite morphism.
355
+
356
+ The domain of the composite morphism is the domain of its
357
+ first component.
358
+
359
+ Examples
360
+ ========
361
+
362
+ >>> from sympy.categories import Object, NamedMorphism
363
+ >>> A = Object("A")
364
+ >>> B = Object("B")
365
+ >>> C = Object("C")
366
+ >>> f = NamedMorphism(A, B, "f")
367
+ >>> g = NamedMorphism(B, C, "g")
368
+ >>> (g * f).domain
369
+ Object("A")
370
+
371
+ """
372
+ return self.components[0].domain
373
+
374
+ @property
375
+ def codomain(self):
376
+ """
377
+ Returns the codomain of this composite morphism.
378
+
379
+ The codomain of the composite morphism is the codomain of its
380
+ last component.
381
+
382
+ Examples
383
+ ========
384
+
385
+ >>> from sympy.categories import Object, NamedMorphism
386
+ >>> A = Object("A")
387
+ >>> B = Object("B")
388
+ >>> C = Object("C")
389
+ >>> f = NamedMorphism(A, B, "f")
390
+ >>> g = NamedMorphism(B, C, "g")
391
+ >>> (g * f).codomain
392
+ Object("C")
393
+
394
+ """
395
+ return self.components[-1].codomain
396
+
397
+ def flatten(self, new_name):
398
+ """
399
+ Forgets the composite structure of this morphism.
400
+
401
+ Explanation
402
+ ===========
403
+
404
+ If ``new_name`` is not empty, returns a :class:`NamedMorphism`
405
+ with the supplied name, otherwise returns a :class:`Morphism`.
406
+ In both cases the domain of the new morphism is the domain of
407
+ this composite morphism and the codomain of the new morphism
408
+ is the codomain of this composite morphism.
409
+
410
+ Examples
411
+ ========
412
+
413
+ >>> from sympy.categories import Object, NamedMorphism
414
+ >>> A = Object("A")
415
+ >>> B = Object("B")
416
+ >>> C = Object("C")
417
+ >>> f = NamedMorphism(A, B, "f")
418
+ >>> g = NamedMorphism(B, C, "g")
419
+ >>> (g * f).flatten("h")
420
+ NamedMorphism(Object("A"), Object("C"), "h")
421
+
422
+ """
423
+ return NamedMorphism(self.domain, self.codomain, new_name)
424
+
425
+
426
+ class Category(Basic):
427
+ r"""
428
+ An (abstract) category.
429
+
430
+ Explanation
431
+ ===========
432
+
433
+ A category [JoyOfCats] is a quadruple `\mbox{K} = (O, \hom, id,
434
+ \circ)` consisting of
435
+
436
+ * a (set-theoretical) class `O`, whose members are called
437
+ `K`-objects,
438
+
439
+ * for each pair `(A, B)` of `K`-objects, a set `\hom(A, B)` whose
440
+ members are called `K`-morphisms from `A` to `B`,
441
+
442
+ * for a each `K`-object `A`, a morphism `id:A\rightarrow A`,
443
+ called the `K`-identity of `A`,
444
+
445
+ * a composition law `\circ` associating with every `K`-morphisms
446
+ `f:A\rightarrow B` and `g:B\rightarrow C` a `K`-morphism `g\circ
447
+ f:A\rightarrow C`, called the composite of `f` and `g`.
448
+
449
+ Composition is associative, `K`-identities are identities with
450
+ respect to composition, and the sets `\hom(A, B)` are pairwise
451
+ disjoint.
452
+
453
+ This class knows nothing about its objects and morphisms.
454
+ Concrete cases of (abstract) categories should be implemented as
455
+ classes derived from this one.
456
+
457
+ Certain instances of :class:`Diagram` can be asserted to be
458
+ commutative in a :class:`Category` by supplying the argument
459
+ ``commutative_diagrams`` in the constructor.
460
+
461
+ Examples
462
+ ========
463
+
464
+ >>> from sympy.categories import Object, NamedMorphism, Diagram, Category
465
+ >>> from sympy import FiniteSet
466
+ >>> A = Object("A")
467
+ >>> B = Object("B")
468
+ >>> C = Object("C")
469
+ >>> f = NamedMorphism(A, B, "f")
470
+ >>> g = NamedMorphism(B, C, "g")
471
+ >>> d = Diagram([f, g])
472
+ >>> K = Category("K", commutative_diagrams=[d])
473
+ >>> K.commutative_diagrams == FiniteSet(d)
474
+ True
475
+
476
+ See Also
477
+ ========
478
+
479
+ Diagram
480
+ """
481
+ def __new__(cls, name, objects=EmptySet, commutative_diagrams=EmptySet):
482
+ if not name:
483
+ raise ValueError("A Category cannot have an empty name.")
484
+
485
+ if not isinstance(name, Str):
486
+ name = Str(name)
487
+
488
+ if not isinstance(objects, Class):
489
+ objects = Class(objects)
490
+
491
+ new_category = Basic.__new__(cls, name, objects,
492
+ FiniteSet(*commutative_diagrams))
493
+ return new_category
494
+
495
+ @property
496
+ def name(self):
497
+ """
498
+ Returns the name of this category.
499
+
500
+ Examples
501
+ ========
502
+
503
+ >>> from sympy.categories import Category
504
+ >>> K = Category("K")
505
+ >>> K.name
506
+ 'K'
507
+
508
+ """
509
+ return self.args[0].name
510
+
511
+ @property
512
+ def objects(self):
513
+ """
514
+ Returns the class of objects of this category.
515
+
516
+ Examples
517
+ ========
518
+
519
+ >>> from sympy.categories import Object, Category
520
+ >>> from sympy import FiniteSet
521
+ >>> A = Object("A")
522
+ >>> B = Object("B")
523
+ >>> K = Category("K", FiniteSet(A, B))
524
+ >>> K.objects
525
+ Class({Object("A"), Object("B")})
526
+
527
+ """
528
+ return self.args[1]
529
+
530
+ @property
531
+ def commutative_diagrams(self):
532
+ """
533
+ Returns the :class:`~.FiniteSet` of diagrams which are known to
534
+ be commutative in this category.
535
+
536
+ Examples
537
+ ========
538
+
539
+ >>> from sympy.categories import Object, NamedMorphism, Diagram, Category
540
+ >>> from sympy import FiniteSet
541
+ >>> A = Object("A")
542
+ >>> B = Object("B")
543
+ >>> C = Object("C")
544
+ >>> f = NamedMorphism(A, B, "f")
545
+ >>> g = NamedMorphism(B, C, "g")
546
+ >>> d = Diagram([f, g])
547
+ >>> K = Category("K", commutative_diagrams=[d])
548
+ >>> K.commutative_diagrams == FiniteSet(d)
549
+ True
550
+
551
+ """
552
+ return self.args[2]
553
+
554
+ def hom(self, A, B):
555
+ raise NotImplementedError(
556
+ "hom-sets are not implemented in Category.")
557
+
558
+ def all_morphisms(self):
559
+ raise NotImplementedError(
560
+ "Obtaining the class of morphisms is not implemented in Category.")
561
+
562
+
563
+ class Diagram(Basic):
564
+ r"""
565
+ Represents a diagram in a certain category.
566
+
567
+ Explanation
568
+ ===========
569
+
570
+ Informally, a diagram is a collection of objects of a category and
571
+ certain morphisms between them. A diagram is still a monoid with
572
+ respect to morphism composition; i.e., identity morphisms, as well
573
+ as all composites of morphisms included in the diagram belong to
574
+ the diagram. For a more formal approach to this notion see
575
+ [Pare1970].
576
+
577
+ The components of composite morphisms are also added to the
578
+ diagram. No properties are assigned to such morphisms by default.
579
+
580
+ A commutative diagram is often accompanied by a statement of the
581
+ following kind: "if such morphisms with such properties exist,
582
+ then such morphisms which such properties exist and the diagram is
583
+ commutative". To represent this, an instance of :class:`Diagram`
584
+ includes a collection of morphisms which are the premises and
585
+ another collection of conclusions. ``premises`` and
586
+ ``conclusions`` associate morphisms belonging to the corresponding
587
+ categories with the :class:`~.FiniteSet`'s of their properties.
588
+
589
+ The set of properties of a composite morphism is the intersection
590
+ of the sets of properties of its components. The domain and
591
+ codomain of a conclusion morphism should be among the domains and
592
+ codomains of the morphisms listed as the premises of a diagram.
593
+
594
+ No checks are carried out of whether the supplied object and
595
+ morphisms do belong to one and the same category.
596
+
597
+ Examples
598
+ ========
599
+
600
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
601
+ >>> from sympy import pprint, default_sort_key
602
+ >>> A = Object("A")
603
+ >>> B = Object("B")
604
+ >>> C = Object("C")
605
+ >>> f = NamedMorphism(A, B, "f")
606
+ >>> g = NamedMorphism(B, C, "g")
607
+ >>> d = Diagram([f, g])
608
+ >>> premises_keys = sorted(d.premises.keys(), key=default_sort_key)
609
+ >>> pprint(premises_keys, use_unicode=False)
610
+ [g*f:A-->C, id:A-->A, id:B-->B, id:C-->C, f:A-->B, g:B-->C]
611
+ >>> pprint(d.premises, use_unicode=False)
612
+ {g*f:A-->C: EmptySet, id:A-->A: EmptySet, id:B-->B: EmptySet,
613
+ id:C-->C: EmptySet, f:A-->B: EmptySet, g:B-->C: EmptySet}
614
+ >>> d = Diagram([f, g], {g * f: "unique"})
615
+ >>> pprint(d.conclusions,use_unicode=False)
616
+ {g*f:A-->C: {unique}}
617
+
618
+ References
619
+ ==========
620
+
621
+ [Pare1970] B. Pareigis: Categories and functors. Academic Press, 1970.
622
+
623
+ """
624
+ @staticmethod
625
+ def _set_dict_union(dictionary, key, value):
626
+ """
627
+ If ``key`` is in ``dictionary``, set the new value of ``key``
628
+ to be the union between the old value and ``value``.
629
+ Otherwise, set the value of ``key`` to ``value.
630
+
631
+ Returns ``True`` if the key already was in the dictionary and
632
+ ``False`` otherwise.
633
+ """
634
+ if key in dictionary:
635
+ dictionary[key] = dictionary[key] | value
636
+ return True
637
+ else:
638
+ dictionary[key] = value
639
+ return False
640
+
641
+ @staticmethod
642
+ def _add_morphism_closure(morphisms, morphism, props, add_identities=True,
643
+ recurse_composites=True):
644
+ """
645
+ Adds a morphism and its attributes to the supplied dictionary
646
+ ``morphisms``. If ``add_identities`` is True, also adds the
647
+ identity morphisms for the domain and the codomain of
648
+ ``morphism``.
649
+ """
650
+ if not Diagram._set_dict_union(morphisms, morphism, props):
651
+ # We have just added a new morphism.
652
+
653
+ if isinstance(morphism, IdentityMorphism):
654
+ if props:
655
+ # Properties for identity morphisms don't really
656
+ # make sense, because very much is known about
657
+ # identity morphisms already, so much that they
658
+ # are trivial. Having properties for identity
659
+ # morphisms would only be confusing.
660
+ raise ValueError(
661
+ "Instances of IdentityMorphism cannot have properties.")
662
+ return
663
+
664
+ if add_identities:
665
+ empty = EmptySet
666
+
667
+ id_dom = IdentityMorphism(morphism.domain)
668
+ id_cod = IdentityMorphism(morphism.codomain)
669
+
670
+ Diagram._set_dict_union(morphisms, id_dom, empty)
671
+ Diagram._set_dict_union(morphisms, id_cod, empty)
672
+
673
+ for existing_morphism, existing_props in list(morphisms.items()):
674
+ new_props = existing_props & props
675
+ if morphism.domain == existing_morphism.codomain:
676
+ left = morphism * existing_morphism
677
+ Diagram._set_dict_union(morphisms, left, new_props)
678
+ if morphism.codomain == existing_morphism.domain:
679
+ right = existing_morphism * morphism
680
+ Diagram._set_dict_union(morphisms, right, new_props)
681
+
682
+ if isinstance(morphism, CompositeMorphism) and recurse_composites:
683
+ # This is a composite morphism, add its components as
684
+ # well.
685
+ empty = EmptySet
686
+ for component in morphism.components:
687
+ Diagram._add_morphism_closure(morphisms, component, empty,
688
+ add_identities)
689
+
690
+ def __new__(cls, *args):
691
+ """
692
+ Construct a new instance of Diagram.
693
+
694
+ Explanation
695
+ ===========
696
+
697
+ If no arguments are supplied, an empty diagram is created.
698
+
699
+ If at least an argument is supplied, ``args[0]`` is
700
+ interpreted as the premises of the diagram. If ``args[0]`` is
701
+ a list, it is interpreted as a list of :class:`Morphism`'s, in
702
+ which each :class:`Morphism` has an empty set of properties.
703
+ If ``args[0]`` is a Python dictionary or a :class:`Dict`, it
704
+ is interpreted as a dictionary associating to some
705
+ :class:`Morphism`'s some properties.
706
+
707
+ If at least two arguments are supplied ``args[1]`` is
708
+ interpreted as the conclusions of the diagram. The type of
709
+ ``args[1]`` is interpreted in exactly the same way as the type
710
+ of ``args[0]``. If only one argument is supplied, the diagram
711
+ has no conclusions.
712
+
713
+ Examples
714
+ ========
715
+
716
+ >>> from sympy.categories import Object, NamedMorphism
717
+ >>> from sympy.categories import IdentityMorphism, Diagram
718
+ >>> A = Object("A")
719
+ >>> B = Object("B")
720
+ >>> C = Object("C")
721
+ >>> f = NamedMorphism(A, B, "f")
722
+ >>> g = NamedMorphism(B, C, "g")
723
+ >>> d = Diagram([f, g])
724
+ >>> IdentityMorphism(A) in d.premises.keys()
725
+ True
726
+ >>> g * f in d.premises.keys()
727
+ True
728
+ >>> d = Diagram([f, g], {g * f: "unique"})
729
+ >>> d.conclusions[g * f]
730
+ {unique}
731
+
732
+ """
733
+ premises = {}
734
+ conclusions = {}
735
+
736
+ # Here we will keep track of the objects which appear in the
737
+ # premises.
738
+ objects = EmptySet
739
+
740
+ if len(args) >= 1:
741
+ # We've got some premises in the arguments.
742
+ premises_arg = args[0]
743
+
744
+ if isinstance(premises_arg, list):
745
+ # The user has supplied a list of morphisms, none of
746
+ # which have any attributes.
747
+ empty = EmptySet
748
+
749
+ for morphism in premises_arg:
750
+ objects |= FiniteSet(morphism.domain, morphism.codomain)
751
+ Diagram._add_morphism_closure(premises, morphism, empty)
752
+ elif isinstance(premises_arg, (dict, Dict)):
753
+ # The user has supplied a dictionary of morphisms and
754
+ # their properties.
755
+ for morphism, props in premises_arg.items():
756
+ objects |= FiniteSet(morphism.domain, morphism.codomain)
757
+ Diagram._add_morphism_closure(
758
+ premises, morphism, FiniteSet(*props) if iterable(props) else FiniteSet(props))
759
+
760
+ if len(args) >= 2:
761
+ # We also have some conclusions.
762
+ conclusions_arg = args[1]
763
+
764
+ if isinstance(conclusions_arg, list):
765
+ # The user has supplied a list of morphisms, none of
766
+ # which have any attributes.
767
+ empty = EmptySet
768
+
769
+ for morphism in conclusions_arg:
770
+ # Check that no new objects appear in conclusions.
771
+ if ((sympify(objects.contains(morphism.domain)) is S.true) and
772
+ (sympify(objects.contains(morphism.codomain)) is S.true)):
773
+ # No need to add identities and recurse
774
+ # composites this time.
775
+ Diagram._add_morphism_closure(
776
+ conclusions, morphism, empty, add_identities=False,
777
+ recurse_composites=False)
778
+ elif isinstance(conclusions_arg, (dict, Dict)):
779
+ # The user has supplied a dictionary of morphisms and
780
+ # their properties.
781
+ for morphism, props in conclusions_arg.items():
782
+ # Check that no new objects appear in conclusions.
783
+ if (morphism.domain in objects) and \
784
+ (morphism.codomain in objects):
785
+ # No need to add identities and recurse
786
+ # composites this time.
787
+ Diagram._add_morphism_closure(
788
+ conclusions, morphism, FiniteSet(*props) if iterable(props) else FiniteSet(props),
789
+ add_identities=False, recurse_composites=False)
790
+
791
+ return Basic.__new__(cls, Dict(premises), Dict(conclusions), objects)
792
+
793
+ @property
794
+ def premises(self):
795
+ """
796
+ Returns the premises of this diagram.
797
+
798
+ Examples
799
+ ========
800
+
801
+ >>> from sympy.categories import Object, NamedMorphism
802
+ >>> from sympy.categories import IdentityMorphism, Diagram
803
+ >>> from sympy import pretty
804
+ >>> A = Object("A")
805
+ >>> B = Object("B")
806
+ >>> f = NamedMorphism(A, B, "f")
807
+ >>> id_A = IdentityMorphism(A)
808
+ >>> id_B = IdentityMorphism(B)
809
+ >>> d = Diagram([f])
810
+ >>> print(pretty(d.premises, use_unicode=False))
811
+ {id:A-->A: EmptySet, id:B-->B: EmptySet, f:A-->B: EmptySet}
812
+
813
+ """
814
+ return self.args[0]
815
+
816
+ @property
817
+ def conclusions(self):
818
+ """
819
+ Returns the conclusions of this diagram.
820
+
821
+ Examples
822
+ ========
823
+
824
+ >>> from sympy.categories import Object, NamedMorphism
825
+ >>> from sympy.categories import IdentityMorphism, Diagram
826
+ >>> from sympy import FiniteSet
827
+ >>> A = Object("A")
828
+ >>> B = Object("B")
829
+ >>> C = Object("C")
830
+ >>> f = NamedMorphism(A, B, "f")
831
+ >>> g = NamedMorphism(B, C, "g")
832
+ >>> d = Diagram([f, g])
833
+ >>> IdentityMorphism(A) in d.premises.keys()
834
+ True
835
+ >>> g * f in d.premises.keys()
836
+ True
837
+ >>> d = Diagram([f, g], {g * f: "unique"})
838
+ >>> d.conclusions[g * f] == FiniteSet("unique")
839
+ True
840
+
841
+ """
842
+ return self.args[1]
843
+
844
+ @property
845
+ def objects(self):
846
+ """
847
+ Returns the :class:`~.FiniteSet` of objects that appear in this
848
+ diagram.
849
+
850
+ Examples
851
+ ========
852
+
853
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
854
+ >>> A = Object("A")
855
+ >>> B = Object("B")
856
+ >>> C = Object("C")
857
+ >>> f = NamedMorphism(A, B, "f")
858
+ >>> g = NamedMorphism(B, C, "g")
859
+ >>> d = Diagram([f, g])
860
+ >>> d.objects
861
+ {Object("A"), Object("B"), Object("C")}
862
+
863
+ """
864
+ return self.args[2]
865
+
866
+ def hom(self, A, B):
867
+ """
868
+ Returns a 2-tuple of sets of morphisms between objects ``A`` and
869
+ ``B``: one set of morphisms listed as premises, and the other set
870
+ of morphisms listed as conclusions.
871
+
872
+ Examples
873
+ ========
874
+
875
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
876
+ >>> from sympy import pretty
877
+ >>> A = Object("A")
878
+ >>> B = Object("B")
879
+ >>> C = Object("C")
880
+ >>> f = NamedMorphism(A, B, "f")
881
+ >>> g = NamedMorphism(B, C, "g")
882
+ >>> d = Diagram([f, g], {g * f: "unique"})
883
+ >>> print(pretty(d.hom(A, C), use_unicode=False))
884
+ ({g*f:A-->C}, {g*f:A-->C})
885
+
886
+ See Also
887
+ ========
888
+ Object, Morphism
889
+ """
890
+ premises = EmptySet
891
+ conclusions = EmptySet
892
+
893
+ for morphism in self.premises.keys():
894
+ if (morphism.domain == A) and (morphism.codomain == B):
895
+ premises |= FiniteSet(morphism)
896
+ for morphism in self.conclusions.keys():
897
+ if (morphism.domain == A) and (morphism.codomain == B):
898
+ conclusions |= FiniteSet(morphism)
899
+
900
+ return (premises, conclusions)
901
+
902
+ def is_subdiagram(self, diagram):
903
+ """
904
+ Checks whether ``diagram`` is a subdiagram of ``self``.
905
+ Diagram `D'` is a subdiagram of `D` if all premises
906
+ (conclusions) of `D'` are contained in the premises
907
+ (conclusions) of `D`. The morphisms contained
908
+ both in `D'` and `D` should have the same properties for `D'`
909
+ to be a subdiagram of `D`.
910
+
911
+ Examples
912
+ ========
913
+
914
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
915
+ >>> A = Object("A")
916
+ >>> B = Object("B")
917
+ >>> C = Object("C")
918
+ >>> f = NamedMorphism(A, B, "f")
919
+ >>> g = NamedMorphism(B, C, "g")
920
+ >>> d = Diagram([f, g], {g * f: "unique"})
921
+ >>> d1 = Diagram([f])
922
+ >>> d.is_subdiagram(d1)
923
+ True
924
+ >>> d1.is_subdiagram(d)
925
+ False
926
+ """
927
+ premises = all((m in self.premises) and
928
+ (diagram.premises[m] == self.premises[m])
929
+ for m in diagram.premises)
930
+ if not premises:
931
+ return False
932
+
933
+ conclusions = all((m in self.conclusions) and
934
+ (diagram.conclusions[m] == self.conclusions[m])
935
+ for m in diagram.conclusions)
936
+
937
+ # Premises is surely ``True`` here.
938
+ return conclusions
939
+
940
+ def subdiagram_from_objects(self, objects):
941
+ """
942
+ If ``objects`` is a subset of the objects of ``self``, returns
943
+ a diagram which has as premises all those premises of ``self``
944
+ which have a domains and codomains in ``objects``, likewise
945
+ for conclusions. Properties are preserved.
946
+
947
+ Examples
948
+ ========
949
+
950
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
951
+ >>> from sympy import FiniteSet
952
+ >>> A = Object("A")
953
+ >>> B = Object("B")
954
+ >>> C = Object("C")
955
+ >>> f = NamedMorphism(A, B, "f")
956
+ >>> g = NamedMorphism(B, C, "g")
957
+ >>> d = Diagram([f, g], {f: "unique", g*f: "veryunique"})
958
+ >>> d1 = d.subdiagram_from_objects(FiniteSet(A, B))
959
+ >>> d1 == Diagram([f], {f: "unique"})
960
+ True
961
+ """
962
+ if not objects.is_subset(self.objects):
963
+ raise ValueError(
964
+ "Supplied objects should all belong to the diagram.")
965
+
966
+ new_premises = {}
967
+ for morphism, props in self.premises.items():
968
+ if ((sympify(objects.contains(morphism.domain)) is S.true) and
969
+ (sympify(objects.contains(morphism.codomain)) is S.true)):
970
+ new_premises[morphism] = props
971
+
972
+ new_conclusions = {}
973
+ for morphism, props in self.conclusions.items():
974
+ if ((sympify(objects.contains(morphism.domain)) is S.true) and
975
+ (sympify(objects.contains(morphism.codomain)) is S.true)):
976
+ new_conclusions[morphism] = props
977
+
978
+ return Diagram(new_premises, new_conclusions)
.venv/lib/python3.13/site-packages/sympy/categories/diagram_drawing.py ADDED
@@ -0,0 +1,2580 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""
2
+ This module contains the functionality to arrange the nodes of a
3
+ diagram on an abstract grid, and then to produce a graphical
4
+ representation of the grid.
5
+
6
+ The currently supported back-ends are Xy-pic [Xypic].
7
+
8
+ Layout Algorithm
9
+ ================
10
+
11
+ This section provides an overview of the algorithms implemented in
12
+ :class:`DiagramGrid` to lay out diagrams.
13
+
14
+ The first step of the algorithm is the removal composite and identity
15
+ morphisms which do not have properties in the supplied diagram. The
16
+ premises and conclusions of the diagram are then merged.
17
+
18
+ The generic layout algorithm begins with the construction of the
19
+ "skeleton" of the diagram. The skeleton is an undirected graph which
20
+ has the objects of the diagram as vertices and has an (undirected)
21
+ edge between each pair of objects between which there exist morphisms.
22
+ The direction of the morphisms does not matter at this stage. The
23
+ skeleton also includes an edge between each pair of vertices `A` and
24
+ `C` such that there exists an object `B` which is connected via
25
+ a morphism to `A`, and via a morphism to `C`.
26
+
27
+ The skeleton constructed in this way has the property that every
28
+ object is a vertex of a triangle formed by three edges of the
29
+ skeleton. This property lies at the base of the generic layout
30
+ algorithm.
31
+
32
+ After the skeleton has been constructed, the algorithm lists all
33
+ triangles which can be formed. Note that some triangles will not have
34
+ all edges corresponding to morphisms which will actually be drawn.
35
+ Triangles which have only one edge or less which will actually be
36
+ drawn are immediately discarded.
37
+
38
+ The list of triangles is sorted according to the number of edges which
39
+ correspond to morphisms, then the triangle with the least number of such
40
+ edges is selected. One of such edges is picked and the corresponding
41
+ objects are placed horizontally, on a grid. This edge is recorded to
42
+ be in the fringe. The algorithm then finds a "welding" of a triangle
43
+ to the fringe. A welding is an edge in the fringe where a triangle
44
+ could be attached. If the algorithm succeeds in finding such a
45
+ welding, it adds to the grid that vertex of the triangle which was not
46
+ yet included in any edge in the fringe and records the two new edges in
47
+ the fringe. This process continues iteratively until all objects of
48
+ the diagram has been placed or until no more weldings can be found.
49
+
50
+ An edge is only removed from the fringe when a welding to this edge
51
+ has been found, and there is no room around this edge to place
52
+ another vertex.
53
+
54
+ When no more weldings can be found, but there are still triangles
55
+ left, the algorithm searches for a possibility of attaching one of the
56
+ remaining triangles to the existing structure by a vertex. If such a
57
+ possibility is found, the corresponding edge of the found triangle is
58
+ placed in the found space and the iterative process of welding
59
+ triangles restarts.
60
+
61
+ When logical groups are supplied, each of these groups is laid out
62
+ independently. Then a diagram is constructed in which groups are
63
+ objects and any two logical groups between which there exist morphisms
64
+ are connected via a morphism. This diagram is laid out. Finally,
65
+ the grid which includes all objects of the initial diagram is
66
+ constructed by replacing the cells which contain logical groups with
67
+ the corresponding laid out grids, and by correspondingly expanding the
68
+ rows and columns.
69
+
70
+ The sequential layout algorithm begins by constructing the
71
+ underlying undirected graph defined by the morphisms obtained after
72
+ simplifying premises and conclusions and merging them (see above).
73
+ The vertex with the minimal degree is then picked up and depth-first
74
+ search is started from it. All objects which are located at distance
75
+ `n` from the root in the depth-first search tree, are positioned in
76
+ the `n`-th column of the resulting grid. The sequential layout will
77
+ therefore attempt to lay the objects out along a line.
78
+
79
+ References
80
+ ==========
81
+
82
+ .. [Xypic] https://xy-pic.sourceforge.net/
83
+
84
+ """
85
+ from sympy.categories import (CompositeMorphism, IdentityMorphism,
86
+ NamedMorphism, Diagram)
87
+ from sympy.core import Dict, Symbol, default_sort_key
88
+ from sympy.printing.latex import latex
89
+ from sympy.sets import FiniteSet
90
+ from sympy.utilities.iterables import iterable
91
+ from sympy.utilities.decorator import doctest_depends_on
92
+
93
+ from itertools import chain
94
+
95
+
96
+ __doctest_requires__ = {('preview_diagram',): 'pyglet'}
97
+
98
+
99
+ class _GrowableGrid:
100
+ """
101
+ Holds a growable grid of objects.
102
+
103
+ Explanation
104
+ ===========
105
+
106
+ It is possible to append or prepend a row or a column to the grid
107
+ using the corresponding methods. Prepending rows or columns has
108
+ the effect of changing the coordinates of the already existing
109
+ elements.
110
+
111
+ This class currently represents a naive implementation of the
112
+ functionality with little attempt at optimisation.
113
+ """
114
+ def __init__(self, width, height):
115
+ self._width = width
116
+ self._height = height
117
+
118
+ self._array = [[None for j in range(width)] for i in range(height)]
119
+
120
+ @property
121
+ def width(self):
122
+ return self._width
123
+
124
+ @property
125
+ def height(self):
126
+ return self._height
127
+
128
+ def __getitem__(self, i_j):
129
+ """
130
+ Returns the element located at in the i-th line and j-th
131
+ column.
132
+ """
133
+ i, j = i_j
134
+ return self._array[i][j]
135
+
136
+ def __setitem__(self, i_j, newvalue):
137
+ """
138
+ Sets the element located at in the i-th line and j-th
139
+ column.
140
+ """
141
+ i, j = i_j
142
+ self._array[i][j] = newvalue
143
+
144
+ def append_row(self):
145
+ """
146
+ Appends an empty row to the grid.
147
+ """
148
+ self._height += 1
149
+ self._array.append([None for j in range(self._width)])
150
+
151
+ def append_column(self):
152
+ """
153
+ Appends an empty column to the grid.
154
+ """
155
+ self._width += 1
156
+ for i in range(self._height):
157
+ self._array[i].append(None)
158
+
159
+ def prepend_row(self):
160
+ """
161
+ Prepends the grid with an empty row.
162
+ """
163
+ self._height += 1
164
+ self._array.insert(0, [None for j in range(self._width)])
165
+
166
+ def prepend_column(self):
167
+ """
168
+ Prepends the grid with an empty column.
169
+ """
170
+ self._width += 1
171
+ for i in range(self._height):
172
+ self._array[i].insert(0, None)
173
+
174
+
175
+ class DiagramGrid:
176
+ r"""
177
+ Constructs and holds the fitting of the diagram into a grid.
178
+
179
+ Explanation
180
+ ===========
181
+
182
+ The mission of this class is to analyse the structure of the
183
+ supplied diagram and to place its objects on a grid such that,
184
+ when the objects and the morphisms are actually drawn, the diagram
185
+ would be "readable", in the sense that there will not be many
186
+ intersections of moprhisms. This class does not perform any
187
+ actual drawing. It does strive nevertheless to offer sufficient
188
+ metadata to draw a diagram.
189
+
190
+ Consider the following simple diagram.
191
+
192
+ >>> from sympy.categories import Object, NamedMorphism
193
+ >>> from sympy.categories import Diagram, DiagramGrid
194
+ >>> from sympy import pprint
195
+ >>> A = Object("A")
196
+ >>> B = Object("B")
197
+ >>> C = Object("C")
198
+ >>> f = NamedMorphism(A, B, "f")
199
+ >>> g = NamedMorphism(B, C, "g")
200
+ >>> diagram = Diagram([f, g])
201
+
202
+ The simplest way to have a diagram laid out is the following:
203
+
204
+ >>> grid = DiagramGrid(diagram)
205
+ >>> (grid.width, grid.height)
206
+ (2, 2)
207
+ >>> pprint(grid)
208
+ A B
209
+ <BLANKLINE>
210
+ C
211
+
212
+ Sometimes one sees the diagram as consisting of logical groups.
213
+ One can advise ``DiagramGrid`` as to such groups by employing the
214
+ ``groups`` keyword argument.
215
+
216
+ Consider the following diagram:
217
+
218
+ >>> D = Object("D")
219
+ >>> f = NamedMorphism(A, B, "f")
220
+ >>> g = NamedMorphism(B, C, "g")
221
+ >>> h = NamedMorphism(D, A, "h")
222
+ >>> k = NamedMorphism(D, B, "k")
223
+ >>> diagram = Diagram([f, g, h, k])
224
+
225
+ Lay it out with generic layout:
226
+
227
+ >>> grid = DiagramGrid(diagram)
228
+ >>> pprint(grid)
229
+ A B D
230
+ <BLANKLINE>
231
+ C
232
+
233
+ Now, we can group the objects `A` and `D` to have them near one
234
+ another:
235
+
236
+ >>> grid = DiagramGrid(diagram, groups=[[A, D], B, C])
237
+ >>> pprint(grid)
238
+ B C
239
+ <BLANKLINE>
240
+ A D
241
+
242
+ Note how the positioning of the other objects changes.
243
+
244
+ Further indications can be supplied to the constructor of
245
+ :class:`DiagramGrid` using keyword arguments. The currently
246
+ supported hints are explained in the following paragraphs.
247
+
248
+ :class:`DiagramGrid` does not automatically guess which layout
249
+ would suit the supplied diagram better. Consider, for example,
250
+ the following linear diagram:
251
+
252
+ >>> E = Object("E")
253
+ >>> f = NamedMorphism(A, B, "f")
254
+ >>> g = NamedMorphism(B, C, "g")
255
+ >>> h = NamedMorphism(C, D, "h")
256
+ >>> i = NamedMorphism(D, E, "i")
257
+ >>> diagram = Diagram([f, g, h, i])
258
+
259
+ When laid out with the generic layout, it does not get to look
260
+ linear:
261
+
262
+ >>> grid = DiagramGrid(diagram)
263
+ >>> pprint(grid)
264
+ A B
265
+ <BLANKLINE>
266
+ C D
267
+ <BLANKLINE>
268
+ E
269
+
270
+ To get it laid out in a line, use ``layout="sequential"``:
271
+
272
+ >>> grid = DiagramGrid(diagram, layout="sequential")
273
+ >>> pprint(grid)
274
+ A B C D E
275
+
276
+ One may sometimes need to transpose the resulting layout. While
277
+ this can always be done by hand, :class:`DiagramGrid` provides a
278
+ hint for that purpose:
279
+
280
+ >>> grid = DiagramGrid(diagram, layout="sequential", transpose=True)
281
+ >>> pprint(grid)
282
+ A
283
+ <BLANKLINE>
284
+ B
285
+ <BLANKLINE>
286
+ C
287
+ <BLANKLINE>
288
+ D
289
+ <BLANKLINE>
290
+ E
291
+
292
+ Separate hints can also be provided for each group. For an
293
+ example, refer to ``tests/test_drawing.py``, and see the different
294
+ ways in which the five lemma [FiveLemma] can be laid out.
295
+
296
+ See Also
297
+ ========
298
+
299
+ Diagram
300
+
301
+ References
302
+ ==========
303
+
304
+ .. [FiveLemma] https://en.wikipedia.org/wiki/Five_lemma
305
+ """
306
+ @staticmethod
307
+ def _simplify_morphisms(morphisms):
308
+ """
309
+ Given a dictionary mapping morphisms to their properties,
310
+ returns a new dictionary in which there are no morphisms which
311
+ do not have properties, and which are compositions of other
312
+ morphisms included in the dictionary. Identities are dropped
313
+ as well.
314
+ """
315
+ newmorphisms = {}
316
+ for morphism, props in morphisms.items():
317
+ if isinstance(morphism, CompositeMorphism) and not props:
318
+ continue
319
+ elif isinstance(morphism, IdentityMorphism):
320
+ continue
321
+ else:
322
+ newmorphisms[morphism] = props
323
+ return newmorphisms
324
+
325
+ @staticmethod
326
+ def _merge_premises_conclusions(premises, conclusions):
327
+ """
328
+ Given two dictionaries of morphisms and their properties,
329
+ produces a single dictionary which includes elements from both
330
+ dictionaries. If a morphism has some properties in premises
331
+ and also in conclusions, the properties in conclusions take
332
+ priority.
333
+ """
334
+ return dict(chain(premises.items(), conclusions.items()))
335
+
336
+ @staticmethod
337
+ def _juxtapose_edges(edge1, edge2):
338
+ """
339
+ If ``edge1`` and ``edge2`` have precisely one common endpoint,
340
+ returns an edge which would form a triangle with ``edge1`` and
341
+ ``edge2``.
342
+
343
+ If ``edge1`` and ``edge2`` do not have a common endpoint,
344
+ returns ``None``.
345
+
346
+ If ``edge1`` and ``edge`` are the same edge, returns ``None``.
347
+ """
348
+ intersection = edge1 & edge2
349
+ if len(intersection) != 1:
350
+ # The edges either have no common points or are equal.
351
+ return None
352
+
353
+ # The edges have a common endpoint. Extract the different
354
+ # endpoints and set up the new edge.
355
+ return (edge1 - intersection) | (edge2 - intersection)
356
+
357
+ @staticmethod
358
+ def _add_edge_append(dictionary, edge, elem):
359
+ """
360
+ If ``edge`` is not in ``dictionary``, adds ``edge`` to the
361
+ dictionary and sets its value to ``[elem]``. Otherwise
362
+ appends ``elem`` to the value of existing entry.
363
+
364
+ Note that edges are undirected, thus `(A, B) = (B, A)`.
365
+ """
366
+ if edge in dictionary:
367
+ dictionary[edge].append(elem)
368
+ else:
369
+ dictionary[edge] = [elem]
370
+
371
+ @staticmethod
372
+ def _build_skeleton(morphisms):
373
+ """
374
+ Creates a dictionary which maps edges to corresponding
375
+ morphisms. Thus for a morphism `f:A\rightarrow B`, the edge
376
+ `(A, B)` will be associated with `f`. This function also adds
377
+ to the list those edges which are formed by juxtaposition of
378
+ two edges already in the list. These new edges are not
379
+ associated with any morphism and are only added to assure that
380
+ the diagram can be decomposed into triangles.
381
+ """
382
+ edges = {}
383
+ # Create edges for morphisms.
384
+ for morphism in morphisms:
385
+ DiagramGrid._add_edge_append(
386
+ edges, frozenset([morphism.domain, morphism.codomain]), morphism)
387
+
388
+ # Create new edges by juxtaposing existing edges.
389
+ edges1 = dict(edges)
390
+ for w in edges1:
391
+ for v in edges1:
392
+ wv = DiagramGrid._juxtapose_edges(w, v)
393
+ if wv and wv not in edges:
394
+ edges[wv] = []
395
+
396
+ return edges
397
+
398
+ @staticmethod
399
+ def _list_triangles(edges):
400
+ """
401
+ Builds the set of triangles formed by the supplied edges. The
402
+ triangles are arbitrary and need not be commutative. A
403
+ triangle is a set that contains all three of its sides.
404
+ """
405
+ triangles = set()
406
+
407
+ for w in edges:
408
+ for v in edges:
409
+ wv = DiagramGrid._juxtapose_edges(w, v)
410
+ if wv and wv in edges:
411
+ triangles.add(frozenset([w, v, wv]))
412
+
413
+ return triangles
414
+
415
+ @staticmethod
416
+ def _drop_redundant_triangles(triangles, skeleton):
417
+ """
418
+ Returns a list which contains only those triangles who have
419
+ morphisms associated with at least two edges.
420
+ """
421
+ return [tri for tri in triangles
422
+ if len([e for e in tri if skeleton[e]]) >= 2]
423
+
424
+ @staticmethod
425
+ def _morphism_length(morphism):
426
+ """
427
+ Returns the length of a morphism. The length of a morphism is
428
+ the number of components it consists of. A non-composite
429
+ morphism is of length 1.
430
+ """
431
+ if isinstance(morphism, CompositeMorphism):
432
+ return len(morphism.components)
433
+ else:
434
+ return 1
435
+
436
+ @staticmethod
437
+ def _compute_triangle_min_sizes(triangles, edges):
438
+ r"""
439
+ Returns a dictionary mapping triangles to their minimal sizes.
440
+ The minimal size of a triangle is the sum of maximal lengths
441
+ of morphisms associated to the sides of the triangle. The
442
+ length of a morphism is the number of components it consists
443
+ of. A non-composite morphism is of length 1.
444
+
445
+ Sorting triangles by this metric attempts to address two
446
+ aspects of layout. For triangles with only simple morphisms
447
+ in the edge, this assures that triangles with all three edges
448
+ visible will get typeset after triangles with less visible
449
+ edges, which sometimes minimizes the necessity in diagonal
450
+ arrows. For triangles with composite morphisms in the edges,
451
+ this assures that objects connected with shorter morphisms
452
+ will be laid out first, resulting the visual proximity of
453
+ those objects which are connected by shorter morphisms.
454
+ """
455
+ triangle_sizes = {}
456
+ for triangle in triangles:
457
+ size = 0
458
+ for e in triangle:
459
+ morphisms = edges[e]
460
+ if morphisms:
461
+ size += max(DiagramGrid._morphism_length(m)
462
+ for m in morphisms)
463
+ triangle_sizes[triangle] = size
464
+ return triangle_sizes
465
+
466
+ @staticmethod
467
+ def _triangle_objects(triangle):
468
+ """
469
+ Given a triangle, returns the objects included in it.
470
+ """
471
+ # A triangle is a frozenset of three two-element frozensets
472
+ # (the edges). This chains the three edges together and
473
+ # creates a frozenset from the iterator, thus producing a
474
+ # frozenset of objects of the triangle.
475
+ return frozenset(chain(*tuple(triangle)))
476
+
477
+ @staticmethod
478
+ def _other_vertex(triangle, edge):
479
+ """
480
+ Given a triangle and an edge of it, returns the vertex which
481
+ opposes the edge.
482
+ """
483
+ # This gets the set of objects of the triangle and then
484
+ # subtracts the set of objects employed in ``edge`` to get the
485
+ # vertex opposite to ``edge``.
486
+ return list(DiagramGrid._triangle_objects(triangle) - set(edge))[0]
487
+
488
+ @staticmethod
489
+ def _empty_point(pt, grid):
490
+ """
491
+ Checks if the cell at coordinates ``pt`` is either empty or
492
+ out of the bounds of the grid.
493
+ """
494
+ if (pt[0] < 0) or (pt[1] < 0) or \
495
+ (pt[0] >= grid.height) or (pt[1] >= grid.width):
496
+ return True
497
+ return grid[pt] is None
498
+
499
+ @staticmethod
500
+ def _put_object(coords, obj, grid, fringe):
501
+ """
502
+ Places an object at the coordinate ``cords`` in ``grid``,
503
+ growing the grid and updating ``fringe``, if necessary.
504
+ Returns (0, 0) if no row or column has been prepended, (1, 0)
505
+ if a row was prepended, (0, 1) if a column was prepended and
506
+ (1, 1) if both a column and a row were prepended.
507
+ """
508
+ (i, j) = coords
509
+ offset = (0, 0)
510
+ if i == -1:
511
+ grid.prepend_row()
512
+ i = 0
513
+ offset = (1, 0)
514
+ for k in range(len(fringe)):
515
+ ((i1, j1), (i2, j2)) = fringe[k]
516
+ fringe[k] = ((i1 + 1, j1), (i2 + 1, j2))
517
+ elif i == grid.height:
518
+ grid.append_row()
519
+
520
+ if j == -1:
521
+ j = 0
522
+ offset = (offset[0], 1)
523
+ grid.prepend_column()
524
+ for k in range(len(fringe)):
525
+ ((i1, j1), (i2, j2)) = fringe[k]
526
+ fringe[k] = ((i1, j1 + 1), (i2, j2 + 1))
527
+ elif j == grid.width:
528
+ grid.append_column()
529
+
530
+ grid[i, j] = obj
531
+ return offset
532
+
533
+ @staticmethod
534
+ def _choose_target_cell(pt1, pt2, edge, obj, skeleton, grid):
535
+ """
536
+ Given two points, ``pt1`` and ``pt2``, and the welding edge
537
+ ``edge``, chooses one of the two points to place the opposing
538
+ vertex ``obj`` of the triangle. If neither of this points
539
+ fits, returns ``None``.
540
+ """
541
+ pt1_empty = DiagramGrid._empty_point(pt1, grid)
542
+ pt2_empty = DiagramGrid._empty_point(pt2, grid)
543
+
544
+ if pt1_empty and pt2_empty:
545
+ # Both cells are empty. Of these two, choose that cell
546
+ # which will assure that a visible edge of the triangle
547
+ # will be drawn perpendicularly to the current welding
548
+ # edge.
549
+
550
+ A = grid[edge[0]]
551
+
552
+ if skeleton.get(frozenset([A, obj])):
553
+ return pt1
554
+ else:
555
+ return pt2
556
+ if pt1_empty:
557
+ return pt1
558
+ elif pt2_empty:
559
+ return pt2
560
+ else:
561
+ return None
562
+
563
+ @staticmethod
564
+ def _find_triangle_to_weld(triangles, fringe, grid):
565
+ """
566
+ Finds, if possible, a triangle and an edge in the ``fringe`` to
567
+ which the triangle could be attached. Returns the tuple
568
+ containing the triangle and the index of the corresponding
569
+ edge in the ``fringe``.
570
+
571
+ This function relies on the fact that objects are unique in
572
+ the diagram.
573
+ """
574
+ for triangle in triangles:
575
+ for (a, b) in fringe:
576
+ if frozenset([grid[a], grid[b]]) in triangle:
577
+ return (triangle, (a, b))
578
+ return None
579
+
580
+ @staticmethod
581
+ def _weld_triangle(tri, welding_edge, fringe, grid, skeleton):
582
+ """
583
+ If possible, welds the triangle ``tri`` to ``fringe`` and
584
+ returns ``False``. If this method encounters a degenerate
585
+ situation in the fringe and corrects it such that a restart of
586
+ the search is required, it returns ``True`` (which means that
587
+ a restart in finding triangle weldings is required).
588
+
589
+ A degenerate situation is a situation when an edge listed in
590
+ the fringe does not belong to the visual boundary of the
591
+ diagram.
592
+ """
593
+ a, b = welding_edge
594
+ target_cell = None
595
+
596
+ obj = DiagramGrid._other_vertex(tri, (grid[a], grid[b]))
597
+
598
+ # We now have a triangle and an edge where it can be welded to
599
+ # the fringe. Decide where to place the other vertex of the
600
+ # triangle and check for degenerate situations en route.
601
+
602
+ if (abs(a[0] - b[0]) == 1) and (abs(a[1] - b[1]) == 1):
603
+ # A diagonal edge.
604
+ target_cell = (a[0], b[1])
605
+ if grid[target_cell]:
606
+ # That cell is already occupied.
607
+ target_cell = (b[0], a[1])
608
+
609
+ if grid[target_cell]:
610
+ # Degenerate situation, this edge is not
611
+ # on the actual fringe. Correct the
612
+ # fringe and go on.
613
+ fringe.remove((a, b))
614
+ return True
615
+ elif a[0] == b[0]:
616
+ # A horizontal edge. We first attempt to build the
617
+ # triangle in the downward direction.
618
+
619
+ down_left = a[0] + 1, a[1]
620
+ down_right = a[0] + 1, b[1]
621
+
622
+ target_cell = DiagramGrid._choose_target_cell(
623
+ down_left, down_right, (a, b), obj, skeleton, grid)
624
+
625
+ if not target_cell:
626
+ # No room below this edge. Check above.
627
+ up_left = a[0] - 1, a[1]
628
+ up_right = a[0] - 1, b[1]
629
+
630
+ target_cell = DiagramGrid._choose_target_cell(
631
+ up_left, up_right, (a, b), obj, skeleton, grid)
632
+
633
+ if not target_cell:
634
+ # This edge is not in the fringe, remove it
635
+ # and restart.
636
+ fringe.remove((a, b))
637
+ return True
638
+ elif a[1] == b[1]:
639
+ # A vertical edge. We will attempt to place the other
640
+ # vertex of the triangle to the right of this edge.
641
+ right_up = a[0], a[1] + 1
642
+ right_down = b[0], a[1] + 1
643
+
644
+ target_cell = DiagramGrid._choose_target_cell(
645
+ right_up, right_down, (a, b), obj, skeleton, grid)
646
+
647
+ if not target_cell:
648
+ # No room to the left. See what's to the right.
649
+ left_up = a[0], a[1] - 1
650
+ left_down = b[0], a[1] - 1
651
+
652
+ target_cell = DiagramGrid._choose_target_cell(
653
+ left_up, left_down, (a, b), obj, skeleton, grid)
654
+
655
+ if not target_cell:
656
+ # This edge is not in the fringe, remove it
657
+ # and restart.
658
+ fringe.remove((a, b))
659
+ return True
660
+
661
+ # We now know where to place the other vertex of the
662
+ # triangle.
663
+ offset = DiagramGrid._put_object(target_cell, obj, grid, fringe)
664
+
665
+ # Take care of the displacement of coordinates if a row or
666
+ # a column was prepended.
667
+ target_cell = (target_cell[0] + offset[0],
668
+ target_cell[1] + offset[1])
669
+ a = (a[0] + offset[0], a[1] + offset[1])
670
+ b = (b[0] + offset[0], b[1] + offset[1])
671
+
672
+ fringe.extend([(a, target_cell), (b, target_cell)])
673
+
674
+ # No restart is required.
675
+ return False
676
+
677
+ @staticmethod
678
+ def _triangle_key(tri, triangle_sizes):
679
+ """
680
+ Returns a key for the supplied triangle. It should be the
681
+ same independently of the hash randomisation.
682
+ """
683
+ objects = sorted(
684
+ DiagramGrid._triangle_objects(tri), key=default_sort_key)
685
+ return (triangle_sizes[tri], default_sort_key(objects))
686
+
687
+ @staticmethod
688
+ def _pick_root_edge(tri, skeleton):
689
+ """
690
+ For a given triangle always picks the same root edge. The
691
+ root edge is the edge that will be placed first on the grid.
692
+ """
693
+ candidates = [sorted(e, key=default_sort_key)
694
+ for e in tri if skeleton[e]]
695
+ sorted_candidates = sorted(candidates, key=default_sort_key)
696
+ # Don't forget to assure the proper ordering of the vertices
697
+ # in this edge.
698
+ return tuple(sorted(sorted_candidates[0], key=default_sort_key))
699
+
700
+ @staticmethod
701
+ def _drop_irrelevant_triangles(triangles, placed_objects):
702
+ """
703
+ Returns only those triangles whose set of objects is not
704
+ completely included in ``placed_objects``.
705
+ """
706
+ return [tri for tri in triangles if not placed_objects.issuperset(
707
+ DiagramGrid._triangle_objects(tri))]
708
+
709
+ @staticmethod
710
+ def _grow_pseudopod(triangles, fringe, grid, skeleton, placed_objects):
711
+ """
712
+ Starting from an object in the existing structure on the ``grid``,
713
+ adds an edge to which a triangle from ``triangles`` could be
714
+ welded. If this method has found a way to do so, it returns
715
+ the object it has just added.
716
+
717
+ This method should be applied when ``_weld_triangle`` cannot
718
+ find weldings any more.
719
+ """
720
+ for i in range(grid.height):
721
+ for j in range(grid.width):
722
+ obj = grid[i, j]
723
+ if not obj:
724
+ continue
725
+
726
+ # Here we need to choose a triangle which has only
727
+ # ``obj`` in common with the existing structure. The
728
+ # situations when this is not possible should be
729
+ # handled elsewhere.
730
+
731
+ def good_triangle(tri):
732
+ objs = DiagramGrid._triangle_objects(tri)
733
+ return obj in objs and \
734
+ placed_objects & (objs - {obj}) == set()
735
+
736
+ tris = [tri for tri in triangles if good_triangle(tri)]
737
+ if not tris:
738
+ # This object is not interesting.
739
+ continue
740
+
741
+ # Pick the "simplest" of the triangles which could be
742
+ # attached. Remember that the list of triangles is
743
+ # sorted according to their "simplicity" (see
744
+ # _compute_triangle_min_sizes for the metric).
745
+ #
746
+ # Note that ``tris`` are sequentially built from
747
+ # ``triangles``, so we don't have to worry about hash
748
+ # randomisation.
749
+ tri = tris[0]
750
+
751
+ # We have found a triangle which could be attached to
752
+ # the existing structure by a vertex.
753
+
754
+ candidates = sorted([e for e in tri if skeleton[e]],
755
+ key=lambda e: FiniteSet(*e).sort_key())
756
+ edges = [e for e in candidates if obj in e]
757
+
758
+ # Note that a meaningful edge (i.e., and edge that is
759
+ # associated with a morphism) containing ``obj``
760
+ # always exists. That's because all triangles are
761
+ # guaranteed to have at least two meaningful edges.
762
+ # See _drop_redundant_triangles.
763
+
764
+ # Get the object at the other end of the edge.
765
+ edge = edges[0]
766
+ other_obj = tuple(edge - frozenset([obj]))[0]
767
+
768
+ # Now check for free directions. When checking for
769
+ # free directions, prefer the horizontal and vertical
770
+ # directions.
771
+ neighbours = [(i - 1, j), (i, j + 1), (i + 1, j), (i, j - 1),
772
+ (i - 1, j - 1), (i - 1, j + 1), (i + 1, j - 1), (i + 1, j + 1)]
773
+
774
+ for pt in neighbours:
775
+ if DiagramGrid._empty_point(pt, grid):
776
+ # We have a found a place to grow the
777
+ # pseudopod into.
778
+ offset = DiagramGrid._put_object(
779
+ pt, other_obj, grid, fringe)
780
+
781
+ i += offset[0]
782
+ j += offset[1]
783
+ pt = (pt[0] + offset[0], pt[1] + offset[1])
784
+ fringe.append(((i, j), pt))
785
+
786
+ return other_obj
787
+
788
+ # This diagram is actually cooler that I can handle. Fail cowardly.
789
+ return None
790
+
791
+ @staticmethod
792
+ def _handle_groups(diagram, groups, merged_morphisms, hints):
793
+ """
794
+ Given the slightly preprocessed morphisms of the diagram,
795
+ produces a grid laid out according to ``groups``.
796
+
797
+ If a group has hints, it is laid out with those hints only,
798
+ without any influence from ``hints``. Otherwise, it is laid
799
+ out with ``hints``.
800
+ """
801
+ def lay_out_group(group, local_hints):
802
+ """
803
+ If ``group`` is a set of objects, uses a ``DiagramGrid``
804
+ to lay it out and returns the grid. Otherwise returns the
805
+ object (i.e., ``group``). If ``local_hints`` is not
806
+ empty, it is supplied to ``DiagramGrid`` as the dictionary
807
+ of hints. Otherwise, the ``hints`` argument of
808
+ ``_handle_groups`` is used.
809
+ """
810
+ if isinstance(group, FiniteSet):
811
+ # Set up the corresponding object-to-group
812
+ # mappings.
813
+ for obj in group:
814
+ obj_groups[obj] = group
815
+
816
+ # Lay out the current group.
817
+ if local_hints:
818
+ groups_grids[group] = DiagramGrid(
819
+ diagram.subdiagram_from_objects(group), **local_hints)
820
+ else:
821
+ groups_grids[group] = DiagramGrid(
822
+ diagram.subdiagram_from_objects(group), **hints)
823
+ else:
824
+ obj_groups[group] = group
825
+
826
+ def group_to_finiteset(group):
827
+ """
828
+ Converts ``group`` to a :class:``FiniteSet`` if it is an
829
+ iterable.
830
+ """
831
+ if iterable(group):
832
+ return FiniteSet(*group)
833
+ else:
834
+ return group
835
+
836
+ obj_groups = {}
837
+ groups_grids = {}
838
+
839
+ # We would like to support various containers to represent
840
+ # groups. To achieve that, before laying each group out, it
841
+ # should be converted to a FiniteSet, because that is what the
842
+ # following code expects.
843
+
844
+ if isinstance(groups, (dict, Dict)):
845
+ finiteset_groups = {}
846
+ for group, local_hints in groups.items():
847
+ finiteset_group = group_to_finiteset(group)
848
+ finiteset_groups[finiteset_group] = local_hints
849
+ lay_out_group(group, local_hints)
850
+ groups = finiteset_groups
851
+ else:
852
+ finiteset_groups = []
853
+ for group in groups:
854
+ finiteset_group = group_to_finiteset(group)
855
+ finiteset_groups.append(finiteset_group)
856
+ lay_out_group(finiteset_group, None)
857
+ groups = finiteset_groups
858
+
859
+ new_morphisms = []
860
+ for morphism in merged_morphisms:
861
+ dom = obj_groups[morphism.domain]
862
+ cod = obj_groups[morphism.codomain]
863
+ # Note that we are not really interested in morphisms
864
+ # which do not employ two different groups, because
865
+ # these do not influence the layout.
866
+ if dom != cod:
867
+ # These are essentially unnamed morphisms; they are
868
+ # not going to mess in the final layout. By giving
869
+ # them the same names, we avoid unnecessary
870
+ # duplicates.
871
+ new_morphisms.append(NamedMorphism(dom, cod, "dummy"))
872
+
873
+ # Lay out the new diagram. Since these are dummy morphisms,
874
+ # properties and conclusions are irrelevant.
875
+ top_grid = DiagramGrid(Diagram(new_morphisms))
876
+
877
+ # We now have to substitute the groups with the corresponding
878
+ # grids, laid out at the beginning of this function. Compute
879
+ # the size of each row and column in the grid, so that all
880
+ # nested grids fit.
881
+
882
+ def group_size(group):
883
+ """
884
+ For the supplied group (or object, eventually), returns
885
+ the size of the cell that will hold this group (object).
886
+ """
887
+ if group in groups_grids:
888
+ grid = groups_grids[group]
889
+ return (grid.height, grid.width)
890
+ else:
891
+ return (1, 1)
892
+
893
+ row_heights = [max(group_size(top_grid[i, j])[0]
894
+ for j in range(top_grid.width))
895
+ for i in range(top_grid.height)]
896
+
897
+ column_widths = [max(group_size(top_grid[i, j])[1]
898
+ for i in range(top_grid.height))
899
+ for j in range(top_grid.width)]
900
+
901
+ grid = _GrowableGrid(sum(column_widths), sum(row_heights))
902
+
903
+ real_row = 0
904
+ real_column = 0
905
+ for logical_row in range(top_grid.height):
906
+ for logical_column in range(top_grid.width):
907
+ obj = top_grid[logical_row, logical_column]
908
+
909
+ if obj in groups_grids:
910
+ # This is a group. Copy the corresponding grid in
911
+ # place.
912
+ local_grid = groups_grids[obj]
913
+ for i in range(local_grid.height):
914
+ for j in range(local_grid.width):
915
+ grid[real_row + i,
916
+ real_column + j] = local_grid[i, j]
917
+ else:
918
+ # This is an object. Just put it there.
919
+ grid[real_row, real_column] = obj
920
+
921
+ real_column += column_widths[logical_column]
922
+ real_column = 0
923
+ real_row += row_heights[logical_row]
924
+
925
+ return grid
926
+
927
+ @staticmethod
928
+ def _generic_layout(diagram, merged_morphisms):
929
+ """
930
+ Produces the generic layout for the supplied diagram.
931
+ """
932
+ all_objects = set(diagram.objects)
933
+ if len(all_objects) == 1:
934
+ # There only one object in the diagram, just put in on 1x1
935
+ # grid.
936
+ grid = _GrowableGrid(1, 1)
937
+ grid[0, 0] = tuple(all_objects)[0]
938
+ return grid
939
+
940
+ skeleton = DiagramGrid._build_skeleton(merged_morphisms)
941
+
942
+ grid = _GrowableGrid(2, 1)
943
+
944
+ if len(skeleton) == 1:
945
+ # This diagram contains only one morphism. Draw it
946
+ # horizontally.
947
+ objects = sorted(all_objects, key=default_sort_key)
948
+ grid[0, 0] = objects[0]
949
+ grid[0, 1] = objects[1]
950
+
951
+ return grid
952
+
953
+ triangles = DiagramGrid._list_triangles(skeleton)
954
+ triangles = DiagramGrid._drop_redundant_triangles(triangles, skeleton)
955
+ triangle_sizes = DiagramGrid._compute_triangle_min_sizes(
956
+ triangles, skeleton)
957
+
958
+ triangles = sorted(triangles, key=lambda tri:
959
+ DiagramGrid._triangle_key(tri, triangle_sizes))
960
+
961
+ # Place the first edge on the grid.
962
+ root_edge = DiagramGrid._pick_root_edge(triangles[0], skeleton)
963
+ grid[0, 0], grid[0, 1] = root_edge
964
+ fringe = [((0, 0), (0, 1))]
965
+
966
+ # Record which objects we now have on the grid.
967
+ placed_objects = set(root_edge)
968
+
969
+ while placed_objects != all_objects:
970
+ welding = DiagramGrid._find_triangle_to_weld(
971
+ triangles, fringe, grid)
972
+
973
+ if welding:
974
+ (triangle, welding_edge) = welding
975
+
976
+ restart_required = DiagramGrid._weld_triangle(
977
+ triangle, welding_edge, fringe, grid, skeleton)
978
+ if restart_required:
979
+ continue
980
+
981
+ placed_objects.update(
982
+ DiagramGrid._triangle_objects(triangle))
983
+ else:
984
+ # No more weldings found. Try to attach triangles by
985
+ # vertices.
986
+ new_obj = DiagramGrid._grow_pseudopod(
987
+ triangles, fringe, grid, skeleton, placed_objects)
988
+
989
+ if not new_obj:
990
+ # No more triangles can be attached, not even by
991
+ # the edge. We will set up a new diagram out of
992
+ # what has been left, laid it out independently,
993
+ # and then attach it to this one.
994
+
995
+ remaining_objects = all_objects - placed_objects
996
+
997
+ remaining_diagram = diagram.subdiagram_from_objects(
998
+ FiniteSet(*remaining_objects))
999
+ remaining_grid = DiagramGrid(remaining_diagram)
1000
+
1001
+ # Now, let's glue ``remaining_grid`` to ``grid``.
1002
+ final_width = grid.width + remaining_grid.width
1003
+ final_height = max(grid.height, remaining_grid.height)
1004
+ final_grid = _GrowableGrid(final_width, final_height)
1005
+
1006
+ for i in range(grid.width):
1007
+ for j in range(grid.height):
1008
+ final_grid[i, j] = grid[i, j]
1009
+
1010
+ start_j = grid.width
1011
+ for i in range(remaining_grid.height):
1012
+ for j in range(remaining_grid.width):
1013
+ final_grid[i, start_j + j] = remaining_grid[i, j]
1014
+
1015
+ return final_grid
1016
+
1017
+ placed_objects.add(new_obj)
1018
+
1019
+ triangles = DiagramGrid._drop_irrelevant_triangles(
1020
+ triangles, placed_objects)
1021
+
1022
+ return grid
1023
+
1024
+ @staticmethod
1025
+ def _get_undirected_graph(objects, merged_morphisms):
1026
+ """
1027
+ Given the objects and the relevant morphisms of a diagram,
1028
+ returns the adjacency lists of the underlying undirected
1029
+ graph.
1030
+ """
1031
+ adjlists = {obj: [] for obj in objects}
1032
+
1033
+ for morphism in merged_morphisms:
1034
+ adjlists[morphism.domain].append(morphism.codomain)
1035
+ adjlists[morphism.codomain].append(morphism.domain)
1036
+
1037
+ # Assure that the objects in the adjacency list are always in
1038
+ # the same order.
1039
+ for obj in adjlists.keys():
1040
+ adjlists[obj].sort(key=default_sort_key)
1041
+
1042
+ return adjlists
1043
+
1044
+ @staticmethod
1045
+ def _sequential_layout(diagram, merged_morphisms):
1046
+ r"""
1047
+ Lays out the diagram in "sequential" layout. This method
1048
+ will attempt to produce a result as close to a line as
1049
+ possible. For linear diagrams, the result will actually be a
1050
+ line.
1051
+ """
1052
+ objects = diagram.objects
1053
+ sorted_objects = sorted(objects, key=default_sort_key)
1054
+
1055
+ # Set up the adjacency lists of the underlying undirected
1056
+ # graph of ``merged_morphisms``.
1057
+ adjlists = DiagramGrid._get_undirected_graph(objects, merged_morphisms)
1058
+
1059
+ root = min(sorted_objects, key=lambda x: len(adjlists[x]))
1060
+ grid = _GrowableGrid(1, 1)
1061
+ grid[0, 0] = root
1062
+
1063
+ placed_objects = {root}
1064
+
1065
+ def place_objects(pt, placed_objects):
1066
+ """
1067
+ Does depth-first search in the underlying graph of the
1068
+ diagram and places the objects en route.
1069
+ """
1070
+ # We will start placing new objects from here.
1071
+ new_pt = (pt[0], pt[1] + 1)
1072
+
1073
+ for adjacent_obj in adjlists[grid[pt]]:
1074
+ if adjacent_obj in placed_objects:
1075
+ # This object has already been placed.
1076
+ continue
1077
+
1078
+ DiagramGrid._put_object(new_pt, adjacent_obj, grid, [])
1079
+ placed_objects.add(adjacent_obj)
1080
+ placed_objects.update(place_objects(new_pt, placed_objects))
1081
+
1082
+ new_pt = (new_pt[0] + 1, new_pt[1])
1083
+
1084
+ return placed_objects
1085
+
1086
+ place_objects((0, 0), placed_objects)
1087
+
1088
+ return grid
1089
+
1090
+ @staticmethod
1091
+ def _drop_inessential_morphisms(merged_morphisms):
1092
+ r"""
1093
+ Removes those morphisms which should appear in the diagram,
1094
+ but which have no relevance to object layout.
1095
+
1096
+ Currently this removes "loop" morphisms: the non-identity
1097
+ morphisms with the same domains and codomains.
1098
+ """
1099
+ morphisms = [m for m in merged_morphisms if m.domain != m.codomain]
1100
+ return morphisms
1101
+
1102
+ @staticmethod
1103
+ def _get_connected_components(objects, merged_morphisms):
1104
+ """
1105
+ Given a container of morphisms, returns a list of connected
1106
+ components formed by these morphisms. A connected component
1107
+ is represented by a diagram consisting of the corresponding
1108
+ morphisms.
1109
+ """
1110
+ component_index = {}
1111
+ for o in objects:
1112
+ component_index[o] = None
1113
+
1114
+ # Get the underlying undirected graph of the diagram.
1115
+ adjlist = DiagramGrid._get_undirected_graph(objects, merged_morphisms)
1116
+
1117
+ def traverse_component(object, current_index):
1118
+ """
1119
+ Does a depth-first search traversal of the component
1120
+ containing ``object``.
1121
+ """
1122
+ component_index[object] = current_index
1123
+ for o in adjlist[object]:
1124
+ if component_index[o] is None:
1125
+ traverse_component(o, current_index)
1126
+
1127
+ # Traverse all components.
1128
+ current_index = 0
1129
+ for o in adjlist:
1130
+ if component_index[o] is None:
1131
+ traverse_component(o, current_index)
1132
+ current_index += 1
1133
+
1134
+ # List the objects of the components.
1135
+ component_objects = [[] for i in range(current_index)]
1136
+ for o, idx in component_index.items():
1137
+ component_objects[idx].append(o)
1138
+
1139
+ # Finally, list the morphisms belonging to each component.
1140
+ #
1141
+ # Note: If some objects are isolated, they will not get any
1142
+ # morphisms at this stage, and since the layout algorithm
1143
+ # relies, we are essentially going to lose this object.
1144
+ # Therefore, check if there are isolated objects and, for each
1145
+ # of them, provide the trivial identity morphism. It will get
1146
+ # discarded later, but the object will be there.
1147
+
1148
+ component_morphisms = []
1149
+ for component in component_objects:
1150
+ current_morphisms = {}
1151
+ for m in merged_morphisms:
1152
+ if (m.domain in component) and (m.codomain in component):
1153
+ current_morphisms[m] = merged_morphisms[m]
1154
+
1155
+ if len(component) == 1:
1156
+ # Let's add an identity morphism, for the sake of
1157
+ # surely having morphisms in this component.
1158
+ current_morphisms[IdentityMorphism(component[0])] = FiniteSet()
1159
+
1160
+ component_morphisms.append(Diagram(current_morphisms))
1161
+
1162
+ return component_morphisms
1163
+
1164
+ def __init__(self, diagram, groups=None, **hints):
1165
+ premises = DiagramGrid._simplify_morphisms(diagram.premises)
1166
+ conclusions = DiagramGrid._simplify_morphisms(diagram.conclusions)
1167
+ all_merged_morphisms = DiagramGrid._merge_premises_conclusions(
1168
+ premises, conclusions)
1169
+ merged_morphisms = DiagramGrid._drop_inessential_morphisms(
1170
+ all_merged_morphisms)
1171
+
1172
+ # Store the merged morphisms for later use.
1173
+ self._morphisms = all_merged_morphisms
1174
+
1175
+ components = DiagramGrid._get_connected_components(
1176
+ diagram.objects, all_merged_morphisms)
1177
+
1178
+ if groups and (groups != diagram.objects):
1179
+ # Lay out the diagram according to the groups.
1180
+ self._grid = DiagramGrid._handle_groups(
1181
+ diagram, groups, merged_morphisms, hints)
1182
+ elif len(components) > 1:
1183
+ # Note that we check for connectedness _before_ checking
1184
+ # the layout hints because the layout strategies don't
1185
+ # know how to deal with disconnected diagrams.
1186
+
1187
+ # The diagram is disconnected. Lay out the components
1188
+ # independently.
1189
+ grids = []
1190
+
1191
+ # Sort the components to eventually get the grids arranged
1192
+ # in a fixed, hash-independent order.
1193
+ components = sorted(components, key=default_sort_key)
1194
+
1195
+ for component in components:
1196
+ grid = DiagramGrid(component, **hints)
1197
+ grids.append(grid)
1198
+
1199
+ # Throw the grids together, in a line.
1200
+ total_width = sum(g.width for g in grids)
1201
+ total_height = max(g.height for g in grids)
1202
+
1203
+ grid = _GrowableGrid(total_width, total_height)
1204
+ start_j = 0
1205
+ for g in grids:
1206
+ for i in range(g.height):
1207
+ for j in range(g.width):
1208
+ grid[i, start_j + j] = g[i, j]
1209
+
1210
+ start_j += g.width
1211
+
1212
+ self._grid = grid
1213
+ elif "layout" in hints:
1214
+ if hints["layout"] == "sequential":
1215
+ self._grid = DiagramGrid._sequential_layout(
1216
+ diagram, merged_morphisms)
1217
+ else:
1218
+ self._grid = DiagramGrid._generic_layout(diagram, merged_morphisms)
1219
+
1220
+ if hints.get("transpose"):
1221
+ # Transpose the resulting grid.
1222
+ grid = _GrowableGrid(self._grid.height, self._grid.width)
1223
+ for i in range(self._grid.height):
1224
+ for j in range(self._grid.width):
1225
+ grid[j, i] = self._grid[i, j]
1226
+ self._grid = grid
1227
+
1228
+ @property
1229
+ def width(self):
1230
+ """
1231
+ Returns the number of columns in this diagram layout.
1232
+
1233
+ Examples
1234
+ ========
1235
+
1236
+ >>> from sympy.categories import Object, NamedMorphism
1237
+ >>> from sympy.categories import Diagram, DiagramGrid
1238
+ >>> A = Object("A")
1239
+ >>> B = Object("B")
1240
+ >>> C = Object("C")
1241
+ >>> f = NamedMorphism(A, B, "f")
1242
+ >>> g = NamedMorphism(B, C, "g")
1243
+ >>> diagram = Diagram([f, g])
1244
+ >>> grid = DiagramGrid(diagram)
1245
+ >>> grid.width
1246
+ 2
1247
+
1248
+ """
1249
+ return self._grid.width
1250
+
1251
+ @property
1252
+ def height(self):
1253
+ """
1254
+ Returns the number of rows in this diagram layout.
1255
+
1256
+ Examples
1257
+ ========
1258
+
1259
+ >>> from sympy.categories import Object, NamedMorphism
1260
+ >>> from sympy.categories import Diagram, DiagramGrid
1261
+ >>> A = Object("A")
1262
+ >>> B = Object("B")
1263
+ >>> C = Object("C")
1264
+ >>> f = NamedMorphism(A, B, "f")
1265
+ >>> g = NamedMorphism(B, C, "g")
1266
+ >>> diagram = Diagram([f, g])
1267
+ >>> grid = DiagramGrid(diagram)
1268
+ >>> grid.height
1269
+ 2
1270
+
1271
+ """
1272
+ return self._grid.height
1273
+
1274
+ def __getitem__(self, i_j):
1275
+ """
1276
+ Returns the object placed in the row ``i`` and column ``j``.
1277
+ The indices are 0-based.
1278
+
1279
+ Examples
1280
+ ========
1281
+
1282
+ >>> from sympy.categories import Object, NamedMorphism
1283
+ >>> from sympy.categories import Diagram, DiagramGrid
1284
+ >>> A = Object("A")
1285
+ >>> B = Object("B")
1286
+ >>> C = Object("C")
1287
+ >>> f = NamedMorphism(A, B, "f")
1288
+ >>> g = NamedMorphism(B, C, "g")
1289
+ >>> diagram = Diagram([f, g])
1290
+ >>> grid = DiagramGrid(diagram)
1291
+ >>> (grid[0, 0], grid[0, 1])
1292
+ (Object("A"), Object("B"))
1293
+ >>> (grid[1, 0], grid[1, 1])
1294
+ (None, Object("C"))
1295
+
1296
+ """
1297
+ i, j = i_j
1298
+ return self._grid[i, j]
1299
+
1300
+ @property
1301
+ def morphisms(self):
1302
+ """
1303
+ Returns those morphisms (and their properties) which are
1304
+ sufficiently meaningful to be drawn.
1305
+
1306
+ Examples
1307
+ ========
1308
+
1309
+ >>> from sympy.categories import Object, NamedMorphism
1310
+ >>> from sympy.categories import Diagram, DiagramGrid
1311
+ >>> A = Object("A")
1312
+ >>> B = Object("B")
1313
+ >>> C = Object("C")
1314
+ >>> f = NamedMorphism(A, B, "f")
1315
+ >>> g = NamedMorphism(B, C, "g")
1316
+ >>> diagram = Diagram([f, g])
1317
+ >>> grid = DiagramGrid(diagram)
1318
+ >>> grid.morphisms
1319
+ {NamedMorphism(Object("A"), Object("B"), "f"): EmptySet,
1320
+ NamedMorphism(Object("B"), Object("C"), "g"): EmptySet}
1321
+
1322
+ """
1323
+ return self._morphisms
1324
+
1325
+ def __str__(self):
1326
+ """
1327
+ Produces a string representation of this class.
1328
+
1329
+ This method returns a string representation of the underlying
1330
+ list of lists of objects.
1331
+
1332
+ Examples
1333
+ ========
1334
+
1335
+ >>> from sympy.categories import Object, NamedMorphism
1336
+ >>> from sympy.categories import Diagram, DiagramGrid
1337
+ >>> A = Object("A")
1338
+ >>> B = Object("B")
1339
+ >>> C = Object("C")
1340
+ >>> f = NamedMorphism(A, B, "f")
1341
+ >>> g = NamedMorphism(B, C, "g")
1342
+ >>> diagram = Diagram([f, g])
1343
+ >>> grid = DiagramGrid(diagram)
1344
+ >>> print(grid)
1345
+ [[Object("A"), Object("B")],
1346
+ [None, Object("C")]]
1347
+
1348
+ """
1349
+ return repr(self._grid._array)
1350
+
1351
+
1352
+ class ArrowStringDescription:
1353
+ r"""
1354
+ Stores the information necessary for producing an Xy-pic
1355
+ description of an arrow.
1356
+
1357
+ The principal goal of this class is to abstract away the string
1358
+ representation of an arrow and to also provide the functionality
1359
+ to produce the actual Xy-pic string.
1360
+
1361
+ ``unit`` sets the unit which will be used to specify the amount of
1362
+ curving and other distances. ``horizontal_direction`` should be a
1363
+ string of ``"r"`` or ``"l"`` specifying the horizontal offset of the
1364
+ target cell of the arrow relatively to the current one.
1365
+ ``vertical_direction`` should specify the vertical offset using a
1366
+ series of either ``"d"`` or ``"u"``. ``label_position`` should be
1367
+ either ``"^"``, ``"_"``, or ``"|"`` to specify that the label should
1368
+ be positioned above the arrow, below the arrow or just over the arrow,
1369
+ in a break. Note that the notions "above" and "below" are relative
1370
+ to arrow direction. ``label`` stores the morphism label.
1371
+
1372
+ This works as follows (disregard the yet unexplained arguments):
1373
+
1374
+ >>> from sympy.categories.diagram_drawing import ArrowStringDescription
1375
+ >>> astr = ArrowStringDescription(
1376
+ ... unit="mm", curving=None, curving_amount=None,
1377
+ ... looping_start=None, looping_end=None, horizontal_direction="d",
1378
+ ... vertical_direction="r", label_position="_", label="f")
1379
+ >>> print(str(astr))
1380
+ \ar[dr]_{f}
1381
+
1382
+ ``curving`` should be one of ``"^"``, ``"_"`` to specify in which
1383
+ direction the arrow is going to curve. ``curving_amount`` is a number
1384
+ describing how many ``unit``'s the morphism is going to curve:
1385
+
1386
+ >>> astr = ArrowStringDescription(
1387
+ ... unit="mm", curving="^", curving_amount=12,
1388
+ ... looping_start=None, looping_end=None, horizontal_direction="d",
1389
+ ... vertical_direction="r", label_position="_", label="f")
1390
+ >>> print(str(astr))
1391
+ \ar@/^12mm/[dr]_{f}
1392
+
1393
+ ``looping_start`` and ``looping_end`` are currently only used for
1394
+ loop morphisms, those which have the same domain and codomain.
1395
+ These two attributes should store a valid Xy-pic direction and
1396
+ specify, correspondingly, the direction the arrow gets out into
1397
+ and the direction the arrow gets back from:
1398
+
1399
+ >>> astr = ArrowStringDescription(
1400
+ ... unit="mm", curving=None, curving_amount=None,
1401
+ ... looping_start="u", looping_end="l", horizontal_direction="",
1402
+ ... vertical_direction="", label_position="_", label="f")
1403
+ >>> print(str(astr))
1404
+ \ar@(u,l)[]_{f}
1405
+
1406
+ ``label_displacement`` controls how far the arrow label is from
1407
+ the ends of the arrow. For example, to position the arrow label
1408
+ near the arrow head, use ">":
1409
+
1410
+ >>> astr = ArrowStringDescription(
1411
+ ... unit="mm", curving="^", curving_amount=12,
1412
+ ... looping_start=None, looping_end=None, horizontal_direction="d",
1413
+ ... vertical_direction="r", label_position="_", label="f")
1414
+ >>> astr.label_displacement = ">"
1415
+ >>> print(str(astr))
1416
+ \ar@/^12mm/[dr]_>{f}
1417
+
1418
+ Finally, ``arrow_style`` is used to specify the arrow style. To
1419
+ get a dashed arrow, for example, use "{-->}" as arrow style:
1420
+
1421
+ >>> astr = ArrowStringDescription(
1422
+ ... unit="mm", curving="^", curving_amount=12,
1423
+ ... looping_start=None, looping_end=None, horizontal_direction="d",
1424
+ ... vertical_direction="r", label_position="_", label="f")
1425
+ >>> astr.arrow_style = "{-->}"
1426
+ >>> print(str(astr))
1427
+ \ar@/^12mm/@{-->}[dr]_{f}
1428
+
1429
+ Notes
1430
+ =====
1431
+
1432
+ Instances of :class:`ArrowStringDescription` will be constructed
1433
+ by :class:`XypicDiagramDrawer` and provided for further use in
1434
+ formatters. The user is not expected to construct instances of
1435
+ :class:`ArrowStringDescription` themselves.
1436
+
1437
+ To be able to properly utilise this class, the reader is encouraged
1438
+ to checkout the Xy-pic user guide, available at [Xypic].
1439
+
1440
+ See Also
1441
+ ========
1442
+
1443
+ XypicDiagramDrawer
1444
+
1445
+ References
1446
+ ==========
1447
+
1448
+ .. [Xypic] https://xy-pic.sourceforge.net/
1449
+ """
1450
+ def __init__(self, unit, curving, curving_amount, looping_start,
1451
+ looping_end, horizontal_direction, vertical_direction,
1452
+ label_position, label):
1453
+ self.unit = unit
1454
+ self.curving = curving
1455
+ self.curving_amount = curving_amount
1456
+ self.looping_start = looping_start
1457
+ self.looping_end = looping_end
1458
+ self.horizontal_direction = horizontal_direction
1459
+ self.vertical_direction = vertical_direction
1460
+ self.label_position = label_position
1461
+ self.label = label
1462
+
1463
+ self.label_displacement = ""
1464
+ self.arrow_style = ""
1465
+
1466
+ # This flag shows that the position of the label of this
1467
+ # morphism was set while typesetting a curved morphism and
1468
+ # should not be modified later.
1469
+ self.forced_label_position = False
1470
+
1471
+ def __str__(self):
1472
+ if self.curving:
1473
+ curving_str = "@/%s%d%s/" % (self.curving, self.curving_amount,
1474
+ self.unit)
1475
+ else:
1476
+ curving_str = ""
1477
+
1478
+ if self.looping_start and self.looping_end:
1479
+ looping_str = "@(%s,%s)" % (self.looping_start, self.looping_end)
1480
+ else:
1481
+ looping_str = ""
1482
+
1483
+ if self.arrow_style:
1484
+
1485
+ style_str = "@" + self.arrow_style
1486
+ else:
1487
+ style_str = ""
1488
+
1489
+ return "\\ar%s%s%s[%s%s]%s%s{%s}" % \
1490
+ (curving_str, looping_str, style_str, self.horizontal_direction,
1491
+ self.vertical_direction, self.label_position,
1492
+ self.label_displacement, self.label)
1493
+
1494
+
1495
+ class XypicDiagramDrawer:
1496
+ r"""
1497
+ Given a :class:`~.Diagram` and the corresponding
1498
+ :class:`DiagramGrid`, produces the Xy-pic representation of the
1499
+ diagram.
1500
+
1501
+ The most important method in this class is ``draw``. Consider the
1502
+ following triangle diagram:
1503
+
1504
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
1505
+ >>> from sympy.categories import DiagramGrid, XypicDiagramDrawer
1506
+ >>> A = Object("A")
1507
+ >>> B = Object("B")
1508
+ >>> C = Object("C")
1509
+ >>> f = NamedMorphism(A, B, "f")
1510
+ >>> g = NamedMorphism(B, C, "g")
1511
+ >>> diagram = Diagram([f, g], {g * f: "unique"})
1512
+
1513
+ To draw this diagram, its objects need to be laid out with a
1514
+ :class:`DiagramGrid`::
1515
+
1516
+ >>> grid = DiagramGrid(diagram)
1517
+
1518
+ Finally, the drawing:
1519
+
1520
+ >>> drawer = XypicDiagramDrawer()
1521
+ >>> print(drawer.draw(diagram, grid))
1522
+ \xymatrix{
1523
+ A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\
1524
+ C &
1525
+ }
1526
+
1527
+ For further details see the docstring of this method.
1528
+
1529
+ To control the appearance of the arrows, formatters are used. The
1530
+ dictionary ``arrow_formatters`` maps morphisms to formatter
1531
+ functions. A formatter is accepts an
1532
+ :class:`ArrowStringDescription` and is allowed to modify any of
1533
+ the arrow properties exposed thereby. For example, to have all
1534
+ morphisms with the property ``unique`` appear as dashed arrows,
1535
+ and to have their names prepended with `\exists !`, the following
1536
+ should be done:
1537
+
1538
+ >>> def formatter(astr):
1539
+ ... astr.label = r"\exists !" + astr.label
1540
+ ... astr.arrow_style = "{-->}"
1541
+ >>> drawer.arrow_formatters["unique"] = formatter
1542
+ >>> print(drawer.draw(diagram, grid))
1543
+ \xymatrix{
1544
+ A \ar@{-->}[d]_{\exists !g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\
1545
+ C &
1546
+ }
1547
+
1548
+ To modify the appearance of all arrows in the diagram, set
1549
+ ``default_arrow_formatter``. For example, to place all morphism
1550
+ labels a little bit farther from the arrow head so that they look
1551
+ more centred, do as follows:
1552
+
1553
+ >>> def default_formatter(astr):
1554
+ ... astr.label_displacement = "(0.45)"
1555
+ >>> drawer.default_arrow_formatter = default_formatter
1556
+ >>> print(drawer.draw(diagram, grid))
1557
+ \xymatrix{
1558
+ A \ar@{-->}[d]_(0.45){\exists !g\circ f} \ar[r]^(0.45){f} & B \ar[ld]^(0.45){g} \\
1559
+ C &
1560
+ }
1561
+
1562
+ In some diagrams some morphisms are drawn as curved arrows.
1563
+ Consider the following diagram:
1564
+
1565
+ >>> D = Object("D")
1566
+ >>> E = Object("E")
1567
+ >>> h = NamedMorphism(D, A, "h")
1568
+ >>> k = NamedMorphism(D, B, "k")
1569
+ >>> diagram = Diagram([f, g, h, k])
1570
+ >>> grid = DiagramGrid(diagram)
1571
+ >>> drawer = XypicDiagramDrawer()
1572
+ >>> print(drawer.draw(diagram, grid))
1573
+ \xymatrix{
1574
+ A \ar[r]_{f} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_3mm/[ll]_{h} \\
1575
+ & C &
1576
+ }
1577
+
1578
+ To control how far the morphisms are curved by default, one can
1579
+ use the ``unit`` and ``default_curving_amount`` attributes:
1580
+
1581
+ >>> drawer.unit = "cm"
1582
+ >>> drawer.default_curving_amount = 1
1583
+ >>> print(drawer.draw(diagram, grid))
1584
+ \xymatrix{
1585
+ A \ar[r]_{f} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_1cm/[ll]_{h} \\
1586
+ & C &
1587
+ }
1588
+
1589
+ In some diagrams, there are multiple curved morphisms between the
1590
+ same two objects. To control by how much the curving changes
1591
+ between two such successive morphisms, use
1592
+ ``default_curving_step``:
1593
+
1594
+ >>> drawer.default_curving_step = 1
1595
+ >>> h1 = NamedMorphism(A, D, "h1")
1596
+ >>> diagram = Diagram([f, g, h, k, h1])
1597
+ >>> grid = DiagramGrid(diagram)
1598
+ >>> print(drawer.draw(diagram, grid))
1599
+ \xymatrix{
1600
+ A \ar[r]_{f} \ar@/^1cm/[rr]^{h_{1}} & B \ar[d]^{g} & D \ar[l]^{k} \ar@/_2cm/[ll]_{h} \\
1601
+ & C &
1602
+ }
1603
+
1604
+ The default value of ``default_curving_step`` is 4 units.
1605
+
1606
+ See Also
1607
+ ========
1608
+
1609
+ draw, ArrowStringDescription
1610
+ """
1611
+ def __init__(self):
1612
+ self.unit = "mm"
1613
+ self.default_curving_amount = 3
1614
+ self.default_curving_step = 4
1615
+
1616
+ # This dictionary maps properties to the corresponding arrow
1617
+ # formatters.
1618
+ self.arrow_formatters = {}
1619
+
1620
+ # This is the default arrow formatter which will be applied to
1621
+ # each arrow independently of its properties.
1622
+ self.default_arrow_formatter = None
1623
+
1624
+ @staticmethod
1625
+ def _process_loop_morphism(i, j, grid, morphisms_str_info, object_coords):
1626
+ """
1627
+ Produces the information required for constructing the string
1628
+ representation of a loop morphism. This function is invoked
1629
+ from ``_process_morphism``.
1630
+
1631
+ See Also
1632
+ ========
1633
+
1634
+ _process_morphism
1635
+ """
1636
+ curving = ""
1637
+ label_pos = "^"
1638
+ looping_start = ""
1639
+ looping_end = ""
1640
+
1641
+ # This is a loop morphism. Count how many morphisms stick
1642
+ # in each of the four quadrants. Note that straight
1643
+ # vertical and horizontal morphisms count in two quadrants
1644
+ # at the same time (i.e., a morphism going up counts both
1645
+ # in the first and the second quadrants).
1646
+
1647
+ # The usual numbering (counterclockwise) of quadrants
1648
+ # applies.
1649
+ quadrant = [0, 0, 0, 0]
1650
+
1651
+ obj = grid[i, j]
1652
+
1653
+ for m, m_str_info in morphisms_str_info.items():
1654
+ if (m.domain == obj) and (m.codomain == obj):
1655
+ # That's another loop morphism. Check how it
1656
+ # loops and mark the corresponding quadrants as
1657
+ # busy.
1658
+ (l_s, l_e) = (m_str_info.looping_start, m_str_info.looping_end)
1659
+
1660
+ if (l_s, l_e) == ("r", "u"):
1661
+ quadrant[0] += 1
1662
+ elif (l_s, l_e) == ("u", "l"):
1663
+ quadrant[1] += 1
1664
+ elif (l_s, l_e) == ("l", "d"):
1665
+ quadrant[2] += 1
1666
+ elif (l_s, l_e) == ("d", "r"):
1667
+ quadrant[3] += 1
1668
+
1669
+ continue
1670
+ if m.domain == obj:
1671
+ (end_i, end_j) = object_coords[m.codomain]
1672
+ goes_out = True
1673
+ elif m.codomain == obj:
1674
+ (end_i, end_j) = object_coords[m.domain]
1675
+ goes_out = False
1676
+ else:
1677
+ continue
1678
+
1679
+ d_i = end_i - i
1680
+ d_j = end_j - j
1681
+ m_curving = m_str_info.curving
1682
+
1683
+ if (d_i != 0) and (d_j != 0):
1684
+ # This is really a diagonal morphism. Detect the
1685
+ # quadrant.
1686
+ if (d_i > 0) and (d_j > 0):
1687
+ quadrant[0] += 1
1688
+ elif (d_i > 0) and (d_j < 0):
1689
+ quadrant[1] += 1
1690
+ elif (d_i < 0) and (d_j < 0):
1691
+ quadrant[2] += 1
1692
+ elif (d_i < 0) and (d_j > 0):
1693
+ quadrant[3] += 1
1694
+ elif d_i == 0:
1695
+ # Knowing where the other end of the morphism is
1696
+ # and which way it goes, we now have to decide
1697
+ # which quadrant is now the upper one and which is
1698
+ # the lower one.
1699
+ if d_j > 0:
1700
+ if goes_out:
1701
+ upper_quadrant = 0
1702
+ lower_quadrant = 3
1703
+ else:
1704
+ upper_quadrant = 3
1705
+ lower_quadrant = 0
1706
+ else:
1707
+ if goes_out:
1708
+ upper_quadrant = 2
1709
+ lower_quadrant = 1
1710
+ else:
1711
+ upper_quadrant = 1
1712
+ lower_quadrant = 2
1713
+
1714
+ if m_curving:
1715
+ if m_curving == "^":
1716
+ quadrant[upper_quadrant] += 1
1717
+ elif m_curving == "_":
1718
+ quadrant[lower_quadrant] += 1
1719
+ else:
1720
+ # This morphism counts in both upper and lower
1721
+ # quadrants.
1722
+ quadrant[upper_quadrant] += 1
1723
+ quadrant[lower_quadrant] += 1
1724
+ elif d_j == 0:
1725
+ # Knowing where the other end of the morphism is
1726
+ # and which way it goes, we now have to decide
1727
+ # which quadrant is now the left one and which is
1728
+ # the right one.
1729
+ if d_i < 0:
1730
+ if goes_out:
1731
+ left_quadrant = 1
1732
+ right_quadrant = 0
1733
+ else:
1734
+ left_quadrant = 0
1735
+ right_quadrant = 1
1736
+ else:
1737
+ if goes_out:
1738
+ left_quadrant = 3
1739
+ right_quadrant = 2
1740
+ else:
1741
+ left_quadrant = 2
1742
+ right_quadrant = 3
1743
+
1744
+ if m_curving:
1745
+ if m_curving == "^":
1746
+ quadrant[left_quadrant] += 1
1747
+ elif m_curving == "_":
1748
+ quadrant[right_quadrant] += 1
1749
+ else:
1750
+ # This morphism counts in both upper and lower
1751
+ # quadrants.
1752
+ quadrant[left_quadrant] += 1
1753
+ quadrant[right_quadrant] += 1
1754
+
1755
+ # Pick the freest quadrant to curve our morphism into.
1756
+ freest_quadrant = 0
1757
+ for i in range(4):
1758
+ if quadrant[i] < quadrant[freest_quadrant]:
1759
+ freest_quadrant = i
1760
+
1761
+ # Now set up proper looping.
1762
+ (looping_start, looping_end) = [("r", "u"), ("u", "l"), ("l", "d"),
1763
+ ("d", "r")][freest_quadrant]
1764
+
1765
+ return (curving, label_pos, looping_start, looping_end)
1766
+
1767
+ @staticmethod
1768
+ def _process_horizontal_morphism(i, j, target_j, grid, morphisms_str_info,
1769
+ object_coords):
1770
+ """
1771
+ Produces the information required for constructing the string
1772
+ representation of a horizontal morphism. This function is
1773
+ invoked from ``_process_morphism``.
1774
+
1775
+ See Also
1776
+ ========
1777
+
1778
+ _process_morphism
1779
+ """
1780
+ # The arrow is horizontal. Check if it goes from left to
1781
+ # right (``backwards == False``) or from right to left
1782
+ # (``backwards == True``).
1783
+ backwards = False
1784
+ start = j
1785
+ end = target_j
1786
+ if end < start:
1787
+ (start, end) = (end, start)
1788
+ backwards = True
1789
+
1790
+ # Let's see which objects are there between ``start`` and
1791
+ # ``end``, and then count how many morphisms stick out
1792
+ # upwards, and how many stick out downwards.
1793
+ #
1794
+ # For example, consider the situation:
1795
+ #
1796
+ # B1 C1
1797
+ # | |
1798
+ # A--B--C--D
1799
+ # |
1800
+ # B2
1801
+ #
1802
+ # Between the objects `A` and `D` there are two objects:
1803
+ # `B` and `C`. Further, there are two morphisms which
1804
+ # stick out upward (the ones between `B1` and `B` and
1805
+ # between `C` and `C1`) and one morphism which sticks out
1806
+ # downward (the one between `B and `B2`).
1807
+ #
1808
+ # We need this information to decide how to curve the
1809
+ # arrow between `A` and `D`. First of all, since there
1810
+ # are two objects between `A` and `D``, we must curve the
1811
+ # arrow. Then, we will have it curve downward, because
1812
+ # there is more space (less morphisms stick out downward
1813
+ # than upward).
1814
+ up = []
1815
+ down = []
1816
+ straight_horizontal = []
1817
+ for k in range(start + 1, end):
1818
+ obj = grid[i, k]
1819
+ if not obj:
1820
+ continue
1821
+
1822
+ for m in morphisms_str_info:
1823
+ if m.domain == obj:
1824
+ (end_i, end_j) = object_coords[m.codomain]
1825
+ elif m.codomain == obj:
1826
+ (end_i, end_j) = object_coords[m.domain]
1827
+ else:
1828
+ continue
1829
+
1830
+ if end_i > i:
1831
+ down.append(m)
1832
+ elif end_i < i:
1833
+ up.append(m)
1834
+ elif not morphisms_str_info[m].curving:
1835
+ # This is a straight horizontal morphism,
1836
+ # because it has no curving.
1837
+ straight_horizontal.append(m)
1838
+
1839
+ if len(up) < len(down):
1840
+ # More morphisms stick out downward than upward, let's
1841
+ # curve the morphism up.
1842
+ if backwards:
1843
+ curving = "_"
1844
+ label_pos = "_"
1845
+ else:
1846
+ curving = "^"
1847
+ label_pos = "^"
1848
+
1849
+ # Assure that the straight horizontal morphisms have
1850
+ # their labels on the lower side of the arrow.
1851
+ for m in straight_horizontal:
1852
+ (i1, j1) = object_coords[m.domain]
1853
+ (i2, j2) = object_coords[m.codomain]
1854
+
1855
+ m_str_info = morphisms_str_info[m]
1856
+ if j1 < j2:
1857
+ m_str_info.label_position = "_"
1858
+ else:
1859
+ m_str_info.label_position = "^"
1860
+
1861
+ # Don't allow any further modifications of the
1862
+ # position of this label.
1863
+ m_str_info.forced_label_position = True
1864
+ else:
1865
+ # More morphisms stick out downward than upward, let's
1866
+ # curve the morphism up.
1867
+ if backwards:
1868
+ curving = "^"
1869
+ label_pos = "^"
1870
+ else:
1871
+ curving = "_"
1872
+ label_pos = "_"
1873
+
1874
+ # Assure that the straight horizontal morphisms have
1875
+ # their labels on the upper side of the arrow.
1876
+ for m in straight_horizontal:
1877
+ (i1, j1) = object_coords[m.domain]
1878
+ (i2, j2) = object_coords[m.codomain]
1879
+
1880
+ m_str_info = morphisms_str_info[m]
1881
+ if j1 < j2:
1882
+ m_str_info.label_position = "^"
1883
+ else:
1884
+ m_str_info.label_position = "_"
1885
+
1886
+ # Don't allow any further modifications of the
1887
+ # position of this label.
1888
+ m_str_info.forced_label_position = True
1889
+
1890
+ return (curving, label_pos)
1891
+
1892
+ @staticmethod
1893
+ def _process_vertical_morphism(i, j, target_i, grid, morphisms_str_info,
1894
+ object_coords):
1895
+ """
1896
+ Produces the information required for constructing the string
1897
+ representation of a vertical morphism. This function is
1898
+ invoked from ``_process_morphism``.
1899
+
1900
+ See Also
1901
+ ========
1902
+
1903
+ _process_morphism
1904
+ """
1905
+ # This arrow is vertical. Check if it goes from top to
1906
+ # bottom (``backwards == False``) or from bottom to top
1907
+ # (``backwards == True``).
1908
+ backwards = False
1909
+ start = i
1910
+ end = target_i
1911
+ if end < start:
1912
+ (start, end) = (end, start)
1913
+ backwards = True
1914
+
1915
+ # Let's see which objects are there between ``start`` and
1916
+ # ``end``, and then count how many morphisms stick out to
1917
+ # the left, and how many stick out to the right.
1918
+ #
1919
+ # See the corresponding comment in the previous branch of
1920
+ # this if-statement for more details.
1921
+ left = []
1922
+ right = []
1923
+ straight_vertical = []
1924
+ for k in range(start + 1, end):
1925
+ obj = grid[k, j]
1926
+ if not obj:
1927
+ continue
1928
+
1929
+ for m in morphisms_str_info:
1930
+ if m.domain == obj:
1931
+ (end_i, end_j) = object_coords[m.codomain]
1932
+ elif m.codomain == obj:
1933
+ (end_i, end_j) = object_coords[m.domain]
1934
+ else:
1935
+ continue
1936
+
1937
+ if end_j > j:
1938
+ right.append(m)
1939
+ elif end_j < j:
1940
+ left.append(m)
1941
+ elif not morphisms_str_info[m].curving:
1942
+ # This is a straight vertical morphism,
1943
+ # because it has no curving.
1944
+ straight_vertical.append(m)
1945
+
1946
+ if len(left) < len(right):
1947
+ # More morphisms stick out to the left than to the
1948
+ # right, let's curve the morphism to the right.
1949
+ if backwards:
1950
+ curving = "^"
1951
+ label_pos = "^"
1952
+ else:
1953
+ curving = "_"
1954
+ label_pos = "_"
1955
+
1956
+ # Assure that the straight vertical morphisms have
1957
+ # their labels on the left side of the arrow.
1958
+ for m in straight_vertical:
1959
+ (i1, j1) = object_coords[m.domain]
1960
+ (i2, j2) = object_coords[m.codomain]
1961
+
1962
+ m_str_info = morphisms_str_info[m]
1963
+ if i1 < i2:
1964
+ m_str_info.label_position = "^"
1965
+ else:
1966
+ m_str_info.label_position = "_"
1967
+
1968
+ # Don't allow any further modifications of the
1969
+ # position of this label.
1970
+ m_str_info.forced_label_position = True
1971
+ else:
1972
+ # More morphisms stick out to the right than to the
1973
+ # left, let's curve the morphism to the left.
1974
+ if backwards:
1975
+ curving = "_"
1976
+ label_pos = "_"
1977
+ else:
1978
+ curving = "^"
1979
+ label_pos = "^"
1980
+
1981
+ # Assure that the straight vertical morphisms have
1982
+ # their labels on the right side of the arrow.
1983
+ for m in straight_vertical:
1984
+ (i1, j1) = object_coords[m.domain]
1985
+ (i2, j2) = object_coords[m.codomain]
1986
+
1987
+ m_str_info = morphisms_str_info[m]
1988
+ if i1 < i2:
1989
+ m_str_info.label_position = "_"
1990
+ else:
1991
+ m_str_info.label_position = "^"
1992
+
1993
+ # Don't allow any further modifications of the
1994
+ # position of this label.
1995
+ m_str_info.forced_label_position = True
1996
+
1997
+ return (curving, label_pos)
1998
+
1999
+ def _process_morphism(self, diagram, grid, morphism, object_coords,
2000
+ morphisms, morphisms_str_info):
2001
+ """
2002
+ Given the required information, produces the string
2003
+ representation of ``morphism``.
2004
+ """
2005
+ def repeat_string_cond(times, str_gt, str_lt):
2006
+ """
2007
+ If ``times > 0``, repeats ``str_gt`` ``times`` times.
2008
+ Otherwise, repeats ``str_lt`` ``-times`` times.
2009
+ """
2010
+ if times > 0:
2011
+ return str_gt * times
2012
+ else:
2013
+ return str_lt * (-times)
2014
+
2015
+ def count_morphisms_undirected(A, B):
2016
+ """
2017
+ Counts how many processed morphisms there are between the
2018
+ two supplied objects.
2019
+ """
2020
+ return len([m for m in morphisms_str_info
2021
+ if {m.domain, m.codomain} == {A, B}])
2022
+
2023
+ def count_morphisms_filtered(dom, cod, curving):
2024
+ """
2025
+ Counts the processed morphisms which go out of ``dom``
2026
+ into ``cod`` with curving ``curving``.
2027
+ """
2028
+ return len([m for m, m_str_info in morphisms_str_info.items()
2029
+ if (m.domain, m.codomain) == (dom, cod) and
2030
+ (m_str_info.curving == curving)])
2031
+
2032
+ (i, j) = object_coords[morphism.domain]
2033
+ (target_i, target_j) = object_coords[morphism.codomain]
2034
+
2035
+ # We now need to determine the direction of
2036
+ # the arrow.
2037
+ delta_i = target_i - i
2038
+ delta_j = target_j - j
2039
+ vertical_direction = repeat_string_cond(delta_i,
2040
+ "d", "u")
2041
+ horizontal_direction = repeat_string_cond(delta_j,
2042
+ "r", "l")
2043
+
2044
+ curving = ""
2045
+ label_pos = "^"
2046
+ looping_start = ""
2047
+ looping_end = ""
2048
+
2049
+ if (delta_i == 0) and (delta_j == 0):
2050
+ # This is a loop morphism.
2051
+ (curving, label_pos, looping_start,
2052
+ looping_end) = XypicDiagramDrawer._process_loop_morphism(
2053
+ i, j, grid, morphisms_str_info, object_coords)
2054
+ elif (delta_i == 0) and (abs(j - target_j) > 1):
2055
+ # This is a horizontal morphism.
2056
+ (curving, label_pos) = XypicDiagramDrawer._process_horizontal_morphism(
2057
+ i, j, target_j, grid, morphisms_str_info, object_coords)
2058
+ elif (delta_j == 0) and (abs(i - target_i) > 1):
2059
+ # This is a vertical morphism.
2060
+ (curving, label_pos) = XypicDiagramDrawer._process_vertical_morphism(
2061
+ i, j, target_i, grid, morphisms_str_info, object_coords)
2062
+
2063
+ count = count_morphisms_undirected(morphism.domain, morphism.codomain)
2064
+ curving_amount = ""
2065
+ if curving:
2066
+ # This morphisms should be curved anyway.
2067
+ curving_amount = self.default_curving_amount + count * \
2068
+ self.default_curving_step
2069
+ elif count:
2070
+ # There are no objects between the domain and codomain of
2071
+ # the current morphism, but this is not there already are
2072
+ # some morphisms with the same domain and codomain, so we
2073
+ # have to curve this one.
2074
+ curving = "^"
2075
+ filtered_morphisms = count_morphisms_filtered(
2076
+ morphism.domain, morphism.codomain, curving)
2077
+ curving_amount = self.default_curving_amount + \
2078
+ filtered_morphisms * \
2079
+ self.default_curving_step
2080
+
2081
+ # Let's now get the name of the morphism.
2082
+ morphism_name = ""
2083
+ if isinstance(morphism, IdentityMorphism):
2084
+ morphism_name = "id_{%s}" + latex(grid[i, j])
2085
+ elif isinstance(morphism, CompositeMorphism):
2086
+ component_names = [latex(Symbol(component.name)) for
2087
+ component in morphism.components]
2088
+ component_names.reverse()
2089
+ morphism_name = "\\circ ".join(component_names)
2090
+ elif isinstance(morphism, NamedMorphism):
2091
+ morphism_name = latex(Symbol(morphism.name))
2092
+
2093
+ return ArrowStringDescription(
2094
+ self.unit, curving, curving_amount, looping_start,
2095
+ looping_end, horizontal_direction, vertical_direction,
2096
+ label_pos, morphism_name)
2097
+
2098
+ @staticmethod
2099
+ def _check_free_space_horizontal(dom_i, dom_j, cod_j, grid):
2100
+ """
2101
+ For a horizontal morphism, checks whether there is free space
2102
+ (i.e., space not occupied by any objects) above the morphism
2103
+ or below it.
2104
+ """
2105
+ if dom_j < cod_j:
2106
+ (start, end) = (dom_j, cod_j)
2107
+ backwards = False
2108
+ else:
2109
+ (start, end) = (cod_j, dom_j)
2110
+ backwards = True
2111
+
2112
+ # Check for free space above.
2113
+ if dom_i == 0:
2114
+ free_up = True
2115
+ else:
2116
+ free_up = all(grid[dom_i - 1, j] for j in
2117
+ range(start, end + 1))
2118
+
2119
+ # Check for free space below.
2120
+ if dom_i == grid.height - 1:
2121
+ free_down = True
2122
+ else:
2123
+ free_down = not any(grid[dom_i + 1, j] for j in
2124
+ range(start, end + 1))
2125
+
2126
+ return (free_up, free_down, backwards)
2127
+
2128
+ @staticmethod
2129
+ def _check_free_space_vertical(dom_i, cod_i, dom_j, grid):
2130
+ """
2131
+ For a vertical morphism, checks whether there is free space
2132
+ (i.e., space not occupied by any objects) to the left of the
2133
+ morphism or to the right of it.
2134
+ """
2135
+ if dom_i < cod_i:
2136
+ (start, end) = (dom_i, cod_i)
2137
+ backwards = False
2138
+ else:
2139
+ (start, end) = (cod_i, dom_i)
2140
+ backwards = True
2141
+
2142
+ # Check if there's space to the left.
2143
+ if dom_j == 0:
2144
+ free_left = True
2145
+ else:
2146
+ free_left = not any(grid[i, dom_j - 1] for i in
2147
+ range(start, end + 1))
2148
+
2149
+ if dom_j == grid.width - 1:
2150
+ free_right = True
2151
+ else:
2152
+ free_right = not any(grid[i, dom_j + 1] for i in
2153
+ range(start, end + 1))
2154
+
2155
+ return (free_left, free_right, backwards)
2156
+
2157
+ @staticmethod
2158
+ def _check_free_space_diagonal(dom_i, cod_i, dom_j, cod_j, grid):
2159
+ """
2160
+ For a diagonal morphism, checks whether there is free space
2161
+ (i.e., space not occupied by any objects) above the morphism
2162
+ or below it.
2163
+ """
2164
+ def abs_xrange(start, end):
2165
+ if start < end:
2166
+ return range(start, end + 1)
2167
+ else:
2168
+ return range(end, start + 1)
2169
+
2170
+ if dom_i < cod_i and dom_j < cod_j:
2171
+ # This morphism goes from top-left to
2172
+ # bottom-right.
2173
+ (start_i, start_j) = (dom_i, dom_j)
2174
+ (end_i, end_j) = (cod_i, cod_j)
2175
+ backwards = False
2176
+ elif dom_i > cod_i and dom_j > cod_j:
2177
+ # This morphism goes from bottom-right to
2178
+ # top-left.
2179
+ (start_i, start_j) = (cod_i, cod_j)
2180
+ (end_i, end_j) = (dom_i, dom_j)
2181
+ backwards = True
2182
+ if dom_i < cod_i and dom_j > cod_j:
2183
+ # This morphism goes from top-right to
2184
+ # bottom-left.
2185
+ (start_i, start_j) = (dom_i, dom_j)
2186
+ (end_i, end_j) = (cod_i, cod_j)
2187
+ backwards = True
2188
+ elif dom_i > cod_i and dom_j < cod_j:
2189
+ # This morphism goes from bottom-left to
2190
+ # top-right.
2191
+ (start_i, start_j) = (cod_i, cod_j)
2192
+ (end_i, end_j) = (dom_i, dom_j)
2193
+ backwards = False
2194
+
2195
+ # This is an attempt at a fast and furious strategy to
2196
+ # decide where there is free space on the two sides of
2197
+ # a diagonal morphism. For a diagonal morphism
2198
+ # starting at ``(start_i, start_j)`` and ending at
2199
+ # ``(end_i, end_j)`` the rectangle defined by these
2200
+ # two points is considered. The slope of the diagonal
2201
+ # ``alpha`` is then computed. Then, for every cell
2202
+ # ``(i, j)`` within the rectangle, the slope
2203
+ # ``alpha1`` of the line through ``(start_i,
2204
+ # start_j)`` and ``(i, j)`` is considered. If
2205
+ # ``alpha1`` is between 0 and ``alpha``, the point
2206
+ # ``(i, j)`` is above the diagonal, if ``alpha1`` is
2207
+ # between ``alpha`` and infinity, the point is below
2208
+ # the diagonal. Also note that, with some beforehand
2209
+ # precautions, this trick works for both the main and
2210
+ # the secondary diagonals of the rectangle.
2211
+
2212
+ # I have considered the possibility to only follow the
2213
+ # shorter diagonals immediately above and below the
2214
+ # main (or secondary) diagonal. This, however,
2215
+ # wouldn't have resulted in much performance gain or
2216
+ # better detection of outer edges, because of
2217
+ # relatively small sizes of diagram grids, while the
2218
+ # code would have become harder to understand.
2219
+
2220
+ alpha = float(end_i - start_i)/(end_j - start_j)
2221
+ free_up = True
2222
+ free_down = True
2223
+ for i in abs_xrange(start_i, end_i):
2224
+ if not free_up and not free_down:
2225
+ break
2226
+
2227
+ for j in abs_xrange(start_j, end_j):
2228
+ if not free_up and not free_down:
2229
+ break
2230
+
2231
+ if (i, j) == (start_i, start_j):
2232
+ continue
2233
+
2234
+ if j == start_j:
2235
+ alpha1 = "inf"
2236
+ else:
2237
+ alpha1 = float(i - start_i)/(j - start_j)
2238
+
2239
+ if grid[i, j]:
2240
+ if (alpha1 == "inf") or (abs(alpha1) > abs(alpha)):
2241
+ free_down = False
2242
+ elif abs(alpha1) < abs(alpha):
2243
+ free_up = False
2244
+
2245
+ return (free_up, free_down, backwards)
2246
+
2247
+ def _push_labels_out(self, morphisms_str_info, grid, object_coords):
2248
+ """
2249
+ For all straight morphisms which form the visual boundary of
2250
+ the laid out diagram, puts their labels on their outer sides.
2251
+ """
2252
+ def set_label_position(free1, free2, pos1, pos2, backwards, m_str_info):
2253
+ """
2254
+ Given the information about room available to one side and
2255
+ to the other side of a morphism (``free1`` and ``free2``),
2256
+ sets the position of the morphism label in such a way that
2257
+ it is on the freer side. This latter operations involves
2258
+ choice between ``pos1`` and ``pos2``, taking ``backwards``
2259
+ in consideration.
2260
+
2261
+ Thus this function will do nothing if either both ``free1
2262
+ == True`` and ``free2 == True`` or both ``free1 == False``
2263
+ and ``free2 == False``. In either case, choosing one side
2264
+ over the other presents no advantage.
2265
+ """
2266
+ if backwards:
2267
+ (pos1, pos2) = (pos2, pos1)
2268
+
2269
+ if free1 and not free2:
2270
+ m_str_info.label_position = pos1
2271
+ elif free2 and not free1:
2272
+ m_str_info.label_position = pos2
2273
+
2274
+ for m, m_str_info in morphisms_str_info.items():
2275
+ if m_str_info.curving or m_str_info.forced_label_position:
2276
+ # This is either a curved morphism, and curved
2277
+ # morphisms have other magic, or the position of this
2278
+ # label has already been fixed.
2279
+ continue
2280
+
2281
+ if m.domain == m.codomain:
2282
+ # This is a loop morphism, their labels, again have a
2283
+ # different magic.
2284
+ continue
2285
+
2286
+ (dom_i, dom_j) = object_coords[m.domain]
2287
+ (cod_i, cod_j) = object_coords[m.codomain]
2288
+
2289
+ if dom_i == cod_i:
2290
+ # Horizontal morphism.
2291
+ (free_up, free_down,
2292
+ backwards) = XypicDiagramDrawer._check_free_space_horizontal(
2293
+ dom_i, dom_j, cod_j, grid)
2294
+
2295
+ set_label_position(free_up, free_down, "^", "_",
2296
+ backwards, m_str_info)
2297
+ elif dom_j == cod_j:
2298
+ # Vertical morphism.
2299
+ (free_left, free_right,
2300
+ backwards) = XypicDiagramDrawer._check_free_space_vertical(
2301
+ dom_i, cod_i, dom_j, grid)
2302
+
2303
+ set_label_position(free_left, free_right, "_", "^",
2304
+ backwards, m_str_info)
2305
+ else:
2306
+ # A diagonal morphism.
2307
+ (free_up, free_down,
2308
+ backwards) = XypicDiagramDrawer._check_free_space_diagonal(
2309
+ dom_i, cod_i, dom_j, cod_j, grid)
2310
+
2311
+ set_label_position(free_up, free_down, "^", "_",
2312
+ backwards, m_str_info)
2313
+
2314
+ @staticmethod
2315
+ def _morphism_sort_key(morphism, object_coords):
2316
+ """
2317
+ Provides a morphism sorting key such that horizontal or
2318
+ vertical morphisms between neighbouring objects come
2319
+ first, then horizontal or vertical morphisms between more
2320
+ far away objects, and finally, all other morphisms.
2321
+ """
2322
+ (i, j) = object_coords[morphism.domain]
2323
+ (target_i, target_j) = object_coords[morphism.codomain]
2324
+
2325
+ if morphism.domain == morphism.codomain:
2326
+ # Loop morphisms should get after diagonal morphisms
2327
+ # so that the proper direction in which to curve the
2328
+ # loop can be determined.
2329
+ return (3, 0, default_sort_key(morphism))
2330
+
2331
+ if target_i == i:
2332
+ return (1, abs(target_j - j), default_sort_key(morphism))
2333
+
2334
+ if target_j == j:
2335
+ return (1, abs(target_i - i), default_sort_key(morphism))
2336
+
2337
+ # Diagonal morphism.
2338
+ return (2, 0, default_sort_key(morphism))
2339
+
2340
+ @staticmethod
2341
+ def _build_xypic_string(diagram, grid, morphisms,
2342
+ morphisms_str_info, diagram_format):
2343
+ """
2344
+ Given a collection of :class:`ArrowStringDescription`
2345
+ describing the morphisms of a diagram and the object layout
2346
+ information of a diagram, produces the final Xy-pic picture.
2347
+ """
2348
+ # Build the mapping between objects and morphisms which have
2349
+ # them as domains.
2350
+ object_morphisms = {}
2351
+ for obj in diagram.objects:
2352
+ object_morphisms[obj] = []
2353
+ for morphism in morphisms:
2354
+ object_morphisms[morphism.domain].append(morphism)
2355
+
2356
+ result = "\\xymatrix%s{\n" % diagram_format
2357
+
2358
+ for i in range(grid.height):
2359
+ for j in range(grid.width):
2360
+ obj = grid[i, j]
2361
+ if obj:
2362
+ result += latex(obj) + " "
2363
+
2364
+ morphisms_to_draw = object_morphisms[obj]
2365
+ for morphism in morphisms_to_draw:
2366
+ result += str(morphisms_str_info[morphism]) + " "
2367
+
2368
+ # Don't put the & after the last column.
2369
+ if j < grid.width - 1:
2370
+ result += "& "
2371
+
2372
+ # Don't put the line break after the last row.
2373
+ if i < grid.height - 1:
2374
+ result += "\\\\"
2375
+ result += "\n"
2376
+
2377
+ result += "}\n"
2378
+
2379
+ return result
2380
+
2381
+ def draw(self, diagram, grid, masked=None, diagram_format=""):
2382
+ r"""
2383
+ Returns the Xy-pic representation of ``diagram`` laid out in
2384
+ ``grid``.
2385
+
2386
+ Consider the following simple triangle diagram.
2387
+
2388
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
2389
+ >>> from sympy.categories import DiagramGrid, XypicDiagramDrawer
2390
+ >>> A = Object("A")
2391
+ >>> B = Object("B")
2392
+ >>> C = Object("C")
2393
+ >>> f = NamedMorphism(A, B, "f")
2394
+ >>> g = NamedMorphism(B, C, "g")
2395
+ >>> diagram = Diagram([f, g], {g * f: "unique"})
2396
+
2397
+ To draw this diagram, its objects need to be laid out with a
2398
+ :class:`DiagramGrid`::
2399
+
2400
+ >>> grid = DiagramGrid(diagram)
2401
+
2402
+ Finally, the drawing:
2403
+
2404
+ >>> drawer = XypicDiagramDrawer()
2405
+ >>> print(drawer.draw(diagram, grid))
2406
+ \xymatrix{
2407
+ A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\
2408
+ C &
2409
+ }
2410
+
2411
+ The argument ``masked`` can be used to skip morphisms in the
2412
+ presentation of the diagram:
2413
+
2414
+ >>> print(drawer.draw(diagram, grid, masked=[g * f]))
2415
+ \xymatrix{
2416
+ A \ar[r]^{f} & B \ar[ld]^{g} \\
2417
+ C &
2418
+ }
2419
+
2420
+ Finally, the ``diagram_format`` argument can be used to
2421
+ specify the format string of the diagram. For example, to
2422
+ increase the spacing by 1 cm, proceeding as follows:
2423
+
2424
+ >>> print(drawer.draw(diagram, grid, diagram_format="@+1cm"))
2425
+ \xymatrix@+1cm{
2426
+ A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\
2427
+ C &
2428
+ }
2429
+
2430
+ """
2431
+ # This method works in several steps. It starts by removing
2432
+ # the masked morphisms, if necessary, and then maps objects to
2433
+ # their positions in the grid (coordinate tuples). Remember
2434
+ # that objects are unique in ``Diagram`` and in the layout
2435
+ # produced by ``DiagramGrid``, so every object is mapped to a
2436
+ # single coordinate pair.
2437
+ #
2438
+ # The next step is the central step and is concerned with
2439
+ # analysing the morphisms of the diagram and deciding how to
2440
+ # draw them. For example, how to curve the arrows is decided
2441
+ # at this step. The bulk of the analysis is implemented in
2442
+ # ``_process_morphism``, to the result of which the
2443
+ # appropriate formatters are applied.
2444
+ #
2445
+ # The result of the previous step is a list of
2446
+ # ``ArrowStringDescription``. After the analysis and
2447
+ # application of formatters, some extra logic tries to assure
2448
+ # better positioning of morphism labels (for example, an
2449
+ # attempt is made to avoid the situations when arrows cross
2450
+ # labels). This functionality constitutes the next step and
2451
+ # is implemented in ``_push_labels_out``. Note that label
2452
+ # positions which have been set via a formatter are not
2453
+ # affected in this step.
2454
+ #
2455
+ # Finally, at the closing step, the array of
2456
+ # ``ArrowStringDescription`` and the layout information
2457
+ # incorporated in ``DiagramGrid`` are combined to produce the
2458
+ # resulting Xy-pic picture. This part of code lies in
2459
+ # ``_build_xypic_string``.
2460
+
2461
+ if not masked:
2462
+ morphisms_props = grid.morphisms
2463
+ else:
2464
+ morphisms_props = {}
2465
+ for m, props in grid.morphisms.items():
2466
+ if m in masked:
2467
+ continue
2468
+ morphisms_props[m] = props
2469
+
2470
+ # Build the mapping between objects and their position in the
2471
+ # grid.
2472
+ object_coords = {}
2473
+ for i in range(grid.height):
2474
+ for j in range(grid.width):
2475
+ if grid[i, j]:
2476
+ object_coords[grid[i, j]] = (i, j)
2477
+
2478
+ morphisms = sorted(morphisms_props,
2479
+ key=lambda m: XypicDiagramDrawer._morphism_sort_key(
2480
+ m, object_coords))
2481
+
2482
+ # Build the tuples defining the string representations of
2483
+ # morphisms.
2484
+ morphisms_str_info = {}
2485
+ for morphism in morphisms:
2486
+ string_description = self._process_morphism(
2487
+ diagram, grid, morphism, object_coords, morphisms,
2488
+ morphisms_str_info)
2489
+
2490
+ if self.default_arrow_formatter:
2491
+ self.default_arrow_formatter(string_description)
2492
+
2493
+ for prop in morphisms_props[morphism]:
2494
+ # prop is a Symbol. TODO: Find out why.
2495
+ if prop.name in self.arrow_formatters:
2496
+ formatter = self.arrow_formatters[prop.name]
2497
+ formatter(string_description)
2498
+
2499
+ morphisms_str_info[morphism] = string_description
2500
+
2501
+ # Reposition the labels a bit.
2502
+ self._push_labels_out(morphisms_str_info, grid, object_coords)
2503
+
2504
+ return XypicDiagramDrawer._build_xypic_string(
2505
+ diagram, grid, morphisms, morphisms_str_info, diagram_format)
2506
+
2507
+
2508
+ def xypic_draw_diagram(diagram, masked=None, diagram_format="",
2509
+ groups=None, **hints):
2510
+ r"""
2511
+ Provides a shortcut combining :class:`DiagramGrid` and
2512
+ :class:`XypicDiagramDrawer`. Returns an Xy-pic presentation of
2513
+ ``diagram``. The argument ``masked`` is a list of morphisms which
2514
+ will be not be drawn. The argument ``diagram_format`` is the
2515
+ format string inserted after "\xymatrix". ``groups`` should be a
2516
+ set of logical groups. The ``hints`` will be passed directly to
2517
+ the constructor of :class:`DiagramGrid`.
2518
+
2519
+ For more information about the arguments, see the docstrings of
2520
+ :class:`DiagramGrid` and ``XypicDiagramDrawer.draw``.
2521
+
2522
+ Examples
2523
+ ========
2524
+
2525
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
2526
+ >>> from sympy.categories import xypic_draw_diagram
2527
+ >>> A = Object("A")
2528
+ >>> B = Object("B")
2529
+ >>> C = Object("C")
2530
+ >>> f = NamedMorphism(A, B, "f")
2531
+ >>> g = NamedMorphism(B, C, "g")
2532
+ >>> diagram = Diagram([f, g], {g * f: "unique"})
2533
+ >>> print(xypic_draw_diagram(diagram))
2534
+ \xymatrix{
2535
+ A \ar[d]_{g\circ f} \ar[r]^{f} & B \ar[ld]^{g} \\
2536
+ C &
2537
+ }
2538
+
2539
+ See Also
2540
+ ========
2541
+
2542
+ XypicDiagramDrawer, DiagramGrid
2543
+ """
2544
+ grid = DiagramGrid(diagram, groups, **hints)
2545
+ drawer = XypicDiagramDrawer()
2546
+ return drawer.draw(diagram, grid, masked, diagram_format)
2547
+
2548
+
2549
+ @doctest_depends_on(exe=('latex', 'dvipng'), modules=('pyglet',))
2550
+ def preview_diagram(diagram, masked=None, diagram_format="", groups=None,
2551
+ output='png', viewer=None, euler=True, **hints):
2552
+ """
2553
+ Combines the functionality of ``xypic_draw_diagram`` and
2554
+ ``sympy.printing.preview``. The arguments ``masked``,
2555
+ ``diagram_format``, ``groups``, and ``hints`` are passed to
2556
+ ``xypic_draw_diagram``, while ``output``, ``viewer, and ``euler``
2557
+ are passed to ``preview``.
2558
+
2559
+ Examples
2560
+ ========
2561
+
2562
+ >>> from sympy.categories import Object, NamedMorphism, Diagram
2563
+ >>> from sympy.categories import preview_diagram
2564
+ >>> A = Object("A")
2565
+ >>> B = Object("B")
2566
+ >>> C = Object("C")
2567
+ >>> f = NamedMorphism(A, B, "f")
2568
+ >>> g = NamedMorphism(B, C, "g")
2569
+ >>> d = Diagram([f, g], {g * f: "unique"})
2570
+ >>> preview_diagram(d)
2571
+
2572
+ See Also
2573
+ ========
2574
+
2575
+ XypicDiagramDrawer
2576
+ """
2577
+ from sympy.printing import preview
2578
+ latex_output = xypic_draw_diagram(diagram, masked, diagram_format,
2579
+ groups, **hints)
2580
+ preview(latex_output, output, viewer, euler, ("xypic",))
.venv/lib/python3.13/site-packages/sympy/categories/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/categories/tests/test_baseclasses.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.categories import (Object, Morphism, IdentityMorphism,
2
+ NamedMorphism, CompositeMorphism,
3
+ Diagram, Category)
4
+ from sympy.categories.baseclasses import Class
5
+ from sympy.testing.pytest import raises
6
+ from sympy.core.containers import (Dict, Tuple)
7
+ from sympy.sets import EmptySet
8
+ from sympy.sets.sets import FiniteSet
9
+
10
+
11
+ def test_morphisms():
12
+ A = Object("A")
13
+ B = Object("B")
14
+ C = Object("C")
15
+ D = Object("D")
16
+
17
+ # Test the base morphism.
18
+ f = NamedMorphism(A, B, "f")
19
+ assert f.domain == A
20
+ assert f.codomain == B
21
+ assert f == NamedMorphism(A, B, "f")
22
+
23
+ # Test identities.
24
+ id_A = IdentityMorphism(A)
25
+ id_B = IdentityMorphism(B)
26
+ assert id_A.domain == A
27
+ assert id_A.codomain == A
28
+ assert id_A == IdentityMorphism(A)
29
+ assert id_A != id_B
30
+
31
+ # Test named morphisms.
32
+ g = NamedMorphism(B, C, "g")
33
+ assert g.name == "g"
34
+ assert g != f
35
+ assert g == NamedMorphism(B, C, "g")
36
+ assert g != NamedMorphism(B, C, "f")
37
+
38
+ # Test composite morphisms.
39
+ assert f == CompositeMorphism(f)
40
+
41
+ k = g.compose(f)
42
+ assert k.domain == A
43
+ assert k.codomain == C
44
+ assert k.components == Tuple(f, g)
45
+ assert g * f == k
46
+ assert CompositeMorphism(f, g) == k
47
+
48
+ assert CompositeMorphism(g * f) == g * f
49
+
50
+ # Test the associativity of composition.
51
+ h = NamedMorphism(C, D, "h")
52
+
53
+ p = h * g
54
+ u = h * g * f
55
+
56
+ assert h * k == u
57
+ assert p * f == u
58
+ assert CompositeMorphism(f, g, h) == u
59
+
60
+ # Test flattening.
61
+ u2 = u.flatten("u")
62
+ assert isinstance(u2, NamedMorphism)
63
+ assert u2.name == "u"
64
+ assert u2.domain == A
65
+ assert u2.codomain == D
66
+
67
+ # Test identities.
68
+ assert f * id_A == f
69
+ assert id_B * f == f
70
+ assert id_A * id_A == id_A
71
+ assert CompositeMorphism(id_A) == id_A
72
+
73
+ # Test bad compositions.
74
+ raises(ValueError, lambda: f * g)
75
+
76
+ raises(TypeError, lambda: f.compose(None))
77
+ raises(TypeError, lambda: id_A.compose(None))
78
+ raises(TypeError, lambda: f * None)
79
+ raises(TypeError, lambda: id_A * None)
80
+
81
+ raises(TypeError, lambda: CompositeMorphism(f, None, 1))
82
+
83
+ raises(ValueError, lambda: NamedMorphism(A, B, ""))
84
+ raises(NotImplementedError, lambda: Morphism(A, B))
85
+
86
+
87
+ def test_diagram():
88
+ A = Object("A")
89
+ B = Object("B")
90
+ C = Object("C")
91
+
92
+ f = NamedMorphism(A, B, "f")
93
+ g = NamedMorphism(B, C, "g")
94
+ id_A = IdentityMorphism(A)
95
+ id_B = IdentityMorphism(B)
96
+
97
+ empty = EmptySet
98
+
99
+ # Test the addition of identities.
100
+ d1 = Diagram([f])
101
+
102
+ assert d1.objects == FiniteSet(A, B)
103
+ assert d1.hom(A, B) == (FiniteSet(f), empty)
104
+ assert d1.hom(A, A) == (FiniteSet(id_A), empty)
105
+ assert d1.hom(B, B) == (FiniteSet(id_B), empty)
106
+
107
+ assert d1 == Diagram([id_A, f])
108
+ assert d1 == Diagram([f, f])
109
+
110
+ # Test the addition of composites.
111
+ d2 = Diagram([f, g])
112
+ homAC = d2.hom(A, C)[0]
113
+
114
+ assert d2.objects == FiniteSet(A, B, C)
115
+ assert g * f in d2.premises.keys()
116
+ assert homAC == FiniteSet(g * f)
117
+
118
+ # Test equality, inequality and hash.
119
+ d11 = Diagram([f])
120
+
121
+ assert d1 == d11
122
+ assert d1 != d2
123
+ assert hash(d1) == hash(d11)
124
+
125
+ d11 = Diagram({f: "unique"})
126
+ assert d1 != d11
127
+
128
+ # Make sure that (re-)adding composites (with new properties)
129
+ # works as expected.
130
+ d = Diagram([f, g], {g * f: "unique"})
131
+ assert d.conclusions == Dict({g * f: FiniteSet("unique")})
132
+
133
+ # Check the hom-sets when there are premises and conclusions.
134
+ assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f))
135
+ d = Diagram([f, g], [g * f])
136
+ assert d.hom(A, C) == (FiniteSet(g * f), FiniteSet(g * f))
137
+
138
+ # Check how the properties of composite morphisms are computed.
139
+ d = Diagram({f: ["unique", "isomorphism"], g: "unique"})
140
+ assert d.premises[g * f] == FiniteSet("unique")
141
+
142
+ # Check that conclusion morphisms with new objects are not allowed.
143
+ d = Diagram([f], [g])
144
+ assert d.conclusions == Dict({})
145
+
146
+ # Test an empty diagram.
147
+ d = Diagram()
148
+ assert d.premises == Dict({})
149
+ assert d.conclusions == Dict({})
150
+ assert d.objects == empty
151
+
152
+ # Check a SymPy Dict object.
153
+ d = Diagram(Dict({f: FiniteSet("unique", "isomorphism"), g: "unique"}))
154
+ assert d.premises[g * f] == FiniteSet("unique")
155
+
156
+ # Check the addition of components of composite morphisms.
157
+ d = Diagram([g * f])
158
+ assert f in d.premises
159
+ assert g in d.premises
160
+
161
+ # Check subdiagrams.
162
+ d = Diagram([f, g], {g * f: "unique"})
163
+
164
+ d1 = Diagram([f])
165
+ assert d.is_subdiagram(d1)
166
+ assert not d1.is_subdiagram(d)
167
+
168
+ d = Diagram([NamedMorphism(B, A, "f'")])
169
+ assert not d.is_subdiagram(d1)
170
+ assert not d1.is_subdiagram(d)
171
+
172
+ d1 = Diagram([f, g], {g * f: ["unique", "something"]})
173
+ assert not d.is_subdiagram(d1)
174
+ assert not d1.is_subdiagram(d)
175
+
176
+ d = Diagram({f: "blooh"})
177
+ d1 = Diagram({f: "bleeh"})
178
+ assert not d.is_subdiagram(d1)
179
+ assert not d1.is_subdiagram(d)
180
+
181
+ d = Diagram([f, g], {f: "unique", g * f: "veryunique"})
182
+ d1 = d.subdiagram_from_objects(FiniteSet(A, B))
183
+ assert d1 == Diagram([f], {f: "unique"})
184
+ raises(ValueError, lambda: d.subdiagram_from_objects(FiniteSet(A,
185
+ Object("D"))))
186
+
187
+ raises(ValueError, lambda: Diagram({IdentityMorphism(A): "unique"}))
188
+
189
+
190
+ def test_category():
191
+ A = Object("A")
192
+ B = Object("B")
193
+ C = Object("C")
194
+
195
+ f = NamedMorphism(A, B, "f")
196
+ g = NamedMorphism(B, C, "g")
197
+
198
+ d1 = Diagram([f, g])
199
+ d2 = Diagram([f])
200
+
201
+ objects = d1.objects | d2.objects
202
+
203
+ K = Category("K", objects, commutative_diagrams=[d1, d2])
204
+
205
+ assert K.name == "K"
206
+ assert K.objects == Class(objects)
207
+ assert K.commutative_diagrams == FiniteSet(d1, d2)
208
+
209
+ raises(ValueError, lambda: Category(""))
.venv/lib/python3.13/site-packages/sympy/categories/tests/test_drawing.py ADDED
@@ -0,0 +1,919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.categories.diagram_drawing import _GrowableGrid, ArrowStringDescription
2
+ from sympy.categories import (DiagramGrid, Object, NamedMorphism,
3
+ Diagram, XypicDiagramDrawer, xypic_draw_diagram)
4
+ from sympy.sets.sets import FiniteSet
5
+
6
+
7
+ def test_GrowableGrid():
8
+ grid = _GrowableGrid(1, 2)
9
+
10
+ # Check dimensions.
11
+ assert grid.width == 1
12
+ assert grid.height == 2
13
+
14
+ # Check initialization of elements.
15
+ assert grid[0, 0] is None
16
+ assert grid[1, 0] is None
17
+
18
+ # Check assignment to elements.
19
+ grid[0, 0] = 1
20
+ grid[1, 0] = "two"
21
+
22
+ assert grid[0, 0] == 1
23
+ assert grid[1, 0] == "two"
24
+
25
+ # Check appending a row.
26
+ grid.append_row()
27
+
28
+ assert grid.width == 1
29
+ assert grid.height == 3
30
+
31
+ assert grid[0, 0] == 1
32
+ assert grid[1, 0] == "two"
33
+ assert grid[2, 0] is None
34
+
35
+ # Check appending a column.
36
+ grid.append_column()
37
+ assert grid.width == 2
38
+ assert grid.height == 3
39
+
40
+ assert grid[0, 0] == 1
41
+ assert grid[1, 0] == "two"
42
+ assert grid[2, 0] is None
43
+
44
+ assert grid[0, 1] is None
45
+ assert grid[1, 1] is None
46
+ assert grid[2, 1] is None
47
+
48
+ grid = _GrowableGrid(1, 2)
49
+ grid[0, 0] = 1
50
+ grid[1, 0] = "two"
51
+
52
+ # Check prepending a row.
53
+ grid.prepend_row()
54
+ assert grid.width == 1
55
+ assert grid.height == 3
56
+
57
+ assert grid[0, 0] is None
58
+ assert grid[1, 0] == 1
59
+ assert grid[2, 0] == "two"
60
+
61
+ # Check prepending a column.
62
+ grid.prepend_column()
63
+ assert grid.width == 2
64
+ assert grid.height == 3
65
+
66
+ assert grid[0, 0] is None
67
+ assert grid[1, 0] is None
68
+ assert grid[2, 0] is None
69
+
70
+ assert grid[0, 1] is None
71
+ assert grid[1, 1] == 1
72
+ assert grid[2, 1] == "two"
73
+
74
+
75
+ def test_DiagramGrid():
76
+ # Set up some objects and morphisms.
77
+ A = Object("A")
78
+ B = Object("B")
79
+ C = Object("C")
80
+ D = Object("D")
81
+ E = Object("E")
82
+
83
+ f = NamedMorphism(A, B, "f")
84
+ g = NamedMorphism(B, C, "g")
85
+ h = NamedMorphism(D, A, "h")
86
+ k = NamedMorphism(D, B, "k")
87
+
88
+ # A one-morphism diagram.
89
+ d = Diagram([f])
90
+ grid = DiagramGrid(d)
91
+
92
+ assert grid.width == 2
93
+ assert grid.height == 1
94
+ assert grid[0, 0] == A
95
+ assert grid[0, 1] == B
96
+ assert grid.morphisms == {f: FiniteSet()}
97
+
98
+ # A triangle.
99
+ d = Diagram([f, g], {g * f: "unique"})
100
+ grid = DiagramGrid(d)
101
+
102
+ assert grid.width == 2
103
+ assert grid.height == 2
104
+ assert grid[0, 0] == A
105
+ assert grid[0, 1] == B
106
+ assert grid[1, 0] == C
107
+ assert grid[1, 1] is None
108
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(),
109
+ g * f: FiniteSet("unique")}
110
+
111
+ # A triangle with a "loop" morphism.
112
+ l_A = NamedMorphism(A, A, "l_A")
113
+ d = Diagram([f, g, l_A])
114
+ grid = DiagramGrid(d)
115
+
116
+ assert grid.width == 2
117
+ assert grid.height == 2
118
+ assert grid[0, 0] == A
119
+ assert grid[0, 1] == B
120
+ assert grid[1, 0] is None
121
+ assert grid[1, 1] == C
122
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), l_A: FiniteSet()}
123
+
124
+ # A simple diagram.
125
+ d = Diagram([f, g, h, k])
126
+ grid = DiagramGrid(d)
127
+
128
+ assert grid.width == 3
129
+ assert grid.height == 2
130
+ assert grid[0, 0] == A
131
+ assert grid[0, 1] == B
132
+ assert grid[0, 2] == D
133
+ assert grid[1, 0] is None
134
+ assert grid[1, 1] == C
135
+ assert grid[1, 2] is None
136
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
137
+ k: FiniteSet()}
138
+
139
+ assert str(grid) == '[[Object("A"), Object("B"), Object("D")], ' \
140
+ '[None, Object("C"), None]]'
141
+
142
+ # A chain of morphisms.
143
+ f = NamedMorphism(A, B, "f")
144
+ g = NamedMorphism(B, C, "g")
145
+ h = NamedMorphism(C, D, "h")
146
+ k = NamedMorphism(D, E, "k")
147
+ d = Diagram([f, g, h, k])
148
+ grid = DiagramGrid(d)
149
+
150
+ assert grid.width == 3
151
+ assert grid.height == 3
152
+ assert grid[0, 0] == A
153
+ assert grid[0, 1] == B
154
+ assert grid[0, 2] is None
155
+ assert grid[1, 0] is None
156
+ assert grid[1, 1] == C
157
+ assert grid[1, 2] == D
158
+ assert grid[2, 0] is None
159
+ assert grid[2, 1] is None
160
+ assert grid[2, 2] == E
161
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
162
+ k: FiniteSet()}
163
+
164
+ # A square.
165
+ f = NamedMorphism(A, B, "f")
166
+ g = NamedMorphism(B, D, "g")
167
+ h = NamedMorphism(A, C, "h")
168
+ k = NamedMorphism(C, D, "k")
169
+ d = Diagram([f, g, h, k])
170
+ grid = DiagramGrid(d)
171
+
172
+ assert grid.width == 2
173
+ assert grid.height == 2
174
+ assert grid[0, 0] == A
175
+ assert grid[0, 1] == B
176
+ assert grid[1, 0] == C
177
+ assert grid[1, 1] == D
178
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
179
+ k: FiniteSet()}
180
+
181
+ # A strange diagram which resulted from a typo when creating a
182
+ # test for five lemma, but which allowed to stop one extra problem
183
+ # in the algorithm.
184
+ A = Object("A")
185
+ B = Object("B")
186
+ C = Object("C")
187
+ D = Object("D")
188
+ E = Object("E")
189
+ A_ = Object("A'")
190
+ B_ = Object("B'")
191
+ C_ = Object("C'")
192
+ D_ = Object("D'")
193
+ E_ = Object("E'")
194
+
195
+ f = NamedMorphism(A, B, "f")
196
+ g = NamedMorphism(B, C, "g")
197
+ h = NamedMorphism(C, D, "h")
198
+ i = NamedMorphism(D, E, "i")
199
+
200
+ # These 4 morphisms should be between primed objects.
201
+ j = NamedMorphism(A, B, "j")
202
+ k = NamedMorphism(B, C, "k")
203
+ l = NamedMorphism(C, D, "l")
204
+ m = NamedMorphism(D, E, "m")
205
+
206
+ o = NamedMorphism(A, A_, "o")
207
+ p = NamedMorphism(B, B_, "p")
208
+ q = NamedMorphism(C, C_, "q")
209
+ r = NamedMorphism(D, D_, "r")
210
+ s = NamedMorphism(E, E_, "s")
211
+
212
+ d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s])
213
+ grid = DiagramGrid(d)
214
+
215
+ assert grid.width == 3
216
+ assert grid.height == 4
217
+ assert grid[0, 0] is None
218
+ assert grid[0, 1] == A
219
+ assert grid[0, 2] == A_
220
+ assert grid[1, 0] == C
221
+ assert grid[1, 1] == B
222
+ assert grid[1, 2] == B_
223
+ assert grid[2, 0] == C_
224
+ assert grid[2, 1] == D
225
+ assert grid[2, 2] == D_
226
+ assert grid[3, 0] is None
227
+ assert grid[3, 1] == E
228
+ assert grid[3, 2] == E_
229
+
230
+ morphisms = {}
231
+ for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]:
232
+ morphisms[m] = FiniteSet()
233
+ assert grid.morphisms == morphisms
234
+
235
+ # A cube.
236
+ A1 = Object("A1")
237
+ A2 = Object("A2")
238
+ A3 = Object("A3")
239
+ A4 = Object("A4")
240
+ A5 = Object("A5")
241
+ A6 = Object("A6")
242
+ A7 = Object("A7")
243
+ A8 = Object("A8")
244
+
245
+ # The top face of the cube.
246
+ f1 = NamedMorphism(A1, A2, "f1")
247
+ f2 = NamedMorphism(A1, A3, "f2")
248
+ f3 = NamedMorphism(A2, A4, "f3")
249
+ f4 = NamedMorphism(A3, A4, "f3")
250
+
251
+ # The bottom face of the cube.
252
+ f5 = NamedMorphism(A5, A6, "f5")
253
+ f6 = NamedMorphism(A5, A7, "f6")
254
+ f7 = NamedMorphism(A6, A8, "f7")
255
+ f8 = NamedMorphism(A7, A8, "f8")
256
+
257
+ # The remaining morphisms.
258
+ f9 = NamedMorphism(A1, A5, "f9")
259
+ f10 = NamedMorphism(A2, A6, "f10")
260
+ f11 = NamedMorphism(A3, A7, "f11")
261
+ f12 = NamedMorphism(A4, A8, "f11")
262
+
263
+ d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12])
264
+ grid = DiagramGrid(d)
265
+
266
+ assert grid.width == 4
267
+ assert grid.height == 3
268
+ assert grid[0, 0] is None
269
+ assert grid[0, 1] == A5
270
+ assert grid[0, 2] == A6
271
+ assert grid[0, 3] is None
272
+ assert grid[1, 0] is None
273
+ assert grid[1, 1] == A1
274
+ assert grid[1, 2] == A2
275
+ assert grid[1, 3] is None
276
+ assert grid[2, 0] == A7
277
+ assert grid[2, 1] == A3
278
+ assert grid[2, 2] == A4
279
+ assert grid[2, 3] == A8
280
+
281
+ morphisms = {}
282
+ for m in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12]:
283
+ morphisms[m] = FiniteSet()
284
+ assert grid.morphisms == morphisms
285
+
286
+ # A line diagram.
287
+ A = Object("A")
288
+ B = Object("B")
289
+ C = Object("C")
290
+ D = Object("D")
291
+ E = Object("E")
292
+
293
+ f = NamedMorphism(A, B, "f")
294
+ g = NamedMorphism(B, C, "g")
295
+ h = NamedMorphism(C, D, "h")
296
+ i = NamedMorphism(D, E, "i")
297
+ d = Diagram([f, g, h, i])
298
+ grid = DiagramGrid(d, layout="sequential")
299
+
300
+ assert grid.width == 5
301
+ assert grid.height == 1
302
+ assert grid[0, 0] == A
303
+ assert grid[0, 1] == B
304
+ assert grid[0, 2] == C
305
+ assert grid[0, 3] == D
306
+ assert grid[0, 4] == E
307
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
308
+ i: FiniteSet()}
309
+
310
+ # Test the transposed version.
311
+ grid = DiagramGrid(d, layout="sequential", transpose=True)
312
+
313
+ assert grid.width == 1
314
+ assert grid.height == 5
315
+ assert grid[0, 0] == A
316
+ assert grid[1, 0] == B
317
+ assert grid[2, 0] == C
318
+ assert grid[3, 0] == D
319
+ assert grid[4, 0] == E
320
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), h: FiniteSet(),
321
+ i: FiniteSet()}
322
+
323
+ # A pullback.
324
+ m1 = NamedMorphism(A, B, "m1")
325
+ m2 = NamedMorphism(A, C, "m2")
326
+ s1 = NamedMorphism(B, D, "s1")
327
+ s2 = NamedMorphism(C, D, "s2")
328
+ f1 = NamedMorphism(E, B, "f1")
329
+ f2 = NamedMorphism(E, C, "f2")
330
+ g = NamedMorphism(E, A, "g")
331
+
332
+ d = Diagram([m1, m2, s1, s2, f1, f2], {g: "unique"})
333
+ grid = DiagramGrid(d)
334
+
335
+ assert grid.width == 3
336
+ assert grid.height == 2
337
+ assert grid[0, 0] == A
338
+ assert grid[0, 1] == B
339
+ assert grid[0, 2] == E
340
+ assert grid[1, 0] == C
341
+ assert grid[1, 1] == D
342
+ assert grid[1, 2] is None
343
+
344
+ morphisms = {g: FiniteSet("unique")}
345
+ for m in [m1, m2, s1, s2, f1, f2]:
346
+ morphisms[m] = FiniteSet()
347
+ assert grid.morphisms == morphisms
348
+
349
+ # Test the pullback with sequential layout, just for stress
350
+ # testing.
351
+ grid = DiagramGrid(d, layout="sequential")
352
+
353
+ assert grid.width == 5
354
+ assert grid.height == 1
355
+ assert grid[0, 0] == D
356
+ assert grid[0, 1] == B
357
+ assert grid[0, 2] == A
358
+ assert grid[0, 3] == C
359
+ assert grid[0, 4] == E
360
+ assert grid.morphisms == morphisms
361
+
362
+ # Test a pullback with object grouping.
363
+ grid = DiagramGrid(d, groups=FiniteSet(E, FiniteSet(A, B, C, D)))
364
+
365
+ assert grid.width == 3
366
+ assert grid.height == 2
367
+ assert grid[0, 0] == E
368
+ assert grid[0, 1] == A
369
+ assert grid[0, 2] == B
370
+ assert grid[1, 0] is None
371
+ assert grid[1, 1] == C
372
+ assert grid[1, 2] == D
373
+ assert grid.morphisms == morphisms
374
+
375
+ # Five lemma, actually.
376
+ A = Object("A")
377
+ B = Object("B")
378
+ C = Object("C")
379
+ D = Object("D")
380
+ E = Object("E")
381
+ A_ = Object("A'")
382
+ B_ = Object("B'")
383
+ C_ = Object("C'")
384
+ D_ = Object("D'")
385
+ E_ = Object("E'")
386
+
387
+ f = NamedMorphism(A, B, "f")
388
+ g = NamedMorphism(B, C, "g")
389
+ h = NamedMorphism(C, D, "h")
390
+ i = NamedMorphism(D, E, "i")
391
+
392
+ j = NamedMorphism(A_, B_, "j")
393
+ k = NamedMorphism(B_, C_, "k")
394
+ l = NamedMorphism(C_, D_, "l")
395
+ m = NamedMorphism(D_, E_, "m")
396
+
397
+ o = NamedMorphism(A, A_, "o")
398
+ p = NamedMorphism(B, B_, "p")
399
+ q = NamedMorphism(C, C_, "q")
400
+ r = NamedMorphism(D, D_, "r")
401
+ s = NamedMorphism(E, E_, "s")
402
+
403
+ d = Diagram([f, g, h, i, j, k, l, m, o, p, q, r, s])
404
+ grid = DiagramGrid(d)
405
+
406
+ assert grid.width == 5
407
+ assert grid.height == 3
408
+ assert grid[0, 0] is None
409
+ assert grid[0, 1] == A
410
+ assert grid[0, 2] == A_
411
+ assert grid[0, 3] is None
412
+ assert grid[0, 4] is None
413
+ assert grid[1, 0] == C
414
+ assert grid[1, 1] == B
415
+ assert grid[1, 2] == B_
416
+ assert grid[1, 3] == C_
417
+ assert grid[1, 4] is None
418
+ assert grid[2, 0] == D
419
+ assert grid[2, 1] == E
420
+ assert grid[2, 2] is None
421
+ assert grid[2, 3] == D_
422
+ assert grid[2, 4] == E_
423
+
424
+ morphisms = {}
425
+ for m in [f, g, h, i, j, k, l, m, o, p, q, r, s]:
426
+ morphisms[m] = FiniteSet()
427
+ assert grid.morphisms == morphisms
428
+
429
+ # Test the five lemma with object grouping.
430
+ grid = DiagramGrid(d, FiniteSet(
431
+ FiniteSet(A, B, C, D, E), FiniteSet(A_, B_, C_, D_, E_)))
432
+
433
+ assert grid.width == 6
434
+ assert grid.height == 3
435
+ assert grid[0, 0] == A
436
+ assert grid[0, 1] == B
437
+ assert grid[0, 2] is None
438
+ assert grid[0, 3] == A_
439
+ assert grid[0, 4] == B_
440
+ assert grid[0, 5] is None
441
+ assert grid[1, 0] is None
442
+ assert grid[1, 1] == C
443
+ assert grid[1, 2] == D
444
+ assert grid[1, 3] is None
445
+ assert grid[1, 4] == C_
446
+ assert grid[1, 5] == D_
447
+ assert grid[2, 0] is None
448
+ assert grid[2, 1] is None
449
+ assert grid[2, 2] == E
450
+ assert grid[2, 3] is None
451
+ assert grid[2, 4] is None
452
+ assert grid[2, 5] == E_
453
+ assert grid.morphisms == morphisms
454
+
455
+ # Test the five lemma with object grouping, but mixing containers
456
+ # to represent groups.
457
+ grid = DiagramGrid(d, [(A, B, C, D, E), {A_, B_, C_, D_, E_}])
458
+
459
+ assert grid.width == 6
460
+ assert grid.height == 3
461
+ assert grid[0, 0] == A
462
+ assert grid[0, 1] == B
463
+ assert grid[0, 2] is None
464
+ assert grid[0, 3] == A_
465
+ assert grid[0, 4] == B_
466
+ assert grid[0, 5] is None
467
+ assert grid[1, 0] is None
468
+ assert grid[1, 1] == C
469
+ assert grid[1, 2] == D
470
+ assert grid[1, 3] is None
471
+ assert grid[1, 4] == C_
472
+ assert grid[1, 5] == D_
473
+ assert grid[2, 0] is None
474
+ assert grid[2, 1] is None
475
+ assert grid[2, 2] == E
476
+ assert grid[2, 3] is None
477
+ assert grid[2, 4] is None
478
+ assert grid[2, 5] == E_
479
+ assert grid.morphisms == morphisms
480
+
481
+ # Test the five lemma with object grouping and hints.
482
+ grid = DiagramGrid(d, {
483
+ FiniteSet(A, B, C, D, E): {"layout": "sequential",
484
+ "transpose": True},
485
+ FiniteSet(A_, B_, C_, D_, E_): {"layout": "sequential",
486
+ "transpose": True}},
487
+ transpose=True)
488
+
489
+ assert grid.width == 5
490
+ assert grid.height == 2
491
+ assert grid[0, 0] == A
492
+ assert grid[0, 1] == B
493
+ assert grid[0, 2] == C
494
+ assert grid[0, 3] == D
495
+ assert grid[0, 4] == E
496
+ assert grid[1, 0] == A_
497
+ assert grid[1, 1] == B_
498
+ assert grid[1, 2] == C_
499
+ assert grid[1, 3] == D_
500
+ assert grid[1, 4] == E_
501
+ assert grid.morphisms == morphisms
502
+
503
+ # A two-triangle disconnected diagram.
504
+ f = NamedMorphism(A, B, "f")
505
+ g = NamedMorphism(B, C, "g")
506
+ f_ = NamedMorphism(A_, B_, "f")
507
+ g_ = NamedMorphism(B_, C_, "g")
508
+ d = Diagram([f, g, f_, g_], {g * f: "unique", g_ * f_: "unique"})
509
+ grid = DiagramGrid(d)
510
+
511
+ assert grid.width == 4
512
+ assert grid.height == 2
513
+ assert grid[0, 0] == A
514
+ assert grid[0, 1] == B
515
+ assert grid[0, 2] == A_
516
+ assert grid[0, 3] == B_
517
+ assert grid[1, 0] == C
518
+ assert grid[1, 1] is None
519
+ assert grid[1, 2] == C_
520
+ assert grid[1, 3] is None
521
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet(), f_: FiniteSet(),
522
+ g_: FiniteSet(), g * f: FiniteSet("unique"),
523
+ g_ * f_: FiniteSet("unique")}
524
+
525
+ # A two-morphism disconnected diagram.
526
+ f = NamedMorphism(A, B, "f")
527
+ g = NamedMorphism(C, D, "g")
528
+ d = Diagram([f, g])
529
+ grid = DiagramGrid(d)
530
+
531
+ assert grid.width == 4
532
+ assert grid.height == 1
533
+ assert grid[0, 0] == A
534
+ assert grid[0, 1] == B
535
+ assert grid[0, 2] == C
536
+ assert grid[0, 3] == D
537
+ assert grid.morphisms == {f: FiniteSet(), g: FiniteSet()}
538
+
539
+ # Test a one-object diagram.
540
+ f = NamedMorphism(A, A, "f")
541
+ d = Diagram([f])
542
+ grid = DiagramGrid(d)
543
+
544
+ assert grid.width == 1
545
+ assert grid.height == 1
546
+ assert grid[0, 0] == A
547
+
548
+ # Test a two-object disconnected diagram.
549
+ g = NamedMorphism(B, B, "g")
550
+ d = Diagram([f, g])
551
+ grid = DiagramGrid(d)
552
+
553
+ assert grid.width == 2
554
+ assert grid.height == 1
555
+ assert grid[0, 0] == A
556
+ assert grid[0, 1] == B
557
+
558
+
559
+ def test_DiagramGrid_pseudopod():
560
+ # Test a diagram in which even growing a pseudopod does not
561
+ # eventually help.
562
+ A = Object("A")
563
+ B = Object("B")
564
+ C = Object("C")
565
+ D = Object("D")
566
+ E = Object("E")
567
+ F = Object("F")
568
+ A_ = Object("A'")
569
+ B_ = Object("B'")
570
+ C_ = Object("C'")
571
+ D_ = Object("D'")
572
+ E_ = Object("E'")
573
+
574
+ f1 = NamedMorphism(A, B, "f1")
575
+ f2 = NamedMorphism(A, C, "f2")
576
+ f3 = NamedMorphism(A, D, "f3")
577
+ f4 = NamedMorphism(A, E, "f4")
578
+ f5 = NamedMorphism(A, A_, "f5")
579
+ f6 = NamedMorphism(A, B_, "f6")
580
+ f7 = NamedMorphism(A, C_, "f7")
581
+ f8 = NamedMorphism(A, D_, "f8")
582
+ f9 = NamedMorphism(A, E_, "f9")
583
+ f10 = NamedMorphism(A, F, "f10")
584
+ d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10])
585
+ grid = DiagramGrid(d)
586
+
587
+ assert grid.width == 5
588
+ assert grid.height == 3
589
+ assert grid[0, 0] == E
590
+ assert grid[0, 1] == C
591
+ assert grid[0, 2] == C_
592
+ assert grid[0, 3] == E_
593
+ assert grid[0, 4] == F
594
+ assert grid[1, 0] == D
595
+ assert grid[1, 1] == A
596
+ assert grid[1, 2] == A_
597
+ assert grid[1, 3] is None
598
+ assert grid[1, 4] is None
599
+ assert grid[2, 0] == D_
600
+ assert grid[2, 1] == B
601
+ assert grid[2, 2] == B_
602
+ assert grid[2, 3] is None
603
+ assert grid[2, 4] is None
604
+
605
+ morphisms = {}
606
+ for f in [f1, f2, f3, f4, f5, f6, f7, f8, f9, f10]:
607
+ morphisms[f] = FiniteSet()
608
+ assert grid.morphisms == morphisms
609
+
610
+
611
+ def test_ArrowStringDescription():
612
+ astr = ArrowStringDescription("cm", "", None, "", "", "d", "r", "_", "f")
613
+ assert str(astr) == "\\ar[dr]_{f}"
614
+
615
+ astr = ArrowStringDescription("cm", "", 12, "", "", "d", "r", "_", "f")
616
+ assert str(astr) == "\\ar[dr]_{f}"
617
+
618
+ astr = ArrowStringDescription("cm", "^", 12, "", "", "d", "r", "_", "f")
619
+ assert str(astr) == "\\ar@/^12cm/[dr]_{f}"
620
+
621
+ astr = ArrowStringDescription("cm", "", 12, "r", "", "d", "r", "_", "f")
622
+ assert str(astr) == "\\ar[dr]_{f}"
623
+
624
+ astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f")
625
+ assert str(astr) == "\\ar@(r,u)[dr]_{f}"
626
+
627
+ astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f")
628
+ assert str(astr) == "\\ar@(r,u)[dr]_{f}"
629
+
630
+ astr = ArrowStringDescription("cm", "", 12, "r", "u", "d", "r", "_", "f")
631
+ astr.arrow_style = "{-->}"
632
+ assert str(astr) == "\\ar@(r,u)@{-->}[dr]_{f}"
633
+
634
+ astr = ArrowStringDescription("cm", "_", 12, "", "", "d", "r", "_", "f")
635
+ astr.arrow_style = "{-->}"
636
+ assert str(astr) == "\\ar@/_12cm/@{-->}[dr]_{f}"
637
+
638
+
639
+ def test_XypicDiagramDrawer_line():
640
+ # A linear diagram.
641
+ A = Object("A")
642
+ B = Object("B")
643
+ C = Object("C")
644
+ D = Object("D")
645
+ E = Object("E")
646
+
647
+ f = NamedMorphism(A, B, "f")
648
+ g = NamedMorphism(B, C, "g")
649
+ h = NamedMorphism(C, D, "h")
650
+ i = NamedMorphism(D, E, "i")
651
+ d = Diagram([f, g, h, i])
652
+ grid = DiagramGrid(d, layout="sequential")
653
+ drawer = XypicDiagramDrawer()
654
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
655
+ "A \\ar[r]^{f} & B \\ar[r]^{g} & C \\ar[r]^{h} & D \\ar[r]^{i} & E \n" \
656
+ "}\n"
657
+
658
+ # The same diagram, transposed.
659
+ grid = DiagramGrid(d, layout="sequential", transpose=True)
660
+ drawer = XypicDiagramDrawer()
661
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
662
+ "A \\ar[d]^{f} \\\\\n" \
663
+ "B \\ar[d]^{g} \\\\\n" \
664
+ "C \\ar[d]^{h} \\\\\n" \
665
+ "D \\ar[d]^{i} \\\\\n" \
666
+ "E \n" \
667
+ "}\n"
668
+
669
+
670
+ def test_XypicDiagramDrawer_triangle():
671
+ # A triangle diagram.
672
+ A = Object("A")
673
+ B = Object("B")
674
+ C = Object("C")
675
+ f = NamedMorphism(A, B, "f")
676
+ g = NamedMorphism(B, C, "g")
677
+
678
+ d = Diagram([f, g], {g * f: "unique"})
679
+ grid = DiagramGrid(d)
680
+ drawer = XypicDiagramDrawer()
681
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
682
+ "A \\ar[d]_{g\\circ f} \\ar[r]^{f} & B \\ar[ld]^{g} \\\\\n" \
683
+ "C & \n" \
684
+ "}\n"
685
+
686
+ # The same diagram, transposed.
687
+ grid = DiagramGrid(d, transpose=True)
688
+ drawer = XypicDiagramDrawer()
689
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
690
+ "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \
691
+ "B \\ar[ru]_{g} & \n" \
692
+ "}\n"
693
+
694
+ # The same diagram, with a masked morphism.
695
+ assert drawer.draw(d, grid, masked=[g]) == "\\xymatrix{\n" \
696
+ "A \\ar[r]^{g\\circ f} \\ar[d]_{f} & C \\\\\n" \
697
+ "B & \n" \
698
+ "}\n"
699
+
700
+ # The same diagram with a formatter for "unique".
701
+ def formatter(astr):
702
+ astr.label = "\\exists !" + astr.label
703
+ astr.arrow_style = "{-->}"
704
+
705
+ drawer.arrow_formatters["unique"] = formatter
706
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
707
+ "A \\ar@{-->}[r]^{\\exists !g\\circ f} \\ar[d]_{f} & C \\\\\n" \
708
+ "B \\ar[ru]_{g} & \n" \
709
+ "}\n"
710
+
711
+ # The same diagram with a default formatter.
712
+ def default_formatter(astr):
713
+ astr.label_displacement = "(0.45)"
714
+
715
+ drawer.default_arrow_formatter = default_formatter
716
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
717
+ "A \\ar@{-->}[r]^(0.45){\\exists !g\\circ f} \\ar[d]_(0.45){f} & C \\\\\n" \
718
+ "B \\ar[ru]_(0.45){g} & \n" \
719
+ "}\n"
720
+
721
+ # A triangle diagram with a lot of morphisms between the same
722
+ # objects.
723
+ f1 = NamedMorphism(B, A, "f1")
724
+ f2 = NamedMorphism(A, B, "f2")
725
+ g1 = NamedMorphism(C, B, "g1")
726
+ g2 = NamedMorphism(B, C, "g2")
727
+ d = Diagram([f, f1, f2, g, g1, g2], {f1 * g1: "unique", g2 * f2: "unique"})
728
+
729
+ grid = DiagramGrid(d, transpose=True)
730
+ drawer = XypicDiagramDrawer()
731
+ assert drawer.draw(d, grid, masked=[f1*g1*g2*f2, g2*f2*f1*g1]) == \
732
+ "\\xymatrix{\n" \
733
+ "A \\ar[r]^{g_{2}\\circ f_{2}} \\ar[d]_{f} \\ar@/^3mm/[d]^{f_{2}} " \
734
+ "& C \\ar@/^3mm/[l]^{f_{1}\\circ g_{1}} \\ar@/^3mm/[ld]^{g_{1}} \\\\\n" \
735
+ "B \\ar@/^3mm/[u]^{f_{1}} \\ar[ru]_{g} \\ar@/^3mm/[ru]^{g_{2}} & \n" \
736
+ "}\n"
737
+
738
+
739
+ def test_XypicDiagramDrawer_cube():
740
+ # A cube diagram.
741
+ A1 = Object("A1")
742
+ A2 = Object("A2")
743
+ A3 = Object("A3")
744
+ A4 = Object("A4")
745
+ A5 = Object("A5")
746
+ A6 = Object("A6")
747
+ A7 = Object("A7")
748
+ A8 = Object("A8")
749
+
750
+ # The top face of the cube.
751
+ f1 = NamedMorphism(A1, A2, "f1")
752
+ f2 = NamedMorphism(A1, A3, "f2")
753
+ f3 = NamedMorphism(A2, A4, "f3")
754
+ f4 = NamedMorphism(A3, A4, "f3")
755
+
756
+ # The bottom face of the cube.
757
+ f5 = NamedMorphism(A5, A6, "f5")
758
+ f6 = NamedMorphism(A5, A7, "f6")
759
+ f7 = NamedMorphism(A6, A8, "f7")
760
+ f8 = NamedMorphism(A7, A8, "f8")
761
+
762
+ # The remaining morphisms.
763
+ f9 = NamedMorphism(A1, A5, "f9")
764
+ f10 = NamedMorphism(A2, A6, "f10")
765
+ f11 = NamedMorphism(A3, A7, "f11")
766
+ f12 = NamedMorphism(A4, A8, "f11")
767
+
768
+ d = Diagram([f1, f2, f3, f4, f5, f6, f7, f8, f9, f10, f11, f12])
769
+ grid = DiagramGrid(d)
770
+ drawer = XypicDiagramDrawer()
771
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
772
+ "& A_{5} \\ar[r]^{f_{5}} \\ar[ldd]_{f_{6}} & A_{6} \\ar[rdd]^{f_{7}} " \
773
+ "& \\\\\n" \
774
+ "& A_{1} \\ar[r]^{f_{1}} \\ar[d]^{f_{2}} \\ar[u]^{f_{9}} & A_{2} " \
775
+ "\\ar[d]^{f_{3}} \\ar[u]_{f_{10}} & \\\\\n" \
776
+ "A_{7} \\ar@/_3mm/[rrr]_{f_{8}} & A_{3} \\ar[r]^{f_{3}} \\ar[l]_{f_{11}} " \
777
+ "& A_{4} \\ar[r]^{f_{11}} & A_{8} \n" \
778
+ "}\n"
779
+
780
+ # The same diagram, transposed.
781
+ grid = DiagramGrid(d, transpose=True)
782
+ drawer = XypicDiagramDrawer()
783
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
784
+ "& & A_{7} \\ar@/^3mm/[ddd]^{f_{8}} \\\\\n" \
785
+ "A_{5} \\ar[d]_{f_{5}} \\ar[rru]^{f_{6}} & A_{1} \\ar[d]^{f_{1}} " \
786
+ "\\ar[r]^{f_{2}} \\ar[l]^{f_{9}} & A_{3} \\ar[d]_{f_{3}} " \
787
+ "\\ar[u]^{f_{11}} \\\\\n" \
788
+ "A_{6} \\ar[rrd]_{f_{7}} & A_{2} \\ar[r]^{f_{3}} \\ar[l]^{f_{10}} " \
789
+ "& A_{4} \\ar[d]_{f_{11}} \\\\\n" \
790
+ "& & A_{8} \n" \
791
+ "}\n"
792
+
793
+
794
+ def test_XypicDiagramDrawer_curved_and_loops():
795
+ # A simple diagram, with a curved arrow.
796
+ A = Object("A")
797
+ B = Object("B")
798
+ C = Object("C")
799
+ D = Object("D")
800
+
801
+ f = NamedMorphism(A, B, "f")
802
+ g = NamedMorphism(B, C, "g")
803
+ h = NamedMorphism(D, A, "h")
804
+ k = NamedMorphism(D, B, "k")
805
+ d = Diagram([f, g, h, k])
806
+ grid = DiagramGrid(d)
807
+ drawer = XypicDiagramDrawer()
808
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
809
+ "A \\ar[r]_{f} & B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_3mm/[ll]_{h} \\\\\n" \
810
+ "& C & \n" \
811
+ "}\n"
812
+
813
+ # The same diagram, transposed.
814
+ grid = DiagramGrid(d, transpose=True)
815
+ drawer = XypicDiagramDrawer()
816
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
817
+ "A \\ar[d]^{f} & \\\\\n" \
818
+ "B \\ar[r]^{g} & C \\\\\n" \
819
+ "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \
820
+ "}\n"
821
+
822
+ # The same diagram, larger and rotated.
823
+ assert drawer.draw(d, grid, diagram_format="@+1cm@dr") == \
824
+ "\\xymatrix@+1cm@dr{\n" \
825
+ "A \\ar[d]^{f} & \\\\\n" \
826
+ "B \\ar[r]^{g} & C \\\\\n" \
827
+ "D \\ar[u]_{k} \\ar@/^3mm/[uu]^{h} & \n" \
828
+ "}\n"
829
+
830
+ # A simple diagram with three curved arrows.
831
+ h1 = NamedMorphism(D, A, "h1")
832
+ h2 = NamedMorphism(A, D, "h2")
833
+ k = NamedMorphism(D, B, "k")
834
+ d = Diagram([f, g, h, k, h1, h2])
835
+ grid = DiagramGrid(d)
836
+ drawer = XypicDiagramDrawer()
837
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
838
+ "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \
839
+ "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\\\\n" \
840
+ "& C & \n" \
841
+ "}\n"
842
+
843
+ # The same diagram, transposed.
844
+ grid = DiagramGrid(d, transpose=True)
845
+ drawer = XypicDiagramDrawer()
846
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
847
+ "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} & \\\\\n" \
848
+ "B \\ar[r]^{g} & C \\\\\n" \
849
+ "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} & \n" \
850
+ "}\n"
851
+
852
+ # The same diagram, with "loop" morphisms.
853
+ l_A = NamedMorphism(A, A, "l_A")
854
+ l_D = NamedMorphism(D, D, "l_D")
855
+ l_C = NamedMorphism(C, C, "l_C")
856
+ d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C])
857
+ grid = DiagramGrid(d)
858
+ drawer = XypicDiagramDrawer()
859
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
860
+ "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \
861
+ "& B \\ar[d]^{g} & D \\ar[l]^{k} \\ar@/_7mm/[ll]_{h} " \
862
+ "\\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} \\\\\n" \
863
+ "& C \\ar@(l,d)[]^{l_{C}} & \n" \
864
+ "}\n"
865
+
866
+ # The same diagram with "loop" morphisms, transposed.
867
+ grid = DiagramGrid(d, transpose=True)
868
+ drawer = XypicDiagramDrawer()
869
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
870
+ "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} & \\\\\n" \
871
+ "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\\\\n" \
872
+ "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \
873
+ "\\ar@(l,d)[]^{l_{D}} & \n" \
874
+ "}\n"
875
+
876
+ # The same diagram with two "loop" morphisms per object.
877
+ l_A_ = NamedMorphism(A, A, "n_A")
878
+ l_D_ = NamedMorphism(D, D, "n_D")
879
+ l_C_ = NamedMorphism(C, C, "n_C")
880
+ d = Diagram([f, g, h, k, h1, h2, l_A, l_D, l_C, l_A_, l_D_, l_C_])
881
+ grid = DiagramGrid(d)
882
+ drawer = XypicDiagramDrawer()
883
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
884
+ "A \\ar[r]_{f} \\ar@/^3mm/[rr]^{h_{2}} \\ar@(u,l)[]^{l_{A}} " \
885
+ "\\ar@/^3mm/@(l,d)[]^{n_{A}} & B \\ar[d]^{g} & D \\ar[l]^{k} " \
886
+ "\\ar@/_7mm/[ll]_{h} \\ar@/_11mm/[ll]_{h_{1}} \\ar@(r,u)[]^{l_{D}} " \
887
+ "\\ar@/^3mm/@(d,r)[]^{n_{D}} \\\\\n" \
888
+ "& C \\ar@(l,d)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} & \n" \
889
+ "}\n"
890
+
891
+ # The same diagram with two "loop" morphisms per object, transposed.
892
+ grid = DiagramGrid(d, transpose=True)
893
+ drawer = XypicDiagramDrawer()
894
+ assert drawer.draw(d, grid) == "\\xymatrix{\n" \
895
+ "A \\ar[d]^{f} \\ar@/_3mm/[dd]_{h_{2}} \\ar@(r,u)[]^{l_{A}} " \
896
+ "\\ar@/^3mm/@(u,l)[]^{n_{A}} & \\\\\n" \
897
+ "B \\ar[r]^{g} & C \\ar@(r,u)[]^{l_{C}} \\ar@/^3mm/@(d,r)[]^{n_{C}} \\\\\n" \
898
+ "D \\ar[u]_{k} \\ar@/^7mm/[uu]^{h} \\ar@/^11mm/[uu]^{h_{1}} " \
899
+ "\\ar@(l,d)[]^{l_{D}} \\ar@/^3mm/@(d,r)[]^{n_{D}} & \n" \
900
+ "}\n"
901
+
902
+
903
+ def test_xypic_draw_diagram():
904
+ # A linear diagram.
905
+ A = Object("A")
906
+ B = Object("B")
907
+ C = Object("C")
908
+ D = Object("D")
909
+ E = Object("E")
910
+
911
+ f = NamedMorphism(A, B, "f")
912
+ g = NamedMorphism(B, C, "g")
913
+ h = NamedMorphism(C, D, "h")
914
+ i = NamedMorphism(D, E, "i")
915
+ d = Diagram([f, g, h, i])
916
+
917
+ grid = DiagramGrid(d, layout="sequential")
918
+ drawer = XypicDiagramDrawer()
919
+ assert drawer.draw(d, grid) == xypic_draw_diagram(d, layout="sequential")
.venv/lib/python3.13/site-packages/sympy/diffgeom/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .diffgeom import (
2
+ BaseCovarDerivativeOp, BaseScalarField, BaseVectorField, Commutator,
3
+ contravariant_order, CoordSystem, CoordinateSymbol,
4
+ CovarDerivativeOp, covariant_order, Differential, intcurve_diffequ,
5
+ intcurve_series, LieDerivative, Manifold, metric_to_Christoffel_1st,
6
+ metric_to_Christoffel_2nd, metric_to_Ricci_components,
7
+ metric_to_Riemann_components, Patch, Point, TensorProduct, twoform_to_matrix,
8
+ vectors_in_basis, WedgeProduct,
9
+ )
10
+
11
+ __all__ = [
12
+ 'BaseCovarDerivativeOp', 'BaseScalarField', 'BaseVectorField', 'Commutator',
13
+ 'contravariant_order', 'CoordSystem', 'CoordinateSymbol',
14
+ 'CovarDerivativeOp', 'covariant_order', 'Differential', 'intcurve_diffequ',
15
+ 'intcurve_series', 'LieDerivative', 'Manifold', 'metric_to_Christoffel_1st',
16
+ 'metric_to_Christoffel_2nd', 'metric_to_Ricci_components',
17
+ 'metric_to_Riemann_components', 'Patch', 'Point', 'TensorProduct',
18
+ 'twoform_to_matrix', 'vectors_in_basis', 'WedgeProduct',
19
+ ]
.venv/lib/python3.13/site-packages/sympy/diffgeom/diffgeom.py ADDED
@@ -0,0 +1,2270 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from typing import Any
3
+
4
+ from functools import reduce
5
+ from itertools import permutations
6
+
7
+ from sympy.combinatorics import Permutation
8
+ from sympy.core import (
9
+ Basic, Expr, Function, diff,
10
+ Pow, Mul, Add, Lambda, S, Tuple, Dict
11
+ )
12
+ from sympy.core.cache import cacheit
13
+
14
+ from sympy.core.symbol import Symbol, Dummy
15
+ from sympy.core.symbol import Str
16
+ from sympy.core.sympify import _sympify
17
+ from sympy.functions import factorial
18
+ from sympy.matrices import ImmutableDenseMatrix as Matrix
19
+ from sympy.solvers import solve
20
+
21
+ from sympy.utilities.exceptions import (sympy_deprecation_warning,
22
+ SymPyDeprecationWarning,
23
+ ignore_warnings)
24
+
25
+
26
+ # TODO you are a bit excessive in the use of Dummies
27
+ # TODO dummy point, literal field
28
+ # TODO too often one needs to call doit or simplify on the output, check the
29
+ # tests and find out why
30
+ from sympy.tensor.array import ImmutableDenseNDimArray
31
+
32
+
33
+ class Manifold(Basic):
34
+ """
35
+ A mathematical manifold.
36
+
37
+ Explanation
38
+ ===========
39
+
40
+ A manifold is a topological space that locally resembles
41
+ Euclidean space near each point [1].
42
+ This class does not provide any means to study the topological
43
+ characteristics of the manifold that it represents, though.
44
+
45
+ Parameters
46
+ ==========
47
+
48
+ name : str
49
+ The name of the manifold.
50
+
51
+ dim : int
52
+ The dimension of the manifold.
53
+
54
+ Examples
55
+ ========
56
+
57
+ >>> from sympy.diffgeom import Manifold
58
+ >>> m = Manifold('M', 2)
59
+ >>> m
60
+ M
61
+ >>> m.dim
62
+ 2
63
+
64
+ References
65
+ ==========
66
+
67
+ .. [1] https://en.wikipedia.org/wiki/Manifold
68
+ """
69
+
70
+ def __new__(cls, name, dim, **kwargs):
71
+ if not isinstance(name, Str):
72
+ name = Str(name)
73
+ dim = _sympify(dim)
74
+ obj = super().__new__(cls, name, dim)
75
+
76
+ obj.patches = _deprecated_list(
77
+ """
78
+ Manifold.patches is deprecated. The Manifold object is now
79
+ immutable. Instead use a separate list to keep track of the
80
+ patches.
81
+ """, [])
82
+ return obj
83
+
84
+ @property
85
+ def name(self):
86
+ return self.args[0]
87
+
88
+ @property
89
+ def dim(self):
90
+ return self.args[1]
91
+
92
+
93
+ class Patch(Basic):
94
+ """
95
+ A patch on a manifold.
96
+
97
+ Explanation
98
+ ===========
99
+
100
+ Coordinate patch, or patch in short, is a simply-connected open set around
101
+ a point in the manifold [1]. On a manifold one can have many patches that
102
+ do not always include the whole manifold. On these patches coordinate
103
+ charts can be defined that permit the parameterization of any point on the
104
+ patch in terms of a tuple of real numbers (the coordinates).
105
+
106
+ This class does not provide any means to study the topological
107
+ characteristics of the patch that it represents.
108
+
109
+ Parameters
110
+ ==========
111
+
112
+ name : str
113
+ The name of the patch.
114
+
115
+ manifold : Manifold
116
+ The manifold on which the patch is defined.
117
+
118
+ Examples
119
+ ========
120
+
121
+ >>> from sympy.diffgeom import Manifold, Patch
122
+ >>> m = Manifold('M', 2)
123
+ >>> p = Patch('P', m)
124
+ >>> p
125
+ P
126
+ >>> p.dim
127
+ 2
128
+
129
+ References
130
+ ==========
131
+
132
+ .. [1] G. Sussman, J. Wisdom, W. Farr, Functional Differential Geometry
133
+ (2013)
134
+
135
+ """
136
+ def __new__(cls, name, manifold, **kwargs):
137
+ if not isinstance(name, Str):
138
+ name = Str(name)
139
+ obj = super().__new__(cls, name, manifold)
140
+
141
+ obj.manifold.patches.append(obj) # deprecated
142
+ obj.coord_systems = _deprecated_list(
143
+ """
144
+ Patch.coord_systms is deprecated. The Patch class is now
145
+ immutable. Instead use a separate list to keep track of coordinate
146
+ systems.
147
+ """, [])
148
+ return obj
149
+
150
+ @property
151
+ def name(self):
152
+ return self.args[0]
153
+
154
+ @property
155
+ def manifold(self):
156
+ return self.args[1]
157
+
158
+ @property
159
+ def dim(self):
160
+ return self.manifold.dim
161
+
162
+
163
+ class CoordSystem(Basic):
164
+ """
165
+ A coordinate system defined on the patch.
166
+
167
+ Explanation
168
+ ===========
169
+
170
+ Coordinate system is a system that uses one or more coordinates to uniquely
171
+ determine the position of the points or other geometric elements on a
172
+ manifold [1].
173
+
174
+ By passing ``Symbols`` to *symbols* parameter, user can define the name and
175
+ assumptions of coordinate symbols of the coordinate system. If not passed,
176
+ these symbols are generated automatically and are assumed to be real valued.
177
+
178
+ By passing *relations* parameter, user can define the transform relations of
179
+ coordinate systems. Inverse transformation and indirect transformation can
180
+ be found automatically. If this parameter is not passed, coordinate
181
+ transformation cannot be done.
182
+
183
+ Parameters
184
+ ==========
185
+
186
+ name : str
187
+ The name of the coordinate system.
188
+
189
+ patch : Patch
190
+ The patch where the coordinate system is defined.
191
+
192
+ symbols : list of Symbols, optional
193
+ Defines the names and assumptions of coordinate symbols.
194
+
195
+ relations : dict, optional
196
+ Key is a tuple of two strings, who are the names of the systems where
197
+ the coordinates transform from and transform to.
198
+ Value is a tuple of the symbols before transformation and a tuple of
199
+ the expressions after transformation.
200
+
201
+ Examples
202
+ ========
203
+
204
+ We define two-dimensional Cartesian coordinate system and polar coordinate
205
+ system.
206
+
207
+ >>> from sympy import symbols, pi, sqrt, atan2, cos, sin
208
+ >>> from sympy.diffgeom import Manifold, Patch, CoordSystem
209
+ >>> m = Manifold('M', 2)
210
+ >>> p = Patch('P', m)
211
+ >>> x, y = symbols('x y', real=True)
212
+ >>> r, theta = symbols('r theta', nonnegative=True)
213
+ >>> relation_dict = {
214
+ ... ('Car2D', 'Pol'): [(x, y), (sqrt(x**2 + y**2), atan2(y, x))],
215
+ ... ('Pol', 'Car2D'): [(r, theta), (r*cos(theta), r*sin(theta))]
216
+ ... }
217
+ >>> Car2D = CoordSystem('Car2D', p, (x, y), relation_dict)
218
+ >>> Pol = CoordSystem('Pol', p, (r, theta), relation_dict)
219
+
220
+ ``symbols`` property returns ``CoordinateSymbol`` instances. These symbols
221
+ are not same with the symbols used to construct the coordinate system.
222
+
223
+ >>> Car2D
224
+ Car2D
225
+ >>> Car2D.dim
226
+ 2
227
+ >>> Car2D.symbols
228
+ (x, y)
229
+ >>> _[0].func
230
+ <class 'sympy.diffgeom.diffgeom.CoordinateSymbol'>
231
+
232
+ ``transformation()`` method returns the transformation function from
233
+ one coordinate system to another. ``transform()`` method returns the
234
+ transformed coordinates.
235
+
236
+ >>> Car2D.transformation(Pol)
237
+ Lambda((x, y), Matrix([
238
+ [sqrt(x**2 + y**2)],
239
+ [ atan2(y, x)]]))
240
+ >>> Car2D.transform(Pol)
241
+ Matrix([
242
+ [sqrt(x**2 + y**2)],
243
+ [ atan2(y, x)]])
244
+ >>> Car2D.transform(Pol, [1, 2])
245
+ Matrix([
246
+ [sqrt(5)],
247
+ [atan(2)]])
248
+
249
+ ``jacobian()`` method returns the Jacobian matrix of coordinate
250
+ transformation between two systems. ``jacobian_determinant()`` method
251
+ returns the Jacobian determinant of coordinate transformation between two
252
+ systems.
253
+
254
+ >>> Pol.jacobian(Car2D)
255
+ Matrix([
256
+ [cos(theta), -r*sin(theta)],
257
+ [sin(theta), r*cos(theta)]])
258
+ >>> Pol.jacobian(Car2D, [1, pi/2])
259
+ Matrix([
260
+ [0, -1],
261
+ [1, 0]])
262
+ >>> Car2D.jacobian_determinant(Pol)
263
+ 1/sqrt(x**2 + y**2)
264
+ >>> Car2D.jacobian_determinant(Pol, [1,0])
265
+ 1
266
+
267
+ References
268
+ ==========
269
+
270
+ .. [1] https://en.wikipedia.org/wiki/Coordinate_system
271
+
272
+ """
273
+ def __new__(cls, name, patch, symbols=None, relations={}, **kwargs):
274
+ if not isinstance(name, Str):
275
+ name = Str(name)
276
+
277
+ # canonicallize the symbols
278
+ if symbols is None:
279
+ names = kwargs.get('names', None)
280
+ if names is None:
281
+ symbols = Tuple(
282
+ *[Symbol('%s_%s' % (name.name, i), real=True)
283
+ for i in range(patch.dim)]
284
+ )
285
+ else:
286
+ sympy_deprecation_warning(
287
+ f"""
288
+ The 'names' argument to CoordSystem is deprecated. Use 'symbols' instead. That
289
+ is, replace
290
+
291
+ CoordSystem(..., names={names})
292
+
293
+ with
294
+
295
+ CoordSystem(..., symbols=[{', '.join(["Symbol(" + repr(n) + ", real=True)" for n in names])}])
296
+ """,
297
+ deprecated_since_version="1.7",
298
+ active_deprecations_target="deprecated-diffgeom-mutable",
299
+ )
300
+ symbols = Tuple(
301
+ *[Symbol(n, real=True) for n in names]
302
+ )
303
+ else:
304
+ syms = []
305
+ for s in symbols:
306
+ if isinstance(s, Symbol):
307
+ syms.append(Symbol(s.name, **s._assumptions.generator))
308
+ elif isinstance(s, str):
309
+ sympy_deprecation_warning(
310
+ f"""
311
+
312
+ Passing a string as the coordinate symbol name to CoordSystem is deprecated.
313
+ Pass a Symbol with the appropriate name and assumptions instead.
314
+
315
+ That is, replace {s} with Symbol({s!r}, real=True).
316
+ """,
317
+
318
+ deprecated_since_version="1.7",
319
+ active_deprecations_target="deprecated-diffgeom-mutable",
320
+ )
321
+ syms.append(Symbol(s, real=True))
322
+ symbols = Tuple(*syms)
323
+
324
+ # canonicallize the relations
325
+ rel_temp = {}
326
+ for k,v in relations.items():
327
+ s1, s2 = k
328
+ if not isinstance(s1, Str):
329
+ s1 = Str(s1)
330
+ if not isinstance(s2, Str):
331
+ s2 = Str(s2)
332
+ key = Tuple(s1, s2)
333
+
334
+ # Old version used Lambda as a value.
335
+ if isinstance(v, Lambda):
336
+ v = (tuple(v.signature), tuple(v.expr))
337
+ else:
338
+ v = (tuple(v[0]), tuple(v[1]))
339
+ rel_temp[key] = v
340
+ relations = Dict(rel_temp)
341
+
342
+ # construct the object
343
+ obj = super().__new__(cls, name, patch, symbols, relations)
344
+
345
+ # Add deprecated attributes
346
+ obj.transforms = _deprecated_dict(
347
+ """
348
+ CoordSystem.transforms is deprecated. The CoordSystem class is now
349
+ immutable. Use the 'relations' keyword argument to the
350
+ CoordSystems() constructor to specify relations.
351
+ """, {})
352
+ obj._names = [str(n) for n in symbols]
353
+ obj.patch.coord_systems.append(obj) # deprecated
354
+ obj._dummies = [Dummy(str(n)) for n in symbols] # deprecated
355
+ obj._dummy = Dummy()
356
+
357
+ return obj
358
+
359
+ @property
360
+ def name(self):
361
+ return self.args[0]
362
+
363
+ @property
364
+ def patch(self):
365
+ return self.args[1]
366
+
367
+ @property
368
+ def manifold(self):
369
+ return self.patch.manifold
370
+
371
+ @property
372
+ def symbols(self):
373
+ return tuple(CoordinateSymbol(self, i, **s._assumptions.generator)
374
+ for i,s in enumerate(self.args[2]))
375
+
376
+ @property
377
+ def relations(self):
378
+ return self.args[3]
379
+
380
+ @property
381
+ def dim(self):
382
+ return self.patch.dim
383
+
384
+ ##########################################################################
385
+ # Finding transformation relation
386
+ ##########################################################################
387
+
388
+ def transformation(self, sys):
389
+ """
390
+ Return coordinate transformation function from *self* to *sys*.
391
+
392
+ Parameters
393
+ ==========
394
+
395
+ sys : CoordSystem
396
+
397
+ Returns
398
+ =======
399
+
400
+ sympy.Lambda
401
+
402
+ Examples
403
+ ========
404
+
405
+ >>> from sympy.diffgeom.rn import R2_r, R2_p
406
+ >>> R2_r.transformation(R2_p)
407
+ Lambda((x, y), Matrix([
408
+ [sqrt(x**2 + y**2)],
409
+ [ atan2(y, x)]]))
410
+
411
+ """
412
+ signature = self.args[2]
413
+
414
+ key = Tuple(self.name, sys.name)
415
+ if self == sys:
416
+ expr = Matrix(self.symbols)
417
+ elif key in self.relations:
418
+ expr = Matrix(self.relations[key][1])
419
+ elif key[::-1] in self.relations:
420
+ expr = Matrix(self._inverse_transformation(sys, self))
421
+ else:
422
+ expr = Matrix(self._indirect_transformation(self, sys))
423
+ return Lambda(signature, expr)
424
+
425
+ @staticmethod
426
+ def _solve_inverse(sym1, sym2, exprs, sys1_name, sys2_name):
427
+ ret = solve(
428
+ [t[0] - t[1] for t in zip(sym2, exprs)],
429
+ list(sym1), dict=True)
430
+
431
+ if len(ret) == 0:
432
+ temp = "Cannot solve inverse relation from {} to {}."
433
+ raise NotImplementedError(temp.format(sys1_name, sys2_name))
434
+ elif len(ret) > 1:
435
+ temp = "Obtained multiple inverse relation from {} to {}."
436
+ raise ValueError(temp.format(sys1_name, sys2_name))
437
+
438
+ return ret[0]
439
+
440
+ @classmethod
441
+ def _inverse_transformation(cls, sys1, sys2):
442
+ # Find the transformation relation from sys2 to sys1
443
+ forward = sys1.transform(sys2)
444
+ inv_results = cls._solve_inverse(sys1.symbols, sys2.symbols, forward,
445
+ sys1.name, sys2.name)
446
+ signature = tuple(sys1.symbols)
447
+ return [inv_results[s] for s in signature]
448
+
449
+ @classmethod
450
+ @cacheit
451
+ def _indirect_transformation(cls, sys1, sys2):
452
+ # Find the transformation relation between two indirectly connected
453
+ # coordinate systems
454
+ rel = sys1.relations
455
+ path = cls._dijkstra(sys1, sys2)
456
+
457
+ transforms = []
458
+ for s1, s2 in zip(path, path[1:]):
459
+ if (s1, s2) in rel:
460
+ transforms.append(rel[(s1, s2)])
461
+ else:
462
+ sym2, inv_exprs = rel[(s2, s1)]
463
+ sym1 = tuple(Dummy() for i in sym2)
464
+ ret = cls._solve_inverse(sym2, sym1, inv_exprs, s2, s1)
465
+ ret = tuple(ret[s] for s in sym2)
466
+ transforms.append((sym1, ret))
467
+ syms = sys1.args[2]
468
+ exprs = syms
469
+ for newsyms, newexprs in transforms:
470
+ exprs = tuple(e.subs(zip(newsyms, exprs)) for e in newexprs)
471
+ return exprs
472
+
473
+ @staticmethod
474
+ def _dijkstra(sys1, sys2):
475
+ # Use Dijkstra algorithm to find the shortest path between two indirectly-connected
476
+ # coordinate systems
477
+ # return value is the list of the names of the systems.
478
+ relations = sys1.relations
479
+ graph = {}
480
+ for s1, s2 in relations.keys():
481
+ if s1 not in graph:
482
+ graph[s1] = {s2}
483
+ else:
484
+ graph[s1].add(s2)
485
+ if s2 not in graph:
486
+ graph[s2] = {s1}
487
+ else:
488
+ graph[s2].add(s1)
489
+
490
+ path_dict = {sys:[0, [], 0] for sys in graph} # minimum distance, path, times of visited
491
+
492
+ def visit(sys):
493
+ path_dict[sys][2] = 1
494
+ for newsys in graph[sys]:
495
+ distance = path_dict[sys][0] + 1
496
+ if path_dict[newsys][0] >= distance or not path_dict[newsys][1]:
497
+ path_dict[newsys][0] = distance
498
+ path_dict[newsys][1] = list(path_dict[sys][1])
499
+ path_dict[newsys][1].append(sys)
500
+
501
+ visit(sys1.name)
502
+
503
+ while True:
504
+ min_distance = max(path_dict.values(), key=lambda x:x[0])[0]
505
+ newsys = None
506
+ for sys, lst in path_dict.items():
507
+ if 0 < lst[0] <= min_distance and not lst[2]:
508
+ min_distance = lst[0]
509
+ newsys = sys
510
+ if newsys is None:
511
+ break
512
+ visit(newsys)
513
+
514
+ result = path_dict[sys2.name][1]
515
+ result.append(sys2.name)
516
+
517
+ if result == [sys2.name]:
518
+ raise KeyError("Two coordinate systems are not connected.")
519
+ return result
520
+
521
+ def connect_to(self, to_sys, from_coords, to_exprs, inverse=True, fill_in_gaps=False):
522
+ sympy_deprecation_warning(
523
+ """
524
+ The CoordSystem.connect_to() method is deprecated. Instead,
525
+ generate a new instance of CoordSystem with the 'relations'
526
+ keyword argument (CoordSystem classes are now immutable).
527
+ """,
528
+ deprecated_since_version="1.7",
529
+ active_deprecations_target="deprecated-diffgeom-mutable",
530
+ )
531
+
532
+ from_coords, to_exprs = dummyfy(from_coords, to_exprs)
533
+ self.transforms[to_sys] = Matrix(from_coords), Matrix(to_exprs)
534
+
535
+ if inverse:
536
+ to_sys.transforms[self] = self._inv_transf(from_coords, to_exprs)
537
+
538
+ if fill_in_gaps:
539
+ self._fill_gaps_in_transformations()
540
+
541
+ @staticmethod
542
+ def _inv_transf(from_coords, to_exprs):
543
+ # Will be removed when connect_to is removed
544
+ inv_from = [i.as_dummy() for i in from_coords]
545
+ inv_to = solve(
546
+ [t[0] - t[1] for t in zip(inv_from, to_exprs)],
547
+ list(from_coords), dict=True)[0]
548
+ inv_to = [inv_to[fc] for fc in from_coords]
549
+ return Matrix(inv_from), Matrix(inv_to)
550
+
551
+ @staticmethod
552
+ def _fill_gaps_in_transformations():
553
+ # Will be removed when connect_to is removed
554
+ raise NotImplementedError
555
+
556
+ ##########################################################################
557
+ # Coordinate transformations
558
+ ##########################################################################
559
+
560
+ def transform(self, sys, coordinates=None):
561
+ """
562
+ Return the result of coordinate transformation from *self* to *sys*.
563
+ If coordinates are not given, coordinate symbols of *self* are used.
564
+
565
+ Parameters
566
+ ==========
567
+
568
+ sys : CoordSystem
569
+
570
+ coordinates : Any iterable, optional.
571
+
572
+ Returns
573
+ =======
574
+
575
+ sympy.ImmutableDenseMatrix containing CoordinateSymbol
576
+
577
+ Examples
578
+ ========
579
+
580
+ >>> from sympy.diffgeom.rn import R2_r, R2_p
581
+ >>> R2_r.transform(R2_p)
582
+ Matrix([
583
+ [sqrt(x**2 + y**2)],
584
+ [ atan2(y, x)]])
585
+ >>> R2_r.transform(R2_p, [0, 1])
586
+ Matrix([
587
+ [ 1],
588
+ [pi/2]])
589
+
590
+ """
591
+ if coordinates is None:
592
+ coordinates = self.symbols
593
+ if self != sys:
594
+ transf = self.transformation(sys)
595
+ coordinates = transf(*coordinates)
596
+ else:
597
+ coordinates = Matrix(coordinates)
598
+ return coordinates
599
+
600
+ def coord_tuple_transform_to(self, to_sys, coords):
601
+ """Transform ``coords`` to coord system ``to_sys``."""
602
+ sympy_deprecation_warning(
603
+ """
604
+ The CoordSystem.coord_tuple_transform_to() method is deprecated.
605
+ Use the CoordSystem.transform() method instead.
606
+ """,
607
+ deprecated_since_version="1.7",
608
+ active_deprecations_target="deprecated-diffgeom-mutable",
609
+ )
610
+
611
+ coords = Matrix(coords)
612
+ if self != to_sys:
613
+ with ignore_warnings(SymPyDeprecationWarning):
614
+ transf = self.transforms[to_sys]
615
+ coords = transf[1].subs(list(zip(transf[0], coords)))
616
+ return coords
617
+
618
+ def jacobian(self, sys, coordinates=None):
619
+ """
620
+ Return the jacobian matrix of a transformation on given coordinates.
621
+ If coordinates are not given, coordinate symbols of *self* are used.
622
+
623
+ Parameters
624
+ ==========
625
+
626
+ sys : CoordSystem
627
+
628
+ coordinates : Any iterable, optional.
629
+
630
+ Returns
631
+ =======
632
+
633
+ sympy.ImmutableDenseMatrix
634
+
635
+ Examples
636
+ ========
637
+
638
+ >>> from sympy.diffgeom.rn import R2_r, R2_p
639
+ >>> R2_p.jacobian(R2_r)
640
+ Matrix([
641
+ [cos(theta), -rho*sin(theta)],
642
+ [sin(theta), rho*cos(theta)]])
643
+ >>> R2_p.jacobian(R2_r, [1, 0])
644
+ Matrix([
645
+ [1, 0],
646
+ [0, 1]])
647
+
648
+ """
649
+ result = self.transform(sys).jacobian(self.symbols)
650
+ if coordinates is not None:
651
+ result = result.subs(list(zip(self.symbols, coordinates)))
652
+ return result
653
+ jacobian_matrix = jacobian
654
+
655
+ def jacobian_determinant(self, sys, coordinates=None):
656
+ """
657
+ Return the jacobian determinant of a transformation on given
658
+ coordinates. If coordinates are not given, coordinate symbols of *self*
659
+ are used.
660
+
661
+ Parameters
662
+ ==========
663
+
664
+ sys : CoordSystem
665
+
666
+ coordinates : Any iterable, optional.
667
+
668
+ Returns
669
+ =======
670
+
671
+ sympy.Expr
672
+
673
+ Examples
674
+ ========
675
+
676
+ >>> from sympy.diffgeom.rn import R2_r, R2_p
677
+ >>> R2_r.jacobian_determinant(R2_p)
678
+ 1/sqrt(x**2 + y**2)
679
+ >>> R2_r.jacobian_determinant(R2_p, [1, 0])
680
+ 1
681
+
682
+ """
683
+ return self.jacobian(sys, coordinates).det()
684
+
685
+
686
+ ##########################################################################
687
+ # Points
688
+ ##########################################################################
689
+
690
+ def point(self, coords):
691
+ """Create a ``Point`` with coordinates given in this coord system."""
692
+ return Point(self, coords)
693
+
694
+ def point_to_coords(self, point):
695
+ """Calculate the coordinates of a point in this coord system."""
696
+ return point.coords(self)
697
+
698
+ ##########################################################################
699
+ # Base fields.
700
+ ##########################################################################
701
+
702
+ def base_scalar(self, coord_index):
703
+ """Return ``BaseScalarField`` that takes a point and returns one of the coordinates."""
704
+ return BaseScalarField(self, coord_index)
705
+ coord_function = base_scalar
706
+
707
+ def base_scalars(self):
708
+ """Returns a list of all coordinate functions.
709
+ For more details see the ``base_scalar`` method of this class."""
710
+ return [self.base_scalar(i) for i in range(self.dim)]
711
+ coord_functions = base_scalars
712
+
713
+ def base_vector(self, coord_index):
714
+ """Return a basis vector field.
715
+ The basis vector field for this coordinate system. It is also an
716
+ operator on scalar fields."""
717
+ return BaseVectorField(self, coord_index)
718
+
719
+ def base_vectors(self):
720
+ """Returns a list of all base vectors.
721
+ For more details see the ``base_vector`` method of this class."""
722
+ return [self.base_vector(i) for i in range(self.dim)]
723
+
724
+ def base_oneform(self, coord_index):
725
+ """Return a basis 1-form field.
726
+ The basis one-form field for this coordinate system. It is also an
727
+ operator on vector fields."""
728
+ return Differential(self.coord_function(coord_index))
729
+
730
+ def base_oneforms(self):
731
+ """Returns a list of all base oneforms.
732
+ For more details see the ``base_oneform`` method of this class."""
733
+ return [self.base_oneform(i) for i in range(self.dim)]
734
+
735
+
736
+ class CoordinateSymbol(Symbol):
737
+ """A symbol which denotes an abstract value of i-th coordinate of
738
+ the coordinate system with given context.
739
+
740
+ Explanation
741
+ ===========
742
+
743
+ Each coordinates in coordinate system are represented by unique symbol,
744
+ such as x, y, z in Cartesian coordinate system.
745
+
746
+ You may not construct this class directly. Instead, use `symbols` method
747
+ of CoordSystem.
748
+
749
+ Parameters
750
+ ==========
751
+
752
+ coord_sys : CoordSystem
753
+
754
+ index : integer
755
+
756
+ Examples
757
+ ========
758
+
759
+ >>> from sympy import symbols, Lambda, Matrix, sqrt, atan2, cos, sin
760
+ >>> from sympy.diffgeom import Manifold, Patch, CoordSystem
761
+ >>> m = Manifold('M', 2)
762
+ >>> p = Patch('P', m)
763
+ >>> x, y = symbols('x y', real=True)
764
+ >>> r, theta = symbols('r theta', nonnegative=True)
765
+ >>> relation_dict = {
766
+ ... ('Car2D', 'Pol'): Lambda((x, y), Matrix([sqrt(x**2 + y**2), atan2(y, x)])),
767
+ ... ('Pol', 'Car2D'): Lambda((r, theta), Matrix([r*cos(theta), r*sin(theta)]))
768
+ ... }
769
+ >>> Car2D = CoordSystem('Car2D', p, [x, y], relation_dict)
770
+ >>> Pol = CoordSystem('Pol', p, [r, theta], relation_dict)
771
+ >>> x, y = Car2D.symbols
772
+
773
+ ``CoordinateSymbol`` contains its coordinate symbol and index.
774
+
775
+ >>> x.name
776
+ 'x'
777
+ >>> x.coord_sys == Car2D
778
+ True
779
+ >>> x.index
780
+ 0
781
+ >>> x.is_real
782
+ True
783
+
784
+ You can transform ``CoordinateSymbol`` into other coordinate system using
785
+ ``rewrite()`` method.
786
+
787
+ >>> x.rewrite(Pol)
788
+ r*cos(theta)
789
+ >>> sqrt(x**2 + y**2).rewrite(Pol).simplify()
790
+ r
791
+
792
+ """
793
+ def __new__(cls, coord_sys, index, **assumptions):
794
+ name = coord_sys.args[2][index].name
795
+ obj = super().__new__(cls, name, **assumptions)
796
+ obj.coord_sys = coord_sys
797
+ obj.index = index
798
+ return obj
799
+
800
+ def __getnewargs__(self):
801
+ return (self.coord_sys, self.index)
802
+
803
+ def _hashable_content(self):
804
+ return (
805
+ self.coord_sys, self.index
806
+ ) + tuple(sorted(self.assumptions0.items()))
807
+
808
+ def _eval_rewrite(self, rule, args, **hints):
809
+ if isinstance(rule, CoordSystem):
810
+ return rule.transform(self.coord_sys)[self.index]
811
+ return super()._eval_rewrite(rule, args, **hints)
812
+
813
+
814
+ class Point(Basic):
815
+ """Point defined in a coordinate system.
816
+
817
+ Explanation
818
+ ===========
819
+
820
+ Mathematically, point is defined in the manifold and does not have any coordinates
821
+ by itself. Coordinate system is what imbues the coordinates to the point by coordinate
822
+ chart. However, due to the difficulty of realizing such logic, you must supply
823
+ a coordinate system and coordinates to define a Point here.
824
+
825
+ The usage of this object after its definition is independent of the
826
+ coordinate system that was used in order to define it, however due to
827
+ limitations in the simplification routines you can arrive at complicated
828
+ expressions if you use inappropriate coordinate systems.
829
+
830
+ Parameters
831
+ ==========
832
+
833
+ coord_sys : CoordSystem
834
+
835
+ coords : list
836
+ The coordinates of the point.
837
+
838
+ Examples
839
+ ========
840
+
841
+ >>> from sympy import pi
842
+ >>> from sympy.diffgeom import Point
843
+ >>> from sympy.diffgeom.rn import R2, R2_r, R2_p
844
+ >>> rho, theta = R2_p.symbols
845
+
846
+ >>> p = Point(R2_p, [rho, 3*pi/4])
847
+
848
+ >>> p.manifold == R2
849
+ True
850
+
851
+ >>> p.coords()
852
+ Matrix([
853
+ [ rho],
854
+ [3*pi/4]])
855
+ >>> p.coords(R2_r)
856
+ Matrix([
857
+ [-sqrt(2)*rho/2],
858
+ [ sqrt(2)*rho/2]])
859
+
860
+ """
861
+
862
+ def __new__(cls, coord_sys, coords, **kwargs):
863
+ coords = Matrix(coords)
864
+ obj = super().__new__(cls, coord_sys, coords)
865
+ obj._coord_sys = coord_sys
866
+ obj._coords = coords
867
+ return obj
868
+
869
+ @property
870
+ def patch(self):
871
+ return self._coord_sys.patch
872
+
873
+ @property
874
+ def manifold(self):
875
+ return self._coord_sys.manifold
876
+
877
+ @property
878
+ def dim(self):
879
+ return self.manifold.dim
880
+
881
+ def coords(self, sys=None):
882
+ """
883
+ Coordinates of the point in given coordinate system. If coordinate system
884
+ is not passed, it returns the coordinates in the coordinate system in which
885
+ the point was defined.
886
+ """
887
+ if sys is None:
888
+ return self._coords
889
+ else:
890
+ return self._coord_sys.transform(sys, self._coords)
891
+
892
+ @property
893
+ def free_symbols(self):
894
+ return self._coords.free_symbols
895
+
896
+
897
+ class BaseScalarField(Expr):
898
+ """Base scalar field over a manifold for a given coordinate system.
899
+
900
+ Explanation
901
+ ===========
902
+
903
+ A scalar field takes a point as an argument and returns a scalar.
904
+ A base scalar field of a coordinate system takes a point and returns one of
905
+ the coordinates of that point in the coordinate system in question.
906
+
907
+ To define a scalar field you need to choose the coordinate system and the
908
+ index of the coordinate.
909
+
910
+ The use of the scalar field after its definition is independent of the
911
+ coordinate system in which it was defined, however due to limitations in
912
+ the simplification routines you may arrive at more complicated
913
+ expression if you use unappropriate coordinate systems.
914
+ You can build complicated scalar fields by just building up SymPy
915
+ expressions containing ``BaseScalarField`` instances.
916
+
917
+ Parameters
918
+ ==========
919
+
920
+ coord_sys : CoordSystem
921
+
922
+ index : integer
923
+
924
+ Examples
925
+ ========
926
+
927
+ >>> from sympy import Function, pi
928
+ >>> from sympy.diffgeom import BaseScalarField
929
+ >>> from sympy.diffgeom.rn import R2_r, R2_p
930
+ >>> rho, _ = R2_p.symbols
931
+ >>> point = R2_p.point([rho, 0])
932
+ >>> fx, fy = R2_r.base_scalars()
933
+ >>> ftheta = BaseScalarField(R2_r, 1)
934
+
935
+ >>> fx(point)
936
+ rho
937
+ >>> fy(point)
938
+ 0
939
+
940
+ >>> (fx**2+fy**2).rcall(point)
941
+ rho**2
942
+
943
+ >>> g = Function('g')
944
+ >>> fg = g(ftheta-pi)
945
+ >>> fg.rcall(point)
946
+ g(-pi)
947
+
948
+ """
949
+
950
+ is_commutative = True
951
+
952
+ def __new__(cls, coord_sys, index, **kwargs):
953
+ index = _sympify(index)
954
+ obj = super().__new__(cls, coord_sys, index)
955
+ obj._coord_sys = coord_sys
956
+ obj._index = index
957
+ return obj
958
+
959
+ @property
960
+ def coord_sys(self):
961
+ return self.args[0]
962
+
963
+ @property
964
+ def index(self):
965
+ return self.args[1]
966
+
967
+ @property
968
+ def patch(self):
969
+ return self.coord_sys.patch
970
+
971
+ @property
972
+ def manifold(self):
973
+ return self.coord_sys.manifold
974
+
975
+ @property
976
+ def dim(self):
977
+ return self.manifold.dim
978
+
979
+ def __call__(self, *args):
980
+ """Evaluating the field at a point or doing nothing.
981
+ If the argument is a ``Point`` instance, the field is evaluated at that
982
+ point. The field is returned itself if the argument is any other
983
+ object. It is so in order to have working recursive calling mechanics
984
+ for all fields (check the ``__call__`` method of ``Expr``).
985
+ """
986
+ point = args[0]
987
+ if len(args) != 1 or not isinstance(point, Point):
988
+ return self
989
+ coords = point.coords(self._coord_sys)
990
+ # XXX Calling doit is necessary with all the Subs expressions
991
+ # XXX Calling simplify is necessary with all the trig expressions
992
+ return simplify(coords[self._index]).doit()
993
+
994
+ # XXX Workaround for limitations on the content of args
995
+ free_symbols: set[Any] = set()
996
+
997
+
998
+ class BaseVectorField(Expr):
999
+ r"""Base vector field over a manifold for a given coordinate system.
1000
+
1001
+ Explanation
1002
+ ===========
1003
+
1004
+ A vector field is an operator taking a scalar field and returning a
1005
+ directional derivative (which is also a scalar field).
1006
+ A base vector field is the same type of operator, however the derivation is
1007
+ specifically done with respect to a chosen coordinate.
1008
+
1009
+ To define a base vector field you need to choose the coordinate system and
1010
+ the index of the coordinate.
1011
+
1012
+ The use of the vector field after its definition is independent of the
1013
+ coordinate system in which it was defined, however due to limitations in the
1014
+ simplification routines you may arrive at more complicated expression if you
1015
+ use unappropriate coordinate systems.
1016
+
1017
+ Parameters
1018
+ ==========
1019
+ coord_sys : CoordSystem
1020
+
1021
+ index : integer
1022
+
1023
+ Examples
1024
+ ========
1025
+
1026
+ >>> from sympy import Function
1027
+ >>> from sympy.diffgeom.rn import R2_p, R2_r
1028
+ >>> from sympy.diffgeom import BaseVectorField
1029
+ >>> from sympy import pprint
1030
+
1031
+ >>> x, y = R2_r.symbols
1032
+ >>> rho, theta = R2_p.symbols
1033
+ >>> fx, fy = R2_r.base_scalars()
1034
+ >>> point_p = R2_p.point([rho, theta])
1035
+ >>> point_r = R2_r.point([x, y])
1036
+
1037
+ >>> g = Function('g')
1038
+ >>> s_field = g(fx, fy)
1039
+
1040
+ >>> v = BaseVectorField(R2_r, 1)
1041
+ >>> pprint(v(s_field))
1042
+ / d \|
1043
+ |---(g(x, xi))||
1044
+ \dxi /|xi=y
1045
+ >>> pprint(v(s_field).rcall(point_r).doit())
1046
+ d
1047
+ --(g(x, y))
1048
+ dy
1049
+ >>> pprint(v(s_field).rcall(point_p))
1050
+ / d \|
1051
+ |---(g(rho*cos(theta), xi))||
1052
+ \dxi /|xi=rho*sin(theta)
1053
+
1054
+ """
1055
+
1056
+ is_commutative = False
1057
+
1058
+ def __new__(cls, coord_sys, index, **kwargs):
1059
+ index = _sympify(index)
1060
+ obj = super().__new__(cls, coord_sys, index)
1061
+ obj._coord_sys = coord_sys
1062
+ obj._index = index
1063
+ return obj
1064
+
1065
+ @property
1066
+ def coord_sys(self):
1067
+ return self.args[0]
1068
+
1069
+ @property
1070
+ def index(self):
1071
+ return self.args[1]
1072
+
1073
+ @property
1074
+ def patch(self):
1075
+ return self.coord_sys.patch
1076
+
1077
+ @property
1078
+ def manifold(self):
1079
+ return self.coord_sys.manifold
1080
+
1081
+ @property
1082
+ def dim(self):
1083
+ return self.manifold.dim
1084
+
1085
+ def __call__(self, scalar_field):
1086
+ """Apply on a scalar field.
1087
+ The action of a vector field on a scalar field is a directional
1088
+ differentiation.
1089
+ If the argument is not a scalar field an error is raised.
1090
+ """
1091
+ if covariant_order(scalar_field) or contravariant_order(scalar_field):
1092
+ raise ValueError('Only scalar fields can be supplied as arguments to vector fields.')
1093
+
1094
+ if scalar_field is None:
1095
+ return self
1096
+
1097
+ base_scalars = list(scalar_field.atoms(BaseScalarField))
1098
+
1099
+ # First step: e_x(x+r**2) -> e_x(x) + 2*r*e_x(r)
1100
+ d_var = self._coord_sys._dummy
1101
+ # TODO: you need a real dummy function for the next line
1102
+ d_funcs = [Function('_#_%s' % i)(d_var) for i,
1103
+ b in enumerate(base_scalars)]
1104
+ d_result = scalar_field.subs(list(zip(base_scalars, d_funcs)))
1105
+ d_result = d_result.diff(d_var)
1106
+
1107
+ # Second step: e_x(x) -> 1 and e_x(r) -> cos(atan2(x, y))
1108
+ coords = self._coord_sys.symbols
1109
+ d_funcs_deriv = [f.diff(d_var) for f in d_funcs]
1110
+ d_funcs_deriv_sub = []
1111
+ for b in base_scalars:
1112
+ jac = self._coord_sys.jacobian(b._coord_sys, coords)
1113
+ d_funcs_deriv_sub.append(jac[b._index, self._index])
1114
+ d_result = d_result.subs(list(zip(d_funcs_deriv, d_funcs_deriv_sub)))
1115
+
1116
+ # Remove the dummies
1117
+ result = d_result.subs(list(zip(d_funcs, base_scalars)))
1118
+ result = result.subs(list(zip(coords, self._coord_sys.coord_functions())))
1119
+ return result.doit()
1120
+
1121
+
1122
+ def _find_coords(expr):
1123
+ # Finds CoordinateSystems existing in expr
1124
+ fields = expr.atoms(BaseScalarField, BaseVectorField)
1125
+ return {f._coord_sys for f in fields}
1126
+
1127
+
1128
+ class Commutator(Expr):
1129
+ r"""Commutator of two vector fields.
1130
+
1131
+ Explanation
1132
+ ===========
1133
+
1134
+ The commutator of two vector fields `v_1` and `v_2` is defined as the
1135
+ vector field `[v_1, v_2]` that evaluated on each scalar field `f` is equal
1136
+ to `v_1(v_2(f)) - v_2(v_1(f))`.
1137
+
1138
+ Examples
1139
+ ========
1140
+
1141
+
1142
+ >>> from sympy.diffgeom.rn import R2_p, R2_r
1143
+ >>> from sympy.diffgeom import Commutator
1144
+ >>> from sympy import simplify
1145
+
1146
+ >>> fx, fy = R2_r.base_scalars()
1147
+ >>> e_x, e_y = R2_r.base_vectors()
1148
+ >>> e_r = R2_p.base_vector(0)
1149
+
1150
+ >>> c_xy = Commutator(e_x, e_y)
1151
+ >>> c_xr = Commutator(e_x, e_r)
1152
+ >>> c_xy
1153
+ 0
1154
+
1155
+ Unfortunately, the current code is not able to compute everything:
1156
+
1157
+ >>> c_xr
1158
+ Commutator(e_x, e_rho)
1159
+ >>> simplify(c_xr(fy**2))
1160
+ -2*cos(theta)*y**2/(x**2 + y**2)
1161
+
1162
+ """
1163
+ def __new__(cls, v1, v2):
1164
+ if (covariant_order(v1) or contravariant_order(v1) != 1
1165
+ or covariant_order(v2) or contravariant_order(v2) != 1):
1166
+ raise ValueError(
1167
+ 'Only commutators of vector fields are supported.')
1168
+ if v1 == v2:
1169
+ return S.Zero
1170
+ coord_sys = set().union(*[_find_coords(v) for v in (v1, v2)])
1171
+ if len(coord_sys) == 1:
1172
+ # Only one coordinate systems is used, hence it is easy enough to
1173
+ # actually evaluate the commutator.
1174
+ if all(isinstance(v, BaseVectorField) for v in (v1, v2)):
1175
+ return S.Zero
1176
+ bases_1, bases_2 = [list(v.atoms(BaseVectorField))
1177
+ for v in (v1, v2)]
1178
+ coeffs_1 = [v1.expand().coeff(b) for b in bases_1]
1179
+ coeffs_2 = [v2.expand().coeff(b) for b in bases_2]
1180
+ res = 0
1181
+ for c1, b1 in zip(coeffs_1, bases_1):
1182
+ for c2, b2 in zip(coeffs_2, bases_2):
1183
+ res += c1*b1(c2)*b2 - c2*b2(c1)*b1
1184
+ return res
1185
+ else:
1186
+ obj = super().__new__(cls, v1, v2)
1187
+ obj._v1 = v1 # deprecated assignment
1188
+ obj._v2 = v2 # deprecated assignment
1189
+ return obj
1190
+
1191
+ @property
1192
+ def v1(self):
1193
+ return self.args[0]
1194
+
1195
+ @property
1196
+ def v2(self):
1197
+ return self.args[1]
1198
+
1199
+ def __call__(self, scalar_field):
1200
+ """Apply on a scalar field.
1201
+ If the argument is not a scalar field an error is raised.
1202
+ """
1203
+ return self.v1(self.v2(scalar_field)) - self.v2(self.v1(scalar_field))
1204
+
1205
+
1206
+ class Differential(Expr):
1207
+ r"""Return the differential (exterior derivative) of a form field.
1208
+
1209
+ Explanation
1210
+ ===========
1211
+
1212
+ The differential of a form (i.e. the exterior derivative) has a complicated
1213
+ definition in the general case.
1214
+ The differential `df` of the 0-form `f` is defined for any vector field `v`
1215
+ as `df(v) = v(f)`.
1216
+
1217
+ Examples
1218
+ ========
1219
+
1220
+ >>> from sympy import Function
1221
+ >>> from sympy.diffgeom.rn import R2_r
1222
+ >>> from sympy.diffgeom import Differential
1223
+ >>> from sympy import pprint
1224
+
1225
+ >>> fx, fy = R2_r.base_scalars()
1226
+ >>> e_x, e_y = R2_r.base_vectors()
1227
+ >>> g = Function('g')
1228
+ >>> s_field = g(fx, fy)
1229
+ >>> dg = Differential(s_field)
1230
+
1231
+ >>> dg
1232
+ d(g(x, y))
1233
+ >>> pprint(dg(e_x))
1234
+ / d \|
1235
+ |---(g(xi, y))||
1236
+ \dxi /|xi=x
1237
+ >>> pprint(dg(e_y))
1238
+ / d \|
1239
+ |---(g(x, xi))||
1240
+ \dxi /|xi=y
1241
+
1242
+ Applying the exterior derivative operator twice always results in:
1243
+
1244
+ >>> Differential(dg)
1245
+ 0
1246
+ """
1247
+
1248
+ is_commutative = False
1249
+
1250
+ def __new__(cls, form_field):
1251
+ if contravariant_order(form_field):
1252
+ raise ValueError(
1253
+ 'A vector field was supplied as an argument to Differential.')
1254
+ if isinstance(form_field, Differential):
1255
+ return S.Zero
1256
+ else:
1257
+ obj = super().__new__(cls, form_field)
1258
+ obj._form_field = form_field # deprecated assignment
1259
+ return obj
1260
+
1261
+ @property
1262
+ def form_field(self):
1263
+ return self.args[0]
1264
+
1265
+ def __call__(self, *vector_fields):
1266
+ """Apply on a list of vector_fields.
1267
+
1268
+ Explanation
1269
+ ===========
1270
+
1271
+ If the number of vector fields supplied is not equal to 1 + the order of
1272
+ the form field inside the differential the result is undefined.
1273
+
1274
+ For 1-forms (i.e. differentials of scalar fields) the evaluation is
1275
+ done as `df(v)=v(f)`. However if `v` is ``None`` instead of a vector
1276
+ field, the differential is returned unchanged. This is done in order to
1277
+ permit partial contractions for higher forms.
1278
+
1279
+ In the general case the evaluation is done by applying the form field
1280
+ inside the differential on a list with one less elements than the number
1281
+ of elements in the original list. Lowering the number of vector fields
1282
+ is achieved through replacing each pair of fields by their
1283
+ commutator.
1284
+
1285
+ If the arguments are not vectors or ``None``s an error is raised.
1286
+ """
1287
+ if any((contravariant_order(a) != 1 or covariant_order(a)) and a is not None
1288
+ for a in vector_fields):
1289
+ raise ValueError('The arguments supplied to Differential should be vector fields or Nones.')
1290
+ k = len(vector_fields)
1291
+ if k == 1:
1292
+ if vector_fields[0]:
1293
+ return vector_fields[0].rcall(self._form_field)
1294
+ return self
1295
+ else:
1296
+ # For higher form it is more complicated:
1297
+ # Invariant formula:
1298
+ # https://en.wikipedia.org/wiki/Exterior_derivative#Invariant_formula
1299
+ # df(v1, ... vn) = +/- vi(f(v1..no i..vn))
1300
+ # +/- f([vi,vj],v1..no i, no j..vn)
1301
+ f = self._form_field
1302
+ v = vector_fields
1303
+ ret = 0
1304
+ for i in range(k):
1305
+ t = v[i].rcall(f.rcall(*v[:i] + v[i + 1:]))
1306
+ ret += (-1)**i*t
1307
+ for j in range(i + 1, k):
1308
+ c = Commutator(v[i], v[j])
1309
+ if c: # TODO this is ugly - the Commutator can be Zero and
1310
+ # this causes the next line to fail
1311
+ t = f.rcall(*(c,) + v[:i] + v[i + 1:j] + v[j + 1:])
1312
+ ret += (-1)**(i + j)*t
1313
+ return ret
1314
+
1315
+
1316
+ class TensorProduct(Expr):
1317
+ """Tensor product of forms.
1318
+
1319
+ Explanation
1320
+ ===========
1321
+
1322
+ The tensor product permits the creation of multilinear functionals (i.e.
1323
+ higher order tensors) out of lower order fields (e.g. 1-forms and vector
1324
+ fields). However, the higher tensors thus created lack the interesting
1325
+ features provided by the other type of product, the wedge product, namely
1326
+ they are not antisymmetric and hence are not form fields.
1327
+
1328
+ Examples
1329
+ ========
1330
+
1331
+ >>> from sympy.diffgeom.rn import R2_r
1332
+ >>> from sympy.diffgeom import TensorProduct
1333
+
1334
+ >>> fx, fy = R2_r.base_scalars()
1335
+ >>> e_x, e_y = R2_r.base_vectors()
1336
+ >>> dx, dy = R2_r.base_oneforms()
1337
+
1338
+ >>> TensorProduct(dx, dy)(e_x, e_y)
1339
+ 1
1340
+ >>> TensorProduct(dx, dy)(e_y, e_x)
1341
+ 0
1342
+ >>> TensorProduct(dx, fx*dy)(fx*e_x, e_y)
1343
+ x**2
1344
+ >>> TensorProduct(e_x, e_y)(fx**2, fy**2)
1345
+ 4*x*y
1346
+ >>> TensorProduct(e_y, dx)(fy)
1347
+ dx
1348
+
1349
+ You can nest tensor products.
1350
+
1351
+ >>> tp1 = TensorProduct(dx, dy)
1352
+ >>> TensorProduct(tp1, dx)(e_x, e_y, e_x)
1353
+ 1
1354
+
1355
+ You can make partial contraction for instance when 'raising an index'.
1356
+ Putting ``None`` in the second argument of ``rcall`` means that the
1357
+ respective position in the tensor product is left as it is.
1358
+
1359
+ >>> TP = TensorProduct
1360
+ >>> metric = TP(dx, dx) + 3*TP(dy, dy)
1361
+ >>> metric.rcall(e_y, None)
1362
+ 3*dy
1363
+
1364
+ Or automatically pad the args with ``None`` without specifying them.
1365
+
1366
+ >>> metric.rcall(e_y)
1367
+ 3*dy
1368
+
1369
+ """
1370
+ def __new__(cls, *args):
1371
+ scalar = Mul(*[m for m in args if covariant_order(m) + contravariant_order(m) == 0])
1372
+ multifields = [m for m in args if covariant_order(m) + contravariant_order(m)]
1373
+ if multifields:
1374
+ if len(multifields) == 1:
1375
+ return scalar*multifields[0]
1376
+ return scalar*super().__new__(cls, *multifields)
1377
+ else:
1378
+ return scalar
1379
+
1380
+ def __call__(self, *fields):
1381
+ """Apply on a list of fields.
1382
+
1383
+ If the number of input fields supplied is not equal to the order of
1384
+ the tensor product field, the list of arguments is padded with ``None``'s.
1385
+
1386
+ The list of arguments is divided in sublists depending on the order of
1387
+ the forms inside the tensor product. The sublists are provided as
1388
+ arguments to these forms and the resulting expressions are given to the
1389
+ constructor of ``TensorProduct``.
1390
+
1391
+ """
1392
+ tot_order = covariant_order(self) + contravariant_order(self)
1393
+ tot_args = len(fields)
1394
+ if tot_args != tot_order:
1395
+ fields = list(fields) + [None]*(tot_order - tot_args)
1396
+ orders = [covariant_order(f) + contravariant_order(f) for f in self._args]
1397
+ indices = [sum(orders[:i + 1]) for i in range(len(orders) - 1)]
1398
+ fields = [fields[i:j] for i, j in zip([0] + indices, indices + [None])]
1399
+ multipliers = [t[0].rcall(*t[1]) for t in zip(self._args, fields)]
1400
+ return TensorProduct(*multipliers)
1401
+
1402
+
1403
+ class WedgeProduct(TensorProduct):
1404
+ """Wedge product of forms.
1405
+
1406
+ Explanation
1407
+ ===========
1408
+
1409
+ In the context of integration only completely antisymmetric forms make
1410
+ sense. The wedge product permits the creation of such forms.
1411
+
1412
+ Examples
1413
+ ========
1414
+
1415
+ >>> from sympy.diffgeom.rn import R2_r
1416
+ >>> from sympy.diffgeom import WedgeProduct
1417
+
1418
+ >>> fx, fy = R2_r.base_scalars()
1419
+ >>> e_x, e_y = R2_r.base_vectors()
1420
+ >>> dx, dy = R2_r.base_oneforms()
1421
+
1422
+ >>> WedgeProduct(dx, dy)(e_x, e_y)
1423
+ 1
1424
+ >>> WedgeProduct(dx, dy)(e_y, e_x)
1425
+ -1
1426
+ >>> WedgeProduct(dx, fx*dy)(fx*e_x, e_y)
1427
+ x**2
1428
+ >>> WedgeProduct(e_x, e_y)(fy, None)
1429
+ -e_x
1430
+
1431
+ You can nest wedge products.
1432
+
1433
+ >>> wp1 = WedgeProduct(dx, dy)
1434
+ >>> WedgeProduct(wp1, dx)(e_x, e_y, e_x)
1435
+ 0
1436
+
1437
+ """
1438
+ # TODO the calculation of signatures is slow
1439
+ # TODO you do not need all these permutations (neither the prefactor)
1440
+ def __call__(self, *fields):
1441
+ """Apply on a list of vector_fields.
1442
+ The expression is rewritten internally in terms of tensor products and evaluated."""
1443
+ orders = (covariant_order(e) + contravariant_order(e) for e in self.args)
1444
+ mul = 1/Mul(*(factorial(o) for o in orders))
1445
+ perms = permutations(fields)
1446
+ perms_par = (Permutation(
1447
+ p).signature() for p in permutations(range(len(fields))))
1448
+ tensor_prod = TensorProduct(*self.args)
1449
+ return mul*Add(*[tensor_prod(*p[0])*p[1] for p in zip(perms, perms_par)])
1450
+
1451
+
1452
+ class LieDerivative(Expr):
1453
+ """Lie derivative with respect to a vector field.
1454
+
1455
+ Explanation
1456
+ ===========
1457
+
1458
+ The transport operator that defines the Lie derivative is the pushforward of
1459
+ the field to be derived along the integral curve of the field with respect
1460
+ to which one derives.
1461
+
1462
+ Examples
1463
+ ========
1464
+
1465
+ >>> from sympy.diffgeom.rn import R2_r, R2_p
1466
+ >>> from sympy.diffgeom import (LieDerivative, TensorProduct)
1467
+
1468
+ >>> fx, fy = R2_r.base_scalars()
1469
+ >>> e_x, e_y = R2_r.base_vectors()
1470
+ >>> e_rho, e_theta = R2_p.base_vectors()
1471
+ >>> dx, dy = R2_r.base_oneforms()
1472
+
1473
+ >>> LieDerivative(e_x, fy)
1474
+ 0
1475
+ >>> LieDerivative(e_x, fx)
1476
+ 1
1477
+ >>> LieDerivative(e_x, e_x)
1478
+ 0
1479
+
1480
+ The Lie derivative of a tensor field by another tensor field is equal to
1481
+ their commutator:
1482
+
1483
+ >>> LieDerivative(e_x, e_rho)
1484
+ Commutator(e_x, e_rho)
1485
+ >>> LieDerivative(e_x + e_y, fx)
1486
+ 1
1487
+
1488
+ >>> tp = TensorProduct(dx, dy)
1489
+ >>> LieDerivative(e_x, tp)
1490
+ LieDerivative(e_x, TensorProduct(dx, dy))
1491
+ >>> LieDerivative(e_x, tp)
1492
+ LieDerivative(e_x, TensorProduct(dx, dy))
1493
+
1494
+ """
1495
+ def __new__(cls, v_field, expr):
1496
+ expr_form_ord = covariant_order(expr)
1497
+ if contravariant_order(v_field) != 1 or covariant_order(v_field):
1498
+ raise ValueError('Lie derivatives are defined only with respect to'
1499
+ ' vector fields. The supplied argument was not a '
1500
+ 'vector field.')
1501
+ if expr_form_ord > 0:
1502
+ obj = super().__new__(cls, v_field, expr)
1503
+ # deprecated assignments
1504
+ obj._v_field = v_field
1505
+ obj._expr = expr
1506
+ return obj
1507
+ if expr.atoms(BaseVectorField):
1508
+ return Commutator(v_field, expr)
1509
+ else:
1510
+ return v_field.rcall(expr)
1511
+
1512
+ @property
1513
+ def v_field(self):
1514
+ return self.args[0]
1515
+
1516
+ @property
1517
+ def expr(self):
1518
+ return self.args[1]
1519
+
1520
+ def __call__(self, *args):
1521
+ v = self.v_field
1522
+ expr = self.expr
1523
+ lead_term = v(expr(*args))
1524
+ rest = Add(*[Mul(*args[:i] + (Commutator(v, args[i]),) + args[i + 1:])
1525
+ for i in range(len(args))])
1526
+ return lead_term - rest
1527
+
1528
+
1529
+ class BaseCovarDerivativeOp(Expr):
1530
+ """Covariant derivative operator with respect to a base vector.
1531
+
1532
+ Examples
1533
+ ========
1534
+
1535
+ >>> from sympy.diffgeom.rn import R2_r
1536
+ >>> from sympy.diffgeom import BaseCovarDerivativeOp
1537
+ >>> from sympy.diffgeom import metric_to_Christoffel_2nd, TensorProduct
1538
+
1539
+ >>> TP = TensorProduct
1540
+ >>> fx, fy = R2_r.base_scalars()
1541
+ >>> e_x, e_y = R2_r.base_vectors()
1542
+ >>> dx, dy = R2_r.base_oneforms()
1543
+
1544
+ >>> ch = metric_to_Christoffel_2nd(TP(dx, dx) + TP(dy, dy))
1545
+ >>> ch
1546
+ [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]
1547
+ >>> cvd = BaseCovarDerivativeOp(R2_r, 0, ch)
1548
+ >>> cvd(fx)
1549
+ 1
1550
+ >>> cvd(fx*e_x)
1551
+ e_x
1552
+ """
1553
+
1554
+ def __new__(cls, coord_sys, index, christoffel):
1555
+ index = _sympify(index)
1556
+ christoffel = ImmutableDenseNDimArray(christoffel)
1557
+ obj = super().__new__(cls, coord_sys, index, christoffel)
1558
+ # deprecated assignments
1559
+ obj._coord_sys = coord_sys
1560
+ obj._index = index
1561
+ obj._christoffel = christoffel
1562
+ return obj
1563
+
1564
+ @property
1565
+ def coord_sys(self):
1566
+ return self.args[0]
1567
+
1568
+ @property
1569
+ def index(self):
1570
+ return self.args[1]
1571
+
1572
+ @property
1573
+ def christoffel(self):
1574
+ return self.args[2]
1575
+
1576
+ def __call__(self, field):
1577
+ """Apply on a scalar field.
1578
+
1579
+ The action of a vector field on a scalar field is a directional
1580
+ differentiation.
1581
+ If the argument is not a scalar field the behaviour is undefined.
1582
+ """
1583
+ if covariant_order(field) != 0:
1584
+ raise NotImplementedError()
1585
+
1586
+ field = vectors_in_basis(field, self._coord_sys)
1587
+
1588
+ wrt_vector = self._coord_sys.base_vector(self._index)
1589
+ wrt_scalar = self._coord_sys.coord_function(self._index)
1590
+ vectors = list(field.atoms(BaseVectorField))
1591
+
1592
+ # First step: replace all vectors with something susceptible to
1593
+ # derivation and do the derivation
1594
+ # TODO: you need a real dummy function for the next line
1595
+ d_funcs = [Function('_#_%s' % i)(wrt_scalar) for i,
1596
+ b in enumerate(vectors)]
1597
+ d_result = field.subs(list(zip(vectors, d_funcs)))
1598
+ d_result = wrt_vector(d_result)
1599
+
1600
+ # Second step: backsubstitute the vectors in
1601
+ d_result = d_result.subs(list(zip(d_funcs, vectors)))
1602
+
1603
+ # Third step: evaluate the derivatives of the vectors
1604
+ derivs = []
1605
+ for v in vectors:
1606
+ d = Add(*[(self._christoffel[k, wrt_vector._index, v._index]
1607
+ *v._coord_sys.base_vector(k))
1608
+ for k in range(v._coord_sys.dim)])
1609
+ derivs.append(d)
1610
+ to_subs = [wrt_vector(d) for d in d_funcs]
1611
+ # XXX: This substitution can fail when there are Dummy symbols and the
1612
+ # cache is disabled: https://github.com/sympy/sympy/issues/17794
1613
+ result = d_result.subs(list(zip(to_subs, derivs)))
1614
+
1615
+ # Remove the dummies
1616
+ result = result.subs(list(zip(d_funcs, vectors)))
1617
+ return result.doit()
1618
+
1619
+
1620
+ class CovarDerivativeOp(Expr):
1621
+ """Covariant derivative operator.
1622
+
1623
+ Examples
1624
+ ========
1625
+
1626
+ >>> from sympy.diffgeom.rn import R2_r
1627
+ >>> from sympy.diffgeom import CovarDerivativeOp
1628
+ >>> from sympy.diffgeom import metric_to_Christoffel_2nd, TensorProduct
1629
+ >>> TP = TensorProduct
1630
+ >>> fx, fy = R2_r.base_scalars()
1631
+ >>> e_x, e_y = R2_r.base_vectors()
1632
+ >>> dx, dy = R2_r.base_oneforms()
1633
+ >>> ch = metric_to_Christoffel_2nd(TP(dx, dx) + TP(dy, dy))
1634
+
1635
+ >>> ch
1636
+ [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]
1637
+ >>> cvd = CovarDerivativeOp(fx*e_x, ch)
1638
+ >>> cvd(fx)
1639
+ x
1640
+ >>> cvd(fx*e_x)
1641
+ x*e_x
1642
+
1643
+ """
1644
+
1645
+ def __new__(cls, wrt, christoffel):
1646
+ if len({v._coord_sys for v in wrt.atoms(BaseVectorField)}) > 1:
1647
+ raise NotImplementedError()
1648
+ if contravariant_order(wrt) != 1 or covariant_order(wrt):
1649
+ raise ValueError('Covariant derivatives are defined only with '
1650
+ 'respect to vector fields. The supplied argument '
1651
+ 'was not a vector field.')
1652
+ christoffel = ImmutableDenseNDimArray(christoffel)
1653
+ obj = super().__new__(cls, wrt, christoffel)
1654
+ # deprecated assignments
1655
+ obj._wrt = wrt
1656
+ obj._christoffel = christoffel
1657
+ return obj
1658
+
1659
+ @property
1660
+ def wrt(self):
1661
+ return self.args[0]
1662
+
1663
+ @property
1664
+ def christoffel(self):
1665
+ return self.args[1]
1666
+
1667
+ def __call__(self, field):
1668
+ vectors = list(self._wrt.atoms(BaseVectorField))
1669
+ base_ops = [BaseCovarDerivativeOp(v._coord_sys, v._index, self._christoffel)
1670
+ for v in vectors]
1671
+ return self._wrt.subs(list(zip(vectors, base_ops))).rcall(field)
1672
+
1673
+
1674
+ ###############################################################################
1675
+ # Integral curves on vector fields
1676
+ ###############################################################################
1677
+ def intcurve_series(vector_field, param, start_point, n=6, coord_sys=None, coeffs=False):
1678
+ r"""Return the series expansion for an integral curve of the field.
1679
+
1680
+ Explanation
1681
+ ===========
1682
+
1683
+ Integral curve is a function `\gamma` taking a parameter in `R` to a point
1684
+ in the manifold. It verifies the equation:
1685
+
1686
+ `V(f)\big(\gamma(t)\big) = \frac{d}{dt}f\big(\gamma(t)\big)`
1687
+
1688
+ where the given ``vector_field`` is denoted as `V`. This holds for any
1689
+ value `t` for the parameter and any scalar field `f`.
1690
+
1691
+ This equation can also be decomposed of a basis of coordinate functions
1692
+ `V(f_i)\big(\gamma(t)\big) = \frac{d}{dt}f_i\big(\gamma(t)\big) \quad \forall i`
1693
+
1694
+ This function returns a series expansion of `\gamma(t)` in terms of the
1695
+ coordinate system ``coord_sys``. The equations and expansions are necessarily
1696
+ done in coordinate-system-dependent way as there is no other way to
1697
+ represent movement between points on the manifold (i.e. there is no such
1698
+ thing as a difference of points for a general manifold).
1699
+
1700
+ Parameters
1701
+ ==========
1702
+ vector_field
1703
+ the vector field for which an integral curve will be given
1704
+
1705
+ param
1706
+ the argument of the function `\gamma` from R to the curve
1707
+
1708
+ start_point
1709
+ the point which corresponds to `\gamma(0)`
1710
+
1711
+ n
1712
+ the order to which to expand
1713
+
1714
+ coord_sys
1715
+ the coordinate system in which to expand
1716
+ coeffs (default False) - if True return a list of elements of the expansion
1717
+
1718
+ Examples
1719
+ ========
1720
+
1721
+ Use the predefined R2 manifold:
1722
+
1723
+ >>> from sympy.abc import t, x, y
1724
+ >>> from sympy.diffgeom.rn import R2_p, R2_r
1725
+ >>> from sympy.diffgeom import intcurve_series
1726
+
1727
+ Specify a starting point and a vector field:
1728
+
1729
+ >>> start_point = R2_r.point([x, y])
1730
+ >>> vector_field = R2_r.e_x
1731
+
1732
+ Calculate the series:
1733
+
1734
+ >>> intcurve_series(vector_field, t, start_point, n=3)
1735
+ Matrix([
1736
+ [t + x],
1737
+ [ y]])
1738
+
1739
+ Or get the elements of the expansion in a list:
1740
+
1741
+ >>> series = intcurve_series(vector_field, t, start_point, n=3, coeffs=True)
1742
+ >>> series[0]
1743
+ Matrix([
1744
+ [x],
1745
+ [y]])
1746
+ >>> series[1]
1747
+ Matrix([
1748
+ [t],
1749
+ [0]])
1750
+ >>> series[2]
1751
+ Matrix([
1752
+ [0],
1753
+ [0]])
1754
+
1755
+ The series in the polar coordinate system:
1756
+
1757
+ >>> series = intcurve_series(vector_field, t, start_point,
1758
+ ... n=3, coord_sys=R2_p, coeffs=True)
1759
+ >>> series[0]
1760
+ Matrix([
1761
+ [sqrt(x**2 + y**2)],
1762
+ [ atan2(y, x)]])
1763
+ >>> series[1]
1764
+ Matrix([
1765
+ [t*x/sqrt(x**2 + y**2)],
1766
+ [ -t*y/(x**2 + y**2)]])
1767
+ >>> series[2]
1768
+ Matrix([
1769
+ [t**2*(-x**2/(x**2 + y**2)**(3/2) + 1/sqrt(x**2 + y**2))/2],
1770
+ [ t**2*x*y/(x**2 + y**2)**2]])
1771
+
1772
+ See Also
1773
+ ========
1774
+
1775
+ intcurve_diffequ
1776
+
1777
+ """
1778
+ if contravariant_order(vector_field) != 1 or covariant_order(vector_field):
1779
+ raise ValueError('The supplied field was not a vector field.')
1780
+
1781
+ def iter_vfield(scalar_field, i):
1782
+ """Return ``vector_field`` called `i` times on ``scalar_field``."""
1783
+ return reduce(lambda s, v: v.rcall(s), [vector_field, ]*i, scalar_field)
1784
+
1785
+ def taylor_terms_per_coord(coord_function):
1786
+ """Return the series for one of the coordinates."""
1787
+ return [param**i*iter_vfield(coord_function, i).rcall(start_point)/factorial(i)
1788
+ for i in range(n)]
1789
+ coord_sys = coord_sys if coord_sys else start_point._coord_sys
1790
+ coord_functions = coord_sys.coord_functions()
1791
+ taylor_terms = [taylor_terms_per_coord(f) for f in coord_functions]
1792
+ if coeffs:
1793
+ return [Matrix(t) for t in zip(*taylor_terms)]
1794
+ else:
1795
+ return Matrix([sum(c) for c in taylor_terms])
1796
+
1797
+
1798
+ def intcurve_diffequ(vector_field, param, start_point, coord_sys=None):
1799
+ r"""Return the differential equation for an integral curve of the field.
1800
+
1801
+ Explanation
1802
+ ===========
1803
+
1804
+ Integral curve is a function `\gamma` taking a parameter in `R` to a point
1805
+ in the manifold. It verifies the equation:
1806
+
1807
+ `V(f)\big(\gamma(t)\big) = \frac{d}{dt}f\big(\gamma(t)\big)`
1808
+
1809
+ where the given ``vector_field`` is denoted as `V`. This holds for any
1810
+ value `t` for the parameter and any scalar field `f`.
1811
+
1812
+ This function returns the differential equation of `\gamma(t)` in terms of the
1813
+ coordinate system ``coord_sys``. The equations and expansions are necessarily
1814
+ done in coordinate-system-dependent way as there is no other way to
1815
+ represent movement between points on the manifold (i.e. there is no such
1816
+ thing as a difference of points for a general manifold).
1817
+
1818
+ Parameters
1819
+ ==========
1820
+
1821
+ vector_field
1822
+ the vector field for which an integral curve will be given
1823
+
1824
+ param
1825
+ the argument of the function `\gamma` from R to the curve
1826
+
1827
+ start_point
1828
+ the point which corresponds to `\gamma(0)`
1829
+
1830
+ coord_sys
1831
+ the coordinate system in which to give the equations
1832
+
1833
+ Returns
1834
+ =======
1835
+
1836
+ a tuple of (equations, initial conditions)
1837
+
1838
+ Examples
1839
+ ========
1840
+
1841
+ Use the predefined R2 manifold:
1842
+
1843
+ >>> from sympy.abc import t
1844
+ >>> from sympy.diffgeom.rn import R2, R2_p, R2_r
1845
+ >>> from sympy.diffgeom import intcurve_diffequ
1846
+
1847
+ Specify a starting point and a vector field:
1848
+
1849
+ >>> start_point = R2_r.point([0, 1])
1850
+ >>> vector_field = -R2.y*R2.e_x + R2.x*R2.e_y
1851
+
1852
+ Get the equation:
1853
+
1854
+ >>> equations, init_cond = intcurve_diffequ(vector_field, t, start_point)
1855
+ >>> equations
1856
+ [f_1(t) + Derivative(f_0(t), t), -f_0(t) + Derivative(f_1(t), t)]
1857
+ >>> init_cond
1858
+ [f_0(0), f_1(0) - 1]
1859
+
1860
+ The series in the polar coordinate system:
1861
+
1862
+ >>> equations, init_cond = intcurve_diffequ(vector_field, t, start_point, R2_p)
1863
+ >>> equations
1864
+ [Derivative(f_0(t), t), Derivative(f_1(t), t) - 1]
1865
+ >>> init_cond
1866
+ [f_0(0) - 1, f_1(0) - pi/2]
1867
+
1868
+ See Also
1869
+ ========
1870
+
1871
+ intcurve_series
1872
+
1873
+ """
1874
+ if contravariant_order(vector_field) != 1 or covariant_order(vector_field):
1875
+ raise ValueError('The supplied field was not a vector field.')
1876
+ coord_sys = coord_sys if coord_sys else start_point._coord_sys
1877
+ gammas = [Function('f_%d' % i)(param) for i in range(
1878
+ start_point._coord_sys.dim)]
1879
+ arbitrary_p = Point(coord_sys, gammas)
1880
+ coord_functions = coord_sys.coord_functions()
1881
+ equations = [simplify(diff(cf.rcall(arbitrary_p), param) - vector_field.rcall(cf).rcall(arbitrary_p))
1882
+ for cf in coord_functions]
1883
+ init_cond = [simplify(cf.rcall(arbitrary_p).subs(param, 0) - cf.rcall(start_point))
1884
+ for cf in coord_functions]
1885
+ return equations, init_cond
1886
+
1887
+
1888
+ ###############################################################################
1889
+ # Helpers
1890
+ ###############################################################################
1891
+ def dummyfy(args, exprs):
1892
+ # TODO Is this a good idea?
1893
+ d_args = Matrix([s.as_dummy() for s in args])
1894
+ reps = dict(zip(args, d_args))
1895
+ d_exprs = Matrix([_sympify(expr).subs(reps) for expr in exprs])
1896
+ return d_args, d_exprs
1897
+
1898
+ ###############################################################################
1899
+ # Helpers
1900
+ ###############################################################################
1901
+ def contravariant_order(expr, _strict=False):
1902
+ """Return the contravariant order of an expression.
1903
+
1904
+ Examples
1905
+ ========
1906
+
1907
+ >>> from sympy.diffgeom import contravariant_order
1908
+ >>> from sympy.diffgeom.rn import R2
1909
+ >>> from sympy.abc import a
1910
+
1911
+ >>> contravariant_order(a)
1912
+ 0
1913
+ >>> contravariant_order(a*R2.x + 2)
1914
+ 0
1915
+ >>> contravariant_order(a*R2.x*R2.e_y + R2.e_x)
1916
+ 1
1917
+
1918
+ """
1919
+ # TODO move some of this to class methods.
1920
+ # TODO rewrite using the .as_blah_blah methods
1921
+ if isinstance(expr, Add):
1922
+ orders = [contravariant_order(e) for e in expr.args]
1923
+ if len(set(orders)) != 1:
1924
+ raise ValueError('Misformed expression containing contravariant fields of varying order.')
1925
+ return orders[0]
1926
+ elif isinstance(expr, Mul):
1927
+ orders = [contravariant_order(e) for e in expr.args]
1928
+ not_zero = [o for o in orders if o != 0]
1929
+ if len(not_zero) > 1:
1930
+ raise ValueError('Misformed expression containing multiplication between vectors.')
1931
+ return 0 if not not_zero else not_zero[0]
1932
+ elif isinstance(expr, Pow):
1933
+ if covariant_order(expr.base) or covariant_order(expr.exp):
1934
+ raise ValueError(
1935
+ 'Misformed expression containing a power of a vector.')
1936
+ return 0
1937
+ elif isinstance(expr, BaseVectorField):
1938
+ return 1
1939
+ elif isinstance(expr, TensorProduct):
1940
+ return sum(contravariant_order(a) for a in expr.args)
1941
+ elif not _strict or expr.atoms(BaseScalarField):
1942
+ return 0
1943
+ else: # If it does not contain anything related to the diffgeom module and it is _strict
1944
+ return -1
1945
+
1946
+
1947
+ def covariant_order(expr, _strict=False):
1948
+ """Return the covariant order of an expression.
1949
+
1950
+ Examples
1951
+ ========
1952
+
1953
+ >>> from sympy.diffgeom import covariant_order
1954
+ >>> from sympy.diffgeom.rn import R2
1955
+ >>> from sympy.abc import a
1956
+
1957
+ >>> covariant_order(a)
1958
+ 0
1959
+ >>> covariant_order(a*R2.x + 2)
1960
+ 0
1961
+ >>> covariant_order(a*R2.x*R2.dy + R2.dx)
1962
+ 1
1963
+
1964
+ """
1965
+ # TODO move some of this to class methods.
1966
+ # TODO rewrite using the .as_blah_blah methods
1967
+ if isinstance(expr, Add):
1968
+ orders = [covariant_order(e) for e in expr.args]
1969
+ if len(set(orders)) != 1:
1970
+ raise ValueError('Misformed expression containing form fields of varying order.')
1971
+ return orders[0]
1972
+ elif isinstance(expr, Mul):
1973
+ orders = [covariant_order(e) for e in expr.args]
1974
+ not_zero = [o for o in orders if o != 0]
1975
+ if len(not_zero) > 1:
1976
+ raise ValueError('Misformed expression containing multiplication between forms.')
1977
+ return 0 if not not_zero else not_zero[0]
1978
+ elif isinstance(expr, Pow):
1979
+ if covariant_order(expr.base) or covariant_order(expr.exp):
1980
+ raise ValueError(
1981
+ 'Misformed expression containing a power of a form.')
1982
+ return 0
1983
+ elif isinstance(expr, Differential):
1984
+ return covariant_order(*expr.args) + 1
1985
+ elif isinstance(expr, TensorProduct):
1986
+ return sum(covariant_order(a) for a in expr.args)
1987
+ elif not _strict or expr.atoms(BaseScalarField):
1988
+ return 0
1989
+ else: # If it does not contain anything related to the diffgeom module and it is _strict
1990
+ return -1
1991
+
1992
+
1993
+ ###############################################################################
1994
+ # Coordinate transformation functions
1995
+ ###############################################################################
1996
+ def vectors_in_basis(expr, to_sys):
1997
+ """Transform all base vectors in base vectors of a specified coord basis.
1998
+ While the new base vectors are in the new coordinate system basis, any
1999
+ coefficients are kept in the old system.
2000
+
2001
+ Examples
2002
+ ========
2003
+
2004
+ >>> from sympy.diffgeom import vectors_in_basis
2005
+ >>> from sympy.diffgeom.rn import R2_r, R2_p
2006
+
2007
+ >>> vectors_in_basis(R2_r.e_x, R2_p)
2008
+ -y*e_theta/(x**2 + y**2) + x*e_rho/sqrt(x**2 + y**2)
2009
+ >>> vectors_in_basis(R2_p.e_r, R2_r)
2010
+ sin(theta)*e_y + cos(theta)*e_x
2011
+
2012
+ """
2013
+ vectors = list(expr.atoms(BaseVectorField))
2014
+ new_vectors = []
2015
+ for v in vectors:
2016
+ cs = v._coord_sys
2017
+ jac = cs.jacobian(to_sys, cs.coord_functions())
2018
+ new = (jac.T*Matrix(to_sys.base_vectors()))[v._index]
2019
+ new_vectors.append(new)
2020
+ return expr.subs(list(zip(vectors, new_vectors)))
2021
+
2022
+
2023
+ ###############################################################################
2024
+ # Coordinate-dependent functions
2025
+ ###############################################################################
2026
+ def twoform_to_matrix(expr):
2027
+ """Return the matrix representing the twoform.
2028
+
2029
+ For the twoform `w` return the matrix `M` such that `M[i,j]=w(e_i, e_j)`,
2030
+ where `e_i` is the i-th base vector field for the coordinate system in
2031
+ which the expression of `w` is given.
2032
+
2033
+ Examples
2034
+ ========
2035
+
2036
+ >>> from sympy.diffgeom.rn import R2
2037
+ >>> from sympy.diffgeom import twoform_to_matrix, TensorProduct
2038
+ >>> TP = TensorProduct
2039
+
2040
+ >>> twoform_to_matrix(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2041
+ Matrix([
2042
+ [1, 0],
2043
+ [0, 1]])
2044
+ >>> twoform_to_matrix(R2.x*TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2045
+ Matrix([
2046
+ [x, 0],
2047
+ [0, 1]])
2048
+ >>> twoform_to_matrix(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy) - TP(R2.dx, R2.dy)/2)
2049
+ Matrix([
2050
+ [ 1, 0],
2051
+ [-1/2, 1]])
2052
+
2053
+ """
2054
+ if covariant_order(expr) != 2 or contravariant_order(expr):
2055
+ raise ValueError('The input expression is not a two-form.')
2056
+ coord_sys = _find_coords(expr)
2057
+ if len(coord_sys) != 1:
2058
+ raise ValueError('The input expression concerns more than one '
2059
+ 'coordinate systems, hence there is no unambiguous '
2060
+ 'way to choose a coordinate system for the matrix.')
2061
+ coord_sys = coord_sys.pop()
2062
+ vectors = coord_sys.base_vectors()
2063
+ expr = expr.expand()
2064
+ matrix_content = [[expr.rcall(v1, v2) for v1 in vectors]
2065
+ for v2 in vectors]
2066
+ return Matrix(matrix_content)
2067
+
2068
+
2069
+ def metric_to_Christoffel_1st(expr):
2070
+ """Return the nested list of Christoffel symbols for the given metric.
2071
+ This returns the Christoffel symbol of first kind that represents the
2072
+ Levi-Civita connection for the given metric.
2073
+
2074
+ Examples
2075
+ ========
2076
+
2077
+ >>> from sympy.diffgeom.rn import R2
2078
+ >>> from sympy.diffgeom import metric_to_Christoffel_1st, TensorProduct
2079
+ >>> TP = TensorProduct
2080
+
2081
+ >>> metric_to_Christoffel_1st(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2082
+ [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]
2083
+ >>> metric_to_Christoffel_1st(R2.x*TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2084
+ [[[1/2, 0], [0, 0]], [[0, 0], [0, 0]]]
2085
+
2086
+ """
2087
+ matrix = twoform_to_matrix(expr)
2088
+ if not matrix.is_symmetric():
2089
+ raise ValueError(
2090
+ 'The two-form representing the metric is not symmetric.')
2091
+ coord_sys = _find_coords(expr).pop()
2092
+ deriv_matrices = [matrix.applyfunc(d) for d in coord_sys.base_vectors()]
2093
+ indices = list(range(coord_sys.dim))
2094
+ christoffel = [[[(deriv_matrices[k][i, j] + deriv_matrices[j][i, k] - deriv_matrices[i][j, k])/2
2095
+ for k in indices]
2096
+ for j in indices]
2097
+ for i in indices]
2098
+ return ImmutableDenseNDimArray(christoffel)
2099
+
2100
+
2101
+ def metric_to_Christoffel_2nd(expr):
2102
+ """Return the nested list of Christoffel symbols for the given metric.
2103
+ This returns the Christoffel symbol of second kind that represents the
2104
+ Levi-Civita connection for the given metric.
2105
+
2106
+ Examples
2107
+ ========
2108
+
2109
+ >>> from sympy.diffgeom.rn import R2
2110
+ >>> from sympy.diffgeom import metric_to_Christoffel_2nd, TensorProduct
2111
+ >>> TP = TensorProduct
2112
+
2113
+ >>> metric_to_Christoffel_2nd(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2114
+ [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]
2115
+ >>> metric_to_Christoffel_2nd(R2.x*TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2116
+ [[[1/(2*x), 0], [0, 0]], [[0, 0], [0, 0]]]
2117
+
2118
+ """
2119
+ ch_1st = metric_to_Christoffel_1st(expr)
2120
+ coord_sys = _find_coords(expr).pop()
2121
+ indices = list(range(coord_sys.dim))
2122
+ # XXX workaround, inverting a matrix does not work if it contains non
2123
+ # symbols
2124
+ #matrix = twoform_to_matrix(expr).inv()
2125
+ matrix = twoform_to_matrix(expr)
2126
+ s_fields = set()
2127
+ for e in matrix:
2128
+ s_fields.update(e.atoms(BaseScalarField))
2129
+ s_fields = list(s_fields)
2130
+ dums = coord_sys.symbols
2131
+ matrix = matrix.subs(list(zip(s_fields, dums))).inv().subs(list(zip(dums, s_fields)))
2132
+ # XXX end of workaround
2133
+ christoffel = [[[Add(*[matrix[i, l]*ch_1st[l, j, k] for l in indices])
2134
+ for k in indices]
2135
+ for j in indices]
2136
+ for i in indices]
2137
+ return ImmutableDenseNDimArray(christoffel)
2138
+
2139
+
2140
+ def metric_to_Riemann_components(expr):
2141
+ """Return the components of the Riemann tensor expressed in a given basis.
2142
+
2143
+ Given a metric it calculates the components of the Riemann tensor in the
2144
+ canonical basis of the coordinate system in which the metric expression is
2145
+ given.
2146
+
2147
+ Examples
2148
+ ========
2149
+
2150
+ >>> from sympy import exp
2151
+ >>> from sympy.diffgeom.rn import R2
2152
+ >>> from sympy.diffgeom import metric_to_Riemann_components, TensorProduct
2153
+ >>> TP = TensorProduct
2154
+
2155
+ >>> metric_to_Riemann_components(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2156
+ [[[[0, 0], [0, 0]], [[0, 0], [0, 0]]], [[[0, 0], [0, 0]], [[0, 0], [0, 0]]]]
2157
+ >>> non_trivial_metric = exp(2*R2.r)*TP(R2.dr, R2.dr) + \
2158
+ R2.r**2*TP(R2.dtheta, R2.dtheta)
2159
+ >>> non_trivial_metric
2160
+ exp(2*rho)*TensorProduct(drho, drho) + rho**2*TensorProduct(dtheta, dtheta)
2161
+ >>> riemann = metric_to_Riemann_components(non_trivial_metric)
2162
+ >>> riemann[0, :, :, :]
2163
+ [[[0, 0], [0, 0]], [[0, exp(-2*rho)*rho], [-exp(-2*rho)*rho, 0]]]
2164
+ >>> riemann[1, :, :, :]
2165
+ [[[0, -1/rho], [1/rho, 0]], [[0, 0], [0, 0]]]
2166
+
2167
+ """
2168
+ ch_2nd = metric_to_Christoffel_2nd(expr)
2169
+ coord_sys = _find_coords(expr).pop()
2170
+ indices = list(range(coord_sys.dim))
2171
+ deriv_ch = [[[[d(ch_2nd[i, j, k])
2172
+ for d in coord_sys.base_vectors()]
2173
+ for k in indices]
2174
+ for j in indices]
2175
+ for i in indices]
2176
+ riemann_a = [[[[deriv_ch[rho][sig][nu][mu] - deriv_ch[rho][sig][mu][nu]
2177
+ for nu in indices]
2178
+ for mu in indices]
2179
+ for sig in indices]
2180
+ for rho in indices]
2181
+ riemann_b = [[[[Add(*[ch_2nd[rho, l, mu]*ch_2nd[l, sig, nu] - ch_2nd[rho, l, nu]*ch_2nd[l, sig, mu] for l in indices])
2182
+ for nu in indices]
2183
+ for mu in indices]
2184
+ for sig in indices]
2185
+ for rho in indices]
2186
+ riemann = [[[[riemann_a[rho][sig][mu][nu] + riemann_b[rho][sig][mu][nu]
2187
+ for nu in indices]
2188
+ for mu in indices]
2189
+ for sig in indices]
2190
+ for rho in indices]
2191
+ return ImmutableDenseNDimArray(riemann)
2192
+
2193
+
2194
+ def metric_to_Ricci_components(expr):
2195
+
2196
+ """Return the components of the Ricci tensor expressed in a given basis.
2197
+
2198
+ Given a metric it calculates the components of the Ricci tensor in the
2199
+ canonical basis of the coordinate system in which the metric expression is
2200
+ given.
2201
+
2202
+ Examples
2203
+ ========
2204
+
2205
+ >>> from sympy import exp
2206
+ >>> from sympy.diffgeom.rn import R2
2207
+ >>> from sympy.diffgeom import metric_to_Ricci_components, TensorProduct
2208
+ >>> TP = TensorProduct
2209
+
2210
+ >>> metric_to_Ricci_components(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
2211
+ [[0, 0], [0, 0]]
2212
+ >>> non_trivial_metric = exp(2*R2.r)*TP(R2.dr, R2.dr) + \
2213
+ R2.r**2*TP(R2.dtheta, R2.dtheta)
2214
+ >>> non_trivial_metric
2215
+ exp(2*rho)*TensorProduct(drho, drho) + rho**2*TensorProduct(dtheta, dtheta)
2216
+ >>> metric_to_Ricci_components(non_trivial_metric)
2217
+ [[1/rho, 0], [0, exp(-2*rho)*rho]]
2218
+
2219
+ """
2220
+ riemann = metric_to_Riemann_components(expr)
2221
+ coord_sys = _find_coords(expr).pop()
2222
+ indices = list(range(coord_sys.dim))
2223
+ ricci = [[Add(*[riemann[k, i, k, j] for k in indices])
2224
+ for j in indices]
2225
+ for i in indices]
2226
+ return ImmutableDenseNDimArray(ricci)
2227
+
2228
+ ###############################################################################
2229
+ # Classes for deprecation
2230
+ ###############################################################################
2231
+
2232
+ class _deprecated_container:
2233
+ # This class gives deprecation warning.
2234
+ # When deprecated features are completely deleted, this should be removed as well.
2235
+ # See https://github.com/sympy/sympy/pull/19368
2236
+ def __init__(self, message, data):
2237
+ super().__init__(data)
2238
+ self.message = message
2239
+
2240
+ def warn(self):
2241
+ sympy_deprecation_warning(
2242
+ self.message,
2243
+ deprecated_since_version="1.7",
2244
+ active_deprecations_target="deprecated-diffgeom-mutable",
2245
+ stacklevel=4
2246
+ )
2247
+
2248
+ def __iter__(self):
2249
+ self.warn()
2250
+ return super().__iter__()
2251
+
2252
+ def __getitem__(self, key):
2253
+ self.warn()
2254
+ return super().__getitem__(key)
2255
+
2256
+ def __contains__(self, key):
2257
+ self.warn()
2258
+ return super().__contains__(key)
2259
+
2260
+
2261
+ class _deprecated_list(_deprecated_container, list):
2262
+ pass
2263
+
2264
+
2265
+ class _deprecated_dict(_deprecated_container, dict):
2266
+ pass
2267
+
2268
+
2269
+ # Import at end to avoid cyclic imports
2270
+ from sympy.simplify.simplify import simplify
.venv/lib/python3.13/site-packages/sympy/diffgeom/rn.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Predefined R^n manifolds together with common coord. systems.
2
+
3
+ Coordinate systems are predefined as well as the transformation laws between
4
+ them.
5
+
6
+ Coordinate functions can be accessed as attributes of the manifold (eg `R2.x`),
7
+ as attributes of the coordinate systems (eg `R2_r.x` and `R2_p.theta`), or by
8
+ using the usual `coord_sys.coord_function(index, name)` interface.
9
+ """
10
+
11
+ from typing import Any
12
+ import warnings
13
+
14
+ from sympy.core.symbol import (Dummy, symbols)
15
+ from sympy.functions.elementary.miscellaneous import sqrt
16
+ from sympy.functions.elementary.trigonometric import (acos, atan2, cos, sin)
17
+ from .diffgeom import Manifold, Patch, CoordSystem
18
+
19
+ __all__ = [
20
+ 'R2', 'R2_origin', 'relations_2d', 'R2_r', 'R2_p',
21
+ 'R3', 'R3_origin', 'relations_3d', 'R3_r', 'R3_c', 'R3_s'
22
+ ]
23
+
24
+ ###############################################################################
25
+ # R2
26
+ ###############################################################################
27
+ R2: Any = Manifold('R^2', 2)
28
+
29
+ R2_origin: Any = Patch('origin', R2)
30
+
31
+ x, y = symbols('x y', real=True)
32
+ r, theta = symbols('rho theta', nonnegative=True)
33
+
34
+ relations_2d = {
35
+ ('rectangular', 'polar'): [(x, y), (sqrt(x**2 + y**2), atan2(y, x))],
36
+ ('polar', 'rectangular'): [(r, theta), (r*cos(theta), r*sin(theta))],
37
+ }
38
+
39
+ R2_r: Any = CoordSystem('rectangular', R2_origin, (x, y), relations_2d)
40
+ R2_p: Any = CoordSystem('polar', R2_origin, (r, theta), relations_2d)
41
+
42
+ # support deprecated feature
43
+ with warnings.catch_warnings():
44
+ warnings.simplefilter("ignore")
45
+ x, y, r, theta = symbols('x y r theta', cls=Dummy)
46
+ R2_r.connect_to(R2_p, [x, y],
47
+ [sqrt(x**2 + y**2), atan2(y, x)],
48
+ inverse=False, fill_in_gaps=False)
49
+ R2_p.connect_to(R2_r, [r, theta],
50
+ [r*cos(theta), r*sin(theta)],
51
+ inverse=False, fill_in_gaps=False)
52
+
53
+ # Defining the basis coordinate functions and adding shortcuts for them to the
54
+ # manifold and the patch.
55
+ R2.x, R2.y = R2_origin.x, R2_origin.y = R2_r.x, R2_r.y = R2_r.coord_functions()
56
+ R2.r, R2.theta = R2_origin.r, R2_origin.theta = R2_p.r, R2_p.theta = R2_p.coord_functions()
57
+
58
+ # Defining the basis vector fields and adding shortcuts for them to the
59
+ # manifold and the patch.
60
+ R2.e_x, R2.e_y = R2_origin.e_x, R2_origin.e_y = R2_r.e_x, R2_r.e_y = R2_r.base_vectors()
61
+ R2.e_r, R2.e_theta = R2_origin.e_r, R2_origin.e_theta = R2_p.e_r, R2_p.e_theta = R2_p.base_vectors()
62
+
63
+ # Defining the basis oneform fields and adding shortcuts for them to the
64
+ # manifold and the patch.
65
+ R2.dx, R2.dy = R2_origin.dx, R2_origin.dy = R2_r.dx, R2_r.dy = R2_r.base_oneforms()
66
+ R2.dr, R2.dtheta = R2_origin.dr, R2_origin.dtheta = R2_p.dr, R2_p.dtheta = R2_p.base_oneforms()
67
+
68
+ ###############################################################################
69
+ # R3
70
+ ###############################################################################
71
+ R3: Any = Manifold('R^3', 3)
72
+
73
+ R3_origin: Any = Patch('origin', R3)
74
+
75
+ x, y, z = symbols('x y z', real=True)
76
+ rho, psi, r, theta, phi = symbols('rho psi r theta phi', nonnegative=True)
77
+
78
+ relations_3d = {
79
+ ('rectangular', 'cylindrical'): [(x, y, z),
80
+ (sqrt(x**2 + y**2), atan2(y, x), z)],
81
+ ('cylindrical', 'rectangular'): [(rho, psi, z),
82
+ (rho*cos(psi), rho*sin(psi), z)],
83
+ ('rectangular', 'spherical'): [(x, y, z),
84
+ (sqrt(x**2 + y**2 + z**2),
85
+ acos(z/sqrt(x**2 + y**2 + z**2)),
86
+ atan2(y, x))],
87
+ ('spherical', 'rectangular'): [(r, theta, phi),
88
+ (r*sin(theta)*cos(phi),
89
+ r*sin(theta)*sin(phi),
90
+ r*cos(theta))],
91
+ ('cylindrical', 'spherical'): [(rho, psi, z),
92
+ (sqrt(rho**2 + z**2),
93
+ acos(z/sqrt(rho**2 + z**2)),
94
+ psi)],
95
+ ('spherical', 'cylindrical'): [(r, theta, phi),
96
+ (r*sin(theta), phi, r*cos(theta))],
97
+ }
98
+
99
+ R3_r: Any = CoordSystem('rectangular', R3_origin, (x, y, z), relations_3d)
100
+ R3_c: Any = CoordSystem('cylindrical', R3_origin, (rho, psi, z), relations_3d)
101
+ R3_s: Any = CoordSystem('spherical', R3_origin, (r, theta, phi), relations_3d)
102
+
103
+ # support deprecated feature
104
+ with warnings.catch_warnings():
105
+ warnings.simplefilter("ignore")
106
+ x, y, z, rho, psi, r, theta, phi = symbols('x y z rho psi r theta phi', cls=Dummy)
107
+ R3_r.connect_to(R3_c, [x, y, z],
108
+ [sqrt(x**2 + y**2), atan2(y, x), z],
109
+ inverse=False, fill_in_gaps=False)
110
+ R3_c.connect_to(R3_r, [rho, psi, z],
111
+ [rho*cos(psi), rho*sin(psi), z],
112
+ inverse=False, fill_in_gaps=False)
113
+ ## rectangular <-> spherical
114
+ R3_r.connect_to(R3_s, [x, y, z],
115
+ [sqrt(x**2 + y**2 + z**2), acos(z/
116
+ sqrt(x**2 + y**2 + z**2)), atan2(y, x)],
117
+ inverse=False, fill_in_gaps=False)
118
+ R3_s.connect_to(R3_r, [r, theta, phi],
119
+ [r*sin(theta)*cos(phi), r*sin(
120
+ theta)*sin(phi), r*cos(theta)],
121
+ inverse=False, fill_in_gaps=False)
122
+ ## cylindrical <-> spherical
123
+ R3_c.connect_to(R3_s, [rho, psi, z],
124
+ [sqrt(rho**2 + z**2), acos(z/sqrt(rho**2 + z**2)), psi],
125
+ inverse=False, fill_in_gaps=False)
126
+ R3_s.connect_to(R3_c, [r, theta, phi],
127
+ [r*sin(theta), phi, r*cos(theta)],
128
+ inverse=False, fill_in_gaps=False)
129
+
130
+ # Defining the basis coordinate functions.
131
+ R3_r.x, R3_r.y, R3_r.z = R3_r.coord_functions()
132
+ R3_c.rho, R3_c.psi, R3_c.z = R3_c.coord_functions()
133
+ R3_s.r, R3_s.theta, R3_s.phi = R3_s.coord_functions()
134
+
135
+ # Defining the basis vector fields.
136
+ R3_r.e_x, R3_r.e_y, R3_r.e_z = R3_r.base_vectors()
137
+ R3_c.e_rho, R3_c.e_psi, R3_c.e_z = R3_c.base_vectors()
138
+ R3_s.e_r, R3_s.e_theta, R3_s.e_phi = R3_s.base_vectors()
139
+
140
+ # Defining the basis oneform fields.
141
+ R3_r.dx, R3_r.dy, R3_r.dz = R3_r.base_oneforms()
142
+ R3_c.drho, R3_c.dpsi, R3_c.dz = R3_c.base_oneforms()
143
+ R3_s.dr, R3_s.dtheta, R3_s.dphi = R3_s.base_oneforms()
.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_class_structure.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.diffgeom import Manifold, Patch, CoordSystem, Point
2
+ from sympy.core.function import Function
3
+ from sympy.core.symbol import symbols
4
+ from sympy.testing.pytest import warns_deprecated_sympy
5
+
6
+ m = Manifold('m', 2)
7
+ p = Patch('p', m)
8
+ a, b = symbols('a b')
9
+ cs = CoordSystem('cs', p, [a, b])
10
+ x, y = symbols('x y')
11
+ f = Function('f')
12
+ s1, s2 = cs.coord_functions()
13
+ v1, v2 = cs.base_vectors()
14
+ f1, f2 = cs.base_oneforms()
15
+
16
+ def test_point():
17
+ point = Point(cs, [x, y])
18
+ assert point != Point(cs, [2, y])
19
+ #TODO assert point.subs(x, 2) == Point(cs, [2, y])
20
+ #TODO assert point.free_symbols == set([x, y])
21
+
22
+ def test_subs():
23
+ assert s1.subs(s1, s2) == s2
24
+ assert v1.subs(v1, v2) == v2
25
+ assert f1.subs(f1, f2) == f2
26
+ assert (x*f(s1) + y).subs(s1, s2) == x*f(s2) + y
27
+ assert (f(s1)*v1).subs(v1, v2) == f(s1)*v2
28
+ assert (y*f(s1)*f1).subs(f1, f2) == y*f(s1)*f2
29
+
30
+ def test_deprecated():
31
+ with warns_deprecated_sympy():
32
+ cs_wname = CoordSystem('cs', p, ['a', 'b'])
33
+ assert cs_wname == cs_wname.func(*cs_wname.args)
.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_diffgeom.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.core import Lambda, Symbol, symbols
2
+ from sympy.diffgeom.rn import R2, R2_p, R2_r, R3_r, R3_c, R3_s, R2_origin
3
+ from sympy.diffgeom import (Manifold, Patch, CoordSystem, Commutator, Differential, TensorProduct,
4
+ WedgeProduct, BaseCovarDerivativeOp, CovarDerivativeOp, LieDerivative,
5
+ covariant_order, contravariant_order, twoform_to_matrix, metric_to_Christoffel_1st,
6
+ metric_to_Christoffel_2nd, metric_to_Riemann_components,
7
+ metric_to_Ricci_components, intcurve_diffequ, intcurve_series)
8
+ from sympy.simplify import trigsimp, simplify
9
+ from sympy.functions import sqrt, atan2, sin
10
+ from sympy.matrices import Matrix
11
+ from sympy.testing.pytest import raises, nocache_fail
12
+ from sympy.testing.pytest import warns_deprecated_sympy
13
+
14
+ TP = TensorProduct
15
+
16
+
17
+ def test_coordsys_transform():
18
+ # test inverse transforms
19
+ p, q, r, s = symbols('p q r s')
20
+ rel = {('first', 'second'): [(p, q), (q, -p)]}
21
+ R2_pq = CoordSystem('first', R2_origin, [p, q], rel)
22
+ R2_rs = CoordSystem('second', R2_origin, [r, s], rel)
23
+ r, s = R2_rs.symbols
24
+ assert R2_rs.transform(R2_pq) == Matrix([[-s], [r]])
25
+
26
+ # inverse transform impossible case
27
+ a, b = symbols('a b', positive=True)
28
+ rel = {('first', 'second'): [(a,), (-a,)]}
29
+ R2_a = CoordSystem('first', R2_origin, [a], rel)
30
+ R2_b = CoordSystem('second', R2_origin, [b], rel)
31
+ # This transformation is uninvertible because there is no positive a, b satisfying a = -b
32
+ with raises(NotImplementedError):
33
+ R2_b.transform(R2_a)
34
+
35
+ # inverse transform ambiguous case
36
+ c, d = symbols('c d')
37
+ rel = {('first', 'second'): [(c,), (c**2,)]}
38
+ R2_c = CoordSystem('first', R2_origin, [c], rel)
39
+ R2_d = CoordSystem('second', R2_origin, [d], rel)
40
+ # The transform method should throw if it finds multiple inverses for a coordinate transformation.
41
+ with raises(ValueError):
42
+ R2_d.transform(R2_c)
43
+
44
+ # test indirect transformation
45
+ a, b, c, d, e, f = symbols('a, b, c, d, e, f')
46
+ rel = {('C1', 'C2'): [(a, b), (2*a, 3*b)],
47
+ ('C2', 'C3'): [(c, d), (3*c, 2*d)]}
48
+ C1 = CoordSystem('C1', R2_origin, (a, b), rel)
49
+ C2 = CoordSystem('C2', R2_origin, (c, d), rel)
50
+ C3 = CoordSystem('C3', R2_origin, (e, f), rel)
51
+ a, b = C1.symbols
52
+ c, d = C2.symbols
53
+ e, f = C3.symbols
54
+ assert C2.transform(C1) == Matrix([c/2, d/3])
55
+ assert C1.transform(C3) == Matrix([6*a, 6*b])
56
+ assert C3.transform(C1) == Matrix([e/6, f/6])
57
+ assert C3.transform(C2) == Matrix([e/3, f/2])
58
+
59
+ a, b, c, d, e, f = symbols('a, b, c, d, e, f')
60
+ rel = {('C1', 'C2'): [(a, b), (2*a, 3*b + 1)],
61
+ ('C3', 'C2'): [(e, f), (-e - 2, 2*f)]}
62
+ C1 = CoordSystem('C1', R2_origin, (a, b), rel)
63
+ C2 = CoordSystem('C2', R2_origin, (c, d), rel)
64
+ C3 = CoordSystem('C3', R2_origin, (e, f), rel)
65
+ a, b = C1.symbols
66
+ c, d = C2.symbols
67
+ e, f = C3.symbols
68
+ assert C2.transform(C1) == Matrix([c/2, (d - 1)/3])
69
+ assert C1.transform(C3) == Matrix([-2*a - 2, (3*b + 1)/2])
70
+ assert C3.transform(C1) == Matrix([-e/2 - 1, (2*f - 1)/3])
71
+ assert C3.transform(C2) == Matrix([-e - 2, 2*f])
72
+
73
+ # old signature uses Lambda
74
+ a, b, c, d, e, f = symbols('a, b, c, d, e, f')
75
+ rel = {('C1', 'C2'): Lambda((a, b), (2*a, 3*b + 1)),
76
+ ('C3', 'C2'): Lambda((e, f), (-e - 2, 2*f))}
77
+ C1 = CoordSystem('C1', R2_origin, (a, b), rel)
78
+ C2 = CoordSystem('C2', R2_origin, (c, d), rel)
79
+ C3 = CoordSystem('C3', R2_origin, (e, f), rel)
80
+ a, b = C1.symbols
81
+ c, d = C2.symbols
82
+ e, f = C3.symbols
83
+ assert C2.transform(C1) == Matrix([c/2, (d - 1)/3])
84
+ assert C1.transform(C3) == Matrix([-2*a - 2, (3*b + 1)/2])
85
+ assert C3.transform(C1) == Matrix([-e/2 - 1, (2*f - 1)/3])
86
+ assert C3.transform(C2) == Matrix([-e - 2, 2*f])
87
+
88
+
89
+ def test_R2():
90
+ x0, y0, r0, theta0 = symbols('x0, y0, r0, theta0', real=True)
91
+ point_r = R2_r.point([x0, y0])
92
+ point_p = R2_p.point([r0, theta0])
93
+
94
+ # r**2 = x**2 + y**2
95
+ assert (R2.r**2 - R2.x**2 - R2.y**2).rcall(point_r) == 0
96
+ assert trigsimp( (R2.r**2 - R2.x**2 - R2.y**2).rcall(point_p) ) == 0
97
+ assert trigsimp(R2.e_r(R2.x**2 + R2.y**2).rcall(point_p).doit()) == 2*r0
98
+
99
+ # polar->rect->polar == Id
100
+ a, b = symbols('a b', positive=True)
101
+ m = Matrix([[a], [b]])
102
+
103
+ #TODO assert m == R2_r.transform(R2_p, R2_p.transform(R2_r, [a, b])).applyfunc(simplify)
104
+ assert m == R2_p.transform(R2_r, R2_r.transform(R2_p, m)).applyfunc(simplify)
105
+
106
+ # deprecated method
107
+ with warns_deprecated_sympy():
108
+ assert m == R2_p.coord_tuple_transform_to(
109
+ R2_r, R2_r.coord_tuple_transform_to(R2_p, m)).applyfunc(simplify)
110
+
111
+
112
+ def test_R3():
113
+ a, b, c = symbols('a b c', positive=True)
114
+ m = Matrix([[a], [b], [c]])
115
+
116
+ assert m == R3_c.transform(R3_r, R3_r.transform(R3_c, m)).applyfunc(simplify)
117
+ #TODO assert m == R3_r.transform(R3_c, R3_c.transform(R3_r, m)).applyfunc(simplify)
118
+ assert m == R3_s.transform(
119
+ R3_r, R3_r.transform(R3_s, m)).applyfunc(simplify)
120
+ #TODO assert m == R3_r.transform(R3_s, R3_s.transform(R3_r, m)).applyfunc(simplify)
121
+ assert m == R3_s.transform(
122
+ R3_c, R3_c.transform(R3_s, m)).applyfunc(simplify)
123
+ #TODO assert m == R3_c.transform(R3_s, R3_s.transform(R3_c, m)).applyfunc(simplify)
124
+
125
+ with warns_deprecated_sympy():
126
+ assert m == R3_c.coord_tuple_transform_to(
127
+ R3_r, R3_r.coord_tuple_transform_to(R3_c, m)).applyfunc(simplify)
128
+ #TODO assert m == R3_r.coord_tuple_transform_to(R3_c, R3_c.coord_tuple_transform_to(R3_r, m)).applyfunc(simplify)
129
+ assert m == R3_s.coord_tuple_transform_to(
130
+ R3_r, R3_r.coord_tuple_transform_to(R3_s, m)).applyfunc(simplify)
131
+ #TODO assert m == R3_r.coord_tuple_transform_to(R3_s, R3_s.coord_tuple_transform_to(R3_r, m)).applyfunc(simplify)
132
+ assert m == R3_s.coord_tuple_transform_to(
133
+ R3_c, R3_c.coord_tuple_transform_to(R3_s, m)).applyfunc(simplify)
134
+ #TODO assert m == R3_c.coord_tuple_transform_to(R3_s, R3_s.coord_tuple_transform_to(R3_c, m)).applyfunc(simplify)
135
+
136
+
137
+ def test_CoordinateSymbol():
138
+ x, y = R2_r.symbols
139
+ r, theta = R2_p.symbols
140
+ assert y.rewrite(R2_p) == r*sin(theta)
141
+
142
+
143
+ def test_point():
144
+ x, y = symbols('x, y')
145
+ p = R2_r.point([x, y])
146
+ assert p.free_symbols == {x, y}
147
+ assert p.coords(R2_r) == p.coords() == Matrix([x, y])
148
+ assert p.coords(R2_p) == Matrix([sqrt(x**2 + y**2), atan2(y, x)])
149
+
150
+
151
+ def test_commutator():
152
+ assert Commutator(R2.e_x, R2.e_y) == 0
153
+ assert Commutator(R2.x*R2.e_x, R2.x*R2.e_x) == 0
154
+ assert Commutator(R2.x*R2.e_x, R2.x*R2.e_y) == R2.x*R2.e_y
155
+ c = Commutator(R2.e_x, R2.e_r)
156
+ assert c(R2.x) == R2.y*(R2.x**2 + R2.y**2)**(-1)*sin(R2.theta)
157
+
158
+
159
+ def test_differential():
160
+ xdy = R2.x*R2.dy
161
+ dxdy = Differential(xdy)
162
+ assert xdy.rcall(None) == xdy
163
+ assert dxdy(R2.e_x, R2.e_y) == 1
164
+ assert dxdy(R2.e_x, R2.x*R2.e_y) == R2.x
165
+ assert Differential(dxdy) == 0
166
+
167
+
168
+ def test_products():
169
+ assert TensorProduct(
170
+ R2.dx, R2.dy)(R2.e_x, R2.e_y) == R2.dx(R2.e_x)*R2.dy(R2.e_y) == 1
171
+ assert TensorProduct(R2.dx, R2.dy)(None, R2.e_y) == R2.dx
172
+ assert TensorProduct(R2.dx, R2.dy)(R2.e_x, None) == R2.dy
173
+ assert TensorProduct(R2.dx, R2.dy)(R2.e_x) == R2.dy
174
+ assert TensorProduct(R2.x, R2.dx) == R2.x*R2.dx
175
+ assert TensorProduct(
176
+ R2.e_x, R2.e_y)(R2.x, R2.y) == R2.e_x(R2.x) * R2.e_y(R2.y) == 1
177
+ assert TensorProduct(R2.e_x, R2.e_y)(None, R2.y) == R2.e_x
178
+ assert TensorProduct(R2.e_x, R2.e_y)(R2.x, None) == R2.e_y
179
+ assert TensorProduct(R2.e_x, R2.e_y)(R2.x) == R2.e_y
180
+ assert TensorProduct(R2.x, R2.e_x) == R2.x * R2.e_x
181
+ assert TensorProduct(
182
+ R2.dx, R2.e_y)(R2.e_x, R2.y) == R2.dx(R2.e_x) * R2.e_y(R2.y) == 1
183
+ assert TensorProduct(R2.dx, R2.e_y)(None, R2.y) == R2.dx
184
+ assert TensorProduct(R2.dx, R2.e_y)(R2.e_x, None) == R2.e_y
185
+ assert TensorProduct(R2.dx, R2.e_y)(R2.e_x) == R2.e_y
186
+ assert TensorProduct(R2.x, R2.e_x) == R2.x * R2.e_x
187
+ assert TensorProduct(
188
+ R2.e_x, R2.dy)(R2.x, R2.e_y) == R2.e_x(R2.x) * R2.dy(R2.e_y) == 1
189
+ assert TensorProduct(R2.e_x, R2.dy)(None, R2.e_y) == R2.e_x
190
+ assert TensorProduct(R2.e_x, R2.dy)(R2.x, None) == R2.dy
191
+ assert TensorProduct(R2.e_x, R2.dy)(R2.x) == R2.dy
192
+ assert TensorProduct(R2.e_y,R2.e_x)(R2.x**2 + R2.y**2,R2.x**2 + R2.y**2) == 4*R2.x*R2.y
193
+
194
+ assert WedgeProduct(R2.dx, R2.dy)(R2.e_x, R2.e_y) == 1
195
+ assert WedgeProduct(R2.e_x, R2.e_y)(R2.x, R2.y) == 1
196
+
197
+
198
+ def test_lie_derivative():
199
+ assert LieDerivative(R2.e_x, R2.y) == R2.e_x(R2.y) == 0
200
+ assert LieDerivative(R2.e_x, R2.x) == R2.e_x(R2.x) == 1
201
+ assert LieDerivative(R2.e_x, R2.e_x) == Commutator(R2.e_x, R2.e_x) == 0
202
+ assert LieDerivative(R2.e_x, R2.e_r) == Commutator(R2.e_x, R2.e_r)
203
+ assert LieDerivative(R2.e_x + R2.e_y, R2.x) == 1
204
+ assert LieDerivative(
205
+ R2.e_x, TensorProduct(R2.dx, R2.dy))(R2.e_x, R2.e_y) == 0
206
+
207
+
208
+ @nocache_fail
209
+ def test_covar_deriv():
210
+ ch = metric_to_Christoffel_2nd(TP(R2.dx, R2.dx) + TP(R2.dy, R2.dy))
211
+ cvd = BaseCovarDerivativeOp(R2_r, 0, ch)
212
+ assert cvd(R2.x) == 1
213
+ # This line fails if the cache is disabled:
214
+ assert cvd(R2.x*R2.e_x) == R2.e_x
215
+ cvd = CovarDerivativeOp(R2.x*R2.e_x, ch)
216
+ assert cvd(R2.x) == R2.x
217
+ assert cvd(R2.x*R2.e_x) == R2.x*R2.e_x
218
+
219
+
220
+ def test_intcurve_diffequ():
221
+ t = symbols('t')
222
+ start_point = R2_r.point([1, 0])
223
+ vector_field = -R2.y*R2.e_x + R2.x*R2.e_y
224
+ equations, init_cond = intcurve_diffequ(vector_field, t, start_point)
225
+ assert str(equations) == '[f_1(t) + Derivative(f_0(t), t), -f_0(t) + Derivative(f_1(t), t)]'
226
+ assert str(init_cond) == '[f_0(0) - 1, f_1(0)]'
227
+ equations, init_cond = intcurve_diffequ(vector_field, t, start_point, R2_p)
228
+ assert str(
229
+ equations) == '[Derivative(f_0(t), t), Derivative(f_1(t), t) - 1]'
230
+ assert str(init_cond) == '[f_0(0) - 1, f_1(0)]'
231
+
232
+
233
+ def test_helpers_and_coordinate_dependent():
234
+ one_form = R2.dr + R2.dx
235
+ two_form = Differential(R2.x*R2.dr + R2.r*R2.dx)
236
+ three_form = Differential(
237
+ R2.y*two_form) + Differential(R2.x*Differential(R2.r*R2.dr))
238
+ metric = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dy, R2.dy)
239
+ metric_ambig = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dr, R2.dr)
240
+ misform_a = TensorProduct(R2.dr, R2.dr) + R2.dr
241
+ misform_b = R2.dr**4
242
+ misform_c = R2.dx*R2.dy
243
+ twoform_not_sym = TensorProduct(R2.dx, R2.dx) + TensorProduct(R2.dx, R2.dy)
244
+ twoform_not_TP = WedgeProduct(R2.dx, R2.dy)
245
+
246
+ one_vector = R2.e_x + R2.e_y
247
+ two_vector = TensorProduct(R2.e_x, R2.e_y)
248
+ three_vector = TensorProduct(R2.e_x, R2.e_y, R2.e_x)
249
+ two_wp = WedgeProduct(R2.e_x,R2.e_y)
250
+
251
+ assert covariant_order(one_form) == 1
252
+ assert covariant_order(two_form) == 2
253
+ assert covariant_order(three_form) == 3
254
+ assert covariant_order(two_form + metric) == 2
255
+ assert covariant_order(two_form + metric_ambig) == 2
256
+ assert covariant_order(two_form + twoform_not_sym) == 2
257
+ assert covariant_order(two_form + twoform_not_TP) == 2
258
+
259
+ assert contravariant_order(one_vector) == 1
260
+ assert contravariant_order(two_vector) == 2
261
+ assert contravariant_order(three_vector) == 3
262
+ assert contravariant_order(two_vector + two_wp) == 2
263
+
264
+ raises(ValueError, lambda: covariant_order(misform_a))
265
+ raises(ValueError, lambda: covariant_order(misform_b))
266
+ raises(ValueError, lambda: covariant_order(misform_c))
267
+
268
+ assert twoform_to_matrix(metric) == Matrix([[1, 0], [0, 1]])
269
+ assert twoform_to_matrix(twoform_not_sym) == Matrix([[1, 0], [1, 0]])
270
+ assert twoform_to_matrix(twoform_not_TP) == Matrix([[0, -1], [1, 0]])
271
+
272
+ raises(ValueError, lambda: twoform_to_matrix(one_form))
273
+ raises(ValueError, lambda: twoform_to_matrix(three_form))
274
+ raises(ValueError, lambda: twoform_to_matrix(metric_ambig))
275
+
276
+ raises(ValueError, lambda: metric_to_Christoffel_1st(twoform_not_sym))
277
+ raises(ValueError, lambda: metric_to_Christoffel_2nd(twoform_not_sym))
278
+ raises(ValueError, lambda: metric_to_Riemann_components(twoform_not_sym))
279
+ raises(ValueError, lambda: metric_to_Ricci_components(twoform_not_sym))
280
+
281
+
282
+ def test_correct_arguments():
283
+ raises(ValueError, lambda: R2.e_x(R2.e_x))
284
+ raises(ValueError, lambda: R2.e_x(R2.dx))
285
+
286
+ raises(ValueError, lambda: Commutator(R2.e_x, R2.x))
287
+ raises(ValueError, lambda: Commutator(R2.dx, R2.e_x))
288
+
289
+ raises(ValueError, lambda: Differential(Differential(R2.e_x)))
290
+
291
+ raises(ValueError, lambda: R2.dx(R2.x))
292
+
293
+ raises(ValueError, lambda: LieDerivative(R2.dx, R2.dx))
294
+ raises(ValueError, lambda: LieDerivative(R2.x, R2.dx))
295
+
296
+ raises(ValueError, lambda: CovarDerivativeOp(R2.dx, []))
297
+ raises(ValueError, lambda: CovarDerivativeOp(R2.x, []))
298
+
299
+ a = Symbol('a')
300
+ raises(ValueError, lambda: intcurve_series(R2.dx, a, R2_r.point([1, 2])))
301
+ raises(ValueError, lambda: intcurve_series(R2.x, a, R2_r.point([1, 2])))
302
+
303
+ raises(ValueError, lambda: intcurve_diffequ(R2.dx, a, R2_r.point([1, 2])))
304
+ raises(ValueError, lambda: intcurve_diffequ(R2.x, a, R2_r.point([1, 2])))
305
+
306
+ raises(ValueError, lambda: contravariant_order(R2.e_x + R2.dx))
307
+ raises(ValueError, lambda: covariant_order(R2.e_x + R2.dx))
308
+
309
+ raises(ValueError, lambda: contravariant_order(R2.e_x*R2.e_y))
310
+ raises(ValueError, lambda: covariant_order(R2.dx*R2.dy))
311
+
312
+ def test_simplify():
313
+ x, y = R2_r.coord_functions()
314
+ dx, dy = R2_r.base_oneforms()
315
+ ex, ey = R2_r.base_vectors()
316
+ assert simplify(x) == x
317
+ assert simplify(x*y) == x*y
318
+ assert simplify(dx*dy) == dx*dy
319
+ assert simplify(ex*ey) == ex*ey
320
+ assert ((1-x)*dx)/(1-x)**2 == dx/(1-x)
321
+
322
+
323
+ def test_issue_17917():
324
+ X = R2.x*R2.e_x - R2.y*R2.e_y
325
+ Y = (R2.x**2 + R2.y**2)*R2.e_x - R2.x*R2.y*R2.e_y
326
+ assert LieDerivative(X, Y).expand() == (
327
+ R2.x**2*R2.e_x - 3*R2.y**2*R2.e_x - R2.x*R2.y*R2.e_y)
328
+
329
+ def test_deprecations():
330
+ m = Manifold('M', 2)
331
+ p = Patch('P', m)
332
+ with warns_deprecated_sympy():
333
+ CoordSystem('Car2d', p, names=['x', 'y'])
334
+
335
+ with warns_deprecated_sympy():
336
+ c = CoordSystem('Car2d', p, ['x', 'y'])
337
+
338
+ with warns_deprecated_sympy():
339
+ list(m.patches)
340
+
341
+ with warns_deprecated_sympy():
342
+ list(c.transforms)
.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_function_diffgeom_book.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.diffgeom.rn import R2, R2_p, R2_r, R3_r
2
+ from sympy.diffgeom import intcurve_series, Differential, WedgeProduct
3
+ from sympy.core import symbols, Function, Derivative
4
+ from sympy.simplify import trigsimp, simplify
5
+ from sympy.functions import sqrt, atan2, sin, cos
6
+ from sympy.matrices import Matrix
7
+
8
+ # Most of the functionality is covered in the
9
+ # test_functional_diffgeom_ch* tests which are based on the
10
+ # example from the paper of Sussman and Wisdom.
11
+ # If they do not cover something, additional tests are added in other test
12
+ # functions.
13
+
14
+ # From "Functional Differential Geometry" as of 2011
15
+ # by Sussman and Wisdom.
16
+
17
+
18
+ def test_functional_diffgeom_ch2():
19
+ x0, y0, r0, theta0 = symbols('x0, y0, r0, theta0', real=True)
20
+ x, y = symbols('x, y', real=True)
21
+ f = Function('f')
22
+
23
+ assert (R2_p.point_to_coords(R2_r.point([x0, y0])) ==
24
+ Matrix([sqrt(x0**2 + y0**2), atan2(y0, x0)]))
25
+ assert (R2_r.point_to_coords(R2_p.point([r0, theta0])) ==
26
+ Matrix([r0*cos(theta0), r0*sin(theta0)]))
27
+
28
+ assert R2_p.jacobian(R2_r, [r0, theta0]) == Matrix(
29
+ [[cos(theta0), -r0*sin(theta0)], [sin(theta0), r0*cos(theta0)]])
30
+
31
+ field = f(R2.x, R2.y)
32
+ p1_in_rect = R2_r.point([x0, y0])
33
+ p1_in_polar = R2_p.point([sqrt(x0**2 + y0**2), atan2(y0, x0)])
34
+ assert field.rcall(p1_in_rect) == f(x0, y0)
35
+ assert field.rcall(p1_in_polar) == f(x0, y0)
36
+
37
+ p_r = R2_r.point([x0, y0])
38
+ p_p = R2_p.point([r0, theta0])
39
+ assert R2.x(p_r) == x0
40
+ assert R2.x(p_p) == r0*cos(theta0)
41
+ assert R2.r(p_p) == r0
42
+ assert R2.r(p_r) == sqrt(x0**2 + y0**2)
43
+ assert R2.theta(p_r) == atan2(y0, x0)
44
+
45
+ h = R2.x*R2.r**2 + R2.y**3
46
+ assert h.rcall(p_r) == x0*(x0**2 + y0**2) + y0**3
47
+ assert h.rcall(p_p) == r0**3*sin(theta0)**3 + r0**3*cos(theta0)
48
+
49
+
50
+ def test_functional_diffgeom_ch3():
51
+ x0, y0 = symbols('x0, y0', real=True)
52
+ x, y, t = symbols('x, y, t', real=True)
53
+ f = Function('f')
54
+ b1 = Function('b1')
55
+ b2 = Function('b2')
56
+ p_r = R2_r.point([x0, y0])
57
+
58
+ s_field = f(R2.x, R2.y)
59
+ v_field = b1(R2.x)*R2.e_x + b2(R2.y)*R2.e_y
60
+ assert v_field.rcall(s_field).rcall(p_r).doit() == b1(
61
+ x0)*Derivative(f(x0, y0), x0) + b2(y0)*Derivative(f(x0, y0), y0)
62
+
63
+ assert R2.e_x(R2.r**2).rcall(p_r) == 2*x0
64
+ v = R2.e_x + 2*R2.e_y
65
+ s = R2.r**2 + 3*R2.x
66
+ assert v.rcall(s).rcall(p_r).doit() == 2*x0 + 4*y0 + 3
67
+
68
+ circ = -R2.y*R2.e_x + R2.x*R2.e_y
69
+ series = intcurve_series(circ, t, R2_r.point([1, 0]), coeffs=True)
70
+ series_x, series_y = zip(*series)
71
+ assert all(
72
+ term == cos(t).taylor_term(i, t) for i, term in enumerate(series_x))
73
+ assert all(
74
+ term == sin(t).taylor_term(i, t) for i, term in enumerate(series_y))
75
+
76
+
77
+ def test_functional_diffgeom_ch4():
78
+ x0, y0, theta0 = symbols('x0, y0, theta0', real=True)
79
+ x, y, r, theta = symbols('x, y, r, theta', real=True)
80
+ r0 = symbols('r0', positive=True)
81
+ f = Function('f')
82
+ b1 = Function('b1')
83
+ b2 = Function('b2')
84
+ p_r = R2_r.point([x0, y0])
85
+ p_p = R2_p.point([r0, theta0])
86
+
87
+ f_field = b1(R2.x, R2.y)*R2.dx + b2(R2.x, R2.y)*R2.dy
88
+ assert f_field.rcall(R2.e_x).rcall(p_r) == b1(x0, y0)
89
+ assert f_field.rcall(R2.e_y).rcall(p_r) == b2(x0, y0)
90
+
91
+ s_field_r = f(R2.x, R2.y)
92
+ df = Differential(s_field_r)
93
+ assert df(R2.e_x).rcall(p_r).doit() == Derivative(f(x0, y0), x0)
94
+ assert df(R2.e_y).rcall(p_r).doit() == Derivative(f(x0, y0), y0)
95
+
96
+ s_field_p = f(R2.r, R2.theta)
97
+ df = Differential(s_field_p)
98
+ assert trigsimp(df(R2.e_x).rcall(p_p).doit()) == (
99
+ cos(theta0)*Derivative(f(r0, theta0), r0) -
100
+ sin(theta0)*Derivative(f(r0, theta0), theta0)/r0)
101
+ assert trigsimp(df(R2.e_y).rcall(p_p).doit()) == (
102
+ sin(theta0)*Derivative(f(r0, theta0), r0) +
103
+ cos(theta0)*Derivative(f(r0, theta0), theta0)/r0)
104
+
105
+ assert R2.dx(R2.e_x).rcall(p_r) == 1
106
+ assert R2.dx(R2.e_x) == 1
107
+ assert R2.dx(R2.e_y).rcall(p_r) == 0
108
+ assert R2.dx(R2.e_y) == 0
109
+
110
+ circ = -R2.y*R2.e_x + R2.x*R2.e_y
111
+ assert R2.dx(circ).rcall(p_r).doit() == -y0
112
+ assert R2.dy(circ).rcall(p_r) == x0
113
+ assert R2.dr(circ).rcall(p_r) == 0
114
+ assert simplify(R2.dtheta(circ).rcall(p_r)) == 1
115
+
116
+ assert (circ - R2.e_theta).rcall(s_field_r).rcall(p_r) == 0
117
+
118
+
119
+ def test_functional_diffgeom_ch6():
120
+ u0, u1, u2, v0, v1, v2, w0, w1, w2 = symbols('u0:3, v0:3, w0:3', real=True)
121
+
122
+ u = u0*R2.e_x + u1*R2.e_y
123
+ v = v0*R2.e_x + v1*R2.e_y
124
+ wp = WedgeProduct(R2.dx, R2.dy)
125
+ assert wp(u, v) == u0*v1 - u1*v0
126
+
127
+ u = u0*R3_r.e_x + u1*R3_r.e_y + u2*R3_r.e_z
128
+ v = v0*R3_r.e_x + v1*R3_r.e_y + v2*R3_r.e_z
129
+ w = w0*R3_r.e_x + w1*R3_r.e_y + w2*R3_r.e_z
130
+ wp = WedgeProduct(R3_r.dx, R3_r.dy, R3_r.dz)
131
+ assert wp(
132
+ u, v, w) == Matrix(3, 3, [u0, u1, u2, v0, v1, v2, w0, w1, w2]).det()
133
+
134
+ a, b, c = symbols('a, b, c', cls=Function)
135
+ a_f = a(R3_r.x, R3_r.y, R3_r.z)
136
+ b_f = b(R3_r.x, R3_r.y, R3_r.z)
137
+ c_f = c(R3_r.x, R3_r.y, R3_r.z)
138
+ theta = a_f*R3_r.dx + b_f*R3_r.dy + c_f*R3_r.dz
139
+ dtheta = Differential(theta)
140
+ da = Differential(a_f)
141
+ db = Differential(b_f)
142
+ dc = Differential(c_f)
143
+ expr = dtheta - WedgeProduct(
144
+ da, R3_r.dx) - WedgeProduct(db, R3_r.dy) - WedgeProduct(dc, R3_r.dz)
145
+ assert expr.rcall(R3_r.e_x, R3_r.e_y) == 0
.venv/lib/python3.13/site-packages/sympy/diffgeom/tests/test_hyperbolic_space.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r'''
2
+ unit test describing the hyperbolic half-plane with the Poincare metric. This
3
+ is a basic model of hyperbolic geometry on the (positive) half-space
4
+
5
+ {(x,y) \in R^2 | y > 0}
6
+
7
+ with the Riemannian metric
8
+
9
+ ds^2 = (dx^2 + dy^2)/y^2
10
+
11
+ It has constant negative scalar curvature = -2
12
+
13
+ https://en.wikipedia.org/wiki/Poincare_half-plane_model
14
+ '''
15
+ from sympy.matrices.dense import diag
16
+ from sympy.diffgeom import (twoform_to_matrix,
17
+ metric_to_Christoffel_1st, metric_to_Christoffel_2nd,
18
+ metric_to_Riemann_components, metric_to_Ricci_components)
19
+ import sympy.diffgeom.rn
20
+ from sympy.tensor.array import ImmutableDenseNDimArray
21
+
22
+
23
+ def test_H2():
24
+ TP = sympy.diffgeom.TensorProduct
25
+ R2 = sympy.diffgeom.rn.R2
26
+ y = R2.y
27
+ dy = R2.dy
28
+ dx = R2.dx
29
+ g = (TP(dx, dx) + TP(dy, dy))*y**(-2)
30
+ automat = twoform_to_matrix(g)
31
+ mat = diag(y**(-2), y**(-2))
32
+ assert mat == automat
33
+
34
+ gamma1 = metric_to_Christoffel_1st(g)
35
+ assert gamma1[0, 0, 0] == 0
36
+ assert gamma1[0, 0, 1] == -y**(-3)
37
+ assert gamma1[0, 1, 0] == -y**(-3)
38
+ assert gamma1[0, 1, 1] == 0
39
+
40
+ assert gamma1[1, 1, 1] == -y**(-3)
41
+ assert gamma1[1, 1, 0] == 0
42
+ assert gamma1[1, 0, 1] == 0
43
+ assert gamma1[1, 0, 0] == y**(-3)
44
+
45
+ gamma2 = metric_to_Christoffel_2nd(g)
46
+ assert gamma2[0, 0, 0] == 0
47
+ assert gamma2[0, 0, 1] == -y**(-1)
48
+ assert gamma2[0, 1, 0] == -y**(-1)
49
+ assert gamma2[0, 1, 1] == 0
50
+
51
+ assert gamma2[1, 1, 1] == -y**(-1)
52
+ assert gamma2[1, 1, 0] == 0
53
+ assert gamma2[1, 0, 1] == 0
54
+ assert gamma2[1, 0, 0] == y**(-1)
55
+
56
+ Rm = metric_to_Riemann_components(g)
57
+ assert Rm[0, 0, 0, 0] == 0
58
+ assert Rm[0, 0, 0, 1] == 0
59
+ assert Rm[0, 0, 1, 0] == 0
60
+ assert Rm[0, 0, 1, 1] == 0
61
+
62
+ assert Rm[0, 1, 0, 0] == 0
63
+ assert Rm[0, 1, 0, 1] == -y**(-2)
64
+ assert Rm[0, 1, 1, 0] == y**(-2)
65
+ assert Rm[0, 1, 1, 1] == 0
66
+
67
+ assert Rm[1, 0, 0, 0] == 0
68
+ assert Rm[1, 0, 0, 1] == y**(-2)
69
+ assert Rm[1, 0, 1, 0] == -y**(-2)
70
+ assert Rm[1, 0, 1, 1] == 0
71
+
72
+ assert Rm[1, 1, 0, 0] == 0
73
+ assert Rm[1, 1, 0, 1] == 0
74
+ assert Rm[1, 1, 1, 0] == 0
75
+ assert Rm[1, 1, 1, 1] == 0
76
+
77
+ Ric = metric_to_Ricci_components(g)
78
+ assert Ric[0, 0] == -y**(-2)
79
+ assert Ric[0, 1] == 0
80
+ assert Ric[1, 0] == 0
81
+ assert Ric[0, 0] == -y**(-2)
82
+
83
+ assert Ric == ImmutableDenseNDimArray([-y**(-2), 0, 0, -y**(-2)], (2, 2))
84
+
85
+ ## scalar curvature is -2
86
+ #TODO - it would be nice to have index contraction built-in
87
+ R = (Ric[0, 0] + Ric[1, 1])*y**2
88
+ assert R == -2
89
+
90
+ ## Gauss curvature is -1
91
+ assert R/2 == -1
.venv/lib/python3.13/site-packages/sympy/external/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Unified place for determining if external dependencies are installed or not.
3
+
4
+ You should import all external modules using the import_module() function.
5
+
6
+ For example
7
+
8
+ >>> from sympy.external import import_module
9
+ >>> numpy = import_module('numpy')
10
+
11
+ If the resulting library is not installed, or if the installed version
12
+ is less than a given minimum version, the function will return None.
13
+ Otherwise, it will return the library. See the docstring of
14
+ import_module() for more information.
15
+
16
+ """
17
+
18
+ from sympy.external.importtools import import_module
19
+
20
+ __all__ = ['import_module']
.venv/lib/python3.13/site-packages/sympy/external/gmpy.py ADDED
@@ -0,0 +1,342 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ import os
3
+ from ctypes import c_long, sizeof
4
+ from functools import reduce
5
+ from typing import Type
6
+ from warnings import warn
7
+
8
+ from sympy.external import import_module
9
+
10
+ from .pythonmpq import PythonMPQ
11
+
12
+ from .ntheory import (
13
+ bit_scan1 as python_bit_scan1,
14
+ bit_scan0 as python_bit_scan0,
15
+ remove as python_remove,
16
+ factorial as python_factorial,
17
+ sqrt as python_sqrt,
18
+ sqrtrem as python_sqrtrem,
19
+ gcd as python_gcd,
20
+ lcm as python_lcm,
21
+ gcdext as python_gcdext,
22
+ is_square as python_is_square,
23
+ invert as python_invert,
24
+ legendre as python_legendre,
25
+ jacobi as python_jacobi,
26
+ kronecker as python_kronecker,
27
+ iroot as python_iroot,
28
+ is_fermat_prp as python_is_fermat_prp,
29
+ is_euler_prp as python_is_euler_prp,
30
+ is_strong_prp as python_is_strong_prp,
31
+ is_fibonacci_prp as python_is_fibonacci_prp,
32
+ is_lucas_prp as python_is_lucas_prp,
33
+ is_selfridge_prp as python_is_selfridge_prp,
34
+ is_strong_lucas_prp as python_is_strong_lucas_prp,
35
+ is_strong_selfridge_prp as python_is_strong_selfridge_prp,
36
+ is_bpsw_prp as python_is_bpsw_prp,
37
+ is_strong_bpsw_prp as python_is_strong_bpsw_prp,
38
+ )
39
+
40
+
41
+ __all__ = [
42
+ # GROUND_TYPES is either 'gmpy' or 'python' depending on which is used. If
43
+ # gmpy is installed then it will be used unless the environment variable
44
+ # SYMPY_GROUND_TYPES is set to something other than 'auto', 'gmpy', or
45
+ # 'gmpy2'.
46
+ 'GROUND_TYPES',
47
+
48
+ # If HAS_GMPY is 0, no supported version of gmpy is available. Otherwise,
49
+ # HAS_GMPY will be 2 for gmpy2 if GROUND_TYPES is 'gmpy'. It used to be
50
+ # possible for HAS_GMPY to be 1 for gmpy but gmpy is no longer supported.
51
+ 'HAS_GMPY',
52
+
53
+ # SYMPY_INTS is a tuple containing the base types for valid integer types.
54
+ # This is either (int,) or (int, type(mpz(0))) depending on GROUND_TYPES.
55
+ 'SYMPY_INTS',
56
+
57
+ # MPQ is either gmpy.mpq or the Python equivalent from
58
+ # sympy.external.pythonmpq
59
+ 'MPQ',
60
+
61
+ # MPZ is either gmpy.mpz or int.
62
+ 'MPZ',
63
+
64
+ 'bit_scan1',
65
+ 'bit_scan0',
66
+ 'remove',
67
+ 'factorial',
68
+ 'sqrt',
69
+ 'is_square',
70
+ 'sqrtrem',
71
+ 'gcd',
72
+ 'lcm',
73
+ 'gcdext',
74
+ 'invert',
75
+ 'legendre',
76
+ 'jacobi',
77
+ 'kronecker',
78
+ 'iroot',
79
+ 'is_fermat_prp',
80
+ 'is_euler_prp',
81
+ 'is_strong_prp',
82
+ 'is_fibonacci_prp',
83
+ 'is_lucas_prp',
84
+ 'is_selfridge_prp',
85
+ 'is_strong_lucas_prp',
86
+ 'is_strong_selfridge_prp',
87
+ 'is_bpsw_prp',
88
+ 'is_strong_bpsw_prp',
89
+ ]
90
+
91
+
92
+ #
93
+ # Tested python-flint version. Future versions might work but we will only use
94
+ # them if explicitly requested by SYMPY_GROUND_TYPES=flint.
95
+ #
96
+ _PYTHON_FLINT_VERSION_NEEDED = ["0.6", "0.7", "0.8", "0.9", "0.10"]
97
+
98
+
99
+ def _flint_version_okay(flint_version):
100
+ major, minor = flint_version.split('.')[:2]
101
+ flint_ver = f'{major}.{minor}'
102
+ return flint_ver in _PYTHON_FLINT_VERSION_NEEDED
103
+
104
+ #
105
+ # We will only use gmpy2 >= 2.0.0
106
+ #
107
+ _GMPY2_MIN_VERSION = '2.0.0'
108
+
109
+
110
+ def _get_flint(sympy_ground_types):
111
+ if sympy_ground_types not in ('auto', 'flint'):
112
+ return None
113
+
114
+ try:
115
+ import flint
116
+ # Earlier versions of python-flint may not have __version__.
117
+ from flint import __version__ as _flint_version
118
+ except ImportError:
119
+ if sympy_ground_types == 'flint':
120
+ warn("SYMPY_GROUND_TYPES was set to flint but python-flint is not "
121
+ "installed. Falling back to other ground types.")
122
+ return None
123
+
124
+ if _flint_version_okay(_flint_version):
125
+ return flint
126
+ elif sympy_ground_types == 'auto':
127
+ return None
128
+ else:
129
+ warn(f"Using python-flint {_flint_version} because SYMPY_GROUND_TYPES "
130
+ f"is set to flint but this version of SymPy is only tested "
131
+ f"with python-flint versions {_PYTHON_FLINT_VERSION_NEEDED}.")
132
+ return flint
133
+
134
+
135
+ def _get_gmpy2(sympy_ground_types):
136
+ if sympy_ground_types not in ('auto', 'gmpy', 'gmpy2'):
137
+ return None
138
+
139
+ gmpy = import_module('gmpy2', min_module_version=_GMPY2_MIN_VERSION,
140
+ module_version_attr='version', module_version_attr_call_args=())
141
+
142
+ if sympy_ground_types != 'auto' and gmpy is None:
143
+ warn("gmpy2 library is not installed, switching to 'python' ground types")
144
+
145
+ return gmpy
146
+
147
+
148
+ #
149
+ # SYMPY_GROUND_TYPES can be flint, gmpy, gmpy2, python or auto (default)
150
+ #
151
+ _SYMPY_GROUND_TYPES = os.environ.get('SYMPY_GROUND_TYPES', 'auto').lower()
152
+ _flint = None
153
+ _gmpy = None
154
+
155
+ #
156
+ # First handle auto-detection of flint/gmpy2. We will prefer flint if available
157
+ # or otherwise gmpy2 if available and then lastly the python types.
158
+ #
159
+ if _SYMPY_GROUND_TYPES in ('auto', 'flint'):
160
+ _flint = _get_flint(_SYMPY_GROUND_TYPES)
161
+ if _flint is not None:
162
+ _SYMPY_GROUND_TYPES = 'flint'
163
+ else:
164
+ _SYMPY_GROUND_TYPES = 'auto'
165
+
166
+ if _SYMPY_GROUND_TYPES in ('auto', 'gmpy', 'gmpy2'):
167
+ _gmpy = _get_gmpy2(_SYMPY_GROUND_TYPES)
168
+ if _gmpy is not None:
169
+ _SYMPY_GROUND_TYPES = 'gmpy'
170
+ else:
171
+ _SYMPY_GROUND_TYPES = 'python'
172
+
173
+ if _SYMPY_GROUND_TYPES not in ('flint', 'gmpy', 'python'):
174
+ warn("SYMPY_GROUND_TYPES environment variable unrecognised. "
175
+ "Should be 'auto', 'flint', 'gmpy', 'gmpy2' or 'python'.")
176
+ _SYMPY_GROUND_TYPES = 'python'
177
+
178
+ #
179
+ # At this point _SYMPY_GROUND_TYPES is either flint, gmpy or python. The blocks
180
+ # below define the values exported by this module in each case.
181
+ #
182
+
183
+ #
184
+ # In gmpy2 and flint, there are functions that take a long (or unsigned long)
185
+ # argument. That is, it is not possible to input a value larger than that.
186
+ #
187
+ LONG_MAX = (1 << (8*sizeof(c_long) - 1)) - 1
188
+
189
+ #
190
+ # Type checkers are confused by what SYMPY_INTS is. There may be a better type
191
+ # hint for this like Type[Integral] or something.
192
+ #
193
+ SYMPY_INTS: tuple[Type, ...]
194
+
195
+ if _SYMPY_GROUND_TYPES == 'gmpy':
196
+
197
+ assert _gmpy is not None
198
+
199
+ flint = None
200
+ gmpy = _gmpy
201
+
202
+ HAS_GMPY = 2
203
+ GROUND_TYPES = 'gmpy'
204
+ SYMPY_INTS = (int, type(gmpy.mpz(0)))
205
+ MPZ = gmpy.mpz
206
+ MPQ = gmpy.mpq
207
+
208
+ bit_scan1 = gmpy.bit_scan1
209
+ bit_scan0 = gmpy.bit_scan0
210
+ remove = gmpy.remove
211
+ factorial = gmpy.fac
212
+ sqrt = gmpy.isqrt
213
+ is_square = gmpy.is_square
214
+ sqrtrem = gmpy.isqrt_rem
215
+ gcd = gmpy.gcd
216
+ lcm = gmpy.lcm
217
+ gcdext = gmpy.gcdext
218
+ invert = gmpy.invert
219
+ legendre = gmpy.legendre
220
+ jacobi = gmpy.jacobi
221
+ kronecker = gmpy.kronecker
222
+
223
+ def iroot(x, n):
224
+ # In the latest gmpy2, the threshold for n is ULONG_MAX,
225
+ # but adjust to the older one.
226
+ if n <= LONG_MAX:
227
+ return gmpy.iroot(x, n)
228
+ return python_iroot(x, n)
229
+
230
+ is_fermat_prp = gmpy.is_fermat_prp
231
+ is_euler_prp = gmpy.is_euler_prp
232
+ is_strong_prp = gmpy.is_strong_prp
233
+ is_fibonacci_prp = gmpy.is_fibonacci_prp
234
+ is_lucas_prp = gmpy.is_lucas_prp
235
+ is_selfridge_prp = gmpy.is_selfridge_prp
236
+ is_strong_lucas_prp = gmpy.is_strong_lucas_prp
237
+ is_strong_selfridge_prp = gmpy.is_strong_selfridge_prp
238
+ is_bpsw_prp = gmpy.is_bpsw_prp
239
+ is_strong_bpsw_prp = gmpy.is_strong_bpsw_prp
240
+
241
+ elif _SYMPY_GROUND_TYPES == 'flint':
242
+
243
+ assert _flint is not None
244
+
245
+ flint = _flint
246
+ gmpy = None
247
+
248
+ HAS_GMPY = 0
249
+ GROUND_TYPES = 'flint'
250
+ SYMPY_INTS = (int, flint.fmpz) # type: ignore
251
+ MPZ = flint.fmpz # type: ignore
252
+ MPQ = flint.fmpq # type: ignore
253
+
254
+ bit_scan1 = python_bit_scan1
255
+ bit_scan0 = python_bit_scan0
256
+ remove = python_remove
257
+ factorial = python_factorial
258
+
259
+ def sqrt(x):
260
+ return flint.fmpz(x).isqrt()
261
+
262
+ def is_square(x):
263
+ if x < 0:
264
+ return False
265
+ return flint.fmpz(x).sqrtrem()[1] == 0
266
+
267
+ def sqrtrem(x):
268
+ return flint.fmpz(x).sqrtrem()
269
+
270
+ def gcd(*args):
271
+ return reduce(flint.fmpz.gcd, args, flint.fmpz(0))
272
+
273
+ def lcm(*args):
274
+ return reduce(flint.fmpz.lcm, args, flint.fmpz(1))
275
+
276
+ gcdext = python_gcdext
277
+ invert = python_invert
278
+ legendre = python_legendre
279
+
280
+ def jacobi(x, y):
281
+ if y <= 0 or not y % 2:
282
+ raise ValueError("y should be an odd positive integer")
283
+ return flint.fmpz(x).jacobi(y)
284
+
285
+ kronecker = python_kronecker
286
+
287
+ def iroot(x, n):
288
+ if n <= LONG_MAX:
289
+ y = flint.fmpz(x).root(n)
290
+ return y, y**n == x
291
+ return python_iroot(x, n)
292
+
293
+ is_fermat_prp = python_is_fermat_prp
294
+ is_euler_prp = python_is_euler_prp
295
+ is_strong_prp = python_is_strong_prp
296
+ is_fibonacci_prp = python_is_fibonacci_prp
297
+ is_lucas_prp = python_is_lucas_prp
298
+ is_selfridge_prp = python_is_selfridge_prp
299
+ is_strong_lucas_prp = python_is_strong_lucas_prp
300
+ is_strong_selfridge_prp = python_is_strong_selfridge_prp
301
+ is_bpsw_prp = python_is_bpsw_prp
302
+ is_strong_bpsw_prp = python_is_strong_bpsw_prp
303
+
304
+ elif _SYMPY_GROUND_TYPES == 'python':
305
+
306
+ flint = None
307
+ gmpy = None
308
+
309
+ HAS_GMPY = 0
310
+ GROUND_TYPES = 'python'
311
+ SYMPY_INTS = (int,)
312
+ MPZ = int
313
+ MPQ = PythonMPQ
314
+
315
+ bit_scan1 = python_bit_scan1
316
+ bit_scan0 = python_bit_scan0
317
+ remove = python_remove
318
+ factorial = python_factorial
319
+ sqrt = python_sqrt
320
+ is_square = python_is_square
321
+ sqrtrem = python_sqrtrem
322
+ gcd = python_gcd
323
+ lcm = python_lcm
324
+ gcdext = python_gcdext
325
+ invert = python_invert
326
+ legendre = python_legendre
327
+ jacobi = python_jacobi
328
+ kronecker = python_kronecker
329
+ iroot = python_iroot
330
+ is_fermat_prp = python_is_fermat_prp
331
+ is_euler_prp = python_is_euler_prp
332
+ is_strong_prp = python_is_strong_prp
333
+ is_fibonacci_prp = python_is_fibonacci_prp
334
+ is_lucas_prp = python_is_lucas_prp
335
+ is_selfridge_prp = python_is_selfridge_prp
336
+ is_strong_lucas_prp = python_is_strong_lucas_prp
337
+ is_strong_selfridge_prp = python_is_strong_selfridge_prp
338
+ is_bpsw_prp = python_is_bpsw_prp
339
+ is_strong_bpsw_prp = python_is_strong_bpsw_prp
340
+
341
+ else:
342
+ assert False
.venv/lib/python3.13/site-packages/sympy/external/importtools.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Tools to assist importing optional external modules."""
2
+
3
+ import sys
4
+ import re
5
+
6
+ # Override these in the module to change the default warning behavior.
7
+ # For example, you might set both to False before running the tests so that
8
+ # warnings are not printed to the console, or set both to True for debugging.
9
+
10
+ WARN_NOT_INSTALLED = None # Default is False
11
+ WARN_OLD_VERSION = None # Default is True
12
+
13
+
14
+ def __sympy_debug():
15
+ # helper function from sympy/__init__.py
16
+ # We don't just import SYMPY_DEBUG from that file because we don't want to
17
+ # import all of SymPy just to use this module.
18
+ import os
19
+ debug_str = os.getenv('SYMPY_DEBUG', 'False')
20
+ if debug_str in ('True', 'False'):
21
+ return eval(debug_str)
22
+ else:
23
+ raise RuntimeError("unrecognized value for SYMPY_DEBUG: %s" %
24
+ debug_str)
25
+
26
+ if __sympy_debug():
27
+ WARN_OLD_VERSION = True
28
+ WARN_NOT_INSTALLED = True
29
+
30
+
31
+ _component_re = re.compile(r'(\d+ | [a-z]+ | \.)', re.VERBOSE)
32
+
33
+ def version_tuple(vstring):
34
+ # Parse a version string to a tuple e.g. '1.2' -> (1, 2)
35
+ # Simplified from distutils.version.LooseVersion which was deprecated in
36
+ # Python 3.10.
37
+ components = []
38
+ for x in _component_re.split(vstring):
39
+ if x and x != '.':
40
+ try:
41
+ x = int(x)
42
+ except ValueError:
43
+ pass
44
+ components.append(x)
45
+ return tuple(components)
46
+
47
+
48
+ def import_module(module, min_module_version=None, min_python_version=None,
49
+ warn_not_installed=None, warn_old_version=None,
50
+ module_version_attr='__version__', module_version_attr_call_args=None,
51
+ import_kwargs={}, catch=()):
52
+ """
53
+ Import and return a module if it is installed.
54
+
55
+ If the module is not installed, it returns None.
56
+
57
+ A minimum version for the module can be given as the keyword argument
58
+ min_module_version. This should be comparable against the module version.
59
+ By default, module.__version__ is used to get the module version. To
60
+ override this, set the module_version_attr keyword argument. If the
61
+ attribute of the module to get the version should be called (e.g.,
62
+ module.version()), then set module_version_attr_call_args to the args such
63
+ that module.module_version_attr(*module_version_attr_call_args) returns the
64
+ module's version.
65
+
66
+ If the module version is less than min_module_version using the Python <
67
+ comparison, None will be returned, even if the module is installed. You can
68
+ use this to keep from importing an incompatible older version of a module.
69
+
70
+ You can also specify a minimum Python version by using the
71
+ min_python_version keyword argument. This should be comparable against
72
+ sys.version_info.
73
+
74
+ If the keyword argument warn_not_installed is set to True, the function will
75
+ emit a UserWarning when the module is not installed.
76
+
77
+ If the keyword argument warn_old_version is set to True, the function will
78
+ emit a UserWarning when the library is installed, but cannot be imported
79
+ because of the min_module_version or min_python_version options.
80
+
81
+ Note that because of the way warnings are handled, a warning will be
82
+ emitted for each module only once. You can change the default warning
83
+ behavior by overriding the values of WARN_NOT_INSTALLED and WARN_OLD_VERSION
84
+ in sympy.external.importtools. By default, WARN_NOT_INSTALLED is False and
85
+ WARN_OLD_VERSION is True.
86
+
87
+ This function uses __import__() to import the module. To pass additional
88
+ options to __import__(), use the import_kwargs keyword argument. For
89
+ example, to import a submodule A.B, you must pass a nonempty fromlist option
90
+ to __import__. See the docstring of __import__().
91
+
92
+ This catches ImportError to determine if the module is not installed. To
93
+ catch additional errors, pass them as a tuple to the catch keyword
94
+ argument.
95
+
96
+ Examples
97
+ ========
98
+
99
+ >>> from sympy.external import import_module
100
+
101
+ >>> numpy = import_module('numpy')
102
+
103
+ >>> numpy = import_module('numpy', min_python_version=(2, 7),
104
+ ... warn_old_version=False)
105
+
106
+ >>> numpy = import_module('numpy', min_module_version='1.5',
107
+ ... warn_old_version=False) # numpy.__version__ is a string
108
+
109
+ >>> # gmpy does not have __version__, but it does have gmpy.version()
110
+
111
+ >>> gmpy = import_module('gmpy', min_module_version='1.14',
112
+ ... module_version_attr='version', module_version_attr_call_args=(),
113
+ ... warn_old_version=False)
114
+
115
+ >>> # To import a submodule, you must pass a nonempty fromlist to
116
+ >>> # __import__(). The values do not matter.
117
+ >>> p3 = import_module('mpl_toolkits.mplot3d',
118
+ ... import_kwargs={'fromlist':['something']})
119
+
120
+ >>> # matplotlib.pyplot can raise RuntimeError when the display cannot be opened
121
+ >>> matplotlib = import_module('matplotlib',
122
+ ... import_kwargs={'fromlist':['pyplot']}, catch=(RuntimeError,))
123
+
124
+ """
125
+ # keyword argument overrides default, and global variable overrides
126
+ # keyword argument.
127
+ warn_old_version = (WARN_OLD_VERSION if WARN_OLD_VERSION is not None
128
+ else warn_old_version or True)
129
+ warn_not_installed = (WARN_NOT_INSTALLED if WARN_NOT_INSTALLED is not None
130
+ else warn_not_installed or False)
131
+
132
+ import warnings
133
+
134
+ # Check Python first so we don't waste time importing a module we can't use
135
+ if min_python_version:
136
+ if sys.version_info < min_python_version:
137
+ if warn_old_version:
138
+ warnings.warn("Python version is too old to use %s "
139
+ "(%s or newer required)" % (
140
+ module, '.'.join(map(str, min_python_version))),
141
+ UserWarning, stacklevel=2)
142
+ return
143
+
144
+ try:
145
+ mod = __import__(module, **import_kwargs)
146
+
147
+ ## there's something funny about imports with matplotlib and py3k. doing
148
+ ## from matplotlib import collections
149
+ ## gives python's stdlib collections module. explicitly re-importing
150
+ ## the module fixes this.
151
+ from_list = import_kwargs.get('fromlist', ())
152
+ for submod in from_list:
153
+ if submod == 'collections' and mod.__name__ == 'matplotlib':
154
+ __import__(module + '.' + submod)
155
+ except ImportError:
156
+ if warn_not_installed:
157
+ warnings.warn("%s module is not installed" % module, UserWarning,
158
+ stacklevel=2)
159
+ return
160
+ except catch as e:
161
+ if warn_not_installed:
162
+ warnings.warn(
163
+ "%s module could not be used (%s)" % (module, repr(e)),
164
+ stacklevel=2)
165
+ return
166
+
167
+ if min_module_version:
168
+ modversion = getattr(mod, module_version_attr)
169
+ if module_version_attr_call_args is not None:
170
+ modversion = modversion(*module_version_attr_call_args)
171
+ if version_tuple(modversion) < version_tuple(min_module_version):
172
+ if warn_old_version:
173
+ # Attempt to create a pretty string version of the version
174
+ if isinstance(min_module_version, str):
175
+ verstr = min_module_version
176
+ elif isinstance(min_module_version, (tuple, list)):
177
+ verstr = '.'.join(map(str, min_module_version))
178
+ else:
179
+ # Either don't know what this is. Hopefully
180
+ # it's something that has a nice str version, like an int.
181
+ verstr = str(min_module_version)
182
+ warnings.warn("%s version is too old to use "
183
+ "(%s or newer required)" % (module, verstr),
184
+ UserWarning, stacklevel=2)
185
+ return
186
+
187
+ return mod
.venv/lib/python3.13/site-packages/sympy/external/ntheory.py ADDED
@@ -0,0 +1,618 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # sympy.external.ntheory
2
+ #
3
+ # This module provides pure Python implementations of some number theory
4
+ # functions that are alternately used from gmpy2 if it is installed.
5
+
6
+ import math
7
+
8
+ import mpmath.libmp as mlib
9
+
10
+
11
+ _small_trailing = [0] * 256
12
+ for j in range(1, 8):
13
+ _small_trailing[1 << j :: 1 << (j + 1)] = [j] * (1 << (7 - j))
14
+
15
+
16
+ def bit_scan1(x, n=0):
17
+ if not x:
18
+ return
19
+ x = abs(x >> n)
20
+ low_byte = x & 0xFF
21
+ if low_byte:
22
+ return _small_trailing[low_byte] + n
23
+
24
+ t = 8 + n
25
+ x >>= 8
26
+ # 2**m is quick for z up through 2**30
27
+ z = x.bit_length() - 1
28
+ if x == 1 << z:
29
+ return z + t
30
+
31
+ if z < 300:
32
+ # fixed 8-byte reduction
33
+ while not x & 0xFF:
34
+ x >>= 8
35
+ t += 8
36
+ else:
37
+ # binary reduction important when there might be a large
38
+ # number of trailing 0s
39
+ p = z >> 1
40
+ while not x & 0xFF:
41
+ while x & ((1 << p) - 1):
42
+ p >>= 1
43
+ x >>= p
44
+ t += p
45
+ return t + _small_trailing[x & 0xFF]
46
+
47
+
48
+ def bit_scan0(x, n=0):
49
+ return bit_scan1(x + (1 << n), n)
50
+
51
+
52
+ def remove(x, f):
53
+ if f < 2:
54
+ raise ValueError("factor must be > 1")
55
+ if x == 0:
56
+ return 0, 0
57
+ if f == 2:
58
+ b = bit_scan1(x)
59
+ return x >> b, b
60
+ m = 0
61
+ y, rem = divmod(x, f)
62
+ while not rem:
63
+ x = y
64
+ m += 1
65
+ if m > 5:
66
+ pow_list = [f**2]
67
+ while pow_list:
68
+ _f = pow_list[-1]
69
+ y, rem = divmod(x, _f)
70
+ if not rem:
71
+ m += 1 << len(pow_list)
72
+ x = y
73
+ pow_list.append(_f**2)
74
+ else:
75
+ pow_list.pop()
76
+ y, rem = divmod(x, f)
77
+ return x, m
78
+
79
+
80
+ def factorial(x):
81
+ """Return x!."""
82
+ return int(mlib.ifac(int(x)))
83
+
84
+
85
+ def sqrt(x):
86
+ """Integer square root of x."""
87
+ return int(mlib.isqrt(int(x)))
88
+
89
+
90
+ def sqrtrem(x):
91
+ """Integer square root of x and remainder."""
92
+ s, r = mlib.sqrtrem(int(x))
93
+ return (int(s), int(r))
94
+
95
+
96
+ gcd = math.gcd
97
+ lcm = math.lcm
98
+
99
+
100
+ def _sign(n):
101
+ if n < 0:
102
+ return -1, -n
103
+ return 1, n
104
+
105
+
106
+ def gcdext(a, b):
107
+ if not a or not b:
108
+ g = abs(a) or abs(b)
109
+ if not g:
110
+ return (0, 0, 0)
111
+ return (g, a // g, b // g)
112
+
113
+ x_sign, a = _sign(a)
114
+ y_sign, b = _sign(b)
115
+ x, r = 1, 0
116
+ y, s = 0, 1
117
+
118
+ while b:
119
+ q, c = divmod(a, b)
120
+ a, b = b, c
121
+ x, r = r, x - q*r
122
+ y, s = s, y - q*s
123
+
124
+ return (a, x * x_sign, y * y_sign)
125
+
126
+
127
+ def is_square(x):
128
+ """Return True if x is a square number."""
129
+ if x < 0:
130
+ return False
131
+
132
+ # Note that the possible values of y**2 % n for a given n are limited.
133
+ # For example, when n=4, y**2 % n can only take 0 or 1.
134
+ # In other words, if x % 4 is 2 or 3, then x is not a square number.
135
+ # Mathematically, it determines if it belongs to the set {y**2 % n},
136
+ # but implementationally, it can be realized as a logical conjunction
137
+ # with an n-bit integer.
138
+ # see https://mersenneforum.org/showpost.php?p=110896
139
+ # def magic(n):
140
+ # s = {y**2 % n for y in range(n)}
141
+ # s = set(range(n)) - s
142
+ # return sum(1 << bit for bit in s)
143
+ # >>> print(hex(magic(128)))
144
+ # 0xfdfdfdedfdfdfdecfdfdfdedfdfcfdec
145
+ # >>> print(hex(magic(99)))
146
+ # 0x5f6f9ffb6fb7ddfcb75befdec
147
+ # >>> print(hex(magic(91)))
148
+ # 0x6fd1bfcfed5f3679d3ebdec
149
+ # >>> print(hex(magic(85)))
150
+ # 0xdef9ae771ffe3b9d67dec
151
+ if 0xfdfdfdedfdfdfdecfdfdfdedfdfcfdec & (1 << (x & 127)):
152
+ return False # e.g. 2, 3
153
+ m = x % 765765 # 765765 = 99 * 91 * 85
154
+ if 0x5f6f9ffb6fb7ddfcb75befdec & (1 << (m % 99)):
155
+ return False # e.g. 17, 68
156
+ if 0x6fd1bfcfed5f3679d3ebdec & (1 << (m % 91)):
157
+ return False # e.g. 97, 388
158
+ if 0xdef9ae771ffe3b9d67dec & (1 << (m % 85)):
159
+ return False # e.g. 793, 1408
160
+ return mlib.sqrtrem(int(x))[1] == 0
161
+
162
+
163
+ def invert(x, m):
164
+ """Modular inverse of x modulo m.
165
+
166
+ Returns y such that x*y == 1 mod m.
167
+
168
+ Uses ``math.pow`` but reproduces the behaviour of ``gmpy2.invert``
169
+ which raises ZeroDivisionError if no inverse exists.
170
+ """
171
+ try:
172
+ return pow(x, -1, m)
173
+ except ValueError:
174
+ raise ZeroDivisionError("invert() no inverse exists")
175
+
176
+
177
+ def legendre(x, y):
178
+ """Legendre symbol (x / y).
179
+
180
+ Following the implementation of gmpy2,
181
+ the error is raised only when y is an even number.
182
+ """
183
+ if y <= 0 or not y % 2:
184
+ raise ValueError("y should be an odd prime")
185
+ x %= y
186
+ if not x:
187
+ return 0
188
+ if pow(x, (y - 1) // 2, y) == 1:
189
+ return 1
190
+ return -1
191
+
192
+
193
+ def jacobi(x, y):
194
+ """Jacobi symbol (x / y)."""
195
+ if y <= 0 or not y % 2:
196
+ raise ValueError("y should be an odd positive integer")
197
+ x %= y
198
+ if not x:
199
+ return int(y == 1)
200
+ if y == 1 or x == 1:
201
+ return 1
202
+ if gcd(x, y) != 1:
203
+ return 0
204
+ j = 1
205
+ while x != 0:
206
+ while x % 2 == 0 and x > 0:
207
+ x >>= 1
208
+ if y % 8 in [3, 5]:
209
+ j = -j
210
+ x, y = y, x
211
+ if x % 4 == y % 4 == 3:
212
+ j = -j
213
+ x %= y
214
+ return j
215
+
216
+
217
+ def kronecker(x, y):
218
+ """Kronecker symbol (x / y)."""
219
+ if gcd(x, y) != 1:
220
+ return 0
221
+ if y == 0:
222
+ return 1
223
+ sign = -1 if y < 0 and x < 0 else 1
224
+ y = abs(y)
225
+ s = bit_scan1(y)
226
+ y >>= s
227
+ if s % 2 and x % 8 in [3, 5]:
228
+ sign = -sign
229
+ return sign * jacobi(x, y)
230
+
231
+
232
+ def iroot(y, n):
233
+ if y < 0:
234
+ raise ValueError("y must be nonnegative")
235
+ if n < 1:
236
+ raise ValueError("n must be positive")
237
+ if y in (0, 1):
238
+ return y, True
239
+ if n == 1:
240
+ return y, True
241
+ if n == 2:
242
+ x, rem = mlib.sqrtrem(y)
243
+ return int(x), not rem
244
+ if n >= y.bit_length():
245
+ return 1, False
246
+ # Get initial estimate for Newton's method. Care must be taken to
247
+ # avoid overflow
248
+ try:
249
+ guess = int(y**(1./n) + 0.5)
250
+ except OverflowError:
251
+ exp = math.log2(y)/n
252
+ if exp > 53:
253
+ shift = int(exp - 53)
254
+ guess = int(2.0**(exp - shift) + 1) << shift
255
+ else:
256
+ guess = int(2.0**exp)
257
+ if guess > 2**50:
258
+ # Newton iteration
259
+ xprev, x = -1, guess
260
+ while 1:
261
+ t = x**(n - 1)
262
+ xprev, x = x, ((n - 1)*x + y//t)//n
263
+ if abs(x - xprev) < 2:
264
+ break
265
+ else:
266
+ x = guess
267
+ # Compensate
268
+ t = x**n
269
+ while t < y:
270
+ x += 1
271
+ t = x**n
272
+ while t > y:
273
+ x -= 1
274
+ t = x**n
275
+ return x, t == y
276
+
277
+
278
+ def is_fermat_prp(n, a):
279
+ if a < 2:
280
+ raise ValueError("is_fermat_prp() requires 'a' greater than or equal to 2")
281
+ if n < 1:
282
+ raise ValueError("is_fermat_prp() requires 'n' be greater than 0")
283
+ if n == 1:
284
+ return False
285
+ if n % 2 == 0:
286
+ return n == 2
287
+ a %= n
288
+ if gcd(n, a) != 1:
289
+ raise ValueError("is_fermat_prp() requires gcd(n,a) == 1")
290
+ return pow(a, n - 1, n) == 1
291
+
292
+
293
+ def is_euler_prp(n, a):
294
+ if a < 2:
295
+ raise ValueError("is_euler_prp() requires 'a' greater than or equal to 2")
296
+ if n < 1:
297
+ raise ValueError("is_euler_prp() requires 'n' be greater than 0")
298
+ if n == 1:
299
+ return False
300
+ if n % 2 == 0:
301
+ return n == 2
302
+ a %= n
303
+ if gcd(n, a) != 1:
304
+ raise ValueError("is_euler_prp() requires gcd(n,a) == 1")
305
+ return pow(a, n >> 1, n) == jacobi(a, n) % n
306
+
307
+
308
+ def _is_strong_prp(n, a):
309
+ s = bit_scan1(n - 1)
310
+ a = pow(a, n >> s, n)
311
+ if a == 1 or a == n - 1:
312
+ return True
313
+ for _ in range(s - 1):
314
+ a = pow(a, 2, n)
315
+ if a == n - 1:
316
+ return True
317
+ if a == 1:
318
+ return False
319
+ return False
320
+
321
+
322
+ def is_strong_prp(n, a):
323
+ if a < 2:
324
+ raise ValueError("is_strong_prp() requires 'a' greater than or equal to 2")
325
+ if n < 1:
326
+ raise ValueError("is_strong_prp() requires 'n' be greater than 0")
327
+ if n == 1:
328
+ return False
329
+ if n % 2 == 0:
330
+ return n == 2
331
+ a %= n
332
+ if gcd(n, a) != 1:
333
+ raise ValueError("is_strong_prp() requires gcd(n,a) == 1")
334
+ return _is_strong_prp(n, a)
335
+
336
+
337
+ def _lucas_sequence(n, P, Q, k):
338
+ r"""Return the modular Lucas sequence (U_k, V_k, Q_k).
339
+
340
+ Explanation
341
+ ===========
342
+
343
+ Given a Lucas sequence defined by P, Q, returns the kth values for
344
+ U and V, along with Q^k, all modulo n. This is intended for use with
345
+ possibly very large values of n and k, where the combinatorial functions
346
+ would be completely unusable.
347
+
348
+ .. math ::
349
+ U_k = \begin{cases}
350
+ 0 & \text{if } k = 0\\
351
+ 1 & \text{if } k = 1\\
352
+ PU_{k-1} - QU_{k-2} & \text{if } k > 1
353
+ \end{cases}\\
354
+ V_k = \begin{cases}
355
+ 2 & \text{if } k = 0\\
356
+ P & \text{if } k = 1\\
357
+ PV_{k-1} - QV_{k-2} & \text{if } k > 1
358
+ \end{cases}
359
+
360
+ The modular Lucas sequences are used in numerous places in number theory,
361
+ especially in the Lucas compositeness tests and the various n + 1 proofs.
362
+
363
+ Parameters
364
+ ==========
365
+
366
+ n : int
367
+ n is an odd number greater than or equal to 3
368
+ P : int
369
+ Q : int
370
+ D determined by D = P**2 - 4*Q is non-zero
371
+ k : int
372
+ k is a nonnegative integer
373
+
374
+ Returns
375
+ =======
376
+
377
+ U, V, Qk : (int, int, int)
378
+ `(U_k \bmod{n}, V_k \bmod{n}, Q^k \bmod{n})`
379
+
380
+ Examples
381
+ ========
382
+
383
+ >>> from sympy.external.ntheory import _lucas_sequence
384
+ >>> N = 10**2000 + 4561
385
+ >>> sol = U, V, Qk = _lucas_sequence(N, 3, 1, N//2); sol
386
+ (0, 2, 1)
387
+
388
+ References
389
+ ==========
390
+
391
+ .. [1] https://en.wikipedia.org/wiki/Lucas_sequence
392
+
393
+ """
394
+ if k == 0:
395
+ return (0, 2, 1)
396
+ D = P**2 - 4*Q
397
+ U = 1
398
+ V = P
399
+ Qk = Q % n
400
+ if Q == 1:
401
+ # Optimization for extra strong tests.
402
+ for b in bin(k)[3:]:
403
+ U = (U*V) % n
404
+ V = (V*V - 2) % n
405
+ if b == "1":
406
+ U, V = U*P + V, V*P + U*D
407
+ if U & 1:
408
+ U += n
409
+ if V & 1:
410
+ V += n
411
+ U, V = U >> 1, V >> 1
412
+ elif P == 1 and Q == -1:
413
+ # Small optimization for 50% of Selfridge parameters.
414
+ for b in bin(k)[3:]:
415
+ U = (U*V) % n
416
+ if Qk == 1:
417
+ V = (V*V - 2) % n
418
+ else:
419
+ V = (V*V + 2) % n
420
+ Qk = 1
421
+ if b == "1":
422
+ # new_U = (U + V) // 2
423
+ # new_V = (5*U + V) // 2 = 2*U + new_U
424
+ U, V = U + V, U << 1
425
+ if U & 1:
426
+ U += n
427
+ U >>= 1
428
+ V += U
429
+ Qk = -1
430
+ Qk %= n
431
+ elif P == 1:
432
+ for b in bin(k)[3:]:
433
+ U = (U*V) % n
434
+ V = (V*V - 2*Qk) % n
435
+ Qk *= Qk
436
+ if b == "1":
437
+ # new_U = (U + V) // 2
438
+ # new_V = new_U - 2*Q*U
439
+ U, V = U + V, (Q*U) << 1
440
+ if U & 1:
441
+ U += n
442
+ U >>= 1
443
+ V = U - V
444
+ Qk *= Q
445
+ Qk %= n
446
+ else:
447
+ # The general case with any P and Q.
448
+ for b in bin(k)[3:]:
449
+ U = (U*V) % n
450
+ V = (V*V - 2*Qk) % n
451
+ Qk *= Qk
452
+ if b == "1":
453
+ U, V = U*P + V, V*P + U*D
454
+ if U & 1:
455
+ U += n
456
+ if V & 1:
457
+ V += n
458
+ U, V = U >> 1, V >> 1
459
+ Qk *= Q
460
+ Qk %= n
461
+ return (U % n, V % n, Qk)
462
+
463
+
464
+ def is_fibonacci_prp(n, p, q):
465
+ d = p**2 - 4*q
466
+ if d == 0 or p <= 0 or q not in [1, -1]:
467
+ raise ValueError("invalid values for p,q in is_fibonacci_prp()")
468
+ if n < 1:
469
+ raise ValueError("is_fibonacci_prp() requires 'n' be greater than 0")
470
+ if n == 1:
471
+ return False
472
+ if n % 2 == 0:
473
+ return n == 2
474
+ return _lucas_sequence(n, p, q, n)[1] == p % n
475
+
476
+
477
+ def is_lucas_prp(n, p, q):
478
+ d = p**2 - 4*q
479
+ if d == 0:
480
+ raise ValueError("invalid values for p,q in is_lucas_prp()")
481
+ if n < 1:
482
+ raise ValueError("is_lucas_prp() requires 'n' be greater than 0")
483
+ if n == 1:
484
+ return False
485
+ if n % 2 == 0:
486
+ return n == 2
487
+ if gcd(n, q*d) not in [1, n]:
488
+ raise ValueError("is_lucas_prp() requires gcd(n,2*q*D) == 1")
489
+ return _lucas_sequence(n, p, q, n - jacobi(d, n))[0] == 0
490
+
491
+
492
+ def _is_selfridge_prp(n):
493
+ """Lucas compositeness test with the Selfridge parameters for n.
494
+
495
+ Explanation
496
+ ===========
497
+
498
+ The Lucas compositeness test checks whether n is a prime number.
499
+ The test can be run with arbitrary parameters ``P`` and ``Q``, which also change the performance of the test.
500
+ So, which parameters are most effective for running the Lucas compositeness test?
501
+ As an algorithm for determining ``P`` and ``Q``, Selfridge proposed method A [1]_ page 1401
502
+ (Since two methods were proposed, referred to simply as A and B in the paper,
503
+ we will refer to one of them as "method A").
504
+
505
+ method A fixes ``P = 1``. Then, ``D`` defined by ``D = P**2 - 4Q`` is varied from 5, -7, 9, -11, 13, and so on,
506
+ with the first ``D`` being ``jacobi(D, n) == -1``. Once ``D`` is determined,
507
+ ``Q`` is determined to be ``(P**2 - D)//4``.
508
+
509
+ References
510
+ ==========
511
+
512
+ .. [1] Robert Baillie, Samuel S. Wagstaff, Lucas Pseudoprimes,
513
+ Math. Comp. Vol 35, Number 152 (1980), pp. 1391-1417,
514
+ https://doi.org/10.1090%2FS0025-5718-1980-0583518-6
515
+ http://mpqs.free.fr/LucasPseudoprimes.pdf
516
+
517
+ """
518
+ for D in range(5, 1_000_000, 2):
519
+ if D & 2: # if D % 4 == 3
520
+ D = -D
521
+ j = jacobi(D, n)
522
+ if j == -1:
523
+ return _lucas_sequence(n, 1, (1-D) // 4, n + 1)[0] == 0
524
+ if j == 0 and D % n:
525
+ return False
526
+ # When j == -1 is hard to find, suspect a square number
527
+ if D == 13 and is_square(n):
528
+ return False
529
+ raise ValueError("appropriate value for D cannot be found in is_selfridge_prp()")
530
+
531
+
532
+ def is_selfridge_prp(n):
533
+ if n < 1:
534
+ raise ValueError("is_selfridge_prp() requires 'n' be greater than 0")
535
+ if n == 1:
536
+ return False
537
+ if n % 2 == 0:
538
+ return n == 2
539
+ return _is_selfridge_prp(n)
540
+
541
+
542
+ def is_strong_lucas_prp(n, p, q):
543
+ D = p**2 - 4*q
544
+ if D == 0:
545
+ raise ValueError("invalid values for p,q in is_strong_lucas_prp()")
546
+ if n < 1:
547
+ raise ValueError("is_selfridge_prp() requires 'n' be greater than 0")
548
+ if n == 1:
549
+ return False
550
+ if n % 2 == 0:
551
+ return n == 2
552
+ if gcd(n, q*D) not in [1, n]:
553
+ raise ValueError("is_strong_lucas_prp() requires gcd(n,2*q*D) == 1")
554
+ j = jacobi(D, n)
555
+ s = bit_scan1(n - j)
556
+ U, V, Qk = _lucas_sequence(n, p, q, (n - j) >> s)
557
+ if U == 0 or V == 0:
558
+ return True
559
+ for _ in range(s - 1):
560
+ V = (V*V - 2*Qk) % n
561
+ if V == 0:
562
+ return True
563
+ Qk = pow(Qk, 2, n)
564
+ return False
565
+
566
+
567
+ def _is_strong_selfridge_prp(n):
568
+ for D in range(5, 1_000_000, 2):
569
+ if D & 2: # if D % 4 == 3
570
+ D = -D
571
+ j = jacobi(D, n)
572
+ if j == -1:
573
+ s = bit_scan1(n + 1)
574
+ U, V, Qk = _lucas_sequence(n, 1, (1-D) // 4, (n + 1) >> s)
575
+ if U == 0 or V == 0:
576
+ return True
577
+ for _ in range(s - 1):
578
+ V = (V*V - 2*Qk) % n
579
+ if V == 0:
580
+ return True
581
+ Qk = pow(Qk, 2, n)
582
+ return False
583
+ if j == 0 and D % n:
584
+ return False
585
+ # When j == -1 is hard to find, suspect a square number
586
+ if D == 13 and is_square(n):
587
+ return False
588
+ raise ValueError("appropriate value for D cannot be found in is_strong_selfridge_prp()")
589
+
590
+
591
+ def is_strong_selfridge_prp(n):
592
+ if n < 1:
593
+ raise ValueError("is_strong_selfridge_prp() requires 'n' be greater than 0")
594
+ if n == 1:
595
+ return False
596
+ if n % 2 == 0:
597
+ return n == 2
598
+ return _is_strong_selfridge_prp(n)
599
+
600
+
601
+ def is_bpsw_prp(n):
602
+ if n < 1:
603
+ raise ValueError("is_bpsw_prp() requires 'n' be greater than 0")
604
+ if n == 1:
605
+ return False
606
+ if n % 2 == 0:
607
+ return n == 2
608
+ return _is_strong_prp(n, 2) and _is_selfridge_prp(n)
609
+
610
+
611
+ def is_strong_bpsw_prp(n):
612
+ if n < 1:
613
+ raise ValueError("is_strong_bpsw_prp() requires 'n' be greater than 0")
614
+ if n == 1:
615
+ return False
616
+ if n % 2 == 0:
617
+ return n == 2
618
+ return _is_strong_prp(n, 2) and _is_strong_selfridge_prp(n)
.venv/lib/python3.13/site-packages/sympy/external/pythonmpq.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ PythonMPQ: Rational number type based on Python integers.
3
+
4
+ This class is intended as a pure Python fallback for when gmpy2 is not
5
+ installed. If gmpy2 is installed then its mpq type will be used instead. The
6
+ mpq type is around 20x faster. We could just use the stdlib Fraction class
7
+ here but that is slower:
8
+
9
+ from fractions import Fraction
10
+ from sympy.external.pythonmpq import PythonMPQ
11
+ nums = range(1000)
12
+ dens = range(5, 1005)
13
+ rats = [Fraction(n, d) for n, d in zip(nums, dens)]
14
+ sum(rats) # <--- 24 milliseconds
15
+ rats = [PythonMPQ(n, d) for n, d in zip(nums, dens)]
16
+ sum(rats) # <--- 7 milliseconds
17
+
18
+ Both mpq and Fraction have some awkward features like the behaviour of
19
+ division with // and %:
20
+
21
+ >>> from fractions import Fraction
22
+ >>> Fraction(2, 3) % Fraction(1, 4)
23
+ 1/6
24
+
25
+ For the QQ domain we do not want this behaviour because there should be no
26
+ remainder when dividing rational numbers. SymPy does not make use of this
27
+ aspect of mpq when gmpy2 is installed. Since this class is a fallback for that
28
+ case we do not bother implementing e.g. __mod__ so that we can be sure we
29
+ are not using it when gmpy2 is installed either.
30
+ """
31
+
32
+ from __future__ import annotations
33
+ import operator
34
+ from math import gcd
35
+ from decimal import Decimal
36
+ from fractions import Fraction
37
+ import sys
38
+ from typing import Type
39
+
40
+
41
+ # Used for __hash__
42
+ _PyHASH_MODULUS = sys.hash_info.modulus
43
+ _PyHASH_INF = sys.hash_info.inf
44
+
45
+
46
+ class PythonMPQ:
47
+ """Rational number implementation that is intended to be compatible with
48
+ gmpy2's mpq.
49
+
50
+ Also slightly faster than fractions.Fraction.
51
+
52
+ PythonMPQ should be treated as immutable although no effort is made to
53
+ prevent mutation (since that might slow down calculations).
54
+ """
55
+ __slots__ = ('numerator', 'denominator')
56
+
57
+ def __new__(cls, numerator, denominator=None):
58
+ """Construct PythonMPQ with gcd computation and checks"""
59
+ if denominator is not None:
60
+ #
61
+ # PythonMPQ(n, d): require n and d to be int and d != 0
62
+ #
63
+ if isinstance(numerator, int) and isinstance(denominator, int):
64
+ # This is the slow part:
65
+ divisor = gcd(numerator, denominator)
66
+ numerator //= divisor
67
+ denominator //= divisor
68
+ return cls._new_check(numerator, denominator)
69
+ else:
70
+ #
71
+ # PythonMPQ(q)
72
+ #
73
+ # Here q can be PythonMPQ, int, Decimal, float, Fraction or str
74
+ #
75
+ if isinstance(numerator, int):
76
+ return cls._new(numerator, 1)
77
+ elif isinstance(numerator, PythonMPQ):
78
+ return cls._new(numerator.numerator, numerator.denominator)
79
+
80
+ # Let Fraction handle Decimal/float conversion and str parsing
81
+ if isinstance(numerator, (Decimal, float, str)):
82
+ numerator = Fraction(numerator)
83
+ if isinstance(numerator, Fraction):
84
+ return cls._new(numerator.numerator, numerator.denominator)
85
+ #
86
+ # Reject everything else. This is more strict than mpq which allows
87
+ # things like mpq(Fraction, Fraction) or mpq(Decimal, any). The mpq
88
+ # behaviour is somewhat inconsistent so we choose to accept only a
89
+ # more strict subset of what mpq allows.
90
+ #
91
+ raise TypeError("PythonMPQ() requires numeric or string argument")
92
+
93
+ @classmethod
94
+ def _new_check(cls, numerator, denominator):
95
+ """Construct PythonMPQ, check divide by zero and canonicalize signs"""
96
+ if not denominator:
97
+ raise ZeroDivisionError(f'Zero divisor {numerator}/{denominator}')
98
+ elif denominator < 0:
99
+ numerator = -numerator
100
+ denominator = -denominator
101
+ return cls._new(numerator, denominator)
102
+
103
+ @classmethod
104
+ def _new(cls, numerator, denominator):
105
+ """Construct PythonMPQ efficiently (no checks)"""
106
+ obj = super().__new__(cls)
107
+ obj.numerator = numerator
108
+ obj.denominator = denominator
109
+ return obj
110
+
111
+ def __int__(self):
112
+ """Convert to int (truncates towards zero)"""
113
+ p, q = self.numerator, self.denominator
114
+ if p < 0:
115
+ return -(-p//q)
116
+ return p//q
117
+
118
+ def __float__(self):
119
+ """Convert to float (approximately)"""
120
+ return self.numerator / self.denominator
121
+
122
+ def __bool__(self):
123
+ """True/False if nonzero/zero"""
124
+ return bool(self.numerator)
125
+
126
+ def __eq__(self, other):
127
+ """Compare equal with PythonMPQ, int, float, Decimal or Fraction"""
128
+ if isinstance(other, PythonMPQ):
129
+ return (self.numerator == other.numerator
130
+ and self.denominator == other.denominator)
131
+ elif isinstance(other, self._compatible_types):
132
+ return self.__eq__(PythonMPQ(other))
133
+ else:
134
+ return NotImplemented
135
+
136
+ def __hash__(self):
137
+ """hash - same as mpq/Fraction"""
138
+ try:
139
+ dinv = pow(self.denominator, -1, _PyHASH_MODULUS)
140
+ except ValueError:
141
+ hash_ = _PyHASH_INF
142
+ else:
143
+ hash_ = hash(hash(abs(self.numerator)) * dinv)
144
+ result = hash_ if self.numerator >= 0 else -hash_
145
+ return -2 if result == -1 else result
146
+
147
+ def __reduce__(self):
148
+ """Deconstruct for pickling"""
149
+ return type(self), (self.numerator, self.denominator)
150
+
151
+ def __str__(self):
152
+ """Convert to string"""
153
+ if self.denominator != 1:
154
+ return f"{self.numerator}/{self.denominator}"
155
+ else:
156
+ return f"{self.numerator}"
157
+
158
+ def __repr__(self):
159
+ """Convert to string"""
160
+ return f"MPQ({self.numerator},{self.denominator})"
161
+
162
+ def _cmp(self, other, op):
163
+ """Helper for lt/le/gt/ge"""
164
+ if not isinstance(other, self._compatible_types):
165
+ return NotImplemented
166
+ lhs = self.numerator * other.denominator
167
+ rhs = other.numerator * self.denominator
168
+ return op(lhs, rhs)
169
+
170
+ def __lt__(self, other):
171
+ """self < other"""
172
+ return self._cmp(other, operator.lt)
173
+
174
+ def __le__(self, other):
175
+ """self <= other"""
176
+ return self._cmp(other, operator.le)
177
+
178
+ def __gt__(self, other):
179
+ """self > other"""
180
+ return self._cmp(other, operator.gt)
181
+
182
+ def __ge__(self, other):
183
+ """self >= other"""
184
+ return self._cmp(other, operator.ge)
185
+
186
+ def __abs__(self):
187
+ """abs(q)"""
188
+ return self._new(abs(self.numerator), self.denominator)
189
+
190
+ def __pos__(self):
191
+ """+q"""
192
+ return self
193
+
194
+ def __neg__(self):
195
+ """-q"""
196
+ return self._new(-self.numerator, self.denominator)
197
+
198
+ def __add__(self, other):
199
+ """q1 + q2"""
200
+ if isinstance(other, PythonMPQ):
201
+ #
202
+ # This is much faster than the naive method used in the stdlib
203
+ # fractions module. Not sure where this method comes from
204
+ # though...
205
+ #
206
+ # Compare timings for something like:
207
+ # nums = range(1000)
208
+ # rats = [PythonMPQ(n, d) for n, d in zip(nums[:-5], nums[5:])]
209
+ # sum(rats) # <-- time this
210
+ #
211
+ ap, aq = self.numerator, self.denominator
212
+ bp, bq = other.numerator, other.denominator
213
+ g = gcd(aq, bq)
214
+ if g == 1:
215
+ p = ap*bq + aq*bp
216
+ q = bq*aq
217
+ else:
218
+ q1, q2 = aq//g, bq//g
219
+ p, q = ap*q2 + bp*q1, q1*q2
220
+ g2 = gcd(p, g)
221
+ p, q = (p // g2), q * (g // g2)
222
+
223
+ elif isinstance(other, int):
224
+ p = self.numerator + self.denominator * other
225
+ q = self.denominator
226
+ else:
227
+ return NotImplemented
228
+
229
+ return self._new(p, q)
230
+
231
+ def __radd__(self, other):
232
+ """z1 + q2"""
233
+ if isinstance(other, int):
234
+ p = self.numerator + self.denominator * other
235
+ q = self.denominator
236
+ return self._new(p, q)
237
+ else:
238
+ return NotImplemented
239
+
240
+ def __sub__(self ,other):
241
+ """q1 - q2"""
242
+ if isinstance(other, PythonMPQ):
243
+ ap, aq = self.numerator, self.denominator
244
+ bp, bq = other.numerator, other.denominator
245
+ g = gcd(aq, bq)
246
+ if g == 1:
247
+ p = ap*bq - aq*bp
248
+ q = bq*aq
249
+ else:
250
+ q1, q2 = aq//g, bq//g
251
+ p, q = ap*q2 - bp*q1, q1*q2
252
+ g2 = gcd(p, g)
253
+ p, q = (p // g2), q * (g // g2)
254
+ elif isinstance(other, int):
255
+ p = self.numerator - self.denominator*other
256
+ q = self.denominator
257
+ else:
258
+ return NotImplemented
259
+
260
+ return self._new(p, q)
261
+
262
+ def __rsub__(self, other):
263
+ """z1 - q2"""
264
+ if isinstance(other, int):
265
+ p = self.denominator * other - self.numerator
266
+ q = self.denominator
267
+ return self._new(p, q)
268
+ else:
269
+ return NotImplemented
270
+
271
+ def __mul__(self, other):
272
+ """q1 * q2"""
273
+ if isinstance(other, PythonMPQ):
274
+ ap, aq = self.numerator, self.denominator
275
+ bp, bq = other.numerator, other.denominator
276
+ x1 = gcd(ap, bq)
277
+ x2 = gcd(bp, aq)
278
+ p, q = ((ap//x1)*(bp//x2), (aq//x2)*(bq//x1))
279
+ elif isinstance(other, int):
280
+ x = gcd(other, self.denominator)
281
+ p = self.numerator*(other//x)
282
+ q = self.denominator//x
283
+ else:
284
+ return NotImplemented
285
+
286
+ return self._new(p, q)
287
+
288
+ def __rmul__(self, other):
289
+ """z1 * q2"""
290
+ if isinstance(other, int):
291
+ x = gcd(self.denominator, other)
292
+ p = self.numerator*(other//x)
293
+ q = self.denominator//x
294
+ return self._new(p, q)
295
+ else:
296
+ return NotImplemented
297
+
298
+ def __pow__(self, exp):
299
+ """q ** z"""
300
+ p, q = self.numerator, self.denominator
301
+
302
+ if exp < 0:
303
+ p, q, exp = q, p, -exp
304
+
305
+ return self._new_check(p**exp, q**exp)
306
+
307
+ def __truediv__(self, other):
308
+ """q1 / q2"""
309
+ if isinstance(other, PythonMPQ):
310
+ ap, aq = self.numerator, self.denominator
311
+ bp, bq = other.numerator, other.denominator
312
+ x1 = gcd(ap, bp)
313
+ x2 = gcd(bq, aq)
314
+ p, q = ((ap//x1)*(bq//x2), (aq//x2)*(bp//x1))
315
+ elif isinstance(other, int):
316
+ x = gcd(other, self.numerator)
317
+ p = self.numerator//x
318
+ q = self.denominator*(other//x)
319
+ else:
320
+ return NotImplemented
321
+
322
+ return self._new_check(p, q)
323
+
324
+ def __rtruediv__(self, other):
325
+ """z / q"""
326
+ if isinstance(other, int):
327
+ x = gcd(self.numerator, other)
328
+ p = self.denominator*(other//x)
329
+ q = self.numerator//x
330
+ return self._new_check(p, q)
331
+ else:
332
+ return NotImplemented
333
+
334
+ _compatible_types: tuple[Type, ...] = ()
335
+
336
+ #
337
+ # These are the types that PythonMPQ will interoperate with for operations
338
+ # and comparisons such as ==, + etc. We define this down here so that we can
339
+ # include PythonMPQ in the list as well.
340
+ #
341
+ PythonMPQ._compatible_types = (PythonMPQ, int, Decimal, Fraction)
.venv/lib/python3.13/site-packages/sympy/external/tests/__init__.py ADDED
File without changes
.venv/lib/python3.13/site-packages/sympy/external/tests/test_autowrap.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sympy
2
+ import tempfile
3
+ import os
4
+ from pathlib import Path
5
+ from sympy.core.mod import Mod
6
+ from sympy.core.relational import Eq
7
+ from sympy.core.symbol import symbols
8
+ from sympy.external import import_module
9
+ from sympy.tensor import IndexedBase, Idx
10
+ from sympy.utilities.autowrap import autowrap, ufuncify, CodeWrapError
11
+ from sympy.testing.pytest import skip
12
+
13
+ numpy = import_module('numpy', min_module_version='1.6.1')
14
+ Cython = import_module('Cython', min_module_version='0.15.1')
15
+ f2py = import_module('numpy.f2py', import_kwargs={'fromlist': ['f2py']})
16
+
17
+ f2pyworks = False
18
+ if f2py:
19
+ try:
20
+ autowrap(symbols('x'), 'f95', 'f2py')
21
+ except (CodeWrapError, ImportError, OSError):
22
+ f2pyworks = False
23
+ else:
24
+ f2pyworks = True
25
+
26
+ a, b, c = symbols('a b c')
27
+ n, m, d = symbols('n m d', integer=True)
28
+ A, B, C = symbols('A B C', cls=IndexedBase)
29
+ i = Idx('i', m)
30
+ j = Idx('j', n)
31
+ k = Idx('k', d)
32
+
33
+
34
+ def has_module(module):
35
+ """
36
+ Return True if module exists, otherwise run skip().
37
+
38
+ module should be a string.
39
+ """
40
+ # To give a string of the module name to skip(), this function takes a
41
+ # string. So we don't waste time running import_module() more than once,
42
+ # just map the three modules tested here in this dict.
43
+ modnames = {'numpy': numpy, 'Cython': Cython, 'f2py': f2py}
44
+
45
+ if modnames[module]:
46
+ if module == 'f2py' and not f2pyworks:
47
+ skip("Couldn't run f2py.")
48
+ return True
49
+ skip("Couldn't import %s." % module)
50
+
51
+ #
52
+ # test runners used by several language-backend combinations
53
+ #
54
+
55
+ def runtest_autowrap_twice(language, backend):
56
+ f = autowrap((((a + b)/c)**5).expand(), language, backend)
57
+ g = autowrap((((a + b)/c)**4).expand(), language, backend)
58
+
59
+ # check that autowrap updates the module name. Else, g gives the same as f
60
+ assert f(1, -2, 1) == -1.0
61
+ assert g(1, -2, 1) == 1.0
62
+
63
+
64
+ def runtest_autowrap_trace(language, backend):
65
+ has_module('numpy')
66
+ trace = autowrap(A[i, i], language, backend)
67
+ assert trace(numpy.eye(100)) == 100
68
+
69
+
70
+ def runtest_autowrap_matrix_vector(language, backend):
71
+ has_module('numpy')
72
+ x, y = symbols('x y', cls=IndexedBase)
73
+ expr = Eq(y[i], A[i, j]*x[j])
74
+ mv = autowrap(expr, language, backend)
75
+
76
+ # compare with numpy's dot product
77
+ M = numpy.random.rand(10, 20)
78
+ x = numpy.random.rand(20)
79
+ y = numpy.dot(M, x)
80
+ assert numpy.sum(numpy.abs(y - mv(M, x))) < 1e-13
81
+
82
+
83
+ def runtest_autowrap_matrix_matrix(language, backend):
84
+ has_module('numpy')
85
+ expr = Eq(C[i, j], A[i, k]*B[k, j])
86
+ matmat = autowrap(expr, language, backend)
87
+
88
+ # compare with numpy's dot product
89
+ M1 = numpy.random.rand(10, 20)
90
+ M2 = numpy.random.rand(20, 15)
91
+ M3 = numpy.dot(M1, M2)
92
+ assert numpy.sum(numpy.abs(M3 - matmat(M1, M2))) < 1e-13
93
+
94
+
95
+ def runtest_ufuncify(language, backend):
96
+ has_module('numpy')
97
+ a, b, c = symbols('a b c')
98
+ fabc = ufuncify([a, b, c], a*b + c, backend=backend)
99
+ facb = ufuncify([a, c, b], a*b + c, backend=backend)
100
+ grid = numpy.linspace(-2, 2, 50)
101
+ b = numpy.linspace(-5, 4, 50)
102
+ c = numpy.linspace(-1, 1, 50)
103
+ expected = grid*b + c
104
+ numpy.testing.assert_allclose(fabc(grid, b, c), expected)
105
+ numpy.testing.assert_allclose(facb(grid, c, b), expected)
106
+
107
+
108
+ def runtest_issue_10274(language, backend):
109
+ expr = (a - b + c)**(13)
110
+ tmp = tempfile.mkdtemp()
111
+ f = autowrap(expr, language, backend, tempdir=tmp,
112
+ helpers=('helper', a - b + c, (a, b, c)))
113
+ assert f(1, 1, 1) == 1
114
+
115
+ for file in os.listdir(tmp):
116
+ if not (file.startswith("wrapped_code_") and file.endswith(".c")):
117
+ continue
118
+
119
+ with open(tmp + '/' + file) as fil:
120
+ lines = fil.readlines()
121
+ assert lines[0] == "/******************************************************************************\n"
122
+ assert "Code generated with SymPy " + sympy.__version__ in lines[1]
123
+ assert lines[2:] == [
124
+ " * *\n",
125
+ " * See http://www.sympy.org/ for more information. *\n",
126
+ " * *\n",
127
+ " * This file is part of 'autowrap' *\n",
128
+ " ******************************************************************************/\n",
129
+ "#include " + '"' + file[:-1]+ 'h"' + "\n",
130
+ "#include <math.h>\n",
131
+ "\n",
132
+ "double helper(double a, double b, double c) {\n",
133
+ "\n",
134
+ " double helper_result;\n",
135
+ " helper_result = a - b + c;\n",
136
+ " return helper_result;\n",
137
+ "\n",
138
+ "}\n",
139
+ "\n",
140
+ "double autofunc(double a, double b, double c) {\n",
141
+ "\n",
142
+ " double autofunc_result;\n",
143
+ " autofunc_result = pow(helper(a, b, c), 13);\n",
144
+ " return autofunc_result;\n",
145
+ "\n",
146
+ "}\n",
147
+ ]
148
+
149
+
150
+ def runtest_issue_15337(language, backend):
151
+ has_module('numpy')
152
+ # NOTE : autowrap was originally designed to only accept an iterable for
153
+ # the kwarg "helpers", but in issue 10274 the user mistakenly thought that
154
+ # if there was only a single helper it did not need to be passed via an
155
+ # iterable that wrapped the helper tuple. There were no tests for this
156
+ # behavior so when the code was changed to accept a single tuple it broke
157
+ # the original behavior. These tests below ensure that both now work.
158
+ a, b, c, d, e = symbols('a, b, c, d, e')
159
+ expr = (a - b + c - d + e)**13
160
+ exp_res = (1. - 2. + 3. - 4. + 5.)**13
161
+
162
+ f = autowrap(expr, language, backend, args=(a, b, c, d, e),
163
+ helpers=('f1', a - b + c, (a, b, c)))
164
+ numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
165
+
166
+ f = autowrap(expr, language, backend, args=(a, b, c, d, e),
167
+ helpers=(('f1', a - b, (a, b)), ('f2', c - d, (c, d))))
168
+ numpy.testing.assert_allclose(f(1, 2, 3, 4, 5), exp_res)
169
+
170
+
171
+ def test_issue_15230():
172
+ has_module('f2py')
173
+
174
+ x, y = symbols('x, y')
175
+ expr = Mod(x, 3.0) - Mod(y, -2.0)
176
+ f = autowrap(expr, args=[x, y], language='F95')
177
+ exp_res = float(expr.xreplace({x: 3.5, y: 2.7}).evalf())
178
+ assert abs(f(3.5, 2.7) - exp_res) < 1e-14
179
+
180
+ x, y = symbols('x, y', integer=True)
181
+ expr = Mod(x, 3) - Mod(y, -2)
182
+ f = autowrap(expr, args=[x, y], language='F95')
183
+ assert f(3, 2) == expr.xreplace({x: 3, y: 2})
184
+
185
+ #
186
+ # tests of language-backend combinations
187
+ #
188
+
189
+ # f2py
190
+
191
+
192
+ def test_wrap_twice_f95_f2py():
193
+ has_module('f2py')
194
+ runtest_autowrap_twice('f95', 'f2py')
195
+
196
+
197
+ def test_autowrap_trace_f95_f2py():
198
+ has_module('f2py')
199
+ runtest_autowrap_trace('f95', 'f2py')
200
+
201
+
202
+ def test_autowrap_matrix_vector_f95_f2py():
203
+ has_module('f2py')
204
+ runtest_autowrap_matrix_vector('f95', 'f2py')
205
+
206
+
207
+ def test_autowrap_matrix_matrix_f95_f2py():
208
+ has_module('f2py')
209
+ runtest_autowrap_matrix_matrix('f95', 'f2py')
210
+
211
+
212
+ def test_ufuncify_f95_f2py():
213
+ has_module('f2py')
214
+ runtest_ufuncify('f95', 'f2py')
215
+
216
+
217
+ def test_issue_15337_f95_f2py():
218
+ has_module('f2py')
219
+ runtest_issue_15337('f95', 'f2py')
220
+
221
+ # Cython
222
+
223
+
224
+ def test_wrap_twice_c_cython():
225
+ has_module('Cython')
226
+ runtest_autowrap_twice('C', 'cython')
227
+
228
+
229
+ def test_autowrap_trace_C_Cython():
230
+ has_module('Cython')
231
+ runtest_autowrap_trace('C99', 'cython')
232
+
233
+
234
+ def test_autowrap_matrix_vector_C_cython():
235
+ has_module('Cython')
236
+ runtest_autowrap_matrix_vector('C99', 'cython')
237
+
238
+
239
+ def test_autowrap_matrix_matrix_C_cython():
240
+ has_module('Cython')
241
+ runtest_autowrap_matrix_matrix('C99', 'cython')
242
+
243
+
244
+ def test_ufuncify_C_Cython():
245
+ has_module('Cython')
246
+ runtest_ufuncify('C99', 'cython')
247
+
248
+
249
+ def test_issue_10274_C_cython():
250
+ has_module('Cython')
251
+ runtest_issue_10274('C89', 'cython')
252
+
253
+
254
+ def test_issue_15337_C_cython():
255
+ has_module('Cython')
256
+ runtest_issue_15337('C89', 'cython')
257
+
258
+
259
+ def test_autowrap_custom_printer():
260
+ has_module('Cython')
261
+
262
+ from sympy.core.numbers import pi
263
+ from sympy.utilities.codegen import C99CodeGen
264
+ from sympy.printing.c import C99CodePrinter
265
+
266
+ class PiPrinter(C99CodePrinter):
267
+ def _print_Pi(self, expr):
268
+ return "S_PI"
269
+
270
+ printer = PiPrinter()
271
+ gen = C99CodeGen(printer=printer)
272
+ gen.preprocessor_statements.append('#include "shortpi.h"')
273
+
274
+ expr = pi * a
275
+
276
+ expected = (
277
+ '#include "%s"\n'
278
+ '#include <math.h>\n'
279
+ '#include "shortpi.h"\n'
280
+ '\n'
281
+ 'double autofunc(double a) {\n'
282
+ '\n'
283
+ ' double autofunc_result;\n'
284
+ ' autofunc_result = S_PI*a;\n'
285
+ ' return autofunc_result;\n'
286
+ '\n'
287
+ '}\n'
288
+ )
289
+
290
+ tmpdir = tempfile.mkdtemp()
291
+ # write a trivial header file to use in the generated code
292
+ Path(os.path.join(tmpdir, 'shortpi.h')).write_text('#define S_PI 3.14')
293
+
294
+ func = autowrap(expr, backend='cython', tempdir=tmpdir, code_gen=gen)
295
+
296
+ assert func(4.2) == 3.14 * 4.2
297
+
298
+ # check that the generated code is correct
299
+ for filename in os.listdir(tmpdir):
300
+ if filename.startswith('wrapped_code') and filename.endswith('.c'):
301
+ with open(os.path.join(tmpdir, filename)) as f:
302
+ lines = f.readlines()
303
+ expected = expected % filename.replace('.c', '.h')
304
+ assert ''.join(lines[7:]) == expected
305
+
306
+
307
+ # Numpy
308
+
309
+ def test_ufuncify_numpy():
310
+ # This test doesn't use Cython, but if Cython works, then there is a valid
311
+ # C compiler, which is needed.
312
+ has_module('Cython')
313
+ runtest_ufuncify('C99', 'numpy')
.venv/lib/python3.13/site-packages/sympy/external/tests/test_codegen.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This tests the compilation and execution of the source code generated with
2
+ # utilities.codegen. The compilation takes place in a temporary directory that
3
+ # is removed after the test. By default the test directory is always removed,
4
+ # but this behavior can be changed by setting the environment variable
5
+ # SYMPY_TEST_CLEAN_TEMP to:
6
+ # export SYMPY_TEST_CLEAN_TEMP=always : the default behavior.
7
+ # export SYMPY_TEST_CLEAN_TEMP=success : only remove the directories of working tests.
8
+ # export SYMPY_TEST_CLEAN_TEMP=never : never remove the directories with the test code.
9
+ # When a directory is not removed, the necessary information is printed on
10
+ # screen to find the files that belong to the (failed) tests. If a test does
11
+ # not fail, py.test captures all the output and you will not see the directories
12
+ # corresponding to the successful tests. Use the --nocapture option to see all
13
+ # the output.
14
+
15
+ # All tests below have a counterpart in utilities/test/test_codegen.py. In the
16
+ # latter file, the resulting code is compared with predefined strings, without
17
+ # compilation or execution.
18
+
19
+ # All the generated Fortran code should conform with the Fortran 95 standard,
20
+ # and all the generated C code should be ANSI C, which facilitates the
21
+ # incorporation in various projects. The tests below assume that the binary cc
22
+ # is somewhere in the path and that it can compile ANSI C code.
23
+
24
+ from sympy.abc import x, y, z
25
+ from sympy.testing.pytest import IS_WASM, skip
26
+ from sympy.utilities.codegen import codegen, make_routine, get_code_generator
27
+ import sys
28
+ import os
29
+ import tempfile
30
+ import subprocess
31
+ from pathlib import Path
32
+
33
+
34
+ # templates for the main program that will test the generated code.
35
+
36
+ main_template = {}
37
+ main_template['F95'] = """
38
+ program main
39
+ include "codegen.h"
40
+ integer :: result;
41
+ result = 0
42
+
43
+ %(statements)s
44
+
45
+ call exit(result)
46
+ end program
47
+ """
48
+
49
+ main_template['C89'] = """
50
+ #include "codegen.h"
51
+ #include <stdio.h>
52
+ #include <math.h>
53
+
54
+ int main() {
55
+ int result = 0;
56
+
57
+ %(statements)s
58
+
59
+ return result;
60
+ }
61
+ """
62
+ main_template['C99'] = main_template['C89']
63
+ # templates for the numerical tests
64
+
65
+ numerical_test_template = {}
66
+ numerical_test_template['C89'] = """
67
+ if (fabs(%(call)s)>%(threshold)s) {
68
+ printf("Numerical validation failed: %(call)s=%%e threshold=%(threshold)s\\n", %(call)s);
69
+ result = -1;
70
+ }
71
+ """
72
+ numerical_test_template['C99'] = numerical_test_template['C89']
73
+
74
+ numerical_test_template['F95'] = """
75
+ if (abs(%(call)s)>%(threshold)s) then
76
+ write(6,"('Numerical validation failed:')")
77
+ write(6,"('%(call)s=',e15.5,'threshold=',e15.5)") %(call)s, %(threshold)s
78
+ result = -1;
79
+ end if
80
+ """
81
+ # command sequences for supported compilers
82
+
83
+ compile_commands = {}
84
+ compile_commands['cc'] = [
85
+ "cc -c codegen.c -o codegen.o",
86
+ "cc -c main.c -o main.o",
87
+ "cc main.o codegen.o -lm -o test.exe"
88
+ ]
89
+
90
+ compile_commands['gfortran'] = [
91
+ "gfortran -c codegen.f90 -o codegen.o",
92
+ "gfortran -ffree-line-length-none -c main.f90 -o main.o",
93
+ "gfortran main.o codegen.o -o test.exe"
94
+ ]
95
+
96
+ compile_commands['g95'] = [
97
+ "g95 -c codegen.f90 -o codegen.o",
98
+ "g95 -ffree-line-length-huge -c main.f90 -o main.o",
99
+ "g95 main.o codegen.o -o test.exe"
100
+ ]
101
+
102
+ compile_commands['ifort'] = [
103
+ "ifort -c codegen.f90 -o codegen.o",
104
+ "ifort -c main.f90 -o main.o",
105
+ "ifort main.o codegen.o -o test.exe"
106
+ ]
107
+
108
+ combinations_lang_compiler = [
109
+ ('C89', 'cc'),
110
+ ('C99', 'cc'),
111
+ ('F95', 'ifort'),
112
+ ('F95', 'gfortran'),
113
+ ('F95', 'g95')
114
+ ]
115
+
116
+ def try_run(commands):
117
+ """Run a series of commands and only return True if all ran fine."""
118
+ if IS_WASM:
119
+ return False
120
+ with open(os.devnull, 'w') as null:
121
+ for command in commands:
122
+ retcode = subprocess.call(command, stdout=null, shell=True,
123
+ stderr=subprocess.STDOUT)
124
+ if retcode != 0:
125
+ return False
126
+ return True
127
+
128
+
129
+ def run_test(label, routines, numerical_tests, language, commands, friendly=True):
130
+ """A driver for the codegen tests.
131
+
132
+ This driver assumes that a compiler ifort is present in the PATH and that
133
+ ifort is (at least) a Fortran 90 compiler. The generated code is written in
134
+ a temporary directory, together with a main program that validates the
135
+ generated code. The test passes when the compilation and the validation
136
+ run correctly.
137
+ """
138
+
139
+ # Check input arguments before touching the file system
140
+ language = language.upper()
141
+ assert language in main_template
142
+ assert language in numerical_test_template
143
+
144
+ # Check that environment variable makes sense
145
+ clean = os.getenv('SYMPY_TEST_CLEAN_TEMP', 'always').lower()
146
+ if clean not in ('always', 'success', 'never'):
147
+ raise ValueError("SYMPY_TEST_CLEAN_TEMP must be one of the following: 'always', 'success' or 'never'.")
148
+
149
+ # Do all the magic to compile, run and validate the test code
150
+ # 1) prepare the temporary working directory, switch to that dir
151
+ work = tempfile.mkdtemp("_sympy_%s_test" % language, "%s_" % label)
152
+ oldwork = os.getcwd()
153
+ os.chdir(work)
154
+
155
+ # 2) write the generated code
156
+ if friendly:
157
+ # interpret the routines as a name_expr list and call the friendly
158
+ # function codegen
159
+ codegen(routines, language, "codegen", to_files=True)
160
+ else:
161
+ code_gen = get_code_generator(language, "codegen")
162
+ code_gen.write(routines, "codegen", to_files=True)
163
+
164
+ # 3) write a simple main program that links to the generated code, and that
165
+ # includes the numerical tests
166
+ test_strings = []
167
+ for fn_name, args, expected, threshold in numerical_tests:
168
+ call_string = "%s(%s)-(%s)" % (
169
+ fn_name, ",".join(str(arg) for arg in args), expected)
170
+ if language == "F95":
171
+ call_string = fortranize_double_constants(call_string)
172
+ threshold = fortranize_double_constants(str(threshold))
173
+ test_strings.append(numerical_test_template[language] % {
174
+ "call": call_string,
175
+ "threshold": threshold,
176
+ })
177
+
178
+ if language == "F95":
179
+ f_name = "main.f90"
180
+ elif language.startswith("C"):
181
+ f_name = "main.c"
182
+ else:
183
+ raise NotImplementedError(
184
+ "FIXME: filename extension unknown for language: %s" % language)
185
+
186
+ Path(f_name).write_text(
187
+ main_template[language] % {'statements': "".join(test_strings)})
188
+
189
+ # 4) Compile and link
190
+ compiled = try_run(commands)
191
+
192
+ # 5) Run if compiled
193
+ if compiled:
194
+ executed = try_run(["./test.exe"])
195
+ else:
196
+ executed = False
197
+
198
+ # 6) Clean up stuff
199
+ if clean == 'always' or (clean == 'success' and compiled and executed):
200
+ def safe_remove(filename):
201
+ if os.path.isfile(filename):
202
+ os.remove(filename)
203
+ safe_remove("codegen.f90")
204
+ safe_remove("codegen.c")
205
+ safe_remove("codegen.h")
206
+ safe_remove("codegen.o")
207
+ safe_remove("main.f90")
208
+ safe_remove("main.c")
209
+ safe_remove("main.o")
210
+ safe_remove("test.exe")
211
+ os.chdir(oldwork)
212
+ os.rmdir(work)
213
+ else:
214
+ print("TEST NOT REMOVED: %s" % work, file=sys.stderr)
215
+ os.chdir(oldwork)
216
+
217
+ # 7) Do the assertions in the end
218
+ assert compiled, "failed to compile %s code with:\n%s" % (
219
+ language, "\n".join(commands))
220
+ assert executed, "failed to execute %s code from:\n%s" % (
221
+ language, "\n".join(commands))
222
+
223
+
224
+ def fortranize_double_constants(code_string):
225
+ """
226
+ Replaces every literal float with literal doubles
227
+ """
228
+ import re
229
+ pattern_exp = re.compile(r'\d+(\.)?\d*[eE]-?\d+')
230
+ pattern_float = re.compile(r'\d+\.\d*(?!\d*d)')
231
+
232
+ def subs_exp(matchobj):
233
+ return re.sub('[eE]', 'd', matchobj.group(0))
234
+
235
+ def subs_float(matchobj):
236
+ return "%sd0" % matchobj.group(0)
237
+
238
+ code_string = pattern_exp.sub(subs_exp, code_string)
239
+ code_string = pattern_float.sub(subs_float, code_string)
240
+
241
+ return code_string
242
+
243
+
244
+ def is_feasible(language, commands):
245
+ # This test should always work, otherwise the compiler is not present.
246
+ routine = make_routine("test", x)
247
+ numerical_tests = [
248
+ ("test", ( 1.0,), 1.0, 1e-15),
249
+ ("test", (-1.0,), -1.0, 1e-15),
250
+ ]
251
+ try:
252
+ run_test("is_feasible", [routine], numerical_tests, language, commands,
253
+ friendly=False)
254
+ return True
255
+ except AssertionError:
256
+ return False
257
+
258
+ valid_lang_commands = []
259
+ invalid_lang_compilers = []
260
+ for lang, compiler in combinations_lang_compiler:
261
+ commands = compile_commands[compiler]
262
+ if is_feasible(lang, commands):
263
+ valid_lang_commands.append((lang, commands))
264
+ else:
265
+ invalid_lang_compilers.append((lang, compiler))
266
+
267
+ # We test all language-compiler combinations, just to report what is skipped
268
+
269
+ def test_C89_cc():
270
+ if ("C89", 'cc') in invalid_lang_compilers:
271
+ skip("`cc' command didn't work as expected (C89)")
272
+
273
+
274
+ def test_C99_cc():
275
+ if ("C99", 'cc') in invalid_lang_compilers:
276
+ skip("`cc' command didn't work as expected (C99)")
277
+
278
+
279
+ def test_F95_ifort():
280
+ if ("F95", 'ifort') in invalid_lang_compilers:
281
+ skip("`ifort' command didn't work as expected")
282
+
283
+
284
+ def test_F95_gfortran():
285
+ if ("F95", 'gfortran') in invalid_lang_compilers:
286
+ skip("`gfortran' command didn't work as expected")
287
+
288
+
289
+ def test_F95_g95():
290
+ if ("F95", 'g95') in invalid_lang_compilers:
291
+ skip("`g95' command didn't work as expected")
292
+
293
+ # Here comes the actual tests
294
+
295
+
296
+ def test_basic_codegen():
297
+ numerical_tests = [
298
+ ("test", (1.0, 6.0, 3.0), 21.0, 1e-15),
299
+ ("test", (-1.0, 2.0, -2.5), -2.5, 1e-15),
300
+ ]
301
+ name_expr = [("test", (x + y)*z)]
302
+ for lang, commands in valid_lang_commands:
303
+ run_test("basic_codegen", name_expr, numerical_tests, lang, commands)
304
+
305
+
306
+ def test_intrinsic_math1_codegen():
307
+ # not included: log10
308
+ from sympy.core.evalf import N
309
+ from sympy.functions import ln
310
+ from sympy.functions.elementary.exponential import log
311
+ from sympy.functions.elementary.hyperbolic import (cosh, sinh, tanh)
312
+ from sympy.functions.elementary.integers import (ceiling, floor)
313
+ from sympy.functions.elementary.miscellaneous import sqrt
314
+ from sympy.functions.elementary.trigonometric import (acos, asin, atan, cos, sin, tan)
315
+ name_expr = [
316
+ ("test_fabs", abs(x)),
317
+ ("test_acos", acos(x)),
318
+ ("test_asin", asin(x)),
319
+ ("test_atan", atan(x)),
320
+ ("test_cos", cos(x)),
321
+ ("test_cosh", cosh(x)),
322
+ ("test_log", log(x)),
323
+ ("test_ln", ln(x)),
324
+ ("test_sin", sin(x)),
325
+ ("test_sinh", sinh(x)),
326
+ ("test_sqrt", sqrt(x)),
327
+ ("test_tan", tan(x)),
328
+ ("test_tanh", tanh(x)),
329
+ ]
330
+ numerical_tests = []
331
+ for name, expr in name_expr:
332
+ for xval in 0.2, 0.5, 0.8:
333
+ expected = N(expr.subs(x, xval))
334
+ numerical_tests.append((name, (xval,), expected, 1e-14))
335
+ for lang, commands in valid_lang_commands:
336
+ if lang.startswith("C"):
337
+ name_expr_C = [("test_floor", floor(x)), ("test_ceil", ceiling(x))]
338
+ else:
339
+ name_expr_C = []
340
+ run_test("intrinsic_math1", name_expr + name_expr_C,
341
+ numerical_tests, lang, commands)
342
+
343
+
344
+ def test_instrinsic_math2_codegen():
345
+ # not included: frexp, ldexp, modf, fmod
346
+ from sympy.core.evalf import N
347
+ from sympy.functions.elementary.trigonometric import atan2
348
+ name_expr = [
349
+ ("test_atan2", atan2(x, y)),
350
+ ("test_pow", x**y),
351
+ ]
352
+ numerical_tests = []
353
+ for name, expr in name_expr:
354
+ for xval, yval in (0.2, 1.3), (0.5, -0.2), (0.8, 0.8):
355
+ expected = N(expr.subs(x, xval).subs(y, yval))
356
+ numerical_tests.append((name, (xval, yval), expected, 1e-14))
357
+ for lang, commands in valid_lang_commands:
358
+ run_test("intrinsic_math2", name_expr, numerical_tests, lang, commands)
359
+
360
+
361
+ def test_complicated_codegen():
362
+ from sympy.core.evalf import N
363
+ from sympy.functions.elementary.trigonometric import (cos, sin, tan)
364
+ name_expr = [
365
+ ("test1", ((sin(x) + cos(y) + tan(z))**7).expand()),
366
+ ("test2", cos(cos(cos(cos(cos(cos(cos(cos(x + y + z))))))))),
367
+ ]
368
+ numerical_tests = []
369
+ for name, expr in name_expr:
370
+ for xval, yval, zval in (0.2, 1.3, -0.3), (0.5, -0.2, 0.0), (0.8, 2.1, 0.8):
371
+ expected = N(expr.subs(x, xval).subs(y, yval).subs(z, zval))
372
+ numerical_tests.append((name, (xval, yval, zval), expected, 1e-12))
373
+ for lang, commands in valid_lang_commands:
374
+ run_test(
375
+ "complicated_codegen", name_expr, numerical_tests, lang, commands)
.venv/lib/python3.13/site-packages/sympy/external/tests/test_gmpy.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.external.gmpy import LONG_MAX, iroot
2
+ from sympy.testing.pytest import raises
3
+
4
+
5
+ def test_iroot():
6
+ assert iroot(2, LONG_MAX) == (1, False)
7
+ assert iroot(2, LONG_MAX + 1) == (1, False)
8
+ for x in range(3):
9
+ assert iroot(x, 1) == (x, True)
10
+ raises(ValueError, lambda: iroot(-1, 1))
11
+ raises(ValueError, lambda: iroot(0, 0))
12
+ raises(ValueError, lambda: iroot(0, -1))
.venv/lib/python3.13/site-packages/sympy/external/tests/test_importtools.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from sympy.external import import_module
2
+ from sympy.testing.pytest import warns
3
+
4
+ # fixes issue that arose in addressing issue 6533
5
+ def test_no_stdlib_collections():
6
+ '''
7
+ make sure we get the right collections when it is not part of a
8
+ larger list
9
+ '''
10
+ import collections
11
+ matplotlib = import_module('matplotlib',
12
+ import_kwargs={'fromlist': ['cm', 'collections']},
13
+ min_module_version='1.1.0', catch=(RuntimeError,))
14
+ if matplotlib:
15
+ assert collections != matplotlib.collections
16
+
17
+ def test_no_stdlib_collections2():
18
+ '''
19
+ make sure we get the right collections when it is not part of a
20
+ larger list
21
+ '''
22
+ import collections
23
+ matplotlib = import_module('matplotlib',
24
+ import_kwargs={'fromlist': ['collections']},
25
+ min_module_version='1.1.0', catch=(RuntimeError,))
26
+ if matplotlib:
27
+ assert collections != matplotlib.collections
28
+
29
+ def test_no_stdlib_collections3():
30
+ '''make sure we get the right collections with no catch'''
31
+ import collections
32
+ matplotlib = import_module('matplotlib',
33
+ import_kwargs={'fromlist': ['cm', 'collections']},
34
+ min_module_version='1.1.0')
35
+ if matplotlib:
36
+ assert collections != matplotlib.collections
37
+
38
+ def test_min_module_version_python3_basestring_error():
39
+ with warns(UserWarning):
40
+ import_module('mpmath', min_module_version='1000.0.1')
.venv/lib/python3.13/site-packages/sympy/external/tests/test_ntheory.py ADDED
@@ -0,0 +1,307 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from itertools import permutations
2
+
3
+ from sympy.external.ntheory import (bit_scan1, remove, bit_scan0, is_fermat_prp,
4
+ is_euler_prp, is_strong_prp, gcdext, _lucas_sequence,
5
+ is_fibonacci_prp, is_lucas_prp, is_selfridge_prp,
6
+ is_strong_lucas_prp, is_strong_selfridge_prp,
7
+ is_bpsw_prp, is_strong_bpsw_prp)
8
+ from sympy.testing.pytest import raises
9
+
10
+
11
+ def test_bit_scan1():
12
+ assert bit_scan1(0) is None
13
+ assert bit_scan1(1) == 0
14
+ assert bit_scan1(-1) == 0
15
+ assert bit_scan1(2) == 1
16
+ assert bit_scan1(7) == 0
17
+ assert bit_scan1(-7) == 0
18
+ for i in range(100):
19
+ assert bit_scan1(1 << i) == i
20
+ assert bit_scan1((1 << i) * 31337) == i
21
+ for i in range(500):
22
+ n = (1 << 500) + (1 << i)
23
+ assert bit_scan1(n) == i
24
+ assert bit_scan1(1 << 1000001) == 1000001
25
+ assert bit_scan1((1 << 273956)*7**37) == 273956
26
+ # issue 12709
27
+ for i in range(1, 10):
28
+ big = 1 << i
29
+ assert bit_scan1(-big) == bit_scan1(big)
30
+
31
+
32
+ def test_bit_scan0():
33
+ assert bit_scan0(-1) is None
34
+ assert bit_scan0(0) == 0
35
+ assert bit_scan0(1) == 1
36
+ assert bit_scan0(-2) == 0
37
+
38
+
39
+ def test_remove():
40
+ raises(ValueError, lambda: remove(1, 1))
41
+ assert remove(0, 3) == (0, 0)
42
+ for f in range(2, 10):
43
+ for y in range(2, 1000):
44
+ for z in [1, 17, 101, 1009]:
45
+ assert remove(z*f**y, f) == (z, y)
46
+
47
+
48
+ def test_gcdext():
49
+ assert gcdext(0, 0) == (0, 0, 0)
50
+ assert gcdext(3, 0) == (3, 1, 0)
51
+ assert gcdext(0, 4) == (4, 0, 1)
52
+
53
+ for n in range(1, 10):
54
+ assert gcdext(n, 1) == gcdext(-n, 1) == (1, 0, 1)
55
+ assert gcdext(n, -1) == gcdext(-n, -1) == (1, 0, -1)
56
+ assert gcdext(n, n) == gcdext(-n, n) == (n, 0, 1)
57
+ assert gcdext(n, -n) == gcdext(-n, -n) == (n, 0, -1)
58
+
59
+ for n in range(2, 10):
60
+ assert gcdext(1, n) == gcdext(1, -n) == (1, 1, 0)
61
+ assert gcdext(-1, n) == gcdext(-1, -n) == (1, -1, 0)
62
+
63
+ for a, b in permutations([2**5, 3, 5, 7**2, 11], 2):
64
+ g, x, y = gcdext(a, b)
65
+ assert g == a*x + b*y == 1
66
+
67
+
68
+ def test_is_fermat_prp():
69
+ # invalid input
70
+ raises(ValueError, lambda: is_fermat_prp(0, 10))
71
+ raises(ValueError, lambda: is_fermat_prp(5, 1))
72
+
73
+ # n = 1
74
+ assert not is_fermat_prp(1, 3)
75
+
76
+ # n is prime
77
+ assert is_fermat_prp(2, 4)
78
+ assert is_fermat_prp(3, 2)
79
+ assert is_fermat_prp(11, 3)
80
+ assert is_fermat_prp(2**31-1, 5)
81
+
82
+ # A001567
83
+ pseudorpime = [341, 561, 645, 1105, 1387, 1729, 1905, 2047,
84
+ 2465, 2701, 2821, 3277, 4033, 4369, 4371, 4681]
85
+ for n in pseudorpime:
86
+ assert is_fermat_prp(n, 2)
87
+
88
+ # A020136
89
+ pseudorpime = [15, 85, 91, 341, 435, 451, 561, 645, 703, 1105,
90
+ 1247, 1271, 1387, 1581, 1695, 1729, 1891, 1905]
91
+ for n in pseudorpime:
92
+ assert is_fermat_prp(n, 4)
93
+
94
+
95
+ def test_is_euler_prp():
96
+ # invalid input
97
+ raises(ValueError, lambda: is_euler_prp(0, 10))
98
+ raises(ValueError, lambda: is_euler_prp(5, 1))
99
+
100
+ # n = 1
101
+ assert not is_euler_prp(1, 3)
102
+
103
+ # n is prime
104
+ assert is_euler_prp(2, 4)
105
+ assert is_euler_prp(3, 2)
106
+ assert is_euler_prp(11, 3)
107
+ assert is_euler_prp(2**31-1, 5)
108
+
109
+ # A047713
110
+ pseudorpime = [561, 1105, 1729, 1905, 2047, 2465, 3277, 4033,
111
+ 4681, 6601, 8321, 8481, 10585, 12801, 15841]
112
+ for n in pseudorpime:
113
+ assert is_euler_prp(n, 2)
114
+
115
+ # A048950
116
+ pseudorpime = [121, 703, 1729, 1891, 2821, 3281, 7381, 8401,
117
+ 8911, 10585, 12403, 15457, 15841, 16531, 18721]
118
+ for n in pseudorpime:
119
+ assert is_euler_prp(n, 3)
120
+
121
+
122
+ def test_is_strong_prp():
123
+ # invalid input
124
+ raises(ValueError, lambda: is_strong_prp(0, 10))
125
+ raises(ValueError, lambda: is_strong_prp(5, 1))
126
+
127
+ # n = 1
128
+ assert not is_strong_prp(1, 3)
129
+
130
+ # n is prime
131
+ assert is_strong_prp(2, 4)
132
+ assert is_strong_prp(3, 2)
133
+ assert is_strong_prp(11, 3)
134
+ assert is_strong_prp(2**31-1, 5)
135
+
136
+ # A001262
137
+ pseudorpime = [2047, 3277, 4033, 4681, 8321, 15841, 29341,
138
+ 42799, 49141, 52633, 65281, 74665, 80581]
139
+ for n in pseudorpime:
140
+ assert is_strong_prp(n, 2)
141
+
142
+ # A020229
143
+ pseudorpime = [121, 703, 1891, 3281, 8401, 8911, 10585, 12403,
144
+ 16531, 18721, 19345, 23521, 31621, 44287, 47197]
145
+ for n in pseudorpime:
146
+ assert is_strong_prp(n, 3)
147
+
148
+
149
+ def test_lucas_sequence():
150
+ def lucas_u(P, Q, length):
151
+ array = [0] * length
152
+ array[1] = 1
153
+ for k in range(2, length):
154
+ array[k] = P * array[k - 1] - Q * array[k - 2]
155
+ return array
156
+
157
+ def lucas_v(P, Q, length):
158
+ array = [0] * length
159
+ array[0] = 2
160
+ array[1] = P
161
+ for k in range(2, length):
162
+ array[k] = P * array[k - 1] - Q * array[k - 2]
163
+ return array
164
+
165
+ length = 20
166
+ for P in range(-10, 10):
167
+ for Q in range(-10, 10):
168
+ D = P**2 - 4*Q
169
+ if D == 0:
170
+ continue
171
+ us = lucas_u(P, Q, length)
172
+ vs = lucas_v(P, Q, length)
173
+ for n in range(3, 100, 2):
174
+ for k in range(length):
175
+ U, V, Qk = _lucas_sequence(n, P, Q, k)
176
+ assert U == us[k] % n
177
+ assert V == vs[k] % n
178
+ assert pow(Q, k, n) == Qk
179
+
180
+
181
+ def test_is_fibonacci_prp():
182
+ # invalid input
183
+ raises(ValueError, lambda: is_fibonacci_prp(3, 2, 1))
184
+ raises(ValueError, lambda: is_fibonacci_prp(3, -5, 1))
185
+ raises(ValueError, lambda: is_fibonacci_prp(3, 5, 2))
186
+ raises(ValueError, lambda: is_fibonacci_prp(0, 5, -1))
187
+
188
+ # n = 1
189
+ assert not is_fibonacci_prp(1, 3, 1)
190
+
191
+ # n is prime
192
+ assert is_fibonacci_prp(2, 5, 1)
193
+ assert is_fibonacci_prp(3, 6, -1)
194
+ assert is_fibonacci_prp(11, 7, 1)
195
+ assert is_fibonacci_prp(2**31-1, 8, -1)
196
+
197
+ # A005845
198
+ pseudorpime = [705, 2465, 2737, 3745, 4181, 5777, 6721,
199
+ 10877, 13201, 15251, 24465, 29281, 34561]
200
+ for n in pseudorpime:
201
+ assert is_fibonacci_prp(n, 1, -1)
202
+
203
+
204
+ def test_is_lucas_prp():
205
+ # invalid input
206
+ raises(ValueError, lambda: is_lucas_prp(3, 2, 1))
207
+ raises(ValueError, lambda: is_lucas_prp(0, 5, -1))
208
+ raises(ValueError, lambda: is_lucas_prp(15, 3, 1))
209
+
210
+ # n = 1
211
+ assert not is_lucas_prp(1, 3, 1)
212
+
213
+ # n is prime
214
+ assert is_lucas_prp(2, 5, 2)
215
+ assert is_lucas_prp(3, 6, -1)
216
+ assert is_lucas_prp(11, 7, 5)
217
+ assert is_lucas_prp(2**31-1, 8, -3)
218
+
219
+ # A081264
220
+ pseudorpime = [323, 377, 1891, 3827, 4181, 5777, 6601, 6721,
221
+ 8149, 10877, 11663, 13201, 13981, 15251, 17119]
222
+ for n in pseudorpime:
223
+ assert is_lucas_prp(n, 1, -1)
224
+
225
+
226
+ def test_is_selfridge_prp():
227
+ # invalid input
228
+ raises(ValueError, lambda: is_selfridge_prp(0))
229
+
230
+ # n = 1
231
+ assert not is_selfridge_prp(1)
232
+
233
+ # n is prime
234
+ assert is_selfridge_prp(2)
235
+ assert is_selfridge_prp(3)
236
+ assert is_selfridge_prp(11)
237
+ assert is_selfridge_prp(2**31-1)
238
+
239
+ # A217120
240
+ pseudorpime = [323, 377, 1159, 1829, 3827, 5459, 5777, 9071,
241
+ 9179, 10877, 11419, 11663, 13919, 14839, 16109]
242
+ for n in pseudorpime:
243
+ assert is_selfridge_prp(n)
244
+
245
+
246
+ def test_is_strong_lucas_prp():
247
+ # invalid input
248
+ raises(ValueError, lambda: is_strong_lucas_prp(3, 2, 1))
249
+ raises(ValueError, lambda: is_strong_lucas_prp(0, 5, -1))
250
+ raises(ValueError, lambda: is_strong_lucas_prp(15, 3, 1))
251
+
252
+ # n = 1
253
+ assert not is_strong_lucas_prp(1, 3, 1)
254
+
255
+ # n is prime
256
+ assert is_strong_lucas_prp(2, 5, 2)
257
+ assert is_strong_lucas_prp(3, 6, -1)
258
+ assert is_strong_lucas_prp(11, 7, 5)
259
+ assert is_strong_lucas_prp(2**31-1, 8, -3)
260
+
261
+
262
+ def test_is_strong_selfridge_prp():
263
+ # invalid input
264
+ raises(ValueError, lambda: is_strong_selfridge_prp(0))
265
+
266
+ # n = 1
267
+ assert not is_strong_selfridge_prp(1)
268
+
269
+ # n is prime
270
+ assert is_strong_selfridge_prp(2)
271
+ assert is_strong_selfridge_prp(3)
272
+ assert is_strong_selfridge_prp(11)
273
+ assert is_strong_selfridge_prp(2**31-1)
274
+
275
+ # A217255
276
+ pseudorpime = [5459, 5777, 10877, 16109, 18971, 22499, 24569,
277
+ 25199, 40309, 58519, 75077, 97439, 100127, 113573]
278
+ for n in pseudorpime:
279
+ assert is_strong_selfridge_prp(n)
280
+
281
+
282
+ def test_is_bpsw_prp():
283
+ # invalid input
284
+ raises(ValueError, lambda: is_bpsw_prp(0))
285
+
286
+ # n = 1
287
+ assert not is_bpsw_prp(1)
288
+
289
+ # n is prime
290
+ assert is_bpsw_prp(2)
291
+ assert is_bpsw_prp(3)
292
+ assert is_bpsw_prp(11)
293
+ assert is_bpsw_prp(2**31-1)
294
+
295
+
296
+ def test_is_strong_bpsw_prp():
297
+ # invalid input
298
+ raises(ValueError, lambda: is_strong_bpsw_prp(0))
299
+
300
+ # n = 1
301
+ assert not is_strong_bpsw_prp(1)
302
+
303
+ # n is prime
304
+ assert is_strong_bpsw_prp(2)
305
+ assert is_strong_bpsw_prp(3)
306
+ assert is_strong_bpsw_prp(11)
307
+ assert is_strong_bpsw_prp(2**31-1)
.venv/lib/python3.13/site-packages/sympy/external/tests/test_numpy.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This testfile tests SymPy <-> NumPy compatibility
2
+
3
+ # Don't test any SymPy features here. Just pure interaction with NumPy.
4
+ # Always write regular SymPy tests for anything, that can be tested in pure
5
+ # Python (without numpy). Here we test everything, that a user may need when
6
+ # using SymPy with NumPy
7
+ from sympy.external.importtools import version_tuple
8
+ from sympy.external import import_module
9
+
10
+ numpy = import_module('numpy')
11
+ if numpy:
12
+ array, matrix, ndarray = numpy.array, numpy.matrix, numpy.ndarray
13
+ else:
14
+ #bin/test will not execute any tests now
15
+ disabled = True
16
+
17
+
18
+ from sympy.core.numbers import (Float, Integer, Rational)
19
+ from sympy.core.symbol import (Symbol, symbols)
20
+ from sympy.functions.elementary.trigonometric import sin
21
+ from sympy.matrices.dense import (Matrix, list2numpy, matrix2numpy, symarray)
22
+ from sympy.utilities.lambdify import lambdify
23
+ import sympy
24
+
25
+ import mpmath
26
+ from sympy.abc import x, y, z
27
+ from sympy.utilities.decorator import conserve_mpmath_dps
28
+ from sympy.utilities.exceptions import ignore_warnings
29
+ from sympy.testing.pytest import raises
30
+
31
+
32
+ # first, systematically check, that all operations are implemented and don't
33
+ # raise an exception
34
+
35
+
36
+ def test_systematic_basic():
37
+ def s(sympy_object, numpy_array):
38
+ _ = [sympy_object + numpy_array,
39
+ numpy_array + sympy_object,
40
+ sympy_object - numpy_array,
41
+ numpy_array - sympy_object,
42
+ sympy_object * numpy_array,
43
+ numpy_array * sympy_object,
44
+ sympy_object / numpy_array,
45
+ numpy_array / sympy_object,
46
+ sympy_object ** numpy_array,
47
+ numpy_array ** sympy_object]
48
+ x = Symbol("x")
49
+ y = Symbol("y")
50
+ sympy_objs = [
51
+ Rational(2, 3),
52
+ Float("1.3"),
53
+ x,
54
+ y,
55
+ pow(x, y)*y,
56
+ Integer(5),
57
+ Float(5.5),
58
+ ]
59
+ numpy_objs = [
60
+ array([1]),
61
+ array([3, 8, -1]),
62
+ array([x, x**2, Rational(5)]),
63
+ array([x/y*sin(y), 5, Rational(5)]),
64
+ ]
65
+ for x in sympy_objs:
66
+ for y in numpy_objs:
67
+ s(x, y)
68
+
69
+
70
+ # now some random tests, that test particular problems and that also
71
+ # check that the results of the operations are correct
72
+
73
+ def test_basics():
74
+ one = Rational(1)
75
+ zero = Rational(0)
76
+ assert array(1) == array(one)
77
+ assert array([one]) == array([one])
78
+ assert array([x]) == array([x])
79
+ assert array(x) == array(Symbol("x"))
80
+ assert array(one + x) == array(1 + x)
81
+
82
+ X = array([one, zero, zero])
83
+ assert (X == array([one, zero, zero])).all()
84
+ assert (X == array([one, 0, 0])).all()
85
+
86
+
87
+ def test_arrays():
88
+ one = Rational(1)
89
+ zero = Rational(0)
90
+ X = array([one, zero, zero])
91
+ Y = one*X
92
+ X = array([Symbol("a") + Rational(1, 2)])
93
+ Y = X + X
94
+ assert Y == array([1 + 2*Symbol("a")])
95
+ Y = Y + 1
96
+ assert Y == array([2 + 2*Symbol("a")])
97
+ Y = X - X
98
+ assert Y == array([0])
99
+
100
+
101
+ def test_conversion1():
102
+ a = list2numpy([x**2, x])
103
+ #looks like an array?
104
+ assert isinstance(a, ndarray)
105
+ assert a[0] == x**2
106
+ assert a[1] == x
107
+ assert len(a) == 2
108
+ #yes, it's the array
109
+
110
+
111
+ def test_conversion2():
112
+ a = 2*list2numpy([x**2, x])
113
+ b = list2numpy([2*x**2, 2*x])
114
+ assert (a == b).all()
115
+
116
+ one = Rational(1)
117
+ zero = Rational(0)
118
+ X = list2numpy([one, zero, zero])
119
+ Y = one*X
120
+ X = list2numpy([Symbol("a") + Rational(1, 2)])
121
+ Y = X + X
122
+ assert Y == array([1 + 2*Symbol("a")])
123
+ Y = Y + 1
124
+ assert Y == array([2 + 2*Symbol("a")])
125
+ Y = X - X
126
+ assert Y == array([0])
127
+
128
+
129
+ def test_list2numpy():
130
+ assert (array([x**2, x]) == list2numpy([x**2, x])).all()
131
+
132
+
133
+ def test_Matrix1():
134
+ m = Matrix([[x, x**2], [5, 2/x]])
135
+ assert (array(m.subs(x, 2)) == array([[2, 4], [5, 1]])).all()
136
+ m = Matrix([[sin(x), x**2], [5, 2/x]])
137
+ assert (array(m.subs(x, 2)) == array([[sin(2), 4], [5, 1]])).all()
138
+
139
+
140
+ def test_Matrix2():
141
+ m = Matrix([[x, x**2], [5, 2/x]])
142
+ with ignore_warnings(PendingDeprecationWarning):
143
+ assert (matrix(m.subs(x, 2)) == matrix([[2, 4], [5, 1]])).all()
144
+ m = Matrix([[sin(x), x**2], [5, 2/x]])
145
+ with ignore_warnings(PendingDeprecationWarning):
146
+ assert (matrix(m.subs(x, 2)) == matrix([[sin(2), 4], [5, 1]])).all()
147
+
148
+
149
+ def test_Matrix3():
150
+ a = array([[2, 4], [5, 1]])
151
+ assert Matrix(a) == Matrix([[2, 4], [5, 1]])
152
+ assert Matrix(a) != Matrix([[2, 4], [5, 2]])
153
+ a = array([[sin(2), 4], [5, 1]])
154
+ assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]])
155
+ assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]])
156
+
157
+
158
+ def test_Matrix4():
159
+ with ignore_warnings(PendingDeprecationWarning):
160
+ a = matrix([[2, 4], [5, 1]])
161
+ assert Matrix(a) == Matrix([[2, 4], [5, 1]])
162
+ assert Matrix(a) != Matrix([[2, 4], [5, 2]])
163
+ with ignore_warnings(PendingDeprecationWarning):
164
+ a = matrix([[sin(2), 4], [5, 1]])
165
+ assert Matrix(a) == Matrix([[sin(2), 4], [5, 1]])
166
+ assert Matrix(a) != Matrix([[sin(0), 4], [5, 1]])
167
+
168
+
169
+ def test_Matrix_sum():
170
+ M = Matrix([[1, 2, 3], [x, y, x], [2*y, -50, z*x]])
171
+ with ignore_warnings(PendingDeprecationWarning):
172
+ m = matrix([[2, 3, 4], [x, 5, 6], [x, y, z**2]])
173
+ assert M + m == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]])
174
+ assert m + M == Matrix([[3, 5, 7], [2*x, y + 5, x + 6], [2*y + x, y - 50, z*x + z**2]])
175
+ assert M + m == M.add(m)
176
+
177
+
178
+ def test_Matrix_mul():
179
+ M = Matrix([[1, 2, 3], [x, y, x]])
180
+ with ignore_warnings(PendingDeprecationWarning):
181
+ m = matrix([[2, 4], [x, 6], [x, z**2]])
182
+ assert M*m == Matrix([
183
+ [ 2 + 5*x, 16 + 3*z**2],
184
+ [2*x + x*y + x**2, 4*x + 6*y + x*z**2],
185
+ ])
186
+
187
+ assert m*M == Matrix([
188
+ [ 2 + 4*x, 4 + 4*y, 6 + 4*x],
189
+ [ 7*x, 2*x + 6*y, 9*x],
190
+ [x + x*z**2, 2*x + y*z**2, 3*x + x*z**2],
191
+ ])
192
+ a = array([2])
193
+ assert a[0] * M == 2 * M
194
+ assert M * a[0] == 2 * M
195
+
196
+
197
+ def test_Matrix_array():
198
+ class matarray:
199
+ def __array__(self, dtype=object, copy=None):
200
+ if copy is not None and not copy:
201
+ raise TypeError("Cannot implement copy=False when converting Matrix to ndarray")
202
+ from numpy import array
203
+ return array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
204
+ matarr = matarray()
205
+ assert Matrix(matarr) == Matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
206
+
207
+
208
+ def test_matrix2numpy():
209
+ a = matrix2numpy(Matrix([[1, x**2], [3*sin(x), 0]]))
210
+ assert isinstance(a, ndarray)
211
+ assert a.shape == (2, 2)
212
+ assert a[0, 0] == 1
213
+ assert a[0, 1] == x**2
214
+ assert a[1, 0] == 3*sin(x)
215
+ assert a[1, 1] == 0
216
+
217
+
218
+ def test_matrix2numpy_conversion():
219
+ a = Matrix([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]])
220
+ b = array([[1, 2, sin(x)], [x**2, x, Rational(1, 2)]])
221
+ assert (matrix2numpy(a) == b).all()
222
+ assert matrix2numpy(a).dtype == numpy.dtype('object')
223
+
224
+ c = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='int8')
225
+ d = matrix2numpy(Matrix([[1, 2], [10, 20]]), dtype='float64')
226
+ assert c.dtype == numpy.dtype('int8')
227
+ assert d.dtype == numpy.dtype('float64')
228
+
229
+
230
+ def test_issue_3728():
231
+ assert (Rational(1, 2)*array([2*x, 0]) == array([x, 0])).all()
232
+ assert (Rational(1, 2) + array(
233
+ [2*x, 0]) == array([2*x + Rational(1, 2), Rational(1, 2)])).all()
234
+ assert (Float("0.5")*array([2*x, 0]) == array([Float("1.0")*x, 0])).all()
235
+ assert (Float("0.5") + array(
236
+ [2*x, 0]) == array([2*x + Float("0.5"), Float("0.5")])).all()
237
+
238
+
239
+ @conserve_mpmath_dps
240
+ def test_lambdify():
241
+ mpmath.mp.dps = 16
242
+ sin02 = mpmath.mpf("0.198669330795061215459412627")
243
+ f = lambdify(x, sin(x), "numpy")
244
+ prec = 1e-15
245
+ assert -prec < f(0.2) - sin02 < prec
246
+
247
+ # if this succeeds, it can't be a numpy function
248
+
249
+ if version_tuple(numpy.__version__) >= version_tuple('1.17'):
250
+ with raises(TypeError):
251
+ f(x)
252
+ else:
253
+ with raises(AttributeError):
254
+ f(x)
255
+
256
+
257
+ def test_lambdify_matrix():
258
+ f = lambdify(x, Matrix([[x, 2*x], [1, 2]]), [{'ImmutableMatrix': numpy.array}, "numpy"])
259
+ assert (f(1) == array([[1, 2], [1, 2]])).all()
260
+
261
+
262
+ def test_lambdify_matrix_multi_input():
263
+ M = sympy.Matrix([[x**2, x*y, x*z],
264
+ [y*x, y**2, y*z],
265
+ [z*x, z*y, z**2]])
266
+ f = lambdify((x, y, z), M, [{'ImmutableMatrix': numpy.array}, "numpy"])
267
+
268
+ xh, yh, zh = 1.0, 2.0, 3.0
269
+ expected = array([[xh**2, xh*yh, xh*zh],
270
+ [yh*xh, yh**2, yh*zh],
271
+ [zh*xh, zh*yh, zh**2]])
272
+ actual = f(xh, yh, zh)
273
+ assert numpy.allclose(actual, expected)
274
+
275
+
276
+ def test_lambdify_matrix_vec_input():
277
+ X = sympy.DeferredVector('X')
278
+ M = Matrix([
279
+ [X[0]**2, X[0]*X[1], X[0]*X[2]],
280
+ [X[1]*X[0], X[1]**2, X[1]*X[2]],
281
+ [X[2]*X[0], X[2]*X[1], X[2]**2]])
282
+ f = lambdify(X, M, [{'ImmutableMatrix': numpy.array}, "numpy"])
283
+
284
+ Xh = array([1.0, 2.0, 3.0])
285
+ expected = array([[Xh[0]**2, Xh[0]*Xh[1], Xh[0]*Xh[2]],
286
+ [Xh[1]*Xh[0], Xh[1]**2, Xh[1]*Xh[2]],
287
+ [Xh[2]*Xh[0], Xh[2]*Xh[1], Xh[2]**2]])
288
+ actual = f(Xh)
289
+ assert numpy.allclose(actual, expected)
290
+
291
+
292
+ def test_lambdify_transl():
293
+ from sympy.utilities.lambdify import NUMPY_TRANSLATIONS
294
+ for sym, mat in NUMPY_TRANSLATIONS.items():
295
+ assert sym in sympy.__dict__
296
+ assert mat in numpy.__dict__
297
+
298
+
299
+ def test_symarray():
300
+ """Test creation of numpy arrays of SymPy symbols."""
301
+
302
+ import numpy as np
303
+ import numpy.testing as npt
304
+
305
+ syms = symbols('_0,_1,_2')
306
+ s1 = symarray("", 3)
307
+ s2 = symarray("", 3)
308
+ npt.assert_array_equal(s1, np.array(syms, dtype=object))
309
+ assert s1[0] == s2[0]
310
+
311
+ a = symarray('a', 3)
312
+ b = symarray('b', 3)
313
+ assert not(a[0] == b[0])
314
+
315
+ asyms = symbols('a_0,a_1,a_2')
316
+ npt.assert_array_equal(a, np.array(asyms, dtype=object))
317
+
318
+ # Multidimensional checks
319
+ a2d = symarray('a', (2, 3))
320
+ assert a2d.shape == (2, 3)
321
+ a00, a12 = symbols('a_0_0,a_1_2')
322
+ assert a2d[0, 0] == a00
323
+ assert a2d[1, 2] == a12
324
+
325
+ a3d = symarray('a', (2, 3, 2))
326
+ assert a3d.shape == (2, 3, 2)
327
+ a000, a120, a121 = symbols('a_0_0_0,a_1_2_0,a_1_2_1')
328
+ assert a3d[0, 0, 0] == a000
329
+ assert a3d[1, 2, 0] == a120
330
+ assert a3d[1, 2, 1] == a121
331
+
332
+
333
+ def test_vectorize():
334
+ assert (numpy.vectorize(
335
+ sin)([1, 2, 3]) == numpy.array([sin(1), sin(2), sin(3)])).all()
.venv/lib/python3.13/site-packages/sympy/external/tests/test_pythonmpq.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ test_pythonmpq.py
3
+
4
+ Test the PythonMPQ class for consistency with gmpy2's mpq type. If gmpy2 is
5
+ installed run the same tests for both.
6
+ """
7
+ from fractions import Fraction
8
+ from decimal import Decimal
9
+ import pickle
10
+ from typing import Callable, List, Tuple, Type
11
+
12
+ from sympy.testing.pytest import raises
13
+
14
+ from sympy.external.pythonmpq import PythonMPQ
15
+
16
+ #
17
+ # If gmpy2 is installed then run the tests for both mpq and PythonMPQ.
18
+ # That should ensure consistency between the implementation here and mpq.
19
+ #
20
+ rational_types: List[Tuple[Callable, Type, Callable, Type]]
21
+ rational_types = [(PythonMPQ, PythonMPQ, int, int)]
22
+ try:
23
+ from gmpy2 import mpq, mpz
24
+ rational_types.append((mpq, type(mpq(1)), mpz, type(mpz(1))))
25
+ except ImportError:
26
+ pass
27
+
28
+
29
+ def test_PythonMPQ():
30
+ #
31
+ # Test PythonMPQ and also mpq if gmpy/gmpy2 is installed.
32
+ #
33
+ for Q, TQ, Z, TZ in rational_types:
34
+
35
+ def check_Q(q):
36
+ assert isinstance(q, TQ)
37
+ assert isinstance(q.numerator, TZ)
38
+ assert isinstance(q.denominator, TZ)
39
+ return q.numerator, q.denominator
40
+
41
+ # Check construction from different types
42
+ assert check_Q(Q(3)) == (3, 1)
43
+ assert check_Q(Q(3, 5)) == (3, 5)
44
+ assert check_Q(Q(Q(3, 5))) == (3, 5)
45
+ assert check_Q(Q(0.5)) == (1, 2)
46
+ assert check_Q(Q('0.5')) == (1, 2)
47
+ assert check_Q(Q(Fraction(3, 5))) == (3, 5)
48
+
49
+ # https://github.com/aleaxit/gmpy/issues/327
50
+ if Q is PythonMPQ:
51
+ assert check_Q(Q(Decimal('0.6'))) == (3, 5)
52
+
53
+ # Invalid types
54
+ raises(TypeError, lambda: Q([]))
55
+ raises(TypeError, lambda: Q([], []))
56
+
57
+ # Check normalisation of signs
58
+ assert check_Q(Q(2, 3)) == (2, 3)
59
+ assert check_Q(Q(-2, 3)) == (-2, 3)
60
+ assert check_Q(Q(2, -3)) == (-2, 3)
61
+ assert check_Q(Q(-2, -3)) == (2, 3)
62
+
63
+ # Check gcd calculation
64
+ assert check_Q(Q(12, 8)) == (3, 2)
65
+
66
+ # __int__/__float__
67
+ assert int(Q(5, 3)) == 1
68
+ assert int(Q(-5, 3)) == -1
69
+ assert float(Q(5, 2)) == 2.5
70
+ assert float(Q(-5, 2)) == -2.5
71
+
72
+ # __str__/__repr__
73
+ assert str(Q(2, 1)) == "2"
74
+ assert str(Q(1, 2)) == "1/2"
75
+ if Q is PythonMPQ:
76
+ assert repr(Q(2, 1)) == "MPQ(2,1)"
77
+ assert repr(Q(1, 2)) == "MPQ(1,2)"
78
+ else:
79
+ assert repr(Q(2, 1)) == "mpq(2,1)"
80
+ assert repr(Q(1, 2)) == "mpq(1,2)"
81
+
82
+ # __bool__
83
+ assert bool(Q(1, 2)) is True
84
+ assert bool(Q(0)) is False
85
+
86
+ # __eq__/__ne__
87
+ assert (Q(2, 3) == Q(2, 3)) is True
88
+ assert (Q(2, 3) == Q(2, 5)) is False
89
+ assert (Q(2, 3) != Q(2, 3)) is False
90
+ assert (Q(2, 3) != Q(2, 5)) is True
91
+
92
+ # __hash__
93
+ assert hash(Q(3, 5)) == hash(Fraction(3, 5))
94
+
95
+ # __reduce__
96
+ q = Q(2, 3)
97
+ assert pickle.loads(pickle.dumps(q)) == q
98
+
99
+ # __ge__/__gt__/__le__/__lt__
100
+ assert (Q(1, 3) < Q(2, 3)) is True
101
+ assert (Q(2, 3) < Q(2, 3)) is False
102
+ assert (Q(2, 3) < Q(1, 3)) is False
103
+ assert (Q(-2, 3) < Q(1, 3)) is True
104
+ assert (Q(1, 3) < Q(-2, 3)) is False
105
+
106
+ assert (Q(1, 3) <= Q(2, 3)) is True
107
+ assert (Q(2, 3) <= Q(2, 3)) is True
108
+ assert (Q(2, 3) <= Q(1, 3)) is False
109
+ assert (Q(-2, 3) <= Q(1, 3)) is True
110
+ assert (Q(1, 3) <= Q(-2, 3)) is False
111
+
112
+ assert (Q(1, 3) > Q(2, 3)) is False
113
+ assert (Q(2, 3) > Q(2, 3)) is False
114
+ assert (Q(2, 3) > Q(1, 3)) is True
115
+ assert (Q(-2, 3) > Q(1, 3)) is False
116
+ assert (Q(1, 3) > Q(-2, 3)) is True
117
+
118
+ assert (Q(1, 3) >= Q(2, 3)) is False
119
+ assert (Q(2, 3) >= Q(2, 3)) is True
120
+ assert (Q(2, 3) >= Q(1, 3)) is True
121
+ assert (Q(-2, 3) >= Q(1, 3)) is False
122
+ assert (Q(1, 3) >= Q(-2, 3)) is True
123
+
124
+ # __abs__/__pos__/__neg__
125
+ assert abs(Q(2, 3)) == abs(Q(-2, 3)) == Q(2, 3)
126
+ assert +Q(2, 3) == Q(2, 3)
127
+ assert -Q(2, 3) == Q(-2, 3)
128
+
129
+ # __add__/__radd__
130
+ assert Q(2, 3) + Q(5, 7) == Q(29, 21)
131
+ assert Q(2, 3) + 1 == Q(5, 3)
132
+ assert 1 + Q(2, 3) == Q(5, 3)
133
+ raises(TypeError, lambda: [] + Q(1))
134
+ raises(TypeError, lambda: Q(1) + [])
135
+
136
+ # __sub__/__rsub__
137
+ assert Q(2, 3) - Q(5, 7) == Q(-1, 21)
138
+ assert Q(2, 3) - 1 == Q(-1, 3)
139
+ assert 1 - Q(2, 3) == Q(1, 3)
140
+ raises(TypeError, lambda: [] - Q(1))
141
+ raises(TypeError, lambda: Q(1) - [])
142
+
143
+ # __mul__/__rmul__
144
+ assert Q(2, 3) * Q(5, 7) == Q(10, 21)
145
+ assert Q(2, 3) * 1 == Q(2, 3)
146
+ assert 1 * Q(2, 3) == Q(2, 3)
147
+ raises(TypeError, lambda: [] * Q(1))
148
+ raises(TypeError, lambda: Q(1) * [])
149
+
150
+ # __pow__/__rpow__
151
+ assert Q(2, 3) ** 2 == Q(4, 9)
152
+ assert Q(2, 3) ** 1 == Q(2, 3)
153
+ assert Q(-2, 3) ** 2 == Q(4, 9)
154
+ assert Q(-2, 3) ** -1 == Q(-3, 2)
155
+ if Q is PythonMPQ:
156
+ raises(TypeError, lambda: 1 ** Q(2, 3))
157
+ raises(TypeError, lambda: Q(1, 4) ** Q(1, 2))
158
+ raises(TypeError, lambda: [] ** Q(1))
159
+ raises(TypeError, lambda: Q(1) ** [])
160
+
161
+ # __div__/__rdiv__
162
+ assert Q(2, 3) / Q(5, 7) == Q(14, 15)
163
+ assert Q(2, 3) / 1 == Q(2, 3)
164
+ assert 1 / Q(2, 3) == Q(3, 2)
165
+ raises(TypeError, lambda: [] / Q(1))
166
+ raises(TypeError, lambda: Q(1) / [])
167
+ raises(ZeroDivisionError, lambda: Q(1, 2) / Q(0))
168
+
169
+ # __divmod__
170
+ if Q is PythonMPQ:
171
+ raises(TypeError, lambda: Q(2, 3) // Q(1, 3))
172
+ raises(TypeError, lambda: Q(2, 3) % Q(1, 3))
173
+ raises(TypeError, lambda: 1 // Q(1, 3))
174
+ raises(TypeError, lambda: 1 % Q(1, 3))
175
+ raises(TypeError, lambda: Q(2, 3) // 1)
176
+ raises(TypeError, lambda: Q(2, 3) % 1)
.venv/lib/python3.13/site-packages/sympy/external/tests/test_scipy.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This testfile tests SymPy <-> SciPy compatibility
2
+
3
+ # Don't test any SymPy features here. Just pure interaction with SciPy.
4
+ # Always write regular SymPy tests for anything, that can be tested in pure
5
+ # Python (without scipy). Here we test everything, that a user may need when
6
+ # using SymPy with SciPy
7
+
8
+ from sympy.external import import_module
9
+
10
+ scipy = import_module('scipy')
11
+ if not scipy:
12
+ #bin/test will not execute any tests now
13
+ disabled = True
14
+
15
+ from sympy.functions.special.bessel import jn_zeros
16
+
17
+
18
+ def eq(a, b, tol=1e-6):
19
+ for x, y in zip(a, b):
20
+ if not (abs(x - y) < tol):
21
+ return False
22
+ return True
23
+
24
+
25
+ def test_jn_zeros():
26
+ assert eq(jn_zeros(0, 4, method="scipy"),
27
+ [3.141592, 6.283185, 9.424777, 12.566370])
28
+ assert eq(jn_zeros(1, 4, method="scipy"),
29
+ [4.493409, 7.725251, 10.904121, 14.066193])
30
+ assert eq(jn_zeros(2, 4, method="scipy"),
31
+ [5.763459, 9.095011, 12.322940, 15.514603])
32
+ assert eq(jn_zeros(3, 4, method="scipy"),
33
+ [6.987932, 10.417118, 13.698023, 16.923621])
34
+ assert eq(jn_zeros(4, 4, method="scipy"),
35
+ [8.182561, 11.704907, 15.039664, 18.301255])
.venv/lib/python3.13/site-packages/sympy/functions/__init__.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """A functions module, includes all the standard functions.
2
+
3
+ Combinatorial - factorial, fibonacci, harmonic, bernoulli...
4
+ Elementary - hyperbolic, trigonometric, exponential, floor and ceiling, sqrt...
5
+ Special - gamma, zeta,spherical harmonics...
6
+ """
7
+
8
+ from sympy.functions.combinatorial.factorials import (factorial, factorial2,
9
+ rf, ff, binomial, RisingFactorial, FallingFactorial, subfactorial)
10
+ from sympy.functions.combinatorial.numbers import (carmichael, fibonacci, lucas, tribonacci,
11
+ harmonic, bernoulli, bell, euler, catalan, genocchi, andre, partition, divisor_sigma,
12
+ udivisor_sigma, legendre_symbol, jacobi_symbol, kronecker_symbol, mobius,
13
+ primenu, primeomega, totient, reduced_totient, primepi, motzkin)
14
+ from sympy.functions.elementary.miscellaneous import (sqrt, root, Min, Max,
15
+ Id, real_root, cbrt, Rem)
16
+ from sympy.functions.elementary.complexes import (re, im, sign, Abs,
17
+ conjugate, arg, polar_lift, periodic_argument, unbranched_argument,
18
+ principal_branch, transpose, adjoint, polarify, unpolarify)
19
+ from sympy.functions.elementary.trigonometric import (sin, cos, tan,
20
+ sec, csc, cot, sinc, asin, acos, atan, asec, acsc, acot, atan2)
21
+ from sympy.functions.elementary.exponential import (exp_polar, exp, log,
22
+ LambertW)
23
+ from sympy.functions.elementary.hyperbolic import (sinh, cosh, tanh, coth,
24
+ sech, csch, asinh, acosh, atanh, acoth, asech, acsch)
25
+ from sympy.functions.elementary.integers import floor, ceiling, frac
26
+ from sympy.functions.elementary.piecewise import (Piecewise, piecewise_fold,
27
+ piecewise_exclusive)
28
+ from sympy.functions.special.error_functions import (erf, erfc, erfi, erf2,
29
+ erfinv, erfcinv, erf2inv, Ei, expint, E1, li, Li, Si, Ci, Shi, Chi,
30
+ fresnels, fresnelc)
31
+ from sympy.functions.special.gamma_functions import (gamma, lowergamma,
32
+ uppergamma, polygamma, loggamma, digamma, trigamma, multigamma)
33
+ from sympy.functions.special.zeta_functions import (dirichlet_eta, zeta,
34
+ lerchphi, polylog, stieltjes, riemann_xi)
35
+ from sympy.functions.special.tensor_functions import (Eijk, LeviCivita,
36
+ KroneckerDelta)
37
+ from sympy.functions.special.singularity_functions import SingularityFunction
38
+ from sympy.functions.special.delta_functions import DiracDelta, Heaviside
39
+ from sympy.functions.special.bsplines import bspline_basis, bspline_basis_set, interpolating_spline
40
+ from sympy.functions.special.bessel import (besselj, bessely, besseli, besselk,
41
+ hankel1, hankel2, jn, yn, jn_zeros, hn1, hn2, airyai, airybi, airyaiprime, airybiprime, marcumq)
42
+ from sympy.functions.special.hyper import hyper, meijerg, appellf1
43
+ from sympy.functions.special.polynomials import (legendre, assoc_legendre,
44
+ hermite, hermite_prob, chebyshevt, chebyshevu, chebyshevu_root,
45
+ chebyshevt_root, laguerre, assoc_laguerre, gegenbauer, jacobi, jacobi_normalized)
46
+ from sympy.functions.special.spherical_harmonics import Ynm, Ynm_c, Znm
47
+ from sympy.functions.special.elliptic_integrals import (elliptic_k,
48
+ elliptic_f, elliptic_e, elliptic_pi)
49
+ from sympy.functions.special.beta_functions import beta, betainc, betainc_regularized
50
+ from sympy.functions.special.mathieu_functions import (mathieus, mathieuc,
51
+ mathieusprime, mathieucprime)
52
+ ln = log
53
+
54
+ __all__ = [
55
+ 'factorial', 'factorial2', 'rf', 'ff', 'binomial', 'RisingFactorial',
56
+ 'FallingFactorial', 'subfactorial',
57
+
58
+ 'carmichael', 'fibonacci', 'lucas', 'motzkin', 'tribonacci', 'harmonic',
59
+ 'bernoulli', 'bell', 'euler', 'catalan', 'genocchi', 'andre', 'partition',
60
+ 'divisor_sigma', 'udivisor_sigma', 'legendre_symbol', 'jacobi_symbol', 'kronecker_symbol',
61
+ 'mobius', 'primenu', 'primeomega', 'totient', 'reduced_totient', 'primepi',
62
+
63
+ 'sqrt', 'root', 'Min', 'Max', 'Id', 'real_root', 'cbrt', 'Rem',
64
+
65
+ 're', 'im', 'sign', 'Abs', 'conjugate', 'arg', 'polar_lift',
66
+ 'periodic_argument', 'unbranched_argument', 'principal_branch',
67
+ 'transpose', 'adjoint', 'polarify', 'unpolarify',
68
+
69
+ 'sin', 'cos', 'tan', 'sec', 'csc', 'cot', 'sinc', 'asin', 'acos', 'atan',
70
+ 'asec', 'acsc', 'acot', 'atan2',
71
+
72
+ 'exp_polar', 'exp', 'ln', 'log', 'LambertW',
73
+
74
+ 'sinh', 'cosh', 'tanh', 'coth', 'sech', 'csch', 'asinh', 'acosh', 'atanh',
75
+ 'acoth', 'asech', 'acsch',
76
+
77
+ 'floor', 'ceiling', 'frac',
78
+
79
+ 'Piecewise', 'piecewise_fold', 'piecewise_exclusive',
80
+
81
+ 'erf', 'erfc', 'erfi', 'erf2', 'erfinv', 'erfcinv', 'erf2inv', 'Ei',
82
+ 'expint', 'E1', 'li', 'Li', 'Si', 'Ci', 'Shi', 'Chi', 'fresnels',
83
+ 'fresnelc',
84
+
85
+ 'gamma', 'lowergamma', 'uppergamma', 'polygamma', 'loggamma', 'digamma',
86
+ 'trigamma', 'multigamma',
87
+
88
+ 'dirichlet_eta', 'zeta', 'lerchphi', 'polylog', 'stieltjes', 'riemann_xi',
89
+
90
+ 'Eijk', 'LeviCivita', 'KroneckerDelta',
91
+
92
+ 'SingularityFunction',
93
+
94
+ 'DiracDelta', 'Heaviside',
95
+
96
+ 'bspline_basis', 'bspline_basis_set', 'interpolating_spline',
97
+
98
+ 'besselj', 'bessely', 'besseli', 'besselk', 'hankel1', 'hankel2', 'jn',
99
+ 'yn', 'jn_zeros', 'hn1', 'hn2', 'airyai', 'airybi', 'airyaiprime',
100
+ 'airybiprime', 'marcumq',
101
+
102
+ 'hyper', 'meijerg', 'appellf1',
103
+
104
+ 'legendre', 'assoc_legendre', 'hermite', 'hermite_prob', 'chebyshevt',
105
+ 'chebyshevu', 'chebyshevu_root', 'chebyshevt_root', 'laguerre',
106
+ 'assoc_laguerre', 'gegenbauer', 'jacobi', 'jacobi_normalized',
107
+
108
+ 'Ynm', 'Ynm_c', 'Znm',
109
+
110
+ 'elliptic_k', 'elliptic_f', 'elliptic_e', 'elliptic_pi',
111
+
112
+ 'beta', 'betainc', 'betainc_regularized',
113
+
114
+ 'mathieus', 'mathieuc', 'mathieusprime', 'mathieucprime',
115
+ ]
.venv/lib/python3.13/site-packages/sympy/geometry/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A geometry module for the SymPy library. This module contains all of the
3
+ entities and functions needed to construct basic geometrical data and to
4
+ perform simple informational queries.
5
+
6
+ Usage:
7
+ ======
8
+
9
+ Examples
10
+ ========
11
+
12
+ """
13
+ from sympy.geometry.point import Point, Point2D, Point3D
14
+ from sympy.geometry.line import Line, Ray, Segment, Line2D, Segment2D, Ray2D, \
15
+ Line3D, Segment3D, Ray3D
16
+ from sympy.geometry.plane import Plane
17
+ from sympy.geometry.ellipse import Ellipse, Circle
18
+ from sympy.geometry.polygon import Polygon, RegularPolygon, Triangle, rad, deg
19
+ from sympy.geometry.util import are_similar, centroid, convex_hull, idiff, \
20
+ intersection, closest_points, farthest_points
21
+ from sympy.geometry.exceptions import GeometryError
22
+ from sympy.geometry.curve import Curve
23
+ from sympy.geometry.parabola import Parabola
24
+
25
+ __all__ = [
26
+ 'Point', 'Point2D', 'Point3D',
27
+
28
+ 'Line', 'Ray', 'Segment', 'Line2D', 'Segment2D', 'Ray2D', 'Line3D',
29
+ 'Segment3D', 'Ray3D',
30
+
31
+ 'Plane',
32
+
33
+ 'Ellipse', 'Circle',
34
+
35
+ 'Polygon', 'RegularPolygon', 'Triangle', 'rad', 'deg',
36
+
37
+ 'are_similar', 'centroid', 'convex_hull', 'idiff', 'intersection',
38
+ 'closest_points', 'farthest_points',
39
+
40
+ 'GeometryError',
41
+
42
+ 'Curve',
43
+
44
+ 'Parabola',
45
+ ]
.venv/lib/python3.13/site-packages/sympy/geometry/curve.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Curves in 2-dimensional Euclidean space.
2
+
3
+ Contains
4
+ ========
5
+ Curve
6
+
7
+ """
8
+
9
+ from sympy.functions.elementary.miscellaneous import sqrt
10
+ from sympy.core import diff
11
+ from sympy.core.containers import Tuple
12
+ from sympy.core.symbol import _symbol
13
+ from sympy.geometry.entity import GeometryEntity, GeometrySet
14
+ from sympy.geometry.point import Point
15
+ from sympy.integrals import integrate
16
+ from sympy.matrices import Matrix, rot_axis3
17
+ from sympy.utilities.iterables import is_sequence
18
+
19
+ from mpmath.libmp.libmpf import prec_to_dps
20
+
21
+
22
+ class Curve(GeometrySet):
23
+ """A curve in space.
24
+
25
+ A curve is defined by parametric functions for the coordinates, a
26
+ parameter and the lower and upper bounds for the parameter value.
27
+
28
+ Parameters
29
+ ==========
30
+
31
+ function : list of functions
32
+ limits : 3-tuple
33
+ Function parameter and lower and upper bounds.
34
+
35
+ Attributes
36
+ ==========
37
+
38
+ functions
39
+ parameter
40
+ limits
41
+
42
+ Raises
43
+ ======
44
+
45
+ ValueError
46
+ When `functions` are specified incorrectly.
47
+ When `limits` are specified incorrectly.
48
+
49
+ Examples
50
+ ========
51
+
52
+ >>> from sympy import Curve, sin, cos, interpolate
53
+ >>> from sympy.abc import t, a
54
+ >>> C = Curve((sin(t), cos(t)), (t, 0, 2))
55
+ >>> C.functions
56
+ (sin(t), cos(t))
57
+ >>> C.limits
58
+ (t, 0, 2)
59
+ >>> C.parameter
60
+ t
61
+ >>> C = Curve((t, interpolate([1, 4, 9, 16], t)), (t, 0, 1)); C
62
+ Curve((t, t**2), (t, 0, 1))
63
+ >>> C.subs(t, 4)
64
+ Point2D(4, 16)
65
+ >>> C.arbitrary_point(a)
66
+ Point2D(a, a**2)
67
+
68
+ See Also
69
+ ========
70
+
71
+ sympy.core.function.Function
72
+ sympy.polys.polyfuncs.interpolate
73
+
74
+ """
75
+
76
+ def __new__(cls, function, limits):
77
+ if not is_sequence(function) or len(function) != 2:
78
+ raise ValueError("Function argument should be (x(t), y(t)) "
79
+ "but got %s" % str(function))
80
+ if not is_sequence(limits) or len(limits) != 3:
81
+ raise ValueError("Limit argument should be (t, tmin, tmax) "
82
+ "but got %s" % str(limits))
83
+
84
+ return GeometryEntity.__new__(cls, Tuple(*function), Tuple(*limits))
85
+
86
+ def __call__(self, f):
87
+ return self.subs(self.parameter, f)
88
+
89
+ def _eval_subs(self, old, new):
90
+ if old == self.parameter:
91
+ return Point(*[f.subs(old, new) for f in self.functions])
92
+
93
+ def _eval_evalf(self, prec=15, **options):
94
+ f, (t, a, b) = self.args
95
+ dps = prec_to_dps(prec)
96
+ f = tuple([i.evalf(n=dps, **options) for i in f])
97
+ a, b = [i.evalf(n=dps, **options) for i in (a, b)]
98
+ return self.func(f, (t, a, b))
99
+
100
+ def arbitrary_point(self, parameter='t'):
101
+ """A parameterized point on the curve.
102
+
103
+ Parameters
104
+ ==========
105
+
106
+ parameter : str or Symbol, optional
107
+ Default value is 't'.
108
+ The Curve's parameter is selected with None or self.parameter
109
+ otherwise the provided symbol is used.
110
+
111
+ Returns
112
+ =======
113
+
114
+ Point :
115
+ Returns a point in parametric form.
116
+
117
+ Raises
118
+ ======
119
+
120
+ ValueError
121
+ When `parameter` already appears in the functions.
122
+
123
+ Examples
124
+ ========
125
+
126
+ >>> from sympy import Curve, Symbol
127
+ >>> from sympy.abc import s
128
+ >>> C = Curve([2*s, s**2], (s, 0, 2))
129
+ >>> C.arbitrary_point()
130
+ Point2D(2*t, t**2)
131
+ >>> C.arbitrary_point(C.parameter)
132
+ Point2D(2*s, s**2)
133
+ >>> C.arbitrary_point(None)
134
+ Point2D(2*s, s**2)
135
+ >>> C.arbitrary_point(Symbol('a'))
136
+ Point2D(2*a, a**2)
137
+
138
+ See Also
139
+ ========
140
+
141
+ sympy.geometry.point.Point
142
+
143
+ """
144
+ if parameter is None:
145
+ return Point(*self.functions)
146
+
147
+ tnew = _symbol(parameter, self.parameter, real=True)
148
+ t = self.parameter
149
+ if (tnew.name != t.name and
150
+ tnew.name in (f.name for f in self.free_symbols)):
151
+ raise ValueError('Symbol %s already appears in object '
152
+ 'and cannot be used as a parameter.' % tnew.name)
153
+ return Point(*[w.subs(t, tnew) for w in self.functions])
154
+
155
+ @property
156
+ def free_symbols(self):
157
+ """Return a set of symbols other than the bound symbols used to
158
+ parametrically define the Curve.
159
+
160
+ Returns
161
+ =======
162
+
163
+ set :
164
+ Set of all non-parameterized symbols.
165
+
166
+ Examples
167
+ ========
168
+
169
+ >>> from sympy.abc import t, a
170
+ >>> from sympy import Curve
171
+ >>> Curve((t, t**2), (t, 0, 2)).free_symbols
172
+ set()
173
+ >>> Curve((t, t**2), (t, a, 2)).free_symbols
174
+ {a}
175
+
176
+ """
177
+ free = set()
178
+ for a in self.functions + self.limits[1:]:
179
+ free |= a.free_symbols
180
+ free = free.difference({self.parameter})
181
+ return free
182
+
183
+ @property
184
+ def ambient_dimension(self):
185
+ """The dimension of the curve.
186
+
187
+ Returns
188
+ =======
189
+
190
+ int :
191
+ the dimension of curve.
192
+
193
+ Examples
194
+ ========
195
+
196
+ >>> from sympy.abc import t
197
+ >>> from sympy import Curve
198
+ >>> C = Curve((t, t**2), (t, 0, 2))
199
+ >>> C.ambient_dimension
200
+ 2
201
+
202
+ """
203
+
204
+ return len(self.args[0])
205
+
206
+ @property
207
+ def functions(self):
208
+ """The functions specifying the curve.
209
+
210
+ Returns
211
+ =======
212
+
213
+ functions :
214
+ list of parameterized coordinate functions.
215
+
216
+ Examples
217
+ ========
218
+
219
+ >>> from sympy.abc import t
220
+ >>> from sympy import Curve
221
+ >>> C = Curve((t, t**2), (t, 0, 2))
222
+ >>> C.functions
223
+ (t, t**2)
224
+
225
+ See Also
226
+ ========
227
+
228
+ parameter
229
+
230
+ """
231
+ return self.args[0]
232
+
233
+ @property
234
+ def limits(self):
235
+ """The limits for the curve.
236
+
237
+ Returns
238
+ =======
239
+
240
+ limits : tuple
241
+ Contains parameter and lower and upper limits.
242
+
243
+ Examples
244
+ ========
245
+
246
+ >>> from sympy.abc import t
247
+ >>> from sympy import Curve
248
+ >>> C = Curve([t, t**3], (t, -2, 2))
249
+ >>> C.limits
250
+ (t, -2, 2)
251
+
252
+ See Also
253
+ ========
254
+
255
+ plot_interval
256
+
257
+ """
258
+ return self.args[1]
259
+
260
+ @property
261
+ def parameter(self):
262
+ """The curve function variable.
263
+
264
+ Returns
265
+ =======
266
+
267
+ Symbol :
268
+ returns a bound symbol.
269
+
270
+ Examples
271
+ ========
272
+
273
+ >>> from sympy.abc import t
274
+ >>> from sympy import Curve
275
+ >>> C = Curve([t, t**2], (t, 0, 2))
276
+ >>> C.parameter
277
+ t
278
+
279
+ See Also
280
+ ========
281
+
282
+ functions
283
+
284
+ """
285
+ return self.args[1][0]
286
+
287
+ @property
288
+ def length(self):
289
+ """The curve length.
290
+
291
+ Examples
292
+ ========
293
+
294
+ >>> from sympy import Curve
295
+ >>> from sympy.abc import t
296
+ >>> Curve((t, t), (t, 0, 1)).length
297
+ sqrt(2)
298
+
299
+ """
300
+ integrand = sqrt(sum(diff(func, self.limits[0])**2 for func in self.functions))
301
+ return integrate(integrand, self.limits)
302
+
303
+ def plot_interval(self, parameter='t'):
304
+ """The plot interval for the default geometric plot of the curve.
305
+
306
+ Parameters
307
+ ==========
308
+
309
+ parameter : str or Symbol, optional
310
+ Default value is 't';
311
+ otherwise the provided symbol is used.
312
+
313
+ Returns
314
+ =======
315
+
316
+ List :
317
+ the plot interval as below:
318
+ [parameter, lower_bound, upper_bound]
319
+
320
+ Examples
321
+ ========
322
+
323
+ >>> from sympy import Curve, sin
324
+ >>> from sympy.abc import x, s
325
+ >>> Curve((x, sin(x)), (x, 1, 2)).plot_interval()
326
+ [t, 1, 2]
327
+ >>> Curve((x, sin(x)), (x, 1, 2)).plot_interval(s)
328
+ [s, 1, 2]
329
+
330
+ See Also
331
+ ========
332
+
333
+ limits : Returns limits of the parameter interval
334
+
335
+ """
336
+ t = _symbol(parameter, self.parameter, real=True)
337
+ return [t] + list(self.limits[1:])
338
+
339
+ def rotate(self, angle=0, pt=None):
340
+ """This function is used to rotate a curve along given point ``pt`` at given angle(in radian).
341
+
342
+ Parameters
343
+ ==========
344
+
345
+ angle :
346
+ the angle at which the curve will be rotated(in radian) in counterclockwise direction.
347
+ default value of angle is 0.
348
+
349
+ pt : Point
350
+ the point along which the curve will be rotated.
351
+ If no point given, the curve will be rotated around origin.
352
+
353
+ Returns
354
+ =======
355
+
356
+ Curve :
357
+ returns a curve rotated at given angle along given point.
358
+
359
+ Examples
360
+ ========
361
+
362
+ >>> from sympy import Curve, pi
363
+ >>> from sympy.abc import x
364
+ >>> Curve((x, x), (x, 0, 1)).rotate(pi/2)
365
+ Curve((-x, x), (x, 0, 1))
366
+
367
+ """
368
+ if pt:
369
+ pt = -Point(pt, dim=2)
370
+ else:
371
+ pt = Point(0,0)
372
+ rv = self.translate(*pt.args)
373
+ f = list(rv.functions)
374
+ f.append(0)
375
+ f = Matrix(1, 3, f)
376
+ f *= rot_axis3(angle)
377
+ rv = self.func(f[0, :2].tolist()[0], self.limits)
378
+ pt = -pt
379
+ return rv.translate(*pt.args)
380
+
381
+ def scale(self, x=1, y=1, pt=None):
382
+ """Override GeometryEntity.scale since Curve is not made up of Points.
383
+
384
+ Returns
385
+ =======
386
+
387
+ Curve :
388
+ returns scaled curve.
389
+
390
+ Examples
391
+ ========
392
+
393
+ >>> from sympy import Curve
394
+ >>> from sympy.abc import x
395
+ >>> Curve((x, x), (x, 0, 1)).scale(2)
396
+ Curve((2*x, x), (x, 0, 1))
397
+
398
+ """
399
+ if pt:
400
+ pt = Point(pt, dim=2)
401
+ return self.translate(*(-pt).args).scale(x, y).translate(*pt.args)
402
+ fx, fy = self.functions
403
+ return self.func((fx*x, fy*y), self.limits)
404
+
405
+ def translate(self, x=0, y=0):
406
+ """Translate the Curve by (x, y).
407
+
408
+ Returns
409
+ =======
410
+
411
+ Curve :
412
+ returns a translated curve.
413
+
414
+ Examples
415
+ ========
416
+
417
+ >>> from sympy import Curve
418
+ >>> from sympy.abc import x
419
+ >>> Curve((x, x), (x, 0, 1)).translate(1, 2)
420
+ Curve((x + 1, x + 2), (x, 0, 1))
421
+
422
+ """
423
+ fx, fy = self.functions
424
+ return self.func((fx + x, fy + y), self.limits)
.venv/lib/python3.13/site-packages/sympy/geometry/ellipse.py ADDED
@@ -0,0 +1,1768 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Elliptical geometrical entities.
2
+
3
+ Contains
4
+ * Ellipse
5
+ * Circle
6
+
7
+ """
8
+
9
+ from sympy.core.expr import Expr
10
+ from sympy.core.relational import Eq
11
+ from sympy.core import S, pi, sympify
12
+ from sympy.core.evalf import N
13
+ from sympy.core.parameters import global_parameters
14
+ from sympy.core.logic import fuzzy_bool
15
+ from sympy.core.numbers import Rational, oo
16
+ from sympy.core.sorting import ordered
17
+ from sympy.core.symbol import Dummy, uniquely_named_symbol, _symbol
18
+ from sympy.simplify.simplify import simplify
19
+ from sympy.simplify.trigsimp import trigsimp
20
+ from sympy.functions.elementary.miscellaneous import sqrt, Max
21
+ from sympy.functions.elementary.trigonometric import cos, sin
22
+ from sympy.functions.special.elliptic_integrals import elliptic_e
23
+ from .entity import GeometryEntity, GeometrySet
24
+ from .exceptions import GeometryError
25
+ from .line import Line, Segment, Ray2D, Segment2D, Line2D, LinearEntity3D
26
+ from .point import Point, Point2D, Point3D
27
+ from .util import idiff, find
28
+ from sympy.polys import DomainError, Poly, PolynomialError
29
+ from sympy.polys.polyutils import _not_a_coeff, _nsort
30
+ from sympy.solvers import solve
31
+ from sympy.solvers.solveset import linear_coeffs
32
+ from sympy.utilities.misc import filldedent, func_name
33
+
34
+ from mpmath.libmp.libmpf import prec_to_dps
35
+
36
+ import random
37
+
38
+ x, y = [Dummy('ellipse_dummy', real=True) for i in range(2)]
39
+
40
+
41
+ class Ellipse(GeometrySet):
42
+ """An elliptical GeometryEntity.
43
+
44
+ Parameters
45
+ ==========
46
+
47
+ center : Point, optional
48
+ Default value is Point(0, 0)
49
+ hradius : number or SymPy expression, optional
50
+ vradius : number or SymPy expression, optional
51
+ eccentricity : number or SymPy expression, optional
52
+ Two of `hradius`, `vradius` and `eccentricity` must be supplied to
53
+ create an Ellipse. The third is derived from the two supplied.
54
+
55
+ Attributes
56
+ ==========
57
+
58
+ center
59
+ hradius
60
+ vradius
61
+ area
62
+ circumference
63
+ eccentricity
64
+ periapsis
65
+ apoapsis
66
+ focus_distance
67
+ foci
68
+
69
+ Raises
70
+ ======
71
+
72
+ GeometryError
73
+ When `hradius`, `vradius` and `eccentricity` are incorrectly supplied
74
+ as parameters.
75
+ TypeError
76
+ When `center` is not a Point.
77
+
78
+ See Also
79
+ ========
80
+
81
+ Circle
82
+
83
+ Notes
84
+ -----
85
+ Constructed from a center and two radii, the first being the horizontal
86
+ radius (along the x-axis) and the second being the vertical radius (along
87
+ the y-axis).
88
+
89
+ When symbolic value for hradius and vradius are used, any calculation that
90
+ refers to the foci or the major or minor axis will assume that the ellipse
91
+ has its major radius on the x-axis. If this is not true then a manual
92
+ rotation is necessary.
93
+
94
+ Examples
95
+ ========
96
+
97
+ >>> from sympy import Ellipse, Point, Rational
98
+ >>> e1 = Ellipse(Point(0, 0), 5, 1)
99
+ >>> e1.hradius, e1.vradius
100
+ (5, 1)
101
+ >>> e2 = Ellipse(Point(3, 1), hradius=3, eccentricity=Rational(4, 5))
102
+ >>> e2
103
+ Ellipse(Point2D(3, 1), 3, 9/5)
104
+
105
+ """
106
+
107
+ def __contains__(self, o):
108
+ if isinstance(o, Point):
109
+ res = self.equation(x, y).subs({x: o.x, y: o.y})
110
+ return trigsimp(simplify(res)) is S.Zero
111
+ elif isinstance(o, Ellipse):
112
+ return self == o
113
+ return False
114
+
115
+ def __eq__(self, o):
116
+ """Is the other GeometryEntity the same as this ellipse?"""
117
+ return isinstance(o, Ellipse) and (self.center == o.center and
118
+ self.hradius == o.hradius and
119
+ self.vradius == o.vradius)
120
+
121
+ def __hash__(self):
122
+ return super().__hash__()
123
+
124
+ def __new__(
125
+ cls, center=None, hradius=None, vradius=None, eccentricity=None, **kwargs):
126
+
127
+ hradius = sympify(hradius)
128
+ vradius = sympify(vradius)
129
+
130
+ if center is None:
131
+ center = Point(0, 0)
132
+ else:
133
+ if len(center) != 2:
134
+ raise ValueError('The center of "{}" must be a two dimensional point'.format(cls))
135
+ center = Point(center, dim=2)
136
+
137
+ if len(list(filter(lambda x: x is not None, (hradius, vradius, eccentricity)))) != 2:
138
+ raise ValueError(filldedent('''
139
+ Exactly two arguments of "hradius", "vradius", and
140
+ "eccentricity" must not be None.'''))
141
+
142
+ if eccentricity is not None:
143
+ eccentricity = sympify(eccentricity)
144
+ if eccentricity.is_negative:
145
+ raise GeometryError("Eccentricity of ellipse/circle should lie between [0, 1)")
146
+ elif hradius is None:
147
+ hradius = vradius / sqrt(1 - eccentricity**2)
148
+ elif vradius is None:
149
+ vradius = hradius * sqrt(1 - eccentricity**2)
150
+
151
+ if hradius == vradius:
152
+ return Circle(center, hradius, **kwargs)
153
+
154
+ if S.Zero in (hradius, vradius):
155
+ return Segment(Point(center[0] - hradius, center[1] - vradius), Point(center[0] + hradius, center[1] + vradius))
156
+
157
+ if hradius.is_real is False or vradius.is_real is False:
158
+ raise GeometryError("Invalid value encountered when computing hradius / vradius.")
159
+
160
+ return GeometryEntity.__new__(cls, center, hradius, vradius, **kwargs)
161
+
162
+ def _svg(self, scale_factor=1., fill_color="#66cc99"):
163
+ """Returns SVG ellipse element for the Ellipse.
164
+
165
+ Parameters
166
+ ==========
167
+
168
+ scale_factor : float
169
+ Multiplication factor for the SVG stroke-width. Default is 1.
170
+ fill_color : str, optional
171
+ Hex string for fill color. Default is "#66cc99".
172
+ """
173
+
174
+ c = N(self.center)
175
+ h, v = N(self.hradius), N(self.vradius)
176
+ return (
177
+ '<ellipse fill="{1}" stroke="#555555" '
178
+ 'stroke-width="{0}" opacity="0.6" cx="{2}" cy="{3}" rx="{4}" ry="{5}"/>'
179
+ ).format(2. * scale_factor, fill_color, c.x, c.y, h, v)
180
+
181
+ @property
182
+ def ambient_dimension(self):
183
+ return 2
184
+
185
+ @property
186
+ def apoapsis(self):
187
+ """The apoapsis of the ellipse.
188
+
189
+ The greatest distance between the focus and the contour.
190
+
191
+ Returns
192
+ =======
193
+
194
+ apoapsis : number
195
+
196
+ See Also
197
+ ========
198
+
199
+ periapsis : Returns shortest distance between foci and contour
200
+
201
+ Examples
202
+ ========
203
+
204
+ >>> from sympy import Point, Ellipse
205
+ >>> p1 = Point(0, 0)
206
+ >>> e1 = Ellipse(p1, 3, 1)
207
+ >>> e1.apoapsis
208
+ 2*sqrt(2) + 3
209
+
210
+ """
211
+ return self.major * (1 + self.eccentricity)
212
+
213
+ def arbitrary_point(self, parameter='t'):
214
+ """A parameterized point on the ellipse.
215
+
216
+ Parameters
217
+ ==========
218
+
219
+ parameter : str, optional
220
+ Default value is 't'.
221
+
222
+ Returns
223
+ =======
224
+
225
+ arbitrary_point : Point
226
+
227
+ Raises
228
+ ======
229
+
230
+ ValueError
231
+ When `parameter` already appears in the functions.
232
+
233
+ See Also
234
+ ========
235
+
236
+ sympy.geometry.point.Point
237
+
238
+ Examples
239
+ ========
240
+
241
+ >>> from sympy import Point, Ellipse
242
+ >>> e1 = Ellipse(Point(0, 0), 3, 2)
243
+ >>> e1.arbitrary_point()
244
+ Point2D(3*cos(t), 2*sin(t))
245
+
246
+ """
247
+ t = _symbol(parameter, real=True)
248
+ if t.name in (f.name for f in self.free_symbols):
249
+ raise ValueError(filldedent('Symbol %s already appears in object '
250
+ 'and cannot be used as a parameter.' % t.name))
251
+ return Point(self.center.x + self.hradius*cos(t),
252
+ self.center.y + self.vradius*sin(t))
253
+
254
+ @property
255
+ def area(self):
256
+ """The area of the ellipse.
257
+
258
+ Returns
259
+ =======
260
+
261
+ area : number
262
+
263
+ Examples
264
+ ========
265
+
266
+ >>> from sympy import Point, Ellipse
267
+ >>> p1 = Point(0, 0)
268
+ >>> e1 = Ellipse(p1, 3, 1)
269
+ >>> e1.area
270
+ 3*pi
271
+
272
+ """
273
+ return simplify(S.Pi * self.hradius * self.vradius)
274
+
275
+ @property
276
+ def bounds(self):
277
+ """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding
278
+ rectangle for the geometric figure.
279
+
280
+ """
281
+
282
+ h, v = self.hradius, self.vradius
283
+ return (self.center.x - h, self.center.y - v, self.center.x + h, self.center.y + v)
284
+
285
+ @property
286
+ def center(self):
287
+ """The center of the ellipse.
288
+
289
+ Returns
290
+ =======
291
+
292
+ center : number
293
+
294
+ See Also
295
+ ========
296
+
297
+ sympy.geometry.point.Point
298
+
299
+ Examples
300
+ ========
301
+
302
+ >>> from sympy import Point, Ellipse
303
+ >>> p1 = Point(0, 0)
304
+ >>> e1 = Ellipse(p1, 3, 1)
305
+ >>> e1.center
306
+ Point2D(0, 0)
307
+
308
+ """
309
+ return self.args[0]
310
+
311
+ @property
312
+ def circumference(self):
313
+ """The circumference of the ellipse.
314
+
315
+ Examples
316
+ ========
317
+
318
+ >>> from sympy import Point, Ellipse
319
+ >>> p1 = Point(0, 0)
320
+ >>> e1 = Ellipse(p1, 3, 1)
321
+ >>> e1.circumference
322
+ 12*elliptic_e(8/9)
323
+
324
+ """
325
+ if self.eccentricity == 1:
326
+ # degenerate
327
+ return 4*self.major
328
+ elif self.eccentricity == 0:
329
+ # circle
330
+ return 2*pi*self.hradius
331
+ else:
332
+ return 4*self.major*elliptic_e(self.eccentricity**2)
333
+
334
+ @property
335
+ def eccentricity(self):
336
+ """The eccentricity of the ellipse.
337
+
338
+ Returns
339
+ =======
340
+
341
+ eccentricity : number
342
+
343
+ Examples
344
+ ========
345
+
346
+ >>> from sympy import Point, Ellipse, sqrt
347
+ >>> p1 = Point(0, 0)
348
+ >>> e1 = Ellipse(p1, 3, sqrt(2))
349
+ >>> e1.eccentricity
350
+ sqrt(7)/3
351
+
352
+ """
353
+ return self.focus_distance / self.major
354
+
355
+ def encloses_point(self, p):
356
+ """
357
+ Return True if p is enclosed by (is inside of) self.
358
+
359
+ Notes
360
+ -----
361
+ Being on the border of self is considered False.
362
+
363
+ Parameters
364
+ ==========
365
+
366
+ p : Point
367
+
368
+ Returns
369
+ =======
370
+
371
+ encloses_point : True, False or None
372
+
373
+ See Also
374
+ ========
375
+
376
+ sympy.geometry.point.Point
377
+
378
+ Examples
379
+ ========
380
+
381
+ >>> from sympy import Ellipse, S
382
+ >>> from sympy.abc import t
383
+ >>> e = Ellipse((0, 0), 3, 2)
384
+ >>> e.encloses_point((0, 0))
385
+ True
386
+ >>> e.encloses_point(e.arbitrary_point(t).subs(t, S.Half))
387
+ False
388
+ >>> e.encloses_point((4, 0))
389
+ False
390
+
391
+ """
392
+ p = Point(p, dim=2)
393
+ if p in self:
394
+ return False
395
+
396
+ if len(self.foci) == 2:
397
+ # if the combined distance from the foci to p (h1 + h2) is less
398
+ # than the combined distance from the foci to the minor axis
399
+ # (which is the same as the major axis length) then p is inside
400
+ # the ellipse
401
+ h1, h2 = [f.distance(p) for f in self.foci]
402
+ test = 2*self.major - (h1 + h2)
403
+ else:
404
+ test = self.radius - self.center.distance(p)
405
+
406
+ return fuzzy_bool(test.is_positive)
407
+
408
+ def equation(self, x='x', y='y', _slope=None):
409
+ """
410
+ Returns the equation of an ellipse aligned with the x and y axes;
411
+ when slope is given, the equation returned corresponds to an ellipse
412
+ with a major axis having that slope.
413
+
414
+ Parameters
415
+ ==========
416
+
417
+ x : str, optional
418
+ Label for the x-axis. Default value is 'x'.
419
+ y : str, optional
420
+ Label for the y-axis. Default value is 'y'.
421
+ _slope : Expr, optional
422
+ The slope of the major axis. Ignored when 'None'.
423
+
424
+ Returns
425
+ =======
426
+
427
+ equation : SymPy expression
428
+
429
+ See Also
430
+ ========
431
+
432
+ arbitrary_point : Returns parameterized point on ellipse
433
+
434
+ Examples
435
+ ========
436
+
437
+ >>> from sympy import Point, Ellipse, pi
438
+ >>> from sympy.abc import x, y
439
+ >>> e1 = Ellipse(Point(1, 0), 3, 2)
440
+ >>> eq1 = e1.equation(x, y); eq1
441
+ y**2/4 + (x/3 - 1/3)**2 - 1
442
+ >>> eq2 = e1.equation(x, y, _slope=1); eq2
443
+ (-x + y + 1)**2/8 + (x + y - 1)**2/18 - 1
444
+
445
+ A point on e1 satisfies eq1. Let's use one on the x-axis:
446
+
447
+ >>> p1 = e1.center + Point(e1.major, 0)
448
+ >>> assert eq1.subs(x, p1.x).subs(y, p1.y) == 0
449
+
450
+ When rotated the same as the rotated ellipse, about the center
451
+ point of the ellipse, it will satisfy the rotated ellipse's
452
+ equation, too:
453
+
454
+ >>> r1 = p1.rotate(pi/4, e1.center)
455
+ >>> assert eq2.subs(x, r1.x).subs(y, r1.y) == 0
456
+
457
+ References
458
+ ==========
459
+
460
+ .. [1] https://math.stackexchange.com/questions/108270/what-is-the-equation-of-an-ellipse-that-is-not-aligned-with-the-axis
461
+ .. [2] https://en.wikipedia.org/wiki/Ellipse#Shifted_ellipse
462
+
463
+ """
464
+
465
+ x = _symbol(x, real=True)
466
+ y = _symbol(y, real=True)
467
+
468
+ dx = x - self.center.x
469
+ dy = y - self.center.y
470
+
471
+ if _slope is not None:
472
+ L = (dy - _slope*dx)**2
473
+ l = (_slope*dy + dx)**2
474
+ h = 1 + _slope**2
475
+ b = h*self.major**2
476
+ a = h*self.minor**2
477
+ return l/b + L/a - 1
478
+
479
+ else:
480
+ t1 = (dx/self.hradius)**2
481
+ t2 = (dy/self.vradius)**2
482
+ return t1 + t2 - 1
483
+
484
+ def evolute(self, x='x', y='y'):
485
+ """The equation of evolute of the ellipse.
486
+
487
+ Parameters
488
+ ==========
489
+
490
+ x : str, optional
491
+ Label for the x-axis. Default value is 'x'.
492
+ y : str, optional
493
+ Label for the y-axis. Default value is 'y'.
494
+
495
+ Returns
496
+ =======
497
+
498
+ equation : SymPy expression
499
+
500
+ Examples
501
+ ========
502
+
503
+ >>> from sympy import Point, Ellipse
504
+ >>> e1 = Ellipse(Point(1, 0), 3, 2)
505
+ >>> e1.evolute()
506
+ 2**(2/3)*y**(2/3) + (3*x - 3)**(2/3) - 5**(2/3)
507
+ """
508
+ if len(self.args) != 3:
509
+ raise NotImplementedError('Evolute of arbitrary Ellipse is not supported.')
510
+ x = _symbol(x, real=True)
511
+ y = _symbol(y, real=True)
512
+ t1 = (self.hradius*(x - self.center.x))**Rational(2, 3)
513
+ t2 = (self.vradius*(y - self.center.y))**Rational(2, 3)
514
+ return t1 + t2 - (self.hradius**2 - self.vradius**2)**Rational(2, 3)
515
+
516
+ @property
517
+ def foci(self):
518
+ """The foci of the ellipse.
519
+
520
+ Notes
521
+ -----
522
+ The foci can only be calculated if the major/minor axes are known.
523
+
524
+ Raises
525
+ ======
526
+
527
+ ValueError
528
+ When the major and minor axis cannot be determined.
529
+
530
+ See Also
531
+ ========
532
+
533
+ sympy.geometry.point.Point
534
+ focus_distance : Returns the distance between focus and center
535
+
536
+ Examples
537
+ ========
538
+
539
+ >>> from sympy import Point, Ellipse
540
+ >>> p1 = Point(0, 0)
541
+ >>> e1 = Ellipse(p1, 3, 1)
542
+ >>> e1.foci
543
+ (Point2D(-2*sqrt(2), 0), Point2D(2*sqrt(2), 0))
544
+
545
+ """
546
+ c = self.center
547
+ hr, vr = self.hradius, self.vradius
548
+ if hr == vr:
549
+ return (c, c)
550
+
551
+ # calculate focus distance manually, since focus_distance calls this
552
+ # routine
553
+ fd = sqrt(self.major**2 - self.minor**2)
554
+ if hr == self.minor:
555
+ # foci on the y-axis
556
+ return (c + Point(0, -fd), c + Point(0, fd))
557
+ elif hr == self.major:
558
+ # foci on the x-axis
559
+ return (c + Point(-fd, 0), c + Point(fd, 0))
560
+
561
+ @property
562
+ def focus_distance(self):
563
+ """The focal distance of the ellipse.
564
+
565
+ The distance between the center and one focus.
566
+
567
+ Returns
568
+ =======
569
+
570
+ focus_distance : number
571
+
572
+ See Also
573
+ ========
574
+
575
+ foci
576
+
577
+ Examples
578
+ ========
579
+
580
+ >>> from sympy import Point, Ellipse
581
+ >>> p1 = Point(0, 0)
582
+ >>> e1 = Ellipse(p1, 3, 1)
583
+ >>> e1.focus_distance
584
+ 2*sqrt(2)
585
+
586
+ """
587
+ return Point.distance(self.center, self.foci[0])
588
+
589
+ @property
590
+ def hradius(self):
591
+ """The horizontal radius of the ellipse.
592
+
593
+ Returns
594
+ =======
595
+
596
+ hradius : number
597
+
598
+ See Also
599
+ ========
600
+
601
+ vradius, major, minor
602
+
603
+ Examples
604
+ ========
605
+
606
+ >>> from sympy import Point, Ellipse
607
+ >>> p1 = Point(0, 0)
608
+ >>> e1 = Ellipse(p1, 3, 1)
609
+ >>> e1.hradius
610
+ 3
611
+
612
+ """
613
+ return self.args[1]
614
+
615
+ def intersection(self, o):
616
+ """The intersection of this ellipse and another geometrical entity
617
+ `o`.
618
+
619
+ Parameters
620
+ ==========
621
+
622
+ o : GeometryEntity
623
+
624
+ Returns
625
+ =======
626
+
627
+ intersection : list of GeometryEntity objects
628
+
629
+ Notes
630
+ -----
631
+ Currently supports intersections with Point, Line, Segment, Ray,
632
+ Circle and Ellipse types.
633
+
634
+ See Also
635
+ ========
636
+
637
+ sympy.geometry.entity.GeometryEntity
638
+
639
+ Examples
640
+ ========
641
+
642
+ >>> from sympy import Ellipse, Point, Line
643
+ >>> e = Ellipse(Point(0, 0), 5, 7)
644
+ >>> e.intersection(Point(0, 0))
645
+ []
646
+ >>> e.intersection(Point(5, 0))
647
+ [Point2D(5, 0)]
648
+ >>> e.intersection(Line(Point(0,0), Point(0, 1)))
649
+ [Point2D(0, -7), Point2D(0, 7)]
650
+ >>> e.intersection(Line(Point(5,0), Point(5, 1)))
651
+ [Point2D(5, 0)]
652
+ >>> e.intersection(Line(Point(6,0), Point(6, 1)))
653
+ []
654
+ >>> e = Ellipse(Point(-1, 0), 4, 3)
655
+ >>> e.intersection(Ellipse(Point(1, 0), 4, 3))
656
+ [Point2D(0, -3*sqrt(15)/4), Point2D(0, 3*sqrt(15)/4)]
657
+ >>> e.intersection(Ellipse(Point(5, 0), 4, 3))
658
+ [Point2D(2, -3*sqrt(7)/4), Point2D(2, 3*sqrt(7)/4)]
659
+ >>> e.intersection(Ellipse(Point(100500, 0), 4, 3))
660
+ []
661
+ >>> e.intersection(Ellipse(Point(0, 0), 3, 4))
662
+ [Point2D(3, 0), Point2D(-363/175, -48*sqrt(111)/175), Point2D(-363/175, 48*sqrt(111)/175)]
663
+ >>> e.intersection(Ellipse(Point(-1, 0), 3, 4))
664
+ [Point2D(-17/5, -12/5), Point2D(-17/5, 12/5), Point2D(7/5, -12/5), Point2D(7/5, 12/5)]
665
+ """
666
+ # TODO: Replace solve with nonlinsolve, when nonlinsolve will be able to solve in real domain
667
+
668
+ if isinstance(o, Point):
669
+ if o in self:
670
+ return [o]
671
+ else:
672
+ return []
673
+
674
+ elif isinstance(o, (Segment2D, Ray2D)):
675
+ ellipse_equation = self.equation(x, y)
676
+ result = solve([ellipse_equation, Line(
677
+ o.points[0], o.points[1]).equation(x, y)], [x, y],
678
+ set=True)[1]
679
+ return list(ordered([Point(i) for i in result if i in o]))
680
+
681
+ elif isinstance(o, Polygon):
682
+ return o.intersection(self)
683
+
684
+ elif isinstance(o, (Ellipse, Line2D)):
685
+ if o == self:
686
+ return self
687
+ else:
688
+ ellipse_equation = self.equation(x, y)
689
+ return list(ordered([Point(i) for i in solve(
690
+ [ellipse_equation, o.equation(x, y)], [x, y],
691
+ set=True)[1]]))
692
+ elif isinstance(o, LinearEntity3D):
693
+ raise TypeError('Entity must be two dimensional, not three dimensional')
694
+ else:
695
+ raise TypeError('Intersection not handled for %s' % func_name(o))
696
+
697
+ def is_tangent(self, o):
698
+ """Is `o` tangent to the ellipse?
699
+
700
+ Parameters
701
+ ==========
702
+
703
+ o : GeometryEntity
704
+ An Ellipse, LinearEntity or Polygon
705
+
706
+ Raises
707
+ ======
708
+
709
+ NotImplementedError
710
+ When the wrong type of argument is supplied.
711
+
712
+ Returns
713
+ =======
714
+
715
+ is_tangent: boolean
716
+ True if o is tangent to the ellipse, False otherwise.
717
+
718
+ See Also
719
+ ========
720
+
721
+ tangent_lines
722
+
723
+ Examples
724
+ ========
725
+
726
+ >>> from sympy import Point, Ellipse, Line
727
+ >>> p0, p1, p2 = Point(0, 0), Point(3, 0), Point(3, 3)
728
+ >>> e1 = Ellipse(p0, 3, 2)
729
+ >>> l1 = Line(p1, p2)
730
+ >>> e1.is_tangent(l1)
731
+ True
732
+
733
+ """
734
+ if isinstance(o, Point2D):
735
+ return False
736
+ elif isinstance(o, Ellipse):
737
+ intersect = self.intersection(o)
738
+ if isinstance(intersect, Ellipse):
739
+ return True
740
+ elif intersect:
741
+ return all((self.tangent_lines(i)[0]).equals(o.tangent_lines(i)[0]) for i in intersect)
742
+ else:
743
+ return False
744
+ elif isinstance(o, Line2D):
745
+ hit = self.intersection(o)
746
+ if not hit:
747
+ return False
748
+ if len(hit) == 1:
749
+ return True
750
+ # might return None if it can't decide
751
+ return hit[0].equals(hit[1])
752
+ elif isinstance(o, (Segment2D, Ray2D)):
753
+ intersect = self.intersection(o)
754
+ if len(intersect) == 1:
755
+ return o in self.tangent_lines(intersect[0])[0]
756
+ else:
757
+ return False
758
+ elif isinstance(o, Polygon):
759
+ return all(self.is_tangent(s) for s in o.sides)
760
+ elif isinstance(o, (LinearEntity3D, Point3D)):
761
+ raise TypeError('Entity must be two dimensional, not three dimensional')
762
+ else:
763
+ raise TypeError('Is_tangent not handled for %s' % func_name(o))
764
+
765
+ @property
766
+ def major(self):
767
+ """Longer axis of the ellipse (if it can be determined) else hradius.
768
+
769
+ Returns
770
+ =======
771
+
772
+ major : number or expression
773
+
774
+ See Also
775
+ ========
776
+
777
+ hradius, vradius, minor
778
+
779
+ Examples
780
+ ========
781
+
782
+ >>> from sympy import Point, Ellipse, Symbol
783
+ >>> p1 = Point(0, 0)
784
+ >>> e1 = Ellipse(p1, 3, 1)
785
+ >>> e1.major
786
+ 3
787
+
788
+ >>> a = Symbol('a')
789
+ >>> b = Symbol('b')
790
+ >>> Ellipse(p1, a, b).major
791
+ a
792
+ >>> Ellipse(p1, b, a).major
793
+ b
794
+
795
+ >>> m = Symbol('m')
796
+ >>> M = m + 1
797
+ >>> Ellipse(p1, m, M).major
798
+ m + 1
799
+
800
+ """
801
+ ab = self.args[1:3]
802
+ if len(ab) == 1:
803
+ return ab[0]
804
+ a, b = ab
805
+ o = b - a < 0
806
+ if o == True:
807
+ return a
808
+ elif o == False:
809
+ return b
810
+ return self.hradius
811
+
812
+ @property
813
+ def minor(self):
814
+ """Shorter axis of the ellipse (if it can be determined) else vradius.
815
+
816
+ Returns
817
+ =======
818
+
819
+ minor : number or expression
820
+
821
+ See Also
822
+ ========
823
+
824
+ hradius, vradius, major
825
+
826
+ Examples
827
+ ========
828
+
829
+ >>> from sympy import Point, Ellipse, Symbol
830
+ >>> p1 = Point(0, 0)
831
+ >>> e1 = Ellipse(p1, 3, 1)
832
+ >>> e1.minor
833
+ 1
834
+
835
+ >>> a = Symbol('a')
836
+ >>> b = Symbol('b')
837
+ >>> Ellipse(p1, a, b).minor
838
+ b
839
+ >>> Ellipse(p1, b, a).minor
840
+ a
841
+
842
+ >>> m = Symbol('m')
843
+ >>> M = m + 1
844
+ >>> Ellipse(p1, m, M).minor
845
+ m
846
+
847
+ """
848
+ ab = self.args[1:3]
849
+ if len(ab) == 1:
850
+ return ab[0]
851
+ a, b = ab
852
+ o = a - b < 0
853
+ if o == True:
854
+ return a
855
+ elif o == False:
856
+ return b
857
+ return self.vradius
858
+
859
+ def normal_lines(self, p, prec=None):
860
+ """Normal lines between `p` and the ellipse.
861
+
862
+ Parameters
863
+ ==========
864
+
865
+ p : Point
866
+
867
+ Returns
868
+ =======
869
+
870
+ normal_lines : list with 1, 2 or 4 Lines
871
+
872
+ Examples
873
+ ========
874
+
875
+ >>> from sympy import Point, Ellipse
876
+ >>> e = Ellipse((0, 0), 2, 3)
877
+ >>> c = e.center
878
+ >>> e.normal_lines(c + Point(1, 0))
879
+ [Line2D(Point2D(0, 0), Point2D(1, 0))]
880
+ >>> e.normal_lines(c)
881
+ [Line2D(Point2D(0, 0), Point2D(0, 1)), Line2D(Point2D(0, 0), Point2D(1, 0))]
882
+
883
+ Off-axis points require the solution of a quartic equation. This
884
+ often leads to very large expressions that may be of little practical
885
+ use. An approximate solution of `prec` digits can be obtained by
886
+ passing in the desired value:
887
+
888
+ >>> e.normal_lines((3, 3), prec=2)
889
+ [Line2D(Point2D(-0.81, -2.7), Point2D(0.19, -1.2)),
890
+ Line2D(Point2D(1.5, -2.0), Point2D(2.5, -2.7))]
891
+
892
+ Whereas the above solution has an operation count of 12, the exact
893
+ solution has an operation count of 2020.
894
+ """
895
+ p = Point(p, dim=2)
896
+
897
+ # XXX change True to something like self.angle == 0 if the arbitrarily
898
+ # rotated ellipse is introduced.
899
+ # https://github.com/sympy/sympy/issues/2815)
900
+ if True:
901
+ rv = []
902
+ if p.x == self.center.x:
903
+ rv.append(Line(self.center, slope=oo))
904
+ if p.y == self.center.y:
905
+ rv.append(Line(self.center, slope=0))
906
+ if rv:
907
+ # at these special orientations of p either 1 or 2 normals
908
+ # exist and we are done
909
+ return rv
910
+
911
+ # find the 4 normal points and construct lines through them with
912
+ # the corresponding slope
913
+ eq = self.equation(x, y)
914
+ dydx = idiff(eq, y, x)
915
+ norm = -1/dydx
916
+ slope = Line(p, (x, y)).slope
917
+ seq = slope - norm
918
+
919
+ # TODO: Replace solve with solveset, when this line is tested
920
+ yis = solve(seq, y)[0]
921
+ xeq = eq.subs(y, yis).as_numer_denom()[0].expand()
922
+ if len(xeq.free_symbols) == 1:
923
+ try:
924
+ # this is so much faster, it's worth a try
925
+ xsol = Poly(xeq, x).real_roots()
926
+ except (DomainError, PolynomialError, NotImplementedError):
927
+ # TODO: Replace solve with solveset, when these lines are tested
928
+ xsol = _nsort(solve(xeq, x), separated=True)[0]
929
+ points = [Point(i, solve(eq.subs(x, i), y)[0]) for i in xsol]
930
+ else:
931
+ raise NotImplementedError(
932
+ 'intersections for the general ellipse are not supported')
933
+ slopes = [norm.subs(zip((x, y), pt.args)) for pt in points]
934
+ if prec is not None:
935
+ points = [pt.n(prec) for pt in points]
936
+ slopes = [i if _not_a_coeff(i) else i.n(prec) for i in slopes]
937
+ return [Line(pt, slope=s) for pt, s in zip(points, slopes)]
938
+
939
+ @property
940
+ def periapsis(self):
941
+ """The periapsis of the ellipse.
942
+
943
+ The shortest distance between the focus and the contour.
944
+
945
+ Returns
946
+ =======
947
+
948
+ periapsis : number
949
+
950
+ See Also
951
+ ========
952
+
953
+ apoapsis : Returns greatest distance between focus and contour
954
+
955
+ Examples
956
+ ========
957
+
958
+ >>> from sympy import Point, Ellipse
959
+ >>> p1 = Point(0, 0)
960
+ >>> e1 = Ellipse(p1, 3, 1)
961
+ >>> e1.periapsis
962
+ 3 - 2*sqrt(2)
963
+
964
+ """
965
+ return self.major * (1 - self.eccentricity)
966
+
967
+ @property
968
+ def semilatus_rectum(self):
969
+ """
970
+ Calculates the semi-latus rectum of the Ellipse.
971
+
972
+ Semi-latus rectum is defined as one half of the chord through a
973
+ focus parallel to the conic section directrix of a conic section.
974
+
975
+ Returns
976
+ =======
977
+
978
+ semilatus_rectum : number
979
+
980
+ See Also
981
+ ========
982
+
983
+ apoapsis : Returns greatest distance between focus and contour
984
+
985
+ periapsis : The shortest distance between the focus and the contour
986
+
987
+ Examples
988
+ ========
989
+
990
+ >>> from sympy import Point, Ellipse
991
+ >>> p1 = Point(0, 0)
992
+ >>> e1 = Ellipse(p1, 3, 1)
993
+ >>> e1.semilatus_rectum
994
+ 1/3
995
+
996
+ References
997
+ ==========
998
+
999
+ .. [1] https://mathworld.wolfram.com/SemilatusRectum.html
1000
+ .. [2] https://en.wikipedia.org/wiki/Ellipse#Semi-latus_rectum
1001
+
1002
+ """
1003
+ return self.major * (1 - self.eccentricity ** 2)
1004
+
1005
+ def auxiliary_circle(self):
1006
+ """Returns a Circle whose diameter is the major axis of the ellipse.
1007
+
1008
+ Examples
1009
+ ========
1010
+
1011
+ >>> from sympy import Ellipse, Point, symbols
1012
+ >>> c = Point(1, 2)
1013
+ >>> Ellipse(c, 8, 7).auxiliary_circle()
1014
+ Circle(Point2D(1, 2), 8)
1015
+ >>> a, b = symbols('a b')
1016
+ >>> Ellipse(c, a, b).auxiliary_circle()
1017
+ Circle(Point2D(1, 2), Max(a, b))
1018
+ """
1019
+ return Circle(self.center, Max(self.hradius, self.vradius))
1020
+
1021
+ def director_circle(self):
1022
+ """
1023
+ Returns a Circle consisting of all points where two perpendicular
1024
+ tangent lines to the ellipse cross each other.
1025
+
1026
+ Returns
1027
+ =======
1028
+
1029
+ Circle
1030
+ A director circle returned as a geometric object.
1031
+
1032
+ Examples
1033
+ ========
1034
+
1035
+ >>> from sympy import Ellipse, Point, symbols
1036
+ >>> c = Point(3,8)
1037
+ >>> Ellipse(c, 7, 9).director_circle()
1038
+ Circle(Point2D(3, 8), sqrt(130))
1039
+ >>> a, b = symbols('a b')
1040
+ >>> Ellipse(c, a, b).director_circle()
1041
+ Circle(Point2D(3, 8), sqrt(a**2 + b**2))
1042
+
1043
+ References
1044
+ ==========
1045
+
1046
+ .. [1] https://en.wikipedia.org/wiki/Director_circle
1047
+
1048
+ """
1049
+ return Circle(self.center, sqrt(self.hradius**2 + self.vradius**2))
1050
+
1051
+ def plot_interval(self, parameter='t'):
1052
+ """The plot interval for the default geometric plot of the Ellipse.
1053
+
1054
+ Parameters
1055
+ ==========
1056
+
1057
+ parameter : str, optional
1058
+ Default value is 't'.
1059
+
1060
+ Returns
1061
+ =======
1062
+
1063
+ plot_interval : list
1064
+ [parameter, lower_bound, upper_bound]
1065
+
1066
+ Examples
1067
+ ========
1068
+
1069
+ >>> from sympy import Point, Ellipse
1070
+ >>> e1 = Ellipse(Point(0, 0), 3, 2)
1071
+ >>> e1.plot_interval()
1072
+ [t, -pi, pi]
1073
+
1074
+ """
1075
+ t = _symbol(parameter, real=True)
1076
+ return [t, -S.Pi, S.Pi]
1077
+
1078
+ def random_point(self, seed=None):
1079
+ """A random point on the ellipse.
1080
+
1081
+ Returns
1082
+ =======
1083
+
1084
+ point : Point
1085
+
1086
+ Examples
1087
+ ========
1088
+
1089
+ >>> from sympy import Point, Ellipse
1090
+ >>> e1 = Ellipse(Point(0, 0), 3, 2)
1091
+ >>> e1.random_point() # gives some random point
1092
+ Point2D(...)
1093
+ >>> p1 = e1.random_point(seed=0); p1.n(2)
1094
+ Point2D(2.1, 1.4)
1095
+
1096
+ Notes
1097
+ =====
1098
+
1099
+ When creating a random point, one may simply replace the
1100
+ parameter with a random number. When doing so, however, the
1101
+ random number should be made a Rational or else the point
1102
+ may not test as being in the ellipse:
1103
+
1104
+ >>> from sympy.abc import t
1105
+ >>> from sympy import Rational
1106
+ >>> arb = e1.arbitrary_point(t); arb
1107
+ Point2D(3*cos(t), 2*sin(t))
1108
+ >>> arb.subs(t, .1) in e1
1109
+ False
1110
+ >>> arb.subs(t, Rational(.1)) in e1
1111
+ True
1112
+ >>> arb.subs(t, Rational('.1')) in e1
1113
+ True
1114
+
1115
+ See Also
1116
+ ========
1117
+ sympy.geometry.point.Point
1118
+ arbitrary_point : Returns parameterized point on ellipse
1119
+ """
1120
+ t = _symbol('t', real=True)
1121
+ x, y = self.arbitrary_point(t).args
1122
+ # get a random value in [-1, 1) corresponding to cos(t)
1123
+ # and confirm that it will test as being in the ellipse
1124
+ if seed is not None:
1125
+ rng = random.Random(seed)
1126
+ else:
1127
+ rng = random
1128
+ # simplify this now or else the Float will turn s into a Float
1129
+ r = Rational(rng.random())
1130
+ c = 2*r - 1
1131
+ s = sqrt(1 - c**2)
1132
+ return Point(x.subs(cos(t), c), y.subs(sin(t), s))
1133
+
1134
+ def reflect(self, line):
1135
+ """Override GeometryEntity.reflect since the radius
1136
+ is not a GeometryEntity.
1137
+
1138
+ Examples
1139
+ ========
1140
+
1141
+ >>> from sympy import Circle, Line
1142
+ >>> Circle((0, 1), 1).reflect(Line((0, 0), (1, 1)))
1143
+ Circle(Point2D(1, 0), -1)
1144
+ >>> from sympy import Ellipse, Line, Point
1145
+ >>> Ellipse(Point(3, 4), 1, 3).reflect(Line(Point(0, -4), Point(5, 0)))
1146
+ Traceback (most recent call last):
1147
+ ...
1148
+ NotImplementedError:
1149
+ General Ellipse is not supported but the equation of the reflected
1150
+ Ellipse is given by the zeros of: f(x, y) = (9*x/41 + 40*y/41 +
1151
+ 37/41)**2 + (40*x/123 - 3*y/41 - 364/123)**2 - 1
1152
+
1153
+ Notes
1154
+ =====
1155
+
1156
+ Until the general ellipse (with no axis parallel to the x-axis) is
1157
+ supported a NotImplemented error is raised and the equation whose
1158
+ zeros define the rotated ellipse is given.
1159
+
1160
+ """
1161
+
1162
+ if line.slope in (0, oo):
1163
+ c = self.center
1164
+ c = c.reflect(line)
1165
+ return self.func(c, -self.hradius, self.vradius)
1166
+ else:
1167
+ x, y = [uniquely_named_symbol(
1168
+ name, (self, line), modify=lambda s: '_' + s, real=True)
1169
+ for name in 'xy']
1170
+ expr = self.equation(x, y)
1171
+ p = Point(x, y).reflect(line)
1172
+ result = expr.subs(zip((x, y), p.args
1173
+ ), simultaneous=True)
1174
+ raise NotImplementedError(filldedent(
1175
+ 'General Ellipse is not supported but the equation '
1176
+ 'of the reflected Ellipse is given by the zeros of: ' +
1177
+ "f(%s, %s) = %s" % (str(x), str(y), str(result))))
1178
+
1179
+ def rotate(self, angle=0, pt=None):
1180
+ """Rotate ``angle`` radians counterclockwise about Point ``pt``.
1181
+
1182
+ Note: since the general ellipse is not supported, only rotations that
1183
+ are integer multiples of pi/2 are allowed.
1184
+
1185
+ Examples
1186
+ ========
1187
+
1188
+ >>> from sympy import Ellipse, pi
1189
+ >>> Ellipse((1, 0), 2, 1).rotate(pi/2)
1190
+ Ellipse(Point2D(0, 1), 1, 2)
1191
+ >>> Ellipse((1, 0), 2, 1).rotate(pi)
1192
+ Ellipse(Point2D(-1, 0), 2, 1)
1193
+ """
1194
+ if self.hradius == self.vradius:
1195
+ return self.func(self.center.rotate(angle, pt), self.hradius)
1196
+ if (angle/S.Pi).is_integer:
1197
+ return super().rotate(angle, pt)
1198
+ if (2*angle/S.Pi).is_integer:
1199
+ return self.func(self.center.rotate(angle, pt), self.vradius, self.hradius)
1200
+ # XXX see https://github.com/sympy/sympy/issues/2815 for general ellipes
1201
+ raise NotImplementedError('Only rotations of pi/2 are currently supported for Ellipse.')
1202
+
1203
+ def scale(self, x=1, y=1, pt=None):
1204
+ """Override GeometryEntity.scale since it is the major and minor
1205
+ axes which must be scaled and they are not GeometryEntities.
1206
+
1207
+ Examples
1208
+ ========
1209
+
1210
+ >>> from sympy import Ellipse
1211
+ >>> Ellipse((0, 0), 2, 1).scale(2, 4)
1212
+ Circle(Point2D(0, 0), 4)
1213
+ >>> Ellipse((0, 0), 2, 1).scale(2)
1214
+ Ellipse(Point2D(0, 0), 4, 1)
1215
+ """
1216
+ c = self.center
1217
+ if pt:
1218
+ pt = Point(pt, dim=2)
1219
+ return self.translate(*(-pt).args).scale(x, y).translate(*pt.args)
1220
+ h = self.hradius
1221
+ v = self.vradius
1222
+ return self.func(c.scale(x, y), hradius=h*x, vradius=v*y)
1223
+
1224
+ def tangent_lines(self, p):
1225
+ """Tangent lines between `p` and the ellipse.
1226
+
1227
+ If `p` is on the ellipse, returns the tangent line through point `p`.
1228
+ Otherwise, returns the tangent line(s) from `p` to the ellipse, or
1229
+ None if no tangent line is possible (e.g., `p` inside ellipse).
1230
+
1231
+ Parameters
1232
+ ==========
1233
+
1234
+ p : Point
1235
+
1236
+ Returns
1237
+ =======
1238
+
1239
+ tangent_lines : list with 1 or 2 Lines
1240
+
1241
+ Raises
1242
+ ======
1243
+
1244
+ NotImplementedError
1245
+ Can only find tangent lines for a point, `p`, on the ellipse.
1246
+
1247
+ See Also
1248
+ ========
1249
+
1250
+ sympy.geometry.point.Point, sympy.geometry.line.Line
1251
+
1252
+ Examples
1253
+ ========
1254
+
1255
+ >>> from sympy import Point, Ellipse
1256
+ >>> e1 = Ellipse(Point(0, 0), 3, 2)
1257
+ >>> e1.tangent_lines(Point(3, 0))
1258
+ [Line2D(Point2D(3, 0), Point2D(3, -12))]
1259
+
1260
+ """
1261
+ p = Point(p, dim=2)
1262
+ if self.encloses_point(p):
1263
+ return []
1264
+
1265
+ if p in self:
1266
+ delta = self.center - p
1267
+ rise = (self.vradius**2)*delta.x
1268
+ run = -(self.hradius**2)*delta.y
1269
+ p2 = Point(simplify(p.x + run),
1270
+ simplify(p.y + rise))
1271
+ return [Line(p, p2)]
1272
+ else:
1273
+ if len(self.foci) == 2:
1274
+ f1, f2 = self.foci
1275
+ maj = self.hradius
1276
+ test = (2*maj -
1277
+ Point.distance(f1, p) -
1278
+ Point.distance(f2, p))
1279
+ else:
1280
+ test = self.radius - Point.distance(self.center, p)
1281
+ if test.is_number and test.is_positive:
1282
+ return []
1283
+ # else p is outside the ellipse or we can't tell. In case of the
1284
+ # latter, the solutions returned will only be valid if
1285
+ # the point is not inside the ellipse; if it is, nan will result.
1286
+ eq = self.equation(x, y)
1287
+ dydx = idiff(eq, y, x)
1288
+ slope = Line(p, Point(x, y)).slope
1289
+
1290
+ # TODO: Replace solve with solveset, when this line is tested
1291
+ tangent_points = solve([slope - dydx, eq], [x, y])
1292
+
1293
+ # handle horizontal and vertical tangent lines
1294
+ if len(tangent_points) == 1:
1295
+ if tangent_points[0][
1296
+ 0] == p.x or tangent_points[0][1] == p.y:
1297
+ return [Line(p, p + Point(1, 0)), Line(p, p + Point(0, 1))]
1298
+ else:
1299
+ return [Line(p, p + Point(0, 1)), Line(p, tangent_points[0])]
1300
+
1301
+ # others
1302
+ return [Line(p, tangent_points[0]), Line(p, tangent_points[1])]
1303
+
1304
+ @property
1305
+ def vradius(self):
1306
+ """The vertical radius of the ellipse.
1307
+
1308
+ Returns
1309
+ =======
1310
+
1311
+ vradius : number
1312
+
1313
+ See Also
1314
+ ========
1315
+
1316
+ hradius, major, minor
1317
+
1318
+ Examples
1319
+ ========
1320
+
1321
+ >>> from sympy import Point, Ellipse
1322
+ >>> p1 = Point(0, 0)
1323
+ >>> e1 = Ellipse(p1, 3, 1)
1324
+ >>> e1.vradius
1325
+ 1
1326
+
1327
+ """
1328
+ return self.args[2]
1329
+
1330
+
1331
+ def second_moment_of_area(self, point=None):
1332
+ """Returns the second moment and product moment area of an ellipse.
1333
+
1334
+ Parameters
1335
+ ==========
1336
+
1337
+ point : Point, two-tuple of sympifiable objects, or None(default=None)
1338
+ point is the point about which second moment of area is to be found.
1339
+ If "point=None" it will be calculated about the axis passing through the
1340
+ centroid of the ellipse.
1341
+
1342
+ Returns
1343
+ =======
1344
+
1345
+ I_xx, I_yy, I_xy : number or SymPy expression
1346
+ I_xx, I_yy are second moment of area of an ellise.
1347
+ I_xy is product moment of area of an ellipse.
1348
+
1349
+ Examples
1350
+ ========
1351
+
1352
+ >>> from sympy import Point, Ellipse
1353
+ >>> p1 = Point(0, 0)
1354
+ >>> e1 = Ellipse(p1, 3, 1)
1355
+ >>> e1.second_moment_of_area()
1356
+ (3*pi/4, 27*pi/4, 0)
1357
+
1358
+ References
1359
+ ==========
1360
+
1361
+ .. [1] https://en.wikipedia.org/wiki/List_of_second_moments_of_area
1362
+
1363
+ """
1364
+
1365
+ I_xx = (S.Pi*(self.hradius)*(self.vradius**3))/4
1366
+ I_yy = (S.Pi*(self.hradius**3)*(self.vradius))/4
1367
+ I_xy = 0
1368
+
1369
+ if point is None:
1370
+ return I_xx, I_yy, I_xy
1371
+
1372
+ # parallel axis theorem
1373
+ I_xx = I_xx + self.area*((point[1] - self.center.y)**2)
1374
+ I_yy = I_yy + self.area*((point[0] - self.center.x)**2)
1375
+ I_xy = I_xy + self.area*(point[0] - self.center.x)*(point[1] - self.center.y)
1376
+
1377
+ return I_xx, I_yy, I_xy
1378
+
1379
+
1380
+ def polar_second_moment_of_area(self):
1381
+ """Returns the polar second moment of area of an Ellipse
1382
+
1383
+ It is a constituent of the second moment of area, linked through
1384
+ the perpendicular axis theorem. While the planar second moment of
1385
+ area describes an object's resistance to deflection (bending) when
1386
+ subjected to a force applied to a plane parallel to the central
1387
+ axis, the polar second moment of area describes an object's
1388
+ resistance to deflection when subjected to a moment applied in a
1389
+ plane perpendicular to the object's central axis (i.e. parallel to
1390
+ the cross-section)
1391
+
1392
+ Examples
1393
+ ========
1394
+
1395
+ >>> from sympy import symbols, Circle, Ellipse
1396
+ >>> c = Circle((5, 5), 4)
1397
+ >>> c.polar_second_moment_of_area()
1398
+ 128*pi
1399
+ >>> a, b = symbols('a, b')
1400
+ >>> e = Ellipse((0, 0), a, b)
1401
+ >>> e.polar_second_moment_of_area()
1402
+ pi*a**3*b/4 + pi*a*b**3/4
1403
+
1404
+ References
1405
+ ==========
1406
+
1407
+ .. [1] https://en.wikipedia.org/wiki/Polar_moment_of_inertia
1408
+
1409
+ """
1410
+ second_moment = self.second_moment_of_area()
1411
+ return second_moment[0] + second_moment[1]
1412
+
1413
+
1414
+ def section_modulus(self, point=None):
1415
+ """Returns a tuple with the section modulus of an ellipse
1416
+
1417
+ Section modulus is a geometric property of an ellipse defined as the
1418
+ ratio of second moment of area to the distance of the extreme end of
1419
+ the ellipse from the centroidal axis.
1420
+
1421
+ Parameters
1422
+ ==========
1423
+
1424
+ point : Point, two-tuple of sympifyable objects, or None(default=None)
1425
+ point is the point at which section modulus is to be found.
1426
+ If "point=None" section modulus will be calculated for the
1427
+ point farthest from the centroidal axis of the ellipse.
1428
+
1429
+ Returns
1430
+ =======
1431
+
1432
+ S_x, S_y: numbers or SymPy expressions
1433
+ S_x is the section modulus with respect to the x-axis
1434
+ S_y is the section modulus with respect to the y-axis
1435
+ A negative sign indicates that the section modulus is
1436
+ determined for a point below the centroidal axis.
1437
+
1438
+ Examples
1439
+ ========
1440
+
1441
+ >>> from sympy import Symbol, Ellipse, Circle, Point2D
1442
+ >>> d = Symbol('d', positive=True)
1443
+ >>> c = Circle((0, 0), d/2)
1444
+ >>> c.section_modulus()
1445
+ (pi*d**3/32, pi*d**3/32)
1446
+ >>> e = Ellipse(Point2D(0, 0), 2, 4)
1447
+ >>> e.section_modulus()
1448
+ (8*pi, 4*pi)
1449
+ >>> e.section_modulus((2, 2))
1450
+ (16*pi, 4*pi)
1451
+
1452
+ References
1453
+ ==========
1454
+
1455
+ .. [1] https://en.wikipedia.org/wiki/Section_modulus
1456
+
1457
+ """
1458
+ x_c, y_c = self.center
1459
+ if point is None:
1460
+ # taking x and y as maximum distances from centroid
1461
+ x_min, y_min, x_max, y_max = self.bounds
1462
+ y = max(y_c - y_min, y_max - y_c)
1463
+ x = max(x_c - x_min, x_max - x_c)
1464
+ else:
1465
+ # taking x and y as distances of the given point from the center
1466
+ point = Point2D(point)
1467
+ y = point.y - y_c
1468
+ x = point.x - x_c
1469
+
1470
+ second_moment = self.second_moment_of_area()
1471
+ S_x = second_moment[0]/y
1472
+ S_y = second_moment[1]/x
1473
+
1474
+ return S_x, S_y
1475
+
1476
+
1477
+ class Circle(Ellipse):
1478
+ r"""A circle in space.
1479
+
1480
+ Constructed simply from a center and a radius, from three
1481
+ non-collinear points, or the equation of a circle.
1482
+
1483
+ Parameters
1484
+ ==========
1485
+
1486
+ center : Point
1487
+ radius : number or SymPy expression
1488
+ points : sequence of three Points
1489
+ equation : equation of a circle
1490
+
1491
+ Attributes
1492
+ ==========
1493
+
1494
+ radius (synonymous with hradius, vradius, major and minor)
1495
+ circumference
1496
+ equation
1497
+
1498
+ Raises
1499
+ ======
1500
+
1501
+ GeometryError
1502
+ When the given equation is not that of a circle.
1503
+ When trying to construct circle from incorrect parameters.
1504
+
1505
+ See Also
1506
+ ========
1507
+
1508
+ Ellipse, sympy.geometry.point.Point
1509
+
1510
+ Examples
1511
+ ========
1512
+
1513
+ >>> from sympy import Point, Circle, Eq
1514
+ >>> from sympy.abc import x, y, a, b
1515
+
1516
+ A circle constructed from a center and radius:
1517
+
1518
+ >>> c1 = Circle(Point(0, 0), 5)
1519
+ >>> c1.hradius, c1.vradius, c1.radius
1520
+ (5, 5, 5)
1521
+
1522
+ A circle constructed from three points:
1523
+
1524
+ >>> c2 = Circle(Point(0, 0), Point(1, 1), Point(1, 0))
1525
+ >>> c2.hradius, c2.vradius, c2.radius, c2.center
1526
+ (sqrt(2)/2, sqrt(2)/2, sqrt(2)/2, Point2D(1/2, 1/2))
1527
+
1528
+ A circle can be constructed from an equation in the form
1529
+ `ax^2 + by^2 + gx + hy + c = 0`, too:
1530
+
1531
+ >>> Circle(x**2 + y**2 - 25)
1532
+ Circle(Point2D(0, 0), 5)
1533
+
1534
+ If the variables corresponding to x and y are named something
1535
+ else, their name or symbol can be supplied:
1536
+
1537
+ >>> Circle(Eq(a**2 + b**2, 25), x='a', y=b)
1538
+ Circle(Point2D(0, 0), 5)
1539
+ """
1540
+
1541
+ def __new__(cls, *args, **kwargs):
1542
+ evaluate = kwargs.get('evaluate', global_parameters.evaluate)
1543
+ if len(args) == 1 and isinstance(args[0], (Expr, Eq)):
1544
+ x = kwargs.get('x', 'x')
1545
+ y = kwargs.get('y', 'y')
1546
+ equation = args[0].expand()
1547
+ if isinstance(equation, Eq):
1548
+ equation = equation.lhs - equation.rhs
1549
+ x = find(x, equation)
1550
+ y = find(y, equation)
1551
+
1552
+ try:
1553
+ a, b, c, d, e = linear_coeffs(equation, x**2, y**2, x, y)
1554
+ except ValueError:
1555
+ raise GeometryError("The given equation is not that of a circle.")
1556
+
1557
+ if S.Zero in (a, b) or a != b:
1558
+ raise GeometryError("The given equation is not that of a circle.")
1559
+
1560
+ center_x = -c/a/2
1561
+ center_y = -d/b/2
1562
+ r2 = (center_x**2) + (center_y**2) - e/a
1563
+
1564
+ return Circle((center_x, center_y), sqrt(r2), evaluate=evaluate)
1565
+
1566
+ else:
1567
+ c, r = None, None
1568
+ if len(args) == 3:
1569
+ args = [Point(a, dim=2, evaluate=evaluate) for a in args]
1570
+ t = Triangle(*args)
1571
+ if not isinstance(t, Triangle):
1572
+ return t
1573
+ c = t.circumcenter
1574
+ r = t.circumradius
1575
+ elif len(args) == 2:
1576
+ # Assume (center, radius) pair
1577
+ c = Point(args[0], dim=2, evaluate=evaluate)
1578
+ r = args[1]
1579
+ # this will prohibit imaginary radius
1580
+ try:
1581
+ r = Point(r, 0, evaluate=evaluate).x
1582
+ except ValueError:
1583
+ raise GeometryError("Circle with imaginary radius is not permitted")
1584
+
1585
+ if not (c is None or r is None):
1586
+ if r == 0:
1587
+ return c
1588
+ return GeometryEntity.__new__(cls, c, r, **kwargs)
1589
+
1590
+ raise GeometryError("Circle.__new__ received unknown arguments")
1591
+
1592
+ def _eval_evalf(self, prec=15, **options):
1593
+ pt, r = self.args
1594
+ dps = prec_to_dps(prec)
1595
+ pt = pt.evalf(n=dps, **options)
1596
+ r = r.evalf(n=dps, **options)
1597
+ return self.func(pt, r, evaluate=False)
1598
+
1599
+ @property
1600
+ def circumference(self):
1601
+ """The circumference of the circle.
1602
+
1603
+ Returns
1604
+ =======
1605
+
1606
+ circumference : number or SymPy expression
1607
+
1608
+ Examples
1609
+ ========
1610
+
1611
+ >>> from sympy import Point, Circle
1612
+ >>> c1 = Circle(Point(3, 4), 6)
1613
+ >>> c1.circumference
1614
+ 12*pi
1615
+
1616
+ """
1617
+ return 2 * S.Pi * self.radius
1618
+
1619
+ def equation(self, x='x', y='y'):
1620
+ """The equation of the circle.
1621
+
1622
+ Parameters
1623
+ ==========
1624
+
1625
+ x : str or Symbol, optional
1626
+ Default value is 'x'.
1627
+ y : str or Symbol, optional
1628
+ Default value is 'y'.
1629
+
1630
+ Returns
1631
+ =======
1632
+
1633
+ equation : SymPy expression
1634
+
1635
+ Examples
1636
+ ========
1637
+
1638
+ >>> from sympy import Point, Circle
1639
+ >>> c1 = Circle(Point(0, 0), 5)
1640
+ >>> c1.equation()
1641
+ x**2 + y**2 - 25
1642
+
1643
+ """
1644
+ x = _symbol(x, real=True)
1645
+ y = _symbol(y, real=True)
1646
+ t1 = (x - self.center.x)**2
1647
+ t2 = (y - self.center.y)**2
1648
+ return t1 + t2 - self.major**2
1649
+
1650
+ def intersection(self, o):
1651
+ """The intersection of this circle with another geometrical entity.
1652
+
1653
+ Parameters
1654
+ ==========
1655
+
1656
+ o : GeometryEntity
1657
+
1658
+ Returns
1659
+ =======
1660
+
1661
+ intersection : list of GeometryEntities
1662
+
1663
+ Examples
1664
+ ========
1665
+
1666
+ >>> from sympy import Point, Circle, Line, Ray
1667
+ >>> p1, p2, p3 = Point(0, 0), Point(5, 5), Point(6, 0)
1668
+ >>> p4 = Point(5, 0)
1669
+ >>> c1 = Circle(p1, 5)
1670
+ >>> c1.intersection(p2)
1671
+ []
1672
+ >>> c1.intersection(p4)
1673
+ [Point2D(5, 0)]
1674
+ >>> c1.intersection(Ray(p1, p2))
1675
+ [Point2D(5*sqrt(2)/2, 5*sqrt(2)/2)]
1676
+ >>> c1.intersection(Line(p2, p3))
1677
+ []
1678
+
1679
+ """
1680
+ return Ellipse.intersection(self, o)
1681
+
1682
+ @property
1683
+ def radius(self):
1684
+ """The radius of the circle.
1685
+
1686
+ Returns
1687
+ =======
1688
+
1689
+ radius : number or SymPy expression
1690
+
1691
+ See Also
1692
+ ========
1693
+
1694
+ Ellipse.major, Ellipse.minor, Ellipse.hradius, Ellipse.vradius
1695
+
1696
+ Examples
1697
+ ========
1698
+
1699
+ >>> from sympy import Point, Circle
1700
+ >>> c1 = Circle(Point(3, 4), 6)
1701
+ >>> c1.radius
1702
+ 6
1703
+
1704
+ """
1705
+ return self.args[1]
1706
+
1707
+ def reflect(self, line):
1708
+ """Override GeometryEntity.reflect since the radius
1709
+ is not a GeometryEntity.
1710
+
1711
+ Examples
1712
+ ========
1713
+
1714
+ >>> from sympy import Circle, Line
1715
+ >>> Circle((0, 1), 1).reflect(Line((0, 0), (1, 1)))
1716
+ Circle(Point2D(1, 0), -1)
1717
+ """
1718
+ c = self.center
1719
+ c = c.reflect(line)
1720
+ return self.func(c, -self.radius)
1721
+
1722
+ def scale(self, x=1, y=1, pt=None):
1723
+ """Override GeometryEntity.scale since the radius
1724
+ is not a GeometryEntity.
1725
+
1726
+ Examples
1727
+ ========
1728
+
1729
+ >>> from sympy import Circle
1730
+ >>> Circle((0, 0), 1).scale(2, 2)
1731
+ Circle(Point2D(0, 0), 2)
1732
+ >>> Circle((0, 0), 1).scale(2, 4)
1733
+ Ellipse(Point2D(0, 0), 2, 4)
1734
+ """
1735
+ c = self.center
1736
+ if pt:
1737
+ pt = Point(pt, dim=2)
1738
+ return self.translate(*(-pt).args).scale(x, y).translate(*pt.args)
1739
+ c = c.scale(x, y)
1740
+ x, y = [abs(i) for i in (x, y)]
1741
+ if x == y:
1742
+ return self.func(c, x*self.radius)
1743
+ h = v = self.radius
1744
+ return Ellipse(c, hradius=h*x, vradius=v*y)
1745
+
1746
+ @property
1747
+ def vradius(self):
1748
+ """
1749
+ This Ellipse property is an alias for the Circle's radius.
1750
+
1751
+ Whereas hradius, major and minor can use Ellipse's conventions,
1752
+ the vradius does not exist for a circle. It is always a positive
1753
+ value in order that the Circle, like Polygons, will have an
1754
+ area that can be positive or negative as determined by the sign
1755
+ of the hradius.
1756
+
1757
+ Examples
1758
+ ========
1759
+
1760
+ >>> from sympy import Point, Circle
1761
+ >>> c1 = Circle(Point(3, 4), 6)
1762
+ >>> c1.vradius
1763
+ 6
1764
+ """
1765
+ return abs(self.radius)
1766
+
1767
+
1768
+ from .polygon import Polygon, Triangle
.venv/lib/python3.13/site-packages/sympy/geometry/entity.py ADDED
@@ -0,0 +1,641 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """The definition of the base geometrical entity with attributes common to
2
+ all derived geometrical entities.
3
+
4
+ Contains
5
+ ========
6
+
7
+ GeometryEntity
8
+ GeometricSet
9
+
10
+ Notes
11
+ =====
12
+
13
+ A GeometryEntity is any object that has special geometric properties.
14
+ A GeometrySet is a superclass of any GeometryEntity that can also
15
+ be viewed as a sympy.sets.Set. In particular, points are the only
16
+ GeometryEntity not considered a Set.
17
+
18
+ Rn is a GeometrySet representing n-dimensional Euclidean space. R2 and
19
+ R3 are currently the only ambient spaces implemented.
20
+
21
+ """
22
+ from __future__ import annotations
23
+
24
+ from sympy.core.basic import Basic
25
+ from sympy.core.containers import Tuple
26
+ from sympy.core.evalf import EvalfMixin, N
27
+ from sympy.core.numbers import oo
28
+ from sympy.core.symbol import Dummy
29
+ from sympy.core.sympify import sympify
30
+ from sympy.functions.elementary.trigonometric import cos, sin, atan
31
+ from sympy.matrices import eye
32
+ from sympy.multipledispatch import dispatch
33
+ from sympy.printing import sstr
34
+ from sympy.sets import Set, Union, FiniteSet
35
+ from sympy.sets.handlers.intersection import intersection_sets
36
+ from sympy.sets.handlers.union import union_sets
37
+ from sympy.solvers.solvers import solve
38
+ from sympy.utilities.misc import func_name
39
+ from sympy.utilities.iterables import is_sequence
40
+
41
+
42
+ # How entities are ordered; used by __cmp__ in GeometryEntity
43
+ ordering_of_classes = [
44
+ "Point2D",
45
+ "Point3D",
46
+ "Point",
47
+ "Segment2D",
48
+ "Ray2D",
49
+ "Line2D",
50
+ "Segment3D",
51
+ "Line3D",
52
+ "Ray3D",
53
+ "Segment",
54
+ "Ray",
55
+ "Line",
56
+ "Plane",
57
+ "Triangle",
58
+ "RegularPolygon",
59
+ "Polygon",
60
+ "Circle",
61
+ "Ellipse",
62
+ "Curve",
63
+ "Parabola"
64
+ ]
65
+
66
+
67
+ x, y = [Dummy('entity_dummy') for i in range(2)]
68
+ T = Dummy('entity_dummy', real=True)
69
+
70
+
71
+ class GeometryEntity(Basic, EvalfMixin):
72
+ """The base class for all geometrical entities.
73
+
74
+ This class does not represent any particular geometric entity, it only
75
+ provides the implementation of some methods common to all subclasses.
76
+
77
+ """
78
+
79
+ __slots__: tuple[str, ...] = ()
80
+
81
+ def __cmp__(self, other):
82
+ """Comparison of two GeometryEntities."""
83
+ n1 = self.__class__.__name__
84
+ n2 = other.__class__.__name__
85
+ c = (n1 > n2) - (n1 < n2)
86
+ if not c:
87
+ return 0
88
+
89
+ i1 = -1
90
+ for cls in self.__class__.__mro__:
91
+ try:
92
+ i1 = ordering_of_classes.index(cls.__name__)
93
+ break
94
+ except ValueError:
95
+ i1 = -1
96
+ if i1 == -1:
97
+ return c
98
+
99
+ i2 = -1
100
+ for cls in other.__class__.__mro__:
101
+ try:
102
+ i2 = ordering_of_classes.index(cls.__name__)
103
+ break
104
+ except ValueError:
105
+ i2 = -1
106
+ if i2 == -1:
107
+ return c
108
+
109
+ return (i1 > i2) - (i1 < i2)
110
+
111
+ def __contains__(self, other):
112
+ """Subclasses should implement this method for anything more complex than equality."""
113
+ if type(self) is type(other):
114
+ return self == other
115
+ raise NotImplementedError()
116
+
117
+ def __getnewargs__(self):
118
+ """Returns a tuple that will be passed to __new__ on unpickling."""
119
+ return tuple(self.args)
120
+
121
+ def __ne__(self, o):
122
+ """Test inequality of two geometrical entities."""
123
+ return not self == o
124
+
125
+ def __new__(cls, *args, **kwargs):
126
+ # Points are sequences, but they should not
127
+ # be converted to Tuples, so use this detection function instead.
128
+ def is_seq_and_not_point(a):
129
+ # we cannot use isinstance(a, Point) since we cannot import Point
130
+ if hasattr(a, 'is_Point') and a.is_Point:
131
+ return False
132
+ return is_sequence(a)
133
+
134
+ args = [Tuple(*a) if is_seq_and_not_point(a) else sympify(a) for a in args]
135
+ return Basic.__new__(cls, *args)
136
+
137
+ def __radd__(self, a):
138
+ """Implementation of reverse add method."""
139
+ return a.__add__(self)
140
+
141
+ def __rtruediv__(self, a):
142
+ """Implementation of reverse division method."""
143
+ return a.__truediv__(self)
144
+
145
+ def __repr__(self):
146
+ """String representation of a GeometryEntity that can be evaluated
147
+ by sympy."""
148
+ return type(self).__name__ + repr(self.args)
149
+
150
+ def __rmul__(self, a):
151
+ """Implementation of reverse multiplication method."""
152
+ return a.__mul__(self)
153
+
154
+ def __rsub__(self, a):
155
+ """Implementation of reverse subtraction method."""
156
+ return a.__sub__(self)
157
+
158
+ def __str__(self):
159
+ """String representation of a GeometryEntity."""
160
+ return type(self).__name__ + sstr(self.args)
161
+
162
+ def _eval_subs(self, old, new):
163
+ from sympy.geometry.point import Point, Point3D
164
+ if is_sequence(old) or is_sequence(new):
165
+ if isinstance(self, Point3D):
166
+ old = Point3D(old)
167
+ new = Point3D(new)
168
+ else:
169
+ old = Point(old)
170
+ new = Point(new)
171
+ return self._subs(old, new)
172
+
173
+ def _repr_svg_(self):
174
+ """SVG representation of a GeometryEntity suitable for IPython"""
175
+
176
+ try:
177
+ bounds = self.bounds
178
+ except (NotImplementedError, TypeError):
179
+ # if we have no SVG representation, return None so IPython
180
+ # will fall back to the next representation
181
+ return None
182
+
183
+ if not all(x.is_number and x.is_finite for x in bounds):
184
+ return None
185
+
186
+ svg_top = '''<svg xmlns="http://www.w3.org/2000/svg"
187
+ xmlns:xlink="http://www.w3.org/1999/xlink"
188
+ width="{1}" height="{2}" viewBox="{0}"
189
+ preserveAspectRatio="xMinYMin meet">
190
+ <defs>
191
+ <marker id="markerCircle" markerWidth="8" markerHeight="8"
192
+ refx="5" refy="5" markerUnits="strokeWidth">
193
+ <circle cx="5" cy="5" r="1.5" style="stroke: none; fill:#000000;"/>
194
+ </marker>
195
+ <marker id="markerArrow" markerWidth="13" markerHeight="13" refx="2" refy="4"
196
+ orient="auto" markerUnits="strokeWidth">
197
+ <path d="M2,2 L2,6 L6,4" style="fill: #000000;" />
198
+ </marker>
199
+ <marker id="markerReverseArrow" markerWidth="13" markerHeight="13" refx="6" refy="4"
200
+ orient="auto" markerUnits="strokeWidth">
201
+ <path d="M6,2 L6,6 L2,4" style="fill: #000000;" />
202
+ </marker>
203
+ </defs>'''
204
+
205
+ # Establish SVG canvas that will fit all the data + small space
206
+ xmin, ymin, xmax, ymax = map(N, bounds)
207
+ if xmin == xmax and ymin == ymax:
208
+ # This is a point; buffer using an arbitrary size
209
+ xmin, ymin, xmax, ymax = xmin - .5, ymin -.5, xmax + .5, ymax + .5
210
+ else:
211
+ # Expand bounds by a fraction of the data ranges
212
+ expand = 0.1 # or 10%; this keeps arrowheads in view (R plots use 4%)
213
+ widest_part = max([xmax - xmin, ymax - ymin])
214
+ expand_amount = widest_part * expand
215
+ xmin -= expand_amount
216
+ ymin -= expand_amount
217
+ xmax += expand_amount
218
+ ymax += expand_amount
219
+ dx = xmax - xmin
220
+ dy = ymax - ymin
221
+ width = min([max([100., dx]), 300])
222
+ height = min([max([100., dy]), 300])
223
+
224
+ scale_factor = 1. if max(width, height) == 0 else max(dx, dy) / max(width, height)
225
+ try:
226
+ svg = self._svg(scale_factor)
227
+ except (NotImplementedError, TypeError):
228
+ # if we have no SVG representation, return None so IPython
229
+ # will fall back to the next representation
230
+ return None
231
+
232
+ view_box = "{} {} {} {}".format(xmin, ymin, dx, dy)
233
+ transform = "matrix(1,0,0,-1,0,{})".format(ymax + ymin)
234
+ svg_top = svg_top.format(view_box, width, height)
235
+
236
+ return svg_top + (
237
+ '<g transform="{}">{}</g></svg>'
238
+ ).format(transform, svg)
239
+
240
+ def _svg(self, scale_factor=1., fill_color="#66cc99"):
241
+ """Returns SVG path element for the GeometryEntity.
242
+
243
+ Parameters
244
+ ==========
245
+
246
+ scale_factor : float
247
+ Multiplication factor for the SVG stroke-width. Default is 1.
248
+ fill_color : str, optional
249
+ Hex string for fill color. Default is "#66cc99".
250
+ """
251
+ raise NotImplementedError()
252
+
253
+ def _sympy_(self):
254
+ return self
255
+
256
+ @property
257
+ def ambient_dimension(self):
258
+ """What is the dimension of the space that the object is contained in?"""
259
+ raise NotImplementedError()
260
+
261
+ @property
262
+ def bounds(self):
263
+ """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding
264
+ rectangle for the geometric figure.
265
+
266
+ """
267
+
268
+ raise NotImplementedError()
269
+
270
+ def encloses(self, o):
271
+ """
272
+ Return True if o is inside (not on or outside) the boundaries of self.
273
+
274
+ The object will be decomposed into Points and individual Entities need
275
+ only define an encloses_point method for their class.
276
+
277
+ See Also
278
+ ========
279
+
280
+ sympy.geometry.ellipse.Ellipse.encloses_point
281
+ sympy.geometry.polygon.Polygon.encloses_point
282
+
283
+ Examples
284
+ ========
285
+
286
+ >>> from sympy import RegularPolygon, Point, Polygon
287
+ >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices)
288
+ >>> t2 = Polygon(*RegularPolygon(Point(0, 0), 2, 3).vertices)
289
+ >>> t2.encloses(t)
290
+ True
291
+ >>> t.encloses(t2)
292
+ False
293
+
294
+ """
295
+
296
+ from sympy.geometry.point import Point
297
+ from sympy.geometry.line import Segment, Ray, Line
298
+ from sympy.geometry.ellipse import Ellipse
299
+ from sympy.geometry.polygon import Polygon, RegularPolygon
300
+
301
+ if isinstance(o, Point):
302
+ return self.encloses_point(o)
303
+ elif isinstance(o, Segment):
304
+ return all(self.encloses_point(x) for x in o.points)
305
+ elif isinstance(o, (Ray, Line)):
306
+ return False
307
+ elif isinstance(o, Ellipse):
308
+ return self.encloses_point(o.center) and \
309
+ self.encloses_point(
310
+ Point(o.center.x + o.hradius, o.center.y)) and \
311
+ not self.intersection(o)
312
+ elif isinstance(o, Polygon):
313
+ if isinstance(o, RegularPolygon):
314
+ if not self.encloses_point(o.center):
315
+ return False
316
+ return all(self.encloses_point(v) for v in o.vertices)
317
+ raise NotImplementedError()
318
+
319
+ def equals(self, o):
320
+ return self == o
321
+
322
+ def intersection(self, o):
323
+ """
324
+ Returns a list of all of the intersections of self with o.
325
+
326
+ Notes
327
+ =====
328
+
329
+ An entity is not required to implement this method.
330
+
331
+ If two different types of entities can intersect, the item with
332
+ higher index in ordering_of_classes should implement
333
+ intersections with anything having a lower index.
334
+
335
+ See Also
336
+ ========
337
+
338
+ sympy.geometry.util.intersection
339
+
340
+ """
341
+ raise NotImplementedError()
342
+
343
+ def is_similar(self, other):
344
+ """Is this geometrical entity similar to another geometrical entity?
345
+
346
+ Two entities are similar if a uniform scaling (enlarging or
347
+ shrinking) of one of the entities will allow one to obtain the other.
348
+
349
+ Notes
350
+ =====
351
+
352
+ This method is not intended to be used directly but rather
353
+ through the `are_similar` function found in util.py.
354
+ An entity is not required to implement this method.
355
+ If two different types of entities can be similar, it is only
356
+ required that one of them be able to determine this.
357
+
358
+ See Also
359
+ ========
360
+
361
+ scale
362
+
363
+ """
364
+ raise NotImplementedError()
365
+
366
+ def reflect(self, line):
367
+ """
368
+ Reflects an object across a line.
369
+
370
+ Parameters
371
+ ==========
372
+
373
+ line: Line
374
+
375
+ Examples
376
+ ========
377
+
378
+ >>> from sympy import pi, sqrt, Line, RegularPolygon
379
+ >>> l = Line((0, pi), slope=sqrt(2))
380
+ >>> pent = RegularPolygon((1, 2), 1, 5)
381
+ >>> rpent = pent.reflect(l)
382
+ >>> rpent
383
+ RegularPolygon(Point2D(-2*sqrt(2)*pi/3 - 1/3 + 4*sqrt(2)/3, 2/3 + 2*sqrt(2)/3 + 2*pi/3), -1, 5, -atan(2*sqrt(2)) + 3*pi/5)
384
+
385
+ >>> from sympy import pi, Line, Circle, Point
386
+ >>> l = Line((0, pi), slope=1)
387
+ >>> circ = Circle(Point(0, 0), 5)
388
+ >>> rcirc = circ.reflect(l)
389
+ >>> rcirc
390
+ Circle(Point2D(-pi, pi), -5)
391
+
392
+ """
393
+ from sympy.geometry.point import Point
394
+
395
+ g = self
396
+ l = line
397
+ o = Point(0, 0)
398
+ if l.slope.is_zero:
399
+ v = l.args[0].y
400
+ if not v: # x-axis
401
+ return g.scale(y=-1)
402
+ reps = [(p, p.translate(y=2*(v - p.y))) for p in g.atoms(Point)]
403
+ elif l.slope is oo:
404
+ v = l.args[0].x
405
+ if not v: # y-axis
406
+ return g.scale(x=-1)
407
+ reps = [(p, p.translate(x=2*(v - p.x))) for p in g.atoms(Point)]
408
+ else:
409
+ if not hasattr(g, 'reflect') and not all(
410
+ isinstance(arg, Point) for arg in g.args):
411
+ raise NotImplementedError(
412
+ 'reflect undefined or non-Point args in %s' % g)
413
+ a = atan(l.slope)
414
+ c = l.coefficients
415
+ d = -c[-1]/c[1] # y-intercept
416
+ # apply the transform to a single point
417
+ xf = Point(x, y)
418
+ xf = xf.translate(y=-d).rotate(-a, o).scale(y=-1
419
+ ).rotate(a, o).translate(y=d)
420
+ # replace every point using that transform
421
+ reps = [(p, xf.xreplace({x: p.x, y: p.y})) for p in g.atoms(Point)]
422
+ return g.xreplace(dict(reps))
423
+
424
+ def rotate(self, angle, pt=None):
425
+ """Rotate ``angle`` radians counterclockwise about Point ``pt``.
426
+
427
+ The default pt is the origin, Point(0, 0)
428
+
429
+ See Also
430
+ ========
431
+
432
+ scale, translate
433
+
434
+ Examples
435
+ ========
436
+
437
+ >>> from sympy import Point, RegularPolygon, Polygon, pi
438
+ >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices)
439
+ >>> t # vertex on x axis
440
+ Triangle(Point2D(1, 0), Point2D(-1/2, sqrt(3)/2), Point2D(-1/2, -sqrt(3)/2))
441
+ >>> t.rotate(pi/2) # vertex on y axis now
442
+ Triangle(Point2D(0, 1), Point2D(-sqrt(3)/2, -1/2), Point2D(sqrt(3)/2, -1/2))
443
+
444
+ """
445
+ newargs = []
446
+ for a in self.args:
447
+ if isinstance(a, GeometryEntity):
448
+ newargs.append(a.rotate(angle, pt))
449
+ else:
450
+ newargs.append(a)
451
+ return type(self)(*newargs)
452
+
453
+ def scale(self, x=1, y=1, pt=None):
454
+ """Scale the object by multiplying the x,y-coordinates by x and y.
455
+
456
+ If pt is given, the scaling is done relative to that point; the
457
+ object is shifted by -pt, scaled, and shifted by pt.
458
+
459
+ See Also
460
+ ========
461
+
462
+ rotate, translate
463
+
464
+ Examples
465
+ ========
466
+
467
+ >>> from sympy import RegularPolygon, Point, Polygon
468
+ >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices)
469
+ >>> t
470
+ Triangle(Point2D(1, 0), Point2D(-1/2, sqrt(3)/2), Point2D(-1/2, -sqrt(3)/2))
471
+ >>> t.scale(2)
472
+ Triangle(Point2D(2, 0), Point2D(-1, sqrt(3)/2), Point2D(-1, -sqrt(3)/2))
473
+ >>> t.scale(2, 2)
474
+ Triangle(Point2D(2, 0), Point2D(-1, sqrt(3)), Point2D(-1, -sqrt(3)))
475
+
476
+ """
477
+ from sympy.geometry.point import Point
478
+ if pt:
479
+ pt = Point(pt, dim=2)
480
+ return self.translate(*(-pt).args).scale(x, y).translate(*pt.args)
481
+ return type(self)(*[a.scale(x, y) for a in self.args]) # if this fails, override this class
482
+
483
+ def translate(self, x=0, y=0):
484
+ """Shift the object by adding to the x,y-coordinates the values x and y.
485
+
486
+ See Also
487
+ ========
488
+
489
+ rotate, scale
490
+
491
+ Examples
492
+ ========
493
+
494
+ >>> from sympy import RegularPolygon, Point, Polygon
495
+ >>> t = Polygon(*RegularPolygon(Point(0, 0), 1, 3).vertices)
496
+ >>> t
497
+ Triangle(Point2D(1, 0), Point2D(-1/2, sqrt(3)/2), Point2D(-1/2, -sqrt(3)/2))
498
+ >>> t.translate(2)
499
+ Triangle(Point2D(3, 0), Point2D(3/2, sqrt(3)/2), Point2D(3/2, -sqrt(3)/2))
500
+ >>> t.translate(2, 2)
501
+ Triangle(Point2D(3, 2), Point2D(3/2, sqrt(3)/2 + 2), Point2D(3/2, 2 - sqrt(3)/2))
502
+
503
+ """
504
+ newargs = []
505
+ for a in self.args:
506
+ if isinstance(a, GeometryEntity):
507
+ newargs.append(a.translate(x, y))
508
+ else:
509
+ newargs.append(a)
510
+ return self.func(*newargs)
511
+
512
+ def parameter_value(self, other, t):
513
+ """Return the parameter corresponding to the given point.
514
+ Evaluating an arbitrary point of the entity at this parameter
515
+ value will return the given point.
516
+
517
+ Examples
518
+ ========
519
+
520
+ >>> from sympy import Line, Point
521
+ >>> from sympy.abc import t
522
+ >>> a = Point(0, 0)
523
+ >>> b = Point(2, 2)
524
+ >>> Line(a, b).parameter_value((1, 1), t)
525
+ {t: 1/2}
526
+ >>> Line(a, b).arbitrary_point(t).subs(_)
527
+ Point2D(1, 1)
528
+ """
529
+ from sympy.geometry.point import Point
530
+ if not isinstance(other, GeometryEntity):
531
+ other = Point(other, dim=self.ambient_dimension)
532
+ if not isinstance(other, Point):
533
+ raise ValueError("other must be a point")
534
+ sol = solve(self.arbitrary_point(T) - other, T, dict=True)
535
+ if not sol:
536
+ raise ValueError("Given point is not on %s" % func_name(self))
537
+ return {t: sol[0][T]}
538
+
539
+
540
+ class GeometrySet(GeometryEntity, Set):
541
+ """Parent class of all GeometryEntity that are also Sets
542
+ (compatible with sympy.sets)
543
+ """
544
+ __slots__ = ()
545
+
546
+ def _contains(self, other):
547
+ """sympy.sets uses the _contains method, so include it for compatibility."""
548
+
549
+ if isinstance(other, Set) and other.is_FiniteSet:
550
+ return all(self.__contains__(i) for i in other)
551
+
552
+ return self.__contains__(other)
553
+
554
+ @dispatch(GeometrySet, Set) # type:ignore # noqa:F811
555
+ def union_sets(self, o): # noqa:F811
556
+ """ Returns the union of self and o
557
+ for use with sympy.sets.Set, if possible. """
558
+
559
+
560
+ # if its a FiniteSet, merge any points
561
+ # we contain and return a union with the rest
562
+ if o.is_FiniteSet:
563
+ other_points = [p for p in o if not self._contains(p)]
564
+ if len(other_points) == len(o):
565
+ return None
566
+ return Union(self, FiniteSet(*other_points))
567
+ if self._contains(o):
568
+ return self
569
+ return None
570
+
571
+
572
+ @dispatch(GeometrySet, Set) # type: ignore # noqa:F811
573
+ def intersection_sets(self, o): # noqa:F811
574
+ """ Returns a sympy.sets.Set of intersection objects,
575
+ if possible. """
576
+
577
+ from sympy.geometry.point import Point
578
+
579
+ try:
580
+ # if o is a FiniteSet, find the intersection directly
581
+ # to avoid infinite recursion
582
+ if o.is_FiniteSet:
583
+ inter = FiniteSet(*(p for p in o if self.contains(p)))
584
+ else:
585
+ inter = self.intersection(o)
586
+ except NotImplementedError:
587
+ # sympy.sets.Set.reduce expects None if an object
588
+ # doesn't know how to simplify
589
+ return None
590
+
591
+ # put the points in a FiniteSet
592
+ points = FiniteSet(*[p for p in inter if isinstance(p, Point)])
593
+ non_points = [p for p in inter if not isinstance(p, Point)]
594
+
595
+ return Union(*(non_points + [points]))
596
+
597
+ def translate(x, y):
598
+ """Return the matrix to translate a 2-D point by x and y."""
599
+ rv = eye(3)
600
+ rv[2, 0] = x
601
+ rv[2, 1] = y
602
+ return rv
603
+
604
+
605
+ def scale(x, y, pt=None):
606
+ """Return the matrix to multiply a 2-D point's coordinates by x and y.
607
+
608
+ If pt is given, the scaling is done relative to that point."""
609
+ rv = eye(3)
610
+ rv[0, 0] = x
611
+ rv[1, 1] = y
612
+ if pt:
613
+ from sympy.geometry.point import Point
614
+ pt = Point(pt, dim=2)
615
+ tr1 = translate(*(-pt).args)
616
+ tr2 = translate(*pt.args)
617
+ return tr1*rv*tr2
618
+ return rv
619
+
620
+
621
+ def rotate(th):
622
+ """Return the matrix to rotate a 2-D point about the origin by ``angle``.
623
+
624
+ The angle is measured in radians. To Point a point about a point other
625
+ then the origin, translate the Point, do the rotation, and
626
+ translate it back:
627
+
628
+ >>> from sympy.geometry.entity import rotate, translate
629
+ >>> from sympy import Point, pi
630
+ >>> rot_about_11 = translate(-1, -1)*rotate(pi/2)*translate(1, 1)
631
+ >>> Point(1, 1).transform(rot_about_11)
632
+ Point2D(1, 1)
633
+ >>> Point(0, 0).transform(rot_about_11)
634
+ Point2D(2, 0)
635
+ """
636
+ s = sin(th)
637
+ rv = eye(3)*cos(th)
638
+ rv[0, 1] = s
639
+ rv[1, 0] = -s
640
+ rv[2, 2] = 1
641
+ return rv
.venv/lib/python3.13/site-packages/sympy/geometry/exceptions.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ """Geometry Errors."""
2
+
3
+ class GeometryError(ValueError):
4
+ """An exception raised by classes in the geometry module."""
5
+ pass
.venv/lib/python3.13/site-packages/sympy/geometry/line.py ADDED
@@ -0,0 +1,2877 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Line-like geometrical entities.
2
+
3
+ Contains
4
+ ========
5
+ LinearEntity
6
+ Line
7
+ Ray
8
+ Segment
9
+ LinearEntity2D
10
+ Line2D
11
+ Ray2D
12
+ Segment2D
13
+ LinearEntity3D
14
+ Line3D
15
+ Ray3D
16
+ Segment3D
17
+
18
+ """
19
+
20
+ from sympy.core.containers import Tuple
21
+ from sympy.core.evalf import N
22
+ from sympy.core.expr import Expr
23
+ from sympy.core.numbers import Rational, oo, Float
24
+ from sympy.core.relational import Eq
25
+ from sympy.core.singleton import S
26
+ from sympy.core.sorting import ordered
27
+ from sympy.core.symbol import _symbol, Dummy, uniquely_named_symbol
28
+ from sympy.core.sympify import sympify
29
+ from sympy.functions.elementary.piecewise import Piecewise
30
+ from sympy.functions.elementary.trigonometric import (_pi_coeff, acos, tan, atan2)
31
+ from .entity import GeometryEntity, GeometrySet
32
+ from .exceptions import GeometryError
33
+ from .point import Point, Point3D
34
+ from .util import find, intersection
35
+ from sympy.logic.boolalg import And
36
+ from sympy.matrices import Matrix
37
+ from sympy.sets.sets import Intersection
38
+ from sympy.simplify.simplify import simplify
39
+ from sympy.solvers.solvers import solve
40
+ from sympy.solvers.solveset import linear_coeffs
41
+ from sympy.utilities.misc import Undecidable, filldedent
42
+
43
+
44
+ import random
45
+
46
+
47
+ t, u = [Dummy('line_dummy') for i in range(2)]
48
+
49
+
50
+ class LinearEntity(GeometrySet):
51
+ """A base class for all linear entities (Line, Ray and Segment)
52
+ in n-dimensional Euclidean space.
53
+
54
+ Attributes
55
+ ==========
56
+
57
+ ambient_dimension
58
+ direction
59
+ length
60
+ p1
61
+ p2
62
+ points
63
+
64
+ Notes
65
+ =====
66
+
67
+ This is an abstract class and is not meant to be instantiated.
68
+
69
+ See Also
70
+ ========
71
+
72
+ sympy.geometry.entity.GeometryEntity
73
+
74
+ """
75
+ def __new__(cls, p1, p2=None, **kwargs):
76
+ p1, p2 = Point._normalize_dimension(p1, p2)
77
+ if p1 == p2:
78
+ # sometimes we return a single point if we are not given two unique
79
+ # points. This is done in the specific subclass
80
+ raise ValueError(
81
+ "%s.__new__ requires two unique Points." % cls.__name__)
82
+ if len(p1) != len(p2):
83
+ raise ValueError(
84
+ "%s.__new__ requires two Points of equal dimension." % cls.__name__)
85
+
86
+ return GeometryEntity.__new__(cls, p1, p2, **kwargs)
87
+
88
+ def __contains__(self, other):
89
+ """Return a definitive answer or else raise an error if it cannot
90
+ be determined that other is on the boundaries of self."""
91
+ result = self.contains(other)
92
+
93
+ if result is not None:
94
+ return result
95
+ else:
96
+ raise Undecidable(
97
+ "Cannot decide whether '%s' contains '%s'" % (self, other))
98
+
99
+ def _span_test(self, other):
100
+ """Test whether the point `other` lies in the positive span of `self`.
101
+ A point x is 'in front' of a point y if x.dot(y) >= 0. Return
102
+ -1 if `other` is behind `self.p1`, 0 if `other` is `self.p1` and
103
+ and 1 if `other` is in front of `self.p1`."""
104
+ if self.p1 == other:
105
+ return 0
106
+
107
+ rel_pos = other - self.p1
108
+ d = self.direction
109
+ if d.dot(rel_pos) > 0:
110
+ return 1
111
+ return -1
112
+
113
+ @property
114
+ def ambient_dimension(self):
115
+ """A property method that returns the dimension of LinearEntity
116
+ object.
117
+
118
+ Parameters
119
+ ==========
120
+
121
+ p1 : LinearEntity
122
+
123
+ Returns
124
+ =======
125
+
126
+ dimension : integer
127
+
128
+ Examples
129
+ ========
130
+
131
+ >>> from sympy import Point, Line
132
+ >>> p1, p2 = Point(0, 0), Point(1, 1)
133
+ >>> l1 = Line(p1, p2)
134
+ >>> l1.ambient_dimension
135
+ 2
136
+
137
+ >>> from sympy import Point, Line
138
+ >>> p1, p2 = Point(0, 0, 0), Point(1, 1, 1)
139
+ >>> l1 = Line(p1, p2)
140
+ >>> l1.ambient_dimension
141
+ 3
142
+
143
+ """
144
+ return len(self.p1)
145
+
146
+ def angle_between(l1, l2):
147
+ """Return the non-reflex angle formed by rays emanating from
148
+ the origin with directions the same as the direction vectors
149
+ of the linear entities.
150
+
151
+ Parameters
152
+ ==========
153
+
154
+ l1 : LinearEntity
155
+ l2 : LinearEntity
156
+
157
+ Returns
158
+ =======
159
+
160
+ angle : angle in radians
161
+
162
+ Notes
163
+ =====
164
+
165
+ From the dot product of vectors v1 and v2 it is known that:
166
+
167
+ ``dot(v1, v2) = |v1|*|v2|*cos(A)``
168
+
169
+ where A is the angle formed between the two vectors. We can
170
+ get the directional vectors of the two lines and readily
171
+ find the angle between the two using the above formula.
172
+
173
+ See Also
174
+ ========
175
+
176
+ is_perpendicular, Ray2D.closing_angle
177
+
178
+ Examples
179
+ ========
180
+
181
+ >>> from sympy import Line
182
+ >>> e = Line((0, 0), (1, 0))
183
+ >>> ne = Line((0, 0), (1, 1))
184
+ >>> sw = Line((1, 1), (0, 0))
185
+ >>> ne.angle_between(e)
186
+ pi/4
187
+ >>> sw.angle_between(e)
188
+ 3*pi/4
189
+
190
+ To obtain the non-obtuse angle at the intersection of lines, use
191
+ the ``smallest_angle_between`` method:
192
+
193
+ >>> sw.smallest_angle_between(e)
194
+ pi/4
195
+
196
+ >>> from sympy import Point3D, Line3D
197
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(-1, 2, 0)
198
+ >>> l1, l2 = Line3D(p1, p2), Line3D(p2, p3)
199
+ >>> l1.angle_between(l2)
200
+ acos(-sqrt(2)/3)
201
+ >>> l1.smallest_angle_between(l2)
202
+ acos(sqrt(2)/3)
203
+ """
204
+ if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity):
205
+ raise TypeError('Must pass only LinearEntity objects')
206
+
207
+ v1, v2 = l1.direction, l2.direction
208
+ return acos(v1.dot(v2)/(abs(v1)*abs(v2)))
209
+
210
+ def smallest_angle_between(l1, l2):
211
+ """Return the smallest angle formed at the intersection of the
212
+ lines containing the linear entities.
213
+
214
+ Parameters
215
+ ==========
216
+
217
+ l1 : LinearEntity
218
+ l2 : LinearEntity
219
+
220
+ Returns
221
+ =======
222
+
223
+ angle : angle in radians
224
+
225
+ Examples
226
+ ========
227
+
228
+ >>> from sympy import Point, Line
229
+ >>> p1, p2, p3 = Point(0, 0), Point(0, 4), Point(2, -2)
230
+ >>> l1, l2 = Line(p1, p2), Line(p1, p3)
231
+ >>> l1.smallest_angle_between(l2)
232
+ pi/4
233
+
234
+ See Also
235
+ ========
236
+
237
+ angle_between, is_perpendicular, Ray2D.closing_angle
238
+ """
239
+ if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity):
240
+ raise TypeError('Must pass only LinearEntity objects')
241
+
242
+ v1, v2 = l1.direction, l2.direction
243
+ return acos(abs(v1.dot(v2))/(abs(v1)*abs(v2)))
244
+
245
+ def arbitrary_point(self, parameter='t'):
246
+ """A parameterized point on the Line.
247
+
248
+ Parameters
249
+ ==========
250
+
251
+ parameter : str, optional
252
+ The name of the parameter which will be used for the parametric
253
+ point. The default value is 't'. When this parameter is 0, the
254
+ first point used to define the line will be returned, and when
255
+ it is 1 the second point will be returned.
256
+
257
+ Returns
258
+ =======
259
+
260
+ point : Point
261
+
262
+ Raises
263
+ ======
264
+
265
+ ValueError
266
+ When ``parameter`` already appears in the Line's definition.
267
+
268
+ See Also
269
+ ========
270
+
271
+ sympy.geometry.point.Point
272
+
273
+ Examples
274
+ ========
275
+
276
+ >>> from sympy import Point, Line
277
+ >>> p1, p2 = Point(1, 0), Point(5, 3)
278
+ >>> l1 = Line(p1, p2)
279
+ >>> l1.arbitrary_point()
280
+ Point2D(4*t + 1, 3*t)
281
+ >>> from sympy import Point3D, Line3D
282
+ >>> p1, p2 = Point3D(1, 0, 0), Point3D(5, 3, 1)
283
+ >>> l1 = Line3D(p1, p2)
284
+ >>> l1.arbitrary_point()
285
+ Point3D(4*t + 1, 3*t, t)
286
+
287
+ """
288
+ t = _symbol(parameter, real=True)
289
+ if t.name in (f.name for f in self.free_symbols):
290
+ raise ValueError(filldedent('''
291
+ Symbol %s already appears in object
292
+ and cannot be used as a parameter.
293
+ ''' % t.name))
294
+ # multiply on the right so the variable gets
295
+ # combined with the coordinates of the point
296
+ return self.p1 + (self.p2 - self.p1)*t
297
+
298
+ @staticmethod
299
+ def are_concurrent(*lines):
300
+ """Is a sequence of linear entities concurrent?
301
+
302
+ Two or more linear entities are concurrent if they all
303
+ intersect at a single point.
304
+
305
+ Parameters
306
+ ==========
307
+
308
+ lines
309
+ A sequence of linear entities.
310
+
311
+ Returns
312
+ =======
313
+
314
+ True : if the set of linear entities intersect in one point
315
+ False : otherwise.
316
+
317
+ See Also
318
+ ========
319
+
320
+ sympy.geometry.util.intersection
321
+
322
+ Examples
323
+ ========
324
+
325
+ >>> from sympy import Point, Line
326
+ >>> p1, p2 = Point(0, 0), Point(3, 5)
327
+ >>> p3, p4 = Point(-2, -2), Point(0, 2)
328
+ >>> l1, l2, l3 = Line(p1, p2), Line(p1, p3), Line(p1, p4)
329
+ >>> Line.are_concurrent(l1, l2, l3)
330
+ True
331
+ >>> l4 = Line(p2, p3)
332
+ >>> Line.are_concurrent(l2, l3, l4)
333
+ False
334
+ >>> from sympy import Point3D, Line3D
335
+ >>> p1, p2 = Point3D(0, 0, 0), Point3D(3, 5, 2)
336
+ >>> p3, p4 = Point3D(-2, -2, -2), Point3D(0, 2, 1)
337
+ >>> l1, l2, l3 = Line3D(p1, p2), Line3D(p1, p3), Line3D(p1, p4)
338
+ >>> Line3D.are_concurrent(l1, l2, l3)
339
+ True
340
+ >>> l4 = Line3D(p2, p3)
341
+ >>> Line3D.are_concurrent(l2, l3, l4)
342
+ False
343
+
344
+ """
345
+ common_points = Intersection(*lines)
346
+ if common_points.is_FiniteSet and len(common_points) == 1:
347
+ return True
348
+ return False
349
+
350
+ def contains(self, other):
351
+ """Subclasses should implement this method and should return
352
+ True if other is on the boundaries of self;
353
+ False if not on the boundaries of self;
354
+ None if a determination cannot be made."""
355
+ raise NotImplementedError()
356
+
357
+ @property
358
+ def direction(self):
359
+ """The direction vector of the LinearEntity.
360
+
361
+ Returns
362
+ =======
363
+
364
+ p : a Point; the ray from the origin to this point is the
365
+ direction of `self`
366
+
367
+ Examples
368
+ ========
369
+
370
+ >>> from sympy import Line
371
+ >>> a, b = (1, 1), (1, 3)
372
+ >>> Line(a, b).direction
373
+ Point2D(0, 2)
374
+ >>> Line(b, a).direction
375
+ Point2D(0, -2)
376
+
377
+ This can be reported so the distance from the origin is 1:
378
+
379
+ >>> Line(b, a).direction.unit
380
+ Point2D(0, -1)
381
+
382
+ See Also
383
+ ========
384
+
385
+ sympy.geometry.point.Point.unit
386
+
387
+ """
388
+ return self.p2 - self.p1
389
+
390
+ def intersection(self, other):
391
+ """The intersection with another geometrical entity.
392
+
393
+ Parameters
394
+ ==========
395
+
396
+ o : Point or LinearEntity
397
+
398
+ Returns
399
+ =======
400
+
401
+ intersection : list of geometrical entities
402
+
403
+ See Also
404
+ ========
405
+
406
+ sympy.geometry.point.Point
407
+
408
+ Examples
409
+ ========
410
+
411
+ >>> from sympy import Point, Line, Segment
412
+ >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(7, 7)
413
+ >>> l1 = Line(p1, p2)
414
+ >>> l1.intersection(p3)
415
+ [Point2D(7, 7)]
416
+ >>> p4, p5 = Point(5, 0), Point(0, 3)
417
+ >>> l2 = Line(p4, p5)
418
+ >>> l1.intersection(l2)
419
+ [Point2D(15/8, 15/8)]
420
+ >>> p6, p7 = Point(0, 5), Point(2, 6)
421
+ >>> s1 = Segment(p6, p7)
422
+ >>> l1.intersection(s1)
423
+ []
424
+ >>> from sympy import Point3D, Line3D, Segment3D
425
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(7, 7, 7)
426
+ >>> l1 = Line3D(p1, p2)
427
+ >>> l1.intersection(p3)
428
+ [Point3D(7, 7, 7)]
429
+ >>> l1 = Line3D(Point3D(4,19,12), Point3D(5,25,17))
430
+ >>> l2 = Line3D(Point3D(-3, -15, -19), direction_ratio=[2,8,8])
431
+ >>> l1.intersection(l2)
432
+ [Point3D(1, 1, -3)]
433
+ >>> p6, p7 = Point3D(0, 5, 2), Point3D(2, 6, 3)
434
+ >>> s1 = Segment3D(p6, p7)
435
+ >>> l1.intersection(s1)
436
+ []
437
+
438
+ """
439
+ def intersect_parallel_rays(ray1, ray2):
440
+ if ray1.direction.dot(ray2.direction) > 0:
441
+ # rays point in the same direction
442
+ # so return the one that is "in front"
443
+ return [ray2] if ray1._span_test(ray2.p1) >= 0 else [ray1]
444
+ else:
445
+ # rays point in opposite directions
446
+ st = ray1._span_test(ray2.p1)
447
+ if st < 0:
448
+ return []
449
+ elif st == 0:
450
+ return [ray2.p1]
451
+ return [Segment(ray1.p1, ray2.p1)]
452
+
453
+ def intersect_parallel_ray_and_segment(ray, seg):
454
+ st1, st2 = ray._span_test(seg.p1), ray._span_test(seg.p2)
455
+ if st1 < 0 and st2 < 0:
456
+ return []
457
+ elif st1 >= 0 and st2 >= 0:
458
+ return [seg]
459
+ elif st1 >= 0: # st2 < 0:
460
+ return [Segment(ray.p1, seg.p1)]
461
+ else: # st1 < 0 and st2 >= 0:
462
+ return [Segment(ray.p1, seg.p2)]
463
+
464
+ def intersect_parallel_segments(seg1, seg2):
465
+ if seg1.contains(seg2):
466
+ return [seg2]
467
+ if seg2.contains(seg1):
468
+ return [seg1]
469
+
470
+ # direct the segments so they're oriented the same way
471
+ if seg1.direction.dot(seg2.direction) < 0:
472
+ seg2 = Segment(seg2.p2, seg2.p1)
473
+ # order the segments so seg1 is "behind" seg2
474
+ if seg1._span_test(seg2.p1) < 0:
475
+ seg1, seg2 = seg2, seg1
476
+ if seg2._span_test(seg1.p2) < 0:
477
+ return []
478
+ return [Segment(seg2.p1, seg1.p2)]
479
+
480
+ if not isinstance(other, GeometryEntity):
481
+ other = Point(other, dim=self.ambient_dimension)
482
+ if other.is_Point:
483
+ if self.contains(other):
484
+ return [other]
485
+ else:
486
+ return []
487
+ elif isinstance(other, LinearEntity):
488
+ # break into cases based on whether
489
+ # the lines are parallel, non-parallel intersecting, or skew
490
+ pts = Point._normalize_dimension(self.p1, self.p2, other.p1, other.p2)
491
+ rank = Point.affine_rank(*pts)
492
+
493
+ if rank == 1:
494
+ # we're collinear
495
+ if isinstance(self, Line):
496
+ return [other]
497
+ if isinstance(other, Line):
498
+ return [self]
499
+
500
+ if isinstance(self, Ray) and isinstance(other, Ray):
501
+ return intersect_parallel_rays(self, other)
502
+ if isinstance(self, Ray) and isinstance(other, Segment):
503
+ return intersect_parallel_ray_and_segment(self, other)
504
+ if isinstance(self, Segment) and isinstance(other, Ray):
505
+ return intersect_parallel_ray_and_segment(other, self)
506
+ if isinstance(self, Segment) and isinstance(other, Segment):
507
+ return intersect_parallel_segments(self, other)
508
+ elif rank == 2:
509
+ # we're in the same plane
510
+ l1 = Line(*pts[:2])
511
+ l2 = Line(*pts[2:])
512
+
513
+ # check to see if we're parallel. If we are, we can't
514
+ # be intersecting, since the collinear case was already
515
+ # handled
516
+ if l1.direction.is_scalar_multiple(l2.direction):
517
+ return []
518
+
519
+ # find the intersection as if everything were lines
520
+ # by solving the equation t*d + p1 == s*d' + p1'
521
+ m = Matrix([l1.direction, -l2.direction]).transpose()
522
+ v = Matrix([l2.p1 - l1.p1]).transpose()
523
+
524
+ # we cannot use m.solve(v) because that only works for square matrices
525
+ m_rref, pivots = m.col_insert(2, v).rref(simplify=True)
526
+ # rank == 2 ensures we have 2 pivots, but let's check anyway
527
+ if len(pivots) != 2:
528
+ raise GeometryError("Failed when solving Mx=b when M={} and b={}".format(m, v))
529
+ coeff = m_rref[0, 2]
530
+ line_intersection = l1.direction*coeff + self.p1
531
+
532
+ # if both are lines, skip a containment check
533
+ if isinstance(self, Line) and isinstance(other, Line):
534
+ return [line_intersection]
535
+
536
+ if ((isinstance(self, Line) or
537
+ self.contains(line_intersection)) and
538
+ other.contains(line_intersection)):
539
+ return [line_intersection]
540
+ if not self.atoms(Float) and not other.atoms(Float):
541
+ # if it can fail when there are no Floats then
542
+ # maybe the following parametric check should be
543
+ # done
544
+ return []
545
+ # floats may fail exact containment so check that the
546
+ # arbitrary points, when equal, both give a
547
+ # non-negative parameter when the arbitrary point
548
+ # coordinates are equated
549
+ tu = solve(self.arbitrary_point(t) - other.arbitrary_point(u),
550
+ t, u, dict=True)[0]
551
+ def ok(p, l):
552
+ if isinstance(l, Line):
553
+ # p > -oo
554
+ return True
555
+ if isinstance(l, Ray):
556
+ # p >= 0
557
+ return p.is_nonnegative
558
+ if isinstance(l, Segment):
559
+ # 0 <= p <= 1
560
+ return p.is_nonnegative and (1 - p).is_nonnegative
561
+ raise ValueError("unexpected line type")
562
+ if ok(tu[t], self) and ok(tu[u], other):
563
+ return [line_intersection]
564
+ return []
565
+ else:
566
+ # we're skew
567
+ return []
568
+
569
+ return other.intersection(self)
570
+
571
+ def is_parallel(l1, l2):
572
+ """Are two linear entities parallel?
573
+
574
+ Parameters
575
+ ==========
576
+
577
+ l1 : LinearEntity
578
+ l2 : LinearEntity
579
+
580
+ Returns
581
+ =======
582
+
583
+ True : if l1 and l2 are parallel,
584
+ False : otherwise.
585
+
586
+ See Also
587
+ ========
588
+
589
+ coefficients
590
+
591
+ Examples
592
+ ========
593
+
594
+ >>> from sympy import Point, Line
595
+ >>> p1, p2 = Point(0, 0), Point(1, 1)
596
+ >>> p3, p4 = Point(3, 4), Point(6, 7)
597
+ >>> l1, l2 = Line(p1, p2), Line(p3, p4)
598
+ >>> Line.is_parallel(l1, l2)
599
+ True
600
+ >>> p5 = Point(6, 6)
601
+ >>> l3 = Line(p3, p5)
602
+ >>> Line.is_parallel(l1, l3)
603
+ False
604
+ >>> from sympy import Point3D, Line3D
605
+ >>> p1, p2 = Point3D(0, 0, 0), Point3D(3, 4, 5)
606
+ >>> p3, p4 = Point3D(2, 1, 1), Point3D(8, 9, 11)
607
+ >>> l1, l2 = Line3D(p1, p2), Line3D(p3, p4)
608
+ >>> Line3D.is_parallel(l1, l2)
609
+ True
610
+ >>> p5 = Point3D(6, 6, 6)
611
+ >>> l3 = Line3D(p3, p5)
612
+ >>> Line3D.is_parallel(l1, l3)
613
+ False
614
+
615
+ """
616
+ if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity):
617
+ raise TypeError('Must pass only LinearEntity objects')
618
+
619
+ return l1.direction.is_scalar_multiple(l2.direction)
620
+
621
+ def is_perpendicular(l1, l2):
622
+ """Are two linear entities perpendicular?
623
+
624
+ Parameters
625
+ ==========
626
+
627
+ l1 : LinearEntity
628
+ l2 : LinearEntity
629
+
630
+ Returns
631
+ =======
632
+
633
+ True : if l1 and l2 are perpendicular,
634
+ False : otherwise.
635
+
636
+ See Also
637
+ ========
638
+
639
+ coefficients
640
+
641
+ Examples
642
+ ========
643
+
644
+ >>> from sympy import Point, Line
645
+ >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(-1, 1)
646
+ >>> l1, l2 = Line(p1, p2), Line(p1, p3)
647
+ >>> l1.is_perpendicular(l2)
648
+ True
649
+ >>> p4 = Point(5, 3)
650
+ >>> l3 = Line(p1, p4)
651
+ >>> l1.is_perpendicular(l3)
652
+ False
653
+ >>> from sympy import Point3D, Line3D
654
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(-1, 2, 0)
655
+ >>> l1, l2 = Line3D(p1, p2), Line3D(p2, p3)
656
+ >>> l1.is_perpendicular(l2)
657
+ False
658
+ >>> p4 = Point3D(5, 3, 7)
659
+ >>> l3 = Line3D(p1, p4)
660
+ >>> l1.is_perpendicular(l3)
661
+ False
662
+
663
+ """
664
+ if not isinstance(l1, LinearEntity) and not isinstance(l2, LinearEntity):
665
+ raise TypeError('Must pass only LinearEntity objects')
666
+
667
+ return S.Zero.equals(l1.direction.dot(l2.direction))
668
+
669
+ def is_similar(self, other):
670
+ """
671
+ Return True if self and other are contained in the same line.
672
+
673
+ Examples
674
+ ========
675
+
676
+ >>> from sympy import Point, Line
677
+ >>> p1, p2, p3 = Point(0, 1), Point(3, 4), Point(2, 3)
678
+ >>> l1 = Line(p1, p2)
679
+ >>> l2 = Line(p1, p3)
680
+ >>> l1.is_similar(l2)
681
+ True
682
+ """
683
+ l = Line(self.p1, self.p2)
684
+ return l.contains(other)
685
+
686
+ @property
687
+ def length(self):
688
+ """
689
+ The length of the line.
690
+
691
+ Examples
692
+ ========
693
+
694
+ >>> from sympy import Point, Line
695
+ >>> p1, p2 = Point(0, 0), Point(3, 5)
696
+ >>> l1 = Line(p1, p2)
697
+ >>> l1.length
698
+ oo
699
+ """
700
+ return S.Infinity
701
+
702
+ @property
703
+ def p1(self):
704
+ """The first defining point of a linear entity.
705
+
706
+ See Also
707
+ ========
708
+
709
+ sympy.geometry.point.Point
710
+
711
+ Examples
712
+ ========
713
+
714
+ >>> from sympy import Point, Line
715
+ >>> p1, p2 = Point(0, 0), Point(5, 3)
716
+ >>> l = Line(p1, p2)
717
+ >>> l.p1
718
+ Point2D(0, 0)
719
+
720
+ """
721
+ return self.args[0]
722
+
723
+ @property
724
+ def p2(self):
725
+ """The second defining point of a linear entity.
726
+
727
+ See Also
728
+ ========
729
+
730
+ sympy.geometry.point.Point
731
+
732
+ Examples
733
+ ========
734
+
735
+ >>> from sympy import Point, Line
736
+ >>> p1, p2 = Point(0, 0), Point(5, 3)
737
+ >>> l = Line(p1, p2)
738
+ >>> l.p2
739
+ Point2D(5, 3)
740
+
741
+ """
742
+ return self.args[1]
743
+
744
+ def parallel_line(self, p):
745
+ """Create a new Line parallel to this linear entity which passes
746
+ through the point `p`.
747
+
748
+ Parameters
749
+ ==========
750
+
751
+ p : Point
752
+
753
+ Returns
754
+ =======
755
+
756
+ line : Line
757
+
758
+ See Also
759
+ ========
760
+
761
+ is_parallel
762
+
763
+ Examples
764
+ ========
765
+
766
+ >>> from sympy import Point, Line
767
+ >>> p1, p2, p3 = Point(0, 0), Point(2, 3), Point(-2, 2)
768
+ >>> l1 = Line(p1, p2)
769
+ >>> l2 = l1.parallel_line(p3)
770
+ >>> p3 in l2
771
+ True
772
+ >>> l1.is_parallel(l2)
773
+ True
774
+ >>> from sympy import Point3D, Line3D
775
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(2, 3, 4), Point3D(-2, 2, 0)
776
+ >>> l1 = Line3D(p1, p2)
777
+ >>> l2 = l1.parallel_line(p3)
778
+ >>> p3 in l2
779
+ True
780
+ >>> l1.is_parallel(l2)
781
+ True
782
+
783
+ """
784
+ p = Point(p, dim=self.ambient_dimension)
785
+ return Line(p, p + self.direction)
786
+
787
+ def perpendicular_line(self, p):
788
+ """Create a new Line perpendicular to this linear entity which passes
789
+ through the point `p`.
790
+
791
+ Parameters
792
+ ==========
793
+
794
+ p : Point
795
+
796
+ Returns
797
+ =======
798
+
799
+ line : Line
800
+
801
+ See Also
802
+ ========
803
+
804
+ sympy.geometry.line.LinearEntity.is_perpendicular, perpendicular_segment
805
+
806
+ Examples
807
+ ========
808
+
809
+ >>> from sympy import Point3D, Line3D
810
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(2, 3, 4), Point3D(-2, 2, 0)
811
+ >>> L = Line3D(p1, p2)
812
+ >>> P = L.perpendicular_line(p3); P
813
+ Line3D(Point3D(-2, 2, 0), Point3D(4/29, 6/29, 8/29))
814
+ >>> L.is_perpendicular(P)
815
+ True
816
+
817
+ In 3D the, the first point used to define the line is the point
818
+ through which the perpendicular was required to pass; the
819
+ second point is (arbitrarily) contained in the given line:
820
+
821
+ >>> P.p2 in L
822
+ True
823
+ """
824
+ p = Point(p, dim=self.ambient_dimension)
825
+ if p in self:
826
+ p = p + self.direction.orthogonal_direction
827
+ return Line(p, self.projection(p))
828
+
829
+ def perpendicular_segment(self, p):
830
+ """Create a perpendicular line segment from `p` to this line.
831
+
832
+ The endpoints of the segment are ``p`` and the closest point in
833
+ the line containing self. (If self is not a line, the point might
834
+ not be in self.)
835
+
836
+ Parameters
837
+ ==========
838
+
839
+ p : Point
840
+
841
+ Returns
842
+ =======
843
+
844
+ segment : Segment
845
+
846
+ Notes
847
+ =====
848
+
849
+ Returns `p` itself if `p` is on this linear entity.
850
+
851
+ See Also
852
+ ========
853
+
854
+ perpendicular_line
855
+
856
+ Examples
857
+ ========
858
+
859
+ >>> from sympy import Point, Line
860
+ >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(0, 2)
861
+ >>> l1 = Line(p1, p2)
862
+ >>> s1 = l1.perpendicular_segment(p3)
863
+ >>> l1.is_perpendicular(s1)
864
+ True
865
+ >>> p3 in s1
866
+ True
867
+ >>> l1.perpendicular_segment(Point(4, 0))
868
+ Segment2D(Point2D(4, 0), Point2D(2, 2))
869
+ >>> from sympy import Point3D, Line3D
870
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(0, 2, 0)
871
+ >>> l1 = Line3D(p1, p2)
872
+ >>> s1 = l1.perpendicular_segment(p3)
873
+ >>> l1.is_perpendicular(s1)
874
+ True
875
+ >>> p3 in s1
876
+ True
877
+ >>> l1.perpendicular_segment(Point3D(4, 0, 0))
878
+ Segment3D(Point3D(4, 0, 0), Point3D(4/3, 4/3, 4/3))
879
+
880
+ """
881
+ p = Point(p, dim=self.ambient_dimension)
882
+ if p in self:
883
+ return p
884
+ l = self.perpendicular_line(p)
885
+ # The intersection should be unique, so unpack the singleton
886
+ p2, = Intersection(Line(self.p1, self.p2), l)
887
+
888
+ return Segment(p, p2)
889
+
890
+ @property
891
+ def points(self):
892
+ """The two points used to define this linear entity.
893
+
894
+ Returns
895
+ =======
896
+
897
+ points : tuple of Points
898
+
899
+ See Also
900
+ ========
901
+
902
+ sympy.geometry.point.Point
903
+
904
+ Examples
905
+ ========
906
+
907
+ >>> from sympy import Point, Line
908
+ >>> p1, p2 = Point(0, 0), Point(5, 11)
909
+ >>> l1 = Line(p1, p2)
910
+ >>> l1.points
911
+ (Point2D(0, 0), Point2D(5, 11))
912
+
913
+ """
914
+ return (self.p1, self.p2)
915
+
916
+ def projection(self, other):
917
+ """Project a point, line, ray, or segment onto this linear entity.
918
+
919
+ Parameters
920
+ ==========
921
+
922
+ other : Point or LinearEntity (Line, Ray, Segment)
923
+
924
+ Returns
925
+ =======
926
+
927
+ projection : Point or LinearEntity (Line, Ray, Segment)
928
+ The return type matches the type of the parameter ``other``.
929
+
930
+ Raises
931
+ ======
932
+
933
+ GeometryError
934
+ When method is unable to perform projection.
935
+
936
+ Notes
937
+ =====
938
+
939
+ A projection involves taking the two points that define
940
+ the linear entity and projecting those points onto a
941
+ Line and then reforming the linear entity using these
942
+ projections.
943
+ A point P is projected onto a line L by finding the point
944
+ on L that is closest to P. This point is the intersection
945
+ of L and the line perpendicular to L that passes through P.
946
+
947
+ See Also
948
+ ========
949
+
950
+ sympy.geometry.point.Point, perpendicular_line
951
+
952
+ Examples
953
+ ========
954
+
955
+ >>> from sympy import Point, Line, Segment, Rational
956
+ >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(Rational(1, 2), 0)
957
+ >>> l1 = Line(p1, p2)
958
+ >>> l1.projection(p3)
959
+ Point2D(1/4, 1/4)
960
+ >>> p4, p5 = Point(10, 0), Point(12, 1)
961
+ >>> s1 = Segment(p4, p5)
962
+ >>> l1.projection(s1)
963
+ Segment2D(Point2D(5, 5), Point2D(13/2, 13/2))
964
+ >>> p1, p2, p3 = Point(0, 0, 1), Point(1, 1, 2), Point(2, 0, 1)
965
+ >>> l1 = Line(p1, p2)
966
+ >>> l1.projection(p3)
967
+ Point3D(2/3, 2/3, 5/3)
968
+ >>> p4, p5 = Point(10, 0, 1), Point(12, 1, 3)
969
+ >>> s1 = Segment(p4, p5)
970
+ >>> l1.projection(s1)
971
+ Segment3D(Point3D(10/3, 10/3, 13/3), Point3D(5, 5, 6))
972
+
973
+ """
974
+ if not isinstance(other, GeometryEntity):
975
+ other = Point(other, dim=self.ambient_dimension)
976
+
977
+ def proj_point(p):
978
+ return Point.project(p - self.p1, self.direction) + self.p1
979
+
980
+ if isinstance(other, Point):
981
+ return proj_point(other)
982
+ elif isinstance(other, LinearEntity):
983
+ p1, p2 = proj_point(other.p1), proj_point(other.p2)
984
+ # test to see if we're degenerate
985
+ if p1 == p2:
986
+ return p1
987
+ projected = other.__class__(p1, p2)
988
+ projected = Intersection(self, projected)
989
+ if projected.is_empty:
990
+ return projected
991
+ # if we happen to have intersected in only a point, return that
992
+ if projected.is_FiniteSet and len(projected) == 1:
993
+ # projected is a set of size 1, so unpack it in `a`
994
+ a, = projected
995
+ return a
996
+ # order args so projection is in the same direction as self
997
+ if self.direction.dot(projected.direction) < 0:
998
+ p1, p2 = projected.args
999
+ projected = projected.func(p2, p1)
1000
+ return projected
1001
+
1002
+ raise GeometryError(
1003
+ "Do not know how to project %s onto %s" % (other, self))
1004
+
1005
+ def random_point(self, seed=None):
1006
+ """A random point on a LinearEntity.
1007
+
1008
+ Returns
1009
+ =======
1010
+
1011
+ point : Point
1012
+
1013
+ See Also
1014
+ ========
1015
+
1016
+ sympy.geometry.point.Point
1017
+
1018
+ Examples
1019
+ ========
1020
+
1021
+ >>> from sympy import Point, Line, Ray, Segment
1022
+ >>> p1, p2 = Point(0, 0), Point(5, 3)
1023
+ >>> line = Line(p1, p2)
1024
+ >>> r = line.random_point(seed=42) # seed value is optional
1025
+ >>> r.n(3)
1026
+ Point2D(-0.72, -0.432)
1027
+ >>> r in line
1028
+ True
1029
+ >>> Ray(p1, p2).random_point(seed=42).n(3)
1030
+ Point2D(0.72, 0.432)
1031
+ >>> Segment(p1, p2).random_point(seed=42).n(3)
1032
+ Point2D(3.2, 1.92)
1033
+
1034
+ """
1035
+ if seed is not None:
1036
+ rng = random.Random(seed)
1037
+ else:
1038
+ rng = random
1039
+ pt = self.arbitrary_point(t)
1040
+ if isinstance(self, Ray):
1041
+ v = abs(rng.gauss(0, 1))
1042
+ elif isinstance(self, Segment):
1043
+ v = rng.random()
1044
+ elif isinstance(self, Line):
1045
+ v = rng.gauss(0, 1)
1046
+ else:
1047
+ raise NotImplementedError('unhandled line type')
1048
+ return pt.subs(t, Rational(v))
1049
+
1050
+ def bisectors(self, other):
1051
+ """Returns the perpendicular lines which pass through the intersections
1052
+ of self and other that are in the same plane.
1053
+
1054
+ Parameters
1055
+ ==========
1056
+
1057
+ line : Line3D
1058
+
1059
+ Returns
1060
+ =======
1061
+
1062
+ list: two Line instances
1063
+
1064
+ Examples
1065
+ ========
1066
+
1067
+ >>> from sympy import Point3D, Line3D
1068
+ >>> r1 = Line3D(Point3D(0, 0, 0), Point3D(1, 0, 0))
1069
+ >>> r2 = Line3D(Point3D(0, 0, 0), Point3D(0, 1, 0))
1070
+ >>> r1.bisectors(r2)
1071
+ [Line3D(Point3D(0, 0, 0), Point3D(1, 1, 0)), Line3D(Point3D(0, 0, 0), Point3D(1, -1, 0))]
1072
+
1073
+ """
1074
+ if not isinstance(other, LinearEntity):
1075
+ raise GeometryError("Expecting LinearEntity, not %s" % other)
1076
+
1077
+ l1, l2 = self, other
1078
+
1079
+ # make sure dimensions match or else a warning will rise from
1080
+ # intersection calculation
1081
+ if l1.p1.ambient_dimension != l2.p1.ambient_dimension:
1082
+ if isinstance(l1, Line2D):
1083
+ l1, l2 = l2, l1
1084
+ _, p1 = Point._normalize_dimension(l1.p1, l2.p1, on_morph='ignore')
1085
+ _, p2 = Point._normalize_dimension(l1.p2, l2.p2, on_morph='ignore')
1086
+ l2 = Line(p1, p2)
1087
+
1088
+ point = intersection(l1, l2)
1089
+
1090
+ # Three cases: Lines may intersect in a point, may be equal or may not intersect.
1091
+ if not point:
1092
+ raise GeometryError("The lines do not intersect")
1093
+ else:
1094
+ pt = point[0]
1095
+ if isinstance(pt, Line):
1096
+ # Intersection is a line because both lines are coincident
1097
+ return [self]
1098
+
1099
+
1100
+ d1 = l1.direction.unit
1101
+ d2 = l2.direction.unit
1102
+
1103
+ bis1 = Line(pt, pt + d1 + d2)
1104
+ bis2 = Line(pt, pt + d1 - d2)
1105
+
1106
+ return [bis1, bis2]
1107
+
1108
+
1109
+ class Line(LinearEntity):
1110
+ """An infinite line in space.
1111
+
1112
+ A 2D line is declared with two distinct points, point and slope, or
1113
+ an equation. A 3D line may be defined with a point and a direction ratio.
1114
+
1115
+ Parameters
1116
+ ==========
1117
+
1118
+ p1 : Point
1119
+ p2 : Point
1120
+ slope : SymPy expression
1121
+ direction_ratio : list
1122
+ equation : equation of a line
1123
+
1124
+ Notes
1125
+ =====
1126
+
1127
+ `Line` will automatically subclass to `Line2D` or `Line3D` based
1128
+ on the dimension of `p1`. The `slope` argument is only relevant
1129
+ for `Line2D` and the `direction_ratio` argument is only relevant
1130
+ for `Line3D`.
1131
+
1132
+ The order of the points will define the direction of the line
1133
+ which is used when calculating the angle between lines.
1134
+
1135
+ See Also
1136
+ ========
1137
+
1138
+ sympy.geometry.point.Point
1139
+ sympy.geometry.line.Line2D
1140
+ sympy.geometry.line.Line3D
1141
+
1142
+ Examples
1143
+ ========
1144
+
1145
+ >>> from sympy import Line, Segment, Point, Eq
1146
+ >>> from sympy.abc import x, y, a, b
1147
+
1148
+ >>> L = Line(Point(2,3), Point(3,5))
1149
+ >>> L
1150
+ Line2D(Point2D(2, 3), Point2D(3, 5))
1151
+ >>> L.points
1152
+ (Point2D(2, 3), Point2D(3, 5))
1153
+ >>> L.equation()
1154
+ -2*x + y + 1
1155
+ >>> L.coefficients
1156
+ (-2, 1, 1)
1157
+
1158
+ Instantiate with keyword ``slope``:
1159
+
1160
+ >>> Line(Point(0, 0), slope=0)
1161
+ Line2D(Point2D(0, 0), Point2D(1, 0))
1162
+
1163
+ Instantiate with another linear object
1164
+
1165
+ >>> s = Segment((0, 0), (0, 1))
1166
+ >>> Line(s).equation()
1167
+ x
1168
+
1169
+ The line corresponding to an equation in the for `ax + by + c = 0`,
1170
+ can be entered:
1171
+
1172
+ >>> Line(3*x + y + 18)
1173
+ Line2D(Point2D(0, -18), Point2D(1, -21))
1174
+
1175
+ If `x` or `y` has a different name, then they can be specified, too,
1176
+ as a string (to match the name) or symbol:
1177
+
1178
+ >>> Line(Eq(3*a + b, -18), x='a', y=b)
1179
+ Line2D(Point2D(0, -18), Point2D(1, -21))
1180
+ """
1181
+ def __new__(cls, *args, **kwargs):
1182
+ if len(args) == 1 and isinstance(args[0], (Expr, Eq)):
1183
+ missing = uniquely_named_symbol('?', args)
1184
+ if not kwargs:
1185
+ x = 'x'
1186
+ y = 'y'
1187
+ else:
1188
+ x = kwargs.pop('x', missing)
1189
+ y = kwargs.pop('y', missing)
1190
+ if kwargs:
1191
+ raise ValueError('expecting only x and y as keywords')
1192
+
1193
+ equation = args[0]
1194
+ if isinstance(equation, Eq):
1195
+ equation = equation.lhs - equation.rhs
1196
+
1197
+ def find_or_missing(x):
1198
+ try:
1199
+ return find(x, equation)
1200
+ except ValueError:
1201
+ return missing
1202
+ x = find_or_missing(x)
1203
+ y = find_or_missing(y)
1204
+
1205
+ a, b, c = linear_coeffs(equation, x, y)
1206
+
1207
+ if b:
1208
+ return Line((0, -c/b), slope=-a/b)
1209
+ if a:
1210
+ return Line((-c/a, 0), slope=oo)
1211
+
1212
+ raise ValueError('not found in equation: %s' % (set('xy') - {x, y}))
1213
+
1214
+ else:
1215
+ if len(args) > 0:
1216
+ p1 = args[0]
1217
+ if len(args) > 1:
1218
+ p2 = args[1]
1219
+ else:
1220
+ p2 = None
1221
+
1222
+ if isinstance(p1, LinearEntity):
1223
+ if p2:
1224
+ raise ValueError('If p1 is a LinearEntity, p2 must be None.')
1225
+ dim = len(p1.p1)
1226
+ else:
1227
+ p1 = Point(p1)
1228
+ dim = len(p1)
1229
+ if p2 is not None or isinstance(p2, Point) and p2.ambient_dimension != dim:
1230
+ p2 = Point(p2)
1231
+
1232
+ if dim == 2:
1233
+ return Line2D(p1, p2, **kwargs)
1234
+ elif dim == 3:
1235
+ return Line3D(p1, p2, **kwargs)
1236
+ return LinearEntity.__new__(cls, p1, p2, **kwargs)
1237
+
1238
+ def contains(self, other):
1239
+ """
1240
+ Return True if `other` is on this Line, or False otherwise.
1241
+
1242
+ Examples
1243
+ ========
1244
+
1245
+ >>> from sympy import Line,Point
1246
+ >>> p1, p2 = Point(0, 1), Point(3, 4)
1247
+ >>> l = Line(p1, p2)
1248
+ >>> l.contains(p1)
1249
+ True
1250
+ >>> l.contains((0, 1))
1251
+ True
1252
+ >>> l.contains((0, 0))
1253
+ False
1254
+ >>> a = (0, 0, 0)
1255
+ >>> b = (1, 1, 1)
1256
+ >>> c = (2, 2, 2)
1257
+ >>> l1 = Line(a, b)
1258
+ >>> l2 = Line(b, a)
1259
+ >>> l1 == l2
1260
+ False
1261
+ >>> l1 in l2
1262
+ True
1263
+
1264
+ """
1265
+ if not isinstance(other, GeometryEntity):
1266
+ other = Point(other, dim=self.ambient_dimension)
1267
+ if isinstance(other, Point):
1268
+ return Point.is_collinear(other, self.p1, self.p2)
1269
+ if isinstance(other, LinearEntity):
1270
+ return Point.is_collinear(self.p1, self.p2, other.p1, other.p2)
1271
+ return False
1272
+
1273
+ def distance(self, other):
1274
+ """
1275
+ Finds the shortest distance between a line and a point.
1276
+
1277
+ Raises
1278
+ ======
1279
+
1280
+ NotImplementedError is raised if `other` is not a Point
1281
+
1282
+ Examples
1283
+ ========
1284
+
1285
+ >>> from sympy import Point, Line
1286
+ >>> p1, p2 = Point(0, 0), Point(1, 1)
1287
+ >>> s = Line(p1, p2)
1288
+ >>> s.distance(Point(-1, 1))
1289
+ sqrt(2)
1290
+ >>> s.distance((-1, 2))
1291
+ 3*sqrt(2)/2
1292
+ >>> p1, p2 = Point(0, 0, 0), Point(1, 1, 1)
1293
+ >>> s = Line(p1, p2)
1294
+ >>> s.distance(Point(-1, 1, 1))
1295
+ 2*sqrt(6)/3
1296
+ >>> s.distance((-1, 1, 1))
1297
+ 2*sqrt(6)/3
1298
+
1299
+ """
1300
+ if not isinstance(other, GeometryEntity):
1301
+ other = Point(other, dim=self.ambient_dimension)
1302
+ if self.contains(other):
1303
+ return S.Zero
1304
+ return self.perpendicular_segment(other).length
1305
+
1306
+ def equals(self, other):
1307
+ """Returns True if self and other are the same mathematical entities"""
1308
+ if not isinstance(other, Line):
1309
+ return False
1310
+ return Point.is_collinear(self.p1, other.p1, self.p2, other.p2)
1311
+
1312
+ def plot_interval(self, parameter='t'):
1313
+ """The plot interval for the default geometric plot of line. Gives
1314
+ values that will produce a line that is +/- 5 units long (where a
1315
+ unit is the distance between the two points that define the line).
1316
+
1317
+ Parameters
1318
+ ==========
1319
+
1320
+ parameter : str, optional
1321
+ Default value is 't'.
1322
+
1323
+ Returns
1324
+ =======
1325
+
1326
+ plot_interval : list (plot interval)
1327
+ [parameter, lower_bound, upper_bound]
1328
+
1329
+ Examples
1330
+ ========
1331
+
1332
+ >>> from sympy import Point, Line
1333
+ >>> p1, p2 = Point(0, 0), Point(5, 3)
1334
+ >>> l1 = Line(p1, p2)
1335
+ >>> l1.plot_interval()
1336
+ [t, -5, 5]
1337
+
1338
+ """
1339
+ t = _symbol(parameter, real=True)
1340
+ return [t, -5, 5]
1341
+
1342
+
1343
+ class Ray(LinearEntity):
1344
+ """A Ray is a semi-line in the space with a source point and a direction.
1345
+
1346
+ Parameters
1347
+ ==========
1348
+
1349
+ p1 : Point
1350
+ The source of the Ray
1351
+ p2 : Point or radian value
1352
+ This point determines the direction in which the Ray propagates.
1353
+ If given as an angle it is interpreted in radians with the positive
1354
+ direction being ccw.
1355
+
1356
+ Attributes
1357
+ ==========
1358
+
1359
+ source
1360
+
1361
+ See Also
1362
+ ========
1363
+
1364
+ sympy.geometry.line.Ray2D
1365
+ sympy.geometry.line.Ray3D
1366
+ sympy.geometry.point.Point
1367
+ sympy.geometry.line.Line
1368
+
1369
+ Notes
1370
+ =====
1371
+
1372
+ `Ray` will automatically subclass to `Ray2D` or `Ray3D` based on the
1373
+ dimension of `p1`.
1374
+
1375
+ Examples
1376
+ ========
1377
+
1378
+ >>> from sympy import Ray, Point, pi
1379
+ >>> r = Ray(Point(2, 3), Point(3, 5))
1380
+ >>> r
1381
+ Ray2D(Point2D(2, 3), Point2D(3, 5))
1382
+ >>> r.points
1383
+ (Point2D(2, 3), Point2D(3, 5))
1384
+ >>> r.source
1385
+ Point2D(2, 3)
1386
+ >>> r.xdirection
1387
+ oo
1388
+ >>> r.ydirection
1389
+ oo
1390
+ >>> r.slope
1391
+ 2
1392
+ >>> Ray(Point(0, 0), angle=pi/4).slope
1393
+ 1
1394
+
1395
+ """
1396
+ def __new__(cls, p1, p2=None, **kwargs):
1397
+ p1 = Point(p1)
1398
+ if p2 is not None:
1399
+ p1, p2 = Point._normalize_dimension(p1, Point(p2))
1400
+ dim = len(p1)
1401
+
1402
+ if dim == 2:
1403
+ return Ray2D(p1, p2, **kwargs)
1404
+ elif dim == 3:
1405
+ return Ray3D(p1, p2, **kwargs)
1406
+ return LinearEntity.__new__(cls, p1, p2, **kwargs)
1407
+
1408
+ def _svg(self, scale_factor=1., fill_color="#66cc99"):
1409
+ """Returns SVG path element for the LinearEntity.
1410
+
1411
+ Parameters
1412
+ ==========
1413
+
1414
+ scale_factor : float
1415
+ Multiplication factor for the SVG stroke-width. Default is 1.
1416
+ fill_color : str, optional
1417
+ Hex string for fill color. Default is "#66cc99".
1418
+ """
1419
+ verts = (N(self.p1), N(self.p2))
1420
+ coords = ["{},{}".format(p.x, p.y) for p in verts]
1421
+ path = "M {} L {}".format(coords[0], " L ".join(coords[1:]))
1422
+
1423
+ return (
1424
+ '<path fill-rule="evenodd" fill="{2}" stroke="#555555" '
1425
+ 'stroke-width="{0}" opacity="0.6" d="{1}" '
1426
+ 'marker-start="url(#markerCircle)" marker-end="url(#markerArrow)"/>'
1427
+ ).format(2.*scale_factor, path, fill_color)
1428
+
1429
+ def contains(self, other):
1430
+ """
1431
+ Is other GeometryEntity contained in this Ray?
1432
+
1433
+ Examples
1434
+ ========
1435
+
1436
+ >>> from sympy import Ray,Point,Segment
1437
+ >>> p1, p2 = Point(0, 0), Point(4, 4)
1438
+ >>> r = Ray(p1, p2)
1439
+ >>> r.contains(p1)
1440
+ True
1441
+ >>> r.contains((1, 1))
1442
+ True
1443
+ >>> r.contains((1, 3))
1444
+ False
1445
+ >>> s = Segment((1, 1), (2, 2))
1446
+ >>> r.contains(s)
1447
+ True
1448
+ >>> s = Segment((1, 2), (2, 5))
1449
+ >>> r.contains(s)
1450
+ False
1451
+ >>> r1 = Ray((2, 2), (3, 3))
1452
+ >>> r.contains(r1)
1453
+ True
1454
+ >>> r1 = Ray((2, 2), (3, 5))
1455
+ >>> r.contains(r1)
1456
+ False
1457
+ """
1458
+ if not isinstance(other, GeometryEntity):
1459
+ other = Point(other, dim=self.ambient_dimension)
1460
+ if isinstance(other, Point):
1461
+ if Point.is_collinear(self.p1, self.p2, other):
1462
+ # if we're in the direction of the ray, our
1463
+ # direction vector dot the ray's direction vector
1464
+ # should be non-negative
1465
+ return bool((self.p2 - self.p1).dot(other - self.p1) >= S.Zero)
1466
+ return False
1467
+ elif isinstance(other, Ray):
1468
+ if Point.is_collinear(self.p1, self.p2, other.p1, other.p2):
1469
+ return bool((self.p2 - self.p1).dot(other.p2 - other.p1) > S.Zero)
1470
+ return False
1471
+ elif isinstance(other, Segment):
1472
+ return other.p1 in self and other.p2 in self
1473
+
1474
+ # No other known entity can be contained in a Ray
1475
+ return False
1476
+
1477
+ def distance(self, other):
1478
+ """
1479
+ Finds the shortest distance between the ray and a point.
1480
+
1481
+ Raises
1482
+ ======
1483
+
1484
+ NotImplementedError is raised if `other` is not a Point
1485
+
1486
+ Examples
1487
+ ========
1488
+
1489
+ >>> from sympy import Point, Ray
1490
+ >>> p1, p2 = Point(0, 0), Point(1, 1)
1491
+ >>> s = Ray(p1, p2)
1492
+ >>> s.distance(Point(-1, -1))
1493
+ sqrt(2)
1494
+ >>> s.distance((-1, 2))
1495
+ 3*sqrt(2)/2
1496
+ >>> p1, p2 = Point(0, 0, 0), Point(1, 1, 2)
1497
+ >>> s = Ray(p1, p2)
1498
+ >>> s
1499
+ Ray3D(Point3D(0, 0, 0), Point3D(1, 1, 2))
1500
+ >>> s.distance(Point(-1, -1, 2))
1501
+ 4*sqrt(3)/3
1502
+ >>> s.distance((-1, -1, 2))
1503
+ 4*sqrt(3)/3
1504
+
1505
+ """
1506
+ if not isinstance(other, GeometryEntity):
1507
+ other = Point(other, dim=self.ambient_dimension)
1508
+ if self.contains(other):
1509
+ return S.Zero
1510
+
1511
+ proj = Line(self.p1, self.p2).projection(other)
1512
+ if self.contains(proj):
1513
+ return abs(other - proj)
1514
+ else:
1515
+ return abs(other - self.source)
1516
+
1517
+ def equals(self, other):
1518
+ """Returns True if self and other are the same mathematical entities"""
1519
+ if not isinstance(other, Ray):
1520
+ return False
1521
+ return self.source == other.source and other.p2 in self
1522
+
1523
+ def plot_interval(self, parameter='t'):
1524
+ """The plot interval for the default geometric plot of the Ray. Gives
1525
+ values that will produce a ray that is 10 units long (where a unit is
1526
+ the distance between the two points that define the ray).
1527
+
1528
+ Parameters
1529
+ ==========
1530
+
1531
+ parameter : str, optional
1532
+ Default value is 't'.
1533
+
1534
+ Returns
1535
+ =======
1536
+
1537
+ plot_interval : list
1538
+ [parameter, lower_bound, upper_bound]
1539
+
1540
+ Examples
1541
+ ========
1542
+
1543
+ >>> from sympy import Ray, pi
1544
+ >>> r = Ray((0, 0), angle=pi/4)
1545
+ >>> r.plot_interval()
1546
+ [t, 0, 10]
1547
+
1548
+ """
1549
+ t = _symbol(parameter, real=True)
1550
+ return [t, 0, 10]
1551
+
1552
+ @property
1553
+ def source(self):
1554
+ """The point from which the ray emanates.
1555
+
1556
+ See Also
1557
+ ========
1558
+
1559
+ sympy.geometry.point.Point
1560
+
1561
+ Examples
1562
+ ========
1563
+
1564
+ >>> from sympy import Point, Ray
1565
+ >>> p1, p2 = Point(0, 0), Point(4, 1)
1566
+ >>> r1 = Ray(p1, p2)
1567
+ >>> r1.source
1568
+ Point2D(0, 0)
1569
+ >>> p1, p2 = Point(0, 0, 0), Point(4, 1, 5)
1570
+ >>> r1 = Ray(p2, p1)
1571
+ >>> r1.source
1572
+ Point3D(4, 1, 5)
1573
+
1574
+ """
1575
+ return self.p1
1576
+
1577
+
1578
+ class Segment(LinearEntity):
1579
+ """A line segment in space.
1580
+
1581
+ Parameters
1582
+ ==========
1583
+
1584
+ p1 : Point
1585
+ p2 : Point
1586
+
1587
+ Attributes
1588
+ ==========
1589
+
1590
+ length : number or SymPy expression
1591
+ midpoint : Point
1592
+
1593
+ See Also
1594
+ ========
1595
+
1596
+ sympy.geometry.line.Segment2D
1597
+ sympy.geometry.line.Segment3D
1598
+ sympy.geometry.point.Point
1599
+ sympy.geometry.line.Line
1600
+
1601
+ Notes
1602
+ =====
1603
+
1604
+ If 2D or 3D points are used to define `Segment`, it will
1605
+ be automatically subclassed to `Segment2D` or `Segment3D`.
1606
+
1607
+ Examples
1608
+ ========
1609
+
1610
+ >>> from sympy import Point, Segment
1611
+ >>> Segment((1, 0), (1, 1)) # tuples are interpreted as pts
1612
+ Segment2D(Point2D(1, 0), Point2D(1, 1))
1613
+ >>> s = Segment(Point(4, 3), Point(1, 1))
1614
+ >>> s.points
1615
+ (Point2D(4, 3), Point2D(1, 1))
1616
+ >>> s.slope
1617
+ 2/3
1618
+ >>> s.length
1619
+ sqrt(13)
1620
+ >>> s.midpoint
1621
+ Point2D(5/2, 2)
1622
+ >>> Segment((1, 0, 0), (1, 1, 1)) # tuples are interpreted as pts
1623
+ Segment3D(Point3D(1, 0, 0), Point3D(1, 1, 1))
1624
+ >>> s = Segment(Point(4, 3, 9), Point(1, 1, 7)); s
1625
+ Segment3D(Point3D(4, 3, 9), Point3D(1, 1, 7))
1626
+ >>> s.points
1627
+ (Point3D(4, 3, 9), Point3D(1, 1, 7))
1628
+ >>> s.length
1629
+ sqrt(17)
1630
+ >>> s.midpoint
1631
+ Point3D(5/2, 2, 8)
1632
+
1633
+ """
1634
+ def __new__(cls, p1, p2, **kwargs):
1635
+ p1, p2 = Point._normalize_dimension(Point(p1), Point(p2))
1636
+ dim = len(p1)
1637
+
1638
+ if dim == 2:
1639
+ return Segment2D(p1, p2, **kwargs)
1640
+ elif dim == 3:
1641
+ return Segment3D(p1, p2, **kwargs)
1642
+ return LinearEntity.__new__(cls, p1, p2, **kwargs)
1643
+
1644
+ def contains(self, other):
1645
+ """
1646
+ Is the other GeometryEntity contained within this Segment?
1647
+
1648
+ Examples
1649
+ ========
1650
+
1651
+ >>> from sympy import Point, Segment
1652
+ >>> p1, p2 = Point(0, 1), Point(3, 4)
1653
+ >>> s = Segment(p1, p2)
1654
+ >>> s2 = Segment(p2, p1)
1655
+ >>> s.contains(s2)
1656
+ True
1657
+ >>> from sympy import Point3D, Segment3D
1658
+ >>> p1, p2 = Point3D(0, 1, 1), Point3D(3, 4, 5)
1659
+ >>> s = Segment3D(p1, p2)
1660
+ >>> s2 = Segment3D(p2, p1)
1661
+ >>> s.contains(s2)
1662
+ True
1663
+ >>> s.contains((p1 + p2)/2)
1664
+ True
1665
+ """
1666
+ if not isinstance(other, GeometryEntity):
1667
+ other = Point(other, dim=self.ambient_dimension)
1668
+ if isinstance(other, Point):
1669
+ if Point.is_collinear(other, self.p1, self.p2):
1670
+ if isinstance(self, Segment2D):
1671
+ # if it is collinear and is in the bounding box of the
1672
+ # segment then it must be on the segment
1673
+ vert = (1/self.slope).equals(0)
1674
+ if vert is False:
1675
+ isin = (self.p1.x - other.x)*(self.p2.x - other.x) <= 0
1676
+ if isin in (True, False):
1677
+ return isin
1678
+ if vert is True:
1679
+ isin = (self.p1.y - other.y)*(self.p2.y - other.y) <= 0
1680
+ if isin in (True, False):
1681
+ return isin
1682
+ # use the triangle inequality
1683
+ d1, d2 = other - self.p1, other - self.p2
1684
+ d = self.p2 - self.p1
1685
+ # without the call to simplify, SymPy cannot tell that an expression
1686
+ # like (a+b)*(a/2+b/2) is always non-negative. If it cannot be
1687
+ # determined, raise an Undecidable error
1688
+ try:
1689
+ # the triangle inequality says that |d1|+|d2| >= |d| and is strict
1690
+ # only if other lies in the line segment
1691
+ return bool(simplify(Eq(abs(d1) + abs(d2) - abs(d), 0)))
1692
+ except TypeError:
1693
+ raise Undecidable("Cannot determine if {} is in {}".format(other, self))
1694
+ if isinstance(other, Segment):
1695
+ return other.p1 in self and other.p2 in self
1696
+
1697
+ return False
1698
+
1699
+ def equals(self, other):
1700
+ """Returns True if self and other are the same mathematical entities"""
1701
+ return isinstance(other, self.func) and list(
1702
+ ordered(self.args)) == list(ordered(other.args))
1703
+
1704
+ def distance(self, other):
1705
+ """
1706
+ Finds the shortest distance between a line segment and a point.
1707
+
1708
+ Raises
1709
+ ======
1710
+
1711
+ NotImplementedError is raised if `other` is not a Point
1712
+
1713
+ Examples
1714
+ ========
1715
+
1716
+ >>> from sympy import Point, Segment
1717
+ >>> p1, p2 = Point(0, 1), Point(3, 4)
1718
+ >>> s = Segment(p1, p2)
1719
+ >>> s.distance(Point(10, 15))
1720
+ sqrt(170)
1721
+ >>> s.distance((0, 12))
1722
+ sqrt(73)
1723
+ >>> from sympy import Point3D, Segment3D
1724
+ >>> p1, p2 = Point3D(0, 0, 3), Point3D(1, 1, 4)
1725
+ >>> s = Segment3D(p1, p2)
1726
+ >>> s.distance(Point3D(10, 15, 12))
1727
+ sqrt(341)
1728
+ >>> s.distance((10, 15, 12))
1729
+ sqrt(341)
1730
+ """
1731
+ if not isinstance(other, GeometryEntity):
1732
+ other = Point(other, dim=self.ambient_dimension)
1733
+ if isinstance(other, Point):
1734
+ vp1 = other - self.p1
1735
+ vp2 = other - self.p2
1736
+
1737
+ dot_prod_sign_1 = self.direction.dot(vp1) >= 0
1738
+ dot_prod_sign_2 = self.direction.dot(vp2) <= 0
1739
+ if dot_prod_sign_1 and dot_prod_sign_2:
1740
+ return Line(self.p1, self.p2).distance(other)
1741
+ if dot_prod_sign_1 and not dot_prod_sign_2:
1742
+ return abs(vp2)
1743
+ if not dot_prod_sign_1 and dot_prod_sign_2:
1744
+ return abs(vp1)
1745
+ raise NotImplementedError()
1746
+
1747
+ @property
1748
+ def length(self):
1749
+ """The length of the line segment.
1750
+
1751
+ See Also
1752
+ ========
1753
+
1754
+ sympy.geometry.point.Point.distance
1755
+
1756
+ Examples
1757
+ ========
1758
+
1759
+ >>> from sympy import Point, Segment
1760
+ >>> p1, p2 = Point(0, 0), Point(4, 3)
1761
+ >>> s1 = Segment(p1, p2)
1762
+ >>> s1.length
1763
+ 5
1764
+ >>> from sympy import Point3D, Segment3D
1765
+ >>> p1, p2 = Point3D(0, 0, 0), Point3D(4, 3, 3)
1766
+ >>> s1 = Segment3D(p1, p2)
1767
+ >>> s1.length
1768
+ sqrt(34)
1769
+
1770
+ """
1771
+ return Point.distance(self.p1, self.p2)
1772
+
1773
+ @property
1774
+ def midpoint(self):
1775
+ """The midpoint of the line segment.
1776
+
1777
+ See Also
1778
+ ========
1779
+
1780
+ sympy.geometry.point.Point.midpoint
1781
+
1782
+ Examples
1783
+ ========
1784
+
1785
+ >>> from sympy import Point, Segment
1786
+ >>> p1, p2 = Point(0, 0), Point(4, 3)
1787
+ >>> s1 = Segment(p1, p2)
1788
+ >>> s1.midpoint
1789
+ Point2D(2, 3/2)
1790
+ >>> from sympy import Point3D, Segment3D
1791
+ >>> p1, p2 = Point3D(0, 0, 0), Point3D(4, 3, 3)
1792
+ >>> s1 = Segment3D(p1, p2)
1793
+ >>> s1.midpoint
1794
+ Point3D(2, 3/2, 3/2)
1795
+
1796
+ """
1797
+ return Point.midpoint(self.p1, self.p2)
1798
+
1799
+ def perpendicular_bisector(self, p=None):
1800
+ """The perpendicular bisector of this segment.
1801
+
1802
+ If no point is specified or the point specified is not on the
1803
+ bisector then the bisector is returned as a Line. Otherwise a
1804
+ Segment is returned that joins the point specified and the
1805
+ intersection of the bisector and the segment.
1806
+
1807
+ Parameters
1808
+ ==========
1809
+
1810
+ p : Point
1811
+
1812
+ Returns
1813
+ =======
1814
+
1815
+ bisector : Line or Segment
1816
+
1817
+ See Also
1818
+ ========
1819
+
1820
+ LinearEntity.perpendicular_segment
1821
+
1822
+ Examples
1823
+ ========
1824
+
1825
+ >>> from sympy import Point, Segment
1826
+ >>> p1, p2, p3 = Point(0, 0), Point(6, 6), Point(5, 1)
1827
+ >>> s1 = Segment(p1, p2)
1828
+ >>> s1.perpendicular_bisector()
1829
+ Line2D(Point2D(3, 3), Point2D(-3, 9))
1830
+
1831
+ >>> s1.perpendicular_bisector(p3)
1832
+ Segment2D(Point2D(5, 1), Point2D(3, 3))
1833
+
1834
+ """
1835
+ l = self.perpendicular_line(self.midpoint)
1836
+ if p is not None:
1837
+ p2 = Point(p, dim=self.ambient_dimension)
1838
+ if p2 in l:
1839
+ return Segment(p2, self.midpoint)
1840
+ return l
1841
+
1842
+ def plot_interval(self, parameter='t'):
1843
+ """The plot interval for the default geometric plot of the Segment gives
1844
+ values that will produce the full segment in a plot.
1845
+
1846
+ Parameters
1847
+ ==========
1848
+
1849
+ parameter : str, optional
1850
+ Default value is 't'.
1851
+
1852
+ Returns
1853
+ =======
1854
+
1855
+ plot_interval : list
1856
+ [parameter, lower_bound, upper_bound]
1857
+
1858
+ Examples
1859
+ ========
1860
+
1861
+ >>> from sympy import Point, Segment
1862
+ >>> p1, p2 = Point(0, 0), Point(5, 3)
1863
+ >>> s1 = Segment(p1, p2)
1864
+ >>> s1.plot_interval()
1865
+ [t, 0, 1]
1866
+
1867
+ """
1868
+ t = _symbol(parameter, real=True)
1869
+ return [t, 0, 1]
1870
+
1871
+
1872
+ class LinearEntity2D(LinearEntity):
1873
+ """A base class for all linear entities (line, ray and segment)
1874
+ in a 2-dimensional Euclidean space.
1875
+
1876
+ Attributes
1877
+ ==========
1878
+
1879
+ p1
1880
+ p2
1881
+ coefficients
1882
+ slope
1883
+ points
1884
+
1885
+ Notes
1886
+ =====
1887
+
1888
+ This is an abstract class and is not meant to be instantiated.
1889
+
1890
+ See Also
1891
+ ========
1892
+
1893
+ sympy.geometry.entity.GeometryEntity
1894
+
1895
+ """
1896
+ @property
1897
+ def bounds(self):
1898
+ """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding
1899
+ rectangle for the geometric figure.
1900
+
1901
+ """
1902
+ verts = self.points
1903
+ xs = [p.x for p in verts]
1904
+ ys = [p.y for p in verts]
1905
+ return (min(xs), min(ys), max(xs), max(ys))
1906
+
1907
+ def perpendicular_line(self, p):
1908
+ """Create a new Line perpendicular to this linear entity which passes
1909
+ through the point `p`.
1910
+
1911
+ Parameters
1912
+ ==========
1913
+
1914
+ p : Point
1915
+
1916
+ Returns
1917
+ =======
1918
+
1919
+ line : Line
1920
+
1921
+ See Also
1922
+ ========
1923
+
1924
+ sympy.geometry.line.LinearEntity.is_perpendicular, perpendicular_segment
1925
+
1926
+ Examples
1927
+ ========
1928
+
1929
+ >>> from sympy import Point, Line
1930
+ >>> p1, p2, p3 = Point(0, 0), Point(2, 3), Point(-2, 2)
1931
+ >>> L = Line(p1, p2)
1932
+ >>> P = L.perpendicular_line(p3); P
1933
+ Line2D(Point2D(-2, 2), Point2D(-5, 4))
1934
+ >>> L.is_perpendicular(P)
1935
+ True
1936
+
1937
+ In 2D, the first point of the perpendicular line is the
1938
+ point through which was required to pass; the second
1939
+ point is arbitrarily chosen. To get a line that explicitly
1940
+ uses a point in the line, create a line from the perpendicular
1941
+ segment from the line to the point:
1942
+
1943
+ >>> Line(L.perpendicular_segment(p3))
1944
+ Line2D(Point2D(-2, 2), Point2D(4/13, 6/13))
1945
+ """
1946
+ p = Point(p, dim=self.ambient_dimension)
1947
+ # any two lines in R^2 intersect, so blindly making
1948
+ # a line through p in an orthogonal direction will work
1949
+ # and is faster than finding the projection point as in 3D
1950
+ return Line(p, p + self.direction.orthogonal_direction)
1951
+
1952
+ @property
1953
+ def slope(self):
1954
+ """The slope of this linear entity, or infinity if vertical.
1955
+
1956
+ Returns
1957
+ =======
1958
+
1959
+ slope : number or SymPy expression
1960
+
1961
+ See Also
1962
+ ========
1963
+
1964
+ coefficients
1965
+
1966
+ Examples
1967
+ ========
1968
+
1969
+ >>> from sympy import Point, Line
1970
+ >>> p1, p2 = Point(0, 0), Point(3, 5)
1971
+ >>> l1 = Line(p1, p2)
1972
+ >>> l1.slope
1973
+ 5/3
1974
+
1975
+ >>> p3 = Point(0, 4)
1976
+ >>> l2 = Line(p1, p3)
1977
+ >>> l2.slope
1978
+ oo
1979
+
1980
+ """
1981
+ d1, d2 = (self.p1 - self.p2).args
1982
+ if d1 == 0:
1983
+ return S.Infinity
1984
+ return simplify(d2/d1)
1985
+
1986
+
1987
+ class Line2D(LinearEntity2D, Line):
1988
+ """An infinite line in space 2D.
1989
+
1990
+ A line is declared with two distinct points or a point and slope
1991
+ as defined using keyword `slope`.
1992
+
1993
+ Parameters
1994
+ ==========
1995
+
1996
+ p1 : Point
1997
+ pt : Point
1998
+ slope : SymPy expression
1999
+
2000
+ See Also
2001
+ ========
2002
+
2003
+ sympy.geometry.point.Point
2004
+
2005
+ Examples
2006
+ ========
2007
+
2008
+ >>> from sympy import Line, Segment, Point
2009
+ >>> L = Line(Point(2,3), Point(3,5))
2010
+ >>> L
2011
+ Line2D(Point2D(2, 3), Point2D(3, 5))
2012
+ >>> L.points
2013
+ (Point2D(2, 3), Point2D(3, 5))
2014
+ >>> L.equation()
2015
+ -2*x + y + 1
2016
+ >>> L.coefficients
2017
+ (-2, 1, 1)
2018
+
2019
+ Instantiate with keyword ``slope``:
2020
+
2021
+ >>> Line(Point(0, 0), slope=0)
2022
+ Line2D(Point2D(0, 0), Point2D(1, 0))
2023
+
2024
+ Instantiate with another linear object
2025
+
2026
+ >>> s = Segment((0, 0), (0, 1))
2027
+ >>> Line(s).equation()
2028
+ x
2029
+ """
2030
+ def __new__(cls, p1, pt=None, slope=None, **kwargs):
2031
+ if isinstance(p1, LinearEntity):
2032
+ if pt is not None:
2033
+ raise ValueError('When p1 is a LinearEntity, pt should be None')
2034
+ p1, pt = Point._normalize_dimension(*p1.args, dim=2)
2035
+ else:
2036
+ p1 = Point(p1, dim=2)
2037
+ if pt is not None and slope is None:
2038
+ try:
2039
+ p2 = Point(pt, dim=2)
2040
+ except (NotImplementedError, TypeError, ValueError):
2041
+ raise ValueError(filldedent('''
2042
+ The 2nd argument was not a valid Point.
2043
+ If it was a slope, enter it with keyword "slope".
2044
+ '''))
2045
+ elif slope is not None and pt is None:
2046
+ slope = sympify(slope)
2047
+ if slope.is_finite is False:
2048
+ # when infinite slope, don't change x
2049
+ dx = 0
2050
+ dy = 1
2051
+ else:
2052
+ # go over 1 up slope
2053
+ dx = 1
2054
+ dy = slope
2055
+ # XXX avoiding simplification by adding to coords directly
2056
+ p2 = Point(p1.x + dx, p1.y + dy, evaluate=False)
2057
+ else:
2058
+ raise ValueError('A 2nd Point or keyword "slope" must be used.')
2059
+ return LinearEntity2D.__new__(cls, p1, p2, **kwargs)
2060
+
2061
+ def _svg(self, scale_factor=1., fill_color="#66cc99"):
2062
+ """Returns SVG path element for the LinearEntity.
2063
+
2064
+ Parameters
2065
+ ==========
2066
+
2067
+ scale_factor : float
2068
+ Multiplication factor for the SVG stroke-width. Default is 1.
2069
+ fill_color : str, optional
2070
+ Hex string for fill color. Default is "#66cc99".
2071
+ """
2072
+ verts = (N(self.p1), N(self.p2))
2073
+ coords = ["{},{}".format(p.x, p.y) for p in verts]
2074
+ path = "M {} L {}".format(coords[0], " L ".join(coords[1:]))
2075
+
2076
+ return (
2077
+ '<path fill-rule="evenodd" fill="{2}" stroke="#555555" '
2078
+ 'stroke-width="{0}" opacity="0.6" d="{1}" '
2079
+ 'marker-start="url(#markerReverseArrow)" marker-end="url(#markerArrow)"/>'
2080
+ ).format(2.*scale_factor, path, fill_color)
2081
+
2082
+ @property
2083
+ def coefficients(self):
2084
+ """The coefficients (`a`, `b`, `c`) for `ax + by + c = 0`.
2085
+
2086
+ See Also
2087
+ ========
2088
+
2089
+ sympy.geometry.line.Line2D.equation
2090
+
2091
+ Examples
2092
+ ========
2093
+
2094
+ >>> from sympy import Point, Line
2095
+ >>> from sympy.abc import x, y
2096
+ >>> p1, p2 = Point(0, 0), Point(5, 3)
2097
+ >>> l = Line(p1, p2)
2098
+ >>> l.coefficients
2099
+ (-3, 5, 0)
2100
+
2101
+ >>> p3 = Point(x, y)
2102
+ >>> l2 = Line(p1, p3)
2103
+ >>> l2.coefficients
2104
+ (-y, x, 0)
2105
+
2106
+ """
2107
+ p1, p2 = self.points
2108
+ if p1.x == p2.x:
2109
+ return (S.One, S.Zero, -p1.x)
2110
+ elif p1.y == p2.y:
2111
+ return (S.Zero, S.One, -p1.y)
2112
+ return tuple([simplify(i) for i in
2113
+ (self.p1.y - self.p2.y,
2114
+ self.p2.x - self.p1.x,
2115
+ self.p1.x*self.p2.y - self.p1.y*self.p2.x)])
2116
+
2117
+ def equation(self, x='x', y='y'):
2118
+ """The equation of the line: ax + by + c.
2119
+
2120
+ Parameters
2121
+ ==========
2122
+
2123
+ x : str, optional
2124
+ The name to use for the x-axis, default value is 'x'.
2125
+ y : str, optional
2126
+ The name to use for the y-axis, default value is 'y'.
2127
+
2128
+ Returns
2129
+ =======
2130
+
2131
+ equation : SymPy expression
2132
+
2133
+ See Also
2134
+ ========
2135
+
2136
+ sympy.geometry.line.Line2D.coefficients
2137
+
2138
+ Examples
2139
+ ========
2140
+
2141
+ >>> from sympy import Point, Line
2142
+ >>> p1, p2 = Point(1, 0), Point(5, 3)
2143
+ >>> l1 = Line(p1, p2)
2144
+ >>> l1.equation()
2145
+ -3*x + 4*y + 3
2146
+
2147
+ """
2148
+ x = _symbol(x, real=True)
2149
+ y = _symbol(y, real=True)
2150
+ p1, p2 = self.points
2151
+ if p1.x == p2.x:
2152
+ return x - p1.x
2153
+ elif p1.y == p2.y:
2154
+ return y - p1.y
2155
+
2156
+ a, b, c = self.coefficients
2157
+ return a*x + b*y + c
2158
+
2159
+
2160
+ class Ray2D(LinearEntity2D, Ray):
2161
+ """
2162
+ A Ray is a semi-line in the space with a source point and a direction.
2163
+
2164
+ Parameters
2165
+ ==========
2166
+
2167
+ p1 : Point
2168
+ The source of the Ray
2169
+ p2 : Point or radian value
2170
+ This point determines the direction in which the Ray propagates.
2171
+ If given as an angle it is interpreted in radians with the positive
2172
+ direction being ccw.
2173
+
2174
+ Attributes
2175
+ ==========
2176
+
2177
+ source
2178
+ xdirection
2179
+ ydirection
2180
+
2181
+ See Also
2182
+ ========
2183
+
2184
+ sympy.geometry.point.Point, Line
2185
+
2186
+ Examples
2187
+ ========
2188
+
2189
+ >>> from sympy import Point, pi, Ray
2190
+ >>> r = Ray(Point(2, 3), Point(3, 5))
2191
+ >>> r
2192
+ Ray2D(Point2D(2, 3), Point2D(3, 5))
2193
+ >>> r.points
2194
+ (Point2D(2, 3), Point2D(3, 5))
2195
+ >>> r.source
2196
+ Point2D(2, 3)
2197
+ >>> r.xdirection
2198
+ oo
2199
+ >>> r.ydirection
2200
+ oo
2201
+ >>> r.slope
2202
+ 2
2203
+ >>> Ray(Point(0, 0), angle=pi/4).slope
2204
+ 1
2205
+
2206
+ """
2207
+ def __new__(cls, p1, pt=None, angle=None, **kwargs):
2208
+ p1 = Point(p1, dim=2)
2209
+ if pt is not None and angle is None:
2210
+ try:
2211
+ p2 = Point(pt, dim=2)
2212
+ except (NotImplementedError, TypeError, ValueError):
2213
+ raise ValueError(filldedent('''
2214
+ The 2nd argument was not a valid Point; if
2215
+ it was meant to be an angle it should be
2216
+ given with keyword "angle".'''))
2217
+ if p1 == p2:
2218
+ raise ValueError('A Ray requires two distinct points.')
2219
+ elif angle is not None and pt is None:
2220
+ # we need to know if the angle is an odd multiple of pi/2
2221
+ angle = sympify(angle)
2222
+ c = _pi_coeff(angle)
2223
+ p2 = None
2224
+ if c is not None:
2225
+ if c.is_Rational:
2226
+ if c.q == 2:
2227
+ if c.p == 1:
2228
+ p2 = p1 + Point(0, 1)
2229
+ elif c.p == 3:
2230
+ p2 = p1 + Point(0, -1)
2231
+ elif c.q == 1:
2232
+ if c.p == 0:
2233
+ p2 = p1 + Point(1, 0)
2234
+ elif c.p == 1:
2235
+ p2 = p1 + Point(-1, 0)
2236
+ if p2 is None:
2237
+ c *= S.Pi
2238
+ else:
2239
+ c = angle % (2*S.Pi)
2240
+ if not p2:
2241
+ m = 2*c/S.Pi
2242
+ left = And(1 < m, m < 3) # is it in quadrant 2 or 3?
2243
+ x = Piecewise((-1, left), (Piecewise((0, Eq(m % 1, 0)), (1, True)), True))
2244
+ y = Piecewise((-tan(c), left), (Piecewise((1, Eq(m, 1)), (-1, Eq(m, 3)), (tan(c), True)), True))
2245
+ p2 = p1 + Point(x, y)
2246
+ else:
2247
+ raise ValueError('A 2nd point or keyword "angle" must be used.')
2248
+
2249
+ return LinearEntity2D.__new__(cls, p1, p2, **kwargs)
2250
+
2251
+ @property
2252
+ def xdirection(self):
2253
+ """The x direction of the ray.
2254
+
2255
+ Positive infinity if the ray points in the positive x direction,
2256
+ negative infinity if the ray points in the negative x direction,
2257
+ or 0 if the ray is vertical.
2258
+
2259
+ See Also
2260
+ ========
2261
+
2262
+ ydirection
2263
+
2264
+ Examples
2265
+ ========
2266
+
2267
+ >>> from sympy import Point, Ray
2268
+ >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(0, -1)
2269
+ >>> r1, r2 = Ray(p1, p2), Ray(p1, p3)
2270
+ >>> r1.xdirection
2271
+ oo
2272
+ >>> r2.xdirection
2273
+ 0
2274
+
2275
+ """
2276
+ if self.p1.x < self.p2.x:
2277
+ return S.Infinity
2278
+ elif self.p1.x == self.p2.x:
2279
+ return S.Zero
2280
+ else:
2281
+ return S.NegativeInfinity
2282
+
2283
+ @property
2284
+ def ydirection(self):
2285
+ """The y direction of the ray.
2286
+
2287
+ Positive infinity if the ray points in the positive y direction,
2288
+ negative infinity if the ray points in the negative y direction,
2289
+ or 0 if the ray is horizontal.
2290
+
2291
+ See Also
2292
+ ========
2293
+
2294
+ xdirection
2295
+
2296
+ Examples
2297
+ ========
2298
+
2299
+ >>> from sympy import Point, Ray
2300
+ >>> p1, p2, p3 = Point(0, 0), Point(-1, -1), Point(-1, 0)
2301
+ >>> r1, r2 = Ray(p1, p2), Ray(p1, p3)
2302
+ >>> r1.ydirection
2303
+ -oo
2304
+ >>> r2.ydirection
2305
+ 0
2306
+
2307
+ """
2308
+ if self.p1.y < self.p2.y:
2309
+ return S.Infinity
2310
+ elif self.p1.y == self.p2.y:
2311
+ return S.Zero
2312
+ else:
2313
+ return S.NegativeInfinity
2314
+
2315
+ def closing_angle(r1, r2):
2316
+ """Return the angle by which r2 must be rotated so it faces the same
2317
+ direction as r1.
2318
+
2319
+ Parameters
2320
+ ==========
2321
+
2322
+ r1 : Ray2D
2323
+ r2 : Ray2D
2324
+
2325
+ Returns
2326
+ =======
2327
+
2328
+ angle : angle in radians (ccw angle is positive)
2329
+
2330
+ See Also
2331
+ ========
2332
+
2333
+ LinearEntity.angle_between
2334
+
2335
+ Examples
2336
+ ========
2337
+
2338
+ >>> from sympy import Ray, pi
2339
+ >>> r1 = Ray((0, 0), (1, 0))
2340
+ >>> r2 = r1.rotate(-pi/2)
2341
+ >>> angle = r1.closing_angle(r2); angle
2342
+ pi/2
2343
+ >>> r2.rotate(angle).direction.unit == r1.direction.unit
2344
+ True
2345
+ >>> r2.closing_angle(r1)
2346
+ -pi/2
2347
+ """
2348
+ if not all(isinstance(r, Ray2D) for r in (r1, r2)):
2349
+ # although the direction property is defined for
2350
+ # all linear entities, only the Ray is truly a
2351
+ # directed object
2352
+ raise TypeError('Both arguments must be Ray2D objects.')
2353
+
2354
+ a1 = atan2(*list(reversed(r1.direction.args)))
2355
+ a2 = atan2(*list(reversed(r2.direction.args)))
2356
+ if a1*a2 < 0:
2357
+ a1 = 2*S.Pi + a1 if a1 < 0 else a1
2358
+ a2 = 2*S.Pi + a2 if a2 < 0 else a2
2359
+ return a1 - a2
2360
+
2361
+
2362
+ class Segment2D(LinearEntity2D, Segment):
2363
+ """A line segment in 2D space.
2364
+
2365
+ Parameters
2366
+ ==========
2367
+
2368
+ p1 : Point
2369
+ p2 : Point
2370
+
2371
+ Attributes
2372
+ ==========
2373
+
2374
+ length : number or SymPy expression
2375
+ midpoint : Point
2376
+
2377
+ See Also
2378
+ ========
2379
+
2380
+ sympy.geometry.point.Point, Line
2381
+
2382
+ Examples
2383
+ ========
2384
+
2385
+ >>> from sympy import Point, Segment
2386
+ >>> Segment((1, 0), (1, 1)) # tuples are interpreted as pts
2387
+ Segment2D(Point2D(1, 0), Point2D(1, 1))
2388
+ >>> s = Segment(Point(4, 3), Point(1, 1)); s
2389
+ Segment2D(Point2D(4, 3), Point2D(1, 1))
2390
+ >>> s.points
2391
+ (Point2D(4, 3), Point2D(1, 1))
2392
+ >>> s.slope
2393
+ 2/3
2394
+ >>> s.length
2395
+ sqrt(13)
2396
+ >>> s.midpoint
2397
+ Point2D(5/2, 2)
2398
+
2399
+ """
2400
+ def __new__(cls, p1, p2, **kwargs):
2401
+ p1 = Point(p1, dim=2)
2402
+ p2 = Point(p2, dim=2)
2403
+
2404
+ if p1 == p2:
2405
+ return p1
2406
+
2407
+ return LinearEntity2D.__new__(cls, p1, p2, **kwargs)
2408
+
2409
+ def _svg(self, scale_factor=1., fill_color="#66cc99"):
2410
+ """Returns SVG path element for the LinearEntity.
2411
+
2412
+ Parameters
2413
+ ==========
2414
+
2415
+ scale_factor : float
2416
+ Multiplication factor for the SVG stroke-width. Default is 1.
2417
+ fill_color : str, optional
2418
+ Hex string for fill color. Default is "#66cc99".
2419
+ """
2420
+ verts = (N(self.p1), N(self.p2))
2421
+ coords = ["{},{}".format(p.x, p.y) for p in verts]
2422
+ path = "M {} L {}".format(coords[0], " L ".join(coords[1:]))
2423
+ return (
2424
+ '<path fill-rule="evenodd" fill="{2}" stroke="#555555" '
2425
+ 'stroke-width="{0}" opacity="0.6" d="{1}" />'
2426
+ ).format(2.*scale_factor, path, fill_color)
2427
+
2428
+
2429
+ class LinearEntity3D(LinearEntity):
2430
+ """An base class for all linear entities (line, ray and segment)
2431
+ in a 3-dimensional Euclidean space.
2432
+
2433
+ Attributes
2434
+ ==========
2435
+
2436
+ p1
2437
+ p2
2438
+ direction_ratio
2439
+ direction_cosine
2440
+ points
2441
+
2442
+ Notes
2443
+ =====
2444
+
2445
+ This is a base class and is not meant to be instantiated.
2446
+ """
2447
+ def __new__(cls, p1, p2, **kwargs):
2448
+ p1 = Point3D(p1, dim=3)
2449
+ p2 = Point3D(p2, dim=3)
2450
+ if p1 == p2:
2451
+ # if it makes sense to return a Point, handle in subclass
2452
+ raise ValueError(
2453
+ "%s.__new__ requires two unique Points." % cls.__name__)
2454
+
2455
+ return GeometryEntity.__new__(cls, p1, p2, **kwargs)
2456
+
2457
+ ambient_dimension = 3
2458
+
2459
+ @property
2460
+ def direction_ratio(self):
2461
+ """The direction ratio of a given line in 3D.
2462
+
2463
+ See Also
2464
+ ========
2465
+
2466
+ sympy.geometry.line.Line3D.equation
2467
+
2468
+ Examples
2469
+ ========
2470
+
2471
+ >>> from sympy import Point3D, Line3D
2472
+ >>> p1, p2 = Point3D(0, 0, 0), Point3D(5, 3, 1)
2473
+ >>> l = Line3D(p1, p2)
2474
+ >>> l.direction_ratio
2475
+ [5, 3, 1]
2476
+ """
2477
+ p1, p2 = self.points
2478
+ return p1.direction_ratio(p2)
2479
+
2480
+ @property
2481
+ def direction_cosine(self):
2482
+ """The normalized direction ratio of a given line in 3D.
2483
+
2484
+ See Also
2485
+ ========
2486
+
2487
+ sympy.geometry.line.Line3D.equation
2488
+
2489
+ Examples
2490
+ ========
2491
+
2492
+ >>> from sympy import Point3D, Line3D
2493
+ >>> p1, p2 = Point3D(0, 0, 0), Point3D(5, 3, 1)
2494
+ >>> l = Line3D(p1, p2)
2495
+ >>> l.direction_cosine
2496
+ [sqrt(35)/7, 3*sqrt(35)/35, sqrt(35)/35]
2497
+ >>> sum(i**2 for i in _)
2498
+ 1
2499
+ """
2500
+ p1, p2 = self.points
2501
+ return p1.direction_cosine(p2)
2502
+
2503
+
2504
+ class Line3D(LinearEntity3D, Line):
2505
+ """An infinite 3D line in space.
2506
+
2507
+ A line is declared with two distinct points or a point and direction_ratio
2508
+ as defined using keyword `direction_ratio`.
2509
+
2510
+ Parameters
2511
+ ==========
2512
+
2513
+ p1 : Point3D
2514
+ pt : Point3D
2515
+ direction_ratio : list
2516
+
2517
+ See Also
2518
+ ========
2519
+
2520
+ sympy.geometry.point.Point3D
2521
+ sympy.geometry.line.Line
2522
+ sympy.geometry.line.Line2D
2523
+
2524
+ Examples
2525
+ ========
2526
+
2527
+ >>> from sympy import Line3D, Point3D
2528
+ >>> L = Line3D(Point3D(2, 3, 4), Point3D(3, 5, 1))
2529
+ >>> L
2530
+ Line3D(Point3D(2, 3, 4), Point3D(3, 5, 1))
2531
+ >>> L.points
2532
+ (Point3D(2, 3, 4), Point3D(3, 5, 1))
2533
+ """
2534
+ def __new__(cls, p1, pt=None, direction_ratio=(), **kwargs):
2535
+ if isinstance(p1, LinearEntity3D):
2536
+ if pt is not None:
2537
+ raise ValueError('if p1 is a LinearEntity, pt must be None.')
2538
+ p1, pt = p1.args
2539
+ else:
2540
+ p1 = Point(p1, dim=3)
2541
+ if pt is not None and len(direction_ratio) == 0:
2542
+ pt = Point(pt, dim=3)
2543
+ elif len(direction_ratio) == 3 and pt is None:
2544
+ pt = Point3D(p1.x + direction_ratio[0], p1.y + direction_ratio[1],
2545
+ p1.z + direction_ratio[2])
2546
+ else:
2547
+ raise ValueError('A 2nd Point or keyword "direction_ratio" must '
2548
+ 'be used.')
2549
+
2550
+ return LinearEntity3D.__new__(cls, p1, pt, **kwargs)
2551
+
2552
+ def equation(self, x='x', y='y', z='z'):
2553
+ """Return the equations that define the line in 3D.
2554
+
2555
+ Parameters
2556
+ ==========
2557
+
2558
+ x : str, optional
2559
+ The name to use for the x-axis, default value is 'x'.
2560
+ y : str, optional
2561
+ The name to use for the y-axis, default value is 'y'.
2562
+ z : str, optional
2563
+ The name to use for the z-axis, default value is 'z'.
2564
+
2565
+ Returns
2566
+ =======
2567
+
2568
+ equation : Tuple of simultaneous equations
2569
+
2570
+ Examples
2571
+ ========
2572
+
2573
+ >>> from sympy import Point3D, Line3D, solve
2574
+ >>> from sympy.abc import x, y, z
2575
+ >>> p1, p2 = Point3D(1, 0, 0), Point3D(5, 3, 0)
2576
+ >>> l1 = Line3D(p1, p2)
2577
+ >>> eq = l1.equation(x, y, z); eq
2578
+ (-3*x + 4*y + 3, z)
2579
+ >>> solve(eq.subs(z, 0), (x, y, z))
2580
+ {x: 4*y/3 + 1}
2581
+ """
2582
+ x, y, z, k = [_symbol(i, real=True) for i in (x, y, z, 'k')]
2583
+ p1, p2 = self.points
2584
+ d1, d2, d3 = p1.direction_ratio(p2)
2585
+ x1, y1, z1 = p1
2586
+ eqs = [-d1*k + x - x1, -d2*k + y - y1, -d3*k + z - z1]
2587
+ # eliminate k from equations by solving first eq with k for k
2588
+ for i, e in enumerate(eqs):
2589
+ if e.has(k):
2590
+ kk = solve(e, k)[0]
2591
+ eqs.pop(i)
2592
+ break
2593
+ return Tuple(*[i.subs(k, kk).as_numer_denom()[0] for i in eqs])
2594
+
2595
+ def distance(self, other):
2596
+ """
2597
+ Finds the shortest distance between a line and another object.
2598
+
2599
+ Parameters
2600
+ ==========
2601
+
2602
+ Point3D, Line3D, Plane, tuple, list
2603
+
2604
+ Returns
2605
+ =======
2606
+
2607
+ distance
2608
+
2609
+ Notes
2610
+ =====
2611
+
2612
+ This method accepts only 3D entities as it's parameter
2613
+
2614
+ Tuples and lists are converted to Point3D and therefore must be of
2615
+ length 3, 2 or 1.
2616
+
2617
+ NotImplementedError is raised if `other` is not an instance of one
2618
+ of the specified classes: Point3D, Line3D, or Plane.
2619
+
2620
+ Examples
2621
+ ========
2622
+
2623
+ >>> from sympy.geometry import Line3D
2624
+ >>> l1 = Line3D((0, 0, 0), (0, 0, 1))
2625
+ >>> l2 = Line3D((0, 1, 0), (1, 1, 1))
2626
+ >>> l1.distance(l2)
2627
+ 1
2628
+
2629
+ The computed distance may be symbolic, too:
2630
+
2631
+ >>> from sympy.abc import x, y
2632
+ >>> l1 = Line3D((0, 0, 0), (0, 0, 1))
2633
+ >>> l2 = Line3D((0, x, 0), (y, x, 1))
2634
+ >>> l1.distance(l2)
2635
+ Abs(x*y)/Abs(sqrt(y**2))
2636
+
2637
+ """
2638
+
2639
+ from .plane import Plane # Avoid circular import
2640
+
2641
+ if isinstance(other, (tuple, list)):
2642
+ try:
2643
+ other = Point3D(other)
2644
+ except ValueError:
2645
+ pass
2646
+
2647
+ if isinstance(other, Point3D):
2648
+ return super().distance(other)
2649
+
2650
+ if isinstance(other, Line3D):
2651
+ if self == other:
2652
+ return S.Zero
2653
+ if self.is_parallel(other):
2654
+ return super().distance(other.p1)
2655
+
2656
+ # Skew lines
2657
+ self_direction = Matrix(self.direction_ratio)
2658
+ other_direction = Matrix(other.direction_ratio)
2659
+ normal = self_direction.cross(other_direction)
2660
+ plane_through_self = Plane(p1=self.p1, normal_vector=normal)
2661
+ return other.p1.distance(plane_through_self)
2662
+
2663
+ if isinstance(other, Plane):
2664
+ return other.distance(self)
2665
+
2666
+ msg = f"{other} has type {type(other)}, which is unsupported"
2667
+ raise NotImplementedError(msg)
2668
+
2669
+
2670
+ class Ray3D(LinearEntity3D, Ray):
2671
+ """
2672
+ A Ray is a semi-line in the space with a source point and a direction.
2673
+
2674
+ Parameters
2675
+ ==========
2676
+
2677
+ p1 : Point3D
2678
+ The source of the Ray
2679
+ p2 : Point or a direction vector
2680
+ direction_ratio: Determines the direction in which the Ray propagates.
2681
+
2682
+
2683
+ Attributes
2684
+ ==========
2685
+
2686
+ source
2687
+ xdirection
2688
+ ydirection
2689
+ zdirection
2690
+
2691
+ See Also
2692
+ ========
2693
+
2694
+ sympy.geometry.point.Point3D, Line3D
2695
+
2696
+
2697
+ Examples
2698
+ ========
2699
+
2700
+ >>> from sympy import Point3D, Ray3D
2701
+ >>> r = Ray3D(Point3D(2, 3, 4), Point3D(3, 5, 0))
2702
+ >>> r
2703
+ Ray3D(Point3D(2, 3, 4), Point3D(3, 5, 0))
2704
+ >>> r.points
2705
+ (Point3D(2, 3, 4), Point3D(3, 5, 0))
2706
+ >>> r.source
2707
+ Point3D(2, 3, 4)
2708
+ >>> r.xdirection
2709
+ oo
2710
+ >>> r.ydirection
2711
+ oo
2712
+ >>> r.direction_ratio
2713
+ [1, 2, -4]
2714
+
2715
+ """
2716
+ def __new__(cls, p1, pt=None, direction_ratio=(), **kwargs):
2717
+ if isinstance(p1, LinearEntity3D):
2718
+ if pt is not None:
2719
+ raise ValueError('If p1 is a LinearEntity, pt must be None')
2720
+ p1, pt = p1.args
2721
+ else:
2722
+ p1 = Point(p1, dim=3)
2723
+ if pt is not None and len(direction_ratio) == 0:
2724
+ pt = Point(pt, dim=3)
2725
+ elif len(direction_ratio) == 3 and pt is None:
2726
+ pt = Point3D(p1.x + direction_ratio[0], p1.y + direction_ratio[1],
2727
+ p1.z + direction_ratio[2])
2728
+ else:
2729
+ raise ValueError(filldedent('''
2730
+ A 2nd Point or keyword "direction_ratio" must be used.
2731
+ '''))
2732
+
2733
+ return LinearEntity3D.__new__(cls, p1, pt, **kwargs)
2734
+
2735
+ @property
2736
+ def xdirection(self):
2737
+ """The x direction of the ray.
2738
+
2739
+ Positive infinity if the ray points in the positive x direction,
2740
+ negative infinity if the ray points in the negative x direction,
2741
+ or 0 if the ray is vertical.
2742
+
2743
+ See Also
2744
+ ========
2745
+
2746
+ ydirection
2747
+
2748
+ Examples
2749
+ ========
2750
+
2751
+ >>> from sympy import Point3D, Ray3D
2752
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(0, -1, 0)
2753
+ >>> r1, r2 = Ray3D(p1, p2), Ray3D(p1, p3)
2754
+ >>> r1.xdirection
2755
+ oo
2756
+ >>> r2.xdirection
2757
+ 0
2758
+
2759
+ """
2760
+ if self.p1.x < self.p2.x:
2761
+ return S.Infinity
2762
+ elif self.p1.x == self.p2.x:
2763
+ return S.Zero
2764
+ else:
2765
+ return S.NegativeInfinity
2766
+
2767
+ @property
2768
+ def ydirection(self):
2769
+ """The y direction of the ray.
2770
+
2771
+ Positive infinity if the ray points in the positive y direction,
2772
+ negative infinity if the ray points in the negative y direction,
2773
+ or 0 if the ray is horizontal.
2774
+
2775
+ See Also
2776
+ ========
2777
+
2778
+ xdirection
2779
+
2780
+ Examples
2781
+ ========
2782
+
2783
+ >>> from sympy import Point3D, Ray3D
2784
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(-1, -1, -1), Point3D(-1, 0, 0)
2785
+ >>> r1, r2 = Ray3D(p1, p2), Ray3D(p1, p3)
2786
+ >>> r1.ydirection
2787
+ -oo
2788
+ >>> r2.ydirection
2789
+ 0
2790
+
2791
+ """
2792
+ if self.p1.y < self.p2.y:
2793
+ return S.Infinity
2794
+ elif self.p1.y == self.p2.y:
2795
+ return S.Zero
2796
+ else:
2797
+ return S.NegativeInfinity
2798
+
2799
+ @property
2800
+ def zdirection(self):
2801
+ """The z direction of the ray.
2802
+
2803
+ Positive infinity if the ray points in the positive z direction,
2804
+ negative infinity if the ray points in the negative z direction,
2805
+ or 0 if the ray is horizontal.
2806
+
2807
+ See Also
2808
+ ========
2809
+
2810
+ xdirection
2811
+
2812
+ Examples
2813
+ ========
2814
+
2815
+ >>> from sympy import Point3D, Ray3D
2816
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(-1, -1, -1), Point3D(-1, 0, 0)
2817
+ >>> r1, r2 = Ray3D(p1, p2), Ray3D(p1, p3)
2818
+ >>> r1.ydirection
2819
+ -oo
2820
+ >>> r2.ydirection
2821
+ 0
2822
+ >>> r2.zdirection
2823
+ 0
2824
+
2825
+ """
2826
+ if self.p1.z < self.p2.z:
2827
+ return S.Infinity
2828
+ elif self.p1.z == self.p2.z:
2829
+ return S.Zero
2830
+ else:
2831
+ return S.NegativeInfinity
2832
+
2833
+
2834
+ class Segment3D(LinearEntity3D, Segment):
2835
+ """A line segment in a 3D space.
2836
+
2837
+ Parameters
2838
+ ==========
2839
+
2840
+ p1 : Point3D
2841
+ p2 : Point3D
2842
+
2843
+ Attributes
2844
+ ==========
2845
+
2846
+ length : number or SymPy expression
2847
+ midpoint : Point3D
2848
+
2849
+ See Also
2850
+ ========
2851
+
2852
+ sympy.geometry.point.Point3D, Line3D
2853
+
2854
+ Examples
2855
+ ========
2856
+
2857
+ >>> from sympy import Point3D, Segment3D
2858
+ >>> Segment3D((1, 0, 0), (1, 1, 1)) # tuples are interpreted as pts
2859
+ Segment3D(Point3D(1, 0, 0), Point3D(1, 1, 1))
2860
+ >>> s = Segment3D(Point3D(4, 3, 9), Point3D(1, 1, 7)); s
2861
+ Segment3D(Point3D(4, 3, 9), Point3D(1, 1, 7))
2862
+ >>> s.points
2863
+ (Point3D(4, 3, 9), Point3D(1, 1, 7))
2864
+ >>> s.length
2865
+ sqrt(17)
2866
+ >>> s.midpoint
2867
+ Point3D(5/2, 2, 8)
2868
+
2869
+ """
2870
+ def __new__(cls, p1, p2, **kwargs):
2871
+ p1 = Point(p1, dim=3)
2872
+ p2 = Point(p2, dim=3)
2873
+
2874
+ if p1 == p2:
2875
+ return p1
2876
+
2877
+ return LinearEntity3D.__new__(cls, p1, p2, **kwargs)
.venv/lib/python3.13/site-packages/sympy/geometry/parabola.py ADDED
@@ -0,0 +1,422 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Parabolic geometrical entity.
2
+
3
+ Contains
4
+ * Parabola
5
+
6
+ """
7
+
8
+ from sympy.core import S
9
+ from sympy.core.sorting import ordered
10
+ from sympy.core.symbol import _symbol, symbols
11
+ from sympy.geometry.entity import GeometryEntity, GeometrySet
12
+ from sympy.geometry.point import Point, Point2D
13
+ from sympy.geometry.line import Line, Line2D, Ray2D, Segment2D, LinearEntity3D
14
+ from sympy.geometry.ellipse import Ellipse
15
+ from sympy.functions import sign
16
+ from sympy.simplify.simplify import simplify
17
+ from sympy.solvers.solvers import solve
18
+
19
+
20
+ class Parabola(GeometrySet):
21
+ """A parabolic GeometryEntity.
22
+
23
+ A parabola is declared with a point, that is called 'focus', and
24
+ a line, that is called 'directrix'.
25
+ Only vertical or horizontal parabolas are currently supported.
26
+
27
+ Parameters
28
+ ==========
29
+
30
+ focus : Point
31
+ Default value is Point(0, 0)
32
+ directrix : Line
33
+
34
+ Attributes
35
+ ==========
36
+
37
+ focus
38
+ directrix
39
+ axis of symmetry
40
+ focal length
41
+ p parameter
42
+ vertex
43
+ eccentricity
44
+
45
+ Raises
46
+ ======
47
+ ValueError
48
+ When `focus` is not a two dimensional point.
49
+ When `focus` is a point of directrix.
50
+ NotImplementedError
51
+ When `directrix` is neither horizontal nor vertical.
52
+
53
+ Examples
54
+ ========
55
+
56
+ >>> from sympy import Parabola, Point, Line
57
+ >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7,8)))
58
+ >>> p1.focus
59
+ Point2D(0, 0)
60
+ >>> p1.directrix
61
+ Line2D(Point2D(5, 8), Point2D(7, 8))
62
+
63
+ """
64
+
65
+ def __new__(cls, focus=None, directrix=None, **kwargs):
66
+
67
+ if focus:
68
+ focus = Point(focus, dim=2)
69
+ else:
70
+ focus = Point(0, 0)
71
+
72
+ directrix = Line(directrix)
73
+
74
+ if directrix.contains(focus):
75
+ raise ValueError('The focus must not be a point of directrix')
76
+
77
+ return GeometryEntity.__new__(cls, focus, directrix, **kwargs)
78
+
79
+ @property
80
+ def ambient_dimension(self):
81
+ """Returns the ambient dimension of parabola.
82
+
83
+ Returns
84
+ =======
85
+
86
+ ambient_dimension : integer
87
+
88
+ Examples
89
+ ========
90
+
91
+ >>> from sympy import Parabola, Point, Line
92
+ >>> f1 = Point(0, 0)
93
+ >>> p1 = Parabola(f1, Line(Point(5, 8), Point(7, 8)))
94
+ >>> p1.ambient_dimension
95
+ 2
96
+
97
+ """
98
+ return 2
99
+
100
+ @property
101
+ def axis_of_symmetry(self):
102
+ """Return the axis of symmetry of the parabola: a line
103
+ perpendicular to the directrix passing through the focus.
104
+
105
+ Returns
106
+ =======
107
+
108
+ axis_of_symmetry : Line
109
+
110
+ See Also
111
+ ========
112
+
113
+ sympy.geometry.line.Line
114
+
115
+ Examples
116
+ ========
117
+
118
+ >>> from sympy import Parabola, Point, Line
119
+ >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8)))
120
+ >>> p1.axis_of_symmetry
121
+ Line2D(Point2D(0, 0), Point2D(0, 1))
122
+
123
+ """
124
+ return self.directrix.perpendicular_line(self.focus)
125
+
126
+ @property
127
+ def directrix(self):
128
+ """The directrix of the parabola.
129
+
130
+ Returns
131
+ =======
132
+
133
+ directrix : Line
134
+
135
+ See Also
136
+ ========
137
+
138
+ sympy.geometry.line.Line
139
+
140
+ Examples
141
+ ========
142
+
143
+ >>> from sympy import Parabola, Point, Line
144
+ >>> l1 = Line(Point(5, 8), Point(7, 8))
145
+ >>> p1 = Parabola(Point(0, 0), l1)
146
+ >>> p1.directrix
147
+ Line2D(Point2D(5, 8), Point2D(7, 8))
148
+
149
+ """
150
+ return self.args[1]
151
+
152
+ @property
153
+ def eccentricity(self):
154
+ """The eccentricity of the parabola.
155
+
156
+ Returns
157
+ =======
158
+
159
+ eccentricity : number
160
+
161
+ A parabola may also be characterized as a conic section with an
162
+ eccentricity of 1. As a consequence of this, all parabolas are
163
+ similar, meaning that while they can be different sizes,
164
+ they are all the same shape.
165
+
166
+ See Also
167
+ ========
168
+
169
+ https://en.wikipedia.org/wiki/Parabola
170
+
171
+
172
+ Examples
173
+ ========
174
+
175
+ >>> from sympy import Parabola, Point, Line
176
+ >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8)))
177
+ >>> p1.eccentricity
178
+ 1
179
+
180
+ Notes
181
+ -----
182
+ The eccentricity for every Parabola is 1 by definition.
183
+
184
+ """
185
+ return S.One
186
+
187
+ def equation(self, x='x', y='y'):
188
+ """The equation of the parabola.
189
+
190
+ Parameters
191
+ ==========
192
+ x : str, optional
193
+ Label for the x-axis. Default value is 'x'.
194
+ y : str, optional
195
+ Label for the y-axis. Default value is 'y'.
196
+
197
+ Returns
198
+ =======
199
+ equation : SymPy expression
200
+
201
+ Examples
202
+ ========
203
+
204
+ >>> from sympy import Parabola, Point, Line
205
+ >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8)))
206
+ >>> p1.equation()
207
+ -x**2 - 16*y + 64
208
+ >>> p1.equation('f')
209
+ -f**2 - 16*y + 64
210
+ >>> p1.equation(y='z')
211
+ -x**2 - 16*z + 64
212
+
213
+ """
214
+ x = _symbol(x, real=True)
215
+ y = _symbol(y, real=True)
216
+
217
+ m = self.directrix.slope
218
+ if m is S.Infinity:
219
+ t1 = 4 * (self.p_parameter) * (x - self.vertex.x)
220
+ t2 = (y - self.vertex.y)**2
221
+ elif m == 0:
222
+ t1 = 4 * (self.p_parameter) * (y - self.vertex.y)
223
+ t2 = (x - self.vertex.x)**2
224
+ else:
225
+ a, b = self.focus
226
+ c, d = self.directrix.coefficients[:2]
227
+ t1 = (x - a)**2 + (y - b)**2
228
+ t2 = self.directrix.equation(x, y)**2/(c**2 + d**2)
229
+ return t1 - t2
230
+
231
+ @property
232
+ def focal_length(self):
233
+ """The focal length of the parabola.
234
+
235
+ Returns
236
+ =======
237
+
238
+ focal_lenght : number or symbolic expression
239
+
240
+ Notes
241
+ =====
242
+
243
+ The distance between the vertex and the focus
244
+ (or the vertex and directrix), measured along the axis
245
+ of symmetry, is the "focal length".
246
+
247
+ See Also
248
+ ========
249
+
250
+ https://en.wikipedia.org/wiki/Parabola
251
+
252
+ Examples
253
+ ========
254
+
255
+ >>> from sympy import Parabola, Point, Line
256
+ >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8)))
257
+ >>> p1.focal_length
258
+ 4
259
+
260
+ """
261
+ distance = self.directrix.distance(self.focus)
262
+ focal_length = distance/2
263
+
264
+ return focal_length
265
+
266
+ @property
267
+ def focus(self):
268
+ """The focus of the parabola.
269
+
270
+ Returns
271
+ =======
272
+
273
+ focus : Point
274
+
275
+ See Also
276
+ ========
277
+
278
+ sympy.geometry.point.Point
279
+
280
+ Examples
281
+ ========
282
+
283
+ >>> from sympy import Parabola, Point, Line
284
+ >>> f1 = Point(0, 0)
285
+ >>> p1 = Parabola(f1, Line(Point(5, 8), Point(7, 8)))
286
+ >>> p1.focus
287
+ Point2D(0, 0)
288
+
289
+ """
290
+ return self.args[0]
291
+
292
+ def intersection(self, o):
293
+ """The intersection of the parabola and another geometrical entity `o`.
294
+
295
+ Parameters
296
+ ==========
297
+
298
+ o : GeometryEntity, LinearEntity
299
+
300
+ Returns
301
+ =======
302
+
303
+ intersection : list of GeometryEntity objects
304
+
305
+ Examples
306
+ ========
307
+
308
+ >>> from sympy import Parabola, Point, Ellipse, Line, Segment
309
+ >>> p1 = Point(0,0)
310
+ >>> l1 = Line(Point(1, -2), Point(-1,-2))
311
+ >>> parabola1 = Parabola(p1, l1)
312
+ >>> parabola1.intersection(Ellipse(Point(0, 0), 2, 5))
313
+ [Point2D(-2, 0), Point2D(2, 0)]
314
+ >>> parabola1.intersection(Line(Point(-7, 3), Point(12, 3)))
315
+ [Point2D(-4, 3), Point2D(4, 3)]
316
+ >>> parabola1.intersection(Segment((-12, -65), (14, -68)))
317
+ []
318
+
319
+ """
320
+ x, y = symbols('x y', real=True)
321
+ parabola_eq = self.equation()
322
+ if isinstance(o, Parabola):
323
+ if o in self:
324
+ return [o]
325
+ else:
326
+ return list(ordered([Point(i) for i in solve(
327
+ [parabola_eq, o.equation()], [x, y], set=True)[1]]))
328
+ elif isinstance(o, Point2D):
329
+ if simplify(parabola_eq.subs([(x, o._args[0]), (y, o._args[1])])) == 0:
330
+ return [o]
331
+ else:
332
+ return []
333
+ elif isinstance(o, (Segment2D, Ray2D)):
334
+ result = solve([parabola_eq,
335
+ Line2D(o.points[0], o.points[1]).equation()],
336
+ [x, y], set=True)[1]
337
+ return list(ordered([Point2D(i) for i in result if i in o]))
338
+ elif isinstance(o, (Line2D, Ellipse)):
339
+ return list(ordered([Point2D(i) for i in solve(
340
+ [parabola_eq, o.equation()], [x, y], set=True)[1]]))
341
+ elif isinstance(o, LinearEntity3D):
342
+ raise TypeError('Entity must be two dimensional, not three dimensional')
343
+ else:
344
+ raise TypeError('Wrong type of argument were put')
345
+
346
+ @property
347
+ def p_parameter(self):
348
+ """P is a parameter of parabola.
349
+
350
+ Returns
351
+ =======
352
+
353
+ p : number or symbolic expression
354
+
355
+ Notes
356
+ =====
357
+
358
+ The absolute value of p is the focal length. The sign on p tells
359
+ which way the parabola faces. Vertical parabolas that open up
360
+ and horizontal that open right, give a positive value for p.
361
+ Vertical parabolas that open down and horizontal that open left,
362
+ give a negative value for p.
363
+
364
+
365
+ See Also
366
+ ========
367
+
368
+ https://www.sparknotes.com/math/precalc/conicsections/section2/
369
+
370
+ Examples
371
+ ========
372
+
373
+ >>> from sympy import Parabola, Point, Line
374
+ >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8)))
375
+ >>> p1.p_parameter
376
+ -4
377
+
378
+ """
379
+ m = self.directrix.slope
380
+ if m is S.Infinity:
381
+ x = self.directrix.coefficients[2]
382
+ p = sign(self.focus.args[0] + x)
383
+ elif m == 0:
384
+ y = self.directrix.coefficients[2]
385
+ p = sign(self.focus.args[1] + y)
386
+ else:
387
+ d = self.directrix.projection(self.focus)
388
+ p = sign(self.focus.x - d.x)
389
+ return p * self.focal_length
390
+
391
+ @property
392
+ def vertex(self):
393
+ """The vertex of the parabola.
394
+
395
+ Returns
396
+ =======
397
+
398
+ vertex : Point
399
+
400
+ See Also
401
+ ========
402
+
403
+ sympy.geometry.point.Point
404
+
405
+ Examples
406
+ ========
407
+
408
+ >>> from sympy import Parabola, Point, Line
409
+ >>> p1 = Parabola(Point(0, 0), Line(Point(5, 8), Point(7, 8)))
410
+ >>> p1.vertex
411
+ Point2D(0, 4)
412
+
413
+ """
414
+ focus = self.focus
415
+ m = self.directrix.slope
416
+ if m is S.Infinity:
417
+ vertex = Point(focus.args[0] - self.p_parameter, focus.args[1])
418
+ elif m == 0:
419
+ vertex = Point(focus.args[0], focus.args[1] - self.p_parameter)
420
+ else:
421
+ vertex = self.axis_of_symmetry.intersection(self)[0]
422
+ return vertex
.venv/lib/python3.13/site-packages/sympy/geometry/plane.py ADDED
@@ -0,0 +1,878 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Geometrical Planes.
2
+
3
+ Contains
4
+ ========
5
+ Plane
6
+
7
+ """
8
+
9
+ from sympy.core import Dummy, Rational, S, Symbol
10
+ from sympy.core.symbol import _symbol
11
+ from sympy.functions.elementary.trigonometric import cos, sin, acos, asin, sqrt
12
+ from .entity import GeometryEntity
13
+ from .line import (Line, Ray, Segment, Line3D, LinearEntity, LinearEntity3D,
14
+ Ray3D, Segment3D)
15
+ from .point import Point, Point3D
16
+ from sympy.matrices import Matrix
17
+ from sympy.polys.polytools import cancel
18
+ from sympy.solvers import solve, linsolve
19
+ from sympy.utilities.iterables import uniq, is_sequence
20
+ from sympy.utilities.misc import filldedent, func_name, Undecidable
21
+
22
+ from mpmath.libmp.libmpf import prec_to_dps
23
+
24
+ import random
25
+
26
+
27
+ x, y, z, t = [Dummy('plane_dummy') for i in range(4)]
28
+
29
+
30
+ class Plane(GeometryEntity):
31
+ """
32
+ A plane is a flat, two-dimensional surface. A plane is the two-dimensional
33
+ analogue of a point (zero-dimensions), a line (one-dimension) and a solid
34
+ (three-dimensions). A plane can generally be constructed by two types of
35
+ inputs. They are:
36
+ - three non-collinear points
37
+ - a point and the plane's normal vector
38
+
39
+ Attributes
40
+ ==========
41
+
42
+ p1
43
+ normal_vector
44
+
45
+ Examples
46
+ ========
47
+
48
+ >>> from sympy import Plane, Point3D
49
+ >>> Plane(Point3D(1, 1, 1), Point3D(2, 3, 4), Point3D(2, 2, 2))
50
+ Plane(Point3D(1, 1, 1), (-1, 2, -1))
51
+ >>> Plane((1, 1, 1), (2, 3, 4), (2, 2, 2))
52
+ Plane(Point3D(1, 1, 1), (-1, 2, -1))
53
+ >>> Plane(Point3D(1, 1, 1), normal_vector=(1,4,7))
54
+ Plane(Point3D(1, 1, 1), (1, 4, 7))
55
+
56
+ """
57
+ def __new__(cls, p1, a=None, b=None, **kwargs):
58
+ p1 = Point3D(p1, dim=3)
59
+ if a and b:
60
+ p2 = Point(a, dim=3)
61
+ p3 = Point(b, dim=3)
62
+ if Point3D.are_collinear(p1, p2, p3):
63
+ raise ValueError('Enter three non-collinear points')
64
+ a = p1.direction_ratio(p2)
65
+ b = p1.direction_ratio(p3)
66
+ normal_vector = tuple(Matrix(a).cross(Matrix(b)))
67
+ else:
68
+ a = kwargs.pop('normal_vector', a)
69
+ evaluate = kwargs.get('evaluate', True)
70
+ if is_sequence(a) and len(a) == 3:
71
+ normal_vector = Point3D(a).args if evaluate else a
72
+ else:
73
+ raise ValueError(filldedent('''
74
+ Either provide 3 3D points or a point with a
75
+ normal vector expressed as a sequence of length 3'''))
76
+ if all(coord.is_zero for coord in normal_vector):
77
+ raise ValueError('Normal vector cannot be zero vector')
78
+ return GeometryEntity.__new__(cls, p1, normal_vector, **kwargs)
79
+
80
+ def __contains__(self, o):
81
+ k = self.equation(x, y, z)
82
+ if isinstance(o, (LinearEntity, LinearEntity3D)):
83
+ d = Point3D(o.arbitrary_point(t))
84
+ e = k.subs([(x, d.x), (y, d.y), (z, d.z)])
85
+ return e.equals(0)
86
+ try:
87
+ o = Point(o, dim=3, strict=True)
88
+ d = k.xreplace(dict(zip((x, y, z), o.args)))
89
+ return d.equals(0)
90
+ except TypeError:
91
+ return False
92
+
93
+ def _eval_evalf(self, prec=15, **options):
94
+ pt, tup = self.args
95
+ dps = prec_to_dps(prec)
96
+ pt = pt.evalf(n=dps, **options)
97
+ tup = tuple([i.evalf(n=dps, **options) for i in tup])
98
+ return self.func(pt, normal_vector=tup, evaluate=False)
99
+
100
+ def angle_between(self, o):
101
+ """Angle between the plane and other geometric entity.
102
+
103
+ Parameters
104
+ ==========
105
+
106
+ LinearEntity3D, Plane.
107
+
108
+ Returns
109
+ =======
110
+
111
+ angle : angle in radians
112
+
113
+ Notes
114
+ =====
115
+
116
+ This method accepts only 3D entities as it's parameter, but if you want
117
+ to calculate the angle between a 2D entity and a plane you should
118
+ first convert to a 3D entity by projecting onto a desired plane and
119
+ then proceed to calculate the angle.
120
+
121
+ Examples
122
+ ========
123
+
124
+ >>> from sympy import Point3D, Line3D, Plane
125
+ >>> a = Plane(Point3D(1, 2, 2), normal_vector=(1, 2, 3))
126
+ >>> b = Line3D(Point3D(1, 3, 4), Point3D(2, 2, 2))
127
+ >>> a.angle_between(b)
128
+ -asin(sqrt(21)/6)
129
+
130
+ """
131
+ if isinstance(o, LinearEntity3D):
132
+ a = Matrix(self.normal_vector)
133
+ b = Matrix(o.direction_ratio)
134
+ c = a.dot(b)
135
+ d = sqrt(sum(i**2 for i in self.normal_vector))
136
+ e = sqrt(sum(i**2 for i in o.direction_ratio))
137
+ return asin(c/(d*e))
138
+ if isinstance(o, Plane):
139
+ a = Matrix(self.normal_vector)
140
+ b = Matrix(o.normal_vector)
141
+ c = a.dot(b)
142
+ d = sqrt(sum(i**2 for i in self.normal_vector))
143
+ e = sqrt(sum(i**2 for i in o.normal_vector))
144
+ return acos(c/(d*e))
145
+
146
+
147
+ def arbitrary_point(self, u=None, v=None):
148
+ """ Returns an arbitrary point on the Plane. If given two
149
+ parameters, the point ranges over the entire plane. If given 1
150
+ or no parameters, returns a point with one parameter which,
151
+ when varying from 0 to 2*pi, moves the point in a circle of
152
+ radius 1 about p1 of the Plane.
153
+
154
+ Examples
155
+ ========
156
+
157
+ >>> from sympy import Plane, Ray
158
+ >>> from sympy.abc import u, v, t, r
159
+ >>> p = Plane((1, 1, 1), normal_vector=(1, 0, 0))
160
+ >>> p.arbitrary_point(u, v)
161
+ Point3D(1, u + 1, v + 1)
162
+ >>> p.arbitrary_point(t)
163
+ Point3D(1, cos(t) + 1, sin(t) + 1)
164
+
165
+ While arbitrary values of u and v can move the point anywhere in
166
+ the plane, the single-parameter point can be used to construct a
167
+ ray whose arbitrary point can be located at angle t and radius
168
+ r from p.p1:
169
+
170
+ >>> Ray(p.p1, _).arbitrary_point(r)
171
+ Point3D(1, r*cos(t) + 1, r*sin(t) + 1)
172
+
173
+ Returns
174
+ =======
175
+
176
+ Point3D
177
+
178
+ """
179
+ circle = v is None
180
+ if circle:
181
+ u = _symbol(u or 't', real=True)
182
+ else:
183
+ u = _symbol(u or 'u', real=True)
184
+ v = _symbol(v or 'v', real=True)
185
+ x, y, z = self.normal_vector
186
+ a, b, c = self.p1.args
187
+ # x1, y1, z1 is a nonzero vector parallel to the plane
188
+ if x.is_zero and y.is_zero:
189
+ x1, y1, z1 = S.One, S.Zero, S.Zero
190
+ else:
191
+ x1, y1, z1 = -y, x, S.Zero
192
+ # x2, y2, z2 is also parallel to the plane, and orthogonal to x1, y1, z1
193
+ x2, y2, z2 = tuple(Matrix((x, y, z)).cross(Matrix((x1, y1, z1))))
194
+ if circle:
195
+ x1, y1, z1 = (w/sqrt(x1**2 + y1**2 + z1**2) for w in (x1, y1, z1))
196
+ x2, y2, z2 = (w/sqrt(x2**2 + y2**2 + z2**2) for w in (x2, y2, z2))
197
+ p = Point3D(a + x1*cos(u) + x2*sin(u), \
198
+ b + y1*cos(u) + y2*sin(u), \
199
+ c + z1*cos(u) + z2*sin(u))
200
+ else:
201
+ p = Point3D(a + x1*u + x2*v, b + y1*u + y2*v, c + z1*u + z2*v)
202
+ return p
203
+
204
+
205
+ @staticmethod
206
+ def are_concurrent(*planes):
207
+ """Is a sequence of Planes concurrent?
208
+
209
+ Two or more Planes are concurrent if their intersections
210
+ are a common line.
211
+
212
+ Parameters
213
+ ==========
214
+
215
+ planes: list
216
+
217
+ Returns
218
+ =======
219
+
220
+ Boolean
221
+
222
+ Examples
223
+ ========
224
+
225
+ >>> from sympy import Plane, Point3D
226
+ >>> a = Plane(Point3D(5, 0, 0), normal_vector=(1, -1, 1))
227
+ >>> b = Plane(Point3D(0, -2, 0), normal_vector=(3, 1, 1))
228
+ >>> c = Plane(Point3D(0, -1, 0), normal_vector=(5, -1, 9))
229
+ >>> Plane.are_concurrent(a, b)
230
+ True
231
+ >>> Plane.are_concurrent(a, b, c)
232
+ False
233
+
234
+ """
235
+ planes = list(uniq(planes))
236
+ for i in planes:
237
+ if not isinstance(i, Plane):
238
+ raise ValueError('All objects should be Planes but got %s' % i.func)
239
+ if len(planes) < 2:
240
+ return False
241
+ planes = list(planes)
242
+ first = planes.pop(0)
243
+ sol = first.intersection(planes[0])
244
+ if sol == []:
245
+ return False
246
+ else:
247
+ line = sol[0]
248
+ for i in planes[1:]:
249
+ l = first.intersection(i)
250
+ if not l or l[0] not in line:
251
+ return False
252
+ return True
253
+
254
+
255
+ def distance(self, o):
256
+ """Distance between the plane and another geometric entity.
257
+
258
+ Parameters
259
+ ==========
260
+
261
+ Point3D, LinearEntity3D, Plane.
262
+
263
+ Returns
264
+ =======
265
+
266
+ distance
267
+
268
+ Notes
269
+ =====
270
+
271
+ This method accepts only 3D entities as it's parameter, but if you want
272
+ to calculate the distance between a 2D entity and a plane you should
273
+ first convert to a 3D entity by projecting onto a desired plane and
274
+ then proceed to calculate the distance.
275
+
276
+ Examples
277
+ ========
278
+
279
+ >>> from sympy import Point3D, Line3D, Plane
280
+ >>> a = Plane(Point3D(1, 1, 1), normal_vector=(1, 1, 1))
281
+ >>> b = Point3D(1, 2, 3)
282
+ >>> a.distance(b)
283
+ sqrt(3)
284
+ >>> c = Line3D(Point3D(2, 3, 1), Point3D(1, 2, 2))
285
+ >>> a.distance(c)
286
+ 0
287
+
288
+ """
289
+ if self.intersection(o) != []:
290
+ return S.Zero
291
+
292
+ if isinstance(o, (Segment3D, Ray3D)):
293
+ a, b = o.p1, o.p2
294
+ pi, = self.intersection(Line3D(a, b))
295
+ if pi in o:
296
+ return self.distance(pi)
297
+ elif a in Segment3D(pi, b):
298
+ return self.distance(a)
299
+ else:
300
+ assert isinstance(o, Segment3D) is True
301
+ return self.distance(b)
302
+
303
+ # following code handles `Point3D`, `LinearEntity3D`, `Plane`
304
+ a = o if isinstance(o, Point3D) else o.p1
305
+ n = Point3D(self.normal_vector).unit
306
+ d = (a - self.p1).dot(n)
307
+ return abs(d)
308
+
309
+
310
+ def equals(self, o):
311
+ """
312
+ Returns True if self and o are the same mathematical entities.
313
+
314
+ Examples
315
+ ========
316
+
317
+ >>> from sympy import Plane, Point3D
318
+ >>> a = Plane(Point3D(1, 2, 3), normal_vector=(1, 1, 1))
319
+ >>> b = Plane(Point3D(1, 2, 3), normal_vector=(2, 2, 2))
320
+ >>> c = Plane(Point3D(1, 2, 3), normal_vector=(-1, 4, 6))
321
+ >>> a.equals(a)
322
+ True
323
+ >>> a.equals(b)
324
+ True
325
+ >>> a.equals(c)
326
+ False
327
+ """
328
+ if isinstance(o, Plane):
329
+ a = self.equation()
330
+ b = o.equation()
331
+ return cancel(a/b).is_constant()
332
+ else:
333
+ return False
334
+
335
+
336
+ def equation(self, x=None, y=None, z=None):
337
+ """The equation of the Plane.
338
+
339
+ Examples
340
+ ========
341
+
342
+ >>> from sympy import Point3D, Plane
343
+ >>> a = Plane(Point3D(1, 1, 2), Point3D(2, 4, 7), Point3D(3, 5, 1))
344
+ >>> a.equation()
345
+ -23*x + 11*y - 2*z + 16
346
+ >>> a = Plane(Point3D(1, 4, 2), normal_vector=(6, 6, 6))
347
+ >>> a.equation()
348
+ 6*x + 6*y + 6*z - 42
349
+
350
+ """
351
+ x, y, z = [i if i else Symbol(j, real=True) for i, j in zip((x, y, z), 'xyz')]
352
+ a = Point3D(x, y, z)
353
+ b = self.p1.direction_ratio(a)
354
+ c = self.normal_vector
355
+ return (sum(i*j for i, j in zip(b, c)))
356
+
357
+
358
+ def intersection(self, o):
359
+ """ The intersection with other geometrical entity.
360
+
361
+ Parameters
362
+ ==========
363
+
364
+ Point, Point3D, LinearEntity, LinearEntity3D, Plane
365
+
366
+ Returns
367
+ =======
368
+
369
+ List
370
+
371
+ Examples
372
+ ========
373
+
374
+ >>> from sympy import Point3D, Line3D, Plane
375
+ >>> a = Plane(Point3D(1, 2, 3), normal_vector=(1, 1, 1))
376
+ >>> b = Point3D(1, 2, 3)
377
+ >>> a.intersection(b)
378
+ [Point3D(1, 2, 3)]
379
+ >>> c = Line3D(Point3D(1, 4, 7), Point3D(2, 2, 2))
380
+ >>> a.intersection(c)
381
+ [Point3D(2, 2, 2)]
382
+ >>> d = Plane(Point3D(6, 0, 0), normal_vector=(2, -5, 3))
383
+ >>> e = Plane(Point3D(2, 0, 0), normal_vector=(3, 4, -3))
384
+ >>> d.intersection(e)
385
+ [Line3D(Point3D(78/23, -24/23, 0), Point3D(147/23, 321/23, 23))]
386
+
387
+ """
388
+ if not isinstance(o, GeometryEntity):
389
+ o = Point(o, dim=3)
390
+ if isinstance(o, Point):
391
+ if o in self:
392
+ return [o]
393
+ else:
394
+ return []
395
+ if isinstance(o, (LinearEntity, LinearEntity3D)):
396
+ # recast to 3D
397
+ p1, p2 = o.p1, o.p2
398
+ if isinstance(o, Segment):
399
+ o = Segment3D(p1, p2)
400
+ elif isinstance(o, Ray):
401
+ o = Ray3D(p1, p2)
402
+ elif isinstance(o, Line):
403
+ o = Line3D(p1, p2)
404
+ else:
405
+ raise ValueError('unhandled linear entity: %s' % o.func)
406
+ if o in self:
407
+ return [o]
408
+ else:
409
+ a = Point3D(o.arbitrary_point(t))
410
+ p1, n = self.p1, Point3D(self.normal_vector)
411
+
412
+ # TODO: Replace solve with solveset, when this line is tested
413
+ c = solve((a - p1).dot(n), t)
414
+ if not c:
415
+ return []
416
+ else:
417
+ c = [i for i in c if i.is_real is not False]
418
+ if len(c) > 1:
419
+ c = [i for i in c if i.is_real]
420
+ if len(c) != 1:
421
+ raise Undecidable("not sure which point is real")
422
+ p = a.subs(t, c[0])
423
+ if p not in o:
424
+ return [] # e.g. a segment might not intersect a plane
425
+ return [p]
426
+ if isinstance(o, Plane):
427
+ if self.equals(o):
428
+ return [self]
429
+ if self.is_parallel(o):
430
+ return []
431
+ else:
432
+ x, y, z = map(Dummy, 'xyz')
433
+ a, b = Matrix([self.normal_vector]), Matrix([o.normal_vector])
434
+ c = list(a.cross(b))
435
+ d = self.equation(x, y, z)
436
+ e = o.equation(x, y, z)
437
+ result = list(linsolve([d, e], x, y, z))[0]
438
+ for i in (x, y, z): result = result.subs(i, 0)
439
+ return [Line3D(Point3D(result), direction_ratio=c)]
440
+
441
+
442
+ def is_coplanar(self, o):
443
+ """ Returns True if `o` is coplanar with self, else False.
444
+
445
+ Examples
446
+ ========
447
+
448
+ >>> from sympy import Plane
449
+ >>> o = (0, 0, 0)
450
+ >>> p = Plane(o, (1, 1, 1))
451
+ >>> p2 = Plane(o, (2, 2, 2))
452
+ >>> p == p2
453
+ False
454
+ >>> p.is_coplanar(p2)
455
+ True
456
+ """
457
+ if isinstance(o, Plane):
458
+ return not cancel(self.equation(x, y, z)/o.equation(x, y, z)).has(x, y, z)
459
+ if isinstance(o, Point3D):
460
+ return o in self
461
+ elif isinstance(o, LinearEntity3D):
462
+ return all(i in self for i in self)
463
+ elif isinstance(o, GeometryEntity): # XXX should only be handling 2D objects now
464
+ return all(i == 0 for i in self.normal_vector[:2])
465
+
466
+
467
+ def is_parallel(self, l):
468
+ """Is the given geometric entity parallel to the plane?
469
+
470
+ Parameters
471
+ ==========
472
+
473
+ LinearEntity3D or Plane
474
+
475
+ Returns
476
+ =======
477
+
478
+ Boolean
479
+
480
+ Examples
481
+ ========
482
+
483
+ >>> from sympy import Plane, Point3D
484
+ >>> a = Plane(Point3D(1,4,6), normal_vector=(2, 4, 6))
485
+ >>> b = Plane(Point3D(3,1,3), normal_vector=(4, 8, 12))
486
+ >>> a.is_parallel(b)
487
+ True
488
+
489
+ """
490
+ if isinstance(l, LinearEntity3D):
491
+ a = l.direction_ratio
492
+ b = self.normal_vector
493
+ return sum(i*j for i, j in zip(a, b)) == 0
494
+ if isinstance(l, Plane):
495
+ a = Matrix(l.normal_vector)
496
+ b = Matrix(self.normal_vector)
497
+ return bool(a.cross(b).is_zero_matrix)
498
+
499
+
500
+ def is_perpendicular(self, l):
501
+ """Is the given geometric entity perpendicualar to the given plane?
502
+
503
+ Parameters
504
+ ==========
505
+
506
+ LinearEntity3D or Plane
507
+
508
+ Returns
509
+ =======
510
+
511
+ Boolean
512
+
513
+ Examples
514
+ ========
515
+
516
+ >>> from sympy import Plane, Point3D
517
+ >>> a = Plane(Point3D(1,4,6), normal_vector=(2, 4, 6))
518
+ >>> b = Plane(Point3D(2, 2, 2), normal_vector=(-1, 2, -1))
519
+ >>> a.is_perpendicular(b)
520
+ True
521
+
522
+ """
523
+ if isinstance(l, LinearEntity3D):
524
+ a = Matrix(l.direction_ratio)
525
+ b = Matrix(self.normal_vector)
526
+ if a.cross(b).is_zero_matrix:
527
+ return True
528
+ else:
529
+ return False
530
+ elif isinstance(l, Plane):
531
+ a = Matrix(l.normal_vector)
532
+ b = Matrix(self.normal_vector)
533
+ if a.dot(b) == 0:
534
+ return True
535
+ else:
536
+ return False
537
+ else:
538
+ return False
539
+
540
+ @property
541
+ def normal_vector(self):
542
+ """Normal vector of the given plane.
543
+
544
+ Examples
545
+ ========
546
+
547
+ >>> from sympy import Point3D, Plane
548
+ >>> a = Plane(Point3D(1, 1, 1), Point3D(2, 3, 4), Point3D(2, 2, 2))
549
+ >>> a.normal_vector
550
+ (-1, 2, -1)
551
+ >>> a = Plane(Point3D(1, 1, 1), normal_vector=(1, 4, 7))
552
+ >>> a.normal_vector
553
+ (1, 4, 7)
554
+
555
+ """
556
+ return self.args[1]
557
+
558
+ @property
559
+ def p1(self):
560
+ """The only defining point of the plane. Others can be obtained from the
561
+ arbitrary_point method.
562
+
563
+ See Also
564
+ ========
565
+
566
+ sympy.geometry.point.Point3D
567
+
568
+ Examples
569
+ ========
570
+
571
+ >>> from sympy import Point3D, Plane
572
+ >>> a = Plane(Point3D(1, 1, 1), Point3D(2, 3, 4), Point3D(2, 2, 2))
573
+ >>> a.p1
574
+ Point3D(1, 1, 1)
575
+
576
+ """
577
+ return self.args[0]
578
+
579
+ def parallel_plane(self, pt):
580
+ """
581
+ Plane parallel to the given plane and passing through the point pt.
582
+
583
+ Parameters
584
+ ==========
585
+
586
+ pt: Point3D
587
+
588
+ Returns
589
+ =======
590
+
591
+ Plane
592
+
593
+ Examples
594
+ ========
595
+
596
+ >>> from sympy import Plane, Point3D
597
+ >>> a = Plane(Point3D(1, 4, 6), normal_vector=(2, 4, 6))
598
+ >>> a.parallel_plane(Point3D(2, 3, 5))
599
+ Plane(Point3D(2, 3, 5), (2, 4, 6))
600
+
601
+ """
602
+ a = self.normal_vector
603
+ return Plane(pt, normal_vector=a)
604
+
605
+ def perpendicular_line(self, pt):
606
+ """A line perpendicular to the given plane.
607
+
608
+ Parameters
609
+ ==========
610
+
611
+ pt: Point3D
612
+
613
+ Returns
614
+ =======
615
+
616
+ Line3D
617
+
618
+ Examples
619
+ ========
620
+
621
+ >>> from sympy import Plane, Point3D
622
+ >>> a = Plane(Point3D(1,4,6), normal_vector=(2, 4, 6))
623
+ >>> a.perpendicular_line(Point3D(9, 8, 7))
624
+ Line3D(Point3D(9, 8, 7), Point3D(11, 12, 13))
625
+
626
+ """
627
+ a = self.normal_vector
628
+ return Line3D(pt, direction_ratio=a)
629
+
630
+ def perpendicular_plane(self, *pts):
631
+ """
632
+ Return a perpendicular passing through the given points. If the
633
+ direction ratio between the points is the same as the Plane's normal
634
+ vector then, to select from the infinite number of possible planes,
635
+ a third point will be chosen on the z-axis (or the y-axis
636
+ if the normal vector is already parallel to the z-axis). If less than
637
+ two points are given they will be supplied as follows: if no point is
638
+ given then pt1 will be self.p1; if a second point is not given it will
639
+ be a point through pt1 on a line parallel to the z-axis (if the normal
640
+ is not already the z-axis, otherwise on the line parallel to the
641
+ y-axis).
642
+
643
+ Parameters
644
+ ==========
645
+
646
+ pts: 0, 1 or 2 Point3D
647
+
648
+ Returns
649
+ =======
650
+
651
+ Plane
652
+
653
+ Examples
654
+ ========
655
+
656
+ >>> from sympy import Plane, Point3D
657
+ >>> a, b = Point3D(0, 0, 0), Point3D(0, 1, 0)
658
+ >>> Z = (0, 0, 1)
659
+ >>> p = Plane(a, normal_vector=Z)
660
+ >>> p.perpendicular_plane(a, b)
661
+ Plane(Point3D(0, 0, 0), (1, 0, 0))
662
+ """
663
+ if len(pts) > 2:
664
+ raise ValueError('No more than 2 pts should be provided.')
665
+
666
+ pts = list(pts)
667
+ if len(pts) == 0:
668
+ pts.append(self.p1)
669
+ if len(pts) == 1:
670
+ x, y, z = self.normal_vector
671
+ if x == y == 0:
672
+ dir = (0, 1, 0)
673
+ else:
674
+ dir = (0, 0, 1)
675
+ pts.append(pts[0] + Point3D(*dir))
676
+
677
+ p1, p2 = [Point(i, dim=3) for i in pts]
678
+ l = Line3D(p1, p2)
679
+ n = Line3D(p1, direction_ratio=self.normal_vector)
680
+ if l in n: # XXX should an error be raised instead?
681
+ # there are infinitely many perpendicular planes;
682
+ x, y, z = self.normal_vector
683
+ if x == y == 0:
684
+ # the z axis is the normal so pick a pt on the y-axis
685
+ p3 = Point3D(0, 1, 0) # case 1
686
+ else:
687
+ # else pick a pt on the z axis
688
+ p3 = Point3D(0, 0, 1) # case 2
689
+ # in case that point is already given, move it a bit
690
+ if p3 in l:
691
+ p3 *= 2 # case 3
692
+ else:
693
+ p3 = p1 + Point3D(*self.normal_vector) # case 4
694
+ return Plane(p1, p2, p3)
695
+
696
+ def projection_line(self, line):
697
+ """Project the given line onto the plane through the normal plane
698
+ containing the line.
699
+
700
+ Parameters
701
+ ==========
702
+
703
+ LinearEntity or LinearEntity3D
704
+
705
+ Returns
706
+ =======
707
+
708
+ Point3D, Line3D, Ray3D or Segment3D
709
+
710
+ Notes
711
+ =====
712
+
713
+ For the interaction between 2D and 3D lines(segments, rays), you should
714
+ convert the line to 3D by using this method. For example for finding the
715
+ intersection between a 2D and a 3D line, convert the 2D line to a 3D line
716
+ by projecting it on a required plane and then proceed to find the
717
+ intersection between those lines.
718
+
719
+ Examples
720
+ ========
721
+
722
+ >>> from sympy import Plane, Line, Line3D, Point3D
723
+ >>> a = Plane(Point3D(1, 1, 1), normal_vector=(1, 1, 1))
724
+ >>> b = Line(Point3D(1, 1), Point3D(2, 2))
725
+ >>> a.projection_line(b)
726
+ Line3D(Point3D(4/3, 4/3, 1/3), Point3D(5/3, 5/3, -1/3))
727
+ >>> c = Line3D(Point3D(1, 1, 1), Point3D(2, 2, 2))
728
+ >>> a.projection_line(c)
729
+ Point3D(1, 1, 1)
730
+
731
+ """
732
+ if not isinstance(line, (LinearEntity, LinearEntity3D)):
733
+ raise NotImplementedError('Enter a linear entity only')
734
+ a, b = self.projection(line.p1), self.projection(line.p2)
735
+ if a == b:
736
+ # projection does not imply intersection so for
737
+ # this case (line parallel to plane's normal) we
738
+ # return the projection point
739
+ return a
740
+ if isinstance(line, (Line, Line3D)):
741
+ return Line3D(a, b)
742
+ if isinstance(line, (Ray, Ray3D)):
743
+ return Ray3D(a, b)
744
+ if isinstance(line, (Segment, Segment3D)):
745
+ return Segment3D(a, b)
746
+
747
+ def projection(self, pt):
748
+ """Project the given point onto the plane along the plane normal.
749
+
750
+ Parameters
751
+ ==========
752
+
753
+ Point or Point3D
754
+
755
+ Returns
756
+ =======
757
+
758
+ Point3D
759
+
760
+ Examples
761
+ ========
762
+
763
+ >>> from sympy import Plane, Point3D
764
+ >>> A = Plane(Point3D(1, 1, 2), normal_vector=(1, 1, 1))
765
+
766
+ The projection is along the normal vector direction, not the z
767
+ axis, so (1, 1) does not project to (1, 1, 2) on the plane A:
768
+
769
+ >>> b = Point3D(1, 1)
770
+ >>> A.projection(b)
771
+ Point3D(5/3, 5/3, 2/3)
772
+ >>> _ in A
773
+ True
774
+
775
+ But the point (1, 1, 2) projects to (1, 1) on the XY-plane:
776
+
777
+ >>> XY = Plane((0, 0, 0), (0, 0, 1))
778
+ >>> XY.projection((1, 1, 2))
779
+ Point3D(1, 1, 0)
780
+ """
781
+ rv = Point(pt, dim=3)
782
+ if rv in self:
783
+ return rv
784
+ return self.intersection(Line3D(rv, rv + Point3D(self.normal_vector)))[0]
785
+
786
+ def random_point(self, seed=None):
787
+ """ Returns a random point on the Plane.
788
+
789
+ Returns
790
+ =======
791
+
792
+ Point3D
793
+
794
+ Examples
795
+ ========
796
+
797
+ >>> from sympy import Plane
798
+ >>> p = Plane((1, 0, 0), normal_vector=(0, 1, 0))
799
+ >>> r = p.random_point(seed=42) # seed value is optional
800
+ >>> r.n(3)
801
+ Point3D(2.29, 0, -1.35)
802
+
803
+ The random point can be moved to lie on the circle of radius
804
+ 1 centered on p1:
805
+
806
+ >>> c = p.p1 + (r - p.p1).unit
807
+ >>> c.distance(p.p1).equals(1)
808
+ True
809
+ """
810
+ if seed is not None:
811
+ rng = random.Random(seed)
812
+ else:
813
+ rng = random
814
+ params = {
815
+ x: 2*Rational(rng.gauss(0, 1)) - 1,
816
+ y: 2*Rational(rng.gauss(0, 1)) - 1}
817
+ return self.arbitrary_point(x, y).subs(params)
818
+
819
+ def parameter_value(self, other, u, v=None):
820
+ """Return the parameter(s) corresponding to the given point.
821
+
822
+ Examples
823
+ ========
824
+
825
+ >>> from sympy import pi, Plane
826
+ >>> from sympy.abc import t, u, v
827
+ >>> p = Plane((2, 0, 0), (0, 0, 1), (0, 1, 0))
828
+
829
+ By default, the parameter value returned defines a point
830
+ that is a distance of 1 from the Plane's p1 value and
831
+ in line with the given point:
832
+
833
+ >>> on_circle = p.arbitrary_point(t).subs(t, pi/4)
834
+ >>> on_circle.distance(p.p1)
835
+ 1
836
+ >>> p.parameter_value(on_circle, t)
837
+ {t: pi/4}
838
+
839
+ Moving the point twice as far from p1 does not change
840
+ the parameter value:
841
+
842
+ >>> off_circle = p.p1 + (on_circle - p.p1)*2
843
+ >>> off_circle.distance(p.p1)
844
+ 2
845
+ >>> p.parameter_value(off_circle, t)
846
+ {t: pi/4}
847
+
848
+ If the 2-value parameter is desired, supply the two
849
+ parameter symbols and a replacement dictionary will
850
+ be returned:
851
+
852
+ >>> p.parameter_value(on_circle, u, v)
853
+ {u: sqrt(10)/10, v: sqrt(10)/30}
854
+ >>> p.parameter_value(off_circle, u, v)
855
+ {u: sqrt(10)/5, v: sqrt(10)/15}
856
+ """
857
+ if not isinstance(other, GeometryEntity):
858
+ other = Point(other, dim=self.ambient_dimension)
859
+ if not isinstance(other, Point):
860
+ raise ValueError("other must be a point")
861
+ if other == self.p1:
862
+ return other
863
+ if isinstance(u, Symbol) and v is None:
864
+ delta = self.arbitrary_point(u) - self.p1
865
+ eq = delta - (other - self.p1).unit
866
+ sol = solve(eq, u, dict=True)
867
+ elif isinstance(u, Symbol) and isinstance(v, Symbol):
868
+ pt = self.arbitrary_point(u, v)
869
+ sol = solve(pt - other, (u, v), dict=True)
870
+ else:
871
+ raise ValueError('expecting 1 or 2 symbols')
872
+ if not sol:
873
+ raise ValueError("Given point is not on %s" % func_name(self))
874
+ return sol[0] # {t: tval} or {u: uval, v: vval}
875
+
876
+ @property
877
+ def ambient_dimension(self):
878
+ return self.p1.ambient_dimension
.venv/lib/python3.13/site-packages/sympy/geometry/point.py ADDED
@@ -0,0 +1,1378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Geometrical Points.
2
+
3
+ Contains
4
+ ========
5
+ Point
6
+ Point2D
7
+ Point3D
8
+
9
+ When methods of Point require 1 or more points as arguments, they
10
+ can be passed as a sequence of coordinates or Points:
11
+
12
+ >>> from sympy import Point
13
+ >>> Point(1, 1).is_collinear((2, 2), (3, 4))
14
+ False
15
+ >>> Point(1, 1).is_collinear(Point(2, 2), Point(3, 4))
16
+ False
17
+
18
+ """
19
+
20
+ import warnings
21
+
22
+ from sympy.core import S, sympify, Expr
23
+ from sympy.core.add import Add
24
+ from sympy.core.containers import Tuple
25
+ from sympy.core.numbers import Float
26
+ from sympy.core.parameters import global_parameters
27
+ from sympy.simplify.simplify import nsimplify, simplify
28
+ from sympy.geometry.exceptions import GeometryError
29
+ from sympy.functions.elementary.miscellaneous import sqrt
30
+ from sympy.functions.elementary.complexes import im
31
+ from sympy.functions.elementary.trigonometric import cos, sin
32
+ from sympy.matrices import Matrix
33
+ from sympy.matrices.expressions import Transpose
34
+ from sympy.utilities.iterables import uniq, is_sequence
35
+ from sympy.utilities.misc import filldedent, func_name, Undecidable
36
+
37
+ from .entity import GeometryEntity
38
+
39
+ from mpmath.libmp.libmpf import prec_to_dps
40
+
41
+
42
+ class Point(GeometryEntity):
43
+ """A point in a n-dimensional Euclidean space.
44
+
45
+ Parameters
46
+ ==========
47
+
48
+ coords : sequence of n-coordinate values. In the special
49
+ case where n=2 or 3, a Point2D or Point3D will be created
50
+ as appropriate.
51
+ evaluate : if `True` (default), all floats are turn into
52
+ exact types.
53
+ dim : number of coordinates the point should have. If coordinates
54
+ are unspecified, they are padded with zeros.
55
+ on_morph : indicates what should happen when the number of
56
+ coordinates of a point need to be changed by adding or
57
+ removing zeros. Possible values are `'warn'`, `'error'`, or
58
+ `ignore` (default). No warning or error is given when `*args`
59
+ is empty and `dim` is given. An error is always raised when
60
+ trying to remove nonzero coordinates.
61
+
62
+
63
+ Attributes
64
+ ==========
65
+
66
+ length
67
+ origin: A `Point` representing the origin of the
68
+ appropriately-dimensioned space.
69
+
70
+ Raises
71
+ ======
72
+
73
+ TypeError : When instantiating with anything but a Point or sequence
74
+ ValueError : when instantiating with a sequence with length < 2 or
75
+ when trying to reduce dimensions if keyword `on_morph='error'` is
76
+ set.
77
+
78
+ See Also
79
+ ========
80
+
81
+ sympy.geometry.line.Segment : Connects two Points
82
+
83
+ Examples
84
+ ========
85
+
86
+ >>> from sympy import Point
87
+ >>> from sympy.abc import x
88
+ >>> Point(1, 2, 3)
89
+ Point3D(1, 2, 3)
90
+ >>> Point([1, 2])
91
+ Point2D(1, 2)
92
+ >>> Point(0, x)
93
+ Point2D(0, x)
94
+ >>> Point(dim=4)
95
+ Point(0, 0, 0, 0)
96
+
97
+ Floats are automatically converted to Rational unless the
98
+ evaluate flag is False:
99
+
100
+ >>> Point(0.5, 0.25)
101
+ Point2D(1/2, 1/4)
102
+ >>> Point(0.5, 0.25, evaluate=False)
103
+ Point2D(0.5, 0.25)
104
+
105
+ """
106
+
107
+ is_Point = True
108
+
109
+ def __new__(cls, *args, **kwargs):
110
+ evaluate = kwargs.get('evaluate', global_parameters.evaluate)
111
+ on_morph = kwargs.get('on_morph', 'ignore')
112
+
113
+ # unpack into coords
114
+ coords = args[0] if len(args) == 1 else args
115
+
116
+ # check args and handle quickly handle Point instances
117
+ if isinstance(coords, Point):
118
+ # even if we're mutating the dimension of a point, we
119
+ # don't reevaluate its coordinates
120
+ evaluate = False
121
+ if len(coords) == kwargs.get('dim', len(coords)):
122
+ return coords
123
+
124
+ if not is_sequence(coords):
125
+ raise TypeError(filldedent('''
126
+ Expecting sequence of coordinates, not `{}`'''
127
+ .format(func_name(coords))))
128
+ # A point where only `dim` is specified is initialized
129
+ # to zeros.
130
+ if len(coords) == 0 and kwargs.get('dim', None):
131
+ coords = (S.Zero,)*kwargs.get('dim')
132
+
133
+ coords = Tuple(*coords)
134
+ dim = kwargs.get('dim', len(coords))
135
+
136
+ if len(coords) < 2:
137
+ raise ValueError(filldedent('''
138
+ Point requires 2 or more coordinates or
139
+ keyword `dim` > 1.'''))
140
+ if len(coords) != dim:
141
+ message = ("Dimension of {} needs to be changed "
142
+ "from {} to {}.").format(coords, len(coords), dim)
143
+ if on_morph == 'ignore':
144
+ pass
145
+ elif on_morph == "error":
146
+ raise ValueError(message)
147
+ elif on_morph == 'warn':
148
+ warnings.warn(message, stacklevel=2)
149
+ else:
150
+ raise ValueError(filldedent('''
151
+ on_morph value should be 'error',
152
+ 'warn' or 'ignore'.'''))
153
+ if any(coords[dim:]):
154
+ raise ValueError('Nonzero coordinates cannot be removed.')
155
+ if any(a.is_number and im(a).is_zero is False for a in coords):
156
+ raise ValueError('Imaginary coordinates are not permitted.')
157
+ if not all(isinstance(a, Expr) for a in coords):
158
+ raise TypeError('Coordinates must be valid SymPy expressions.')
159
+
160
+ # pad with zeros appropriately
161
+ coords = coords[:dim] + (S.Zero,)*(dim - len(coords))
162
+
163
+ # Turn any Floats into rationals and simplify
164
+ # any expressions before we instantiate
165
+ if evaluate:
166
+ coords = coords.xreplace({
167
+ f: simplify(nsimplify(f, rational=True))
168
+ for f in coords.atoms(Float)})
169
+
170
+ # return 2D or 3D instances
171
+ if len(coords) == 2:
172
+ kwargs['_nocheck'] = True
173
+ return Point2D(*coords, **kwargs)
174
+ elif len(coords) == 3:
175
+ kwargs['_nocheck'] = True
176
+ return Point3D(*coords, **kwargs)
177
+
178
+ # the general Point
179
+ return GeometryEntity.__new__(cls, *coords)
180
+
181
+ def __abs__(self):
182
+ """Returns the distance between this point and the origin."""
183
+ origin = Point([0]*len(self))
184
+ return Point.distance(origin, self)
185
+
186
+ def __add__(self, other):
187
+ """Add other to self by incrementing self's coordinates by
188
+ those of other.
189
+
190
+ Notes
191
+ =====
192
+
193
+ >>> from sympy import Point
194
+
195
+ When sequences of coordinates are passed to Point methods, they
196
+ are converted to a Point internally. This __add__ method does
197
+ not do that so if floating point values are used, a floating
198
+ point result (in terms of SymPy Floats) will be returned.
199
+
200
+ >>> Point(1, 2) + (.1, .2)
201
+ Point2D(1.1, 2.2)
202
+
203
+ If this is not desired, the `translate` method can be used or
204
+ another Point can be added:
205
+
206
+ >>> Point(1, 2).translate(.1, .2)
207
+ Point2D(11/10, 11/5)
208
+ >>> Point(1, 2) + Point(.1, .2)
209
+ Point2D(11/10, 11/5)
210
+
211
+ See Also
212
+ ========
213
+
214
+ sympy.geometry.point.Point.translate
215
+
216
+ """
217
+ try:
218
+ s, o = Point._normalize_dimension(self, Point(other, evaluate=False))
219
+ except TypeError:
220
+ raise GeometryError("Don't know how to add {} and a Point object".format(other))
221
+
222
+ coords = [simplify(a + b) for a, b in zip(s, o)]
223
+ return Point(coords, evaluate=False)
224
+
225
+ def __contains__(self, item):
226
+ return item in self.args
227
+
228
+ def __truediv__(self, divisor):
229
+ """Divide point's coordinates by a factor."""
230
+ divisor = sympify(divisor)
231
+ coords = [simplify(x/divisor) for x in self.args]
232
+ return Point(coords, evaluate=False)
233
+
234
+ def __eq__(self, other):
235
+ if not isinstance(other, Point) or len(self.args) != len(other.args):
236
+ return False
237
+ return self.args == other.args
238
+
239
+ def __getitem__(self, key):
240
+ return self.args[key]
241
+
242
+ def __hash__(self):
243
+ return hash(self.args)
244
+
245
+ def __iter__(self):
246
+ return self.args.__iter__()
247
+
248
+ def __len__(self):
249
+ return len(self.args)
250
+
251
+ def __mul__(self, factor):
252
+ """Multiply point's coordinates by a factor.
253
+
254
+ Notes
255
+ =====
256
+
257
+ >>> from sympy import Point
258
+
259
+ When multiplying a Point by a floating point number,
260
+ the coordinates of the Point will be changed to Floats:
261
+
262
+ >>> Point(1, 2)*0.1
263
+ Point2D(0.1, 0.2)
264
+
265
+ If this is not desired, the `scale` method can be used or
266
+ else only multiply or divide by integers:
267
+
268
+ >>> Point(1, 2).scale(1.1, 1.1)
269
+ Point2D(11/10, 11/5)
270
+ >>> Point(1, 2)*11/10
271
+ Point2D(11/10, 11/5)
272
+
273
+ See Also
274
+ ========
275
+
276
+ sympy.geometry.point.Point.scale
277
+ """
278
+ factor = sympify(factor)
279
+ coords = [simplify(x*factor) for x in self.args]
280
+ return Point(coords, evaluate=False)
281
+
282
+ def __rmul__(self, factor):
283
+ """Multiply a factor by point's coordinates."""
284
+ return self.__mul__(factor)
285
+
286
+ def __neg__(self):
287
+ """Negate the point."""
288
+ coords = [-x for x in self.args]
289
+ return Point(coords, evaluate=False)
290
+
291
+ def __sub__(self, other):
292
+ """Subtract two points, or subtract a factor from this point's
293
+ coordinates."""
294
+ return self + [-x for x in other]
295
+
296
+ @classmethod
297
+ def _normalize_dimension(cls, *points, **kwargs):
298
+ """Ensure that points have the same dimension.
299
+ By default `on_morph='warn'` is passed to the
300
+ `Point` constructor."""
301
+ # if we have a built-in ambient dimension, use it
302
+ dim = getattr(cls, '_ambient_dimension', None)
303
+ # override if we specified it
304
+ dim = kwargs.get('dim', dim)
305
+ # if no dim was given, use the highest dimensional point
306
+ if dim is None:
307
+ dim = max(i.ambient_dimension for i in points)
308
+ if all(i.ambient_dimension == dim for i in points):
309
+ return list(points)
310
+ kwargs['dim'] = dim
311
+ kwargs['on_morph'] = kwargs.get('on_morph', 'warn')
312
+ return [Point(i, **kwargs) for i in points]
313
+
314
+ @staticmethod
315
+ def affine_rank(*args):
316
+ """The affine rank of a set of points is the dimension
317
+ of the smallest affine space containing all the points.
318
+ For example, if the points lie on a line (and are not all
319
+ the same) their affine rank is 1. If the points lie on a plane
320
+ but not a line, their affine rank is 2. By convention, the empty
321
+ set has affine rank -1."""
322
+
323
+ if len(args) == 0:
324
+ return -1
325
+ # make sure we're genuinely points
326
+ # and translate every point to the origin
327
+ points = Point._normalize_dimension(*[Point(i) for i in args])
328
+ origin = points[0]
329
+ points = [i - origin for i in points[1:]]
330
+
331
+ m = Matrix([i.args for i in points])
332
+ # XXX fragile -- what is a better way?
333
+ return m.rank(iszerofunc = lambda x:
334
+ abs(x.n(2)) < 1e-12 if x.is_number else x.is_zero)
335
+
336
+ @property
337
+ def ambient_dimension(self):
338
+ """Number of components this point has."""
339
+ return getattr(self, '_ambient_dimension', len(self))
340
+
341
+ @classmethod
342
+ def are_coplanar(cls, *points):
343
+ """Return True if there exists a plane in which all the points
344
+ lie. A trivial True value is returned if `len(points) < 3` or
345
+ all Points are 2-dimensional.
346
+
347
+ Parameters
348
+ ==========
349
+
350
+ A set of points
351
+
352
+ Raises
353
+ ======
354
+
355
+ ValueError : if less than 3 unique points are given
356
+
357
+ Returns
358
+ =======
359
+
360
+ boolean
361
+
362
+ Examples
363
+ ========
364
+
365
+ >>> from sympy import Point3D
366
+ >>> p1 = Point3D(1, 2, 2)
367
+ >>> p2 = Point3D(2, 7, 2)
368
+ >>> p3 = Point3D(0, 0, 2)
369
+ >>> p4 = Point3D(1, 1, 2)
370
+ >>> Point3D.are_coplanar(p1, p2, p3, p4)
371
+ True
372
+ >>> p5 = Point3D(0, 1, 3)
373
+ >>> Point3D.are_coplanar(p1, p2, p3, p5)
374
+ False
375
+
376
+ """
377
+ if len(points) <= 1:
378
+ return True
379
+
380
+ points = cls._normalize_dimension(*[Point(i) for i in points])
381
+ # quick exit if we are in 2D
382
+ if points[0].ambient_dimension == 2:
383
+ return True
384
+ points = list(uniq(points))
385
+ return Point.affine_rank(*points) <= 2
386
+
387
+ def distance(self, other):
388
+ """The Euclidean distance between self and another GeometricEntity.
389
+
390
+ Returns
391
+ =======
392
+
393
+ distance : number or symbolic expression.
394
+
395
+ Raises
396
+ ======
397
+
398
+ TypeError : if other is not recognized as a GeometricEntity or is a
399
+ GeometricEntity for which distance is not defined.
400
+
401
+ See Also
402
+ ========
403
+
404
+ sympy.geometry.line.Segment.length
405
+ sympy.geometry.point.Point.taxicab_distance
406
+
407
+ Examples
408
+ ========
409
+
410
+ >>> from sympy import Point, Line
411
+ >>> p1, p2 = Point(1, 1), Point(4, 5)
412
+ >>> l = Line((3, 1), (2, 2))
413
+ >>> p1.distance(p2)
414
+ 5
415
+ >>> p1.distance(l)
416
+ sqrt(2)
417
+
418
+ The computed distance may be symbolic, too:
419
+
420
+ >>> from sympy.abc import x, y
421
+ >>> p3 = Point(x, y)
422
+ >>> p3.distance((0, 0))
423
+ sqrt(x**2 + y**2)
424
+
425
+ """
426
+ if not isinstance(other, GeometryEntity):
427
+ try:
428
+ other = Point(other, dim=self.ambient_dimension)
429
+ except TypeError:
430
+ raise TypeError("not recognized as a GeometricEntity: %s" % type(other))
431
+ if isinstance(other, Point):
432
+ s, p = Point._normalize_dimension(self, Point(other))
433
+ return sqrt(Add(*((a - b)**2 for a, b in zip(s, p))))
434
+ distance = getattr(other, 'distance', None)
435
+ if distance is None:
436
+ raise TypeError("distance between Point and %s is not defined" % type(other))
437
+ return distance(self)
438
+
439
+ def dot(self, p):
440
+ """Return dot product of self with another Point."""
441
+ if not is_sequence(p):
442
+ p = Point(p) # raise the error via Point
443
+ return Add(*(a*b for a, b in zip(self, p)))
444
+
445
+ def equals(self, other):
446
+ """Returns whether the coordinates of self and other agree."""
447
+ # a point is equal to another point if all its components are equal
448
+ if not isinstance(other, Point) or len(self) != len(other):
449
+ return False
450
+ return all(a.equals(b) for a, b in zip(self, other))
451
+
452
+ def _eval_evalf(self, prec=15, **options):
453
+ """Evaluate the coordinates of the point.
454
+
455
+ This method will, where possible, create and return a new Point
456
+ where the coordinates are evaluated as floating point numbers to
457
+ the precision indicated (default=15).
458
+
459
+ Parameters
460
+ ==========
461
+
462
+ prec : int
463
+
464
+ Returns
465
+ =======
466
+
467
+ point : Point
468
+
469
+ Examples
470
+ ========
471
+
472
+ >>> from sympy import Point, Rational
473
+ >>> p1 = Point(Rational(1, 2), Rational(3, 2))
474
+ >>> p1
475
+ Point2D(1/2, 3/2)
476
+ >>> p1.evalf()
477
+ Point2D(0.5, 1.5)
478
+
479
+ """
480
+ dps = prec_to_dps(prec)
481
+ coords = [x.evalf(n=dps, **options) for x in self.args]
482
+ return Point(*coords, evaluate=False)
483
+
484
+ def intersection(self, other):
485
+ """The intersection between this point and another GeometryEntity.
486
+
487
+ Parameters
488
+ ==========
489
+
490
+ other : GeometryEntity or sequence of coordinates
491
+
492
+ Returns
493
+ =======
494
+
495
+ intersection : list of Points
496
+
497
+ Notes
498
+ =====
499
+
500
+ The return value will either be an empty list if there is no
501
+ intersection, otherwise it will contain this point.
502
+
503
+ Examples
504
+ ========
505
+
506
+ >>> from sympy import Point
507
+ >>> p1, p2, p3 = Point(0, 0), Point(1, 1), Point(0, 0)
508
+ >>> p1.intersection(p2)
509
+ []
510
+ >>> p1.intersection(p3)
511
+ [Point2D(0, 0)]
512
+
513
+ """
514
+ if not isinstance(other, GeometryEntity):
515
+ other = Point(other)
516
+ if isinstance(other, Point):
517
+ if self == other:
518
+ return [self]
519
+ p1, p2 = Point._normalize_dimension(self, other)
520
+ if p1 == self and p1 == p2:
521
+ return [self]
522
+ return []
523
+ return other.intersection(self)
524
+
525
+ def is_collinear(self, *args):
526
+ """Returns `True` if there exists a line
527
+ that contains `self` and `points`. Returns `False` otherwise.
528
+ A trivially True value is returned if no points are given.
529
+
530
+ Parameters
531
+ ==========
532
+
533
+ args : sequence of Points
534
+
535
+ Returns
536
+ =======
537
+
538
+ is_collinear : boolean
539
+
540
+ See Also
541
+ ========
542
+
543
+ sympy.geometry.line.Line
544
+
545
+ Examples
546
+ ========
547
+
548
+ >>> from sympy import Point
549
+ >>> from sympy.abc import x
550
+ >>> p1, p2 = Point(0, 0), Point(1, 1)
551
+ >>> p3, p4, p5 = Point(2, 2), Point(x, x), Point(1, 2)
552
+ >>> Point.is_collinear(p1, p2, p3, p4)
553
+ True
554
+ >>> Point.is_collinear(p1, p2, p3, p5)
555
+ False
556
+
557
+ """
558
+ points = (self,) + args
559
+ points = Point._normalize_dimension(*[Point(i) for i in points])
560
+ points = list(uniq(points))
561
+ return Point.affine_rank(*points) <= 1
562
+
563
+ def is_concyclic(self, *args):
564
+ """Do `self` and the given sequence of points lie in a circle?
565
+
566
+ Returns True if the set of points are concyclic and
567
+ False otherwise. A trivial value of True is returned
568
+ if there are fewer than 2 other points.
569
+
570
+ Parameters
571
+ ==========
572
+
573
+ args : sequence of Points
574
+
575
+ Returns
576
+ =======
577
+
578
+ is_concyclic : boolean
579
+
580
+
581
+ Examples
582
+ ========
583
+
584
+ >>> from sympy import Point
585
+
586
+ Define 4 points that are on the unit circle:
587
+
588
+ >>> p1, p2, p3, p4 = Point(1, 0), (0, 1), (-1, 0), (0, -1)
589
+
590
+ >>> p1.is_concyclic() == p1.is_concyclic(p2, p3, p4) == True
591
+ True
592
+
593
+ Define a point not on that circle:
594
+
595
+ >>> p = Point(1, 1)
596
+
597
+ >>> p.is_concyclic(p1, p2, p3)
598
+ False
599
+
600
+ """
601
+ points = (self,) + args
602
+ points = Point._normalize_dimension(*[Point(i) for i in points])
603
+ points = list(uniq(points))
604
+ if not Point.affine_rank(*points) <= 2:
605
+ return False
606
+ origin = points[0]
607
+ points = [p - origin for p in points]
608
+ # points are concyclic if they are coplanar and
609
+ # there is a point c so that ||p_i-c|| == ||p_j-c|| for all
610
+ # i and j. Rearranging this equation gives us the following
611
+ # condition: the matrix `mat` must not a pivot in the last
612
+ # column.
613
+ mat = Matrix([list(i) + [i.dot(i)] for i in points])
614
+ rref, pivots = mat.rref()
615
+ if len(origin) not in pivots:
616
+ return True
617
+ return False
618
+
619
+ @property
620
+ def is_nonzero(self):
621
+ """True if any coordinate is nonzero, False if every coordinate is zero,
622
+ and None if it cannot be determined."""
623
+ is_zero = self.is_zero
624
+ if is_zero is None:
625
+ return None
626
+ return not is_zero
627
+
628
+ def is_scalar_multiple(self, p):
629
+ """Returns whether each coordinate of `self` is a scalar
630
+ multiple of the corresponding coordinate in point p.
631
+ """
632
+ s, o = Point._normalize_dimension(self, Point(p))
633
+ # 2d points happen a lot, so optimize this function call
634
+ if s.ambient_dimension == 2:
635
+ (x1, y1), (x2, y2) = s.args, o.args
636
+ rv = (x1*y2 - x2*y1).equals(0)
637
+ if rv is None:
638
+ raise Undecidable(filldedent(
639
+ '''Cannot determine if %s is a scalar multiple of
640
+ %s''' % (s, o)))
641
+
642
+ # if the vectors p1 and p2 are linearly dependent, then they must
643
+ # be scalar multiples of each other
644
+ m = Matrix([s.args, o.args])
645
+ return m.rank() < 2
646
+
647
+ @property
648
+ def is_zero(self):
649
+ """True if every coordinate is zero, False if any coordinate is not zero,
650
+ and None if it cannot be determined."""
651
+ nonzero = [x.is_nonzero for x in self.args]
652
+ if any(nonzero):
653
+ return False
654
+ if any(x is None for x in nonzero):
655
+ return None
656
+ return True
657
+
658
+ @property
659
+ def length(self):
660
+ """
661
+ Treating a Point as a Line, this returns 0 for the length of a Point.
662
+
663
+ Examples
664
+ ========
665
+
666
+ >>> from sympy import Point
667
+ >>> p = Point(0, 1)
668
+ >>> p.length
669
+ 0
670
+ """
671
+ return S.Zero
672
+
673
+ def midpoint(self, p):
674
+ """The midpoint between self and point p.
675
+
676
+ Parameters
677
+ ==========
678
+
679
+ p : Point
680
+
681
+ Returns
682
+ =======
683
+
684
+ midpoint : Point
685
+
686
+ See Also
687
+ ========
688
+
689
+ sympy.geometry.line.Segment.midpoint
690
+
691
+ Examples
692
+ ========
693
+
694
+ >>> from sympy import Point
695
+ >>> p1, p2 = Point(1, 1), Point(13, 5)
696
+ >>> p1.midpoint(p2)
697
+ Point2D(7, 3)
698
+
699
+ """
700
+ s, p = Point._normalize_dimension(self, Point(p))
701
+ return Point([simplify((a + b)*S.Half) for a, b in zip(s, p)])
702
+
703
+ @property
704
+ def origin(self):
705
+ """A point of all zeros of the same ambient dimension
706
+ as the current point"""
707
+ return Point([0]*len(self), evaluate=False)
708
+
709
+ @property
710
+ def orthogonal_direction(self):
711
+ """Returns a non-zero point that is orthogonal to the
712
+ line containing `self` and the origin.
713
+
714
+ Examples
715
+ ========
716
+
717
+ >>> from sympy import Line, Point
718
+ >>> a = Point(1, 2, 3)
719
+ >>> a.orthogonal_direction
720
+ Point3D(-2, 1, 0)
721
+ >>> b = _
722
+ >>> Line(b, b.origin).is_perpendicular(Line(a, a.origin))
723
+ True
724
+ """
725
+ dim = self.ambient_dimension
726
+ # if a coordinate is zero, we can put a 1 there and zeros elsewhere
727
+ if self[0].is_zero:
728
+ return Point([1] + (dim - 1)*[0])
729
+ if self[1].is_zero:
730
+ return Point([0,1] + (dim - 2)*[0])
731
+ # if the first two coordinates aren't zero, we can create a non-zero
732
+ # orthogonal vector by swapping them, negating one, and padding with zeros
733
+ return Point([-self[1], self[0]] + (dim - 2)*[0])
734
+
735
+ @staticmethod
736
+ def project(a, b):
737
+ """Project the point `a` onto the line between the origin
738
+ and point `b` along the normal direction.
739
+
740
+ Parameters
741
+ ==========
742
+
743
+ a : Point
744
+ b : Point
745
+
746
+ Returns
747
+ =======
748
+
749
+ p : Point
750
+
751
+ See Also
752
+ ========
753
+
754
+ sympy.geometry.line.LinearEntity.projection
755
+
756
+ Examples
757
+ ========
758
+
759
+ >>> from sympy import Line, Point
760
+ >>> a = Point(1, 2)
761
+ >>> b = Point(2, 5)
762
+ >>> z = a.origin
763
+ >>> p = Point.project(a, b)
764
+ >>> Line(p, a).is_perpendicular(Line(p, b))
765
+ True
766
+ >>> Point.is_collinear(z, p, b)
767
+ True
768
+ """
769
+ a, b = Point._normalize_dimension(Point(a), Point(b))
770
+ if b.is_zero:
771
+ raise ValueError("Cannot project to the zero vector.")
772
+ return b*(a.dot(b) / b.dot(b))
773
+
774
+ def taxicab_distance(self, p):
775
+ """The Taxicab Distance from self to point p.
776
+
777
+ Returns the sum of the horizontal and vertical distances to point p.
778
+
779
+ Parameters
780
+ ==========
781
+
782
+ p : Point
783
+
784
+ Returns
785
+ =======
786
+
787
+ taxicab_distance : The sum of the horizontal
788
+ and vertical distances to point p.
789
+
790
+ See Also
791
+ ========
792
+
793
+ sympy.geometry.point.Point.distance
794
+
795
+ Examples
796
+ ========
797
+
798
+ >>> from sympy import Point
799
+ >>> p1, p2 = Point(1, 1), Point(4, 5)
800
+ >>> p1.taxicab_distance(p2)
801
+ 7
802
+
803
+ """
804
+ s, p = Point._normalize_dimension(self, Point(p))
805
+ return Add(*(abs(a - b) for a, b in zip(s, p)))
806
+
807
+ def canberra_distance(self, p):
808
+ """The Canberra Distance from self to point p.
809
+
810
+ Returns the weighted sum of horizontal and vertical distances to
811
+ point p.
812
+
813
+ Parameters
814
+ ==========
815
+
816
+ p : Point
817
+
818
+ Returns
819
+ =======
820
+
821
+ canberra_distance : The weighted sum of horizontal and vertical
822
+ distances to point p. The weight used is the sum of absolute values
823
+ of the coordinates.
824
+
825
+ Examples
826
+ ========
827
+
828
+ >>> from sympy import Point
829
+ >>> p1, p2 = Point(1, 1), Point(3, 3)
830
+ >>> p1.canberra_distance(p2)
831
+ 1
832
+ >>> p1, p2 = Point(0, 0), Point(3, 3)
833
+ >>> p1.canberra_distance(p2)
834
+ 2
835
+
836
+ Raises
837
+ ======
838
+
839
+ ValueError when both vectors are zero.
840
+
841
+ See Also
842
+ ========
843
+
844
+ sympy.geometry.point.Point.distance
845
+
846
+ """
847
+
848
+ s, p = Point._normalize_dimension(self, Point(p))
849
+ if self.is_zero and p.is_zero:
850
+ raise ValueError("Cannot project to the zero vector.")
851
+ return Add(*((abs(a - b)/(abs(a) + abs(b))) for a, b in zip(s, p)))
852
+
853
+ @property
854
+ def unit(self):
855
+ """Return the Point that is in the same direction as `self`
856
+ and a distance of 1 from the origin"""
857
+ return self / abs(self)
858
+
859
+
860
+ class Point2D(Point):
861
+ """A point in a 2-dimensional Euclidean space.
862
+
863
+ Parameters
864
+ ==========
865
+
866
+ coords
867
+ A sequence of 2 coordinate values.
868
+
869
+ Attributes
870
+ ==========
871
+
872
+ x
873
+ y
874
+ length
875
+
876
+ Raises
877
+ ======
878
+
879
+ TypeError
880
+ When trying to add or subtract points with different dimensions.
881
+ When trying to create a point with more than two dimensions.
882
+ When `intersection` is called with object other than a Point.
883
+
884
+ See Also
885
+ ========
886
+
887
+ sympy.geometry.line.Segment : Connects two Points
888
+
889
+ Examples
890
+ ========
891
+
892
+ >>> from sympy import Point2D
893
+ >>> from sympy.abc import x
894
+ >>> Point2D(1, 2)
895
+ Point2D(1, 2)
896
+ >>> Point2D([1, 2])
897
+ Point2D(1, 2)
898
+ >>> Point2D(0, x)
899
+ Point2D(0, x)
900
+
901
+ Floats are automatically converted to Rational unless the
902
+ evaluate flag is False:
903
+
904
+ >>> Point2D(0.5, 0.25)
905
+ Point2D(1/2, 1/4)
906
+ >>> Point2D(0.5, 0.25, evaluate=False)
907
+ Point2D(0.5, 0.25)
908
+
909
+ """
910
+
911
+ _ambient_dimension = 2
912
+
913
+ def __new__(cls, *args, _nocheck=False, **kwargs):
914
+ if not _nocheck:
915
+ kwargs['dim'] = 2
916
+ args = Point(*args, **kwargs)
917
+ return GeometryEntity.__new__(cls, *args)
918
+
919
+ def __contains__(self, item):
920
+ return item == self
921
+
922
+ @property
923
+ def bounds(self):
924
+ """Return a tuple (xmin, ymin, xmax, ymax) representing the bounding
925
+ rectangle for the geometric figure.
926
+
927
+ """
928
+
929
+ return (self.x, self.y, self.x, self.y)
930
+
931
+ def rotate(self, angle, pt=None):
932
+ """Rotate ``angle`` radians counterclockwise about Point ``pt``.
933
+
934
+ See Also
935
+ ========
936
+
937
+ translate, scale
938
+
939
+ Examples
940
+ ========
941
+
942
+ >>> from sympy import Point2D, pi
943
+ >>> t = Point2D(1, 0)
944
+ >>> t.rotate(pi/2)
945
+ Point2D(0, 1)
946
+ >>> t.rotate(pi/2, (2, 0))
947
+ Point2D(2, -1)
948
+
949
+ """
950
+ c = cos(angle)
951
+ s = sin(angle)
952
+
953
+ rv = self
954
+ if pt is not None:
955
+ pt = Point(pt, dim=2)
956
+ rv -= pt
957
+ x, y = rv.args
958
+ rv = Point(c*x - s*y, s*x + c*y)
959
+ if pt is not None:
960
+ rv += pt
961
+ return rv
962
+
963
+ def scale(self, x=1, y=1, pt=None):
964
+ """Scale the coordinates of the Point by multiplying by
965
+ ``x`` and ``y`` after subtracting ``pt`` -- default is (0, 0) --
966
+ and then adding ``pt`` back again (i.e. ``pt`` is the point of
967
+ reference for the scaling).
968
+
969
+ See Also
970
+ ========
971
+
972
+ rotate, translate
973
+
974
+ Examples
975
+ ========
976
+
977
+ >>> from sympy import Point2D
978
+ >>> t = Point2D(1, 1)
979
+ >>> t.scale(2)
980
+ Point2D(2, 1)
981
+ >>> t.scale(2, 2)
982
+ Point2D(2, 2)
983
+
984
+ """
985
+ if pt:
986
+ pt = Point(pt, dim=2)
987
+ return self.translate(*(-pt).args).scale(x, y).translate(*pt.args)
988
+ return Point(self.x*x, self.y*y)
989
+
990
+ def transform(self, matrix):
991
+ """Return the point after applying the transformation described
992
+ by the 3x3 Matrix, ``matrix``.
993
+
994
+ See Also
995
+ ========
996
+ sympy.geometry.point.Point2D.rotate
997
+ sympy.geometry.point.Point2D.scale
998
+ sympy.geometry.point.Point2D.translate
999
+ """
1000
+ if not (matrix.is_Matrix and matrix.shape == (3, 3)):
1001
+ raise ValueError("matrix must be a 3x3 matrix")
1002
+ x, y = self.args
1003
+ return Point(*(Matrix(1, 3, [x, y, 1])*matrix).tolist()[0][:2])
1004
+
1005
+ def translate(self, x=0, y=0):
1006
+ """Shift the Point by adding x and y to the coordinates of the Point.
1007
+
1008
+ See Also
1009
+ ========
1010
+
1011
+ sympy.geometry.point.Point2D.rotate, scale
1012
+
1013
+ Examples
1014
+ ========
1015
+
1016
+ >>> from sympy import Point2D
1017
+ >>> t = Point2D(0, 1)
1018
+ >>> t.translate(2)
1019
+ Point2D(2, 1)
1020
+ >>> t.translate(2, 2)
1021
+ Point2D(2, 3)
1022
+ >>> t + Point2D(2, 2)
1023
+ Point2D(2, 3)
1024
+
1025
+ """
1026
+ return Point(self.x + x, self.y + y)
1027
+
1028
+ @property
1029
+ def coordinates(self):
1030
+ """
1031
+ Returns the two coordinates of the Point.
1032
+
1033
+ Examples
1034
+ ========
1035
+
1036
+ >>> from sympy import Point2D
1037
+ >>> p = Point2D(0, 1)
1038
+ >>> p.coordinates
1039
+ (0, 1)
1040
+ """
1041
+ return self.args
1042
+
1043
+ @property
1044
+ def x(self):
1045
+ """
1046
+ Returns the X coordinate of the Point.
1047
+
1048
+ Examples
1049
+ ========
1050
+
1051
+ >>> from sympy import Point2D
1052
+ >>> p = Point2D(0, 1)
1053
+ >>> p.x
1054
+ 0
1055
+ """
1056
+ return self.args[0]
1057
+
1058
+ @property
1059
+ def y(self):
1060
+ """
1061
+ Returns the Y coordinate of the Point.
1062
+
1063
+ Examples
1064
+ ========
1065
+
1066
+ >>> from sympy import Point2D
1067
+ >>> p = Point2D(0, 1)
1068
+ >>> p.y
1069
+ 1
1070
+ """
1071
+ return self.args[1]
1072
+
1073
+ class Point3D(Point):
1074
+ """A point in a 3-dimensional Euclidean space.
1075
+
1076
+ Parameters
1077
+ ==========
1078
+
1079
+ coords
1080
+ A sequence of 3 coordinate values.
1081
+
1082
+ Attributes
1083
+ ==========
1084
+
1085
+ x
1086
+ y
1087
+ z
1088
+ length
1089
+
1090
+ Raises
1091
+ ======
1092
+
1093
+ TypeError
1094
+ When trying to add or subtract points with different dimensions.
1095
+ When `intersection` is called with object other than a Point.
1096
+
1097
+ Examples
1098
+ ========
1099
+
1100
+ >>> from sympy import Point3D
1101
+ >>> from sympy.abc import x
1102
+ >>> Point3D(1, 2, 3)
1103
+ Point3D(1, 2, 3)
1104
+ >>> Point3D([1, 2, 3])
1105
+ Point3D(1, 2, 3)
1106
+ >>> Point3D(0, x, 3)
1107
+ Point3D(0, x, 3)
1108
+
1109
+ Floats are automatically converted to Rational unless the
1110
+ evaluate flag is False:
1111
+
1112
+ >>> Point3D(0.5, 0.25, 2)
1113
+ Point3D(1/2, 1/4, 2)
1114
+ >>> Point3D(0.5, 0.25, 3, evaluate=False)
1115
+ Point3D(0.5, 0.25, 3)
1116
+
1117
+ """
1118
+
1119
+ _ambient_dimension = 3
1120
+
1121
+ def __new__(cls, *args, _nocheck=False, **kwargs):
1122
+ if not _nocheck:
1123
+ kwargs['dim'] = 3
1124
+ args = Point(*args, **kwargs)
1125
+ return GeometryEntity.__new__(cls, *args)
1126
+
1127
+ def __contains__(self, item):
1128
+ return item == self
1129
+
1130
+ @staticmethod
1131
+ def are_collinear(*points):
1132
+ """Is a sequence of points collinear?
1133
+
1134
+ Test whether or not a set of points are collinear. Returns True if
1135
+ the set of points are collinear, or False otherwise.
1136
+
1137
+ Parameters
1138
+ ==========
1139
+
1140
+ points : sequence of Point
1141
+
1142
+ Returns
1143
+ =======
1144
+
1145
+ are_collinear : boolean
1146
+
1147
+ See Also
1148
+ ========
1149
+
1150
+ sympy.geometry.line.Line3D
1151
+
1152
+ Examples
1153
+ ========
1154
+
1155
+ >>> from sympy import Point3D
1156
+ >>> from sympy.abc import x
1157
+ >>> p1, p2 = Point3D(0, 0, 0), Point3D(1, 1, 1)
1158
+ >>> p3, p4, p5 = Point3D(2, 2, 2), Point3D(x, x, x), Point3D(1, 2, 6)
1159
+ >>> Point3D.are_collinear(p1, p2, p3, p4)
1160
+ True
1161
+ >>> Point3D.are_collinear(p1, p2, p3, p5)
1162
+ False
1163
+ """
1164
+ return Point.is_collinear(*points)
1165
+
1166
+ def direction_cosine(self, point):
1167
+ """
1168
+ Gives the direction cosine between 2 points
1169
+
1170
+ Parameters
1171
+ ==========
1172
+
1173
+ p : Point3D
1174
+
1175
+ Returns
1176
+ =======
1177
+
1178
+ list
1179
+
1180
+ Examples
1181
+ ========
1182
+
1183
+ >>> from sympy import Point3D
1184
+ >>> p1 = Point3D(1, 2, 3)
1185
+ >>> p1.direction_cosine(Point3D(2, 3, 5))
1186
+ [sqrt(6)/6, sqrt(6)/6, sqrt(6)/3]
1187
+ """
1188
+ a = self.direction_ratio(point)
1189
+ b = sqrt(Add(*(i**2 for i in a)))
1190
+ return [(point.x - self.x) / b,(point.y - self.y) / b,
1191
+ (point.z - self.z) / b]
1192
+
1193
+ def direction_ratio(self, point):
1194
+ """
1195
+ Gives the direction ratio between 2 points
1196
+
1197
+ Parameters
1198
+ ==========
1199
+
1200
+ p : Point3D
1201
+
1202
+ Returns
1203
+ =======
1204
+
1205
+ list
1206
+
1207
+ Examples
1208
+ ========
1209
+
1210
+ >>> from sympy import Point3D
1211
+ >>> p1 = Point3D(1, 2, 3)
1212
+ >>> p1.direction_ratio(Point3D(2, 3, 5))
1213
+ [1, 1, 2]
1214
+ """
1215
+ return [(point.x - self.x),(point.y - self.y),(point.z - self.z)]
1216
+
1217
+ def intersection(self, other):
1218
+ """The intersection between this point and another GeometryEntity.
1219
+
1220
+ Parameters
1221
+ ==========
1222
+
1223
+ other : GeometryEntity or sequence of coordinates
1224
+
1225
+ Returns
1226
+ =======
1227
+
1228
+ intersection : list of Points
1229
+
1230
+ Notes
1231
+ =====
1232
+
1233
+ The return value will either be an empty list if there is no
1234
+ intersection, otherwise it will contain this point.
1235
+
1236
+ Examples
1237
+ ========
1238
+
1239
+ >>> from sympy import Point3D
1240
+ >>> p1, p2, p3 = Point3D(0, 0, 0), Point3D(1, 1, 1), Point3D(0, 0, 0)
1241
+ >>> p1.intersection(p2)
1242
+ []
1243
+ >>> p1.intersection(p3)
1244
+ [Point3D(0, 0, 0)]
1245
+
1246
+ """
1247
+ if not isinstance(other, GeometryEntity):
1248
+ other = Point(other, dim=3)
1249
+ if isinstance(other, Point3D):
1250
+ if self == other:
1251
+ return [self]
1252
+ return []
1253
+ return other.intersection(self)
1254
+
1255
+ def scale(self, x=1, y=1, z=1, pt=None):
1256
+ """Scale the coordinates of the Point by multiplying by
1257
+ ``x`` and ``y`` after subtracting ``pt`` -- default is (0, 0) --
1258
+ and then adding ``pt`` back again (i.e. ``pt`` is the point of
1259
+ reference for the scaling).
1260
+
1261
+ See Also
1262
+ ========
1263
+
1264
+ translate
1265
+
1266
+ Examples
1267
+ ========
1268
+
1269
+ >>> from sympy import Point3D
1270
+ >>> t = Point3D(1, 1, 1)
1271
+ >>> t.scale(2)
1272
+ Point3D(2, 1, 1)
1273
+ >>> t.scale(2, 2)
1274
+ Point3D(2, 2, 1)
1275
+
1276
+ """
1277
+ if pt:
1278
+ pt = Point3D(pt)
1279
+ return self.translate(*(-pt).args).scale(x, y, z).translate(*pt.args)
1280
+ return Point3D(self.x*x, self.y*y, self.z*z)
1281
+
1282
+ def transform(self, matrix):
1283
+ """Return the point after applying the transformation described
1284
+ by the 4x4 Matrix, ``matrix``.
1285
+
1286
+ See Also
1287
+ ========
1288
+ sympy.geometry.point.Point3D.scale
1289
+ sympy.geometry.point.Point3D.translate
1290
+ """
1291
+ if not (matrix.is_Matrix and matrix.shape == (4, 4)):
1292
+ raise ValueError("matrix must be a 4x4 matrix")
1293
+ x, y, z = self.args
1294
+ m = Transpose(matrix)
1295
+ return Point3D(*(Matrix(1, 4, [x, y, z, 1])*m).tolist()[0][:3])
1296
+
1297
+ def translate(self, x=0, y=0, z=0):
1298
+ """Shift the Point by adding x and y to the coordinates of the Point.
1299
+
1300
+ See Also
1301
+ ========
1302
+
1303
+ scale
1304
+
1305
+ Examples
1306
+ ========
1307
+
1308
+ >>> from sympy import Point3D
1309
+ >>> t = Point3D(0, 1, 1)
1310
+ >>> t.translate(2)
1311
+ Point3D(2, 1, 1)
1312
+ >>> t.translate(2, 2)
1313
+ Point3D(2, 3, 1)
1314
+ >>> t + Point3D(2, 2, 2)
1315
+ Point3D(2, 3, 3)
1316
+
1317
+ """
1318
+ return Point3D(self.x + x, self.y + y, self.z + z)
1319
+
1320
+ @property
1321
+ def coordinates(self):
1322
+ """
1323
+ Returns the three coordinates of the Point.
1324
+
1325
+ Examples
1326
+ ========
1327
+
1328
+ >>> from sympy import Point3D
1329
+ >>> p = Point3D(0, 1, 2)
1330
+ >>> p.coordinates
1331
+ (0, 1, 2)
1332
+ """
1333
+ return self.args
1334
+
1335
+ @property
1336
+ def x(self):
1337
+ """
1338
+ Returns the X coordinate of the Point.
1339
+
1340
+ Examples
1341
+ ========
1342
+
1343
+ >>> from sympy import Point3D
1344
+ >>> p = Point3D(0, 1, 3)
1345
+ >>> p.x
1346
+ 0
1347
+ """
1348
+ return self.args[0]
1349
+
1350
+ @property
1351
+ def y(self):
1352
+ """
1353
+ Returns the Y coordinate of the Point.
1354
+
1355
+ Examples
1356
+ ========
1357
+
1358
+ >>> from sympy import Point3D
1359
+ >>> p = Point3D(0, 1, 2)
1360
+ >>> p.y
1361
+ 1
1362
+ """
1363
+ return self.args[1]
1364
+
1365
+ @property
1366
+ def z(self):
1367
+ """
1368
+ Returns the Z coordinate of the Point.
1369
+
1370
+ Examples
1371
+ ========
1372
+
1373
+ >>> from sympy import Point3D
1374
+ >>> p = Point3D(0, 1, 1)
1375
+ >>> p.z
1376
+ 1
1377
+ """
1378
+ return self.args[2]